1. 统一为使用基于pydantic的.env环境变量管理配置

2. 全项目基于loguru进行日志管理
This commit is contained in:
Doiiars
2025-11-05 14:56:49 +08:00
parent 1d2e23d8c1
commit 537d682861
50 changed files with 1404 additions and 1731 deletions
+114 -119
View File
@@ -8,6 +8,7 @@ import os
import re
from datetime import datetime
from typing import Optional, Dict, Any, List, Union
from loguru import logger
from .llms import LLMClient
from .nodes import (
@@ -20,32 +21,25 @@ from .nodes import (
)
from .state import State
from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer
from .utils import Config, load_config, format_search_results_for_prompt
from .utils.config import settings, Settings
from .utils import format_search_results_for_prompt
class DeepSearchAgent:
"""Deep Search Agent主类"""
def __init__(self, config: Optional[Config] = None):
def __init__(self, config: Optional[Settings] = None):
"""
初始化Deep Search Agent
Args:
config: 配置对象,如果不提供则自动加载
config: 可选配置对象(不填则用全局settings
"""
# 加载配置
self.config = config or load_config()
self.config = config or settings
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
# 设置数据库环境变量
os.environ["DB_HOST"] = self.config.db_host or ""
os.environ["DB_USER"] = self.config.db_user or ""
os.environ["DB_PASSWORD"] = self.config.db_password or ""
os.environ["DB_NAME"] = self.config.db_name or ""
os.environ["DB_PORT"] = str(self.config.db_port)
os.environ["DB_CHARSET"] = self.config.db_charset
# 初始化搜索工具集
self.search_agency = MediaCrawlerDB()
@@ -60,19 +54,19 @@ class DeepSearchAgent:
self.state = State()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
print(f"Insight Agent已初始化")
print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
print(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
logger.info(f"Insight Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
api_key=self.config.INSIGHT_ENGINE_API_KEY,
model_name=self.config.INSIGHT_ENGINE_MODEL_NAME,
base_url=self.config.INSIGHT_ENGINE_BASE_URL,
)
def _initialize_nodes(self):
@@ -127,7 +121,7 @@ class DeepSearchAgent:
Returns:
DBResponse对象(可能包含情感分析结果)
"""
print(f" → 执行数据库查询工具: {tool_name}")
logger.info(f" → 执行数据库查询工具: {tool_name}")
# 对于热点内容搜索,不需要关键词优化(因为不需要query参数)
if tool_name == "search_hot_content":
@@ -138,12 +132,12 @@ class DeepSearchAgent:
# 检查是否需要进行情感分析
enable_sentiment = kwargs.get("enable_sentiment", True)
if enable_sentiment and response.results and len(response.results) > 0:
print(f" 🎭 开始对热点内容进行情感分析...")
logger.info(f" 🎭 开始对热点内容进行情感分析...")
sentiment_analysis = self._perform_sentiment_analysis(response.results)
if sentiment_analysis:
# 将情感分析结果添加到响应的parameters中
response.parameters["sentiment_analysis"] = sentiment_analysis
print(f" ✅ 情感分析完成")
logger.info(f" ✅ 情感分析完成")
return response
@@ -170,32 +164,32 @@ class DeepSearchAgent:
context=f"使用{tool_name}工具进行查询"
)
print(f" 🔍 原始查询: '{query}'")
print(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
logger.info(f" 🔍 原始查询: '{query}'")
logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
# 使用优化后的关键词进行多次查询并整合结果
all_results = []
total_count = 0
for keyword in optimized_response.optimized_keywords:
print(f" 查询关键词: '{keyword}'")
logger.info(f" 查询关键词: '{keyword}'")
try:
if tool_name == "search_topic_globally":
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table)
elif tool_name == "search_topic_by_date":
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
if not start_date or not end_date:
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
elif tool_name == "get_comments_for_topic":
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
limit = self.config.default_get_comments_for_topic_limit // len(optimized_response.optimized_keywords)
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords)
limit = max(limit, 50)
response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit)
elif tool_name == "search_topic_on_platform":
@@ -203,30 +197,30 @@ class DeepSearchAgent:
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
limit = self.config.default_search_topic_on_platform_limit // len(optimized_response.optimized_keywords)
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords)
limit = max(limit, 30)
if not platform:
raise ValueError("search_topic_on_platform工具需要platform参数")
response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit)
else:
print(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.default_search_topic_globally_limit_per_table)
logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE)
# 收集结果
if response.results:
print(f" 找到 {len(response.results)} 条结果")
logger.info(f" 找到 {len(response.results)} 条结果")
all_results.extend(response.results)
total_count += len(response.results)
else:
print(f" 未找到结果")
logger.info(f" 未找到结果")
except Exception as e:
print(f" 查询'{keyword}'时出错: {str(e)}")
logger.error(f" 查询'{keyword}'时出错: {str(e)}")
continue
# 去重和整合结果
unique_results = self._deduplicate_results(all_results)
print(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)}")
logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)}")
# 构建整合后的响应
integrated_response = DBResponse(
@@ -244,12 +238,12 @@ class DeepSearchAgent:
# 检查是否需要进行情感分析
enable_sentiment = kwargs.get("enable_sentiment", True)
if enable_sentiment and unique_results and len(unique_results) > 0:
print(f" 🎭 开始对搜索结果进行情感分析...")
logger.info(f" 🎭 开始对搜索结果进行情感分析...")
sentiment_analysis = self._perform_sentiment_analysis(unique_results)
if sentiment_analysis:
# 将情感分析结果添加到响应的parameters中
integrated_response.parameters["sentiment_analysis"] = sentiment_analysis
print(f" ✅ 情感分析完成")
logger.info(f" ✅ 情感分析完成")
return integrated_response
@@ -282,11 +276,11 @@ class DeepSearchAgent:
try:
# 初始化情感分析器(如果尚未初始化且未被禁用)
if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled:
print(" 初始化情感分析模型...")
logger.info(" 初始化情感分析模型...")
if not self.sentiment_analyzer.initialize():
print(" 情感分析模型初始化失败,将直接透传原始文本")
logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
elif self.sentiment_analyzer.is_disabled:
print(" 情感分析功能已禁用,直接透传原始文本")
logger.info(" 情感分析功能已禁用,直接透传原始文本")
# 将查询结果转换为字典格式
results_dict = []
@@ -310,7 +304,7 @@ class DeepSearchAgent:
return sentiment_analysis.get("sentiment_analysis")
except Exception as e:
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
return None
def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]:
@@ -323,16 +317,16 @@ class DeepSearchAgent:
Returns:
情感分析结果
"""
print(f" → 执行独立情感分析")
logger.info(f" → 执行独立情感分析")
try:
# 初始化情感分析器(如果尚未初始化且未被禁用)
if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled:
print(" 初始化情感分析模型...")
logger.info(" 初始化情感分析模型...")
if not self.sentiment_analyzer.initialize():
print(" 情感分析模型初始化失败,将直接透传原始文本")
logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
elif self.sentiment_analyzer.is_disabled:
print(" 情感分析功能已禁用,直接透传原始文本")
logger.warning(" 情感分析功能已禁用,直接透传原始文本")
# 执行分析
if isinstance(texts, str):
@@ -368,7 +362,7 @@ class DeepSearchAgent:
return response
except Exception as e:
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
return {
"success": False,
"error": str(e),
@@ -386,9 +380,9 @@ class DeepSearchAgent:
Returns:
最终报告内容
"""
print(f"\n{'='*60}")
print(f"开始深度研究: {query}")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info(f"开始深度研究: {query}")
logger.info(f"{'='*60}")
try:
# Step 1: 生成报告结构
@@ -403,20 +397,18 @@ class DeepSearchAgent:
# Step 4: 保存报告
if save_report:
self._save_report(final_report)
print(f"\n{'='*60}")
print("深度研究完成!")
print(f"{'='*60}")
logger.info("深度研究完成!")
return final_report
except Exception as e:
print(f"研究过程中发生错误: {str(e)}")
logger.exception(f"研究过程中发生错误: {str(e)}")
raise e
def _generate_report_structure(self, query: str):
"""生成报告结构"""
print(f"\n[步骤 1] 生成报告结构...")
logger.info(f"\n[步骤 1] 生成报告结构...")
# 创建报告结构节点
report_structure_node = ReportStructureNode(self.llm_client, query)
@@ -424,17 +416,18 @@ class DeepSearchAgent:
# 生成结构并更新状态
self.state = report_structure_node.mutate_state(state=self.state)
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
for i, paragraph in enumerate(self.state.paragraphs, 1):
print(f" {i}. {paragraph.title}")
_message += f"\n {i}. {paragraph.title}"
logger.info(_message)
def _process_paragraphs(self):
"""处理所有段落"""
total_paragraphs = len(self.state.paragraphs)
for i in range(total_paragraphs):
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
print("-" * 50)
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
logger.info("-" * 50)
# 初始搜索和总结
self._initial_search_and_summary(i)
@@ -446,7 +439,7 @@ class DeepSearchAgent:
self.state.paragraphs[i].research.mark_completed()
progress = (i + 1) / total_paragraphs * 100
print(f"段落处理完成 ({progress:.1f}%)")
logger.info(f"段落处理完成 ({progress:.1f}%)")
def _initial_search_and_summary(self, paragraph_index: int):
"""执行初始搜索和总结"""
@@ -459,18 +452,18 @@ class DeepSearchAgent:
}
# 生成搜索查询和工具选择
print(" - 生成搜索查询...")
logger.info(" - 生成搜索查询...")
search_output = self.first_search_node.run(search_input)
search_query = search_output["search_query"]
search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具
reasoning = search_output["reasoning"]
print(f" - 搜索查询: {search_query}")
print(f" - 选择的工具: {search_tool}")
print(f" - 推理: {reasoning}")
logger.info(f" - 搜索查询: {search_query}")
logger.info(f" - 选择的工具: {search_tool}")
logger.info(f" - 推理: {reasoning}")
# 执行搜索
print(" - 执行数据库查询...")
logger.info(" - 执行数据库查询...")
# 处理特殊参数
search_kwargs = {}
@@ -485,13 +478,13 @@ class DeepSearchAgent:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" - 时间范围: {start_date}{end_date}")
logger.info(f" - 时间范围: {start_date}{end_date}")
else:
print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "search_topic_globally"
elif search_tool == "search_topic_by_date":
print(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理需要平台参数的工具
@@ -499,28 +492,28 @@ class DeepSearchAgent:
platform = search_output.get("platform")
if platform:
search_kwargs["platform"] = platform
print(f" - 指定平台: {platform}")
logger.info(f" - 指定平台: {platform}")
else:
print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理限制参数,使用配置文件中的默认值而不是agent提供的参数
if search_tool == "search_hot_content":
time_period = search_output.get("time_period", "week")
limit = self.config.default_search_hot_content_limit
limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT
search_kwargs["time_period"] = time_period
search_kwargs["limit"] = limit
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
if search_tool == "search_topic_globally":
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
else: # search_topic_by_date
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
search_kwargs["limit_per_table"] = limit_per_table
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
if search_tool == "get_comments_for_topic":
limit = self.config.default_get_comments_for_topic_limit
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT
else: # search_topic_on_platform
limit = self.config.default_search_topic_on_platform_limit
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
search_kwargs["limit"] = limit
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -529,8 +522,8 @@ class DeepSearchAgent:
search_results = []
if search_response and search_response.results:
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
if self.config.max_search_results_for_llm > 0:
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM)
else:
max_results = len(search_response.results) # 不限制,传递所有结果
for result in search_response.results[:max_results]:
@@ -548,24 +541,25 @@ class DeepSearchAgent:
})
if search_results:
print(f" - 找到 {len(search_results)} 个搜索结果")
_message = f" - 找到 {len(search_results)} 个搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" - 未找到搜索结果")
logger.info(" - 未找到搜索结果")
# 更新状态中的搜索历史
paragraph.research.add_search_results(search_query, search_results)
# 生成初始总结
print(" - 生成初始总结...")
logger.info(" - 生成初始总结...")
summary_input = {
"title": paragraph.title,
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.MAX_CONTENT_LENGTH
)
}
@@ -574,14 +568,14 @@ class DeepSearchAgent:
summary_input, self.state, paragraph_index
)
print(" - 初始总结完成")
logger.info(" - 初始总结完成")
def _reflection_loop(self, paragraph_index: int):
"""执行反思循环"""
paragraph = self.state.paragraphs[paragraph_index]
for reflection_i in range(self.config.max_reflections):
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
for reflection_i in range(self.config.MAX_REFLECTIONS):
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
# 准备反思输入
reflection_input = {
@@ -596,9 +590,9 @@ class DeepSearchAgent:
search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具
reasoning = reflection_output["reasoning"]
print(f" 反思查询: {search_query}")
print(f" 选择的工具: {search_tool}")
print(f" 反思推理: {reasoning}")
logger.info(f" 反思查询: {search_query}")
logger.info(f" 选择的工具: {search_tool}")
logger.info(f" 反思推理: {reasoning}")
# 执行反思搜索
# 处理特殊参数
@@ -614,13 +608,13 @@ class DeepSearchAgent:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" 时间范围: {start_date}{end_date}")
logger.info(f" 时间范围: {start_date}{end_date}")
else:
print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "search_topic_globally"
elif search_tool == "search_topic_by_date":
print(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理需要平台参数的工具
@@ -628,31 +622,31 @@ class DeepSearchAgent:
platform = reflection_output.get("platform")
if platform:
search_kwargs["platform"] = platform
print(f" 指定平台: {platform}")
logger.info(f" 指定平台: {platform}")
else:
print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理限制参数
if search_tool == "search_hot_content":
time_period = reflection_output.get("time_period", "week")
# 使用配置文件中的默认值,不允许agent控制limit参数
limit = self.config.default_search_hot_content_limit
limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT
search_kwargs["time_period"] = time_period
search_kwargs["limit"] = limit
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
# 使用配置文件中的默认值,不允许agent控制limit_per_table参数
if search_tool == "search_topic_globally":
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
else: # search_topic_by_date
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
search_kwargs["limit_per_table"] = limit_per_table
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
# 使用配置文件中的默认值,不允许agent控制limit参数
if search_tool == "get_comments_for_topic":
limit = self.config.default_get_comments_for_topic_limit
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT
else: # search_topic_on_platform
limit = self.config.default_search_topic_on_platform_limit
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
search_kwargs["limit"] = limit
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -661,8 +655,8 @@ class DeepSearchAgent:
search_results = []
if search_response and search_response.results:
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
if self.config.max_search_results_for_llm > 0:
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM)
else:
max_results = len(search_response.results) # 不限制,传递所有结果
for result in search_response.results[:max_results]:
@@ -680,12 +674,13 @@ class DeepSearchAgent:
})
if search_results:
print(f" 找到 {len(search_results)} 个反思搜索结果")
_message = f" 找到 {len(search_results)} 个反思搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" 未找到反思搜索结果")
logger.info(" 未找到反思搜索结果")
# 更新搜索历史
paragraph.research.add_search_results(search_query, search_results)
@@ -696,7 +691,7 @@ class DeepSearchAgent:
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.MAX_CONTENT_LENGTH
),
"paragraph_latest_state": paragraph.research.latest_summary
}
@@ -706,11 +701,11 @@ class DeepSearchAgent:
reflection_summary_input, self.state, paragraph_index
)
print(f" 反思 {reflection_i + 1} 完成")
logger.info(f" 反思 {reflection_i + 1} 完成")
def _generate_final_report(self) -> str:
"""生成最终报告"""
print(f"\n[步骤 3] 生成最终报告...")
logger.info(f"\n[步骤 3] 生成最终报告...")
# 准备报告数据
report_data = []
@@ -724,7 +719,7 @@ class DeepSearchAgent:
try:
final_report = self.report_formatting_node.run(report_data)
except Exception as e:
print(f"LLM格式化失败,使用备用方法: {str(e)}")
logger.exception(f"LLM格式化失败,使用备用方法: {str(e)}")
final_report = self.report_formatting_node.format_report_manually(
report_data, self.state.report_title
)
@@ -733,7 +728,7 @@ class DeepSearchAgent:
self.state.final_report = final_report
self.state.mark_completed()
print("最终报告生成完成")
logger.info("最终报告生成完成")
return final_report
def _save_report(self, report_content: str):
@@ -744,20 +739,20 @@ class DeepSearchAgent:
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
filepath = os.path.join(self.config.output_dir, filename)
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
# 保存报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(report_content)
print(f"报告已保存到: {filepath}")
logger.info(f"报告已保存到: {filepath}")
# 保存状态(如果配置允许)
if self.config.save_intermediate_states:
if self.config.SAVE_INTERMEDIATE_STATES:
state_filename = f"state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
self.state.save_to_file(state_filepath)
print(f"状态已保存到: {state_filepath}")
logger.info(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
@@ -766,12 +761,12 @@ class DeepSearchAgent:
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = State.load_from_file(filepath)
print(f"状态已从 {filepath} 加载")
logger.info(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
print(f"状态已保存到 {filepath}")
logger.info(f"状态已保存到 {filepath}")
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
@@ -784,5 +779,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
Returns:
DeepSearchAgent实例
"""
config = load_config(config_file)
config = settings
return DeepSearchAgent(config)