The Insight Engine agent has been basically completed.
This commit is contained in:
+102
-34
@@ -19,7 +19,7 @@ from .nodes import (
|
||||
ReportFormattingNode
|
||||
)
|
||||
from .state import State
|
||||
from .tools import MediaCrawlerDB, DBResponse
|
||||
from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer
|
||||
from .utils import Config, load_config, format_search_results_for_prompt
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ class DeepSearchAgent:
|
||||
|
||||
def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse:
|
||||
"""
|
||||
执行指定的数据库查询工具
|
||||
执行指定的数据库查询工具(集成关键词优化中间件)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称,可选值:
|
||||
@@ -130,34 +130,102 @@ class DeepSearchAgent:
|
||||
"""
|
||||
print(f" → 执行数据库查询工具: {tool_name}")
|
||||
|
||||
# 对于热点内容搜索,不需要关键词优化(因为不需要query参数)
|
||||
if tool_name == "search_hot_content":
|
||||
time_period = kwargs.get("time_period", "week")
|
||||
limit = kwargs.get("limit", 10)
|
||||
limit = kwargs.get("limit", 100)
|
||||
return self.search_agency.search_hot_content(time_period=time_period, limit=limit)
|
||||
elif tool_name == "search_topic_globally":
|
||||
limit_per_table = kwargs.get("limit_per_table", 5)
|
||||
return self.search_agency.search_topic_globally(topic=query, limit_per_table=limit_per_table)
|
||||
elif tool_name == "search_topic_by_date":
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
limit_per_table = kwargs.get("limit_per_table", 10)
|
||||
if not start_date or not end_date:
|
||||
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
|
||||
return self.search_agency.search_topic_by_date(topic=query, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
|
||||
elif tool_name == "get_comments_for_topic":
|
||||
limit = kwargs.get("limit", 50)
|
||||
return self.search_agency.get_comments_for_topic(topic=query, limit=limit)
|
||||
elif tool_name == "search_topic_on_platform":
|
||||
platform = kwargs.get("platform")
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
limit = kwargs.get("limit", 20)
|
||||
if not platform:
|
||||
raise ValueError("search_topic_on_platform工具需要platform参数")
|
||||
return self.search_agency.search_topic_on_platform(platform=platform, topic=query, start_date=start_date, end_date=end_date, limit=limit)
|
||||
else:
|
||||
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认全局搜索")
|
||||
return self.search_agency.search_topic_globally(topic=query)
|
||||
|
||||
# 对于需要搜索词的工具,使用关键词优化中间件
|
||||
optimized_response = keyword_optimizer.optimize_keywords(
|
||||
original_query=query,
|
||||
context=f"使用{tool_name}工具进行查询"
|
||||
)
|
||||
|
||||
print(f" 🔍 原始查询: '{query}'")
|
||||
print(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
|
||||
|
||||
# 使用优化后的关键词进行多次查询并整合结果
|
||||
all_results = []
|
||||
total_count = 0
|
||||
|
||||
for keyword in optimized_response.optimized_keywords:
|
||||
print(f" 查询关键词: '{keyword}'")
|
||||
|
||||
try:
|
||||
if tool_name == "search_topic_globally":
|
||||
limit_per_table = kwargs.get("limit_per_table", 100)
|
||||
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table)
|
||||
elif tool_name == "search_topic_by_date":
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
limit_per_table = kwargs.get("limit_per_table", 100)
|
||||
if not start_date or not end_date:
|
||||
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
|
||||
response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
|
||||
elif tool_name == "get_comments_for_topic":
|
||||
limit = kwargs.get("limit", 500) // len(optimized_response.optimized_keywords)
|
||||
limit = max(limit, 50)
|
||||
response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit)
|
||||
elif tool_name == "search_topic_on_platform":
|
||||
platform = kwargs.get("platform")
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
limit = kwargs.get("limit", 200) // len(optimized_response.optimized_keywords)
|
||||
limit = max(limit, 30)
|
||||
if not platform:
|
||||
raise ValueError("search_topic_on_platform工具需要platform参数")
|
||||
response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit)
|
||||
else:
|
||||
print(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
|
||||
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=100)
|
||||
|
||||
# 收集结果
|
||||
if response.results:
|
||||
print(f" 找到 {len(response.results)} 条结果")
|
||||
all_results.extend(response.results)
|
||||
total_count += len(response.results)
|
||||
else:
|
||||
print(f" 未找到结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 查询'{keyword}'时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
# 去重和整合结果
|
||||
unique_results = self._deduplicate_results(all_results)
|
||||
print(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条")
|
||||
|
||||
# 构建整合后的响应
|
||||
integrated_response = DBResponse(
|
||||
tool_name=f"{tool_name}_optimized",
|
||||
parameters={
|
||||
"original_query": query,
|
||||
"optimized_keywords": optimized_response.optimized_keywords,
|
||||
"optimization_reasoning": optimized_response.reasoning,
|
||||
**kwargs
|
||||
},
|
||||
results=unique_results,
|
||||
results_count=len(unique_results)
|
||||
)
|
||||
|
||||
return integrated_response
|
||||
|
||||
def _deduplicate_results(self, results: List) -> List:
|
||||
"""
|
||||
去重搜索结果
|
||||
"""
|
||||
seen = set()
|
||||
unique_results = []
|
||||
|
||||
for result in results:
|
||||
# 使用URL或内容作为去重标识
|
||||
identifier = result.url if result.url else result.title_or_content[:100]
|
||||
if identifier not in seen:
|
||||
seen.add(identifier)
|
||||
unique_results.append(result)
|
||||
|
||||
return unique_results
|
||||
|
||||
def research(self, query: str, save_report: bool = True) -> str:
|
||||
"""
|
||||
@@ -291,14 +359,14 @@ class DeepSearchAgent:
|
||||
# 处理限制参数
|
||||
if search_tool == "search_hot_content":
|
||||
time_period = search_output.get("time_period", "week")
|
||||
limit = search_output.get("limit", 10)
|
||||
limit = search_output.get("limit", 100)
|
||||
search_kwargs["time_period"] = time_period
|
||||
search_kwargs["limit"] = limit
|
||||
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
|
||||
limit_per_table = search_output.get("limit_per_table", 5)
|
||||
limit_per_table = search_output.get("limit_per_table", 100)
|
||||
search_kwargs["limit_per_table"] = limit_per_table
|
||||
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
|
||||
limit = search_output.get("limit", 20)
|
||||
limit = search_output.get("limit", 200)
|
||||
search_kwargs["limit"] = limit
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
@@ -306,8 +374,8 @@ class DeepSearchAgent:
|
||||
# 转换为兼容格式
|
||||
search_results = []
|
||||
if search_response and search_response.results:
|
||||
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
|
||||
max_results = min(len(search_response.results), 10)
|
||||
# 每种搜索工具都有其特定的结果数量,这里取前100个作为上限
|
||||
max_results = min(len(search_response.results), 100)
|
||||
for result in search_response.results[:max_results]:
|
||||
search_results.append({
|
||||
'title': result.title_or_content,
|
||||
@@ -426,8 +494,8 @@ class DeepSearchAgent:
|
||||
# 转换为兼容格式
|
||||
search_results = []
|
||||
if search_response and search_response.results:
|
||||
# 每种搜索工具都有其特定的结果数量,这里取前10个作为上限
|
||||
max_results = min(len(search_response.results), 10)
|
||||
# 每种搜索工具都有其特定的结果数量,这里取前100个作为上限
|
||||
max_results = min(len(search_response.results), 100)
|
||||
for result in search_response.results[:max_results]:
|
||||
search_results.append({
|
||||
'title': result.title_or_content,
|
||||
|
||||
Reference in New Issue
Block a user