Completely refactor the LLM integration method to easily replace the LLM used by each module and optimize the retransmission mechanism.

This commit is contained in:
666ghj
2025-10-09 13:45:39 +08:00
parent ce74f00137
commit 154b29c0d7
73 changed files with 942 additions and 51758 deletions
+3 -7
View File
@@ -1,11 +1,7 @@
"""
LLM调用模块
支持多种大语言模型的统一接口
LLM module for the Media Engine.
"""
from .base import BaseLLM
from .deepseek import DeepSeekLLM
from .openai_llm import OpenAILLM
from .gemini_llm import GeminiLLM
from .base import LLMClient
__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM", "GeminiLLM"]
__all__ = ["LLMClient"]
+81 -50
View File
@@ -1,61 +1,92 @@
"""
LLM基础抽象类
定义所有LLM实现需要遵循的接口标准
Unified OpenAI-compatible LLM client for the Media Engine, with retry support.
"""
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any
import os
import sys
from typing import Any, Dict, Optional
from openai import OpenAI
# Ensure project-level retry helper is importable
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(project_root, "utils")
if utils_dir not in sys.path:
sys.path.append(utils_dir)
try:
from retry_helper import with_retry, LLM_RETRY_CONFIG
except ImportError:
def with_retry(config=None):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class BaseLLM(ABC):
"""LLM基础抽象类"""
def __init__(self, api_key: str, model_name: Optional[str] = None):
"""
初始化LLM客户端
Args:
api_key: API密钥
model_name: 模型名称,如果不指定则使用默认模型
"""
class LLMClient:
"""
Minimal wrapper around the OpenAI-compatible chat completion API.
"""
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
if not api_key:
raise ValueError("Media Engine LLM API key is required.")
if not model_name:
raise ValueError("Media Engine model name is required.")
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
@abstractmethod
self.provider = model_name
timeout_fallback = os.getenv("LLM_REQUEST_TIMEOUT") or os.getenv("MEDIA_ENGINE_REQUEST_TIMEOUT") or "180"
try:
self.timeout = float(timeout_fallback)
except ValueError:
self.timeout = 300.0
client_kwargs: Dict[str, Any] = {
"api_key": api_key,
"max_retries": 0,
}
if base_url:
client_kwargs["base_url"] = base_url
self.client = OpenAI(**client_kwargs)
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用LLM生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
LLM生成的回复文本
"""
pass
@abstractmethod
def get_default_model(self) -> str:
"""
获取默认模型名称
Returns:
默认模型名称
"""
pass
def validate_response(self, response: str) -> str:
"""
验证和清理响应内容
Args:
response: LLM原始响应
Returns:
清理后的响应内容
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty", "stream"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
timeout = kwargs.pop("timeout", self.timeout)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
if response.choices and response.choices[0].message:
return self.validate_response(response.choices[0].message.content)
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
if response is None:
return ""
return response.strip()
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": self.provider,
"model": self.model_name,
"api_base": self.base_url or "default",
}
-118
View File
@@ -1,118 +0,0 @@
"""
DeepSeek LLM实现
使用DeepSeek API进行文本生成
"""
import os
import sys
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
DEFAULT_DEEPSEEK_BASE_URL = "https://api.deepseek.com"
# 添加utils目录到Python路径并导入重试模块
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(root_dir, 'utils')
if utils_dir not in sys.path:
sys.path.append(utils_dir)
from retry_helper import with_retry, with_graceful_retry, LLM_RETRY_CONFIG
except ImportError:
# 如果无法导入重试模块,使用空装饰器避免报错
def with_retry(config):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class DeepSeekLLM(BaseLLM):
"""DeepSeek LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None, base_url: Optional[str] = None):
"""
初始化DeepSeek客户端
Args:
api_key: DeepSeek API密钥,如果不提供则从环境变量读取
model_name: 模型名称,默认使用deepseek-chat
base_url: DeepSeek API基础地址
"""
if api_key is None:
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DeepSeek API Key未找到!请设置DEEPSEEK_API_KEY环境变量或在初始化时提供")
super().__init__(api_key, model_name)
self.base_url = base_url or os.getenv("DEEPSEEK_BASE_URL") or DEFAULT_DEEPSEEK_BASE_URL
# 初始化OpenAI客户端,使用DeepSeek的endpoint
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "deepseek-chat"
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用DeepSeek API生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
DeepSeek生成的回复文本
"""
try:
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 30000), # 提高到30000以支持一万字报告
"stream": False
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
except Exception as e:
print(f"DeepSeek API调用错误: {str(e)}")
raise e
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "DeepSeek",
"model": self.default_model,
"api_base": self.base_url
}
-118
View File
@@ -1,118 +0,0 @@
"""
Gemini LLM实现
使用Gemini 2.5-pro中转API进行文本生成
"""
import os
import sys
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
DEFAULT_GEMINI_BASE_URL = "https://www.chataiapi.com/v1"
# 添加utils目录到Python路径并导入重试模块
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(root_dir, 'utils')
if utils_dir not in sys.path:
sys.path.append(utils_dir)
from retry_helper import with_retry, with_graceful_retry, LLM_RETRY_CONFIG
except ImportError:
# 如果无法导入重试模块,使用空装饰器避免报错
def with_retry(config):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class GeminiLLM(BaseLLM):
"""Gemini LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None, base_url: Optional[str] = None):
"""
初始化Gemini客户端
Args:
api_key: Gemini API密钥,如果不提供则从环境变量读取
model_name: 模型名称,默认使用gemini-2.5-pro
base_url: Gemini API基础地址
"""
if api_key is None:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("Gemini API Key未找到!请设置GEMINI_API_KEY环境变量或在初始化时提供")
super().__init__(api_key, model_name)
self.base_url = base_url or os.getenv("GEMINI_BASE_URL") or DEFAULT_GEMINI_BASE_URL
# 初始化OpenAI客户端,使用Gemini的中转endpoint
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "gemini-2.5-pro"
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用Gemini API生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
Gemini生成的回复文本
"""
try:
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 30000), # 提高到30000以支持一万字报告
"stream": False
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
except Exception as e:
print(f"Gemini API调用错误: {str(e)}")
raise e
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "Gemini",
"model": self.default_model,
"api_base": self.base_url
}
-108
View File
@@ -1,108 +0,0 @@
"""
OpenAI LLM实现
使用OpenAI API进行文本生成
"""
import os
import sys
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
# 添加utils目录到Python路径并导入重试模块
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(root_dir, 'utils')
if utils_dir not in sys.path:
sys.path.append(utils_dir)
from retry_helper import with_retry, with_graceful_retry, LLM_RETRY_CONFIG
except ImportError:
# 如果无法导入重试模块,使用空装饰器避免报错
def with_retry(config):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class OpenAILLM(BaseLLM):
"""OpenAI LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
"""
初始化OpenAI客户端
Args:
api_key: OpenAI API密钥,如果不提供则从环境变量读取
model_name: 模型名称,默认使用gpt-4o-mini
"""
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OpenAI API Key未找到!请设置OPENAI_API_KEY环境变量或在初始化时提供")
super().__init__(api_key, model_name)
# 初始化OpenAI客户端
self.client = OpenAI(api_key=self.api_key)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "gpt-4o-mini"
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用OpenAI API生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
OpenAI生成的回复文本
"""
try:
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 30000) # 提高到30000以支持一万字报告
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
except Exception as e:
print(f"OpenAI API调用错误: {str(e)}")
raise e
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "OpenAI",
"model": self.default_model,
"api_base": "https://api.openai.com"
}