The Insight Engine agent has been basically completed.

This commit is contained in:
戒酒的李白
2025-08-23 15:11:51 +08:00
parent c35a6baf05
commit 4e33224633
7 changed files with 437 additions and 54 deletions
+102 -34
View File
@@ -19,7 +19,7 @@ from .nodes import (
ReportFormattingNode
)
from .state import State
from .tools import MediaCrawlerDB, DBResponse
from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer
from .utils import Config, load_config, format_search_results_for_prompt
@@ -113,7 +113,7 @@ class DeepSearchAgent:
def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse:
"""
执行指定的数据库查询工具
执行指定的数据库查询工具(集成关键词优化中间件)
Args:
tool_name: 工具名称,可选值:
@@ -130,34 +130,102 @@ class DeepSearchAgent:
"""
print(f" → 执行数据库查询工具: {tool_name}")
# 对于热点内容搜索,不需要关键词优化(因为不需要query参数)
if tool_name == "search_hot_content":
time_period = kwargs.get("time_period", "week")
limit = kwargs.get("limit", 10)
limit = kwargs.get("limit", 100)
return self.search_agency.search_hot_content(time_period=time_period, limit=limit)
elif tool_name == "search_topic_globally":
limit_per_table = kwargs.get("limit_per_table", 5)
return self.search_agency.search_topic_globally(topic=query, limit_per_table=limit_per_table)
elif tool_name == "search_topic_by_date":
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
limit_per_table = kwargs.get("limit_per_table", 10)
if not start_date or not end_date:
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
return self.search_agency.search_topic_by_date(topic=query, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
elif tool_name == "get_comments_for_topic":
limit = kwargs.get("limit", 50)
return self.search_agency.get_comments_for_topic(topic=query, limit=limit)
elif tool_name == "search_topic_on_platform":
platform = kwargs.get("platform")
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
limit = kwargs.get("limit", 20)
if not platform:
raise ValueError("search_topic_on_platform工具需要platform参数")
return self.search_agency.search_topic_on_platform(platform=platform, topic=query, start_date=start_date, end_date=end_date, limit=limit)
else:
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认全局搜索")
return self.search_agency.search_topic_globally(topic=query)
# 对于需要搜索词的工具,使用关键词优化中间件
optimized_response = keyword_optimizer.optimize_keywords(
original_query=query,
context=f"使用{tool_name}工具进行查询"
)
print(f" 🔍 原始查询: '{query}'")
print(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
# 使用优化后的关键词进行多次查询并整合结果
all_results = []
total_count = 0
for keyword in optimized_response.optimized_keywords:
print(f" 查询关键词: '{keyword}'")
try:
if tool_name == "search_topic_globally":
limit_per_table = kwargs.get("limit_per_table", 100)
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table)
elif tool_name == "search_topic_by_date":
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
limit_per_table = kwargs.get("limit_per_table", 100)
if not start_date or not end_date:
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
elif tool_name == "get_comments_for_topic":
limit = kwargs.get("limit", 500) // len(optimized_response.optimized_keywords)
limit = max(limit, 50)
response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit)
elif tool_name == "search_topic_on_platform":
platform = kwargs.get("platform")
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
limit = kwargs.get("limit", 200) // len(optimized_response.optimized_keywords)
limit = max(limit, 30)
if not platform:
raise ValueError("search_topic_on_platform工具需要platform参数")
response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit)
else:
print(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=100)
# 收集结果
if response.results:
print(f" 找到 {len(response.results)} 条结果")
all_results.extend(response.results)
total_count += len(response.results)
else:
print(f" 未找到结果")
except Exception as e:
print(f" 查询'{keyword}'时出错: {str(e)}")
continue
# 去重和整合结果
unique_results = self._deduplicate_results(all_results)
print(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)}")
# 构建整合后的响应
integrated_response = DBResponse(
tool_name=f"{tool_name}_optimized",
parameters={
"original_query": query,
"optimized_keywords": optimized_response.optimized_keywords,
"optimization_reasoning": optimized_response.reasoning,
**kwargs
},
results=unique_results,
results_count=len(unique_results)
)
return integrated_response
def _deduplicate_results(self, results: List) -> List:
"""
去重搜索结果
"""
seen = set()
unique_results = []
for result in results:
# 使用URL或内容作为去重标识
identifier = result.url if result.url else result.title_or_content[:100]
if identifier not in seen:
seen.add(identifier)
unique_results.append(result)
return unique_results
def research(self, query: str, save_report: bool = True) -> str:
"""
@@ -291,14 +359,14 @@ class DeepSearchAgent:
# 处理限制参数
if search_tool == "search_hot_content":
time_period = search_output.get("time_period", "week")
limit = search_output.get("limit", 10)
limit = search_output.get("limit", 100)
search_kwargs["time_period"] = time_period
search_kwargs["limit"] = limit
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
limit_per_table = search_output.get("limit_per_table", 5)
limit_per_table = search_output.get("limit_per_table", 100)
search_kwargs["limit_per_table"] = limit_per_table
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
limit = search_output.get("limit", 20)
limit = search_output.get("limit", 200)
search_kwargs["limit"] = limit
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -306,8 +374,8 @@ class DeepSearchAgent:
# 转换为兼容格式
search_results = []
if search_response and search_response.results:
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
max_results = min(len(search_response.results), 10)
# 每种搜索工具都有其特定的结果数量,这里取前100个作为上限
max_results = min(len(search_response.results), 100)
for result in search_response.results[:max_results]:
search_results.append({
'title': result.title_or_content,
@@ -426,8 +494,8 @@ class DeepSearchAgent:
# 转换为兼容格式
search_results = []
if search_response and search_response.results:
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
max_results = min(len(search_response.results), 10)
# 每种搜索工具都有其特定的结果数量,这里取前100个作为上限
max_results = min(len(search_response.results), 100)
for result in search_response.results[:max_results]:
search_results.append({
'title': result.title_or_content,