Completely refactor the LLM integration method to easily replace the LLM used by each module and optimize the retransmission mechanism.
This commit is contained in:
+7
-11
@@ -9,7 +9,7 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from .llms import GeminiLLM, BaseLLM
|
||||
from .llms import LLMClient
|
||||
from .nodes import (
|
||||
TemplateSelectionNode,
|
||||
HTMLGenerationNode
|
||||
@@ -186,17 +186,13 @@ class ReportAgent:
|
||||
}
|
||||
self.file_baseline.initialize_baseline(directories)
|
||||
|
||||
def _initialize_llm(self) -> BaseLLM:
|
||||
def _initialize_llm(self) -> LLMClient:
|
||||
"""初始化LLM客户端"""
|
||||
if self.config.default_llm_provider == "gemini":
|
||||
return GeminiLLM(
|
||||
api_key=self.config.gemini_api_key,
|
||||
model_name=self.config.gemini_model,
|
||||
base_url=self.config.gemini_base_url,
|
||||
config=self.config # 传入配置对象以支持动态超时设置
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}")
|
||||
return LLMClient(
|
||||
api_key=self.config.llm_api_key,
|
||||
model_name=self.config.llm_model_name,
|
||||
base_url=self.config.llm_base_url,
|
||||
)
|
||||
|
||||
def _initialize_nodes(self):
|
||||
"""初始化处理节点"""
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""
|
||||
Report Engine LLM模块
|
||||
包含各种大语言模型的接口实现
|
||||
LLM module for the Report Engine.
|
||||
"""
|
||||
|
||||
from .base import BaseLLM
|
||||
from .gemini_llm import GeminiLLM
|
||||
from .base import LLMClient
|
||||
|
||||
__all__ = ["BaseLLM", "GeminiLLM"]
|
||||
__all__ = ["LLMClient"]
|
||||
|
||||
+80
-86
@@ -1,95 +1,89 @@
|
||||
"""
|
||||
Report Engine LLM基类
|
||||
定义所有LLM实现的基础接口
|
||||
Unified OpenAI-compatible LLM client for the Report Engine, with retry support.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
utils_dir = os.path.join(project_root, "utils")
|
||||
if utils_dir not in sys.path:
|
||||
sys.path.append(utils_dir)
|
||||
|
||||
try:
|
||||
from retry_helper import with_retry, LLM_RETRY_CONFIG
|
||||
except ImportError:
|
||||
def with_retry(config=None):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
LLM_RETRY_CONFIG = None
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""LLM基类"""
|
||||
|
||||
def __init__(self, api_key: str, model_name: Optional[str] = None):
|
||||
"""
|
||||
初始化LLM客户端
|
||||
|
||||
Args:
|
||||
api_key: API密钥
|
||||
model_name: 模型名称
|
||||
"""
|
||||
class LLMClient:
|
||||
"""Minimal wrapper around the OpenAI-compatible chat completion API."""
|
||||
|
||||
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
|
||||
if not api_key:
|
||||
raise ValueError("Report Engine LLM API key is required.")
|
||||
if not model_name:
|
||||
raise ValueError("Report Engine model name is required.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
|
||||
@abstractmethod
|
||||
self.provider = model_name
|
||||
timeout_fallback = os.getenv("LLM_REQUEST_TIMEOUT") or os.getenv("REPORT_ENGINE_REQUEST_TIMEOUT") or "180"
|
||||
try:
|
||||
self.timeout = float(timeout_fallback)
|
||||
except ValueError:
|
||||
self.timeout = 300.0
|
||||
|
||||
client_kwargs: Dict[str, Any] = {
|
||||
"api_key": api_key,
|
||||
"max_retries": 0,
|
||||
}
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
|
||||
@with_retry(LLM_RETRY_CONFIG)
|
||||
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM生成回复
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成的回复文本
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""
|
||||
获取默认模型名称
|
||||
|
||||
Returns:
|
||||
默认模型名称
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_response(self, response: str) -> str:
|
||||
"""
|
||||
验证和清理响应内容
|
||||
|
||||
Args:
|
||||
response: 原始响应
|
||||
|
||||
Returns:
|
||||
清理后的响应
|
||||
"""
|
||||
if not response:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty", "stream"}
|
||||
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
|
||||
|
||||
timeout = kwargs.pop("timeout", self.timeout)
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
timeout=timeout,
|
||||
**extra_params,
|
||||
)
|
||||
|
||||
if response.choices and response.choices[0].message:
|
||||
return self.validate_response(response.choices[0].message.content)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def validate_response(response: Optional[str]) -> str:
|
||||
if response is None:
|
||||
return ""
|
||||
|
||||
# 移除多余的空白字符
|
||||
response = response.strip()
|
||||
|
||||
# 确保响应不为空
|
||||
if not response:
|
||||
return "抱歉,生成的内容为空。"
|
||||
|
||||
return response
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""
|
||||
估算文本的token数量(简单实现)
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
估算的token数量
|
||||
"""
|
||||
# 简单估算:中文字符按1.5个token计算,英文单词按1个token计算
|
||||
chinese_chars = len([c for c in text if '\u4e00' <= c <= '\u9fff'])
|
||||
english_words = len(text.split()) - chinese_chars
|
||||
|
||||
return int(chinese_chars * 1.5 + english_words)
|
||||
return response.strip()
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"model": self.model_name,
|
||||
"api_base": self.base_url or "default",
|
||||
}
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
"""
|
||||
Report Engine Gemini LLM实现
|
||||
使用Gemini 2.5-pro中转API进行文本生成
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional, Dict, Any
|
||||
from openai import OpenAI
|
||||
from .base import BaseLLM
|
||||
|
||||
DEFAULT_GEMINI_BASE_URL = "https://www.chataiapi.com/v1"
|
||||
|
||||
# 导入根目录的config
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||
if root_dir not in sys.path:
|
||||
sys.path.append(root_dir)
|
||||
import config
|
||||
except ImportError:
|
||||
config = None
|
||||
|
||||
# 添加utils目录到Python路径并导入重试模块
|
||||
try:
|
||||
if root_dir:
|
||||
utils_dir = os.path.join(root_dir, 'utils')
|
||||
if utils_dir not in sys.path:
|
||||
sys.path.append(utils_dir)
|
||||
from retry_helper import with_retry, with_graceful_retry, LLM_RETRY_CONFIG, RetryConfig
|
||||
# 创建动态重试配置生成函数
|
||||
def create_report_retry_config(config=None):
|
||||
"""创建ReportEngine专用的重试配置,适应7分钟平均生成时间"""
|
||||
return RetryConfig(
|
||||
max_retries=config.max_retries if config and hasattr(config, 'max_retries') else 8,
|
||||
initial_delay=8.0, # 初始延迟增加到8秒,适应长时间生成
|
||||
backoff_factor=2.0, # 保持2倍退避
|
||||
max_delay=config.max_retry_delay if config and hasattr(config, 'max_retry_delay') else 180.0
|
||||
)
|
||||
# 创建默认配置用于模块导入时的向后兼容
|
||||
REPORT_LLM_RETRY_CONFIG = create_report_retry_config()
|
||||
except ImportError:
|
||||
# 如果无法导入重试模块,使用空装饰器避免报错
|
||||
def with_retry(config):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
LLM_RETRY_CONFIG = None
|
||||
REPORT_LLM_RETRY_CONFIG = None
|
||||
|
||||
|
||||
class GeminiLLM(BaseLLM):
|
||||
"""Report Engine Gemini LLM实现类"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None, base_url: Optional[str] = None, config=None):
|
||||
"""
|
||||
初始化Gemini客户端
|
||||
|
||||
Args:
|
||||
api_key: Gemini API密钥,如果不提供则从config或环境变量读取
|
||||
model_name: 模型名称,默认使用gemini-2.5-pro
|
||||
base_url: Gemini API基础地址
|
||||
config: 配置对象,用于获取超时设置
|
||||
"""
|
||||
if api_key is None:
|
||||
# 优先从根目录config读取
|
||||
if config and hasattr(config, 'GEMINI_API_KEY'):
|
||||
api_key = config.GEMINI_API_KEY
|
||||
else:
|
||||
# 备选方案:从环境变量读取
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Gemini API Key未找到!请在config.py中设置GEMINI_API_KEY或设置环境变量")
|
||||
|
||||
super().__init__(api_key, model_name)
|
||||
|
||||
# 存储配置对象
|
||||
self.config = config
|
||||
|
||||
# 从配置获取超时时间,默认15分钟(适应7分钟平均生成时间)
|
||||
timeout = config.api_timeout if config and hasattr(config, 'api_timeout') else 900.0
|
||||
|
||||
self.base_url = (
|
||||
base_url
|
||||
or (getattr(self.config, 'gemini_base_url', None) if self.config else None)
|
||||
or os.getenv('GEMINI_BASE_URL')
|
||||
or DEFAULT_GEMINI_BASE_URL
|
||||
)
|
||||
|
||||
# 创建针对此实例的重试配置
|
||||
self.retry_config = create_report_retry_config(config)
|
||||
|
||||
# 初始化OpenAI客户端,使用Gemini的中转endpoint
|
||||
# 专门为报告生成设置长超时(15分钟),适应7分钟平均生成时间
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
self.default_model = model_name or self.get_default_model()
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""获取默认模型名称"""
|
||||
return "gemini-2.5-pro"
|
||||
|
||||
def _make_api_call(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
内部API调用方法
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
API响应内容
|
||||
"""
|
||||
# 构建消息
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 设置默认参数
|
||||
params = {
|
||||
"model": self.default_model,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 50000),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# 调用API
|
||||
response = self.client.chat.completions.create(**params)
|
||||
|
||||
# 提取回复内容
|
||||
if response.choices and response.choices[0].message:
|
||||
content = response.choices[0].message.content
|
||||
return self.validate_response(content)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
调用Gemini API生成回复(带动态重试配置)
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数,如temperature、max_tokens等
|
||||
|
||||
Returns:
|
||||
Gemini生成的回复文本
|
||||
"""
|
||||
import time
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.retry_config.max_retries + 1):
|
||||
try:
|
||||
result = self._make_api_call(system_prompt, user_prompt, **kwargs)
|
||||
if attempt > 0:
|
||||
print(f"Report Engine Gemini API在第 {attempt + 1} 次尝试后成功")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt == self.retry_config.max_retries:
|
||||
print(f"Report Engine Gemini API在 {self.retry_config.max_retries + 1} 次尝试后仍然失败")
|
||||
print(f"最终错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
# 计算延迟时间
|
||||
delay = min(
|
||||
self.retry_config.initial_delay * (self.retry_config.backoff_factor ** attempt),
|
||||
self.retry_config.max_delay
|
||||
)
|
||||
|
||||
print(f"Report Engine Gemini API第 {attempt + 1} 次尝试失败: {str(e)}")
|
||||
print(f"将在 {delay:.1f} 秒后进行第 {attempt + 2} 次尝试...")
|
||||
|
||||
time.sleep(delay)
|
||||
|
||||
# 这里不应该到达,但作为安全网
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
return {
|
||||
"provider": "Gemini",
|
||||
"model": self.default_model,
|
||||
"api_base": self.base_url,
|
||||
"purpose": "Report Generation"
|
||||
}
|
||||
@@ -6,14 +6,14 @@ Report Engine节点基类
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from ..llms.base import BaseLLM
|
||||
from ..llms.base import LLMClient
|
||||
from ..state.state import ReportState
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类"""
|
||||
|
||||
def __init__(self, llm_client: BaseLLM, node_name: str = ""):
|
||||
def __init__(self, llm_client: LLMClient, node_name: str = ""):
|
||||
"""
|
||||
初始化节点
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..llms.base import BaseLLM
|
||||
from ..llms.base import LLMClient
|
||||
from ..state.state import ReportState
|
||||
from ..prompts import SYSTEM_PROMPT_HTML_GENERATION
|
||||
# 不再需要text_processing依赖
|
||||
@@ -17,7 +17,7 @@ from ..prompts import SYSTEM_PROMPT_HTML_GENERATION
|
||||
class HTMLGenerationNode(StateMutationNode):
|
||||
"""HTML生成处理节点"""
|
||||
|
||||
def __init__(self, llm_client: BaseLLM):
|
||||
def __init__(self, llm_client: LLMClient):
|
||||
"""
|
||||
初始化HTML生成节点
|
||||
|
||||
|
||||
+104
-103
@@ -1,6 +1,5 @@
|
||||
"""
|
||||
Report Engine配置管理模块
|
||||
处理环境变量和配置参数
|
||||
Configuration management module for the Report Engine.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -8,144 +7,146 @@ from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _get_value(source, key: str, default=None, *fallback_keys: str):
|
||||
candidates = (key,) + fallback_keys
|
||||
value = None
|
||||
for candidate in candidates:
|
||||
if isinstance(source, dict):
|
||||
value = source.get(candidate)
|
||||
else:
|
||||
value = getattr(source, candidate, None)
|
||||
if value not in (None, ""):
|
||||
break
|
||||
if value in (None, ""):
|
||||
for candidate in candidates:
|
||||
env_val = os.getenv(candidate)
|
||||
if env_val not in (None, ""):
|
||||
value = env_val
|
||||
break
|
||||
return value if value not in (None, "") else default
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Report Engine配置类"""
|
||||
# API密钥
|
||||
gemini_api_key: Optional[str] = None
|
||||
gemini_base_url: str = "https://www.chataiapi.com/v1"
|
||||
|
||||
# 模型配置
|
||||
default_llm_provider: str = "gemini"
|
||||
gemini_model: str = "gemini-2.5-pro"
|
||||
|
||||
# 报告配置
|
||||
"""Report Engine configuration."""
|
||||
|
||||
llm_api_key: Optional[str] = None
|
||||
llm_base_url: Optional[str] = None
|
||||
llm_model_name: Optional[str] = None
|
||||
llm_provider: Optional[str] = None # compatibility
|
||||
|
||||
max_content_length: int = 200000
|
||||
output_dir: str = "final_reports"
|
||||
template_dir: str = "ReportEngine/report_template"
|
||||
|
||||
# 超时配置 - 专门为长报告生成优化(平均生成时间7分钟)
|
||||
api_timeout: float = 900.0 # API调用超时时间(秒),设置为15分钟,适应7分钟平均生成时间
|
||||
max_retry_delay: float = 180.0 # 最大重试延迟(秒),设置为3分钟
|
||||
max_retries: int = 8 # 最大重试次数,增加到8次
|
||||
|
||||
# 日志配置
|
||||
|
||||
api_timeout: float = 900.0
|
||||
max_retry_delay: float = 180.0
|
||||
max_retries: int = 8
|
||||
|
||||
log_file: str = "logs/report.log"
|
||||
|
||||
# HTML导出配置
|
||||
enable_pdf_export: bool = True
|
||||
chart_style: str = "modern" # modern, classic, minimal
|
||||
|
||||
chart_style: str = "modern"
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.llm_provider and self.llm_model_name:
|
||||
self.llm_provider = self.llm_model_name
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置"""
|
||||
if not self.gemini_api_key:
|
||||
print("错误: Gemini API Key未设置")
|
||||
if not self.llm_api_key:
|
||||
print("错误: Report Engine LLM API Key 未设置 (REPORT_ENGINE_API_KEY)。")
|
||||
return False
|
||||
if not self.llm_model_name:
|
||||
print("错误: Report Engine 模型名称未设置 (REPORT_ENGINE_MODEL_NAME)。")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_file: str) -> "Config":
|
||||
"""从配置文件创建配置"""
|
||||
if config_file.endswith('.py'):
|
||||
# Python配置文件
|
||||
if config_file.endswith(".py"):
|
||||
import importlib.util
|
||||
|
||||
# 动态导入配置文件
|
||||
|
||||
spec = importlib.util.spec_from_file_location("config", config_file)
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config_module)
|
||||
|
||||
|
||||
return cls(
|
||||
gemini_api_key=getattr(config_module, "GEMINI_API_KEY", None),
|
||||
gemini_base_url=getattr(config_module, "GEMINI_BASE_URL", "https://www.chataiapi.com/v1"),
|
||||
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "gemini"),
|
||||
gemini_model=getattr(config_module, "GEMINI_MODEL", "gemini-2.5-pro"),
|
||||
max_content_length=getattr(config_module, "MAX_CONTENT_LENGTH", 200000),
|
||||
output_dir=getattr(config_module, "REPORT_OUTPUT_DIR", "final_reports"),
|
||||
template_dir=getattr(config_module, "TEMPLATE_DIR", "ReportEngine/report_template"),
|
||||
api_timeout=getattr(config_module, "REPORT_API_TIMEOUT", 900.0),
|
||||
max_retry_delay=getattr(config_module, "REPORT_MAX_RETRY_DELAY", 180.0),
|
||||
max_retries=getattr(config_module, "REPORT_MAX_RETRIES", 8),
|
||||
log_file=getattr(config_module, "REPORT_LOG_FILE", "logs/report.log"),
|
||||
enable_pdf_export=getattr(config_module, "ENABLE_PDF_EXPORT", True),
|
||||
chart_style=getattr(config_module, "CHART_STYLE", "modern")
|
||||
)
|
||||
else:
|
||||
# .env格式配置文件
|
||||
config_dict = {}
|
||||
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
config_dict[key.strip()] = value.strip()
|
||||
|
||||
return cls(
|
||||
gemini_api_key=config_dict.get("GEMINI_API_KEY"),
|
||||
gemini_base_url=config_dict.get("GEMINI_BASE_URL", "https://www.chataiapi.com/v1"),
|
||||
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "gemini"),
|
||||
gemini_model=config_dict.get("GEMINI_MODEL", "gemini-2.5-pro"),
|
||||
max_content_length=int(config_dict.get("MAX_CONTENT_LENGTH", "200000")),
|
||||
output_dir=config_dict.get("REPORT_OUTPUT_DIR", "final_reports"),
|
||||
template_dir=config_dict.get("TEMPLATE_DIR", "ReportEngine/report_template"),
|
||||
api_timeout=float(config_dict.get("REPORT_API_TIMEOUT", "900.0")),
|
||||
max_retry_delay=float(config_dict.get("REPORT_MAX_RETRY_DELAY", "180.0")),
|
||||
max_retries=int(config_dict.get("REPORT_MAX_RETRIES", "8")),
|
||||
log_file=config_dict.get("REPORT_LOG_FILE", "logs/report.log"),
|
||||
enable_pdf_export=config_dict.get("ENABLE_PDF_EXPORT", "true").lower() == "true",
|
||||
chart_style=config_dict.get("CHART_STYLE", "modern")
|
||||
llm_api_key=_get_value(config_module, "REPORT_ENGINE_API_KEY"),
|
||||
llm_base_url=_get_value(config_module, "REPORT_ENGINE_BASE_URL"),
|
||||
llm_model_name=_get_value(config_module, "REPORT_ENGINE_MODEL_NAME"),
|
||||
max_content_length=int(_get_value(config_module, "MAX_CONTENT_LENGTH", 200000)),
|
||||
output_dir=_get_value(config_module, "REPORT_OUTPUT_DIR", "final_reports"),
|
||||
template_dir=_get_value(config_module, "TEMPLATE_DIR", "ReportEngine/report_template"),
|
||||
api_timeout=float(_get_value(config_module, "REPORT_API_TIMEOUT", 900.0)),
|
||||
max_retry_delay=float(_get_value(config_module, "REPORT_MAX_RETRY_DELAY", 180.0)),
|
||||
max_retries=int(_get_value(config_module, "REPORT_MAX_RETRIES", 8)),
|
||||
log_file=_get_value(config_module, "REPORT_LOG_FILE", "logs/report.log"),
|
||||
enable_pdf_export=str(
|
||||
_get_value(config_module, "ENABLE_PDF_EXPORT", "true")
|
||||
).lower()
|
||||
in ("true", "1", "yes"),
|
||||
chart_style=_get_value(config_module, "CHART_STYLE", "modern"),
|
||||
)
|
||||
|
||||
config_dict = {}
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
config_dict[key.strip()] = value.strip()
|
||||
|
||||
return cls(
|
||||
llm_api_key=_get_value(config_dict, "REPORT_ENGINE_API_KEY"),
|
||||
llm_base_url=_get_value(config_dict, "REPORT_ENGINE_BASE_URL"),
|
||||
llm_model_name=_get_value(config_dict, "REPORT_ENGINE_MODEL_NAME"),
|
||||
max_content_length=int(_get_value(config_dict, "MAX_CONTENT_LENGTH", 200000)),
|
||||
output_dir=_get_value(config_dict, "REPORT_OUTPUT_DIR", "final_reports"),
|
||||
template_dir=_get_value(config_dict, "TEMPLATE_DIR", "ReportEngine/report_template"),
|
||||
api_timeout=float(_get_value(config_dict, "REPORT_API_TIMEOUT", 900.0)),
|
||||
max_retry_delay=float(_get_value(config_dict, "REPORT_MAX_RETRY_DELAY", 180.0)),
|
||||
max_retries=int(_get_value(config_dict, "REPORT_MAX_RETRIES", 8)),
|
||||
log_file=_get_value(config_dict, "REPORT_LOG_FILE", "logs/report.log"),
|
||||
enable_pdf_export=str(
|
||||
_get_value(config_dict, "ENABLE_PDF_EXPORT", "true")
|
||||
).lower()
|
||||
in ("true", "1", "yes"),
|
||||
chart_style=_get_value(config_dict, "CHART_STYLE", "modern"),
|
||||
)
|
||||
|
||||
|
||||
def load_config(config_file: Optional[str] = None) -> Config:
|
||||
"""
|
||||
加载配置
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径,如果不指定则使用默认路径
|
||||
|
||||
Returns:
|
||||
配置对象
|
||||
"""
|
||||
# 确定配置文件路径
|
||||
if config_file:
|
||||
if not os.path.exists(config_file):
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_file}")
|
||||
file_to_load = config_file
|
||||
else:
|
||||
# 尝试加载常见的配置文件
|
||||
for config_path in ["config.py", "config.env", ".env"]:
|
||||
if os.path.exists(config_path):
|
||||
file_to_load = config_path
|
||||
for candidate in ("config.py", "config.env", ".env"):
|
||||
if os.path.exists(candidate):
|
||||
file_to_load = candidate
|
||||
print(f"已找到配置文件: {candidate}")
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError("未找到配置文件,请创建 config.py 文件")
|
||||
|
||||
# 创建配置对象
|
||||
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
|
||||
|
||||
config = Config.from_file(file_to_load)
|
||||
|
||||
# 验证配置
|
||||
if not config.validate():
|
||||
raise ValueError("Report Engine配置验证失败,请检查配置文件中的API密钥")
|
||||
|
||||
raise ValueError("Report Engine 配置校验失败,请检查 config.py 中的相关配置。")
|
||||
return config
|
||||
|
||||
|
||||
def print_config(config: Config):
|
||||
"""打印配置信息(隐藏敏感信息)"""
|
||||
print("\n=== Report Engine配置 ===")
|
||||
print(f"LLM提供商: {config.default_llm_provider}")
|
||||
print(f"Gemini模型: {config.gemini_model}")
|
||||
print("\n=== Report Engine 配置 ===")
|
||||
print(f"LLM 模型: {config.llm_model_name}")
|
||||
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
|
||||
print(f"最大内容长度: {config.max_content_length}")
|
||||
print(f"输出目录: {config.output_dir}")
|
||||
print(f"模板目录: {config.template_dir}")
|
||||
print(f"API超时时间: {config.api_timeout}秒({config.api_timeout/60:.1f}分钟)")
|
||||
print(f"最大重试延迟: {config.max_retry_delay}秒({config.max_retry_delay/60:.1f}分钟)")
|
||||
print(f"最大重试次数: {config.max_retries}次")
|
||||
print(f"API 超时时间: {config.api_timeout} 秒")
|
||||
print(f"最大重试间隔: {config.max_retry_delay} 秒")
|
||||
print(f"最大重试次数: {config.max_retries}")
|
||||
print(f"日志文件: {config.log_file}")
|
||||
print(f"PDF导出: {config.enable_pdf_export}")
|
||||
print(f"PDF 导出: {config.enable_pdf_export}")
|
||||
print(f"图表样式: {config.chart_style}")
|
||||
print(f"Gemini API Key: {'已设置' if config.gemini_api_key else '未设置'}")
|
||||
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
|
||||
print("========================\n")
|
||||
|
||||
Reference in New Issue
Block a user