Reconfiguration of the basic multi-agent architecture.

This commit is contained in:
戒酒的李白
2025-08-22 22:04:08 +08:00
parent bec01f8930
commit 7ae863a781
70 changed files with 6792 additions and 648 deletions
+26
View File
@@ -0,0 +1,26 @@
"""
工具函数模块
提供文本处理、JSON解析等辅助功能
"""
from .text_processing import (
clean_json_tags,
clean_markdown_tags,
remove_reasoning_from_output,
extract_clean_response,
update_state_with_search_results,
format_search_results_for_prompt
)
from .config import Config, load_config
__all__ = [
"clean_json_tags",
"clean_markdown_tags",
"remove_reasoning_from_output",
"extract_clean_response",
"update_state_with_search_results",
"format_search_results_for_prompt",
"Config",
"load_config"
]
+162
View File
@@ -0,0 +1,162 @@
"""
配置管理模块
处理环境变量和配置参数
"""
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class Config:
"""配置类"""
# API密钥
deepseek_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
tavily_api_key: Optional[str] = None
# 模型配置
default_llm_provider: str = "deepseek" # deepseek 或 openai
deepseek_model: str = "deepseek-chat"
openai_model: str = "gpt-4o-mini"
# 搜索配置
search_timeout: int = 240
max_content_length: int = 20000
# Agent配置
max_reflections: int = 2
max_paragraphs: int = 5
# 输出配置
output_dir: str = "reports"
save_intermediate_states: bool = True
def validate(self) -> bool:
"""验证配置"""
# 检查必需的API密钥
if self.default_llm_provider == "deepseek" and not self.deepseek_api_key:
print("错误: DeepSeek API Key未设置")
return False
if self.default_llm_provider == "openai" and not self.openai_api_key:
print("错误: OpenAI API Key未设置")
return False
if not self.tavily_api_key:
print("错误: Tavily API Key未设置")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
"""从配置文件创建配置"""
if config_file.endswith('.py'):
# Python配置文件
import importlib.util
# 动态导入配置文件
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None),
openai_api_key=getattr(config_module, "OPENAI_API_KEY", None),
tavily_api_key=getattr(config_module, "TAVILY_API_KEY", None),
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"),
deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"),
openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"),
search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240),
max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000),
max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2),
max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5),
output_dir=getattr(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True)
)
else:
# .env格式配置文件
config_dict = {}
if os.path.exists(config_file):
with open(config_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
config_dict[key.strip()] = value.strip()
return cls(
deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
openai_api_key=config_dict.get("OPENAI_API_KEY"),
tavily_api_key=config_dict.get("TAVILY_API_KEY"),
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")),
max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")),
max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")),
max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")),
output_dir=config_dict.get("OUTPUT_DIR", "reports"),
save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true"
)
def load_config(config_file: Optional[str] = None) -> Config:
"""
加载配置
Args:
config_file: 配置文件路径,如果不指定则使用默认路径
Returns:
配置对象
"""
# 确定配置文件路径
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
# 尝试加载常见的配置文件
for config_path in ["config.py", "config.env", ".env"]:
if os.path.exists(config_path):
file_to_load = config_path
print(f"已找到配置文件: {config_path}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py 文件")
# 创建配置对象
config = Config.from_file(file_to_load)
# 验证配置
if not config.validate():
raise ValueError("配置验证失败,请检查配置文件中的API密钥")
return config
def print_config(config: Config):
"""打印配置信息(隐藏敏感信息)"""
print("\n=== 当前配置 ===")
print(f"LLM提供商: {config.default_llm_provider}")
print(f"DeepSeek模型: {config.deepseek_model}")
print(f"OpenAI模型: {config.openai_model}")
print(f"最大搜索结果数: {config.max_search_results}")
print(f"搜索超时: {config.search_timeout}")
print(f"最大内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}")
print(f"最大段落数: {config.max_paragraphs}")
print(f"输出目录: {config.output_dir}")
print(f"保存中间状态: {config.save_intermediate_states}")
# 显示API密钥状态(不显示实际密钥)
print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}")
print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}")
print(f"Tavily API Key: {'已设置' if config.tavily_api_key else '未设置'}")
print("==================\n")
+308
View File
@@ -0,0 +1,308 @@
"""
文本处理工具函数
用于清理LLM输出、解析JSON等
"""
import re
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
def clean_json_tags(text: str) -> str:
"""
清理文本中的JSON标签
Args:
text: 原始文本
Returns:
清理后的文本
"""
# 移除```json 和 ```标签
text = re.sub(r'```json\s*', '', text)
text = re.sub(r'```\s*$', '', text)
text = re.sub(r'```', '', text)
return text.strip()
def clean_markdown_tags(text: str) -> str:
"""
清理文本中的Markdown标签
Args:
text: 原始文本
Returns:
清理后的文本
"""
# 移除```markdown 和 ```标签
text = re.sub(r'```markdown\s*', '', text)
text = re.sub(r'```\s*$', '', text)
text = re.sub(r'```', '', text)
return text.strip()
def remove_reasoning_from_output(text: str) -> str:
"""
移除输出中的推理过程文本
Args:
text: 原始文本
Returns:
清理后的文本
"""
# 查找JSON开始位置
json_start = -1
# 尝试找到第一个 { 或 [
for i, char in enumerate(text):
if char in '{[':
json_start = i
break
if json_start != -1:
# 从JSON开始位置截取
return text[json_start:].strip()
# 如果没有找到JSON标记,尝试其他方法
# 移除常见的推理标识
patterns = [
r'(?:reasoning|推理|思考|分析)[:]\s*.*?(?=\{|\[)', # 移除推理部分
r'(?:explanation|解释|说明)[:]\s*.*?(?=\{|\[)', # 移除解释部分
r'^.*?(?=\{|\[)', # 移除JSON前的所有文本
]
for pattern in patterns:
text = re.sub(pattern, '', text, flags=re.IGNORECASE | re.DOTALL)
return text.strip()
def extract_clean_response(text: str) -> Dict[str, Any]:
"""
提取并清理响应中的JSON内容
Args:
text: 原始响应文本
Returns:
解析后的JSON字典
"""
# 清理文本
cleaned_text = clean_json_tags(text)
cleaned_text = remove_reasoning_from_output(cleaned_text)
# 尝试直接解析
try:
return json.loads(cleaned_text)
except JSONDecodeError:
pass
# 尝试修复不完整的JSON
fixed_text = fix_incomplete_json(cleaned_text)
if fixed_text:
try:
return json.loads(fixed_text)
except JSONDecodeError:
pass
# 尝试查找JSON对象
json_pattern = r'\{.*\}'
match = re.search(json_pattern, cleaned_text, re.DOTALL)
if match:
try:
return json.loads(match.group())
except JSONDecodeError:
pass
# 尝试查找JSON数组
array_pattern = r'\[.*\]'
match = re.search(array_pattern, cleaned_text, re.DOTALL)
if match:
try:
return json.loads(match.group())
except JSONDecodeError:
pass
# 如果所有方法都失败,返回错误信息
print(f"无法解析JSON响应: {cleaned_text[:200]}...")
return {"error": "JSON解析失败", "raw_text": cleaned_text}
def fix_incomplete_json(text: str) -> str:
"""
修复不完整的JSON响应
Args:
text: 原始文本
Returns:
修复后的JSON文本,如果无法修复则返回空字符串
"""
# 移除多余的逗号和空白
text = re.sub(r',\s*}', '}', text)
text = re.sub(r',\s*]', ']', text)
# 检查是否已经是有效的JSON
try:
json.loads(text)
return text
except JSONDecodeError:
pass
# 检查是否缺少开头的数组符号
if text.strip().startswith('{') and not text.strip().startswith('['):
# 如果以对象开始,尝试包装成数组
if text.count('{') > 1:
# 多个对象,包装成数组
text = '[' + text + ']'
else:
# 单个对象,包装成数组
text = '[' + text + ']'
# 检查是否缺少结尾的数组符号
if text.strip().endswith('}') and not text.strip().endswith(']'):
# 如果以对象结束,尝试包装成数组
if text.count('}') > 1:
# 多个对象,包装成数组
text = '[' + text + ']'
else:
# 单个对象,包装成数组
text = '[' + text + ']'
# 检查括号是否匹配
open_braces = text.count('{')
close_braces = text.count('}')
open_brackets = text.count('[')
close_brackets = text.count(']')
# 修复不匹配的括号
if open_braces > close_braces:
text += '}' * (open_braces - close_braces)
if open_brackets > close_brackets:
text += ']' * (open_brackets - close_brackets)
# 验证修复后的JSON是否有效
try:
json.loads(text)
return text
except JSONDecodeError:
# 如果仍然无效,尝试更激进的修复
return fix_aggressive_json(text)
def fix_aggressive_json(text: str) -> str:
"""
更激进的JSON修复方法
Args:
text: 原始文本
Returns:
修复后的JSON文本
"""
# 查找所有可能的JSON对象
objects = re.findall(r'\{[^{}]*\}', text)
if len(objects) >= 2:
# 如果有多个对象,包装成数组
return '[' + ','.join(objects) + ']'
elif len(objects) == 1:
# 如果只有一个对象,包装成数组
return '[' + objects[0] + ']'
else:
# 如果没有找到对象,返回空数组
return '[]'
def update_state_with_search_results(search_results: List[Dict[str, Any]],
paragraph_index: int, state: Any) -> Any:
"""
将搜索结果更新到状态中
Args:
search_results: 搜索结果列表
paragraph_index: 段落索引
state: 状态对象
Returns:
更新后的状态对象
"""
if 0 <= paragraph_index < len(state.paragraphs):
# 获取最后一次搜索的查询(假设是当前查询)
current_query = ""
if search_results:
# 从搜索结果推断查询(这里需要改进以获取实际查询)
current_query = "搜索查询"
# 添加搜索结果到状态
state.paragraphs[paragraph_index].research.add_search_results(
current_query, search_results
)
return state
def validate_json_schema(data: Dict[str, Any], required_fields: List[str]) -> bool:
"""
验证JSON数据是否包含必需字段
Args:
data: 要验证的数据
required_fields: 必需字段列表
Returns:
验证是否通过
"""
return all(field in data for field in required_fields)
def truncate_content(content: str, max_length: int = 20000) -> str:
"""
截断内容到指定长度
Args:
content: 原始内容
max_length: 最大长度
Returns:
截断后的内容
"""
if len(content) <= max_length:
return content
# 尝试在单词边界截断
truncated = content[:max_length]
last_space = truncated.rfind(' ')
if last_space > max_length * 0.8: # 如果最后一个空格位置合理
return truncated[:last_space] + "..."
else:
return truncated + "..."
def format_search_results_for_prompt(search_results: List[Dict[str, Any]],
max_length: int = 20000) -> List[str]:
"""
格式化搜索结果用于提示词
Args:
search_results: 搜索结果列表
max_length: 每个结果的最大长度
Returns:
格式化后的内容列表
"""
formatted_results = []
for result in search_results:
content = result.get('content', '')
if content:
truncated_content = truncate_content(content, max_length)
formatted_results.append(truncated_content)
return formatted_results