Reconfiguration of the basic multi-agent architecture.
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Deep Search Agent
|
||||
一个无框架的深度搜索AI代理实现
|
||||
"""
|
||||
|
||||
from .agent import DeepSearchAgent, create_agent
|
||||
from .utils.config import Config, load_config
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Deep Search Agent Team"
|
||||
|
||||
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
|
||||
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Deep Search Agent主类
|
||||
整合所有模块,实现完整的深度搜索流程
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from .llms import DeepSeekLLM, OpenAILLM, BaseLLM
|
||||
from .nodes import (
|
||||
ReportStructureNode,
|
||||
FirstSearchNode,
|
||||
ReflectionNode,
|
||||
FirstSummaryNode,
|
||||
ReflectionSummaryNode,
|
||||
ReportFormattingNode
|
||||
)
|
||||
from .state import State
|
||||
from .tools import TavilyNewsAgency, TavilyResponse
|
||||
from .utils import Config, load_config, format_search_results_for_prompt
|
||||
|
||||
|
||||
class DeepSearchAgent:
|
||||
"""Deep Search Agent主类"""
|
||||
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
"""
|
||||
初始化Deep Search Agent
|
||||
|
||||
Args:
|
||||
config: 配置对象,如果不提供则自动加载
|
||||
"""
|
||||
# 加载配置
|
||||
self.config = config or load_config()
|
||||
|
||||
# 初始化LLM客户端
|
||||
self.llm_client = self._initialize_llm()
|
||||
|
||||
# 初始化搜索工具集
|
||||
self.search_agency = TavilyNewsAgency(api_key=self.config.tavily_api_key)
|
||||
|
||||
# 初始化节点
|
||||
self._initialize_nodes()
|
||||
|
||||
# 状态
|
||||
self.state = State()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(self.config.output_dir, exist_ok=True)
|
||||
|
||||
print(f"Deep Search Agent 已初始化")
|
||||
print(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
print(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)")
|
||||
|
||||
def _initialize_llm(self) -> BaseLLM:
|
||||
"""初始化LLM客户端"""
|
||||
if self.config.default_llm_provider == "deepseek":
|
||||
return DeepSeekLLM(
|
||||
api_key=self.config.deepseek_api_key,
|
||||
model_name=self.config.deepseek_model
|
||||
)
|
||||
elif self.config.default_llm_provider == "openai":
|
||||
return OpenAILLM(
|
||||
api_key=self.config.openai_api_key,
|
||||
model_name=self.config.openai_model
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}")
|
||||
|
||||
def _initialize_nodes(self):
|
||||
"""初始化处理节点"""
|
||||
self.first_search_node = FirstSearchNode(self.llm_client)
|
||||
self.reflection_node = ReflectionNode(self.llm_client)
|
||||
self.first_summary_node = FirstSummaryNode(self.llm_client)
|
||||
self.reflection_summary_node = ReflectionSummaryNode(self.llm_client)
|
||||
self.report_formatting_node = ReportFormattingNode(self.llm_client)
|
||||
|
||||
def _validate_date_format(self, date_str: str) -> bool:
|
||||
"""
|
||||
验证日期格式是否为YYYY-MM-DD
|
||||
|
||||
Args:
|
||||
date_str: 日期字符串
|
||||
|
||||
Returns:
|
||||
是否为有效格式
|
||||
"""
|
||||
if not date_str:
|
||||
return False
|
||||
|
||||
# 检查格式
|
||||
pattern = r'^\d{4}-\d{2}-\d{2}$'
|
||||
if not re.match(pattern, date_str):
|
||||
return False
|
||||
|
||||
# 检查日期是否有效
|
||||
try:
|
||||
datetime.strptime(date_str, '%Y-%m-%d')
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyResponse:
|
||||
"""
|
||||
执行指定的搜索工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称,可选值:
|
||||
- "basic_search_news": 基础新闻搜索(快速、通用)
|
||||
- "deep_search_news": 深度新闻分析
|
||||
- "search_news_last_24_hours": 24小时内最新新闻
|
||||
- "search_news_last_week": 本周新闻
|
||||
- "search_images_for_news": 新闻图片搜索
|
||||
- "search_news_by_date": 按日期范围搜索新闻
|
||||
query: 搜索查询
|
||||
**kwargs: 额外参数(如start_date, end_date, max_results)
|
||||
|
||||
Returns:
|
||||
TavilyResponse对象
|
||||
"""
|
||||
print(f" → 执行搜索工具: {tool_name}")
|
||||
|
||||
if tool_name == "basic_search_news":
|
||||
max_results = kwargs.get("max_results", 7)
|
||||
return self.search_agency.basic_search_news(query, max_results)
|
||||
elif tool_name == "deep_search_news":
|
||||
return self.search_agency.deep_search_news(query)
|
||||
elif tool_name == "search_news_last_24_hours":
|
||||
return self.search_agency.search_news_last_24_hours(query)
|
||||
elif tool_name == "search_news_last_week":
|
||||
return self.search_agency.search_news_last_week(query)
|
||||
elif tool_name == "search_images_for_news":
|
||||
return self.search_agency.search_images_for_news(query)
|
||||
elif tool_name == "search_news_by_date":
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
if not start_date or not end_date:
|
||||
raise ValueError("search_news_by_date工具需要start_date和end_date参数")
|
||||
return self.search_agency.search_news_by_date(query, start_date, end_date)
|
||||
else:
|
||||
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索")
|
||||
return self.search_agency.basic_search_news(query)
|
||||
|
||||
def research(self, query: str, save_report: bool = True) -> str:
|
||||
"""
|
||||
执行深度研究
|
||||
|
||||
Args:
|
||||
query: 研究查询
|
||||
save_report: 是否保存报告到文件
|
||||
|
||||
Returns:
|
||||
最终报告内容
|
||||
"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"开始深度研究: {query}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
# Step 1: 生成报告结构
|
||||
self._generate_report_structure(query)
|
||||
|
||||
# Step 2: 处理每个段落
|
||||
self._process_paragraphs()
|
||||
|
||||
# Step 3: 生成最终报告
|
||||
final_report = self._generate_final_report()
|
||||
|
||||
# Step 4: 保存报告
|
||||
if save_report:
|
||||
self._save_report(final_report)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("深度研究完成!")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return final_report
|
||||
|
||||
except Exception as e:
|
||||
print(f"研究过程中发生错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
def _generate_report_structure(self, query: str):
|
||||
"""生成报告结构"""
|
||||
print(f"\n[步骤 1] 生成报告结构...")
|
||||
|
||||
# 创建报告结构节点
|
||||
report_structure_node = ReportStructureNode(self.llm_client, query)
|
||||
|
||||
# 生成结构并更新状态
|
||||
self.state = report_structure_node.mutate_state(state=self.state)
|
||||
|
||||
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
|
||||
for i, paragraph in enumerate(self.state.paragraphs, 1):
|
||||
print(f" {i}. {paragraph.title}")
|
||||
|
||||
def _process_paragraphs(self):
|
||||
"""处理所有段落"""
|
||||
total_paragraphs = len(self.state.paragraphs)
|
||||
|
||||
for i in range(total_paragraphs):
|
||||
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
|
||||
print("-" * 50)
|
||||
|
||||
# 初始搜索和总结
|
||||
self._initial_search_and_summary(i)
|
||||
|
||||
# 反思循环
|
||||
self._reflection_loop(i)
|
||||
|
||||
# 标记段落完成
|
||||
self.state.paragraphs[i].research.mark_completed()
|
||||
|
||||
progress = (i + 1) / total_paragraphs * 100
|
||||
print(f"段落处理完成 ({progress:.1f}%)")
|
||||
|
||||
def _initial_search_and_summary(self, paragraph_index: int):
|
||||
"""执行初始搜索和总结"""
|
||||
paragraph = self.state.paragraphs[paragraph_index]
|
||||
|
||||
# 准备搜索输入
|
||||
search_input = {
|
||||
"title": paragraph.title,
|
||||
"content": paragraph.content
|
||||
}
|
||||
|
||||
# 生成搜索查询和工具选择
|
||||
print(" - 生成搜索查询...")
|
||||
search_output = self.first_search_node.run(search_input)
|
||||
search_query = search_output["search_query"]
|
||||
search_tool = search_output.get("search_tool", "basic_search_news") # 默认工具
|
||||
reasoning = search_output["reasoning"]
|
||||
|
||||
print(f" - 搜索查询: {search_query}")
|
||||
print(f" - 选择的工具: {search_tool}")
|
||||
print(f" - 推理: {reasoning}")
|
||||
|
||||
# 执行搜索
|
||||
print(" - 执行网络搜索...")
|
||||
|
||||
# 处理search_news_by_date的特殊参数
|
||||
search_kwargs = {}
|
||||
if search_tool == "search_news_by_date":
|
||||
start_date = search_output.get("start_date")
|
||||
end_date = search_output.get("end_date")
|
||||
|
||||
if start_date and end_date:
|
||||
# 验证日期格式
|
||||
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
|
||||
search_kwargs["start_date"] = start_date
|
||||
search_kwargs["end_date"] = end_date
|
||||
print(f" - 时间范围: {start_date} 到 {end_date}")
|
||||
else:
|
||||
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
|
||||
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
search_tool = "basic_search_news"
|
||||
else:
|
||||
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
|
||||
search_tool = "basic_search_news"
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
|
||||
# 转换为兼容格式
|
||||
search_results = []
|
||||
if search_response and search_response.results:
|
||||
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
|
||||
max_results = min(len(search_response.results), 10)
|
||||
for result in search_response.results[:max_results]:
|
||||
search_results.append({
|
||||
'title': result.title,
|
||||
'url': result.url,
|
||||
'content': result.content,
|
||||
'score': result.score,
|
||||
'raw_content': result.raw_content,
|
||||
'published_date': result.published_date # 新增字段
|
||||
})
|
||||
|
||||
if search_results:
|
||||
print(f" - 找到 {len(search_results)} 个搜索结果")
|
||||
for j, result in enumerate(search_results, 1):
|
||||
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
|
||||
print(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
else:
|
||||
print(" - 未找到搜索结果")
|
||||
|
||||
# 更新状态中的搜索历史
|
||||
paragraph.research.add_search_results(search_query, search_results)
|
||||
|
||||
# 生成初始总结
|
||||
print(" - 生成初始总结...")
|
||||
summary_input = {
|
||||
"title": paragraph.title,
|
||||
"content": paragraph.content,
|
||||
"search_query": search_query,
|
||||
"search_results": format_search_results_for_prompt(
|
||||
search_results, self.config.max_content_length
|
||||
)
|
||||
}
|
||||
|
||||
# 更新状态
|
||||
self.state = self.first_summary_node.mutate_state(
|
||||
summary_input, self.state, paragraph_index
|
||||
)
|
||||
|
||||
print(" - 初始总结完成")
|
||||
|
||||
def _reflection_loop(self, paragraph_index: int):
|
||||
"""执行反思循环"""
|
||||
paragraph = self.state.paragraphs[paragraph_index]
|
||||
|
||||
for reflection_i in range(self.config.max_reflections):
|
||||
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
|
||||
|
||||
# 准备反思输入
|
||||
reflection_input = {
|
||||
"title": paragraph.title,
|
||||
"content": paragraph.content,
|
||||
"paragraph_latest_state": paragraph.research.latest_summary
|
||||
}
|
||||
|
||||
# 生成反思搜索查询
|
||||
reflection_output = self.reflection_node.run(reflection_input)
|
||||
search_query = reflection_output["search_query"]
|
||||
search_tool = reflection_output.get("search_tool", "basic_search_news") # 默认工具
|
||||
reasoning = reflection_output["reasoning"]
|
||||
|
||||
print(f" 反思查询: {search_query}")
|
||||
print(f" 选择的工具: {search_tool}")
|
||||
print(f" 反思推理: {reasoning}")
|
||||
|
||||
# 执行反思搜索
|
||||
# 处理search_news_by_date的特殊参数
|
||||
search_kwargs = {}
|
||||
if search_tool == "search_news_by_date":
|
||||
start_date = reflection_output.get("start_date")
|
||||
end_date = reflection_output.get("end_date")
|
||||
|
||||
if start_date and end_date:
|
||||
# 验证日期格式
|
||||
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
|
||||
search_kwargs["start_date"] = start_date
|
||||
search_kwargs["end_date"] = end_date
|
||||
print(f" 时间范围: {start_date} 到 {end_date}")
|
||||
else:
|
||||
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
|
||||
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
search_tool = "basic_search_news"
|
||||
else:
|
||||
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
|
||||
search_tool = "basic_search_news"
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
|
||||
# 转换为兼容格式
|
||||
search_results = []
|
||||
if search_response and search_response.results:
|
||||
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
|
||||
max_results = min(len(search_response.results), 10)
|
||||
for result in search_response.results[:max_results]:
|
||||
search_results.append({
|
||||
'title': result.title,
|
||||
'url': result.url,
|
||||
'content': result.content,
|
||||
'score': result.score,
|
||||
'raw_content': result.raw_content,
|
||||
'published_date': result.published_date
|
||||
})
|
||||
|
||||
if search_results:
|
||||
print(f" 找到 {len(search_results)} 个反思搜索结果")
|
||||
for j, result in enumerate(search_results, 1):
|
||||
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
|
||||
print(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
else:
|
||||
print(" 未找到反思搜索结果")
|
||||
|
||||
# 更新搜索历史
|
||||
paragraph.research.add_search_results(search_query, search_results)
|
||||
|
||||
# 生成反思总结
|
||||
reflection_summary_input = {
|
||||
"title": paragraph.title,
|
||||
"content": paragraph.content,
|
||||
"search_query": search_query,
|
||||
"search_results": format_search_results_for_prompt(
|
||||
search_results, self.config.max_content_length
|
||||
),
|
||||
"paragraph_latest_state": paragraph.research.latest_summary
|
||||
}
|
||||
|
||||
# 更新状态
|
||||
self.state = self.reflection_summary_node.mutate_state(
|
||||
reflection_summary_input, self.state, paragraph_index
|
||||
)
|
||||
|
||||
print(f" 反思 {reflection_i + 1} 完成")
|
||||
|
||||
def _generate_final_report(self) -> str:
|
||||
"""生成最终报告"""
|
||||
print(f"\n[步骤 3] 生成最终报告...")
|
||||
|
||||
# 准备报告数据
|
||||
report_data = []
|
||||
for paragraph in self.state.paragraphs:
|
||||
report_data.append({
|
||||
"title": paragraph.title,
|
||||
"paragraph_latest_state": paragraph.research.latest_summary
|
||||
})
|
||||
|
||||
# 格式化报告
|
||||
try:
|
||||
final_report = self.report_formatting_node.run(report_data)
|
||||
except Exception as e:
|
||||
print(f"LLM格式化失败,使用备用方法: {str(e)}")
|
||||
final_report = self.report_formatting_node.format_report_manually(
|
||||
report_data, self.state.report_title
|
||||
)
|
||||
|
||||
# 更新状态
|
||||
self.state.final_report = final_report
|
||||
self.state.mark_completed()
|
||||
|
||||
print("最终报告生成完成")
|
||||
return final_report
|
||||
|
||||
def _save_report(self, report_content: str):
|
||||
"""保存报告到文件"""
|
||||
# 生成文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip()
|
||||
query_safe = query_safe.replace(' ', '_')[:30]
|
||||
|
||||
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
|
||||
filepath = os.path.join(self.config.output_dir, filename)
|
||||
|
||||
# 保存报告
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(report_content)
|
||||
|
||||
print(f"报告已保存到: {filepath}")
|
||||
|
||||
# 保存状态(如果配置允许)
|
||||
if self.config.save_intermediate_states:
|
||||
state_filename = f"state_{query_safe}_{timestamp}.json"
|
||||
state_filepath = os.path.join(self.config.output_dir, state_filename)
|
||||
self.state.save_to_file(state_filepath)
|
||||
print(f"状态已保存到: {state_filepath}")
|
||||
|
||||
def get_progress_summary(self) -> Dict[str, Any]:
|
||||
"""获取进度摘要"""
|
||||
return self.state.get_progress_summary()
|
||||
|
||||
def load_state(self, filepath: str):
|
||||
"""从文件加载状态"""
|
||||
self.state = State.load_from_file(filepath)
|
||||
print(f"状态已从 {filepath} 加载")
|
||||
|
||||
def save_state(self, filepath: str):
|
||||
"""保存状态到文件"""
|
||||
self.state.save_to_file(filepath)
|
||||
print(f"状态已保存到 {filepath}")
|
||||
|
||||
|
||||
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
|
||||
"""
|
||||
创建Deep Search Agent实例的便捷函数
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
|
||||
Returns:
|
||||
DeepSearchAgent实例
|
||||
"""
|
||||
config = load_config(config_file)
|
||||
return DeepSearchAgent(config)
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
LLM调用模块
|
||||
支持多种大语言模型的统一接口
|
||||
"""
|
||||
|
||||
from .base import BaseLLM
|
||||
from .deepseek import DeepSeekLLM
|
||||
from .openai_llm import OpenAILLM
|
||||
|
||||
__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"]
|
||||
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
LLM基础抽象类
|
||||
定义所有LLM实现需要遵循的接口标准
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""LLM基础抽象类"""
|
||||
|
||||
def __init__(self, api_key: str, model_name: Optional[str] = None):
|
||||
"""
|
||||
初始化LLM客户端
|
||||
|
||||
Args:
|
||||
api_key: API密钥
|
||||
model_name: 模型名称,如果不指定则使用默认模型
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model_name = model_name
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM生成回复
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数,如temperature、max_tokens等
|
||||
|
||||
Returns:
|
||||
LLM生成的回复文本
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""
|
||||
获取默认模型名称
|
||||
|
||||
Returns:
|
||||
默认模型名称
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_response(self, response: str) -> str:
|
||||
"""
|
||||
验证和清理响应内容
|
||||
|
||||
Args:
|
||||
response: LLM原始响应
|
||||
|
||||
Returns:
|
||||
清理后的响应内容
|
||||
"""
|
||||
if response is None:
|
||||
return ""
|
||||
return response.strip()
|
||||
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
DeepSeek LLM实现
|
||||
使用DeepSeek API进行文本生成
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from openai import OpenAI
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class DeepSeekLLM(BaseLLM):
|
||||
"""DeepSeek LLM实现类"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
|
||||
"""
|
||||
初始化DeepSeek客户端
|
||||
|
||||
Args:
|
||||
api_key: DeepSeek API密钥,如果不提供则从环境变量读取
|
||||
model_name: 模型名称,默认使用deepseek-chat
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("DeepSeek API Key未找到!请设置DEEPSEEK_API_KEY环境变量或在初始化时提供")
|
||||
|
||||
super().__init__(api_key, model_name)
|
||||
|
||||
# 初始化OpenAI客户端,使用DeepSeek的endpoint
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url="https://api.deepseek.com"
|
||||
)
|
||||
|
||||
self.default_model = model_name or self.get_default_model()
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""获取默认模型名称"""
|
||||
return "deepseek-chat"
|
||||
|
||||
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
调用DeepSeek API生成回复
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数,如temperature、max_tokens等
|
||||
|
||||
Returns:
|
||||
DeepSeek生成的回复文本
|
||||
"""
|
||||
try:
|
||||
# 构建消息
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 设置默认参数
|
||||
params = {
|
||||
"model": self.default_model,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 4000),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# 调用API
|
||||
response = self.client.chat.completions.create(**params)
|
||||
|
||||
# 提取回复内容
|
||||
if response.choices and response.choices[0].message:
|
||||
content = response.choices[0].message.content
|
||||
return self.validate_response(content)
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
print(f"DeepSeek API调用错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
return {
|
||||
"provider": "DeepSeek",
|
||||
"model": self.default_model,
|
||||
"api_base": "https://api.deepseek.com"
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
OpenAI LLM实现
|
||||
使用OpenAI API进行文本生成
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from openai import OpenAI
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""OpenAI LLM实现类"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
|
||||
"""
|
||||
初始化OpenAI客户端
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥,如果不提供则从环境变量读取
|
||||
model_name: 模型名称,默认使用gpt-4o-mini
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API Key未找到!请设置OPENAI_API_KEY环境变量或在初始化时提供")
|
||||
|
||||
super().__init__(api_key, model_name)
|
||||
|
||||
# 初始化OpenAI客户端
|
||||
self.client = OpenAI(api_key=self.api_key)
|
||||
self.default_model = model_name or self.get_default_model()
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""获取默认模型名称"""
|
||||
return "gpt-4o-mini"
|
||||
|
||||
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
调用OpenAI API生成回复
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数,如temperature、max_tokens等
|
||||
|
||||
Returns:
|
||||
OpenAI生成的回复文本
|
||||
"""
|
||||
try:
|
||||
# 构建消息
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 设置默认参数
|
||||
params = {
|
||||
"model": self.default_model,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 4000)
|
||||
}
|
||||
|
||||
# 调用API
|
||||
response = self.client.chat.completions.create(**params)
|
||||
|
||||
# 提取回复内容
|
||||
if response.choices and response.choices[0].message:
|
||||
content = response.choices[0].message.content
|
||||
return self.validate_response(content)
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
print(f"OpenAI API调用错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
return {
|
||||
"provider": "OpenAI",
|
||||
"model": self.default_model,
|
||||
"api_base": "https://api.openai.com"
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
节点处理模块
|
||||
实现Deep Search Agent的各个处理步骤
|
||||
"""
|
||||
|
||||
from .base_node import BaseNode
|
||||
from .report_structure_node import ReportStructureNode
|
||||
from .search_node import FirstSearchNode, ReflectionNode
|
||||
from .summary_node import FirstSummaryNode, ReflectionSummaryNode
|
||||
from .formatting_node import ReportFormattingNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
"ReportStructureNode",
|
||||
"FirstSearchNode",
|
||||
"ReflectionNode",
|
||||
"FirstSummaryNode",
|
||||
"ReflectionSummaryNode",
|
||||
"ReportFormattingNode"
|
||||
]
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
节点基类
|
||||
定义所有处理节点的基础接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from ..llms.base import BaseLLM
|
||||
from ..state.state import State
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类"""
|
||||
|
||||
def __init__(self, llm_client: BaseLLM, node_name: str = ""):
|
||||
"""
|
||||
初始化节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
node_name: 节点名称
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.node_name = node_name or self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input_data: Any, **kwargs) -> Any:
|
||||
"""
|
||||
执行节点处理逻辑
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
处理结果
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""
|
||||
验证输入数据
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
return True
|
||||
|
||||
def process_output(self, output: Any) -> Any:
|
||||
"""
|
||||
处理输出数据
|
||||
|
||||
Args:
|
||||
output: 原始输出
|
||||
|
||||
Returns:
|
||||
处理后的输出
|
||||
"""
|
||||
return output
|
||||
|
||||
def log_info(self, message: str):
|
||||
"""记录信息日志"""
|
||||
print(f"[{self.node_name}] {message}")
|
||||
|
||||
def log_error(self, message: str):
|
||||
"""记录错误日志"""
|
||||
print(f"[{self.node_name}] 错误: {message}")
|
||||
|
||||
|
||||
class StateMutationNode(BaseNode):
|
||||
"""带状态修改功能的节点基类"""
|
||||
|
||||
@abstractmethod
|
||||
def mutate_state(self, input_data: Any, state: State, **kwargs) -> State:
|
||||
"""
|
||||
修改状态
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
state: 当前状态
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
修改后的状态
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
报告格式化节点
|
||||
负责将最终研究结果格式化为美观的Markdown报告
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_markdown_tags
|
||||
)
|
||||
|
||||
|
||||
class ReportFormattingNode(BaseNode):
|
||||
"""格式化最终报告的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化报告格式化节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "ReportFormattingNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
return isinstance(data, list) and all(
|
||||
isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
|
||||
for item in data
|
||||
)
|
||||
except:
|
||||
return False
|
||||
elif isinstance(input_data, list):
|
||||
return all(
|
||||
isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
|
||||
for item in input_data
|
||||
)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM生成Markdown格式报告
|
||||
|
||||
Args:
|
||||
input_data: 包含所有段落信息的列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
格式化的Markdown报告
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误,需要包含title和paragraph_latest_state的列表")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在格式化最终报告")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成格式化报告")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"报告格式化失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
"""
|
||||
处理LLM输出,清理Markdown格式
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
清理后的Markdown报告
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_markdown_tags(cleaned_output)
|
||||
|
||||
# 确保报告有基本结构
|
||||
if not cleaned_output.strip():
|
||||
return "# 报告生成失败\n\n无法生成有效的报告内容。"
|
||||
|
||||
# 如果没有标题,添加一个默认标题
|
||||
if not cleaned_output.strip().startswith('#'):
|
||||
cleaned_output = "# 深度研究报告\n\n" + cleaned_output
|
||||
|
||||
return cleaned_output.strip()
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
|
||||
|
||||
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
|
||||
report_title: str = "深度研究报告") -> str:
|
||||
"""
|
||||
手动格式化报告(备用方法)
|
||||
|
||||
Args:
|
||||
paragraphs_data: 段落数据列表
|
||||
report_title: 报告标题
|
||||
|
||||
Returns:
|
||||
格式化的Markdown报告
|
||||
"""
|
||||
try:
|
||||
self.log_info("使用手动格式化方法")
|
||||
|
||||
# 构建报告
|
||||
report_lines = [
|
||||
f"# {report_title}",
|
||||
"",
|
||||
"---",
|
||||
""
|
||||
]
|
||||
|
||||
# 添加各个段落
|
||||
for i, paragraph in enumerate(paragraphs_data, 1):
|
||||
title = paragraph.get("title", f"段落 {i}")
|
||||
content = paragraph.get("paragraph_latest_state", "")
|
||||
|
||||
if content:
|
||||
report_lines.extend([
|
||||
f"## {title}",
|
||||
"",
|
||||
content,
|
||||
"",
|
||||
"---",
|
||||
""
|
||||
])
|
||||
|
||||
# 添加结论
|
||||
if len(paragraphs_data) > 1:
|
||||
report_lines.extend([
|
||||
"## 结论",
|
||||
"",
|
||||
"本报告通过深度搜索和研究,对相关主题进行了全面分析。"
|
||||
"以上各个方面的内容为理解该主题提供了重要参考。",
|
||||
""
|
||||
])
|
||||
|
||||
return "\n".join(report_lines)
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"手动格式化失败: {str(e)}")
|
||||
return "# 报告生成失败\n\n无法完成报告格式化。"
|
||||
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
报告结构生成节点
|
||||
负责根据查询生成报告的整体结构
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..state.state import State
|
||||
from ..prompts import SYSTEM_PROMPT_REPORT_STRUCTURE
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_json_tags,
|
||||
extract_clean_response,
|
||||
fix_incomplete_json
|
||||
)
|
||||
|
||||
|
||||
class ReportStructureNode(StateMutationNode):
|
||||
"""生成报告结构的节点"""
|
||||
|
||||
def __init__(self, llm_client, query: str):
|
||||
"""
|
||||
初始化报告结构节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
query: 用户查询
|
||||
"""
|
||||
super().__init__(llm_client, "ReportStructureNode")
|
||||
self.query = query
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
return isinstance(self.query, str) and len(self.query.strip()) > 0
|
||||
|
||||
def run(self, input_data: Any = None, **kwargs) -> List[Dict[str, str]]:
|
||||
"""
|
||||
调用LLM生成报告结构
|
||||
|
||||
Args:
|
||||
input_data: 输入数据(这里不使用,使用初始化时的query)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
报告结构列表
|
||||
"""
|
||||
try:
|
||||
self.log_info(f"正在为查询生成报告结构: {self.query}")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成报告结构失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
处理LLM输出,提取报告结构
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
处理后的报告结构列表
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
report_structure = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
report_structure = extract_clean_response(cleaned_output)
|
||||
if "error" in report_structure:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
report_structure = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
# 返回默认结构
|
||||
return self._generate_default_structure()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
# 验证结构
|
||||
if not isinstance(report_structure, list):
|
||||
self.log_info("报告结构不是列表,尝试转换...")
|
||||
if isinstance(report_structure, dict):
|
||||
# 如果是单个对象,包装成列表
|
||||
report_structure = [report_structure]
|
||||
else:
|
||||
self.log_error("报告结构格式无效,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
# 验证每个段落
|
||||
validated_structure = []
|
||||
for i, paragraph in enumerate(report_structure):
|
||||
if not isinstance(paragraph, dict):
|
||||
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
|
||||
continue
|
||||
|
||||
title = paragraph.get("title", f"段落 {i+1}")
|
||||
content = paragraph.get("content", "")
|
||||
|
||||
if not title or not content:
|
||||
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
|
||||
continue
|
||||
|
||||
validated_structure.append({
|
||||
"title": title,
|
||||
"content": content
|
||||
})
|
||||
|
||||
if not validated_structure:
|
||||
self.log_warning("没有有效的段落结构,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
|
||||
return validated_structure
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
return self._generate_default_structure()
|
||||
|
||||
def _generate_default_structure(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
生成默认的报告结构
|
||||
|
||||
Returns:
|
||||
默认的报告结构列表
|
||||
"""
|
||||
self.log_info("生成默认报告结构")
|
||||
return [
|
||||
{
|
||||
"title": "研究概述",
|
||||
"content": "对查询主题进行总体概述和分析"
|
||||
},
|
||||
{
|
||||
"title": "深度分析",
|
||||
"content": "深入分析查询主题的各个方面"
|
||||
}
|
||||
]
|
||||
|
||||
def mutate_state(self, input_data: Any = None, state: State = None, **kwargs) -> State:
|
||||
"""
|
||||
将报告结构写入状态
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
state: 当前状态,如果为None则创建新状态
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
if state is None:
|
||||
state = State()
|
||||
|
||||
try:
|
||||
# 生成报告结构
|
||||
report_structure = self.run(input_data, **kwargs)
|
||||
|
||||
# 设置查询和报告标题
|
||||
state.query = self.query
|
||||
if not state.report_title:
|
||||
state.report_title = f"关于'{self.query}'的深度研究报告"
|
||||
|
||||
# 添加段落到状态
|
||||
for paragraph_data in report_structure:
|
||||
state.add_paragraph(
|
||||
title=paragraph_data["title"],
|
||||
content=paragraph_data["content"]
|
||||
)
|
||||
|
||||
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
搜索节点实现
|
||||
负责生成搜索查询和反思查询
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_json_tags,
|
||||
extract_clean_response,
|
||||
fix_incomplete_json
|
||||
)
|
||||
|
||||
|
||||
class FirstSearchNode(BaseNode):
|
||||
"""为段落生成首次搜索查询的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化首次搜索节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "FirstSearchNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
return "title" in data and "content" in data
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
return "title" in input_data and "content" in input_data
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
|
||||
"""
|
||||
调用LLM生成搜索查询和理由
|
||||
|
||||
Args:
|
||||
input_data: 包含title和content的字符串或字典
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误,需要包含title和content字段")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成首次搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
"""
|
||||
处理LLM输出,提取搜索查询和推理
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_search_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
# 验证和清理结果
|
||||
search_query = result.get("search_query", "")
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
return {
|
||||
"search_query": search_query,
|
||||
"reasoning": reasoning
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
# 返回默认查询
|
||||
return self._get_default_search_query()
|
||||
|
||||
def _get_default_search_query(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取默认搜索查询
|
||||
|
||||
Returns:
|
||||
默认的搜索查询字典
|
||||
"""
|
||||
return {
|
||||
"search_query": "相关主题研究",
|
||||
"reasoning": "由于解析失败,使用默认搜索查询"
|
||||
}
|
||||
|
||||
|
||||
class ReflectionNode(BaseNode):
|
||||
"""反思段落并生成新搜索查询的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化反思节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "ReflectionNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
required_fields = ["title", "content", "paragraph_latest_state"]
|
||||
return all(field in data for field in required_fields)
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
required_fields = ["title", "content", "paragraph_latest_state"]
|
||||
return all(field in input_data for field in required_fields)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
|
||||
"""
|
||||
调用LLM反思并生成搜索查询
|
||||
|
||||
Args:
|
||||
input_data: 包含title、content和paragraph_latest_state的字符串或字典
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误,需要包含title、content和paragraph_latest_state字段")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在进行反思并生成新搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"反思生成搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
"""
|
||||
处理LLM输出,提取搜索查询和推理
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
包含search_query和reasoning的字典
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
# 验证和清理结果
|
||||
search_query = result.get("search_query", "")
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
return {
|
||||
"search_query": search_query,
|
||||
"reasoning": reasoning
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
def _get_default_reflection_query(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取默认反思搜索查询
|
||||
|
||||
Returns:
|
||||
默认的反思搜索查询字典
|
||||
"""
|
||||
return {
|
||||
"search_query": "深度研究补充信息",
|
||||
"reasoning": "由于解析失败,使用默认反思搜索查询"
|
||||
}
|
||||
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
总结节点实现
|
||||
负责根据搜索结果生成和更新段落内容
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..state.state import State
|
||||
from ..prompts import SYSTEM_PROMPT_FIRST_SUMMARY, SYSTEM_PROMPT_REFLECTION_SUMMARY
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
clean_json_tags,
|
||||
extract_clean_response,
|
||||
fix_incomplete_json,
|
||||
format_search_results_for_prompt
|
||||
)
|
||||
|
||||
|
||||
class FirstSummaryNode(StateMutationNode):
|
||||
"""根据搜索结果生成段落首次总结的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化首次总结节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "FirstSummaryNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
required_fields = ["title", "content", "search_query", "search_results"]
|
||||
return all(field in data for field in required_fields)
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
required_fields = ["title", "content", "search_query", "search_results"]
|
||||
return all(field in input_data for field in required_fields)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM生成段落总结
|
||||
|
||||
Args:
|
||||
input_data: 包含title、content、search_query和search_results的数据
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
段落总结内容
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成首次段落总结")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成首次段落总结")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次总结失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
"""
|
||||
处理LLM输出,提取段落内容
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
段落内容
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_info("JSON修复失败,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
else:
|
||||
self.log_info("无法修复JSON,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
# 提取段落内容
|
||||
if isinstance(result, dict):
|
||||
paragraph_content = result.get("paragraph_latest_state", "")
|
||||
if paragraph_content:
|
||||
return paragraph_content
|
||||
|
||||
# 如果提取失败,返回原始清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
return "段落总结生成失败"
|
||||
|
||||
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
|
||||
"""
|
||||
更新段落的最新总结到状态
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
state: 当前状态
|
||||
paragraph_index: 段落索引
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
try:
|
||||
# 生成总结
|
||||
summary = self.run(input_data, **kwargs)
|
||||
|
||||
# 更新状态
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
state.paragraphs[paragraph_index].research.latest_summary = summary
|
||||
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
|
||||
else:
|
||||
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
|
||||
|
||||
state.update_timestamp()
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
class ReflectionSummaryNode(StateMutationNode):
|
||||
"""根据反思搜索结果更新段落总结的节点"""
|
||||
|
||||
def __init__(self, llm_client):
|
||||
"""
|
||||
初始化反思总结节点
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
"""
|
||||
super().__init__(llm_client, "ReflectionSummaryNode")
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""验证输入数据"""
|
||||
if isinstance(input_data, str):
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"]
|
||||
return all(field in data for field in required_fields)
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
elif isinstance(input_data, dict):
|
||||
required_fields = ["title", "content", "search_query", "search_results", "paragraph_latest_state"]
|
||||
return all(field in input_data for field in required_fields)
|
||||
return False
|
||||
|
||||
def run(self, input_data: Any, **kwargs) -> str:
|
||||
"""
|
||||
调用LLM更新段落内容
|
||||
|
||||
Args:
|
||||
input_data: 包含完整反思信息的数据
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
更新后的段落内容
|
||||
"""
|
||||
try:
|
||||
if not self.validate_input(input_data):
|
||||
raise ValueError("输入数据格式错误")
|
||||
|
||||
# 准备输入数据
|
||||
if isinstance(input_data, str):
|
||||
message = input_data
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成反思总结")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成反思总结")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成反思总结失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
"""
|
||||
处理LLM输出,提取更新后的段落内容
|
||||
|
||||
Args:
|
||||
output: LLM原始输出
|
||||
|
||||
Returns:
|
||||
更新后的段落内容
|
||||
"""
|
||||
try:
|
||||
# 清理响应文本
|
||||
cleaned_output = remove_reasoning_from_output(output)
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output[:200]}...")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_info("JSON修复失败,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
else:
|
||||
self.log_info("无法修复JSON,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
# 提取更新后的段落内容
|
||||
if isinstance(result, dict):
|
||||
updated_content = result.get("updated_paragraph_latest_state", "")
|
||||
if updated_content:
|
||||
return updated_content
|
||||
|
||||
# 如果提取失败,返回原始清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
return "反思总结生成失败"
|
||||
|
||||
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
|
||||
"""
|
||||
将更新后的总结写入状态
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
state: 当前状态
|
||||
paragraph_index: 段落索引
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
try:
|
||||
# 生成更新后的总结
|
||||
updated_summary = self.run(input_data, **kwargs)
|
||||
|
||||
# 更新状态
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
|
||||
state.paragraphs[paragraph_index].research.increment_reflection()
|
||||
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
|
||||
else:
|
||||
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
|
||||
|
||||
state.update_timestamp()
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Prompt模块
|
||||
定义Deep Search Agent各个阶段使用的系统提示词
|
||||
"""
|
||||
|
||||
from .prompts import (
|
||||
SYSTEM_PROMPT_REPORT_STRUCTURE,
|
||||
SYSTEM_PROMPT_FIRST_SEARCH,
|
||||
SYSTEM_PROMPT_FIRST_SUMMARY,
|
||||
SYSTEM_PROMPT_REFLECTION,
|
||||
SYSTEM_PROMPT_REFLECTION_SUMMARY,
|
||||
SYSTEM_PROMPT_REPORT_FORMATTING,
|
||||
output_schema_report_structure,
|
||||
output_schema_first_search,
|
||||
output_schema_first_summary,
|
||||
output_schema_reflection,
|
||||
output_schema_reflection_summary,
|
||||
input_schema_report_formatting
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SYSTEM_PROMPT_REPORT_STRUCTURE",
|
||||
"SYSTEM_PROMPT_FIRST_SEARCH",
|
||||
"SYSTEM_PROMPT_FIRST_SUMMARY",
|
||||
"SYSTEM_PROMPT_REFLECTION",
|
||||
"SYSTEM_PROMPT_REFLECTION_SUMMARY",
|
||||
"SYSTEM_PROMPT_REPORT_FORMATTING",
|
||||
"output_schema_report_structure",
|
||||
"output_schema_first_search",
|
||||
"output_schema_first_summary",
|
||||
"output_schema_reflection",
|
||||
"output_schema_reflection_summary",
|
||||
"input_schema_report_formatting"
|
||||
]
|
||||
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Deep Search Agent 的所有提示词定义
|
||||
包含各个阶段的系统提示词和JSON Schema定义
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
# ===== JSON Schema 定义 =====
|
||||
|
||||
# 报告结构输出Schema
|
||||
output_schema_report_structure = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"content": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 首次搜索输入Schema
|
||||
input_schema_first_search = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"content": {"type": "string"}
|
||||
}
|
||||
}
|
||||
|
||||
# 首次搜索输出Schema
|
||||
output_schema_first_search = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_query": {"type": "string"},
|
||||
"search_tool": {"type": "string"},
|
||||
"reasoning": {"type": "string"},
|
||||
"start_date": {"type": "string", "description": "开始日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"},
|
||||
"end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"}
|
||||
},
|
||||
"required": ["search_query", "search_tool", "reasoning"]
|
||||
}
|
||||
|
||||
# 首次总结输入Schema
|
||||
input_schema_first_summary = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
"search_query": {"type": "string"},
|
||||
"search_results": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 首次总结输出Schema
|
||||
output_schema_first_summary = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"paragraph_latest_state": {"type": "string"}
|
||||
}
|
||||
}
|
||||
|
||||
# 反思输入Schema
|
||||
input_schema_reflection = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
"paragraph_latest_state": {"type": "string"}
|
||||
}
|
||||
}
|
||||
|
||||
# 反思输出Schema
|
||||
output_schema_reflection = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_query": {"type": "string"},
|
||||
"search_tool": {"type": "string"},
|
||||
"reasoning": {"type": "string"},
|
||||
"start_date": {"type": "string", "description": "开始日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"},
|
||||
"end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"}
|
||||
},
|
||||
"required": ["search_query", "search_tool", "reasoning"]
|
||||
}
|
||||
|
||||
# 反思总结输入Schema
|
||||
input_schema_reflection_summary = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
"search_query": {"type": "string"},
|
||||
"search_results": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
},
|
||||
"paragraph_latest_state": {"type": "string"}
|
||||
}
|
||||
}
|
||||
|
||||
# 反思总结输出Schema
|
||||
output_schema_reflection_summary = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"updated_paragraph_latest_state": {"type": "string"}
|
||||
}
|
||||
}
|
||||
|
||||
# 报告格式化输入Schema
|
||||
input_schema_report_formatting = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"paragraph_latest_state": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# ===== 系统提示词定义 =====
|
||||
|
||||
# 生成报告结构的系统提示词
|
||||
SYSTEM_PROMPT_REPORT_STRUCTURE = f"""
|
||||
你是一位深度研究助手。给定一个查询,你需要规划一个报告的结构和其中包含的段落。最多五个段落。
|
||||
确保段落的排序合理有序。
|
||||
一旦大纲创建完成,你将获得工具来分别为每个部分搜索网络并进行反思。
|
||||
请按照以下JSON模式定义格式化输出:
|
||||
|
||||
<OUTPUT JSON SCHEMA>
|
||||
{json.dumps(output_schema_report_structure, indent=2, ensure_ascii=False)}
|
||||
</OUTPUT JSON SCHEMA>
|
||||
|
||||
标题和内容属性将用于更深入的研究。
|
||||
确保输出是一个符合上述输出JSON模式定义的JSON对象。
|
||||
只返回JSON对象,不要有解释或额外文本。
|
||||
"""
|
||||
|
||||
# 每个段落第一次搜索的系统提示词
|
||||
SYSTEM_PROMPT_FIRST_SEARCH = f"""
|
||||
你是一位深度研究助手。你将获得报告中的一个段落,其标题和预期内容将按照以下JSON模式定义提供:
|
||||
|
||||
<INPUT JSON SCHEMA>
|
||||
{json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)}
|
||||
</INPUT JSON SCHEMA>
|
||||
|
||||
你可以使用以下6种专业的新闻搜索工具:
|
||||
|
||||
1. **basic_search_news** - 基础新闻搜索工具
|
||||
- 适用于:一般性的新闻搜索,不确定需要何种特定搜索时
|
||||
- 特点:快速、标准的通用搜索,是最常用的基础工具
|
||||
|
||||
2. **deep_search_news** - 深度新闻分析工具
|
||||
- 适用于:需要全面深入了解某个主题时
|
||||
- 特点:提供最详细的分析结果,包含高级AI摘要
|
||||
|
||||
3. **search_news_last_24_hours** - 24小时最新新闻工具
|
||||
- 适用于:需要了解最新动态、突发事件时
|
||||
- 特点:只搜索过去24小时的新闻
|
||||
|
||||
4. **search_news_last_week** - 本周新闻工具
|
||||
- 适用于:需要了解近期发展趋势时
|
||||
- 特点:搜索过去一周的新闻报道
|
||||
|
||||
5. **search_images_for_news** - 图片搜索工具
|
||||
- 适用于:需要可视化信息、图片资料时
|
||||
- 特点:提供相关图片和图片描述
|
||||
|
||||
6. **search_news_by_date** - 按日期范围搜索工具
|
||||
- 适用于:需要研究特定历史时期时
|
||||
- 特点:可以指定开始和结束日期进行搜索
|
||||
- 特殊要求:需要提供start_date和end_date参数,格式为'YYYY-MM-DD'
|
||||
- 注意:只有这个工具需要额外的时间参数
|
||||
|
||||
你的任务是:
|
||||
1. 根据段落主题选择最合适的搜索工具
|
||||
2. 制定最佳的搜索查询
|
||||
3. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD)
|
||||
4. 解释你的选择理由
|
||||
|
||||
注意:除了search_news_by_date工具外,其他工具都不需要额外参数。
|
||||
请按照以下JSON模式定义格式化输出(文字请使用中文):
|
||||
|
||||
<OUTPUT JSON SCHEMA>
|
||||
{json.dumps(output_schema_first_search, indent=2, ensure_ascii=False)}
|
||||
</OUTPUT JSON SCHEMA>
|
||||
|
||||
确保输出是一个符合上述输出JSON模式定义的JSON对象。
|
||||
只返回JSON对象,不要有解释或额外文本。
|
||||
"""
|
||||
|
||||
# 每个段落第一次总结的系统提示词
|
||||
SYSTEM_PROMPT_FIRST_SUMMARY = f"""
|
||||
你是一位深度研究助手。你将获得搜索查询、搜索结果以及你正在研究的报告段落,数据将按照以下JSON模式定义提供:
|
||||
|
||||
<INPUT JSON SCHEMA>
|
||||
{json.dumps(input_schema_first_summary, indent=2, ensure_ascii=False)}
|
||||
</INPUT JSON SCHEMA>
|
||||
|
||||
你的任务是作为研究者,使用搜索结果撰写与段落主题一致的内容,并适当地组织结构以便纳入报告中。
|
||||
请按照以下JSON模式定义格式化输出:
|
||||
|
||||
<OUTPUT JSON SCHEMA>
|
||||
{json.dumps(output_schema_first_summary, indent=2, ensure_ascii=False)}
|
||||
</OUTPUT JSON SCHEMA>
|
||||
|
||||
确保输出是一个符合上述输出JSON模式定义的JSON对象。
|
||||
只返回JSON对象,不要有解释或额外文本。
|
||||
"""
|
||||
|
||||
# 反思(Reflect)的系统提示词
|
||||
SYSTEM_PROMPT_REFLECTION = f"""
|
||||
你是一位深度研究助手。你负责为研究报告构建全面的段落。你将获得段落标题、计划内容摘要,以及你已经创建的段落最新状态,所有这些都将按照以下JSON模式定义提供:
|
||||
|
||||
<INPUT JSON SCHEMA>
|
||||
{json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)}
|
||||
</INPUT JSON SCHEMA>
|
||||
|
||||
你可以使用以下6种专业的新闻搜索工具:
|
||||
|
||||
1. **basic_search_news** - 基础新闻搜索工具
|
||||
2. **deep_search_news** - 深度新闻分析工具
|
||||
3. **search_news_last_24_hours** - 24小时最新新闻工具
|
||||
4. **search_news_last_week** - 本周新闻工具
|
||||
5. **search_images_for_news** - 图片搜索工具
|
||||
6. **search_news_by_date** - 按日期范围搜索工具(需要时间参数)
|
||||
|
||||
你的任务是:
|
||||
1. 反思段落文本的当前状态,思考是否遗漏了主题的某些关键方面
|
||||
2. 选择最合适的搜索工具来补充缺失信息
|
||||
3. 制定精确的搜索查询
|
||||
4. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD)
|
||||
5. 解释你的选择和推理
|
||||
|
||||
注意:除了search_news_by_date工具外,其他工具都不需要额外参数。
|
||||
请按照以下JSON模式定义格式化输出:
|
||||
|
||||
<OUTPUT JSON SCHEMA>
|
||||
{json.dumps(output_schema_reflection, indent=2, ensure_ascii=False)}
|
||||
</OUTPUT JSON SCHEMA>
|
||||
|
||||
确保输出是一个符合上述输出JSON模式定义的JSON对象。
|
||||
只返回JSON对象,不要有解释或额外文本。
|
||||
"""
|
||||
|
||||
# 总结反思的系统提示词
|
||||
SYSTEM_PROMPT_REFLECTION_SUMMARY = f"""
|
||||
你是一位深度研究助手。
|
||||
你将获得搜索查询、搜索结果、段落标题以及你正在研究的报告段落的预期内容。
|
||||
你正在迭代完善这个段落,并且段落的最新状态也会提供给你。
|
||||
数据将按照以下JSON模式定义提供:
|
||||
|
||||
<INPUT JSON SCHEMA>
|
||||
{json.dumps(input_schema_reflection_summary, indent=2, ensure_ascii=False)}
|
||||
</INPUT JSON SCHEMA>
|
||||
|
||||
你的任务是根据搜索结果和预期内容丰富段落的当前最新状态。
|
||||
不要删除最新状态中的关键信息,尽量丰富它,只添加缺失的信息。
|
||||
适当地组织段落结构以便纳入报告中。
|
||||
请按照以下JSON模式定义格式化输出:
|
||||
|
||||
<OUTPUT JSON SCHEMA>
|
||||
{json.dumps(output_schema_reflection_summary, indent=2, ensure_ascii=False)}
|
||||
</OUTPUT JSON SCHEMA>
|
||||
|
||||
确保输出是一个符合上述输出JSON模式定义的JSON对象。
|
||||
只返回JSON对象,不要有解释或额外文本。
|
||||
"""
|
||||
|
||||
# 最终研究报告格式化的系统提示词
|
||||
SYSTEM_PROMPT_REPORT_FORMATTING = f"""
|
||||
你是一位深度研究助手。你已经完成了研究并构建了报告中所有段落的最终版本。
|
||||
你将获得以下JSON格式的数据:
|
||||
|
||||
<INPUT JSON SCHEMA>
|
||||
{json.dumps(input_schema_report_formatting, indent=2, ensure_ascii=False)}
|
||||
</INPUT JSON SCHEMA>
|
||||
|
||||
你的任务是将报告格式化为美观的形式,并以Markdown格式返回。
|
||||
如果没有结论段落,请根据其他段落的最新状态在报告末尾添加一个结论。
|
||||
使用段落标题来创建报告的标题。
|
||||
"""
|
||||
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
状态管理模块
|
||||
定义Deep Search Agent的状态数据结构
|
||||
"""
|
||||
|
||||
from .state import State, Paragraph, Research, Search
|
||||
|
||||
__all__ = ["State", "Paragraph", "Research", "Search"]
|
||||
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Deep Search Agent状态管理
|
||||
定义所有状态数据结构和操作方法
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class Search:
|
||||
"""单个搜索结果的状态"""
|
||||
query: str = "" # 搜索查询
|
||||
url: str = "" # 搜索结果的链接
|
||||
title: str = "" # 搜索结果标题
|
||||
content: str = "" # 搜索返回的内容
|
||||
score: Optional[float] = None # 相关度评分
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"query": self.query,
|
||||
"url": self.url,
|
||||
"title": self.title,
|
||||
"content": self.content,
|
||||
"score": self.score,
|
||||
"timestamp": self.timestamp
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Search":
|
||||
"""从字典创建Search对象"""
|
||||
return cls(
|
||||
query=data.get("query", ""),
|
||||
url=data.get("url", ""),
|
||||
title=data.get("title", ""),
|
||||
content=data.get("content", ""),
|
||||
score=data.get("score"),
|
||||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Research:
|
||||
"""段落研究过程的状态"""
|
||||
search_history: List[Search] = field(default_factory=list) # 搜索记录列表
|
||||
latest_summary: str = "" # 当前段落的最新总结
|
||||
reflection_iteration: int = 0 # 反思迭代次数
|
||||
is_completed: bool = False # 是否完成研究
|
||||
|
||||
def add_search(self, search: Search):
|
||||
"""添加搜索记录"""
|
||||
self.search_history.append(search)
|
||||
|
||||
def add_search_results(self, query: str, results: List[Dict[str, Any]]):
|
||||
"""批量添加搜索结果"""
|
||||
for result in results:
|
||||
search = Search(
|
||||
query=query,
|
||||
url=result.get("url", ""),
|
||||
title=result.get("title", ""),
|
||||
content=result.get("content", ""),
|
||||
score=result.get("score")
|
||||
)
|
||||
self.add_search(search)
|
||||
|
||||
def get_search_count(self) -> int:
|
||||
"""获取搜索次数"""
|
||||
return len(self.search_history)
|
||||
|
||||
def increment_reflection(self):
|
||||
"""增加反思次数"""
|
||||
self.reflection_iteration += 1
|
||||
|
||||
def mark_completed(self):
|
||||
"""标记为完成"""
|
||||
self.is_completed = True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"search_history": [search.to_dict() for search in self.search_history],
|
||||
"latest_summary": self.latest_summary,
|
||||
"reflection_iteration": self.reflection_iteration,
|
||||
"is_completed": self.is_completed
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Research":
|
||||
"""从字典创建Research对象"""
|
||||
search_history = [Search.from_dict(search_data) for search_data in data.get("search_history", [])]
|
||||
return cls(
|
||||
search_history=search_history,
|
||||
latest_summary=data.get("latest_summary", ""),
|
||||
reflection_iteration=data.get("reflection_iteration", 0),
|
||||
is_completed=data.get("is_completed", False)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Paragraph:
|
||||
"""报告中单个段落的状态"""
|
||||
title: str = "" # 段落标题
|
||||
content: str = "" # 段落的预期内容(初始规划)
|
||||
research: Research = field(default_factory=Research) # 研究进度
|
||||
order: int = 0 # 段落顺序
|
||||
|
||||
def is_completed(self) -> bool:
|
||||
"""检查段落是否完成"""
|
||||
return self.research.is_completed and bool(self.research.latest_summary)
|
||||
|
||||
def get_final_content(self) -> str:
|
||||
"""获取最终内容"""
|
||||
return self.research.latest_summary or self.content
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"title": self.title,
|
||||
"content": self.content,
|
||||
"research": self.research.to_dict(),
|
||||
"order": self.order
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Paragraph":
|
||||
"""从字典创建Paragraph对象"""
|
||||
research_data = data.get("research", {})
|
||||
research = Research.from_dict(research_data) if research_data else Research()
|
||||
|
||||
return cls(
|
||||
title=data.get("title", ""),
|
||||
content=data.get("content", ""),
|
||||
research=research,
|
||||
order=data.get("order", 0)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""整个报告的状态"""
|
||||
query: str = "" # 原始查询
|
||||
report_title: str = "" # 报告标题
|
||||
paragraphs: List[Paragraph] = field(default_factory=list) # 段落列表
|
||||
final_report: str = "" # 最终报告内容
|
||||
is_completed: bool = False # 是否完成
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def add_paragraph(self, title: str, content: str) -> int:
|
||||
"""
|
||||
添加段落
|
||||
|
||||
Args:
|
||||
title: 段落标题
|
||||
content: 段落内容
|
||||
|
||||
Returns:
|
||||
段落索引
|
||||
"""
|
||||
order = len(self.paragraphs)
|
||||
paragraph = Paragraph(title=title, content=content, order=order)
|
||||
self.paragraphs.append(paragraph)
|
||||
self.update_timestamp()
|
||||
return order
|
||||
|
||||
def get_paragraph(self, index: int) -> Optional[Paragraph]:
|
||||
"""获取指定索引的段落"""
|
||||
if 0 <= index < len(self.paragraphs):
|
||||
return self.paragraphs[index]
|
||||
return None
|
||||
|
||||
def get_completed_paragraphs_count(self) -> int:
|
||||
"""获取已完成段落数量"""
|
||||
return sum(1 for p in self.paragraphs if p.is_completed())
|
||||
|
||||
def get_total_paragraphs_count(self) -> int:
|
||||
"""获取总段落数量"""
|
||||
return len(self.paragraphs)
|
||||
|
||||
def is_all_paragraphs_completed(self) -> bool:
|
||||
"""检查是否所有段落都完成"""
|
||||
return all(p.is_completed() for p in self.paragraphs) if self.paragraphs else False
|
||||
|
||||
def mark_completed(self):
|
||||
"""标记整个报告为完成"""
|
||||
self.is_completed = True
|
||||
self.update_timestamp()
|
||||
|
||||
def update_timestamp(self):
|
||||
"""更新时间戳"""
|
||||
self.updated_at = datetime.now().isoformat()
|
||||
|
||||
def get_progress_summary(self) -> Dict[str, Any]:
|
||||
"""获取进度摘要"""
|
||||
completed = self.get_completed_paragraphs_count()
|
||||
total = self.get_total_paragraphs_count()
|
||||
|
||||
return {
|
||||
"total_paragraphs": total,
|
||||
"completed_paragraphs": completed,
|
||||
"progress_percentage": (completed / total * 100) if total > 0 else 0,
|
||||
"is_completed": self.is_completed,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"query": self.query,
|
||||
"report_title": self.report_title,
|
||||
"paragraphs": [p.to_dict() for p in self.paragraphs],
|
||||
"final_report": self.final_report,
|
||||
"is_completed": self.is_completed,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at
|
||||
}
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""转换为JSON字符串"""
|
||||
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "State":
|
||||
"""从字典创建State对象"""
|
||||
paragraphs = [Paragraph.from_dict(p_data) for p_data in data.get("paragraphs", [])]
|
||||
|
||||
return cls(
|
||||
query=data.get("query", ""),
|
||||
report_title=data.get("report_title", ""),
|
||||
paragraphs=paragraphs,
|
||||
final_report=data.get("final_report", ""),
|
||||
is_completed=data.get("is_completed", False),
|
||||
created_at=data.get("created_at", datetime.now().isoformat()),
|
||||
updated_at=data.get("updated_at", datetime.now().isoformat())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "State":
|
||||
"""从JSON字符串创建State对象"""
|
||||
data = json.loads(json_str)
|
||||
return cls.from_dict(data)
|
||||
|
||||
def save_to_file(self, filepath: str):
|
||||
"""保存状态到文件"""
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(self.to_json())
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, filepath: str) -> "State":
|
||||
"""从文件加载状态"""
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
json_str = f.read()
|
||||
return cls.from_json(json_str)
|
||||
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
工具调用模块
|
||||
提供外部工具接口,如网络搜索等
|
||||
"""
|
||||
|
||||
from .search import (
|
||||
TavilyNewsAgency,
|
||||
SearchResult,
|
||||
TavilyResponse,
|
||||
ImageResult,
|
||||
print_response_summary
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TavilyNewsAgency",
|
||||
"SearchResult",
|
||||
"TavilyResponse",
|
||||
"ImageResult",
|
||||
"print_response_summary"
|
||||
]
|
||||
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
专为 AI Agent 设计的舆情搜索工具集 (Tavily)
|
||||
|
||||
版本: 1.5
|
||||
最后更新: 2025-08-22
|
||||
|
||||
此脚本将复杂的Tavily搜索功能分解为一系列目标明确、参数极少的独立工具,
|
||||
专为AI Agent调用而设计。Agent只需根据任务意图选择合适的工具,
|
||||
无需理解复杂的参数组合。所有工具默认搜索“新闻”(topic='news')。
|
||||
|
||||
新特性:
|
||||
- 新增 `basic_search_news` 工具,用于执行标准、通用的新闻搜索。
|
||||
- 每个搜索结果现在都包含 `published_date` (新闻发布日期)。
|
||||
|
||||
主要工具:
|
||||
- basic_search_news: (新增) 执行标准、快速的通用新闻搜索。
|
||||
- deep_search_news: 对主题进行最全面的深度分析。
|
||||
- search_news_last_24_hours: 获取24小时内的最新动态。
|
||||
- search_news_last_week: 获取过去一周的主要报道。
|
||||
- search_images_for_news: 查找与新闻主题相关的图片。
|
||||
- search_news_by_date: 在指定的历史日期范围内搜索。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# 运行前请确保已安装Tavily库: pip install tavily-python
|
||||
try:
|
||||
from tavily import TavilyClient
|
||||
except ImportError:
|
||||
raise ImportError("Tavily库未安装,请运行 `pip install tavily-python` 进行安装。")
|
||||
|
||||
# --- 1. 数据结构定义 ---
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""
|
||||
网页搜索结果数据类
|
||||
包含 published_date 属性来存储新闻发布日期
|
||||
"""
|
||||
title: str
|
||||
url: str
|
||||
content: str
|
||||
score: Optional[float] = None
|
||||
raw_content: Optional[str] = None
|
||||
published_date: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class ImageResult:
|
||||
"""图片搜索结果数据类"""
|
||||
url: str
|
||||
description: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class TavilyResponse:
|
||||
"""封装Tavily API的完整返回结果,以便在工具间传递"""
|
||||
query: str
|
||||
answer: Optional[str] = None
|
||||
results: List[SearchResult] = field(default_factory=list)
|
||||
images: List[ImageResult] = field(default_factory=list)
|
||||
response_time: Optional[float] = None
|
||||
|
||||
|
||||
# --- 2. 核心客户端与专用工具集 ---
|
||||
|
||||
class TavilyNewsAgency:
|
||||
"""
|
||||
一个包含多种专用新闻舆情搜索工具的客户端。
|
||||
每个公共方法都设计为供 AI Agent 独立调用的工具。
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""
|
||||
初始化客户端。
|
||||
Args:
|
||||
api_key: Tavily API密钥,若不提供则从环境变量 TAVILY_API_KEY 读取。
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("TAVILY_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("Tavily API Key未找到!请设置TAVILY_API_KEY环境变量或在初始化时提供")
|
||||
self._client = TavilyClient(api_key=api_key)
|
||||
|
||||
def _search_internal(self, **kwargs) -> TavilyResponse:
|
||||
"""内部通用的搜索执行器,所有工具最终都调用此方法"""
|
||||
try:
|
||||
kwargs['topic'] = 'general'
|
||||
api_params = {k: v for k, v in kwargs.items() if v is not None}
|
||||
response_dict = self._client.search(**api_params)
|
||||
|
||||
search_results = [
|
||||
SearchResult(
|
||||
title=item.get('title'),
|
||||
url=item.get('url'),
|
||||
content=item.get('content'),
|
||||
score=item.get('score'),
|
||||
raw_content=item.get('raw_content'),
|
||||
published_date=item.get('published_date')
|
||||
) for item in response_dict.get('results', [])
|
||||
]
|
||||
|
||||
image_results = [ImageResult(url=item.get('url'), description=item.get('description')) for item in response_dict.get('images', [])]
|
||||
|
||||
return TavilyResponse(
|
||||
query=response_dict.get('query'), answer=response_dict.get('answer'),
|
||||
results=search_results, images=image_results,
|
||||
response_time=response_dict.get('response_time')
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"搜索时发生错误: {str(e)}")
|
||||
return TavilyResponse(query=kwargs.get("query", "Unknown Query"))
|
||||
|
||||
# --- Agent 可用的工具方法 ---
|
||||
|
||||
def basic_search_news(self, query: str, max_results: int = 7) -> TavilyResponse:
|
||||
"""
|
||||
【工具】基础新闻搜索: 执行一次标准、快速的新闻搜索。
|
||||
这是最常用的通用搜索工具,适用于不确定需要何种特定搜索时。
|
||||
Agent可提供搜索查询(query)和可选的最大结果数(max_results)。
|
||||
"""
|
||||
print(f"--- TOOL: 基础新闻搜索 (query: {query}) ---")
|
||||
return self._search_internal(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
search_depth="basic",
|
||||
include_answer=False
|
||||
)
|
||||
|
||||
def deep_search_news(self, query: str) -> TavilyResponse:
|
||||
"""
|
||||
【工具】深度新闻分析: 对一个主题进行最全面、最深入的搜索。
|
||||
返回AI生成的“高级”详细摘要答案和最多20条最相关的新闻结果。适用于需要全面了解某个事件背景的场景。
|
||||
Agent只需提供搜索查询(query)。
|
||||
"""
|
||||
print(f"--- TOOL: 深度新闻分析 (query: {query}) ---")
|
||||
return self._search_internal(
|
||||
query=query, search_depth="advanced", max_results=20, include_answer="advanced"
|
||||
)
|
||||
|
||||
def search_news_last_24_hours(self, query: str) -> TavilyResponse:
|
||||
"""
|
||||
【工具】搜索24小时内新闻: 获取关于某个主题的最新动态。
|
||||
此工具专门查找过去24小时内发布的新闻。适用于追踪突发事件或最新进展。
|
||||
Agent只需提供搜索查询(query)。
|
||||
"""
|
||||
print(f"--- TOOL: 搜索24小时内新闻 (query: {query}) ---")
|
||||
return self._search_internal(query=query, time_range='d', max_results=10)
|
||||
|
||||
def search_news_last_week(self, query: str) -> TavilyResponse:
|
||||
"""
|
||||
【工具】搜索本周新闻: 获取关于某个主题过去一周内的主要新闻报道。
|
||||
适用于进行周度舆情总结或回顾。
|
||||
Agent只需提供搜索查询(query)。
|
||||
"""
|
||||
print(f"--- TOOL: 搜索本周新闻 (query: {query}) ---")
|
||||
return self._search_internal(query=query, time_range='w', max_results=10)
|
||||
|
||||
def search_images_for_news(self, query: str) -> TavilyResponse:
|
||||
"""
|
||||
【工具】查找新闻图片: 搜索与某个新闻主题相关的图片。
|
||||
此工具会返回图片链接及描述,适用于需要为报告或文章配图的场景。
|
||||
Agent只需提供搜索查询(query)。
|
||||
"""
|
||||
print(f"--- TOOL: 查找新闻图片 (query: {query}) ---")
|
||||
return self._search_internal(
|
||||
query=query, include_images=True, include_image_descriptions=True, max_results=5
|
||||
)
|
||||
|
||||
def search_news_by_date(self, query: str, start_date: str, end_date: str) -> TavilyResponse:
|
||||
"""
|
||||
【工具】按指定日期范围搜索新闻: 在一个明确的历史时间段内搜索新闻。
|
||||
这是唯一需要Agent提供详细时间参数的工具。适用于需要对特定历史事件进行分析的场景。
|
||||
Agent需要提供查询(query)、开始日期(start_date)和结束日期(end_date),格式均为 'YYYY-MM-DD'。
|
||||
"""
|
||||
print(f"--- TOOL: 按指定日期范围搜索新闻 (query: {query}, from: {start_date}, to: {end_date}) ---")
|
||||
return self._search_internal(
|
||||
query=query, start_date=start_date, end_date=end_date, max_results=15
|
||||
)
|
||||
|
||||
|
||||
# --- 3. 测试与使用示例 ---
|
||||
|
||||
def print_response_summary(response: TavilyResponse):
|
||||
"""简化的打印函数,用于展示测试结果,现在会显示发布日期"""
|
||||
if not response or not response.query:
|
||||
print("未能获取有效响应。")
|
||||
return
|
||||
|
||||
print(f"\n查询: '{response.query}' | 耗时: {response.response_time}s")
|
||||
if response.answer:
|
||||
print(f"AI摘要: {response.answer[:120]}...")
|
||||
print(f"找到 {len(response.results)} 条网页, {len(response.images)} 张图片。")
|
||||
if response.results:
|
||||
first_result = response.results[0]
|
||||
date_info = f"(发布于: {first_result.published_date})" if first_result.published_date else ""
|
||||
print(f"第一条结果: {first_result.title} {date_info}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 在运行前,请确保您已设置 TAVILY_API_KEY 环境变量
|
||||
|
||||
try:
|
||||
# 初始化“新闻社”客户端,它内部包含了所有工具
|
||||
agency = TavilyNewsAgency()
|
||||
|
||||
# 场景1: Agent 进行一次常规、快速的搜索
|
||||
response1 = agency.basic_search_news(query="奥运会最新赛况", max_results=5)
|
||||
print_response_summary(response1)
|
||||
|
||||
# 场景2: Agent 需要全面了解“全球芯片技术竞争”的背景
|
||||
response2 = agency.deep_search_news(query="全球芯片技术竞争")
|
||||
print_response_summary(response2)
|
||||
|
||||
# 场景3: Agent 需要追踪“GTC大会”的最新消息
|
||||
response3 = agency.search_news_last_24_hours(query="Nvidia GTC大会 最新发布")
|
||||
print_response_summary(response3)
|
||||
|
||||
# 场景4: Agent 需要为一篇关于“自动驾驶”的周报查找素材
|
||||
response4 = agency.search_news_last_week(query="自动驾驶商业化落地")
|
||||
print_response_summary(response4)
|
||||
|
||||
# 场景5: Agent 需要查找“韦伯太空望远镜”的新闻图片
|
||||
response5 = agency.search_images_for_news(query="韦伯太空望远镜最新发现")
|
||||
print_response_summary(response5)
|
||||
|
||||
# 场景6: Agent 需要研究2025年第一季度关于“人工智能法规”的新闻
|
||||
response6 = agency.search_news_by_date(
|
||||
query="人工智能法规",
|
||||
start_date="2025-01-01",
|
||||
end_date="2025-03-31"
|
||||
)
|
||||
print_response_summary(response6)
|
||||
|
||||
except ValueError as e:
|
||||
print(f"初始化失败: {e}")
|
||||
print("请确保 TAVILY_API_KEY 环境变量已正确设置。")
|
||||
except Exception as e:
|
||||
print(f"测试过程中发生未知错误: {e}")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
工具函数模块
|
||||
提供文本处理、JSON解析等辅助功能
|
||||
"""
|
||||
|
||||
from .text_processing import (
|
||||
clean_json_tags,
|
||||
clean_markdown_tags,
|
||||
remove_reasoning_from_output,
|
||||
extract_clean_response,
|
||||
update_state_with_search_results,
|
||||
format_search_results_for_prompt
|
||||
)
|
||||
|
||||
from .config import Config, load_config
|
||||
|
||||
__all__ = [
|
||||
"clean_json_tags",
|
||||
"clean_markdown_tags",
|
||||
"remove_reasoning_from_output",
|
||||
"extract_clean_response",
|
||||
"update_state_with_search_results",
|
||||
"format_search_results_for_prompt",
|
||||
"Config",
|
||||
"load_config"
|
||||
]
|
||||
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
配置管理模块
|
||||
处理环境变量和配置参数
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""配置类"""
|
||||
# API密钥
|
||||
deepseek_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
tavily_api_key: Optional[str] = None
|
||||
|
||||
# 模型配置
|
||||
default_llm_provider: str = "deepseek" # deepseek 或 openai
|
||||
deepseek_model: str = "deepseek-chat"
|
||||
openai_model: str = "gpt-4o-mini"
|
||||
|
||||
# 搜索配置
|
||||
search_timeout: int = 240
|
||||
max_content_length: int = 20000
|
||||
|
||||
# Agent配置
|
||||
max_reflections: int = 2
|
||||
max_paragraphs: int = 5
|
||||
|
||||
# 输出配置
|
||||
output_dir: str = "reports"
|
||||
save_intermediate_states: bool = True
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置"""
|
||||
# 检查必需的API密钥
|
||||
if self.default_llm_provider == "deepseek" and not self.deepseek_api_key:
|
||||
print("错误: DeepSeek API Key未设置")
|
||||
return False
|
||||
|
||||
if self.default_llm_provider == "openai" and not self.openai_api_key:
|
||||
print("错误: OpenAI API Key未设置")
|
||||
return False
|
||||
|
||||
if not self.tavily_api_key:
|
||||
print("错误: Tavily API Key未设置")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_file: str) -> "Config":
|
||||
"""从配置文件创建配置"""
|
||||
if config_file.endswith('.py'):
|
||||
# Python配置文件
|
||||
import importlib.util
|
||||
|
||||
# 动态导入配置文件
|
||||
spec = importlib.util.spec_from_file_location("config", config_file)
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config_module)
|
||||
|
||||
return cls(
|
||||
deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None),
|
||||
openai_api_key=getattr(config_module, "OPENAI_API_KEY", None),
|
||||
tavily_api_key=getattr(config_module, "TAVILY_API_KEY", None),
|
||||
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"),
|
||||
deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"),
|
||||
openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"),
|
||||
|
||||
search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240),
|
||||
max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000),
|
||||
max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2),
|
||||
max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5),
|
||||
output_dir=getattr(config_module, "OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True)
|
||||
)
|
||||
else:
|
||||
# .env格式配置文件
|
||||
config_dict = {}
|
||||
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
config_dict[key.strip()] = value.strip()
|
||||
|
||||
return cls(
|
||||
deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
|
||||
openai_api_key=config_dict.get("OPENAI_API_KEY"),
|
||||
tavily_api_key=config_dict.get("TAVILY_API_KEY"),
|
||||
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
|
||||
deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
|
||||
openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||
|
||||
search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")),
|
||||
max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")),
|
||||
max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")),
|
||||
max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")),
|
||||
output_dir=config_dict.get("OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def load_config(config_file: Optional[str] = None) -> Config:
|
||||
"""
|
||||
加载配置
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径,如果不指定则使用默认路径
|
||||
|
||||
Returns:
|
||||
配置对象
|
||||
"""
|
||||
# 确定配置文件路径
|
||||
if config_file:
|
||||
if not os.path.exists(config_file):
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_file}")
|
||||
file_to_load = config_file
|
||||
else:
|
||||
# 尝试加载常见的配置文件
|
||||
for config_path in ["config.py", "config.env", ".env"]:
|
||||
if os.path.exists(config_path):
|
||||
file_to_load = config_path
|
||||
print(f"已找到配置文件: {config_path}")
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError("未找到配置文件,请创建 config.py 文件")
|
||||
|
||||
# 创建配置对象
|
||||
config = Config.from_file(file_to_load)
|
||||
|
||||
# 验证配置
|
||||
if not config.validate():
|
||||
raise ValueError("配置验证失败,请检查配置文件中的API密钥")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def print_config(config: Config):
|
||||
"""打印配置信息(隐藏敏感信息)"""
|
||||
print("\n=== 当前配置 ===")
|
||||
print(f"LLM提供商: {config.default_llm_provider}")
|
||||
print(f"DeepSeek模型: {config.deepseek_model}")
|
||||
print(f"OpenAI模型: {config.openai_model}")
|
||||
print(f"最大搜索结果数: {config.max_search_results}")
|
||||
print(f"搜索超时: {config.search_timeout}秒")
|
||||
print(f"最大内容长度: {config.max_content_length}")
|
||||
print(f"最大反思次数: {config.max_reflections}")
|
||||
print(f"最大段落数: {config.max_paragraphs}")
|
||||
print(f"输出目录: {config.output_dir}")
|
||||
print(f"保存中间状态: {config.save_intermediate_states}")
|
||||
|
||||
# 显示API密钥状态(不显示实际密钥)
|
||||
print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}")
|
||||
print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}")
|
||||
print(f"Tavily API Key: {'已设置' if config.tavily_api_key else '未设置'}")
|
||||
print("==================\n")
|
||||
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
文本处理工具函数
|
||||
用于清理LLM输出、解析JSON等
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
|
||||
def clean_json_tags(text: str) -> str:
|
||||
"""
|
||||
清理文本中的JSON标签
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
清理后的文本
|
||||
"""
|
||||
# 移除```json 和 ```标签
|
||||
text = re.sub(r'```json\s*', '', text)
|
||||
text = re.sub(r'```\s*$', '', text)
|
||||
text = re.sub(r'```', '', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def clean_markdown_tags(text: str) -> str:
|
||||
"""
|
||||
清理文本中的Markdown标签
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
清理后的文本
|
||||
"""
|
||||
# 移除```markdown 和 ```标签
|
||||
text = re.sub(r'```markdown\s*', '', text)
|
||||
text = re.sub(r'```\s*$', '', text)
|
||||
text = re.sub(r'```', '', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def remove_reasoning_from_output(text: str) -> str:
|
||||
"""
|
||||
移除输出中的推理过程文本
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
清理后的文本
|
||||
"""
|
||||
# 查找JSON开始位置
|
||||
json_start = -1
|
||||
|
||||
# 尝试找到第一个 { 或 [
|
||||
for i, char in enumerate(text):
|
||||
if char in '{[':
|
||||
json_start = i
|
||||
break
|
||||
|
||||
if json_start != -1:
|
||||
# 从JSON开始位置截取
|
||||
return text[json_start:].strip()
|
||||
|
||||
# 如果没有找到JSON标记,尝试其他方法
|
||||
# 移除常见的推理标识
|
||||
patterns = [
|
||||
r'(?:reasoning|推理|思考|分析)[::]\s*.*?(?=\{|\[)', # 移除推理部分
|
||||
r'(?:explanation|解释|说明)[::]\s*.*?(?=\{|\[)', # 移除解释部分
|
||||
r'^.*?(?=\{|\[)', # 移除JSON前的所有文本
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
text = re.sub(pattern, '', text, flags=re.IGNORECASE | re.DOTALL)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def extract_clean_response(text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
提取并清理响应中的JSON内容
|
||||
|
||||
Args:
|
||||
text: 原始响应文本
|
||||
|
||||
Returns:
|
||||
解析后的JSON字典
|
||||
"""
|
||||
# 清理文本
|
||||
cleaned_text = clean_json_tags(text)
|
||||
cleaned_text = remove_reasoning_from_output(cleaned_text)
|
||||
|
||||
# 尝试直接解析
|
||||
try:
|
||||
return json.loads(cleaned_text)
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试修复不完整的JSON
|
||||
fixed_text = fix_incomplete_json(cleaned_text)
|
||||
if fixed_text:
|
||||
try:
|
||||
return json.loads(fixed_text)
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试查找JSON对象
|
||||
json_pattern = r'\{.*\}'
|
||||
match = re.search(json_pattern, cleaned_text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group())
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试查找JSON数组
|
||||
array_pattern = r'\[.*\]'
|
||||
match = re.search(array_pattern, cleaned_text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group())
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 如果所有方法都失败,返回错误信息
|
||||
print(f"无法解析JSON响应: {cleaned_text[:200]}...")
|
||||
return {"error": "JSON解析失败", "raw_text": cleaned_text}
|
||||
|
||||
|
||||
def fix_incomplete_json(text: str) -> str:
|
||||
"""
|
||||
修复不完整的JSON响应
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
修复后的JSON文本,如果无法修复则返回空字符串
|
||||
"""
|
||||
# 移除多余的逗号和空白
|
||||
text = re.sub(r',\s*}', '}', text)
|
||||
text = re.sub(r',\s*]', ']', text)
|
||||
|
||||
# 检查是否已经是有效的JSON
|
||||
try:
|
||||
json.loads(text)
|
||||
return text
|
||||
except JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 检查是否缺少开头的数组符号
|
||||
if text.strip().startswith('{') and not text.strip().startswith('['):
|
||||
# 如果以对象开始,尝试包装成数组
|
||||
if text.count('{') > 1:
|
||||
# 多个对象,包装成数组
|
||||
text = '[' + text + ']'
|
||||
else:
|
||||
# 单个对象,包装成数组
|
||||
text = '[' + text + ']'
|
||||
|
||||
# 检查是否缺少结尾的数组符号
|
||||
if text.strip().endswith('}') and not text.strip().endswith(']'):
|
||||
# 如果以对象结束,尝试包装成数组
|
||||
if text.count('}') > 1:
|
||||
# 多个对象,包装成数组
|
||||
text = '[' + text + ']'
|
||||
else:
|
||||
# 单个对象,包装成数组
|
||||
text = '[' + text + ']'
|
||||
|
||||
# 检查括号是否匹配
|
||||
open_braces = text.count('{')
|
||||
close_braces = text.count('}')
|
||||
open_brackets = text.count('[')
|
||||
close_brackets = text.count(']')
|
||||
|
||||
# 修复不匹配的括号
|
||||
if open_braces > close_braces:
|
||||
text += '}' * (open_braces - close_braces)
|
||||
if open_brackets > close_brackets:
|
||||
text += ']' * (open_brackets - close_brackets)
|
||||
|
||||
# 验证修复后的JSON是否有效
|
||||
try:
|
||||
json.loads(text)
|
||||
return text
|
||||
except JSONDecodeError:
|
||||
# 如果仍然无效,尝试更激进的修复
|
||||
return fix_aggressive_json(text)
|
||||
|
||||
|
||||
def fix_aggressive_json(text: str) -> str:
|
||||
"""
|
||||
更激进的JSON修复方法
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
修复后的JSON文本
|
||||
"""
|
||||
# 查找所有可能的JSON对象
|
||||
objects = re.findall(r'\{[^{}]*\}', text)
|
||||
|
||||
if len(objects) >= 2:
|
||||
# 如果有多个对象,包装成数组
|
||||
return '[' + ','.join(objects) + ']'
|
||||
elif len(objects) == 1:
|
||||
# 如果只有一个对象,包装成数组
|
||||
return '[' + objects[0] + ']'
|
||||
else:
|
||||
# 如果没有找到对象,返回空数组
|
||||
return '[]'
|
||||
|
||||
|
||||
def update_state_with_search_results(search_results: List[Dict[str, Any]],
|
||||
paragraph_index: int, state: Any) -> Any:
|
||||
"""
|
||||
将搜索结果更新到状态中
|
||||
|
||||
Args:
|
||||
search_results: 搜索结果列表
|
||||
paragraph_index: 段落索引
|
||||
state: 状态对象
|
||||
|
||||
Returns:
|
||||
更新后的状态对象
|
||||
"""
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
# 获取最后一次搜索的查询(假设是当前查询)
|
||||
current_query = ""
|
||||
if search_results:
|
||||
# 从搜索结果推断查询(这里需要改进以获取实际查询)
|
||||
current_query = "搜索查询"
|
||||
|
||||
# 添加搜索结果到状态
|
||||
state.paragraphs[paragraph_index].research.add_search_results(
|
||||
current_query, search_results
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def validate_json_schema(data: Dict[str, Any], required_fields: List[str]) -> bool:
|
||||
"""
|
||||
验证JSON数据是否包含必需字段
|
||||
|
||||
Args:
|
||||
data: 要验证的数据
|
||||
required_fields: 必需字段列表
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
return all(field in data for field in required_fields)
|
||||
|
||||
|
||||
def truncate_content(content: str, max_length: int = 20000) -> str:
|
||||
"""
|
||||
截断内容到指定长度
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
截断后的内容
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
|
||||
# 尝试在单词边界截断
|
||||
truncated = content[:max_length]
|
||||
last_space = truncated.rfind(' ')
|
||||
|
||||
if last_space > max_length * 0.8: # 如果最后一个空格位置合理
|
||||
return truncated[:last_space] + "..."
|
||||
else:
|
||||
return truncated + "..."
|
||||
|
||||
|
||||
def format_search_results_for_prompt(search_results: List[Dict[str, Any]],
|
||||
max_length: int = 20000) -> List[str]:
|
||||
"""
|
||||
格式化搜索结果用于提示词
|
||||
|
||||
Args:
|
||||
search_results: 搜索结果列表
|
||||
max_length: 每个结果的最大长度
|
||||
|
||||
Returns:
|
||||
格式化后的内容列表
|
||||
"""
|
||||
formatted_results = []
|
||||
|
||||
for result in search_results:
|
||||
content = result.get('content', '')
|
||||
if content:
|
||||
truncated_content = truncate_content(content, max_length)
|
||||
formatted_results.append(truncated_content)
|
||||
|
||||
return formatted_results
|
||||
Reference in New Issue
Block a user