The multimodal agent has been basically completed.

This commit is contained in:
戒酒的李白
2025-08-22 23:39:11 +08:00
parent 062f66cb2e
commit 431cf6a12c
5 changed files with 99 additions and 144 deletions
+51 -85
View File
@@ -19,7 +19,7 @@ from .nodes import (
ReportFormattingNode ReportFormattingNode
) )
from .state import State from .state import State
from .tools import TavilyNewsAgency, TavilyResponse from .tools import BochaMultimodalSearch, BochaResponse
from .utils import Config, load_config, format_search_results_for_prompt from .utils import Config, load_config, format_search_results_for_prompt
@@ -40,7 +40,7 @@ class DeepSearchAgent:
self.llm_client = self._initialize_llm() self.llm_client = self._initialize_llm()
# 初始化搜索工具集 # 初始化搜索工具集
self.search_agency = TavilyNewsAgency(api_key=self.config.tavily_api_key) self.search_agency = BochaMultimodalSearch(api_key=self.config.bocha_api_key)
# 初始化节点 # 初始化节点
self._initialize_nodes() self._initialize_nodes()
@@ -53,7 +53,7 @@ class DeepSearchAgent:
print(f"Deep Search Agent 已初始化") print(f"Deep Search Agent 已初始化")
print(f"使用LLM: {self.llm_client.get_model_info()}") print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)") print(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)")
def _initialize_llm(self) -> BaseLLM: def _initialize_llm(self) -> BaseLLM:
"""初始化LLM客户端""" """初始化LLM客户端"""
@@ -103,46 +103,40 @@ class DeepSearchAgent:
except ValueError: except ValueError:
return False return False
def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyResponse: def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> BochaResponse:
""" """
执行指定的搜索工具 执行指定的搜索工具
Args: Args:
tool_name: 工具名称,可选值: tool_name: 工具名称,可选值:
- "basic_search_news": 基础新闻搜索(快速、通用 - "comprehensive_search": 全面综合搜索(默认
- "deep_search_news": 深度新闻分析 - "web_search_only": 纯网页搜索
- "search_news_last_24_hours": 24小时内最新新闻 - "search_for_structured_data": 结构化数据查询
- "search_news_last_week": 本周新闻 - "search_last_24_hours": 24小时内最新信息
- "search_images_for_news": 新闻图片搜索 - "search_last_week": 本周信息
- "search_news_by_date": 按日期范围搜索新闻
query: 搜索查询 query: 搜索查询
**kwargs: 额外参数(如start_date, end_date, max_results **kwargs: 额外参数(如max_results
Returns: Returns:
TavilyResponse对象 BochaResponse对象
""" """
print(f" → 执行搜索工具: {tool_name}") print(f" → 执行搜索工具: {tool_name}")
if tool_name == "basic_search_news": if tool_name == "comprehensive_search":
max_results = kwargs.get("max_results", 7) max_results = kwargs.get("max_results", 10)
return self.search_agency.basic_search_news(query, max_results) return self.search_agency.comprehensive_search(query, max_results)
elif tool_name == "deep_search_news": elif tool_name == "web_search_only":
return self.search_agency.deep_search_news(query) max_results = kwargs.get("max_results", 15)
elif tool_name == "search_news_last_24_hours": return self.search_agency.web_search_only(query, max_results)
return self.search_agency.search_news_last_24_hours(query) elif tool_name == "search_for_structured_data":
elif tool_name == "search_news_last_week": return self.search_agency.search_for_structured_data(query)
return self.search_agency.search_news_last_week(query) elif tool_name == "search_last_24_hours":
elif tool_name == "search_images_for_news": return self.search_agency.search_last_24_hours(query)
return self.search_agency.search_images_for_news(query) elif tool_name == "search_last_week":
elif tool_name == "search_news_by_date": return self.search_agency.search_last_week(query)
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
if not start_date or not end_date:
raise ValueError("search_news_by_date工具需要start_date和end_date参数")
return self.search_agency.search_news_by_date(query, start_date, end_date)
else: else:
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索") print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索")
return self.search_agency.basic_search_news(query) return self.search_agency.comprehensive_search(query)
def research(self, query: str, save_report: bool = True) -> str: def research(self, query: str, save_report: bool = True) -> str:
""" """
@@ -231,7 +225,7 @@ class DeepSearchAgent:
print(" - 生成搜索查询...") print(" - 生成搜索查询...")
search_output = self.first_search_node.run(search_input) search_output = self.first_search_node.run(search_input)
search_query = search_output["search_query"] search_query = search_output["search_query"]
search_tool = search_output.get("search_tool", "basic_search_news") # 默认工具 search_tool = search_output.get("search_tool", "comprehensive_search") # 默认工具
reasoning = search_output["reasoning"] reasoning = search_output["reasoning"]
print(f" - 搜索查询: {search_query}") print(f" - 搜索查询: {search_query}")
@@ -241,41 +235,27 @@ class DeepSearchAgent:
# 执行搜索 # 执行搜索
print(" - 执行网络搜索...") print(" - 执行网络搜索...")
# 处理search_news_by_date的特殊参数 # 处理特殊参数(新的工具集不需要日期参数处理)
search_kwargs = {} search_kwargs = {}
if search_tool == "search_news_by_date": if search_tool in ["comprehensive_search", "web_search_only"]:
start_date = search_output.get("start_date") # 这些工具支持max_results参数
end_date = search_output.get("end_date") search_kwargs["max_results"] = 10
if start_date and end_date:
# 验证日期格式
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" - 时间范围: {start_date}{end_date}")
else:
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "basic_search_news"
else:
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
search_tool = "basic_search_news"
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
# 转换为兼容格式 # 转换为兼容格式
search_results = [] search_results = []
if search_response and search_response.results: if search_response and search_response.webpages:
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限 # 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
max_results = min(len(search_response.results), 10) max_results = min(len(search_response.webpages), 10)
for result in search_response.results[:max_results]: for result in search_response.webpages[:max_results]:
search_results.append({ search_results.append({
'title': result.title, 'title': result.name,
'url': result.url, 'url': result.url,
'content': result.content, 'content': result.snippet,
'score': result.score, 'score': None, # Bocha API不提供score
'raw_content': result.raw_content, 'raw_content': result.snippet,
'published_date': result.published_date # 新增字段 'published_date': result.date_last_crawled # 使用爬取日期
}) })
if search_results: if search_results:
@@ -324,7 +304,7 @@ class DeepSearchAgent:
# 生成反思搜索查询 # 生成反思搜索查询
reflection_output = self.reflection_node.run(reflection_input) reflection_output = self.reflection_node.run(reflection_input)
search_query = reflection_output["search_query"] search_query = reflection_output["search_query"]
search_tool = reflection_output.get("search_tool", "basic_search_news") # 默认工具 search_tool = reflection_output.get("search_tool", "comprehensive_search") # 默认工具
reasoning = reflection_output["reasoning"] reasoning = reflection_output["reasoning"]
print(f" 反思查询: {search_query}") print(f" 反思查询: {search_query}")
@@ -332,41 +312,27 @@ class DeepSearchAgent:
print(f" 反思推理: {reasoning}") print(f" 反思推理: {reasoning}")
# 执行反思搜索 # 执行反思搜索
# 处理search_news_by_date的特殊参数 # 处理特殊参数
search_kwargs = {} search_kwargs = {}
if search_tool == "search_news_by_date": if search_tool in ["comprehensive_search", "web_search_only"]:
start_date = reflection_output.get("start_date") # 这些工具支持max_results参数
end_date = reflection_output.get("end_date") search_kwargs["max_results"] = 10
if start_date and end_date:
# 验证日期格式
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" 时间范围: {start_date}{end_date}")
else:
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "basic_search_news"
else:
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
search_tool = "basic_search_news"
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
# 转换为兼容格式 # 转换为兼容格式
search_results = [] search_results = []
if search_response and search_response.results: if search_response and search_response.webpages:
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限 # 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
max_results = min(len(search_response.results), 10) max_results = min(len(search_response.webpages), 10)
for result in search_response.results[:max_results]: for result in search_response.webpages[:max_results]:
search_results.append({ search_results.append({
'title': result.title, 'title': result.name,
'url': result.url, 'url': result.url,
'content': result.content, 'content': result.snippet,
'score': result.score, 'score': None, # Bocha API不提供score
'raw_content': result.raw_content, 'raw_content': result.snippet,
'published_date': result.published_date 'published_date': result.date_last_crawled
}) })
if search_results: if search_results:
+27 -40
View File
@@ -34,9 +34,7 @@ output_schema_first_search = {
"properties": { "properties": {
"search_query": {"type": "string"}, "search_query": {"type": "string"},
"search_tool": {"type": "string"}, "search_tool": {"type": "string"},
"reasoning": {"type": "string"}, "reasoning": {"type": "string"}
"start_date": {"type": "string", "description": "开始日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"},
"end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"}
}, },
"required": ["search_query", "search_tool", "reasoning"] "required": ["search_query", "search_tool", "reasoning"]
} }
@@ -79,9 +77,7 @@ output_schema_reflection = {
"properties": { "properties": {
"search_query": {"type": "string"}, "search_query": {"type": "string"},
"search_tool": {"type": "string"}, "search_tool": {"type": "string"},
"reasoning": {"type": "string"}, "reasoning": {"type": "string"}
"start_date": {"type": "string", "description": "开始日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"},
"end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,仅search_news_by_date工具需要"}
}, },
"required": ["search_query", "search_tool", "reasoning"] "required": ["search_query", "search_tool", "reasoning"]
} }
@@ -147,41 +143,34 @@ SYSTEM_PROMPT_FIRST_SEARCH = f"""
{json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)} {json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)}
</INPUT JSON SCHEMA> </INPUT JSON SCHEMA>
你可以使用以下6种专业的新闻搜索工具: 你可以使用以下5种专业的多模态搜索工具:
1. **basic_search_news** - 基础新闻搜索工具 1. **comprehensive_search** - 全面综合搜索工具
- 适用于:一般性的新闻搜索,不确定需要何种特定搜索 - 适用于:一般性的研究需求,需要完整信息
- 特点:快速、标准的通用搜索,是最常用的基础工具 - 特点:返回网页、图片、AI总结、追问建议和可能的结构化数据,是最常用的基础工具
2. **deep_search_news** - 深度新闻分析工具 2. **web_search_only** - 纯网页搜索工具
- 适用于:需要全面深入了解某个主题 - 适用于:需要网页链接和摘要,不需要AI分析
- 特点:提供最详细的分析结果,包含高级AI摘要 - 特点:速度更快,成本更低,只返回网页结果
3. **search_news_last_24_hours** - 24小时最新新闻工具 3. **search_for_structured_data** - 结构化数据查询工具
- 适用于:查询天气、股票、汇率、百科定义等结构化信息时
- 特点:专门用于触发"模态卡"的查询,返回结构化数据
4. **search_last_24_hours** - 24小时内信息搜索工具
- 适用于:需要了解最新动态、突发事件时 - 适用于:需要了解最新动态、突发事件时
- 特点:只搜索过去24小时的新闻 - 特点:只搜索过去24小时内发布的内容
4. **search_news_last_week** - 本周新闻工具 5. **search_last_week** - 本周信息搜索工具
- 适用于:需要了解近期发展趋势时 - 适用于:需要了解近期发展趋势时
- 特点:搜索过去一周的新闻报道 - 特点:搜索过去一周内的主要报道
5. **search_images_for_news** - 图片搜索工具
- 适用于:需要可视化信息、图片资料时
- 特点:提供相关图片和图片描述
6. **search_news_by_date** - 按日期范围搜索工具
- 适用于:需要研究特定历史时期时
- 特点:可以指定开始和结束日期进行搜索
- 特殊要求:需要提供start_date和end_date参数,格式为'YYYY-MM-DD'
- 注意:只有这个工具需要额外的时间参数
你的任务是: 你的任务是:
1. 根据段落主题选择最合适的搜索工具 1. 根据段落主题选择最合适的搜索工具
2. 制定最佳的搜索查询 2. 制定最佳的搜索查询
3. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD 3. 解释你的选择理由
4. 解释你的选择理由
注意:除了search_news_by_date工具外,其他工具都不需要额外参数 注意:所有工具都不需要额外参数,选择工具主要基于搜索意图和需要的信息类型
请按照以下JSON模式定义格式化输出(文字请使用中文): 请按照以下JSON模式定义格式化输出(文字请使用中文):
<OUTPUT JSON SCHEMA> <OUTPUT JSON SCHEMA>
@@ -219,23 +208,21 @@ SYSTEM_PROMPT_REFLECTION = f"""
{json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)} {json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)}
</INPUT JSON SCHEMA> </INPUT JSON SCHEMA>
你可以使用以下6种专业的新闻搜索工具: 你可以使用以下5种专业的多模态搜索工具:
1. **basic_search_news** - 基础新闻搜索工具 1. **comprehensive_search** - 全面综合搜索工具
2. **deep_search_news** - 深度新闻分析工具 2. **web_search_only** - 纯网页搜索工具
3. **search_news_last_24_hours** - 24小时最新新闻工具 3. **search_for_structured_data** - 结构化数据查询工具
4. **search_news_last_week** - 本周新闻工具 4. **search_last_24_hours** - 24小时内信息搜索工具
5. **search_images_for_news** - 图片搜索工具 5. **search_last_week** - 本周信息搜索工具
6. **search_news_by_date** - 按日期范围搜索工具(需要时间参数)
你的任务是: 你的任务是:
1. 反思段落文本的当前状态,思考是否遗漏了主题的某些关键方面 1. 反思段落文本的当前状态,思考是否遗漏了主题的某些关键方面
2. 选择最合适的搜索工具来补充缺失信息 2. 选择最合适的搜索工具来补充缺失信息
3. 制定精确的搜索查询 3. 制定精确的搜索查询
4. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD 4. 解释你的选择和推理
5. 解释你的选择和推理
注意:除了search_news_by_date工具外,其他工具都不需要额外参数 注意:所有工具都不需要额外参数,选择工具主要基于搜索意图和需要的信息类型
请按照以下JSON模式定义格式化输出: 请按照以下JSON模式定义格式化输出:
<OUTPUT JSON SCHEMA> <OUTPUT JSON SCHEMA>
+9 -7
View File
@@ -1,20 +1,22 @@
""" """
工具调用模块 工具调用模块
提供外部工具接口,如网络搜索等 提供外部工具接口,如多模态搜索等
""" """
from .search import ( from .search import (
TavilyNewsAgency, BochaMultimodalSearch,
SearchResult, WebpageResult,
TavilyResponse,
ImageResult, ImageResult,
ModalCardResult,
BochaResponse,
print_response_summary print_response_summary
) )
__all__ = [ __all__ = [
"TavilyNewsAgency", "BochaMultimodalSearch",
"SearchResult", "WebpageResult",
"TavilyResponse",
"ImageResult", "ImageResult",
"ModalCardResult",
"BochaResponse",
"print_response_summary" "print_response_summary"
] ]
+7 -7
View File
@@ -14,7 +14,7 @@ class Config:
# API密钥 # API密钥
deepseek_api_key: Optional[str] = None deepseek_api_key: Optional[str] = None
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
tavily_api_key: Optional[str] = None bocha_api_key: Optional[str] = None
# 模型配置 # 模型配置
default_llm_provider: str = "deepseek" # deepseek 或 openai default_llm_provider: str = "deepseek" # deepseek 或 openai
@@ -44,8 +44,8 @@ class Config:
print("错误: OpenAI API Key未设置") print("错误: OpenAI API Key未设置")
return False return False
if not self.tavily_api_key: if not self.bocha_api_key:
print("错误: Tavily API Key未设置") print("错误: Bocha API Key未设置")
return False return False
return True return True
@@ -65,7 +65,7 @@ class Config:
return cls( return cls(
deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None), deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None),
openai_api_key=getattr(config_module, "OPENAI_API_KEY", None), openai_api_key=getattr(config_module, "OPENAI_API_KEY", None),
tavily_api_key=getattr(config_module, "TAVILY_API_KEY", None), bocha_api_key=getattr(config_module, "BOCHA_API_KEY", None),
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"), default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"),
deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"), deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"),
openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"), openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"),
@@ -92,7 +92,7 @@ class Config:
return cls( return cls(
deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"), deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
openai_api_key=config_dict.get("OPENAI_API_KEY"), openai_api_key=config_dict.get("OPENAI_API_KEY"),
tavily_api_key=config_dict.get("TAVILY_API_KEY"), bocha_api_key=config_dict.get("BOCHA_API_KEY"),
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"), default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"), deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
@@ -147,7 +147,7 @@ def print_config(config: Config):
print(f"LLM提供商: {config.default_llm_provider}") print(f"LLM提供商: {config.default_llm_provider}")
print(f"DeepSeek模型: {config.deepseek_model}") print(f"DeepSeek模型: {config.deepseek_model}")
print(f"OpenAI模型: {config.openai_model}") print(f"OpenAI模型: {config.openai_model}")
print(f"最大搜索结果数: {config.max_search_results}")
print(f"搜索超时: {config.search_timeout}") print(f"搜索超时: {config.search_timeout}")
print(f"最大内容长度: {config.max_content_length}") print(f"最大内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}") print(f"最大反思次数: {config.max_reflections}")
@@ -158,5 +158,5 @@ def print_config(config: Config):
# 显示API密钥状态(不显示实际密钥) # 显示API密钥状态(不显示实际密钥)
print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}") print(f"DeepSeek API Key: {'已设置' if config.deepseek_api_key else '未设置'}")
print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}") print(f"OpenAI API Key: {'已设置' if config.openai_api_key else '未设置'}")
print(f"Tavily API Key: {'已设置' if config.tavily_api_key else '未设置'}") print(f"Bocha API Key: {'已设置' if config.bocha_api_key else '未设置'}")
print("==================\n") print("==================\n")
+5 -5
View File
@@ -12,8 +12,8 @@ import json
# 添加src目录到Python路径 # 添加src目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '.')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '.'))
from QueryEngine import DeepSearchAgent, Config from MediaEngine import DeepSearchAgent, Config
from config import DEEPSEEK_API_KEY, TAVILY_API_KEY from config import DEEPSEEK_API_KEY, BOCHA_Web_Search_API_KEY
def main(): def main():
@@ -98,19 +98,19 @@ def main():
# 自动使用配置文件中的API密钥 # 自动使用配置文件中的API密钥
deepseek_key = DEEPSEEK_API_KEY deepseek_key = DEEPSEEK_API_KEY
tavily_key = TAVILY_API_KEY bocha_key = BOCHA_Web_Search_API_KEY
# 创建配置 # 创建配置
config = Config( config = Config(
deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None, deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None,
openai_api_key=openai_key if llm_provider == "openai" else None, openai_api_key=openai_key if llm_provider == "openai" else None,
tavily_api_key=tavily_key, bocha_api_key=bocha_key,
default_llm_provider=llm_provider, default_llm_provider=llm_provider,
deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat", deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat",
openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini", openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini",
max_reflections=max_reflections, max_reflections=max_reflections,
max_content_length=max_content_length, max_content_length=max_content_length,
output_dir="query_engine_streamlit_reports" output_dir="media_engine_streamlit_reports"
) )
# 执行研究 # 执行研究