The private database analysis agent has been basically completed.
This commit is contained in:
+131
-58
@@ -19,7 +19,7 @@ from .nodes import (
|
||||
ReportFormattingNode
|
||||
)
|
||||
from .state import State
|
||||
from .tools import TavilyNewsAgency, TavilyResponse
|
||||
from .tools import MediaCrawlerDB, DBResponse
|
||||
from .utils import Config, load_config, format_search_results_for_prompt
|
||||
|
||||
|
||||
@@ -39,8 +39,16 @@ class DeepSearchAgent:
|
||||
# 初始化LLM客户端
|
||||
self.llm_client = self._initialize_llm()
|
||||
|
||||
# 设置数据库环境变量
|
||||
os.environ["DB_HOST"] = self.config.db_host or ""
|
||||
os.environ["DB_USER"] = self.config.db_user or ""
|
||||
os.environ["DB_PASSWORD"] = self.config.db_password or ""
|
||||
os.environ["DB_NAME"] = self.config.db_name or ""
|
||||
os.environ["DB_PORT"] = str(self.config.db_port)
|
||||
os.environ["DB_CHARSET"] = self.config.db_charset
|
||||
|
||||
# 初始化搜索工具集
|
||||
self.search_agency = TavilyNewsAgency(api_key=self.config.tavily_api_key)
|
||||
self.search_agency = MediaCrawlerDB()
|
||||
|
||||
# 初始化节点
|
||||
self._initialize_nodes()
|
||||
@@ -53,7 +61,7 @@ class DeepSearchAgent:
|
||||
|
||||
print(f"Deep Search Agent 已初始化")
|
||||
print(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
print(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)")
|
||||
print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
|
||||
|
||||
def _initialize_llm(self) -> BaseLLM:
|
||||
"""初始化LLM客户端"""
|
||||
@@ -103,46 +111,53 @@ class DeepSearchAgent:
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyResponse:
|
||||
def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse:
|
||||
"""
|
||||
执行指定的搜索工具
|
||||
执行指定的数据库查询工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称,可选值:
|
||||
- "basic_search_news": 基础新闻搜索(快速、通用)
|
||||
- "deep_search_news": 深度新闻分析
|
||||
- "search_news_last_24_hours": 24小时内最新新闻
|
||||
- "search_news_last_week": 本周新闻
|
||||
- "search_images_for_news": 新闻图片搜索
|
||||
- "search_news_by_date": 按日期范围搜索新闻
|
||||
query: 搜索查询
|
||||
**kwargs: 额外参数(如start_date, end_date, max_results)
|
||||
- "search_hot_content": 查找热点内容
|
||||
- "search_topic_globally": 全局话题搜索
|
||||
- "search_topic_by_date": 按日期搜索话题
|
||||
- "get_comments_for_topic": 获取话题评论
|
||||
- "search_topic_on_platform": 平台定向搜索
|
||||
query: 搜索关键词/话题
|
||||
**kwargs: 额外参数(如start_date, end_date, platform, limit等)
|
||||
|
||||
Returns:
|
||||
TavilyResponse对象
|
||||
DBResponse对象
|
||||
"""
|
||||
print(f" → 执行搜索工具: {tool_name}")
|
||||
print(f" → 执行数据库查询工具: {tool_name}")
|
||||
|
||||
if tool_name == "basic_search_news":
|
||||
max_results = kwargs.get("max_results", 7)
|
||||
return self.search_agency.basic_search_news(query, max_results)
|
||||
elif tool_name == "deep_search_news":
|
||||
return self.search_agency.deep_search_news(query)
|
||||
elif tool_name == "search_news_last_24_hours":
|
||||
return self.search_agency.search_news_last_24_hours(query)
|
||||
elif tool_name == "search_news_last_week":
|
||||
return self.search_agency.search_news_last_week(query)
|
||||
elif tool_name == "search_images_for_news":
|
||||
return self.search_agency.search_images_for_news(query)
|
||||
elif tool_name == "search_news_by_date":
|
||||
if tool_name == "search_hot_content":
|
||||
time_period = kwargs.get("time_period", "week")
|
||||
limit = kwargs.get("limit", 10)
|
||||
return self.search_agency.search_hot_content(time_period=time_period, limit=limit)
|
||||
elif tool_name == "search_topic_globally":
|
||||
limit_per_table = kwargs.get("limit_per_table", 5)
|
||||
return self.search_agency.search_topic_globally(topic=query, limit_per_table=limit_per_table)
|
||||
elif tool_name == "search_topic_by_date":
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
limit_per_table = kwargs.get("limit_per_table", 10)
|
||||
if not start_date or not end_date:
|
||||
raise ValueError("search_news_by_date工具需要start_date和end_date参数")
|
||||
return self.search_agency.search_news_by_date(query, start_date, end_date)
|
||||
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
|
||||
return self.search_agency.search_topic_by_date(topic=query, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
|
||||
elif tool_name == "get_comments_for_topic":
|
||||
limit = kwargs.get("limit", 50)
|
||||
return self.search_agency.get_comments_for_topic(topic=query, limit=limit)
|
||||
elif tool_name == "search_topic_on_platform":
|
||||
platform = kwargs.get("platform")
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
limit = kwargs.get("limit", 20)
|
||||
if not platform:
|
||||
raise ValueError("search_topic_on_platform工具需要platform参数")
|
||||
return self.search_agency.search_topic_on_platform(platform=platform, topic=query, start_date=start_date, end_date=end_date, limit=limit)
|
||||
else:
|
||||
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索")
|
||||
return self.search_agency.basic_search_news(query)
|
||||
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认全局搜索")
|
||||
return self.search_agency.search_topic_globally(topic=query)
|
||||
|
||||
def research(self, query: str, save_report: bool = True) -> str:
|
||||
"""
|
||||
@@ -231,7 +246,7 @@ class DeepSearchAgent:
|
||||
print(" - 生成搜索查询...")
|
||||
search_output = self.first_search_node.run(search_input)
|
||||
search_query = search_output["search_query"]
|
||||
search_tool = search_output.get("search_tool", "basic_search_news") # 默认工具
|
||||
search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具
|
||||
reasoning = search_output["reasoning"]
|
||||
|
||||
print(f" - 搜索查询: {search_query}")
|
||||
@@ -239,11 +254,13 @@ class DeepSearchAgent:
|
||||
print(f" - 推理: {reasoning}")
|
||||
|
||||
# 执行搜索
|
||||
print(" - 执行网络搜索...")
|
||||
print(" - 执行数据库查询...")
|
||||
|
||||
# 处理search_news_by_date的特殊参数
|
||||
# 处理特殊参数
|
||||
search_kwargs = {}
|
||||
if search_tool == "search_news_by_date":
|
||||
|
||||
# 处理需要日期的工具
|
||||
if search_tool in ["search_topic_by_date", "search_topic_on_platform"]:
|
||||
start_date = search_output.get("start_date")
|
||||
end_date = search_output.get("end_date")
|
||||
|
||||
@@ -254,12 +271,35 @@ class DeepSearchAgent:
|
||||
search_kwargs["end_date"] = end_date
|
||||
print(f" - 时间范围: {start_date} 到 {end_date}")
|
||||
else:
|
||||
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
|
||||
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
|
||||
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
search_tool = "basic_search_news"
|
||||
search_tool = "search_topic_globally"
|
||||
elif search_tool == "search_topic_by_date":
|
||||
print(f" ⚠️ search_topic_by_date工具缺少时间参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理需要平台参数的工具
|
||||
if search_tool == "search_topic_on_platform":
|
||||
platform = search_output.get("platform")
|
||||
if platform:
|
||||
search_kwargs["platform"] = platform
|
||||
print(f" - 指定平台: {platform}")
|
||||
else:
|
||||
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
|
||||
search_tool = "basic_search_news"
|
||||
print(f" ⚠️ search_topic_on_platform工具缺少平台参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理限制参数
|
||||
if search_tool == "search_hot_content":
|
||||
time_period = search_output.get("time_period", "week")
|
||||
limit = search_output.get("limit", 10)
|
||||
search_kwargs["time_period"] = time_period
|
||||
search_kwargs["limit"] = limit
|
||||
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
|
||||
limit_per_table = search_output.get("limit_per_table", 5)
|
||||
search_kwargs["limit_per_table"] = limit_per_table
|
||||
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
|
||||
limit = search_output.get("limit", 20)
|
||||
search_kwargs["limit"] = limit
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
|
||||
@@ -270,12 +310,16 @@ class DeepSearchAgent:
|
||||
max_results = min(len(search_response.results), 10)
|
||||
for result in search_response.results[:max_results]:
|
||||
search_results.append({
|
||||
'title': result.title,
|
||||
'url': result.url,
|
||||
'content': result.content,
|
||||
'score': result.score,
|
||||
'raw_content': result.raw_content,
|
||||
'published_date': result.published_date # 新增字段
|
||||
'title': result.title_or_content,
|
||||
'url': result.url or "",
|
||||
'content': result.title_or_content,
|
||||
'score': result.hotness_score,
|
||||
'raw_content': result.title_or_content,
|
||||
'published_date': result.publish_time.isoformat() if result.publish_time else None,
|
||||
'platform': result.platform,
|
||||
'content_type': result.content_type,
|
||||
'author': result.author_nickname,
|
||||
'engagement': result.engagement
|
||||
})
|
||||
|
||||
if search_results:
|
||||
@@ -324,7 +368,7 @@ class DeepSearchAgent:
|
||||
# 生成反思搜索查询
|
||||
reflection_output = self.reflection_node.run(reflection_input)
|
||||
search_query = reflection_output["search_query"]
|
||||
search_tool = reflection_output.get("search_tool", "basic_search_news") # 默认工具
|
||||
search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具
|
||||
reasoning = reflection_output["reasoning"]
|
||||
|
||||
print(f" 反思查询: {search_query}")
|
||||
@@ -332,9 +376,11 @@ class DeepSearchAgent:
|
||||
print(f" 反思推理: {reasoning}")
|
||||
|
||||
# 执行反思搜索
|
||||
# 处理search_news_by_date的特殊参数
|
||||
# 处理特殊参数
|
||||
search_kwargs = {}
|
||||
if search_tool == "search_news_by_date":
|
||||
|
||||
# 处理需要日期的工具
|
||||
if search_tool in ["search_topic_by_date", "search_topic_on_platform"]:
|
||||
start_date = reflection_output.get("start_date")
|
||||
end_date = reflection_output.get("end_date")
|
||||
|
||||
@@ -345,12 +391,35 @@ class DeepSearchAgent:
|
||||
search_kwargs["end_date"] = end_date
|
||||
print(f" 时间范围: {start_date} 到 {end_date}")
|
||||
else:
|
||||
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
|
||||
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
|
||||
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
search_tool = "basic_search_news"
|
||||
search_tool = "search_topic_globally"
|
||||
elif search_tool == "search_topic_by_date":
|
||||
print(f" ⚠️ search_topic_by_date工具缺少时间参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理需要平台参数的工具
|
||||
if search_tool == "search_topic_on_platform":
|
||||
platform = reflection_output.get("platform")
|
||||
if platform:
|
||||
search_kwargs["platform"] = platform
|
||||
print(f" 指定平台: {platform}")
|
||||
else:
|
||||
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
|
||||
search_tool = "basic_search_news"
|
||||
print(f" ⚠️ search_topic_on_platform工具缺少平台参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理限制参数
|
||||
if search_tool == "search_hot_content":
|
||||
time_period = reflection_output.get("time_period", "week")
|
||||
limit = reflection_output.get("limit", 10)
|
||||
search_kwargs["time_period"] = time_period
|
||||
search_kwargs["limit"] = limit
|
||||
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
|
||||
limit_per_table = reflection_output.get("limit_per_table", 5)
|
||||
search_kwargs["limit_per_table"] = limit_per_table
|
||||
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
|
||||
limit = reflection_output.get("limit", 20)
|
||||
search_kwargs["limit"] = limit
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
|
||||
@@ -361,12 +430,16 @@ class DeepSearchAgent:
|
||||
max_results = min(len(search_response.results), 10)
|
||||
for result in search_response.results[:max_results]:
|
||||
search_results.append({
|
||||
'title': result.title,
|
||||
'url': result.url,
|
||||
'content': result.content,
|
||||
'score': result.score,
|
||||
'raw_content': result.raw_content,
|
||||
'published_date': result.published_date
|
||||
'title': result.title_or_content,
|
||||
'url': result.url or "",
|
||||
'content': result.title_or_content,
|
||||
'score': result.hotness_score,
|
||||
'raw_content': result.title_or_content,
|
||||
'published_date': result.publish_time.isoformat() if result.publish_time else None,
|
||||
'platform': result.platform,
|
||||
'content_type': result.content_type,
|
||||
'author': result.author_nickname,
|
||||
'engagement': result.engagement
|
||||
})
|
||||
|
||||
if search_results:
|
||||
|
||||
Reference in New Issue
Block a user