Reconfiguration of the basic multi-agent architecture.
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
工具函数模块
|
||||
提供文本处理、JSON解析等辅助功能
|
||||
"""
|
||||
|
||||
from .text_processing import (
|
||||
clean_json_tags,
|
||||
clean_markdown_tags,
|
||||
remove_reasoning_from_output,
|
||||
extract_clean_response,
|
||||
update_state_with_search_results,
|
||||
format_search_results_for_prompt
|
||||
)
|
||||
|
||||
from .config import Config, load_config
|
||||
|
||||
__all__ = [
|
||||
"clean_json_tags",
|
||||
"clean_markdown_tags",
|
||||
"remove_reasoning_from_output",
|
||||
"extract_clean_response",
|
||||
"update_state_with_search_results",
|
||||
"format_search_results_for_prompt",
|
||||
"Config",
|
||||
"load_config"
|
||||
]
|
||||
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
配置管理模块
|
||||
处理环境变量和配置参数
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""配置类"""
|
||||
# API密钥
|
||||
deepseek_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
tavily_api_key: Optional[str] = None
|
||||
|
||||
# 模型配置
|
||||
default_llm_provider: str = "deepseek" # deepseek 或 openai
|
||||
deepseek_model: str = "deepseek-chat"
|
||||
openai_model: str = "gpt-4o-mini"
|
||||
|
||||
# 搜索配置
|
||||
search_timeout: int = 240
|
||||
max_content_length: int = 20000
|
||||
|
||||
# Agent配置
|
||||
max_reflections: int = 2
|
||||
max_paragraphs: int = 5
|
||||
|
||||
# 输出配置
|
||||
output_dir: str = "reports"
|
||||
save_intermediate_states: bool = True
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置"""
|
||||
# 检查必需的API密钥
|
||||
if self.default_llm_provider == "deepseek" and not self.deepseek_api_key:
|
||||
print("错误: DeepSeek API Key未设置")
|
||||
return False
|
||||
|
||||
if self.default_llm_provider == "openai" and not self.openai_api_key:
|
||||
print("错误: OpenAI API Key未设置")
|
||||
return False
|
||||
|
||||
if not self.tavily_api_key:
|
||||
print("错误: Tavily API Key未设置")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_file: str) -> "Config":
|
||||
"""从配置文件创建配置"""
|
||||
if config_file.endswith('.py'):
|
||||
# Python配置文件
|
||||
import importlib.util
|
||||
|
||||
# 动态导入配置文件
|
||||
spec = importlib.util.spec_from_file_location("config", config_file)
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config_module)
|
||||
|
||||
return cls(
|
||||
deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None),
|
||||
openai_api_key=getattr(config_module, "OPENAI_API_KEY", None),
|
||||
tavily_api_key=getattr(config_module, "TAVILY_API_KEY", None),
|
||||
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"),
|
||||
deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"),
|
||||
openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"),
|
||||
|
||||
search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240),
|
||||
max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000),
|
||||
max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2),
|
||||
max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5),
|
||||
output_dir=getattr(config_module, "OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True)
|
||||
)
|
||||
else:
|
||||
# .env格式配置文件
|
||||
config_dict = {}
|
||||
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
config_dict[key.strip()] = value.strip()
|
||||
|
||||
return cls(
|
||||
deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
|
||||
openai_api_key=config_dict.get("OPENAI_API_KEY"),
|
||||
tavily_api_key=config_dict.get("TAVILY_API_KEY"),
|
||||
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
|
||||
deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
|
||||
openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||
|
||||
search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")),
|
||||
max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")),
|
||||
max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")),
|
||||
max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")),
|
||||
output_dir=config_dict.get("OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def load_config(config_file: Optional[str] = None) -> Config:
|
||||
"""
|
||||
加载配置
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径,如果不指定则使用默认路径
|
||||
|
||||
Returns:
|
||||
配置对象
|
||||
"""
|
||||
# 确定配置文件路径
|
||||
if config_file:
|
||||
if not os.path.exists(config_file):
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_file}")
|
||||
file_to_load = config_file
|
||||
else:
|
||||
# 尝试加载常见的配置文件
|
||||
for config_path in ["config.py", "config.env", ".env"]:
|
||||
if os.path.exists(config_path):
|
||||
file_to_load = config_path
|
||||
print(f"已找到配置文件: {config_path}")
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError("未找到配置文件,请创建 config.py 文件")
|
||||
|
||||
# 创建配置对象
|
||||
config = Config.from_file(file_to_load)
|
||||
|
||||
# 验证配置
|
||||
if not config.validate():
|
||||
raise ValueError("配置验证失败,请检查配置文件中的API密钥")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def print_config(config: Config):
|
||||
"""打印配置信息(隐藏敏感信息)"""
|
||||
print("\n=== 当前配置 ===")
|
||||
print(f"LLM提供商: {config.default_llm_provider}")
|
||||
print(f"DeepSeek模型: {config.deepseek_model}")
|
||||
print(f"OpenAI模型: {config.openai_model}")
|
||||
print(f"最大搜索结果数: {config.max_search_results}")
|
||||
print(f"搜索超时: {config.search_timeout}秒")
|
||||
print(f"最大内容长度: {config.max_content_length}")
|
||||
print(f"最大反思次数: {config.max_reflections}")
|
||||
print(f"最大段落数: {config.max_paragraphs}")
|
||||
print(f"输出目录: {config.output_dir}")
|
||||
print(f"保存中间状态: {config.save_intermediate_states}")
|
||||
|
||||
# 显示API密钥状态(不显示实际密钥)
|
||||
print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}")
|
||||
print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}")
|
||||
print(f"Tavily API Key: {'已设置' if config.tavily_api_key else '未设置'}")
|
||||
print("==================\n")
|
||||
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
文本处理工具函数
|
||||
用于清理LLM输出、解析JSON等
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
|
||||
def clean_json_tags(text: str) -> str:
|
||||
"""
|
||||
清理文本中的JSON标签
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
清理后的文本
|
||||
"""
|
||||
# 移除```json 和 ```标签
|
||||
text = re.sub(r'```json\s*', '', text)
|
||||
text = re.sub(r'```\s*$', '', text)
|
||||
text = re.sub(r'```', '', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def clean_markdown_tags(text: str) -> str:
|
||||
"""
|
||||
清理文本中的Markdown标签
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
清理后的文本
|
||||
"""
|
||||
# 移除```markdown 和 ```标签
|
||||
text = re.sub(r'```markdown\s*', '', text)
|
||||
text = re.sub(r'```\s*$', '', text)
|
||||
text = re.sub(r'```', '', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def remove_reasoning_from_output(text: str) -> str:
|
||||
"""
|
||||
移除输出中的推理过程文本
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
清理后的文本
|
||||
"""
|
||||
# 查找JSON开始位置
|
||||
json_start = -1
|
||||
|
||||
# 尝试找到第一个 { 或 [
|
||||
for i, char in enumerate(text):
|
||||
if char in '{[':
|
||||
json_start = i
|
||||
break
|
||||
|
||||
if json_start != -1:
|
||||
# 从JSON开始位置截取
|
||||
return text[json_start:].strip()
|
||||
|
||||
# 如果没有找到JSON标记,尝试其他方法
|
||||
# 移除常见的推理标识
|
||||
patterns = [
|
||||
r'(?:reasoning|推理|思考|分析)[::]\s*.*?(?=\{|\[)', # 移除推理部分
|
||||
r'(?:explanation|解释|说明)[::]\s*.*?(?=\{|\[)', # 移除解释部分
|
||||
r'^.*?(?=\{|\[)', # 移除JSON前的所有文本
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
text = re.sub(pattern, '', text, flags=re.IGNORECASE | re.DOTALL)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_clean_response(text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
提取并清理响应中的JSON内容
|
||||
|
||||
Args:
|
||||
text: 原始响应文本
|
||||
|
||||
Returns:
|
||||
解析后的JSON字典
|
||||
"""
|
||||
# 清理文本
|
||||
cleaned_text = clean_json_tags(text)
|
||||
cleaned_text = remove_reasoning_from_output(cleaned_text)
|
||||
|
||||
# 尝试直接解析
|
||||
try:
|
||||
return json.loads(cleaned_text)
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试修复不完整的JSON
|
||||
fixed_text = fix_incomplete_json(cleaned_text)
|
||||
if fixed_text:
|
||||
try:
|
||||
return json.loads(fixed_text)
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试查找JSON对象
|
||||
json_pattern = r'\{.*\}'
|
||||
match = re.search(json_pattern, cleaned_text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group())
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试查找JSON数组
|
||||
array_pattern = r'\[.*\]'
|
||||
match = re.search(array_pattern, cleaned_text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group())
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果所有方法都失败,返回错误信息
|
||||
print(f"无法解析JSON响应: {cleaned_text[:200]}...")
|
||||
return {"error": "JSON解析失败", "raw_text": cleaned_text}
|
||||
|
||||
|
||||
def fix_incomplete_json(text: str) -> str:
|
||||
"""
|
||||
修复不完整的JSON响应
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
修复后的JSON文本,如果无法修复则返回空字符串
|
||||
"""
|
||||
# 移除多余的逗号和空白
|
||||
text = re.sub(r',\s*}', '}', text)
|
||||
text = re.sub(r',\s*]', ']', text)
|
||||
|
||||
# 检查是否已经是有效的JSON
|
||||
try:
|
||||
json.loads(text)
|
||||
return text
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 检查是否缺少开头的数组符号
|
||||
if text.strip().startswith('{') and not text.strip().startswith('['):
|
||||
# 如果以对象开始,尝试包装成数组
|
||||
if text.count('{') > 1:
|
||||
# 多个对象,包装成数组
|
||||
text = '[' + text + ']'
|
||||
else:
|
||||
# 单个对象,包装成数组
|
||||
text = '[' + text + ']'
|
||||
|
||||
# 检查是否缺少结尾的数组符号
|
||||
if text.strip().endswith('}') and not text.strip().endswith(']'):
|
||||
# 如果以对象结束,尝试包装成数组
|
||||
if text.count('}') > 1:
|
||||
# 多个对象,包装成数组
|
||||
text = '[' + text + ']'
|
||||
else:
|
||||
# 单个对象,包装成数组
|
||||
text = '[' + text + ']'
|
||||
|
||||
# 检查括号是否匹配
|
||||
open_braces = text.count('{')
|
||||
close_braces = text.count('}')
|
||||
open_brackets = text.count('[')
|
||||
close_brackets = text.count(']')
|
||||
|
||||
# 修复不匹配的括号
|
||||
if open_braces > close_braces:
|
||||
text += '}' * (open_braces - close_braces)
|
||||
if open_brackets > close_brackets:
|
||||
text += ']' * (open_brackets - close_brackets)
|
||||
|
||||
# 验证修复后的JSON是否有效
|
||||
try:
|
||||
json.loads(text)
|
||||
return text
|
||||
except JSONDecodeError:
|
||||
# 如果仍然无效,尝试更激进的修复
|
||||
return fix_aggressive_json(text)
|
||||
|
||||
|
||||
def fix_aggressive_json(text: str) -> str:
|
||||
"""
|
||||
更激进的JSON修复方法
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
修复后的JSON文本
|
||||
"""
|
||||
# 查找所有可能的JSON对象
|
||||
objects = re.findall(r'\{[^{}]*\}', text)
|
||||
|
||||
if len(objects) >= 2:
|
||||
# 如果有多个对象,包装成数组
|
||||
return '[' + ','.join(objects) + ']'
|
||||
elif len(objects) == 1:
|
||||
# 如果只有一个对象,包装成数组
|
||||
return '[' + objects[0] + ']'
|
||||
else:
|
||||
# 如果没有找到对象,返回空数组
|
||||
return '[]'
|
||||
|
||||
|
||||
def update_state_with_search_results(search_results: List[Dict[str, Any]],
|
||||
paragraph_index: int, state: Any) -> Any:
|
||||
"""
|
||||
将搜索结果更新到状态中
|
||||
|
||||
Args:
|
||||
search_results: 搜索结果列表
|
||||
paragraph_index: 段落索引
|
||||
state: 状态对象
|
||||
|
||||
Returns:
|
||||
更新后的状态对象
|
||||
"""
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
# 获取最后一次搜索的查询(假设是当前查询)
|
||||
current_query = ""
|
||||
if search_results:
|
||||
# 从搜索结果推断查询(这里需要改进以获取实际查询)
|
||||
current_query = "搜索查询"
|
||||
|
||||
# 添加搜索结果到状态
|
||||
state.paragraphs[paragraph_index].research.add_search_results(
|
||||
current_query, search_results
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def validate_json_schema(data: Dict[str, Any], required_fields: List[str]) -> bool:
|
||||
"""
|
||||
验证JSON数据是否包含必需字段
|
||||
|
||||
Args:
|
||||
data: 要验证的数据
|
||||
required_fields: 必需字段列表
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
return all(field in data for field in required_fields)
|
||||
|
||||
|
||||
def truncate_content(content: str, max_length: int = 20000) -> str:
|
||||
"""
|
||||
截断内容到指定长度
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
截断后的内容
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
|
||||
# 尝试在单词边界截断
|
||||
truncated = content[:max_length]
|
||||
last_space = truncated.rfind(' ')
|
||||
|
||||
if last_space > max_length * 0.8: # 如果最后一个空格位置合理
|
||||
return truncated[:last_space] + "..."
|
||||
else:
|
||||
return truncated + "..."
|
||||
|
||||
|
||||
def format_search_results_for_prompt(search_results: List[Dict[str, Any]],
|
||||
max_length: int = 20000) -> List[str]:
|
||||
"""
|
||||
格式化搜索结果用于提示词
|
||||
|
||||
Args:
|
||||
search_results: 搜索结果列表
|
||||
max_length: 每个结果的最大长度
|
||||
|
||||
Returns:
|
||||
格式化后的内容列表
|
||||
"""
|
||||
formatted_results = []
|
||||
|
||||
for result in search_results:
|
||||
content = result.get('content', '')
|
||||
if content:
|
||||
truncated_content = truncate_content(content, max_length)
|
||||
formatted_results.append(truncated_content)
|
||||
|
||||
return formatted_results
|
||||
Reference in New Issue
Block a user