1. 统一为使用基于pydantic的.env环境变量管理配置
2. 全项目基于loguru进行日志管理
This commit is contained in:
+65
-63
@@ -8,7 +8,7 @@ import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from loguru import logger
|
||||
from .llms import LLMClient
|
||||
from .nodes import (
|
||||
ReportStructureNode,
|
||||
@@ -20,29 +20,26 @@ from .nodes import (
|
||||
)
|
||||
from .state import State
|
||||
from .tools import BochaMultimodalSearch, BochaResponse
|
||||
from .utils import Config, load_config, format_search_results_for_prompt
|
||||
from .utils import settings, Settings, format_search_results_for_prompt
|
||||
|
||||
|
||||
class DeepSearchAgent:
|
||||
"""Deep Search Agent主类"""
|
||||
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
def __init__(self, config: Optional[Settings] = None):
|
||||
"""
|
||||
初始化Deep Search Agent
|
||||
|
||||
Args:
|
||||
config: 配置对象,如果不提供则自动加载
|
||||
"""
|
||||
# 加载配置
|
||||
self.config = config or load_config()
|
||||
os.environ["BOCHA_API_KEY"] = self.config.bocha_api_key or ""
|
||||
os.environ["BOCHA_WEB_SEARCH_API_KEY"] = self.config.bocha_api_key or ""
|
||||
self.config = config or settings
|
||||
|
||||
# 初始化LLM客户端
|
||||
self.llm_client = self._initialize_llm()
|
||||
|
||||
# 初始化搜索工具集
|
||||
self.search_agency = BochaMultimodalSearch(api_key=self.config.bocha_api_key)
|
||||
self.search_agency = BochaMultimodalSearch(api_key=(self.config.BOCHA_API_KEY or self.config.BOCHA_WEB_SEARCH_API_KEY))
|
||||
|
||||
# 初始化节点
|
||||
self._initialize_nodes()
|
||||
@@ -51,18 +48,18 @@ class DeepSearchAgent:
|
||||
self.state = State()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(self.config.output_dir, exist_ok=True)
|
||||
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
print(f"Meida Agent已初始化")
|
||||
print(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
print(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)")
|
||||
logger.info(f"Meida Agent已初始化")
|
||||
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
logger.info(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)")
|
||||
|
||||
def _initialize_llm(self) -> LLMClient:
|
||||
"""初始化LLM客户端"""
|
||||
return LLMClient(
|
||||
api_key=self.config.llm_api_key,
|
||||
model_name=self.config.llm_model_name,
|
||||
base_url=self.config.llm_base_url,
|
||||
api_key=(self.config.MEDIA_ENGINE_API_KEY or self.config.MINDSPIDER_API_KEY),
|
||||
model_name=(self.config.MEDIA_ENGINE_MODEL_NAME or self.config.MINDSPIDER_MODEL_NAME),
|
||||
base_url=(self.config.MEDIA_ENGINE_BASE_URL or self.config.MINDSPIDER_BASE_URL),
|
||||
)
|
||||
|
||||
def _initialize_nodes(self):
|
||||
@@ -115,7 +112,7 @@ class DeepSearchAgent:
|
||||
Returns:
|
||||
BochaResponse对象
|
||||
"""
|
||||
print(f" → 执行搜索工具: {tool_name}")
|
||||
logger.info(f" → 执行搜索工具: {tool_name}")
|
||||
|
||||
if tool_name == "comprehensive_search":
|
||||
max_results = kwargs.get("max_results", 10)
|
||||
@@ -130,7 +127,7 @@ class DeepSearchAgent:
|
||||
elif tool_name == "search_last_week":
|
||||
return self.search_agency.search_last_week(query)
|
||||
else:
|
||||
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索")
|
||||
logger.info(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索")
|
||||
return self.search_agency.comprehensive_search(query)
|
||||
|
||||
def research(self, query: str, save_report: bool = True) -> str:
|
||||
@@ -144,9 +141,9 @@ class DeepSearchAgent:
|
||||
Returns:
|
||||
最终报告内容
|
||||
"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"开始深度研究: {query}")
|
||||
print(f"{'='*60}")
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"开始深度研究: {query}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
try:
|
||||
# Step 1: 生成报告结构
|
||||
@@ -162,19 +159,21 @@ class DeepSearchAgent:
|
||||
if save_report:
|
||||
self._save_report(final_report)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("深度研究完成!")
|
||||
print(f"{'='*60}")
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info("深度研究完成!")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
return final_report
|
||||
|
||||
except Exception as e:
|
||||
print(f"研究过程中发生错误: {str(e)}")
|
||||
import traceback
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}")
|
||||
raise e
|
||||
|
||||
def _generate_report_structure(self, query: str):
|
||||
"""生成报告结构"""
|
||||
print(f"\n[步骤 1] 生成报告结构...")
|
||||
logger.info(f"\n[步骤 1] 生成报告结构...")
|
||||
|
||||
# 创建报告结构节点
|
||||
report_structure_node = ReportStructureNode(self.llm_client, query)
|
||||
@@ -182,17 +181,18 @@ class DeepSearchAgent:
|
||||
# 生成结构并更新状态
|
||||
self.state = report_structure_node.mutate_state(state=self.state)
|
||||
|
||||
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
|
||||
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
|
||||
for i, paragraph in enumerate(self.state.paragraphs, 1):
|
||||
print(f" {i}. {paragraph.title}")
|
||||
_message += f"\n {i}. {paragraph.title}"
|
||||
logger.info(_message)
|
||||
|
||||
def _process_paragraphs(self):
|
||||
"""处理所有段落"""
|
||||
total_paragraphs = len(self.state.paragraphs)
|
||||
|
||||
for i in range(total_paragraphs):
|
||||
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
|
||||
print("-" * 50)
|
||||
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
|
||||
logger.info("-" * 50)
|
||||
|
||||
# 初始搜索和总结
|
||||
self._initial_search_and_summary(i)
|
||||
@@ -204,7 +204,7 @@ class DeepSearchAgent:
|
||||
self.state.paragraphs[i].research.mark_completed()
|
||||
|
||||
progress = (i + 1) / total_paragraphs * 100
|
||||
print(f"段落处理完成 ({progress:.1f}%)")
|
||||
logger.info(f"段落处理完成 ({progress:.1f}%)")
|
||||
|
||||
def _initial_search_and_summary(self, paragraph_index: int):
|
||||
"""执行初始搜索和总结"""
|
||||
@@ -217,18 +217,18 @@ class DeepSearchAgent:
|
||||
}
|
||||
|
||||
# 生成搜索查询和工具选择
|
||||
print(" - 生成搜索查询...")
|
||||
logger.info(" - 生成搜索查询...")
|
||||
search_output = self.first_search_node.run(search_input)
|
||||
search_query = search_output["search_query"]
|
||||
search_tool = search_output.get("search_tool", "comprehensive_search") # 默认工具
|
||||
reasoning = search_output["reasoning"]
|
||||
|
||||
print(f" - 搜索查询: {search_query}")
|
||||
print(f" - 选择的工具: {search_tool}")
|
||||
print(f" - 推理: {reasoning}")
|
||||
logger.info(f" - 搜索查询: {search_query}")
|
||||
logger.info(f" - 选择的工具: {search_tool}")
|
||||
logger.info(f" - 推理: {reasoning}")
|
||||
|
||||
# 执行搜索
|
||||
print(" - 执行网络搜索...")
|
||||
logger.info(" - 执行网络搜索...")
|
||||
|
||||
# 处理特殊参数(新的工具集不需要日期参数处理)
|
||||
search_kwargs = {}
|
||||
@@ -254,24 +254,25 @@ class DeepSearchAgent:
|
||||
})
|
||||
|
||||
if search_results:
|
||||
print(f" - 找到 {len(search_results)} 个搜索结果")
|
||||
_message = f" - 找到 {len(search_results)} 个搜索结果"
|
||||
for j, result in enumerate(search_results, 1):
|
||||
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
|
||||
print(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
|
||||
logger.info(_message)
|
||||
else:
|
||||
print(" - 未找到搜索结果")
|
||||
logger.info(" - 未找到搜索结果")
|
||||
|
||||
# 更新状态中的搜索历史
|
||||
paragraph.research.add_search_results(search_query, search_results)
|
||||
|
||||
# 生成初始总结
|
||||
print(" - 生成初始总结...")
|
||||
logger.info(" - 生成初始总结...")
|
||||
summary_input = {
|
||||
"title": paragraph.title,
|
||||
"content": paragraph.content,
|
||||
"search_query": search_query,
|
||||
"search_results": format_search_results_for_prompt(
|
||||
search_results, self.config.max_content_length
|
||||
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
|
||||
)
|
||||
}
|
||||
|
||||
@@ -280,14 +281,14 @@ class DeepSearchAgent:
|
||||
summary_input, self.state, paragraph_index
|
||||
)
|
||||
|
||||
print(" - 初始总结完成")
|
||||
logger.info(" - 初始总结完成")
|
||||
|
||||
def _reflection_loop(self, paragraph_index: int):
|
||||
"""执行反思循环"""
|
||||
paragraph = self.state.paragraphs[paragraph_index]
|
||||
|
||||
for reflection_i in range(self.config.max_reflections):
|
||||
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
|
||||
for reflection_i in range(self.config.MAX_REFLECTIONS):
|
||||
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
|
||||
|
||||
# 准备反思输入
|
||||
reflection_input = {
|
||||
@@ -302,9 +303,9 @@ class DeepSearchAgent:
|
||||
search_tool = reflection_output.get("search_tool", "comprehensive_search") # 默认工具
|
||||
reasoning = reflection_output["reasoning"]
|
||||
|
||||
print(f" 反思查询: {search_query}")
|
||||
print(f" 选择的工具: {search_tool}")
|
||||
print(f" 反思推理: {reasoning}")
|
||||
logger.info(f" 反思查询: {search_query}")
|
||||
logger.info(f" 选择的工具: {search_tool}")
|
||||
logger.info(f" 反思推理: {reasoning}")
|
||||
|
||||
# 执行反思搜索
|
||||
# 处理特殊参数
|
||||
@@ -331,12 +332,13 @@ class DeepSearchAgent:
|
||||
})
|
||||
|
||||
if search_results:
|
||||
print(f" 找到 {len(search_results)} 个反思搜索结果")
|
||||
_message = f" 找到 {len(search_results)} 个反思搜索结果"
|
||||
for j, result in enumerate(search_results, 1):
|
||||
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
|
||||
print(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
|
||||
logger.info(_message)
|
||||
else:
|
||||
print(" 未找到反思搜索结果")
|
||||
logger.info(" 未找到反思搜索结果")
|
||||
|
||||
# 更新搜索历史
|
||||
paragraph.research.add_search_results(search_query, search_results)
|
||||
@@ -347,7 +349,7 @@ class DeepSearchAgent:
|
||||
"content": paragraph.content,
|
||||
"search_query": search_query,
|
||||
"search_results": format_search_results_for_prompt(
|
||||
search_results, self.config.max_content_length
|
||||
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
|
||||
),
|
||||
"paragraph_latest_state": paragraph.research.latest_summary
|
||||
}
|
||||
@@ -357,11 +359,11 @@ class DeepSearchAgent:
|
||||
reflection_summary_input, self.state, paragraph_index
|
||||
)
|
||||
|
||||
print(f" 反思 {reflection_i + 1} 完成")
|
||||
logger.info(f" 反思 {reflection_i + 1} 完成")
|
||||
|
||||
def _generate_final_report(self) -> str:
|
||||
"""生成最终报告"""
|
||||
print(f"\n[步骤 3] 生成最终报告...")
|
||||
logger.info(f"\n[步骤 3] 生成最终报告...")
|
||||
|
||||
# 准备报告数据
|
||||
report_data = []
|
||||
@@ -375,7 +377,7 @@ class DeepSearchAgent:
|
||||
try:
|
||||
final_report = self.report_formatting_node.run(report_data)
|
||||
except Exception as e:
|
||||
print(f"LLM格式化失败,使用备用方法: {str(e)}")
|
||||
logger.info(f"LLM格式化失败,使用备用方法: {str(e)}")
|
||||
final_report = self.report_formatting_node.format_report_manually(
|
||||
report_data, self.state.report_title
|
||||
)
|
||||
@@ -384,7 +386,7 @@ class DeepSearchAgent:
|
||||
self.state.final_report = final_report
|
||||
self.state.mark_completed()
|
||||
|
||||
print("最终报告生成完成")
|
||||
logger.info("最终报告生成完成")
|
||||
return final_report
|
||||
|
||||
def _save_report(self, report_content: str):
|
||||
@@ -395,20 +397,20 @@ class DeepSearchAgent:
|
||||
query_safe = query_safe.replace(' ', '_')[:30]
|
||||
|
||||
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
|
||||
filepath = os.path.join(self.config.output_dir, filename)
|
||||
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
|
||||
|
||||
# 保存报告
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(report_content)
|
||||
|
||||
print(f"报告已保存到: {filepath}")
|
||||
logger.info(f"报告已保存到: {filepath}")
|
||||
|
||||
# 保存状态(如果配置允许)
|
||||
if self.config.save_intermediate_states:
|
||||
if self.config.SAVE_INTERMEDIATE_STATES:
|
||||
state_filename = f"state_{query_safe}_{timestamp}.json"
|
||||
state_filepath = os.path.join(self.config.output_dir, state_filename)
|
||||
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
|
||||
self.state.save_to_file(state_filepath)
|
||||
print(f"状态已保存到: {state_filepath}")
|
||||
logger.info(f"状态已保存到: {state_filepath}")
|
||||
|
||||
def get_progress_summary(self) -> Dict[str, Any]:
|
||||
"""获取进度摘要"""
|
||||
@@ -417,12 +419,12 @@ class DeepSearchAgent:
|
||||
def load_state(self, filepath: str):
|
||||
"""从文件加载状态"""
|
||||
self.state = State.load_from_file(filepath)
|
||||
print(f"状态已从 {filepath} 加载")
|
||||
logger.info(f"状态已从 {filepath} 加载")
|
||||
|
||||
def save_state(self, filepath: str):
|
||||
"""保存状态到文件"""
|
||||
self.state.save_to_file(filepath)
|
||||
print(f"状态已保存到 {filepath}")
|
||||
logger.info(f"状态已保存到 {filepath}")
|
||||
|
||||
|
||||
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
|
||||
@@ -435,5 +437,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
|
||||
Returns:
|
||||
DeepSearchAgent实例
|
||||
"""
|
||||
config = load_config(config_file)
|
||||
return DeepSearchAgent(config)
|
||||
settings = Settings()
|
||||
return DeepSearchAgent(settings)
|
||||
|
||||
Reference in New Issue
Block a user