Completely refactor the LLM integration method to easily replace the LLM used by each module and optimize the retransmission mechanism.

This commit is contained in:
666ghj
2025-10-09 13:45:39 +08:00
parent ce74f00137
commit 154b29c0d7
73 changed files with 942 additions and 51758 deletions
+7 -11
View File
@@ -9,7 +9,7 @@ import logging
from datetime import datetime
from typing import Optional, Dict, Any, List
from .llms import GeminiLLM, BaseLLM
from .llms import LLMClient
from .nodes import (
TemplateSelectionNode,
HTMLGenerationNode
@@ -186,17 +186,13 @@ class ReportAgent:
}
self.file_baseline.initialize_baseline(directories)
def _initialize_llm(self) -> BaseLLM:
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
if self.config.default_llm_provider == "gemini":
return GeminiLLM(
api_key=self.config.gemini_api_key,
model_name=self.config.gemini_model,
base_url=self.config.gemini_base_url,
config=self.config # 传入配置对象以支持动态超时设置
)
else:
raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}")
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
)
def _initialize_nodes(self):
"""初始化处理节点"""
+3 -5
View File
@@ -1,9 +1,7 @@
"""
Report Engine LLM模块
包含各种大语言模型的接口实现
LLM module for the Report Engine.
"""
from .base import BaseLLM
from .gemini_llm import GeminiLLM
from .base import LLMClient
__all__ = ["BaseLLM", "GeminiLLM"]
__all__ = ["LLMClient"]
+80 -86
View File
@@ -1,95 +1,89 @@
"""
Report Engine LLM基类
定义所有LLM实现的基础接口
Unified OpenAI-compatible LLM client for the Report Engine, with retry support.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import os
import sys
from typing import Any, Dict, Optional
from openai import OpenAI
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(project_root, "utils")
if utils_dir not in sys.path:
sys.path.append(utils_dir)
try:
from retry_helper import with_retry, LLM_RETRY_CONFIG
except ImportError:
def with_retry(config=None):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class BaseLLM(ABC):
"""LLM基类"""
def __init__(self, api_key: str, model_name: Optional[str] = None):
"""
初始化LLM客户端
Args:
api_key: API密钥
model_name: 模型名称
"""
class LLMClient:
"""Minimal wrapper around the OpenAI-compatible chat completion API."""
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
if not api_key:
raise ValueError("Report Engine LLM API key is required.")
if not model_name:
raise ValueError("Report Engine model name is required.")
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
@abstractmethod
self.provider = model_name
timeout_fallback = os.getenv("LLM_REQUEST_TIMEOUT") or os.getenv("REPORT_ENGINE_REQUEST_TIMEOUT") or "180"
try:
self.timeout = float(timeout_fallback)
except ValueError:
self.timeout = 300.0
client_kwargs: Dict[str, Any] = {
"api_key": api_key,
"max_retries": 0,
}
if base_url:
client_kwargs["base_url"] = base_url
self.client = OpenAI(**client_kwargs)
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用LLM生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数
Returns:
生成的回复文本
"""
pass
@abstractmethod
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
pass
@abstractmethod
def get_default_model(self) -> str:
"""
获取默认模型名称
Returns:
默认模型名称
"""
pass
def validate_response(self, response: str) -> str:
"""
验证和清理响应内容
Args:
response: 原始响应
Returns:
清理后的响应
"""
if not response:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty", "stream"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
timeout = kwargs.pop("timeout", self.timeout)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
if response.choices and response.choices[0].message:
return self.validate_response(response.choices[0].message.content)
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
if response is None:
return ""
# 移除多余的空白字符
response = response.strip()
# 确保响应不为空
if not response:
return "抱歉,生成的内容为空。"
return response
def estimate_tokens(self, text: str) -> int:
"""
估算文本的token数量(简单实现)
Args:
text: 输入文本
Returns:
估算的token数量
"""
# 简单估算:中文字符按1.5个token计算,英文单词按1个token计算
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
english_words = len(text.split()) - chinese_chars
return int(chinese_chars * 1.5 + english_words)
return response.strip()
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": self.provider,
"model": self.model_name,
"api_base": self.base_url or "default",
}
-203
View File
@@ -1,203 +0,0 @@
"""
Report Engine Gemini LLM实现
使用Gemini 2.5-pro中转API进行文本生成
"""
import os
import sys
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
DEFAULT_GEMINI_BASE_URL = "https://www.chataiapi.com/v1"
# 导入根目录的config
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.dirname(os.path.dirname(current_dir))
if root_dir not in sys.path:
sys.path.append(root_dir)
import config
except ImportError:
config = None
# 添加utils目录到Python路径并导入重试模块
try:
if root_dir:
utils_dir = os.path.join(root_dir, 'utils')
if utils_dir not in sys.path:
sys.path.append(utils_dir)
from retry_helper import with_retry, with_graceful_retry, LLM_RETRY_CONFIG, RetryConfig
# 创建动态重试配置生成函数
def create_report_retry_config(config=None):
"""创建ReportEngine专用的重试配置,适应7分钟平均生成时间"""
return RetryConfig(
max_retries=config.max_retries if config and hasattr(config, 'max_retries') else 8,
initial_delay=8.0, # 初始延迟增加到8秒,适应长时间生成
backoff_factor=2.0, # 保持2倍退避
max_delay=config.max_retry_delay if config and hasattr(config, 'max_retry_delay') else 180.0
)
# 创建默认配置用于模块导入时的向后兼容
REPORT_LLM_RETRY_CONFIG = create_report_retry_config()
except ImportError:
# 如果无法导入重试模块,使用空装饰器避免报错
def with_retry(config):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
REPORT_LLM_RETRY_CONFIG = None
class GeminiLLM(BaseLLM):
"""Report Engine Gemini LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None, base_url: Optional[str] = None, config=None):
"""
初始化Gemini客户端
Args:
api_key: Gemini API密钥,如果不提供则从config或环境变量读取
model_name: 模型名称,默认使用gemini-2.5-pro
base_url: Gemini API基础地址
config: 配置对象,用于获取超时设置
"""
if api_key is None:
# 优先从根目录config读取
if config and hasattr(config, 'GEMINI_API_KEY'):
api_key = config.GEMINI_API_KEY
else:
# 备选方案:从环境变量读取
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("Gemini API Key未找到!请在config.py中设置GEMINI_API_KEY或设置环境变量")
super().__init__(api_key, model_name)
# 存储配置对象
self.config = config
# 从配置获取超时时间,默认15分钟(适应7分钟平均生成时间)
timeout = config.api_timeout if config and hasattr(config, 'api_timeout') else 900.0
self.base_url = (
base_url
or (getattr(self.config, 'gemini_base_url', None) if self.config else None)
or os.getenv('GEMINI_BASE_URL')
or DEFAULT_GEMINI_BASE_URL
)
# 创建针对此实例的重试配置
self.retry_config = create_report_retry_config(config)
# 初始化OpenAI客户端,使用Gemini的中转endpoint
# 专门为报告生成设置长超时(15分钟),适应7分钟平均生成时间
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=timeout
)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "gemini-2.5-pro"
def _make_api_call(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
内部API调用方法
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数
Returns:
API响应内容
"""
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 50000),
"stream": False
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用Gemini API生成回复(带动态重试配置)
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
Gemini生成的回复文本
"""
import time
last_exception = None
for attempt in range(self.retry_config.max_retries + 1):
try:
result = self._make_api_call(system_prompt, user_prompt, **kwargs)
if attempt > 0:
print(f"Report Engine Gemini API在第 {attempt + 1} 次尝试后成功")
return result
except Exception as e:
last_exception = e
if attempt == self.retry_config.max_retries:
print(f"Report Engine Gemini API在 {self.retry_config.max_retries + 1} 次尝试后仍然失败")
print(f"最终错误: {str(e)}")
raise e
# 计算延迟时间
delay = min(
self.retry_config.initial_delay * (self.retry_config.backoff_factor ** attempt),
self.retry_config.max_delay
)
print(f"Report Engine Gemini API第 {attempt + 1} 次尝试失败: {str(e)}")
print(f"将在 {delay:.1f} 秒后进行第 {attempt + 2} 次尝试...")
time.sleep(delay)
# 这里不应该到达,但作为安全网
if last_exception:
raise last_exception
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "Gemini",
"model": self.default_model,
"api_base": self.base_url,
"purpose": "Report Generation"
}
+2 -2
View File
@@ -6,14 +6,14 @@ Report Engine节点基类
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from ..llms.base import BaseLLM
from ..llms.base import LLMClient
from ..state.state import ReportState
class BaseNode(ABC):
"""节点基类"""
def __init__(self, llm_client: BaseLLM, node_name: str = ""):
def __init__(self, llm_client: LLMClient, node_name: str = ""):
"""
初始化节点
+2 -2
View File
@@ -8,7 +8,7 @@ from datetime import datetime
from typing import Dict, Any
from .base_node import StateMutationNode
from ..llms.base import BaseLLM
from ..llms.base import LLMClient
from ..state.state import ReportState
from ..prompts import SYSTEM_PROMPT_HTML_GENERATION
# 不再需要text_processing依赖
@@ -17,7 +17,7 @@ from ..prompts import SYSTEM_PROMPT_HTML_GENERATION
class HTMLGenerationNode(StateMutationNode):
"""HTML生成处理节点"""
def __init__(self, llm_client: BaseLLM):
def __init__(self, llm_client: LLMClient):
"""
初始化HTML生成节点
+104 -103
View File
@@ -1,6 +1,5 @@
"""
Report Engine配置管理模块
处理环境变量和配置参数
Configuration management module for the Report Engine.
"""
import os
@@ -8,144 +7,146 @@ from dataclasses import dataclass
from typing import Optional
def _get_value(source, key: str, default=None, *fallback_keys: str):
candidates = (key,) + fallback_keys
value = None
for candidate in candidates:
if isinstance(source, dict):
value = source.get(candidate)
else:
value = getattr(source, candidate, None)
if value not in (None, ""):
break
if value in (None, ""):
for candidate in candidates:
env_val = os.getenv(candidate)
if env_val not in (None, ""):
value = env_val
break
return value if value not in (None, "") else default
@dataclass
class Config:
"""Report Engine配置类"""
# API密钥
gemini_api_key: Optional[str] = None
gemini_base_url: str = "https://www.chataiapi.com/v1"
# 模型配置
default_llm_provider: str = "gemini"
gemini_model: str = "gemini-2.5-pro"
# 报告配置
"""Report Engine configuration."""
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # compatibility
max_content_length: int = 200000
output_dir: str = "final_reports"
template_dir: str = "ReportEngine/report_template"
# 超时配置 - 专门为长报告生成优化(平均生成时间7分钟)
api_timeout: float = 900.0 # API调用超时时间(秒),设置为15分钟,适应7分钟平均生成时间
max_retry_delay: float = 180.0 # 最大重试延迟(秒),设置为3分钟
max_retries: int = 8 # 最大重试次数,增加到8次
# 日志配置
api_timeout: float = 900.0
max_retry_delay: float = 180.0
max_retries: int = 8
log_file: str = "logs/report.log"
# HTML导出配置
enable_pdf_export: bool = True
chart_style: str = "modern" # modern, classic, minimal
chart_style: str = "modern"
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
"""验证配置"""
if not self.gemini_api_key:
print("错误: Gemini API Key未设置")
if not self.llm_api_key:
print("错误: Report Engine LLM API Key 未设置 (REPORT_ENGINE_API_KEY)。")
return False
if not self.llm_model_name:
print("错误: Report Engine 模型名称未设置 (REPORT_ENGINE_MODEL_NAME)。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
"""从配置文件创建配置"""
if config_file.endswith('.py'):
# Python配置文件
if config_file.endswith(".py"):
import importlib.util
# 动态导入配置文件
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
gemini_api_key=getattr(config_module, "GEMINI_API_KEY", None),
gemini_base_url=getattr(config_module, "GEMINI_BASE_URL", "https://www.chataiapi.com/v1"),
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "gemini"),
gemini_model=getattr(config_module, "GEMINI_MODEL", "gemini-2.5-pro"),
max_content_length=getattr(config_module, "MAX_CONTENT_LENGTH", 200000),
output_dir=getattr(config_module, "REPORT_OUTPUT_DIR", "final_reports"),
template_dir=getattr(config_module, "TEMPLATE_DIR", "ReportEngine/report_template"),
api_timeout=getattr(config_module, "REPORT_API_TIMEOUT", 900.0),
max_retry_delay=getattr(config_module, "REPORT_MAX_RETRY_DELAY", 180.0),
max_retries=getattr(config_module, "REPORT_MAX_RETRIES", 8),
log_file=getattr(config_module, "REPORT_LOG_FILE", "logs/report.log"),
enable_pdf_export=getattr(config_module, "ENABLE_PDF_EXPORT", True),
chart_style=getattr(config_module, "CHART_STYLE", "modern")
)
else:
# .env格式配置文件
config_dict = {}
if os.path.exists(config_file):
with open(config_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
config_dict[key.strip()] = value.strip()
return cls(
gemini_api_key=config_dict.get("GEMINI_API_KEY"),
gemini_base_url=config_dict.get("GEMINI_BASE_URL", "https://www.chataiapi.com/v1"),
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "gemini"),
gemini_model=config_dict.get("GEMINI_MODEL", "gemini-2.5-pro"),
max_content_length=int(config_dict.get("MAX_CONTENT_LENGTH", "200000")),
output_dir=config_dict.get("REPORT_OUTPUT_DIR", "final_reports"),
template_dir=config_dict.get("TEMPLATE_DIR", "ReportEngine/report_template"),
api_timeout=float(config_dict.get("REPORT_API_TIMEOUT", "900.0")),
max_retry_delay=float(config_dict.get("REPORT_MAX_RETRY_DELAY", "180.0")),
max_retries=int(config_dict.get("REPORT_MAX_RETRIES", "8")),
log_file=config_dict.get("REPORT_LOG_FILE", "logs/report.log"),
enable_pdf_export=config_dict.get("ENABLE_PDF_EXPORT", "true").lower() == "true",
chart_style=config_dict.get("CHART_STYLE", "modern")
llm_api_key=_get_value(config_module, "REPORT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "REPORT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "REPORT_ENGINE_MODEL_NAME"),
max_content_length=int(_get_value(config_module, "MAX_CONTENT_LENGTH", 200000)),
output_dir=_get_value(config_module, "REPORT_OUTPUT_DIR", "final_reports"),
template_dir=_get_value(config_module, "TEMPLATE_DIR", "ReportEngine/report_template"),
api_timeout=float(_get_value(config_module, "REPORT_API_TIMEOUT", 900.0)),
max_retry_delay=float(_get_value(config_module, "REPORT_MAX_RETRY_DELAY", 180.0)),
max_retries=int(_get_value(config_module, "REPORT_MAX_RETRIES", 8)),
log_file=_get_value(config_module, "REPORT_LOG_FILE", "logs/report.log"),
enable_pdf_export=str(
_get_value(config_module, "ENABLE_PDF_EXPORT", "true")
).lower()
in ("true", "1", "yes"),
chart_style=_get_value(config_module, "CHART_STYLE", "modern"),
)
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "REPORT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "REPORT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "REPORT_ENGINE_MODEL_NAME"),
max_content_length=int(_get_value(config_dict, "MAX_CONTENT_LENGTH", 200000)),
output_dir=_get_value(config_dict, "REPORT_OUTPUT_DIR", "final_reports"),
template_dir=_get_value(config_dict, "TEMPLATE_DIR", "ReportEngine/report_template"),
api_timeout=float(_get_value(config_dict, "REPORT_API_TIMEOUT", 900.0)),
max_retry_delay=float(_get_value(config_dict, "REPORT_MAX_RETRY_DELAY", 180.0)),
max_retries=int(_get_value(config_dict, "REPORT_MAX_RETRIES", 8)),
log_file=_get_value(config_dict, "REPORT_LOG_FILE", "logs/report.log"),
enable_pdf_export=str(
_get_value(config_dict, "ENABLE_PDF_EXPORT", "true")
).lower()
in ("true", "1", "yes"),
chart_style=_get_value(config_dict, "CHART_STYLE", "modern"),
)
def load_config(config_file: Optional[str] = None) -> Config:
"""
加载配置
Args:
config_file: 配置文件路径,如果不指定则使用默认路径
Returns:
配置对象
"""
# 确定配置文件路径
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
# 尝试加载常见的配置文件
for config_path in ["config.py", "config.env", ".env"]:
if os.path.exists(config_path):
file_to_load = config_path
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py 文件")
# 创建配置对象
raise FileNotFoundError("未找到配置文件,请创建 config.py")
config = Config.from_file(file_to_load)
# 验证配置
if not config.validate():
raise ValueError("Report Engine配置验失败,请检查配置文件中的API密钥")
raise ValueError("Report Engine 配置验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Config):
"""打印配置信息(隐藏敏感信息)"""
print("\n=== Report Engine配置 ===")
print(f"LLM提供商: {config.default_llm_provider}")
print(f"Gemini模型: {config.gemini_model}")
print("\n=== Report Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"最大内容长度: {config.max_content_length}")
print(f"输出目录: {config.output_dir}")
print(f"模板目录: {config.template_dir}")
print(f"API超时时间: {config.api_timeout}{config.api_timeout/60:.1f}分钟)")
print(f"最大重试延迟: {config.max_retry_delay}{config.max_retry_delay/60:.1f}分钟)")
print(f"最大重试次数: {config.max_retries}")
print(f"API 超时时间: {config.api_timeout} ")
print(f"最大重试间隔: {config.max_retry_delay} ")
print(f"最大重试次数: {config.max_retries}")
print(f"日志文件: {config.log_file}")
print(f"PDF导出: {config.enable_pdf_export}")
print(f"PDF 导出: {config.enable_pdf_export}")
print(f"图表样式: {config.chart_style}")
print(f"Gemini API Key: {'' if config.gemini_api_key else ''}")
print(f"LLM API Key: {'' if config.llm_api_key else ''}")
print("========================\n")