Completely refactor the LLM integration method to easily replace the LLM used by each module and optimize the retransmission mechanism.

This commit is contained in:
666ghj
2025-10-09 13:45:39 +08:00
parent ce74f00137
commit 154b29c0d7
73 changed files with 942 additions and 51758 deletions
+9 -21
View File
@@ -9,7 +9,7 @@ import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from .llms import DeepSeekLLM, OpenAILLM, GeminiLLM, BaseLLM
from .llms import LLMClient
from .nodes import (
ReportStructureNode,
FirstSearchNode,
@@ -35,6 +35,8 @@ class DeepSearchAgent:
"""
# 加载配置
self.config = config or load_config()
os.environ["BOCHA_API_KEY"] = self.config.bocha_api_key or ""
os.environ["BOCHA_WEB_SEARCH_API_KEY"] = self.config.bocha_api_key or ""
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
@@ -55,27 +57,13 @@ class DeepSearchAgent:
print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)")
def _initialize_llm(self) -> BaseLLM:
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
if self.config.default_llm_provider == "deepseek":
return DeepSeekLLM(
api_key=self.config.deepseek_api_key,
model_name=self.config.deepseek_model,
base_url=self.config.deepseek_base_url
)
elif self.config.default_llm_provider == "openai":
return OpenAILLM(
api_key=self.config.openai_api_key,
model_name=self.config.openai_model
)
elif self.config.default_llm_provider == "gemini":
return GeminiLLM(
api_key=self.config.gemini_api_key,
model_name=self.config.gemini_model,
base_url=self.config.gemini_base_url
)
else:
raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}")
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
)
def _initialize_nodes(self):
"""初始化处理节点"""
+3 -7
View File
@@ -1,11 +1,7 @@
"""
LLM调用模块
支持多种大语言模型的统一接口
LLM module for the Media Engine.
"""
from .base import BaseLLM
from .deepseek import DeepSeekLLM
from .openai_llm import OpenAILLM
from .gemini_llm import GeminiLLM
from .base import LLMClient
__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM", "GeminiLLM"]
__all__ = ["LLMClient"]
+81 -50
View File
@@ -1,61 +1,92 @@
"""
LLM基础抽象类
定义所有LLM实现需要遵循的接口标准
Unified OpenAI-compatible LLM client for the Media Engine, with retry support.
"""
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any
import os
import sys
from typing import Any, Dict, Optional
from openai import OpenAI
# Ensure project-level retry helper is importable
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(project_root, "utils")
if utils_dir not in sys.path:
sys.path.append(utils_dir)
try:
from retry_helper import with_retry, LLM_RETRY_CONFIG
except ImportError:
def with_retry(config=None):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class BaseLLM(ABC):
"""LLM基础抽象类"""
def __init__(self, api_key: str, model_name: Optional[str] = None):
"""
初始化LLM客户端
Args:
api_key: API密钥
model_name: 模型名称,如果不指定则使用默认模型
"""
class LLMClient:
"""
Minimal wrapper around the OpenAI-compatible chat completion API.
"""
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
if not api_key:
raise ValueError("Media Engine LLM API key is required.")
if not model_name:
raise ValueError("Media Engine model name is required.")
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
@abstractmethod
self.provider = model_name
timeout_fallback = os.getenv("LLM_REQUEST_TIMEOUT") or os.getenv("MEDIA_ENGINE_REQUEST_TIMEOUT") or "180"
try:
self.timeout = float(timeout_fallback)
except ValueError:
self.timeout = 300.0
client_kwargs: Dict[str, Any] = {
"api_key": api_key,
"max_retries": 0,
}
if base_url:
client_kwargs["base_url"] = base_url
self.client = OpenAI(**client_kwargs)
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用LLM生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
LLM生成的回复文本
"""
pass
@abstractmethod
def get_default_model(self) -> str:
"""
获取默认模型名称
Returns:
默认模型名称
"""
pass
def validate_response(self, response: str) -> str:
"""
验证和清理响应内容
Args:
response: LLM原始响应
Returns:
清理后的响应内容
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty", "stream"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
timeout = kwargs.pop("timeout", self.timeout)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
if response.choices and response.choices[0].message:
return self.validate_response(response.choices[0].message.content)
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
if response is None:
return ""
return response.strip()
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": self.provider,
"model": self.model_name,
"api_base": self.base_url or "default",
}
-118
View File
@@ -1,118 +0,0 @@
"""
DeepSeek LLM实现
使用DeepSeek API进行文本生成
"""
import os
import sys
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
DEFAULT_DEEPSEEK_BASE_URL = "https://api.deepseek.com"
# 添加utils目录到Python路径并导入重试模块
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(root_dir, 'utils')
if utils_dir not in sys.path:
sys.path.append(utils_dir)
from retry_helper import with_retry, with_graceful_retry, LLM_RETRY_CONFIG
except ImportError:
# 如果无法导入重试模块,使用空装饰器避免报错
def with_retry(config):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class DeepSeekLLM(BaseLLM):
"""DeepSeek LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None, base_url: Optional[str] = None):
"""
初始化DeepSeek客户端
Args:
api_key: DeepSeek API密钥,如果不提供则从环境变量读取
model_name: 模型名称,默认使用deepseek-chat
base_url: DeepSeek API基础地址
"""
if api_key is None:
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DeepSeek API Key未找到!请设置DEEPSEEK_API_KEY环境变量或在初始化时提供")
super().__init__(api_key, model_name)
self.base_url = base_url or os.getenv("DEEPSEEK_BASE_URL") or DEFAULT_DEEPSEEK_BASE_URL
# 初始化OpenAI客户端,使用DeepSeek的endpoint
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "deepseek-chat"
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用DeepSeek API生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
DeepSeek生成的回复文本
"""
try:
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 30000), # 提高到30000以支持一万字报告
"stream": False
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
except Exception as e:
print(f"DeepSeek API调用错误: {str(e)}")
raise e
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "DeepSeek",
"model": self.default_model,
"api_base": self.base_url
}
-118
View File
@@ -1,118 +0,0 @@
"""
Gemini LLM实现
使用Gemini 2.5-pro中转API进行文本生成
"""
import os
import sys
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
DEFAULT_GEMINI_BASE_URL = "https://www.chataiapi.com/v1"
# 添加utils目录到Python路径并导入重试模块
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(root_dir, 'utils')
if utils_dir not in sys.path:
sys.path.append(utils_dir)
from retry_helper import with_retry, with_graceful_retry, LLM_RETRY_CONFIG
except ImportError:
# 如果无法导入重试模块,使用空装饰器避免报错
def with_retry(config):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class GeminiLLM(BaseLLM):
"""Gemini LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None, base_url: Optional[str] = None):
"""
初始化Gemini客户端
Args:
api_key: Gemini API密钥,如果不提供则从环境变量读取
model_name: 模型名称,默认使用gemini-2.5-pro
base_url: Gemini API基础地址
"""
if api_key is None:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("Gemini API Key未找到!请设置GEMINI_API_KEY环境变量或在初始化时提供")
super().__init__(api_key, model_name)
self.base_url = base_url or os.getenv("GEMINI_BASE_URL") or DEFAULT_GEMINI_BASE_URL
# 初始化OpenAI客户端,使用Gemini的中转endpoint
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "gemini-2.5-pro"
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用Gemini API生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
Gemini生成的回复文本
"""
try:
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 30000), # 提高到30000以支持一万字报告
"stream": False
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
except Exception as e:
print(f"Gemini API调用错误: {str(e)}")
raise e
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "Gemini",
"model": self.default_model,
"api_base": self.base_url
}
-108
View File
@@ -1,108 +0,0 @@
"""
OpenAI LLM实现
使用OpenAI API进行文本生成
"""
import os
import sys
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
# 添加utils目录到Python路径并导入重试模块
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.dirname(os.path.dirname(current_dir))
utils_dir = os.path.join(root_dir, 'utils')
if utils_dir not in sys.path:
sys.path.append(utils_dir)
from retry_helper import with_retry, with_graceful_retry, LLM_RETRY_CONFIG
except ImportError:
# 如果无法导入重试模块,使用空装饰器避免报错
def with_retry(config):
def decorator(func):
return func
return decorator
LLM_RETRY_CONFIG = None
class OpenAILLM(BaseLLM):
"""OpenAI LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
"""
初始化OpenAI客户端
Args:
api_key: OpenAI API密钥,如果不提供则从环境变量读取
model_name: 模型名称,默认使用gpt-4o-mini
"""
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OpenAI API Key未找到!请设置OPENAI_API_KEY环境变量或在初始化时提供")
super().__init__(api_key, model_name)
# 初始化OpenAI客户端
self.client = OpenAI(api_key=self.api_key)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "gpt-4o-mini"
@with_retry(LLM_RETRY_CONFIG)
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用OpenAI API生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
OpenAI生成的回复文本
"""
try:
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 30000) # 提高到30000以支持一万字报告
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
except Exception as e:
print(f"OpenAI API调用错误: {str(e)}")
raise e
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "OpenAI",
"model": self.default_model,
"api_base": "https://api.openai.com"
}
+2 -2
View File
@@ -5,14 +5,14 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from ..llms.base import BaseLLM
from ..llms.base import LLMClient
from ..state.state import State
class BaseNode(ABC):
"""节点基类"""
def __init__(self, llm_client: BaseLLM, node_name: str = ""):
def __init__(self, llm_client: LLMClient, node_name: str = ""):
"""
初始化节点
+3 -4
View File
@@ -67,11 +67,10 @@ class ReportFormattingNode(BaseNode):
self.log_info("正在格式化最终报告")
# 调用LLM,传递更大的max_tokens以支持长文本报告
# 调用LLM生成Markdown格式
response = self.llm_client.invoke(
SYSTEM_PROMPT_REPORT_FORMATTING,
message,
max_tokens=30000 # 支持一万字的报告输出
SYSTEM_PROMPT_REPORT_FORMATTING,
message,
)
# 处理响应
+4 -6
View File
@@ -98,11 +98,10 @@ class FirstSummaryNode(StateMutationNode):
self.log_info("正在生成首次段落总结")
# 调用LLM,增加max_tokens以支持更长的总结
# 调用LLM生成总结
response = self.llm_client.invoke(
SYSTEM_PROMPT_FIRST_SUMMARY,
SYSTEM_PROMPT_FIRST_SUMMARY,
message,
max_tokens=15000 # 支持更长的总结内容
)
# 处理响应
@@ -267,11 +266,10 @@ class ReflectionSummaryNode(StateMutationNode):
self.log_info("正在生成反思总结")
# 调用LLM,增加max_tokens以支持更长的总结
# 调用LLM生成总结
response = self.llm_client.invoke(
SYSTEM_PROMPT_REFLECTION_SUMMARY,
SYSTEM_PROMPT_REFLECTION_SUMMARY,
message,
max_tokens=15000 # 支持更长的总结内容
)
# 处理响应
+106 -128
View File
@@ -1,6 +1,5 @@
"""
配置管理模块
处理环境变量和配置参数
Configuration management module for the Media Engine.
"""
import os
@@ -8,172 +7,151 @@ from dataclasses import dataclass
from typing import Optional
def _get_value(source, key: str, default=None, *fallback_keys: str):
candidates = (key,) + fallback_keys
value = None
for candidate in candidates:
if isinstance(source, dict):
value = source.get(candidate)
else:
value = getattr(source, candidate, None)
if value not in (None, ""):
break
if value in (None, ""):
for candidate in candidates:
env_val = os.getenv(candidate)
if env_val not in (None, ""):
value = env_val
break
return value if value not in (None, "") else default
@dataclass
class Config:
"""配置类"""
# API密钥
deepseek_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
gemini_api_key: Optional[str] = None
"""Media Engine configuration."""
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # compatibility
bocha_api_key: Optional[str] = None
deepseek_base_url: str = "https://api.deepseek.com"
openai_base_url: Optional[str] = None
gemini_base_url: str = "https://www.chataiapi.com/v1"
# 模型配置
default_llm_provider: str = "deepseek" # deepseek、openai 或 gemini
deepseek_model: str = "deepseek-chat"
openai_model: str = "gpt-4o-mini"
gemini_model: str = "gemini-2.5-pro"
# 搜索配置
search_timeout: int = 240
max_content_length: int = 20000
# Agent配置
max_reflections: int = 2
max_paragraphs: int = 5
# 输出配置
output_dir: str = "reports"
save_intermediate_states: bool = True
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
"""验证配置"""
# 检查必需的API密钥
if self.default_llm_provider == "deepseek" and not self.deepseek_api_key:
print("错误: DeepSeek API Key未设置")
if not self.llm_api_key:
print("错误: Media Engine LLM API Key 未设置 (MEDIA_ENGINE_API_KEY)。")
return False
if self.default_llm_provider == "openai" and not self.openai_api_key:
print("错误: OpenAI API Key未设置")
if not self.llm_model_name:
print("错误: Media Engine 模型名称未设置 (MEDIA_ENGINE_MODEL_NAME)。")
return False
if self.default_llm_provider == "gemini" and not self.gemini_api_key:
print("错误: Gemini API Key未设置")
return False
if not self.bocha_api_key:
print("错误: Bocha API Key未设置")
print("错误: Bocha API Key 未设置 (BOCHA_WEB_SEARCH_API_KEY)。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
"""从配置文件创建配置"""
if config_file.endswith('.py'):
# Python配置文件
if config_file.endswith(".py"):
import importlib.util
# 动态导入配置文件
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None),
openai_api_key=getattr(config_module, "OPENAI_API_KEY", None),
gemini_api_key=getattr(config_module, "GEMINI_API_KEY", None),
deepseek_base_url=getattr(config_module, "DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
openai_base_url=getattr(config_module, "OPENAI_BASE_URL", None),
gemini_base_url=getattr(config_module, "GEMINI_BASE_URL", "https://www.chataiapi.com/v1"),
bocha_api_key=getattr(config_module, "BOCHA_API_KEY", None),
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"),
deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"),
openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"),
gemini_model=getattr(config_module, "GEMINI_MODEL", "gemini-2.5-pro"),
search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240),
max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000),
max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2),
max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5),
output_dir=getattr(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True)
)
else:
# .env格式配置文件
config_dict = {}
if os.path.exists(config_file):
with open(config_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
config_dict[key.strip()] = value.strip()
return cls(
deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
openai_api_key=config_dict.get("OPENAI_API_KEY"),
gemini_api_key=config_dict.get("GEMINI_API_KEY"),
deepseek_base_url=config_dict.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
openai_base_url=config_dict.get("OPENAI_BASE_URL"),
gemini_base_url=config_dict.get("GEMINI_BASE_URL", "https://www.chataiapi.com/v1"),
bocha_api_key=config_dict.get("BOCHA_API_KEY"),
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
gemini_model=config_dict.get("GEMINI_MODEL", "gemini-2.5-pro"),
search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")),
max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")),
max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")),
max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")),
output_dir=config_dict.get("OUTPUT_DIR", "reports"),
save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true"
llm_api_key=_get_value(config_module, "MEDIA_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "MEDIA_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "MEDIA_ENGINE_MODEL_NAME"),
bocha_api_key=_get_value(
config_module,
"BOCHA_WEB_SEARCH_API_KEY",
None,
"BOCHA_API_KEY",
),
search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 2)),
max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 5)),
output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "MEDIA_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "MEDIA_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "MEDIA_ENGINE_MODEL_NAME"),
bocha_api_key=_get_value(
config_dict,
"BOCHA_WEB_SEARCH_API_KEY",
None,
"BOCHA_API_KEY",
),
search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 2)),
max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 5)),
output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
def load_config(config_file: Optional[str] = None) -> Config:
"""
加载配置
Args:
config_file: 配置文件路径,如果不指定则使用默认路径
Returns:
配置对象
"""
# 确定配置文件路径
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
# 尝试加载常见的配置文件
for config_path in ["config.py", "config.env", ".env"]:
if os.path.exists(config_path):
file_to_load = config_path
print(f"已找到配置文件: {config_path}")
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py 文件")
# 创建配置对象
raise FileNotFoundError("未找到配置文件,请创建 config.py")
config = Config.from_file(file_to_load)
# 验证配置
if not config.validate():
raise ValueError("配置验失败,请检查配置文件中的API密钥")
raise ValueError("配置验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Config):
"""打印配置信息(隐藏敏感信息)"""
print("\n=== 当前配置 ===")
print(f"LLM提供商: {config.default_llm_provider}")
print(f"DeepSeek模型: {config.deepseek_model}")
print(f"OpenAI模型: {config.openai_model}")
print(f"搜索超时: {config.search_timeout}")
print(f"最大内容长度: {config.max_content_length}")
print("\n=== Media Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"Bocha API Key: {'已配置' if config.bocha_api_key else '未配置'}")
print(f"搜索超时: {config.search_timeout}")
print(f"最长内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}")
print(f"最大段落数: {config.max_paragraphs}")
print(f"输出目录: {config.output_dir}")
print(f"保存中间状态: {config.save_intermediate_states}")
# 显示API密钥状态(不显示实际密钥)
print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}")
print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}")
print(f"Bocha API Key: {'已设置' if config.bocha_api_key else '未设置'}")
print("==================\n")
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
print("========================\n")