1. 统一为使用基于pydantic的.env环境变量管理配置
2. 全项目基于loguru进行日志管理
This commit is contained in:
@@ -4,9 +4,9 @@ Deep Search Agent
|
||||
"""
|
||||
|
||||
from .agent import DeepSearchAgent, create_agent
|
||||
from .utils.config import Config, load_config
|
||||
from .utils.config import Settings
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Deep Search Agent Team"
|
||||
|
||||
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
|
||||
__all__ = ["DeepSearchAgent", "create_agent", "Settings"]
|
||||
|
||||
+74
-73
@@ -20,13 +20,13 @@ from .nodes import (
|
||||
)
|
||||
from .state import State
|
||||
from .tools import TavilyNewsAgency, TavilyResponse
|
||||
from .utils import Config, load_config, format_search_results_for_prompt
|
||||
|
||||
from .utils import Settings, format_search_results_for_prompt
|
||||
from loguru import logger
|
||||
|
||||
class DeepSearchAgent:
|
||||
"""Deep Search Agent主类"""
|
||||
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
def __init__(self, config: Optional[Settings] = None):
|
||||
"""
|
||||
初始化Deep Search Agent
|
||||
|
||||
@@ -34,14 +34,14 @@ class DeepSearchAgent:
|
||||
config: 配置对象,如果不提供则自动加载
|
||||
"""
|
||||
# 加载配置
|
||||
self.config = config or load_config()
|
||||
os.environ["TAVILY_API_KEY"] = self.config.tavily_api_key or ""
|
||||
from .utils.config import settings
|
||||
self.config = config or settings
|
||||
|
||||
# 初始化LLM客户端
|
||||
self.llm_client = self._initialize_llm()
|
||||
|
||||
# 初始化搜索工具集
|
||||
self.search_agency = TavilyNewsAgency(api_key=self.config.tavily_api_key)
|
||||
self.search_agency = TavilyNewsAgency(api_key=self.config.TAVILY_API_KEY)
|
||||
|
||||
# 初始化节点
|
||||
self._initialize_nodes()
|
||||
@@ -50,18 +50,18 @@ class DeepSearchAgent:
|
||||
self.state = State()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(self.config.output_dir, exist_ok=True)
|
||||
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
print(f"Query Agent已初始化")
|
||||
print(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
print(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)")
|
||||
logger.info(f"Query Agent已初始化")
|
||||
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
logger.info(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)")
|
||||
|
||||
def _initialize_llm(self) -> LLMClient:
|
||||
"""初始化LLM客户端"""
|
||||
return LLMClient(
|
||||
api_key=self.config.llm_api_key,
|
||||
model_name=self.config.llm_model_name,
|
||||
base_url=self.config.llm_base_url,
|
||||
api_key=self.config.QUERY_ENGINE_API_KEY,
|
||||
model_name=self.config.QUERY_ENGINE_MODEL_NAME,
|
||||
base_url=self.config.QUERY_ENGINE_BASE_URL,
|
||||
)
|
||||
|
||||
def _initialize_nodes(self):
|
||||
@@ -115,7 +115,7 @@ class DeepSearchAgent:
|
||||
Returns:
|
||||
TavilyResponse对象
|
||||
"""
|
||||
print(f" → 执行搜索工具: {tool_name}")
|
||||
logger.info(f" → 执行搜索工具: {tool_name}")
|
||||
|
||||
if tool_name == "basic_search_news":
|
||||
max_results = kwargs.get("max_results", 7)
|
||||
@@ -135,7 +135,7 @@ class DeepSearchAgent:
|
||||
raise ValueError("search_news_by_date工具需要start_date和end_date参数")
|
||||
return self.search_agency.search_news_by_date(query, start_date, end_date)
|
||||
else:
|
||||
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索")
|
||||
logger.warning(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索")
|
||||
return self.search_agency.basic_search_news(query)
|
||||
|
||||
def research(self, query: str, save_report: bool = True) -> str:
|
||||
@@ -149,9 +149,9 @@ class DeepSearchAgent:
|
||||
Returns:
|
||||
最终报告内容
|
||||
"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"开始深度研究: {query}")
|
||||
print(f"{'='*60}")
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"开始深度研究: {query}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
try:
|
||||
# Step 1: 生成报告结构
|
||||
@@ -167,19 +167,21 @@ class DeepSearchAgent:
|
||||
if save_report:
|
||||
self._save_report(final_report)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("深度研究完成!")
|
||||
print(f"{'='*60}")
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info("深度研究完成!")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
return final_report
|
||||
|
||||
except Exception as e:
|
||||
print(f"研究过程中发生错误: {str(e)}")
|
||||
import traceback
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}")
|
||||
raise e
|
||||
|
||||
def _generate_report_structure(self, query: str):
|
||||
"""生成报告结构"""
|
||||
print(f"\n[步骤 1] 生成报告结构...")
|
||||
logger.info(f"\n[步骤 1] 生成报告结构...")
|
||||
|
||||
# 创建报告结构节点
|
||||
report_structure_node = ReportStructureNode(self.llm_client, query)
|
||||
@@ -187,17 +189,18 @@ class DeepSearchAgent:
|
||||
# 生成结构并更新状态
|
||||
self.state = report_structure_node.mutate_state(state=self.state)
|
||||
|
||||
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
|
||||
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
|
||||
for i, paragraph in enumerate(self.state.paragraphs, 1):
|
||||
print(f" {i}. {paragraph.title}")
|
||||
_message += f"\n {i}. {paragraph.title}"
|
||||
logger.info(_message)
|
||||
|
||||
def _process_paragraphs(self):
|
||||
"""处理所有段落"""
|
||||
total_paragraphs = len(self.state.paragraphs)
|
||||
|
||||
for i in range(total_paragraphs):
|
||||
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
|
||||
print("-" * 50)
|
||||
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
|
||||
logger.info("-" * 50)
|
||||
|
||||
# 初始搜索和总结
|
||||
self._initial_search_and_summary(i)
|
||||
@@ -209,7 +212,7 @@ class DeepSearchAgent:
|
||||
self.state.paragraphs[i].research.mark_completed()
|
||||
|
||||
progress = (i + 1) / total_paragraphs * 100
|
||||
print(f"段落处理完成 ({progress:.1f}%)")
|
||||
logger.info(f"段落处理完成 ({progress:.1f}%)")
|
||||
|
||||
def _initial_search_and_summary(self, paragraph_index: int):
|
||||
"""执行初始搜索和总结"""
|
||||
@@ -222,18 +225,18 @@ class DeepSearchAgent:
|
||||
}
|
||||
|
||||
# 生成搜索查询和工具选择
|
||||
print(" - 生成搜索查询...")
|
||||
logger.info(" - 生成搜索查询...")
|
||||
search_output = self.first_search_node.run(search_input)
|
||||
search_query = search_output["search_query"]
|
||||
search_tool = search_output.get("search_tool", "basic_search_news") # 默认工具
|
||||
reasoning = search_output["reasoning"]
|
||||
|
||||
print(f" - 搜索查询: {search_query}")
|
||||
print(f" - 选择的工具: {search_tool}")
|
||||
print(f" - 推理: {reasoning}")
|
||||
logger.info(f" - 搜索查询: {search_query}")
|
||||
logger.info(f" - 选择的工具: {search_tool}")
|
||||
logger.info(f" - 推理: {reasoning}")
|
||||
|
||||
# 执行搜索
|
||||
print(" - 执行网络搜索...")
|
||||
logger.info(" - 执行网络搜索...")
|
||||
|
||||
# 处理search_news_by_date的特殊参数
|
||||
search_kwargs = {}
|
||||
@@ -246,13 +249,13 @@ class DeepSearchAgent:
|
||||
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
|
||||
search_kwargs["start_date"] = start_date
|
||||
search_kwargs["end_date"] = end_date
|
||||
print(f" - 时间范围: {start_date} 到 {end_date}")
|
||||
logger.info(f" - 时间范围: {start_date} 到 {end_date}")
|
||||
else:
|
||||
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
|
||||
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
|
||||
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
search_tool = "basic_search_news"
|
||||
else:
|
||||
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
|
||||
logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
|
||||
search_tool = "basic_search_news"
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
@@ -273,24 +276,24 @@ class DeepSearchAgent:
|
||||
})
|
||||
|
||||
if search_results:
|
||||
print(f" - 找到 {len(search_results)} 个搜索结果")
|
||||
_message = f" - 找到 {len(search_results)} 个搜索结果"
|
||||
for j, result in enumerate(search_results, 1):
|
||||
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
|
||||
print(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
|
||||
logger.info(_message)
|
||||
else:
|
||||
print(" - 未找到搜索结果")
|
||||
|
||||
logger.info(" - 未找到搜索结果")
|
||||
# 更新状态中的搜索历史
|
||||
paragraph.research.add_search_results(search_query, search_results)
|
||||
|
||||
# 生成初始总结
|
||||
print(" - 生成初始总结...")
|
||||
logger.info(" - 生成初始总结...")
|
||||
summary_input = {
|
||||
"title": paragraph.title,
|
||||
"content": paragraph.content,
|
||||
"search_query": search_query,
|
||||
"search_results": format_search_results_for_prompt(
|
||||
search_results, self.config.max_content_length
|
||||
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
|
||||
)
|
||||
}
|
||||
|
||||
@@ -299,14 +302,14 @@ class DeepSearchAgent:
|
||||
summary_input, self.state, paragraph_index
|
||||
)
|
||||
|
||||
print(" - 初始总结完成")
|
||||
logger.info(" - 初始总结完成")
|
||||
|
||||
def _reflection_loop(self, paragraph_index: int):
|
||||
"""执行反思循环"""
|
||||
paragraph = self.state.paragraphs[paragraph_index]
|
||||
|
||||
for reflection_i in range(self.config.max_reflections):
|
||||
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
|
||||
for reflection_i in range(self.config.MAX_REFLECTIONS):
|
||||
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
|
||||
|
||||
# 准备反思输入
|
||||
reflection_input = {
|
||||
@@ -321,9 +324,9 @@ class DeepSearchAgent:
|
||||
search_tool = reflection_output.get("search_tool", "basic_search_news") # 默认工具
|
||||
reasoning = reflection_output["reasoning"]
|
||||
|
||||
print(f" 反思查询: {search_query}")
|
||||
print(f" 选择的工具: {search_tool}")
|
||||
print(f" 反思推理: {reasoning}")
|
||||
logger.info(f" 反思查询: {search_query}")
|
||||
logger.info(f" 选择的工具: {search_tool}")
|
||||
logger.info(f" 反思推理: {reasoning}")
|
||||
|
||||
# 执行反思搜索
|
||||
# 处理search_news_by_date的特殊参数
|
||||
@@ -337,13 +340,13 @@ class DeepSearchAgent:
|
||||
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
|
||||
search_kwargs["start_date"] = start_date
|
||||
search_kwargs["end_date"] = end_date
|
||||
print(f" 时间范围: {start_date} 到 {end_date}")
|
||||
logger.info(f" 时间范围: {start_date} 到 {end_date}")
|
||||
else:
|
||||
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
|
||||
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
|
||||
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
search_tool = "basic_search_news"
|
||||
else:
|
||||
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
|
||||
logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
|
||||
search_tool = "basic_search_news"
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
@@ -364,12 +367,12 @@ class DeepSearchAgent:
|
||||
})
|
||||
|
||||
if search_results:
|
||||
print(f" 找到 {len(search_results)} 个反思搜索结果")
|
||||
logger.info(f" 找到 {len(search_results)} 个反思搜索结果")
|
||||
for j, result in enumerate(search_results, 1):
|
||||
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
|
||||
print(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
logger.info(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
else:
|
||||
print(" 未找到反思搜索结果")
|
||||
logger.info(" 未找到反思搜索结果")
|
||||
|
||||
# 更新搜索历史
|
||||
paragraph.research.add_search_results(search_query, search_results)
|
||||
@@ -380,7 +383,7 @@ class DeepSearchAgent:
|
||||
"content": paragraph.content,
|
||||
"search_query": search_query,
|
||||
"search_results": format_search_results_for_prompt(
|
||||
search_results, self.config.max_content_length
|
||||
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
|
||||
),
|
||||
"paragraph_latest_state": paragraph.research.latest_summary
|
||||
}
|
||||
@@ -390,11 +393,11 @@ class DeepSearchAgent:
|
||||
reflection_summary_input, self.state, paragraph_index
|
||||
)
|
||||
|
||||
print(f" 反思 {reflection_i + 1} 完成")
|
||||
logger.info(f" 反思 {reflection_i + 1} 完成")
|
||||
|
||||
def _generate_final_report(self) -> str:
|
||||
"""生成最终报告"""
|
||||
print(f"\n[步骤 3] 生成最终报告...")
|
||||
logger.info(f"\n[步骤 3] 生成最终报告...")
|
||||
|
||||
# 准备报告数据
|
||||
report_data = []
|
||||
@@ -408,7 +411,7 @@ class DeepSearchAgent:
|
||||
try:
|
||||
final_report = self.report_formatting_node.run(report_data)
|
||||
except Exception as e:
|
||||
print(f"LLM格式化失败,使用备用方法: {str(e)}")
|
||||
logger.error(f"LLM格式化失败,使用备用方法: {str(e)}")
|
||||
final_report = self.report_formatting_node.format_report_manually(
|
||||
report_data, self.state.report_title
|
||||
)
|
||||
@@ -417,7 +420,7 @@ class DeepSearchAgent:
|
||||
self.state.final_report = final_report
|
||||
self.state.mark_completed()
|
||||
|
||||
print("最终报告生成完成")
|
||||
logger.info("最终报告生成完成")
|
||||
return final_report
|
||||
|
||||
def _save_report(self, report_content: str):
|
||||
@@ -428,20 +431,20 @@ class DeepSearchAgent:
|
||||
query_safe = query_safe.replace(' ', '_')[:30]
|
||||
|
||||
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
|
||||
filepath = os.path.join(self.config.output_dir, filename)
|
||||
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
|
||||
|
||||
# 保存报告
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(report_content)
|
||||
|
||||
print(f"报告已保存到: {filepath}")
|
||||
logger.info(f"报告已保存到: {filepath}")
|
||||
|
||||
# 保存状态(如果配置允许)
|
||||
if self.config.save_intermediate_states:
|
||||
if self.config.SAVE_INTERMEDIATE_STATES:
|
||||
state_filename = f"state_{query_safe}_{timestamp}.json"
|
||||
state_filepath = os.path.join(self.config.output_dir, state_filename)
|
||||
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
|
||||
self.state.save_to_file(state_filepath)
|
||||
print(f"状态已保存到: {state_filepath}")
|
||||
logger.info(f"状态已保存到: {state_filepath}")
|
||||
|
||||
def get_progress_summary(self) -> Dict[str, Any]:
|
||||
"""获取进度摘要"""
|
||||
@@ -450,23 +453,21 @@ class DeepSearchAgent:
|
||||
def load_state(self, filepath: str):
|
||||
"""从文件加载状态"""
|
||||
self.state = State.load_from_file(filepath)
|
||||
print(f"状态已从 {filepath} 加载")
|
||||
logger.info(f"状态已从 {filepath} 加载")
|
||||
|
||||
def save_state(self, filepath: str):
|
||||
"""保存状态到文件"""
|
||||
self.state.save_to_file(filepath)
|
||||
print(f"状态已保存到 {filepath}")
|
||||
logger.info(f"状态已保存到 {filepath}")
|
||||
|
||||
|
||||
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
|
||||
def create_agent() -> DeepSearchAgent:
|
||||
"""
|
||||
创建Deep Search Agent实例的便捷函数
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
|
||||
Returns:
|
||||
DeepSearchAgent实例
|
||||
"""
|
||||
config = load_config(config_file)
|
||||
from .utils.config import Settings
|
||||
config = Settings()
|
||||
return DeepSearchAgent(config)
|
||||
|
||||
@@ -5,69 +5,74 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from loguru import logger
|
||||
from ..llms.base import LLMClient
|
||||
from ..state.state import State
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类"""
|
||||
|
||||
|
||||
def __init__(self, llm_client: LLMClient, node_name: str = ""):
|
||||
"""
|
||||
初始化节点
|
||||
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端
|
||||
node_name: 节点名称
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.node_name = node_name or self.__class__.__name__
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input_data: Any, **kwargs) -> Any:
|
||||
"""
|
||||
执行节点处理逻辑
|
||||
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
**kwargs: 额外参数
|
||||
|
||||
|
||||
Returns:
|
||||
处理结果
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def validate_input(self, input_data: Any) -> bool:
|
||||
"""
|
||||
验证输入数据
|
||||
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
def process_output(self, output: Any) -> Any:
|
||||
"""
|
||||
处理输出数据
|
||||
|
||||
|
||||
Args:
|
||||
output: 原始输出
|
||||
|
||||
|
||||
Returns:
|
||||
处理后的输出
|
||||
"""
|
||||
return output
|
||||
|
||||
|
||||
def log_info(self, message: str):
|
||||
"""记录信息日志"""
|
||||
print(f"[{self.node_name}] {message}")
|
||||
logger.info(f"[{self.node_name}] {message}")
|
||||
|
||||
def log_warning(self, message: str):
|
||||
"""记录警告日志"""
|
||||
logger.warning(f"[{self.node_name}] 警告: {message}")
|
||||
|
||||
def log_error(self, message: str):
|
||||
"""记录错误日志"""
|
||||
print(f"[{self.node_name}] 错误: {message}")
|
||||
logger.error(f"[{self.node_name}] 错误: {message}")
|
||||
|
||||
|
||||
class StateMutationNode(BaseNode):
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from .base_node import BaseNode
|
||||
from loguru import logger
|
||||
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
|
||||
from ..utils.text_processing import (
|
||||
remove_reasoning_from_output,
|
||||
@@ -65,7 +66,7 @@ class ReportFormattingNode(BaseNode):
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在格式化最终报告")
|
||||
logger.info("正在格式化最终报告")
|
||||
|
||||
# 调用LLM生成Markdown格式
|
||||
response = self.llm_client.invoke(
|
||||
@@ -76,11 +77,11 @@ class ReportFormattingNode(BaseNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成格式化报告")
|
||||
logger.info("成功生成格式化报告")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"报告格式化失败: {str(e)}")
|
||||
logger.exception(f"报告格式化失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
@@ -109,7 +110,7 @@ class ReportFormattingNode(BaseNode):
|
||||
return cleaned_output.strip()
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
|
||||
|
||||
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
|
||||
@@ -125,7 +126,7 @@ class ReportFormattingNode(BaseNode):
|
||||
格式化的Markdown报告
|
||||
"""
|
||||
try:
|
||||
self.log_info("使用手动格式化方法")
|
||||
logger.info("使用手动格式化方法")
|
||||
|
||||
# 构建报告
|
||||
report_lines = [
|
||||
@@ -163,5 +164,5 @@ class ReportFormattingNode(BaseNode):
|
||||
return "\n".join(report_lines)
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"手动格式化失败: {str(e)}")
|
||||
logger.exception(f"手动格式化失败: {str(e)}")
|
||||
return "# 报告生成失败\n\n无法完成报告格式化。"
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
from loguru import logger
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..state.state import State
|
||||
@@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode):
|
||||
报告结构列表
|
||||
"""
|
||||
try:
|
||||
self.log_info(f"正在为查询生成报告结构: {self.query}")
|
||||
logger.info(f"正在为查询生成报告结构: {self.query}")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
||||
@@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
|
||||
logger.info(f"成功生成 {len(processed_response)} 个段落结构")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成报告结构失败: {str(e)}")
|
||||
logger.exception(f"生成报告结构失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> List[Dict[str, str]]:
|
||||
@@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
report_structure = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
report_structure = extract_clean_response(cleaned_output)
|
||||
if "error" in report_structure:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
logger.error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
report_structure = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
logger.error("JSON修复失败")
|
||||
# 返回默认结构
|
||||
return self._generate_default_structure()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认结构")
|
||||
logger.error("无法修复JSON,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
# 验证结构
|
||||
if not isinstance(report_structure, list):
|
||||
self.log_info("报告结构不是列表,尝试转换...")
|
||||
logger.info("报告结构不是列表,尝试转换...")
|
||||
if isinstance(report_structure, dict):
|
||||
# 如果是单个对象,包装成列表
|
||||
report_structure = [report_structure]
|
||||
else:
|
||||
self.log_error("报告结构格式无效,使用默认结构")
|
||||
logger.error("报告结构格式无效,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
# 验证每个段落
|
||||
validated_structure = []
|
||||
for i, paragraph in enumerate(report_structure):
|
||||
if not isinstance(paragraph, dict):
|
||||
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
|
||||
logger.warning(f"段落 {i+1} 不是字典格式,跳过")
|
||||
continue
|
||||
|
||||
title = paragraph.get("title", f"段落 {i+1}")
|
||||
content = paragraph.get("content", "")
|
||||
|
||||
if not title or not content:
|
||||
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
|
||||
logger.warning(f"段落 {i+1} 缺少标题或内容,跳过")
|
||||
continue
|
||||
|
||||
validated_structure.append({
|
||||
@@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode):
|
||||
})
|
||||
|
||||
if not validated_structure:
|
||||
self.log_warning("没有有效的段落结构,使用默认结构")
|
||||
logger.warning("没有有效的段落结构,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
|
||||
logger.info(f"成功验证 {len(validated_structure)} 个段落结构")
|
||||
return validated_structure
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
return self._generate_default_structure()
|
||||
|
||||
def _generate_default_structure(self) -> List[Dict[str, str]]:
|
||||
@@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode):
|
||||
Returns:
|
||||
默认的报告结构列表
|
||||
"""
|
||||
self.log_info("生成默认报告结构")
|
||||
logger.info("生成默认报告结构")
|
||||
return [
|
||||
{
|
||||
"title": "研究概述",
|
||||
@@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode):
|
||||
content=paragraph_data["content"]
|
||||
)
|
||||
|
||||
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
|
||||
logger.info(f"已将 {len(report_structure)} 个段落添加到状态中")
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
logger.exception(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from json.decoder import JSONDecodeError
|
||||
from loguru import logger
|
||||
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
|
||||
@@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode):
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成首次搜索查询")
|
||||
logger.info("正在生成首次搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
||||
@@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次搜索查询失败: {str(e)}")
|
||||
logger.exception(f"生成首次搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
@@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
logger.error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
logger.error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_search_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
logger.error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
# 验证和清理结果
|
||||
@@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode):
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
logger.warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
return {
|
||||
@@ -197,7 +198,7 @@ class ReflectionNode(BaseNode):
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在进行反思并生成新搜索查询")
|
||||
logger.info("正在进行反思并生成新搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
|
||||
@@ -205,11 +206,11 @@ class ReflectionNode(BaseNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"反思生成搜索查询失败: {str(e)}")
|
||||
logger.exception(f"反思生成搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
@@ -228,30 +229,30 @@ class ReflectionNode(BaseNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
logger.error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
logger.error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
logger.error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
# 验证和清理结果
|
||||
@@ -259,7 +260,7 @@ class ReflectionNode(BaseNode):
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
logger.warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
return {
|
||||
@@ -268,7 +269,7 @@ class ReflectionNode(BaseNode):
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
from loguru import logger
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..state.state import State
|
||||
@@ -27,7 +28,7 @@ try:
|
||||
FORUM_READER_AVAILABLE = True
|
||||
except ImportError:
|
||||
FORUM_READER_AVAILABLE = False
|
||||
print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
|
||||
logger.warning("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
|
||||
|
||||
|
||||
class FirstSummaryNode(StateMutationNode):
|
||||
@@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode):
|
||||
if host_speech:
|
||||
# 将HOST发言添加到输入数据中
|
||||
data['host_speech'] = host_speech
|
||||
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
|
||||
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
|
||||
except Exception as e:
|
||||
self.log_info(f"读取HOST发言失败: {str(e)}")
|
||||
logger.exception(f"读取HOST发言失败: {str(e)}")
|
||||
|
||||
# 转换为JSON字符串
|
||||
message = json.dumps(data, ensure_ascii=False)
|
||||
@@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode):
|
||||
formatted_host = format_host_speech_for_prompt(data['host_speech'])
|
||||
message = formatted_host + "\n" + message
|
||||
|
||||
self.log_info("正在生成首次段落总结")
|
||||
logger.info("正在生成首次段落总结")
|
||||
|
||||
# 调用LLM生成总结
|
||||
response = self.llm_client.invoke(
|
||||
@@ -107,11 +108,11 @@ class FirstSummaryNode(StateMutationNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成首次段落总结")
|
||||
logger.info("成功生成首次段落总结")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次总结失败: {str(e)}")
|
||||
logger.exception(f"生成首次总结失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
@@ -130,26 +131,26 @@ class FirstSummaryNode(StateMutationNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_info("JSON修复失败,直接使用清理后的文本")
|
||||
logger.exception("JSON修复失败,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
else:
|
||||
self.log_info("无法修复JSON,直接使用清理后的文本")
|
||||
logger.exception("无法修复JSON,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
@@ -163,7 +164,7 @@ class FirstSummaryNode(StateMutationNode):
|
||||
return cleaned_output
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
return "段落总结生成失败"
|
||||
|
||||
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
|
||||
@@ -186,7 +187,7 @@ class FirstSummaryNode(StateMutationNode):
|
||||
# 更新状态
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
state.paragraphs[paragraph_index].research.latest_summary = summary
|
||||
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
|
||||
logger.info(f"已更新段落 {paragraph_index} 的首次总结")
|
||||
else:
|
||||
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
|
||||
|
||||
@@ -194,7 +195,7 @@ class FirstSummaryNode(StateMutationNode):
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
logger.exception(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
@@ -252,9 +253,9 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
if host_speech:
|
||||
# 将HOST发言添加到输入数据中
|
||||
data['host_speech'] = host_speech
|
||||
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
|
||||
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
|
||||
except Exception as e:
|
||||
self.log_info(f"读取HOST发言失败: {str(e)}")
|
||||
logger.exception(f"读取HOST发言失败: {str(e)}")
|
||||
|
||||
# 转换为JSON字符串
|
||||
message = json.dumps(data, ensure_ascii=False)
|
||||
@@ -264,7 +265,7 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
formatted_host = format_host_speech_for_prompt(data['host_speech'])
|
||||
message = formatted_host + "\n" + message
|
||||
|
||||
self.log_info("正在生成反思总结")
|
||||
logger.info("正在生成反思总结")
|
||||
|
||||
# 调用LLM生成总结
|
||||
response = self.llm_client.invoke(
|
||||
@@ -275,11 +276,11 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成反思总结")
|
||||
logger.info("成功生成反思总结")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成反思总结失败: {str(e)}")
|
||||
logger.exception(f"生成反思总结失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
@@ -298,26 +299,26 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_info("JSON修复失败,直接使用清理后的文本")
|
||||
logger.exception("JSON修复失败,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
else:
|
||||
self.log_info("无法修复JSON,直接使用清理后的文本")
|
||||
logger.exception("无法修复JSON,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
@@ -331,7 +332,7 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
return cleaned_output
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
return "反思总结生成失败"
|
||||
|
||||
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
|
||||
@@ -355,7 +356,7 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
|
||||
state.paragraphs[paragraph_index].research.increment_reflection()
|
||||
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
|
||||
logger.info(f"已更新段落 {paragraph_index} 的反思总结")
|
||||
else:
|
||||
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
|
||||
|
||||
@@ -363,5 +364,5 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
logger.exception(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
@@ -12,7 +12,7 @@ from .text_processing import (
|
||||
format_search_results_for_prompt
|
||||
)
|
||||
|
||||
from .config import Config, load_config
|
||||
from .config import Settings
|
||||
|
||||
__all__ = [
|
||||
"clean_json_tags",
|
||||
@@ -21,6 +21,5 @@ __all__ = [
|
||||
"extract_clean_response",
|
||||
"update_state_with_search_results",
|
||||
"format_search_results_for_prompt",
|
||||
"Config",
|
||||
"load_config"
|
||||
"Settings",
|
||||
]
|
||||
|
||||
+68
-140
@@ -1,151 +1,79 @@
|
||||
"""
|
||||
Configuration management module for the Query Engine.
|
||||
Query Engine 配置管理模块
|
||||
|
||||
此模块使用 pydantic-settings 管理 Query Engine 的配置,支持从环境变量和 .env 文件自动加载。
|
||||
数据模型定义位置:
|
||||
- 本文件 - 配置模型定义
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def _get_value(source, key: str, default=None, *fallback_keys: str):
|
||||
candidates = (key,) + fallback_keys
|
||||
value = None
|
||||
for candidate in candidates:
|
||||
if isinstance(source, dict):
|
||||
value = source.get(candidate)
|
||||
else:
|
||||
value = getattr(source, candidate, None)
|
||||
if value not in (None, ""):
|
||||
break
|
||||
if value in (None, ""):
|
||||
for candidate in candidates:
|
||||
env_val = os.getenv(candidate)
|
||||
if env_val not in (None, ""):
|
||||
value = env_val
|
||||
break
|
||||
return value if value not in (None, "") else default
|
||||
# 计算 .env 优先级:优先当前工作目录,其次项目根目录
|
||||
PROJECT_ROOT: Path = Path(__file__).resolve().parents[2]
|
||||
CWD_ENV: Path = Path.cwd() / ".env"
|
||||
ENV_FILE: str = str(CWD_ENV if CWD_ENV.exists() else (PROJECT_ROOT / ".env"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Query Engine configuration."""
|
||||
|
||||
llm_api_key: Optional[str] = None
|
||||
llm_base_url: Optional[str] = None
|
||||
llm_model_name: Optional[str] = None
|
||||
llm_provider: Optional[str] = None # compatibility
|
||||
|
||||
tavily_api_key: Optional[str] = None
|
||||
|
||||
search_timeout: int = 240
|
||||
max_content_length: int = 20000
|
||||
max_reflections: int = 2
|
||||
max_paragraphs: int = 5
|
||||
max_search_results: int = 20
|
||||
|
||||
output_dir: str = "reports"
|
||||
save_intermediate_states: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.llm_provider and self.llm_model_name:
|
||||
self.llm_provider = self.llm_model_name
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not self.llm_api_key:
|
||||
print("错误: Query Engine LLM API Key 未设置 (QUERY_ENGINE_API_KEY)。")
|
||||
return False
|
||||
if not self.llm_model_name:
|
||||
print("错误: Query Engine 模型名称未设置 (QUERY_ENGINE_MODEL_NAME)。")
|
||||
return False
|
||||
if not self.tavily_api_key:
|
||||
print("错误: Tavily API Key 未设置 (TAVILY_API_KEY)。")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_file: str) -> "Config":
|
||||
if config_file.endswith(".py"):
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("config", config_file)
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config_module)
|
||||
|
||||
return cls(
|
||||
llm_api_key=_get_value(config_module, "QUERY_ENGINE_API_KEY"),
|
||||
llm_base_url=_get_value(config_module, "QUERY_ENGINE_BASE_URL"),
|
||||
llm_model_name=_get_value(config_module, "QUERY_ENGINE_MODEL_NAME"),
|
||||
tavily_api_key=_get_value(config_module, "TAVILY_API_KEY"),
|
||||
search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
|
||||
max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
|
||||
max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 2)),
|
||||
max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 5)),
|
||||
max_search_results=int(_get_value(config_module, "MAX_SEARCH_RESULTS", 20)),
|
||||
output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=str(
|
||||
_get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
|
||||
).lower()
|
||||
in ("true", "1", "yes"),
|
||||
)
|
||||
|
||||
config_dict = {}
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
config_dict[key.strip()] = value.strip()
|
||||
|
||||
return cls(
|
||||
llm_api_key=_get_value(config_dict, "QUERY_ENGINE_API_KEY"),
|
||||
llm_base_url=_get_value(config_dict, "QUERY_ENGINE_BASE_URL"),
|
||||
llm_model_name=_get_value(config_dict, "QUERY_ENGINE_MODEL_NAME"),
|
||||
tavily_api_key=_get_value(config_dict, "TAVILY_API_KEY"),
|
||||
search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
|
||||
max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
|
||||
max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 2)),
|
||||
max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 5)),
|
||||
max_search_results=int(_get_value(config_dict, "MAX_SEARCH_RESULTS", 20)),
|
||||
output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=str(
|
||||
_get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
|
||||
).lower()
|
||||
in ("true", "1", "yes"),
|
||||
)
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
Query Engine 全局配置;支持 .env 和环境变量自动加载。
|
||||
变量名与原 config.py 大写一致,便于平滑过渡。
|
||||
"""
|
||||
|
||||
# ======================= LLM 相关 =======================
|
||||
QUERY_ENGINE_API_KEY: str = Field(..., description="Query Engine LLM API密钥,用于主LLM。您可以更改每个部分LLM使用的API,🚩只要兼容OpenAI请求格式都可以,定义好KEY、BASE_URL与MODEL_NAME即可正常使用。")
|
||||
QUERY_ENGINE_BASE_URL: Optional[str] = Field(None, description="Query Engine LLM接口BaseUrl,可自定义厂商API")
|
||||
QUERY_ENGINE_MODEL_NAME: str = Field(..., description="Query Engine LLM模型名称")
|
||||
QUERY_ENGINE_PROVIDER: Optional[str] = Field(None, description="Query Engine LLM提供商(兼容字段)")
|
||||
|
||||
# ================== 网络工具配置 ====================
|
||||
TAVILY_API_KEY: str = Field(..., description="Tavily API(申请地址:https://www.tavily.com/)API密钥,用于Tavily网络搜索")
|
||||
|
||||
# ================== 搜索参数配置 ====================
|
||||
SEARCH_TIMEOUT: int = Field(240, description="搜索超时(秒)")
|
||||
SEARCH_CONTENT_MAX_LENGTH: int = Field(20000, description="用于提示的最长内容长度")
|
||||
MAX_REFLECTIONS: int = Field(2, description="最大反思轮数")
|
||||
MAX_PARAGRAPHS: int = Field(5, description="最大段落数")
|
||||
MAX_SEARCH_RESULTS: int = Field(20, description="最大搜索结果数")
|
||||
|
||||
# ================== 输出配置 ====================
|
||||
OUTPUT_DIR: str = Field("reports", description="输出目录")
|
||||
SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态")
|
||||
|
||||
class Config:
|
||||
env_file = ENV_FILE
|
||||
env_prefix = ""
|
||||
case_sensitive = False
|
||||
extra = "allow"
|
||||
|
||||
|
||||
def load_config(config_file: Optional[str] = None) -> Config:
|
||||
if config_file:
|
||||
if not os.path.exists(config_file):
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_file}")
|
||||
file_to_load = config_file
|
||||
else:
|
||||
for candidate in ("config.py", "config.env", ".env"):
|
||||
if os.path.exists(candidate):
|
||||
file_to_load = candidate
|
||||
print(f"已找到配置文件: {candidate}")
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
|
||||
# 创建全局配置实例
|
||||
settings = Settings()
|
||||
|
||||
config = Config.from_file(file_to_load)
|
||||
if not config.validate():
|
||||
raise ValueError("配置校验失败,请检查 config.py 中的相关配置。")
|
||||
return config
|
||||
|
||||
|
||||
def print_config(config: Config):
|
||||
print("\n=== Query Engine 配置 ===")
|
||||
print(f"LLM 模型: {config.llm_model_name}")
|
||||
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
|
||||
print(f"Tavily API Key: {'已配置' if config.tavily_api_key else '未配置'}")
|
||||
print(f"搜索超时: {config.search_timeout} 秒")
|
||||
print(f"最长内容长度: {config.max_content_length}")
|
||||
print(f"最大反思次数: {config.max_reflections}")
|
||||
print(f"最大段落数: {config.max_paragraphs}")
|
||||
print(f"最大搜索结果数: {config.max_search_results}")
|
||||
print(f"输出目录: {config.output_dir}")
|
||||
print(f"保存中间状态: {config.save_intermediate_states}")
|
||||
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
|
||||
print("========================\n")
|
||||
def print_config(config: Settings):
|
||||
"""
|
||||
打印配置信息
|
||||
|
||||
Args:
|
||||
config: Settings配置对象
|
||||
"""
|
||||
message = ""
|
||||
message += "=== Query Engine 配置 ===\n"
|
||||
message += f"LLM 模型: {config.QUERY_ENGINE_MODEL_NAME}\n"
|
||||
message += f"LLM Base URL: {config.QUERY_ENGINE_BASE_URL or '(默认)'}\n"
|
||||
message += f"Tavily API Key: {'已配置' if config.TAVILY_API_KEY else '未配置'}\n"
|
||||
message += f"搜索超时: {config.SEARCH_TIMEOUT} 秒\n"
|
||||
message += f"最长内容长度: {config.SEARCH_CONTENT_MAX_LENGTH}\n"
|
||||
message += f"最大反思次数: {config.MAX_REFLECTIONS}\n"
|
||||
message += f"最大段落数: {config.MAX_PARAGRAPHS}\n"
|
||||
message += f"最大搜索结果数: {config.MAX_SEARCH_RESULTS}\n"
|
||||
message += f"输出目录: {config.OUTPUT_DIR}\n"
|
||||
message += f"保存中间状态: {config.SAVE_INTERMEDIATE_STATES}\n"
|
||||
message += f"LLM API Key: {'已配置' if config.QUERY_ENGINE_API_KEY else '未配置'}\n"
|
||||
message += "========================\n"
|
||||
logger.info(message)
|
||||
|
||||
Reference in New Issue
Block a user