Reconfiguration of the basic multi-agent architecture.

This commit is contained in:
戒酒的李白
2025-08-22 22:04:08 +08:00
parent bec01f8930
commit 7ae863a781
70 changed files with 6792 additions and 648 deletions
+12
View File
@@ -0,0 +1,12 @@
"""
Deep Search Agent
一个无框架的深度搜索AI代理实现
"""
from .agent import DeepSearchAgent, create_agent
from .utils.config import Config, load_config
__version__ = "1.0.0"
__author__ = "Deep Search Agent Team"
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
+478
View File
@@ -0,0 +1,478 @@
"""
Deep Search Agent主类
整合所有模块,实现完整的深度搜索流程
"""
import json
import os
import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from .llms import DeepSeekLLM, OpenAILLM, BaseLLM
from .nodes import (
ReportStructureNode,
FirstSearchNode,
ReflectionNode,
FirstSummaryNode,
ReflectionSummaryNode,
ReportFormattingNode
)
from .state import State
from .tools import TavilyNewsAgency, TavilyResponse
from .utils import Config, load_config, format_search_results_for_prompt
class DeepSearchAgent:
"""Deep Search Agent主类"""
def __init__(self, config: Optional[Config] = None):
"""
初始化Deep Search Agent
Args:
config: 配置对象,如果不提供则自动加载
"""
# 加载配置
self.config = config or load_config()
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
# 初始化搜索工具集
self.search_agency = TavilyNewsAgency(api_key=self.config.tavily_api_key)
# 初始化节点
self._initialize_nodes()
# 状态
self.state = State()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
print(f"Deep Search Agent 已初始化")
print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)")
def _initialize_llm(self) -> BaseLLM:
"""初始化LLM客户端"""
if self.config.default_llm_provider == "deepseek":
return DeepSeekLLM(
api_key=self.config.deepseek_api_key,
model_name=self.config.deepseek_model
)
elif self.config.default_llm_provider == "openai":
return OpenAILLM(
api_key=self.config.openai_api_key,
model_name=self.config.openai_model
)
else:
raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}")
def _initialize_nodes(self):
"""初始化处理节点"""
self.first_search_node = FirstSearchNode(self.llm_client)
self.reflection_node = ReflectionNode(self.llm_client)
self.first_summary_node = FirstSummaryNode(self.llm_client)
self.reflection_summary_node = ReflectionSummaryNode(self.llm_client)
self.report_formatting_node = ReportFormattingNode(self.llm_client)
def _validate_date_format(self, date_str: str) -> bool:
"""
验证日期格式是否为YYYY-MM-DD
Args:
date_str: 日期字符串
Returns:
是否为有效格式
"""
if not date_str:
return False
# 检查格式
pattern = r'^\d{4}-\d{2}-\d{2}$'
if not re.match(pattern, date_str):
return False
# 检查日期是否有效
try:
datetime.strptime(date_str, '%Y-%m-%d')
return True
except ValueError:
return False
def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyResponse:
"""
执行指定的搜索工具
Args:
tool_name: 工具名称,可选值:
- "basic_search_news": 基础新闻搜索(快速、通用)
- "deep_search_news": 深度新闻分析
- "search_news_last_24_hours": 24小时内最新新闻
- "search_news_last_week": 本周新闻
- "search_images_for_news": 新闻图片搜索
- "search_news_by_date": 按日期范围搜索新闻
query: 搜索查询
**kwargs: 额外参数(如start_date, end_date, max_results
Returns:
TavilyResponse对象
"""
print(f" → 执行搜索工具: {tool_name}")
if tool_name == "basic_search_news":
max_results = kwargs.get("max_results", 7)
return self.search_agency.basic_search_news(query, max_results)
elif tool_name == "deep_search_news":
return self.search_agency.deep_search_news(query)
elif tool_name == "search_news_last_24_hours":
return self.search_agency.search_news_last_24_hours(query)
elif tool_name == "search_news_last_week":
return self.search_agency.search_news_last_week(query)
elif tool_name == "search_images_for_news":
return self.search_agency.search_images_for_news(query)
elif tool_name == "search_news_by_date":
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
if not start_date or not end_date:
raise ValueError("search_news_by_date工具需要start_date和end_date参数")
return self.search_agency.search_news_by_date(query, start_date, end_date)
else:
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索")
return self.search_agency.basic_search_news(query)
def research(self, query: str, save_report: bool = True) -> str:
"""
执行深度研究
Args:
query: 研究查询
save_report: 是否保存报告到文件
Returns:
最终报告内容
"""
print(f"\n{'='*60}")
print(f"开始深度研究: {query}")
print(f"{'='*60}")
try:
# Step 1: 生成报告结构
self._generate_report_structure(query)
# Step 2: 处理每个段落
self._process_paragraphs()
# Step 3: 生成最终报告
final_report = self._generate_final_report()
# Step 4: 保存报告
if save_report:
self._save_report(final_report)
print(f"\n{'='*60}")
print("深度研究完成!")
print(f"{'='*60}")
return final_report
except Exception as e:
print(f"研究过程中发生错误: {str(e)}")
raise e
def _generate_report_structure(self, query: str):
"""生成报告结构"""
print(f"\n[步骤 1] 生成报告结构...")
# 创建报告结构节点
report_structure_node = ReportStructureNode(self.llm_client, query)
# 生成结构并更新状态
self.state = report_structure_node.mutate_state(state=self.state)
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
for i, paragraph in enumerate(self.state.paragraphs, 1):
print(f" {i}. {paragraph.title}")
def _process_paragraphs(self):
"""处理所有段落"""
total_paragraphs = len(self.state.paragraphs)
for i in range(total_paragraphs):
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
print("-" * 50)
# 初始搜索和总结
self._initial_search_and_summary(i)
# 反思循环
self._reflection_loop(i)
# 标记段落完成
self.state.paragraphs[i].research.mark_completed()
progress = (i + 1) / total_paragraphs * 100
print(f"段落处理完成 ({progress:.1f}%)")
def _initial_search_and_summary(self, paragraph_index: int):
"""执行初始搜索和总结"""
paragraph = self.state.paragraphs[paragraph_index]
# 准备搜索输入
search_input = {
"title": paragraph.title,
"content": paragraph.content
}
# 生成搜索查询和工具选择
print(" - 生成搜索查询...")
search_output = self.first_search_node.run(search_input)
search_query = search_output["search_query"]
search_tool = search_output.get("search_tool", "basic_search_news") # 默认工具
reasoning = search_output["reasoning"]
print(f" - 搜索查询: {search_query}")
print(f" - 选择的工具: {search_tool}")
print(f" - 推理: {reasoning}")
# 执行搜索
print(" - 执行网络搜索...")
# 处理search_news_by_date的特殊参数
search_kwargs = {}
if search_tool == "search_news_by_date":
start_date = search_output.get("start_date")
end_date = search_output.get("end_date")
if start_date and end_date:
# 验证日期格式
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" - 时间范围: {start_date}{end_date}")
else:
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "basic_search_news"
else:
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
search_tool = "basic_search_news"
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
# 转换为兼容格式
search_results = []
if search_response and search_response.results:
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
max_results = min(len(search_response.results), 10)
for result in search_response.results[:max_results]:
search_results.append({
'title': result.title,
'url': result.url,
'content': result.content,
'score': result.score,
'raw_content': result.raw_content,
'published_date': result.published_date # 新增字段
})
if search_results:
print(f" - 找到 {len(search_results)} 个搜索结果")
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
else:
print(" - 未找到搜索结果")
# 更新状态中的搜索历史
paragraph.research.add_search_results(search_query, search_results)
# 生成初始总结
print(" - 生成初始总结...")
summary_input = {
"title": paragraph.title,
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
)
}
# 更新状态
self.state = self.first_summary_node.mutate_state(
summary_input, self.state, paragraph_index
)
print(" - 初始总结完成")
def _reflection_loop(self, paragraph_index: int):
"""执行反思循环"""
paragraph = self.state.paragraphs[paragraph_index]
for reflection_i in range(self.config.max_reflections):
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
# 准备反思输入
reflection_input = {
"title": paragraph.title,
"content": paragraph.content,
"paragraph_latest_state": paragraph.research.latest_summary
}
# 生成反思搜索查询
reflection_output = self.reflection_node.run(reflection_input)
search_query = reflection_output["search_query"]
search_tool = reflection_output.get("search_tool", "basic_search_news") # 默认工具
reasoning = reflection_output["reasoning"]
print(f" 反思查询: {search_query}")
print(f" 选择的工具: {search_tool}")
print(f" 反思推理: {reasoning}")
# 执行反思搜索
# 处理search_news_by_date的特殊参数
search_kwargs = {}
if search_tool == "search_news_by_date":
start_date = reflection_output.get("start_date")
end_date = reflection_output.get("end_date")
if start_date and end_date:
# 验证日期格式
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" 时间范围: {start_date}{end_date}")
else:
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "basic_search_news"
else:
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
search_tool = "basic_search_news"
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
# 转换为兼容格式
search_results = []
if search_response and search_response.results:
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
max_results = min(len(search_response.results), 10)
for result in search_response.results[:max_results]:
search_results.append({
'title': result.title,
'url': result.url,
'content': result.content,
'score': result.score,
'raw_content': result.raw_content,
'published_date': result.published_date
})
if search_results:
print(f" 找到 {len(search_results)} 个反思搜索结果")
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
else:
print(" 未找到反思搜索结果")
# 更新搜索历史
paragraph.research.add_search_results(search_query, search_results)
# 生成反思总结
reflection_summary_input = {
"title": paragraph.title,
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
),
"paragraph_latest_state": paragraph.research.latest_summary
}
# 更新状态
self.state = self.reflection_summary_node.mutate_state(
reflection_summary_input, self.state, paragraph_index
)
print(f" 反思 {reflection_i + 1} 完成")
def _generate_final_report(self) -> str:
"""生成最终报告"""
print(f"\n[步骤 3] 生成最终报告...")
# 准备报告数据
report_data = []
for paragraph in self.state.paragraphs:
report_data.append({
"title": paragraph.title,
"paragraph_latest_state": paragraph.research.latest_summary
})
# 格式化报告
try:
final_report = self.report_formatting_node.run(report_data)
except Exception as e:
print(f"LLM格式化失败,使用备用方法: {str(e)}")
final_report = self.report_formatting_node.format_report_manually(
report_data, self.state.report_title
)
# 更新状态
self.state.final_report = final_report
self.state.mark_completed()
print("最终报告生成完成")
return final_report
def _save_report(self, report_content: str):
"""保存报告到文件"""
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip()
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
filepath = os.path.join(self.config.output_dir, filename)
# 保存报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(report_content)
print(f"报告已保存到: {filepath}")
# 保存状态(如果配置允许)
if self.config.save_intermediate_states:
state_filename = f"state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
self.state.save_to_file(state_filepath)
print(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
return self.state.get_progress_summary()
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = State.load_from_file(filepath)
print(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
print(f"状态已保存到 {filepath}")
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
"""
创建Deep Search Agent实例的便捷函数
Args:
config_file: 配置文件路径
Returns:
DeepSearchAgent实例
"""
config = load_config(config_file)
return DeepSearchAgent(config)
+10
View File
@@ -0,0 +1,10 @@
"""
LLM调用模块
支持多种大语言模型的统一接口
"""
from .base import BaseLLM
from .deepseek import DeepSeekLLM
from .openai_llm import OpenAILLM
__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"]
+61
View File
@@ -0,0 +1,61 @@
"""
LLM基础抽象类
定义所有LLM实现需要遵循的接口标准
"""
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any
class BaseLLM(ABC):
"""LLM基础抽象类"""
def __init__(self, api_key: str, model_name: Optional[str] = None):
"""
初始化LLM客户端
Args:
api_key: API密钥
model_name: 模型名称,如果不指定则使用默认模型
"""
self.api_key = api_key
self.model_name = model_name
@abstractmethod
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用LLM生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
LLM生成的回复文本
"""
pass
@abstractmethod
def get_default_model(self) -> str:
"""
获取默认模型名称
Returns:
默认模型名称
"""
pass
def validate_response(self, response: str) -> str:
"""
验证和清理响应内容
Args:
response: LLM原始响应
Returns:
清理后的响应内容
"""
if response is None:
return ""
return response.strip()
+95
View File
@@ -0,0 +1,95 @@
"""
DeepSeek LLM实现
使用DeepSeek API进行文本生成
"""
import os
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
class DeepSeekLLM(BaseLLM):
"""DeepSeek LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
"""
初始化DeepSeek客户端
Args:
api_key: DeepSeek API密钥,如果不提供则从环境变量读取
model_name: 模型名称,默认使用deepseek-chat
"""
if api_key is None:
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DeepSeek API Key未找到!请设置DEEPSEEK_API_KEY环境变量或在初始化时提供")
super().__init__(api_key, model_name)
# 初始化OpenAI客户端,使用DeepSeek的endpoint
self.client = OpenAI(
api_key=self.api_key,
base_url="https://api.deepseek.com"
)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "deepseek-chat"
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用DeepSeek API生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
DeepSeek生成的回复文本
"""
try:
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 4000),
"stream": False
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
except Exception as e:
print(f"DeepSeek API调用错误: {str(e)}")
raise e
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "DeepSeek",
"model": self.default_model,
"api_base": "https://api.deepseek.com"
}
+90
View File
@@ -0,0 +1,90 @@
"""
OpenAI LLM实现
使用OpenAI API进行文本生成
"""
import os
from typing import Optional, Dict, Any
from openai import OpenAI
from .base import BaseLLM
class OpenAILLM(BaseLLM):
"""OpenAI LLM实现类"""
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
"""
初始化OpenAI客户端
Args:
api_key: OpenAI API密钥,如果不提供则从环境变量读取
model_name: 模型名称,默认使用gpt-4o-mini
"""
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OpenAI API Key未找到!请设置OPENAI_API_KEY环境变量或在初始化时提供")
super().__init__(api_key, model_name)
# 初始化OpenAI客户端
self.client = OpenAI(api_key=self.api_key)
self.default_model = model_name or self.get_default_model()
def get_default_model(self) -> str:
"""获取默认模型名称"""
return "gpt-4o-mini"
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
调用OpenAI API生成回复
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
**kwargs: 其他参数,如temperature、max_tokens等
Returns:
OpenAI生成的回复文本
"""
try:
# 构建消息
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# 设置默认参数
params = {
"model": self.default_model,
"messages": messages,
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 4000)
}
# 调用API
response = self.client.chat.completions.create(**params)
# 提取回复内容
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return self.validate_response(content)
else:
return ""
except Exception as e:
print(f"OpenAI API调用错误: {str(e)}")
raise e
def get_model_info(self) -> Dict[str, Any]:
"""
获取当前模型信息
Returns:
模型信息字典
"""
return {
"provider": "OpenAI",
"model": self.default_model,
"api_base": "https://api.openai.com"
}
+20
View File
@@ -0,0 +1,20 @@
"""
节点处理模块
实现Deep Search Agent的各个处理步骤
"""
from .base_node import BaseNode
from .report_structure_node import ReportStructureNode
from .search_node import FirstSearchNode, ReflectionNode
from .summary_node import FirstSummaryNode, ReflectionSummaryNode
from .formatting_node import ReportFormattingNode
__all__ = [
"BaseNode",
"ReportStructureNode",
"FirstSearchNode",
"ReflectionNode",
"FirstSummaryNode",
"ReflectionSummaryNode",
"ReportFormattingNode"
]
+89
View File
@@ -0,0 +1,89 @@
"""
节点基类
定义所有处理节点的基础接口
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from ..llms.base import BaseLLM
from ..state.state import State
class BaseNode(ABC):
"""节点基类"""
def __init__(self, llm_client: BaseLLM, node_name: str = ""):
"""
初始化节点
Args:
llm_client: LLM客户端
node_name: 节点名称
"""
self.llm_client = llm_client
self.node_name = node_name or self.__class__.__name__
@abstractmethod
def run(self, input_data: Any, **kwargs) -> Any:
"""
执行节点处理逻辑
Args:
input_data: 输入数据
**kwargs: 额外参数
Returns:
处理结果
"""
pass
def validate_input(self, input_data: Any) -> bool:
"""
验证输入数据
Args:
input_data: 输入数据
Returns:
验证是否通过
"""
return True
def process_output(self, output: Any) -> Any:
"""
处理输出数据
Args:
output: 原始输出
Returns:
处理后的输出
"""
return output
def log_info(self, message: str):
"""记录信息日志"""
print(f"[{self.node_name}] {message}")
def log_error(self, message: str):
"""记录错误日志"""
print(f"[{self.node_name}] 错误: {message}")
class StateMutationNode(BaseNode):
"""带状态修改功能的节点基类"""
@abstractmethod
def mutate_state(self, input_data: Any, state: State, **kwargs) -> State:
"""
修改状态
Args:
input_data: 输入数据
state: 当前状态
**kwargs: 额外参数
Returns:
修改后的状态
"""
pass
+164
View File
@@ -0,0 +1,164 @@
"""
报告格式化节点
负责将最终研究结果格式化为美观的Markdown报告
"""
import json
from typing import List, Dict, Any
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
from ..utils.text_processing import (
remove_reasoning_from_output,
clean_markdown_tags
)
class ReportFormattingNode(BaseNode):
"""格式化最终报告的节点"""
def __init__(self, llm_client):
"""
初始化报告格式化节点
Args:
llm_client: LLM客户端
"""
super().__init__(llm_client, "ReportFormattingNode")
def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
if isinstance(input_data, str):
try:
data = json.loads(input_data)
return isinstance(data, list) and all(
isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
for item in data
)
except:
return False
elif isinstance(input_data, list):
return all(
isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
for item in input_data
)
return False
def run(self, input_data: Any, **kwargs) -> str:
"""
调用LLM生成Markdown格式报告
Args:
input_data: 包含所有段落信息的列表
**kwargs: 额外参数
Returns:
格式化的Markdown报告
"""
try:
if not self.validate_input(input_data):
raise ValueError("输入数据格式错误,需要包含title和paragraph_latest_state的列表")
# 准备输入数据
if isinstance(input_data, str):
message = input_data
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在格式化最终报告")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message)
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成格式化报告")
return processed_response
except Exception as e:
self.log_error(f"报告格式化失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
"""
处理LLM输出,清理Markdown格式
Args:
output: LLM原始输出
Returns:
清理后的Markdown报告
"""
try:
# 清理响应文本
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_markdown_tags(cleaned_output)
# 确保报告有基本结构
if not cleaned_output.strip():
return "# 报告生成失败\n\n无法生成有效的报告内容。"
# 如果没有标题,添加一个默认标题
if not cleaned_output.strip().startswith('#'):
cleaned_output = "# 深度研究报告\n\n" + cleaned_output
return cleaned_output.strip()
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
report_title: str = "深度研究报告") -> str:
"""
手动格式化报告(备用方法)
Args:
paragraphs_data: 段落数据列表
report_title: 报告标题
Returns:
格式化的Markdown报告
"""
try:
self.log_info("使用手动格式化方法")
# 构建报告
report_lines = [
f"# {report_title}",
"",
"---",
""
]
# 添加各个段落
for i, paragraph in enumerate(paragraphs_data, 1):
title = paragraph.get("title", f"段落 {i}")
content = paragraph.get("paragraph_latest_state", "")
if content:
report_lines.extend([
f"## {title}",
"",
content,
"",
"---",
""
])
# 添加结论
if len(paragraphs_data) > 1:
report_lines.extend([
"## 结论",
"",
"本报告通过深度搜索和研究,对相关主题进行了全面分析。"
"以上各个方面的内容为理解该主题提供了重要参考。",
""
])
return "\n".join(report_lines)
except Exception as e:
self.log_error(f"手动格式化失败: {str(e)}")
return "# 报告生成失败\n\n无法完成报告格式化。"
+203
View File
@@ -0,0 +1,203 @@
"""
报告结构生成节点
负责根据查询生成报告的整体结构
"""
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from .base_node import StateMutationNode
from ..state.state import State
from ..prompts import SYSTEM_PROMPT_REPORT_STRUCTURE
from ..utils.text_processing import (
remove_reasoning_from_output,
clean_json_tags,
extract_clean_response,
fix_incomplete_json
)
class ReportStructureNode(StateMutationNode):
"""生成报告结构的节点"""
def __init__(self, llm_client, query: str):
"""
初始化报告结构节点
Args:
llm_client: LLM客户端
query: 用户查询
"""
super().__init__(llm_client, "ReportStructureNode")
self.query = query
def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
return isinstance(self.query, str) and len(self.query.strip()) > 0
def run(self, input_data: Any = None, **kwargs) -> List[Dict[str, str]]:
"""
调用LLM生成报告结构
Args:
input_data: 输入数据(这里不使用,使用初始化时的query)
**kwargs: 额外参数
Returns:
报告结构列表
"""
try:
self.log_info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
return processed_response
except Exception as e:
self.log_error(f"生成报告结构失败: {str(e)}")
raise e
def process_output(self, output: str) -> List[Dict[str, str]]:
"""
处理LLM输出,提取报告结构
Args:
output: LLM原始输出
Returns:
处理后的报告结构列表
"""
try:
# 清理响应文本
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
report_structure = json.loads(cleaned_output)
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
report_structure = extract_clean_response(cleaned_output)
if "error" in report_structure:
self.log_error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
report_structure = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
# 返回默认结构
return self._generate_default_structure()
else:
self.log_error("无法修复JSON,使用默认结构")
return self._generate_default_structure()
# 验证结构
if not isinstance(report_structure, list):
self.log_info("报告结构不是列表,尝试转换...")
if isinstance(report_structure, dict):
# 如果是单个对象,包装成列表
report_structure = [report_structure]
else:
self.log_error("报告结构格式无效,使用默认结构")
return self._generate_default_structure()
# 验证每个段落
validated_structure = []
for i, paragraph in enumerate(report_structure):
if not isinstance(paragraph, dict):
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
continue
title = paragraph.get("title", f"段落 {i+1}")
content = paragraph.get("content", "")
if not title or not content:
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
continue
validated_structure.append({
"title": title,
"content": content
})
if not validated_structure:
self.log_warning("没有有效的段落结构,使用默认结构")
return self._generate_default_structure()
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
return validated_structure
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
return self._generate_default_structure()
def _generate_default_structure(self) -> List[Dict[str, str]]:
"""
生成默认的报告结构
Returns:
默认的报告结构列表
"""
self.log_info("生成默认报告结构")
return [
{
"title": "研究概述",
"content": "对查询主题进行总体概述和分析"
},
{
"title": "深度分析",
"content": "深入分析查询主题的各个方面"
}
]
def mutate_state(self, input_data: Any = None, state: State = None, **kwargs) -> State:
"""
将报告结构写入状态
Args:
input_data: 输入数据
state: 当前状态,如果为None则创建新状态
**kwargs: 额外参数
Returns:
更新后的状态
"""
if state is None:
state = State()
try:
# 生成报告结构
report_structure = self.run(input_data, **kwargs)
# 设置查询和报告标题
state.query = self.query
if not state.report_title:
state.report_title = f"关于'{self.query}'的深度研究报告"
# 添加段落到状态
for paragraph_data in report_structure:
state.add_paragraph(
title=paragraph_data["title"],
content=paragraph_data["content"]
)
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
raise e
+285
View File
@@ -0,0 +1,285 @@
"""
搜索节点实现
负责生成搜索查询和反思查询
"""
import json
from typing import Dict, Any
from json.decoder import JSONDecodeError
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
from ..utils.text_processing import (
remove_reasoning_from_output,
clean_json_tags,
extract_clean_response,
fix_incomplete_json
)
class FirstSearchNode(BaseNode):
"""为段落生成首次搜索查询的节点"""
def __init__(self, llm_client):
"""
初始化首次搜索节点
Args:
llm_client: LLM客户端
"""
super().__init__(llm_client, "FirstSearchNode")
def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
if isinstance(input_data, str):
try:
data = json.loads(input_data)
return "title" in data and "content" in data
except JSONDecodeError:
return False
elif isinstance(input_data, dict):
return "title" in input_data and "content" in input_data
return False
def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
"""
调用LLM生成搜索查询和理由
Args:
input_data: 包含title和content的字符串或字典
**kwargs: 额外参数
Returns:
包含search_query和reasoning的字典
"""
try:
if not self.validate_input(input_data):
raise ValueError("输入数据格式错误,需要包含title和content字段")
# 准备输入数据
if isinstance(input_data, str):
message = input_data
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"生成首次搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
"""
处理LLM输出,提取搜索查询和推理
Args:
output: LLM原始输出
Returns:
包含search_query和reasoning的字典
"""
try:
# 清理响应文本
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
# 返回默认查询
return self._get_default_search_query()
else:
self.log_error("无法修复JSON,使用默认查询")
return self._get_default_search_query()
# 验证和清理结果
search_query = result.get("search_query", "")
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
return self._get_default_search_query()
return {
"search_query": search_query,
"reasoning": reasoning
}
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_search_query()
def _get_default_search_query(self) -> Dict[str, str]:
"""
获取默认搜索查询
Returns:
默认的搜索查询字典
"""
return {
"search_query": "相关主题研究",
"reasoning": "由于解析失败,使用默认搜索查询"
}
class ReflectionNode(BaseNode):
"""反思段落并生成新搜索查询的节点"""
def __init__(self, llm_client):
"""
初始化反思节点
Args:
llm_client: LLM客户端
"""
super().__init__(llm_client, "ReflectionNode")
def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
if isinstance(input_data, str):
try:
data = json.loads(input_data)
required_fields = ["title", "content", "paragraph_latest_state"]
return all(field in data for field in required_fields)
except JSONDecodeError:
return False
elif isinstance(input_data, dict):
required_fields = ["title", "content", "paragraph_latest_state"]
return all(field in input_data for field in required_fields)
return False
def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
"""
调用LLM反思并生成搜索查询
Args:
input_data: 包含title、content和paragraph_latest_state的字符串或字典
**kwargs: 额外参数
Returns:
包含search_query和reasoning的字典
"""
try:
if not self.validate_input(input_data):
raise ValueError("输入数据格式错误,需要包含title、content和paragraph_latest_state字段")
# 准备输入数据
if isinstance(input_data, str):
message = input_data
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"反思生成搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
"""
处理LLM输出,提取搜索查询和推理
Args:
output: LLM原始输出
Returns:
包含search_query和reasoning的字典
"""
try:
# 清理响应文本
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
# 返回默认查询
return self._get_default_reflection_query()
else:
self.log_error("无法修复JSON,使用默认查询")
return self._get_default_reflection_query()
# 验证和清理结果
search_query = result.get("search_query", "")
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
return self._get_default_reflection_query()
return {
"search_query": search_query,
"reasoning": reasoning
}
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_reflection_query()
def _get_default_reflection_query(self) -> Dict[str, str]:
"""
获取默认反思搜索查询
Returns:
默认的反思搜索查询字典
"""
return {
"search_query": "深度研究补充信息",
"reasoning": "由于解析失败,使用默认反思搜索查询"
}
+312
View File
@@ -0,0 +1,312 @@
"""
总结节点实现
负责根据搜索结果生成和更新段落内容
"""
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from .base_node import StateMutationNode
from ..state.state import State
from ..prompts import SYSTEM_PROMPT_FIRST_SUMMARY, SYSTEM_PROMPT_REFLECTION_SUMMARY
from ..utils.text_processing import (
remove_reasoning_from_output,
clean_json_tags,
extract_clean_response,
fix_incomplete_json,
format_search_results_for_prompt
)
class FirstSummaryNode(StateMutationNode):
"""根据搜索结果生成段落首次总结的节点"""
def __init__(self, llm_client):
"""
初始化首次总结节点
Args:
llm_client: LLM客户端
"""
super().__init__(llm_client, "FirstSummaryNode")
def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
if isinstance(input_data, str):
try:
data = json.loads(input_data)
required_fields = ["title", "content", "search_query", "search_results"]
return all(field in data for field in required_fields)
except JSONDecodeError:
return False
elif isinstance(input_data, dict):
required_fields = ["title", "content", "search_query", "search_results"]
return all(field in input_data for field in required_fields)
return False
def run(self, input_data: Any, **kwargs) -> str:
"""
调用LLM生成段落总结
Args:
input_data: 包含title、content、search_query和search_results的数据
**kwargs: 额外参数
Returns:
段落总结内容
"""
try:
if not self.validate_input(input_data):
raise ValueError("输入数据格式错误")
# 准备输入数据
if isinstance(input_data, str):
message = input_data
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在生成首次段落总结")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成首次段落总结")
return processed_response
except Exception as e:
self.log_error(f"生成首次总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
"""
处理LLM输出,提取段落内容
Args:
output: LLM原始输出
Returns:
段落内容
"""
try:
# 清理响应文本
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
# 提取段落内容
if isinstance(result, dict):
paragraph_content = result.get("paragraph_latest_state", "")
if paragraph_content:
return paragraph_content
# 如果提取失败,返回原始清理后的文本
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
return "段落总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
"""
更新段落的最新总结到状态
Args:
input_data: 输入数据
state: 当前状态
paragraph_index: 段落索引
**kwargs: 额外参数
Returns:
更新后的状态
"""
try:
# 生成总结
summary = self.run(input_data, **kwargs)
# 更新状态
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = summary
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
state.update_timestamp()
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
raise e
class ReflectionSummaryNode(StateMutationNode):
"""根据反思搜索结果更新段落总结的节点"""
def __init__(self, llm_client):
"""
初始化反思总结节点
Args:
llm_client: LLM客户端
"""
super().__init__(llm_client, "ReflectionSummaryNode")
def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
if isinstance(input_data, str):
try:
data = json.loads(input_data)
required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"]
return all(field in data for field in required_fields)
except JSONDecodeError:
return False
elif isinstance(input_data, dict):
required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"]
return all(field in input_data for field in required_fields)
return False
def run(self, input_data: Any, **kwargs) -> str:
"""
调用LLM更新段落内容
Args:
input_data: 包含完整反思信息的数据
**kwargs: 额外参数
Returns:
更新后的段落内容
"""
try:
if not self.validate_input(input_data):
raise ValueError("输入数据格式错误")
# 准备输入数据
if isinstance(input_data, str):
message = input_data
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在生成反思总结")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成反思总结")
return processed_response
except Exception as e:
self.log_error(f"生成反思总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
"""
处理LLM输出,提取更新后的段落内容
Args:
output: LLM原始输出
Returns:
更新后的段落内容
"""
try:
# 清理响应文本
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
# 提取更新后的段落内容
if isinstance(result, dict):
updated_content = result.get("updated_paragraph_latest_state", "")
if updated_content:
return updated_content
# 如果提取失败,返回原始清理后的文本
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
return "反思总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
"""
将更新后的总结写入状态
Args:
input_data: 输入数据
state: 当前状态
paragraph_index: 段落索引
**kwargs: 额外参数
Returns:
更新后的状态
"""
try:
# 生成更新后的总结
updated_summary = self.run(input_data, **kwargs)
# 更新状态
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
state.paragraphs[paragraph_index].research.increment_reflection()
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
state.update_timestamp()
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
raise e
+34
View File
@@ -0,0 +1,34 @@
"""
Prompt模块
定义Deep Search Agent各个阶段使用的系统提示词
"""
from .prompts import (
SYSTEM_PROMPT_REPORT_STRUCTURE,
SYSTEM_PROMPT_FIRST_SEARCH,
SYSTEM_PROMPT_FIRST_SUMMARY,
SYSTEM_PROMPT_REFLECTION,
SYSTEM_PROMPT_REFLECTION_SUMMARY,
SYSTEM_PROMPT_REPORT_FORMATTING,
output_schema_report_structure,
output_schema_first_search,
output_schema_first_summary,
output_schema_reflection,
output_schema_reflection_summary,
input_schema_report_formatting
)
__all__ = [
"SYSTEM_PROMPT_REPORT_STRUCTURE",
"SYSTEM_PROMPT_FIRST_SEARCH",
"SYSTEM_PROMPT_FIRST_SUMMARY",
"SYSTEM_PROMPT_REFLECTION",
"SYSTEM_PROMPT_REFLECTION_SUMMARY",
"SYSTEM_PROMPT_REPORT_FORMATTING",
"output_schema_report_structure",
"output_schema_first_search",
"output_schema_first_summary",
"output_schema_reflection",
"output_schema_reflection_summary",
"input_schema_report_formatting"
]
+285
View File
@@ -0,0 +1,285 @@
"""
Deep Search Agent 的所有提示词定义
包含各个阶段的系统提示词和JSON Schema定义
"""
import json
# ===== JSON Schema 定义 =====
# 报告结构输出Schema
output_schema_report_structure = {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"content": {"type": "string"}
}
}
}
# 首次搜索输入Schema
input_schema_first_search = {
"type": "object",
"properties": {
"title": {"type": "string"},
"content": {"type": "string"}
}
}
# 首次搜索输出Schema
output_schema_first_search = {
"type": "object",
"properties": {
"search_query": {"type": "string"},
"search_tool": {"type": "string"},
"reasoning": {"type": "string"},
"start_date": {"type": "string", "description": "开始日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"},
"end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"}
},
"required": ["search_query", "search_tool", "reasoning"]
}
# 首次总结输入Schema
input_schema_first_summary = {
"type": "object",
"properties": {
"title": {"type": "string"},
"content": {"type": "string"},
"search_query": {"type": "string"},
"search_results": {
"type": "array",
"items": {"type": "string"}
}
}
}
# 首次总结输出Schema
output_schema_first_summary = {
"type": "object",
"properties": {
"paragraph_latest_state": {"type": "string"}
}
}
# 反思输入Schema
input_schema_reflection = {
"type": "object",
"properties": {
"title": {"type": "string"},
"content": {"type": "string"},
"paragraph_latest_state": {"type": "string"}
}
}
# 反思输出Schema
output_schema_reflection = {
"type": "object",
"properties": {
"search_query": {"type": "string"},
"search_tool": {"type": "string"},
"reasoning": {"type": "string"},
"start_date": {"type": "string", "description": "开始日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"},
"end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"}
},
"required": ["search_query", "search_tool", "reasoning"]
}
# 反思总结输入Schema
input_schema_reflection_summary = {
"type": "object",
"properties": {
"title": {"type": "string"},
"content": {"type": "string"},
"search_query": {"type": "string"},
"search_results": {
"type": "array",
"items": {"type": "string"}
},
"paragraph_latest_state": {"type": "string"}
}
}
# 反思总结输出Schema
output_schema_reflection_summary = {
"type": "object",
"properties": {
"updated_paragraph_latest_state": {"type": "string"}
}
}
# 报告格式化输入Schema
input_schema_report_formatting = {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"paragraph_latest_state": {"type": "string"}
}
}
}
# ===== 系统提示词定义 =====
# 生成报告结构的系统提示词
SYSTEM_PROMPT_REPORT_STRUCTURE = f"""
你是一位深度研究助手。给定一个查询,你需要规划一个报告的结构和其中包含的段落。最多五个段落。
确保段落的排序合理有序。
一旦大纲创建完成,你将获得工具来分别为每个部分搜索网络并进行反思。
请按照以下JSON模式定义格式化输出:
<OUTPUT JSON SCHEMA>
{json.dumps(output_schema_report_structure, indent=2, ensure_ascii=False)}
</OUTPUT JSON SCHEMA>
标题和内容属性将用于更深入的研究。
确保输出是一个符合上述输出JSON模式定义的JSON对象。
只返回JSON对象,不要有解释或额外文本。
"""
# 每个段落第一次搜索的系统提示词
SYSTEM_PROMPT_FIRST_SEARCH = f"""
你是一位深度研究助手。你将获得报告中的一个段落,其标题和预期内容将按照以下JSON模式定义提供:
<INPUT JSON SCHEMA>
{json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)}
</INPUT JSON SCHEMA>
你可以使用以下6种专业的新闻搜索工具:
1. **basic_search_news** - 基础新闻搜索工具
- 适用于:一般性的新闻搜索,不确定需要何种特定搜索时
- 特点:快速、标准的通用搜索,是最常用的基础工具
2. **deep_search_news** - 深度新闻分析工具
- 适用于:需要全面深入了解某个主题时
- 特点:提供最详细的分析结果,包含高级AI摘要
3. **search_news_last_24_hours** - 24小时最新新闻工具
- 适用于:需要了解最新动态、突发事件时
- 特点:只搜索过去24小时的新闻
4. **search_news_last_week** - 本周新闻工具
- 适用于:需要了解近期发展趋势时
- 特点:搜索过去一周的新闻报道
5. **search_images_for_news** - 图片搜索工具
- 适用于:需要可视化信息、图片资料时
- 特点:提供相关图片和图片描述
6. **search_news_by_date** - 按日期范围搜索工具
- 适用于:需要研究特定历史时期时
- 特点:可以指定开始和结束日期进行搜索
- 特殊要求:需要提供start_date和end_date参数,格式为'YYYY-MM-DD'
- 注意:只有这个工具需要额外的时间参数
你的任务是:
1. 根据段落主题选择最合适的搜索工具
2. 制定最佳的搜索查询
3. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD
4. 解释你的选择理由
注意:除了search_news_by_date工具外,其他工具都不需要额外参数。
请按照以下JSON模式定义格式化输出(文字请使用中文):
<OUTPUT JSON SCHEMA>
{json.dumps(output_schema_first_search, indent=2, ensure_ascii=False)}
</OUTPUT JSON SCHEMA>
确保输出是一个符合上述输出JSON模式定义的JSON对象。
只返回JSON对象,不要有解释或额外文本。
"""
# 每个段落第一次总结的系统提示词
SYSTEM_PROMPT_FIRST_SUMMARY = f"""
你是一位深度研究助手。你将获得搜索查询、搜索结果以及你正在研究的报告段落,数据将按照以下JSON模式定义提供:
<INPUT JSON SCHEMA>
{json.dumps(input_schema_first_summary, indent=2, ensure_ascii=False)}
</INPUT JSON SCHEMA>
你的任务是作为研究者,使用搜索结果撰写与段落主题一致的内容,并适当地组织结构以便纳入报告中。
请按照以下JSON模式定义格式化输出:
<OUTPUT JSON SCHEMA>
{json.dumps(output_schema_first_summary, indent=2, ensure_ascii=False)}
</OUTPUT JSON SCHEMA>
确保输出是一个符合上述输出JSON模式定义的JSON对象。
只返回JSON对象,不要有解释或额外文本。
"""
# 反思(Reflect)的系统提示词
SYSTEM_PROMPT_REFLECTION = f"""
你是一位深度研究助手。你负责为研究报告构建全面的段落。你将获得段落标题、计划内容摘要,以及你已经创建的段落最新状态,所有这些都将按照以下JSON模式定义提供:
<INPUT JSON SCHEMA>
{json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)}
</INPUT JSON SCHEMA>
你可以使用以下6种专业的新闻搜索工具:
1. **basic_search_news** - 基础新闻搜索工具
2. **deep_search_news** - 深度新闻分析工具
3. **search_news_last_24_hours** - 24小时最新新闻工具
4. **search_news_last_week** - 本周新闻工具
5. **search_images_for_news** - 图片搜索工具
6. **search_news_by_date** - 按日期范围搜索工具(需要时间参数)
你的任务是:
1. 反思段落文本的当前状态,思考是否遗漏了主题的某些关键方面
2. 选择最合适的搜索工具来补充缺失信息
3. 制定精确的搜索查询
4. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD
5. 解释你的选择和推理
注意:除了search_news_by_date工具外,其他工具都不需要额外参数。
请按照以下JSON模式定义格式化输出:
<OUTPUT JSON SCHEMA>
{json.dumps(output_schema_reflection, indent=2, ensure_ascii=False)}
</OUTPUT JSON SCHEMA>
确保输出是一个符合上述输出JSON模式定义的JSON对象。
只返回JSON对象,不要有解释或额外文本。
"""
# 总结反思的系统提示词
SYSTEM_PROMPT_REFLECTION_SUMMARY = f"""
你是一位深度研究助手。
你将获得搜索查询、搜索结果、段落标题以及你正在研究的报告段落的预期内容。
你正在迭代完善这个段落,并且段落的最新状态也会提供给你。
数据将按照以下JSON模式定义提供:
<INPUT JSON SCHEMA>
{json.dumps(input_schema_reflection_summary, indent=2, ensure_ascii=False)}
</INPUT JSON SCHEMA>
你的任务是根据搜索结果和预期内容丰富段落的当前最新状态。
不要删除最新状态中的关键信息,尽量丰富它,只添加缺失的信息。
适当地组织段落结构以便纳入报告中。
请按照以下JSON模式定义格式化输出:
<OUTPUT JSON SCHEMA>
{json.dumps(output_schema_reflection_summary, indent=2, ensure_ascii=False)}
</OUTPUT JSON SCHEMA>
确保输出是一个符合上述输出JSON模式定义的JSON对象。
只返回JSON对象,不要有解释或额外文本。
"""
# 最终研究报告格式化的系统提示词
SYSTEM_PROMPT_REPORT_FORMATTING = f"""
你是一位深度研究助手。你已经完成了研究并构建了报告中所有段落的最终版本。
你将获得以下JSON格式的数据:
<INPUT JSON SCHEMA>
{json.dumps(input_schema_report_formatting, indent=2, ensure_ascii=False)}
</INPUT JSON SCHEMA>
你的任务是将报告格式化为美观的形式,并以Markdown格式返回。
如果没有结论段落,请根据其他段落的最新状态在报告末尾添加一个结论。
使用段落标题来创建报告的标题。
"""
+8
View File
@@ -0,0 +1,8 @@
"""
状态管理模块
定义Deep Search Agent的状态数据结构
"""
from .state import State, Paragraph, Research, Search
__all__ = ["State", "Paragraph", "Research", "Search"]
+258
View File
@@ -0,0 +1,258 @@
"""
Deep Search Agent状态管理
定义所有状态数据结构和操作方法
"""
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
import json
from datetime import datetime
@dataclass
class Search:
"""单个搜索结果的状态"""
query: str = "" # 搜索查询
url: str = "" # 搜索结果的链接
title: str = "" # 搜索结果标题
content: str = "" # 搜索返回的内容
score: Optional[float] = None # 相关度评分
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"query": self.query,
"url": self.url,
"title": self.title,
"content": self.content,
"score": self.score,
"timestamp": self.timestamp
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Search":
"""从字典创建Search对象"""
return cls(
query=data.get("query", ""),
url=data.get("url", ""),
title=data.get("title", ""),
content=data.get("content", ""),
score=data.get("score"),
timestamp=data.get("timestamp", datetime.now().isoformat())
)
@dataclass
class Research:
"""段落研究过程的状态"""
search_history: List[Search] = field(default_factory=list) # 搜索记录列表
latest_summary: str = "" # 当前段落的最新总结
reflection_iteration: int = 0 # 反思迭代次数
is_completed: bool = False # 是否完成研究
def add_search(self, search: Search):
"""添加搜索记录"""
self.search_history.append(search)
def add_search_results(self, query: str, results: List[Dict[str, Any]]):
"""批量添加搜索结果"""
for result in results:
search = Search(
query=query,
url=result.get("url", ""),
title=result.get("title", ""),
content=result.get("content", ""),
score=result.get("score")
)
self.add_search(search)
def get_search_count(self) -> int:
"""获取搜索次数"""
return len(self.search_history)
def increment_reflection(self):
"""增加反思次数"""
self.reflection_iteration += 1
def mark_completed(self):
"""标记为完成"""
self.is_completed = True
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"search_history": [search.to_dict() for search in self.search_history],
"latest_summary": self.latest_summary,
"reflection_iteration": self.reflection_iteration,
"is_completed": self.is_completed
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Research":
"""从字典创建Research对象"""
search_history = [Search.from_dict(search_data) for search_data in data.get("search_history", [])]
return cls(
search_history=search_history,
latest_summary=data.get("latest_summary", ""),
reflection_iteration=data.get("reflection_iteration", 0),
is_completed=data.get("is_completed", False)
)
@dataclass
class Paragraph:
"""报告中单个段落的状态"""
title: str = "" # 段落标题
content: str = "" # 段落的预期内容(初始规划)
research: Research = field(default_factory=Research) # 研究进度
order: int = 0 # 段落顺序
def is_completed(self) -> bool:
"""检查段落是否完成"""
return self.research.is_completed and bool(self.research.latest_summary)
def get_final_content(self) -> str:
"""获取最终内容"""
return self.research.latest_summary or self.content
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"title": self.title,
"content": self.content,
"research": self.research.to_dict(),
"order": self.order
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Paragraph":
"""从字典创建Paragraph对象"""
research_data = data.get("research", {})
research = Research.from_dict(research_data) if research_data else Research()
return cls(
title=data.get("title", ""),
content=data.get("content", ""),
research=research,
order=data.get("order", 0)
)
@dataclass
class State:
"""整个报告的状态"""
query: str = "" # 原始查询
report_title: str = "" # 报告标题
paragraphs: List[Paragraph] = field(default_factory=list) # 段落列表
final_report: str = "" # 最终报告内容
is_completed: bool = False # 是否完成
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
def add_paragraph(self, title: str, content: str) -> int:
"""
添加段落
Args:
title: 段落标题
content: 段落内容
Returns:
段落索引
"""
order = len(self.paragraphs)
paragraph = Paragraph(title=title, content=content, order=order)
self.paragraphs.append(paragraph)
self.update_timestamp()
return order
def get_paragraph(self, index: int) -> Optional[Paragraph]:
"""获取指定索引的段落"""
if 0 <= index < len(self.paragraphs):
return self.paragraphs[index]
return None
def get_completed_paragraphs_count(self) -> int:
"""获取已完成段落数量"""
return sum(1 for p in self.paragraphs if p.is_completed())
def get_total_paragraphs_count(self) -> int:
"""获取总段落数量"""
return len(self.paragraphs)
def is_all_paragraphs_completed(self) -> bool:
"""检查是否所有段落都完成"""
return all(p.is_completed() for p in self.paragraphs) if self.paragraphs else False
def mark_completed(self):
"""标记整个报告为完成"""
self.is_completed = True
self.update_timestamp()
def update_timestamp(self):
"""更新时间戳"""
self.updated_at = datetime.now().isoformat()
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
completed = self.get_completed_paragraphs_count()
total = self.get_total_paragraphs_count()
return {
"total_paragraphs": total,
"completed_paragraphs": completed,
"progress_percentage": (completed / total * 100) if total > 0 else 0,
"is_completed": self.is_completed,
"created_at": self.created_at,
"updated_at": self.updated_at
}
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"query": self.query,
"report_title": self.report_title,
"paragraphs": [p.to_dict() for p in self.paragraphs],
"final_report": self.final_report,
"is_completed": self.is_completed,
"created_at": self.created_at,
"updated_at": self.updated_at
}
def to_json(self, indent: int = 2) -> str:
"""转换为JSON字符串"""
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "State":
"""从字典创建State对象"""
paragraphs = [Paragraph.from_dict(p_data) for p_data in data.get("paragraphs", [])]
return cls(
query=data.get("query", ""),
report_title=data.get("report_title", ""),
paragraphs=paragraphs,
final_report=data.get("final_report", ""),
is_completed=data.get("is_completed", False),
created_at=data.get("created_at", datetime.now().isoformat()),
updated_at=data.get("updated_at", datetime.now().isoformat())
)
@classmethod
def from_json(cls, json_str: str) -> "State":
"""从JSON字符串创建State对象"""
data = json.loads(json_str)
return cls.from_dict(data)
def save_to_file(self, filepath: str):
"""保存状态到文件"""
with open(filepath, 'w', encoding='utf-8') as f:
f.write(self.to_json())
@classmethod
def load_from_file(cls, filepath: str) -> "State":
"""从文件加载状态"""
with open(filepath, 'r', encoding='utf-8') as f:
json_str = f.read()
return cls.from_json(json_str)
+20
View File
@@ -0,0 +1,20 @@
"""
工具调用模块
提供外部工具接口,如网络搜索等
"""
from .search import (
TavilyNewsAgency,
SearchResult,
TavilyResponse,
ImageResult,
print_response_summary
)
__all__ = [
"TavilyNewsAgency",
"SearchResult",
"TavilyResponse",
"ImageResult",
"print_response_summary"
]
+240
View File
@@ -0,0 +1,240 @@
"""
专为 AI Agent 设计的舆情搜索工具集 (Tavily)
版本: 1.5
最后更新: 2025-08-22
此脚本将复杂的Tavily搜索功能分解为一系列目标明确、参数极少的独立工具,
专为AI Agent调用而设计。Agent只需根据任务意图选择合适的工具,
无需理解复杂的参数组合。所有工具默认搜索“新闻”(topic='news')。
新特性:
- 新增 `basic_search_news` 工具,用于执行标准、通用的新闻搜索。
- 每个搜索结果现在都包含 `published_date` (新闻发布日期)。
主要工具:
- basic_search_news: (新增) 执行标准、快速的通用新闻搜索。
- deep_search_news: 对主题进行最全面的深度分析。
- search_news_last_24_hours: 获取24小时内的最新动态。
- search_news_last_week: 获取过去一周的主要报道。
- search_images_for_news: 查找与新闻主题相关的图片。
- search_news_by_date: 在指定的历史日期范围内搜索。
"""
import os
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
# 运行前请确保已安装Tavily库: pip install tavily-python
try:
from tavily import TavilyClient
except ImportError:
raise ImportError("Tavily库未安装,请运行 `pip install tavily-python` 进行安装。")
# --- 1. 数据结构定义 ---
@dataclass
class SearchResult:
"""
网页搜索结果数据类
包含 published_date 属性来存储新闻发布日期
"""
title: str
url: str
content: str
score: Optional[float] = None
raw_content: Optional[str] = None
published_date: Optional[str] = None
@dataclass
class ImageResult:
"""图片搜索结果数据类"""
url: str
description: Optional[str] = None
@dataclass
class TavilyResponse:
"""封装Tavily API的完整返回结果,以便在工具间传递"""
query: str
answer: Optional[str] = None
results: List[SearchResult] = field(default_factory=list)
images: List[ImageResult] = field(default_factory=list)
response_time: Optional[float] = None
# --- 2. 核心客户端与专用工具集 ---
class TavilyNewsAgency:
"""
一个包含多种专用新闻舆情搜索工具的客户端。
每个公共方法都设计为供 AI Agent 独立调用的工具。
"""
def __init__(self, api_key: Optional[str] = None):
"""
初始化客户端。
Args:
api_key: Tavily API密钥,若不提供则从环境变量 TAVILY_API_KEY 读取。
"""
if api_key is None:
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
raise ValueError("Tavily API Key未找到!请设置TAVILY_API_KEY环境变量或在初始化时提供")
self._client = TavilyClient(api_key=api_key)
def _search_internal(self, **kwargs) -> TavilyResponse:
"""内部通用的搜索执行器,所有工具最终都调用此方法"""
try:
kwargs['topic'] = 'general'
api_params = {k: v for k, v in kwargs.items() if v is not None}
response_dict = self._client.search(**api_params)
search_results = [
SearchResult(
title=item.get('title'),
url=item.get('url'),
content=item.get('content'),
score=item.get('score'),
raw_content=item.get('raw_content'),
published_date=item.get('published_date')
) for item in response_dict.get('results', [])
]
image_results = [ImageResult(url=item.get('url'), description=item.get('description')) for item in response_dict.get('images', [])]
return TavilyResponse(
query=response_dict.get('query'), answer=response_dict.get('answer'),
results=search_results, images=image_results,
response_time=response_dict.get('response_time')
)
except Exception as e:
print(f"搜索时发生错误: {str(e)}")
return TavilyResponse(query=kwargs.get("query", "Unknown Query"))
# --- Agent 可用的工具方法 ---
def basic_search_news(self, query: str, max_results: int = 7) -> TavilyResponse:
"""
【工具】基础新闻搜索: 执行一次标准、快速的新闻搜索。
这是最常用的通用搜索工具,适用于不确定需要何种特定搜索时。
Agent可提供搜索查询(query)和可选的最大结果数(max_results)。
"""
print(f"--- TOOL: 基础新闻搜索 (query: {query}) ---")
return self._search_internal(
query=query,
max_results=max_results,
search_depth="basic",
include_answer=False
)
def deep_search_news(self, query: str) -> TavilyResponse:
"""
【工具】深度新闻分析: 对一个主题进行最全面、最深入的搜索。
返回AI生成的“高级”详细摘要答案和最多20条最相关的新闻结果。适用于需要全面了解某个事件背景的场景。
Agent只需提供搜索查询(query)。
"""
print(f"--- TOOL: 深度新闻分析 (query: {query}) ---")
return self._search_internal(
query=query, search_depth="advanced", max_results=20, include_answer="advanced"
)
def search_news_last_24_hours(self, query: str) -> TavilyResponse:
"""
【工具】搜索24小时内新闻: 获取关于某个主题的最新动态。
此工具专门查找过去24小时内发布的新闻。适用于追踪突发事件或最新进展。
Agent只需提供搜索查询(query)。
"""
print(f"--- TOOL: 搜索24小时内新闻 (query: {query}) ---")
return self._search_internal(query=query, time_range='d', max_results=10)
def search_news_last_week(self, query: str) -> TavilyResponse:
"""
【工具】搜索本周新闻: 获取关于某个主题过去一周内的主要新闻报道。
适用于进行周度舆情总结或回顾。
Agent只需提供搜索查询(query)。
"""
print(f"--- TOOL: 搜索本周新闻 (query: {query}) ---")
return self._search_internal(query=query, time_range='w', max_results=10)
def search_images_for_news(self, query: str) -> TavilyResponse:
"""
【工具】查找新闻图片: 搜索与某个新闻主题相关的图片。
此工具会返回图片链接及描述,适用于需要为报告或文章配图的场景。
Agent只需提供搜索查询(query)。
"""
print(f"--- TOOL: 查找新闻图片 (query: {query}) ---")
return self._search_internal(
query=query, include_images=True, include_image_descriptions=True, max_results=5
)
def search_news_by_date(self, query: str, start_date: str, end_date: str) -> TavilyResponse:
"""
【工具】按指定日期范围搜索新闻: 在一个明确的历史时间段内搜索新闻。
这是唯一需要Agent提供详细时间参数的工具。适用于需要对特定历史事件进行分析的场景。
Agent需要提供查询(query)、开始日期(start_date)和结束日期(end_date),格式均为 'YYYY-MM-DD'
"""
print(f"--- TOOL: 按指定日期范围搜索新闻 (query: {query}, from: {start_date}, to: {end_date}) ---")
return self._search_internal(
query=query, start_date=start_date, end_date=end_date, max_results=15
)
# --- 3. 测试与使用示例 ---
def print_response_summary(response: TavilyResponse):
"""简化的打印函数,用于展示测试结果,现在会显示发布日期"""
if not response or not response.query:
print("未能获取有效响应。")
return
print(f"\n查询: '{response.query}' | 耗时: {response.response_time}s")
if response.answer:
print(f"AI摘要: {response.answer[:120]}...")
print(f"找到 {len(response.results)} 条网页, {len(response.images)} 张图片。")
if response.results:
first_result = response.results[0]
date_info = f"(发布于: {first_result.published_date})" if first_result.published_date else ""
print(f"第一条结果: {first_result.title} {date_info}")
print("-" * 60)
if __name__ == "__main__":
# 在运行前,请确保您已设置 TAVILY_API_KEY 环境变量
try:
# 初始化“新闻社”客户端,它内部包含了所有工具
agency = TavilyNewsAgency()
# 场景1: Agent 进行一次常规、快速的搜索
response1 = agency.basic_search_news(query="奥运会最新赛况", max_results=5)
print_response_summary(response1)
# 场景2: Agent 需要全面了解“全球芯片技术竞争”的背景
response2 = agency.deep_search_news(query="全球芯片技术竞争")
print_response_summary(response2)
# 场景3: Agent 需要追踪“GTC大会”的最新消息
response3 = agency.search_news_last_24_hours(query="Nvidia GTC大会 最新发布")
print_response_summary(response3)
# 场景4: Agent 需要为一篇关于“自动驾驶”的周报查找素材
response4 = agency.search_news_last_week(query="自动驾驶商业化落地")
print_response_summary(response4)
# 场景5: Agent 需要查找“韦伯太空望远镜”的新闻图片
response5 = agency.search_images_for_news(query="韦伯太空望远镜最新发现")
print_response_summary(response5)
# 场景6: Agent 需要研究2025年第一季度关于“人工智能法规”的新闻
response6 = agency.search_news_by_date(
query="人工智能法规",
start_date="2025-01-01",
end_date="2025-03-31"
)
print_response_summary(response6)
except ValueError as e:
print(f"初始化失败: {e}")
print("请确保 TAVILY_API_KEY 环境变量已正确设置。")
except Exception as e:
print(f"测试过程中发生未知错误: {e}")
+26
View File
@@ -0,0 +1,26 @@
"""
工具函数模块
提供文本处理、JSON解析等辅助功能
"""
from .text_processing import (
clean_json_tags,
clean_markdown_tags,
remove_reasoning_from_output,
extract_clean_response,
update_state_with_search_results,
format_search_results_for_prompt
)
from .config import Config, load_config
__all__ = [
"clean_json_tags",
"clean_markdown_tags",
"remove_reasoning_from_output",
"extract_clean_response",
"update_state_with_search_results",
"format_search_results_for_prompt",
"Config",
"load_config"
]
+162
View File
@@ -0,0 +1,162 @@
"""
配置管理模块
处理环境变量和配置参数
"""
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class Config:
"""配置类"""
# API密钥
deepseek_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
tavily_api_key: Optional[str] = None
# 模型配置
default_llm_provider: str = "deepseek" # deepseek 或 openai
deepseek_model: str = "deepseek-chat"
openai_model: str = "gpt-4o-mini"
# 搜索配置
search_timeout: int = 240
max_content_length: int = 20000
# Agent配置
max_reflections: int = 2
max_paragraphs: int = 5
# 输出配置
output_dir: str = "reports"
save_intermediate_states: bool = True
def validate(self) -> bool:
"""验证配置"""
# 检查必需的API密钥
if self.default_llm_provider == "deepseek" and not self.deepseek_api_key:
print("错误: DeepSeek API Key未设置")
return False
if self.default_llm_provider == "openai" and not self.openai_api_key:
print("错误: OpenAI API Key未设置")
return False
if not self.tavily_api_key:
print("错误: Tavily API Key未设置")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
"""从配置文件创建配置"""
if config_file.endswith('.py'):
# Python配置文件
import importlib.util
# 动态导入配置文件
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None),
openai_api_key=getattr(config_module, "OPENAI_API_KEY", None),
tavily_api_key=getattr(config_module, "TAVILY_API_KEY", None),
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"),
deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"),
openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"),
search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240),
max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000),
max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2),
max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5),
output_dir=getattr(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True)
)
else:
# .env格式配置文件
config_dict = {}
if os.path.exists(config_file):
with open(config_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
config_dict[key.strip()] = value.strip()
return cls(
deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
openai_api_key=config_dict.get("OPENAI_API_KEY"),
tavily_api_key=config_dict.get("TAVILY_API_KEY"),
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")),
max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")),
max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")),
max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")),
output_dir=config_dict.get("OUTPUT_DIR", "reports"),
save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true"
)
def load_config(config_file: Optional[str] = None) -> Config:
"""
加载配置
Args:
config_file: 配置文件路径,如果不指定则使用默认路径
Returns:
配置对象
"""
# 确定配置文件路径
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
# 尝试加载常见的配置文件
for config_path in ["config.py", "config.env", ".env"]:
if os.path.exists(config_path):
file_to_load = config_path
print(f"已找到配置文件: {config_path}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py 文件")
# 创建配置对象
config = Config.from_file(file_to_load)
# 验证配置
if not config.validate():
raise ValueError("配置验证失败,请检查配置文件中的API密钥")
return config
def print_config(config: Config):
"""打印配置信息(隐藏敏感信息)"""
print("\n=== 当前配置 ===")
print(f"LLM提供商: {config.default_llm_provider}")
print(f"DeepSeek模型: {config.deepseek_model}")
print(f"OpenAI模型: {config.openai_model}")
print(f"最大搜索结果数: {config.max_search_results}")
print(f"搜索超时: {config.search_timeout}")
print(f"最大内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}")
print(f"最大段落数: {config.max_paragraphs}")
print(f"输出目录: {config.output_dir}")
print(f"保存中间状态: {config.save_intermediate_states}")
# 显示API密钥状态(不显示实际密钥)
print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}")
print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}")
print(f"Tavily API Key: {'已设置' if config.tavily_api_key else '未设置'}")
print("==================\n")
+308
View File
@@ -0,0 +1,308 @@
"""
文本处理工具函数
用于清理LLM输出、解析JSON等
"""
import re
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
def clean_json_tags(text: str) -> str:
"""
清理文本中的JSON标签
Args:
text: 原始文本
Returns:
清理后的文本
"""
# 移除```json 和 ```标签
text = re.sub(r'```json\s*', '', text)
text = re.sub(r'```\s*$', '', text)
text = re.sub(r'```', '', text)
return text.strip()
def clean_markdown_tags(text: str) -> str:
"""
清理文本中的Markdown标签
Args:
text: 原始文本
Returns:
清理后的文本
"""
# 移除```markdown 和 ```标签
text = re.sub(r'```markdown\s*', '', text)
text = re.sub(r'```\s*$', '', text)
text = re.sub(r'```', '', text)
return text.strip()
def remove_reasoning_from_output(text: str) -> str:
"""
移除输出中的推理过程文本
Args:
text: 原始文本
Returns:
清理后的文本
"""
# 查找JSON开始位置
json_start = -1
# 尝试找到第一个 { 或 [
for i, char in enumerate(text):
if char in '{[':
json_start = i
break
if json_start != -1:
# 从JSON开始位置截取
return text[json_start:].strip()
# 如果没有找到JSON标记,尝试其他方法
# 移除常见的推理标识
patterns = [
r'(?:reasoning|推理|思考|分析)[:]\s*.*?(?=\{|\[)', # 移除推理部分
r'(?:explanation|解释|说明)[:]\s*.*?(?=\{|\[)', # 移除解释部分
r'^.*?(?=\{|\[)', # 移除JSON前的所有文本
]
for pattern in patterns:
text = re.sub(pattern, '', text, flags=re.IGNORECASE | re.DOTALL)
return text.strip()
def extract_clean_response(text: str) -> Dict[str, Any]:
"""
提取并清理响应中的JSON内容
Args:
text: 原始响应文本
Returns:
解析后的JSON字典
"""
# 清理文本
cleaned_text = clean_json_tags(text)
cleaned_text = remove_reasoning_from_output(cleaned_text)
# 尝试直接解析
try:
return json.loads(cleaned_text)
except JSONDecodeError:
pass
# 尝试修复不完整的JSON
fixed_text = fix_incomplete_json(cleaned_text)
if fixed_text:
try:
return json.loads(fixed_text)
except JSONDecodeError:
pass
# 尝试查找JSON对象
json_pattern = r'\{.*\}'
match = re.search(json_pattern, cleaned_text, re.DOTALL)
if match:
try:
return json.loads(match.group())
except JSONDecodeError:
pass
# 尝试查找JSON数组
array_pattern = r'\[.*\]'
match = re.search(array_pattern, cleaned_text, re.DOTALL)
if match:
try:
return json.loads(match.group())
except JSONDecodeError:
pass
# 如果所有方法都失败,返回错误信息
print(f"无法解析JSON响应: {cleaned_text[:200]}...")
return {"error": "JSON解析失败", "raw_text": cleaned_text}
def fix_incomplete_json(text: str) -> str:
"""
修复不完整的JSON响应
Args:
text: 原始文本
Returns:
修复后的JSON文本,如果无法修复则返回空字符串
"""
# 移除多余的逗号和空白
text = re.sub(r',\s*}', '}', text)
text = re.sub(r',\s*]', ']', text)
# 检查是否已经是有效的JSON
try:
json.loads(text)
return text
except JSONDecodeError:
pass
# 检查是否缺少开头的数组符号
if text.strip().startswith('{') and not text.strip().startswith('['):
# 如果以对象开始,尝试包装成数组
if text.count('{') > 1:
# 多个对象,包装成数组
text = '[' + text + ']'
else:
# 单个对象,包装成数组
text = '[' + text + ']'
# 检查是否缺少结尾的数组符号
if text.strip().endswith('}') and not text.strip().endswith(']'):
# 如果以对象结束,尝试包装成数组
if text.count('}') > 1:
# 多个对象,包装成数组
text = '[' + text + ']'
else:
# 单个对象,包装成数组
text = '[' + text + ']'
# 检查括号是否匹配
open_braces = text.count('{')
close_braces = text.count('}')
open_brackets = text.count('[')
close_brackets = text.count(']')
# 修复不匹配的括号
if open_braces > close_braces:
text += '}' * (open_braces - close_braces)
if open_brackets > close_brackets:
text += ']' * (open_brackets - close_brackets)
# 验证修复后的JSON是否有效
try:
json.loads(text)
return text
except JSONDecodeError:
# 如果仍然无效,尝试更激进的修复
return fix_aggressive_json(text)
def fix_aggressive_json(text: str) -> str:
"""
更激进的JSON修复方法
Args:
text: 原始文本
Returns:
修复后的JSON文本
"""
# 查找所有可能的JSON对象
objects = re.findall(r'\{[^{}]*\}', text)
if len(objects) >= 2:
# 如果有多个对象,包装成数组
return '[' + ','.join(objects) + ']'
elif len(objects) == 1:
# 如果只有一个对象,包装成数组
return '[' + objects[0] + ']'
else:
# 如果没有找到对象,返回空数组
return '[]'
def update_state_with_search_results(search_results: List[Dict[str, Any]],
paragraph_index: int, state: Any) -> Any:
"""
将搜索结果更新到状态中
Args:
search_results: 搜索结果列表
paragraph_index: 段落索引
state: 状态对象
Returns:
更新后的状态对象
"""
if 0 <= paragraph_index < len(state.paragraphs):
# 获取最后一次搜索的查询(假设是当前查询)
current_query = ""
if search_results:
# 从搜索结果推断查询(这里需要改进以获取实际查询)
current_query = "搜索查询"
# 添加搜索结果到状态
state.paragraphs[paragraph_index].research.add_search_results(
current_query, search_results
)
return state
def validate_json_schema(data: Dict[str, Any], required_fields: List[str]) -> bool:
"""
验证JSON数据是否包含必需字段
Args:
data: 要验证的数据
required_fields: 必需字段列表
Returns:
验证是否通过
"""
return all(field in data for field in required_fields)
def truncate_content(content: str, max_length: int = 20000) -> str:
"""
截断内容到指定长度
Args:
content: 原始内容
max_length: 最大长度
Returns:
截断后的内容
"""
if len(content) <= max_length:
return content
# 尝试在单词边界截断
truncated = content[:max_length]
last_space = truncated.rfind(' ')
if last_space > max_length * 0.8: # 如果最后一个空格位置合理
return truncated[:last_space] + "..."
else:
return truncated + "..."
def format_search_results_for_prompt(search_results: List[Dict[str, Any]],
max_length: int = 20000) -> List[str]:
"""
格式化搜索结果用于提示词
Args:
search_results: 搜索结果列表
max_length: 每个结果的最大长度
Returns:
格式化后的内容列表
"""
formatted_results = []
for result in search_results:
content = result.get('content', '')
if content:
truncated_content = truncate_content(content, max_length)
formatted_results.append(truncated_content)
return formatted_results