Completely refactor the LLM integration method to easily replace the LLM used by each module and optimize the retransmission mechanism.

This commit is contained in:
666ghj
2025-10-09 13:45:39 +08:00
parent ce74f00137
commit 154b29c0d7
73 changed files with 942 additions and 51758 deletions
+80 -86
View File
@@ -1,95 +1,89 @@
"""
Report Engine LLM基类
定义所有LLM实现的基础接口
Unified OpenAI-compatible LLM client for the Report Engine, with retry support.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import os
import sys
from typing import Any, Dict, Optional
from openai import OpenAI
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(project_root, "utils")
if utils_dir not in sys.path:
sys.path.append(utils_dir)
try:
from retry_helper import with_retry, LLM_RETRY_CONFIG
except ImportError:
def with_retry(config=None):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class BaseLLM(ABC):
"""LLM基类"""
def __init__(self, api_key: str, model_name: Optional[str] = None):
"""
初始化LLM客户端
Args:
api_key: API密钥
model_name: 模型名称
"""
class LLMClient:
"""Minimal wrapper around the OpenAI-compatible chat completion API."""
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
if not api_key:
raise ValueError("Report Engine LLM API key is required.")
if not model_name:
raise ValueError("Report Engine model name is required.")
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
@abstractmethod
self.provider = model_name
timeout_fallback = os.getenv("LLM_REQUEST_TIMEOUT") or os.getenv("REPORT_ENGINE_REQUEST_TIMEOUT") or "180"
try:
self.timeout = float(timeout_fallback)
except ValueError:
self.timeout = 300.0
client_kwargs: Dict[str, Any] = {
"api_key": api_key,
"max_retries": 0,
}
if base_url:
client_kwargs["base_url"] = base_url
self.client = OpenAI(**client_kwargs)
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用LLM生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数
Returns:
生成的回复文本
"""
pass
@abstractmethod
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
pass
@abstractmethod
def get_default_model(self) -> str:
"""
获取默认模型名称
Returns:
默认模型名称
"""
pass
def validate_response(self, response: str) -> str:
"""
验证和清理响应内容
Args:
response: 原始响应
Returns:
清理后的响应
"""
if not response:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty", "stream"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
timeout = kwargs.pop("timeout", self.timeout)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
if response.choices and response.choices[0].message:
return self.validate_response(response.choices[0].message.content)
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
if response is None:
return ""
# 移除多余的空白字符
response = response.strip()
# 确保响应不为空
if not response:
return "抱歉,生成的内容为空。"
return response
def estimate_tokens(self, text: str) -> int:
"""
估算文本的token数量(简单实现)
Args:
text: 输入文本
Returns:
估算的token数量
"""
# 简单估算:中文字符按1.5个token计算,英文单词按1个token计算
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
english_words = len(text.split()) - chinese_chars
return int(chinese_chars * 1.5 + english_words)
return response.strip()
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": self.provider,
"model": self.model_name,
"api_base": self.base_url or "default",
}