1. 统一为使用基于pydantic的.env环境变量管理配置

2. 全项目基于loguru进行日志管理
This commit is contained in:
Doiiars
2025-11-05 14:56:49 +08:00
parent 1d2e23d8c1
commit 537d682861
50 changed files with 1404 additions and 1731 deletions
+2 -2
View File
@@ -4,9 +4,9 @@ Deep Search Agent
"""
from .agent import DeepSearchAgent, create_agent
from .utils.config import Config, load_config
from .utils.config import Settings
__version__ = "1.0.0"
__author__ = "Deep Search Agent Team"
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
__all__ = ["DeepSearchAgent", "create_agent", "Settings"]
+65 -63
View File
@@ -8,7 +8,7 @@ import os
import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from loguru import logger
from .llms import LLMClient
from .nodes import (
ReportStructureNode,
@@ -20,29 +20,26 @@ from .nodes import (
)
from .state import State
from .tools import BochaMultimodalSearch, BochaResponse
from .utils import Config, load_config, format_search_results_for_prompt
from .utils import settings, Settings, format_search_results_for_prompt
class DeepSearchAgent:
"""Deep Search Agent主类"""
def __init__(self, config: Optional[Config] = None):
def __init__(self, config: Optional[Settings] = None):
"""
初始化Deep Search Agent
Args:
config: 配置对象,如果不提供则自动加载
"""
# 加载配置
self.config = config or load_config()
os.environ["BOCHA_API_KEY"] = self.config.bocha_api_key or ""
os.environ["BOCHA_WEB_SEARCH_API_KEY"] = self.config.bocha_api_key or ""
self.config = config or settings
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
# 初始化搜索工具集
self.search_agency = BochaMultimodalSearch(api_key=self.config.bocha_api_key)
self.search_agency = BochaMultimodalSearch(api_key=(self.config.BOCHA_API_KEY or self.config.BOCHA_WEB_SEARCH_API_KEY))
# 初始化节点
self._initialize_nodes()
@@ -51,18 +48,18 @@ class DeepSearchAgent:
self.state = State()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
print(f"Meida Agent已初始化")
print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)")
logger.info(f"Meida Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)")
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
api_key=(self.config.MEDIA_ENGINE_API_KEY or self.config.MINDSPIDER_API_KEY),
model_name=(self.config.MEDIA_ENGINE_MODEL_NAME or self.config.MINDSPIDER_MODEL_NAME),
base_url=(self.config.MEDIA_ENGINE_BASE_URL or self.config.MINDSPIDER_BASE_URL),
)
def _initialize_nodes(self):
@@ -115,7 +112,7 @@ class DeepSearchAgent:
Returns:
BochaResponse对象
"""
print(f" → 执行搜索工具: {tool_name}")
logger.info(f" → 执行搜索工具: {tool_name}")
if tool_name == "comprehensive_search":
max_results = kwargs.get("max_results", 10)
@@ -130,7 +127,7 @@ class DeepSearchAgent:
elif tool_name == "search_last_week":
return self.search_agency.search_last_week(query)
else:
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索")
logger.info(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索")
return self.search_agency.comprehensive_search(query)
def research(self, query: str, save_report: bool = True) -> str:
@@ -144,9 +141,9 @@ class DeepSearchAgent:
Returns:
最终报告内容
"""
print(f"\n{'='*60}")
print(f"开始深度研究: {query}")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info(f"开始深度研究: {query}")
logger.info(f"{'='*60}")
try:
# Step 1: 生成报告结构
@@ -162,19 +159,21 @@ class DeepSearchAgent:
if save_report:
self._save_report(final_report)
print(f"\n{'='*60}")
print("深度研究完成!")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info("深度研究完成!")
logger.info(f"{'='*60}")
return final_report
except Exception as e:
print(f"研究过程中发生错误: {str(e)}")
import traceback
error_traceback = traceback.format_exc()
logger.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}")
raise e
def _generate_report_structure(self, query: str):
"""生成报告结构"""
print(f"\n[步骤 1] 生成报告结构...")
logger.info(f"\n[步骤 1] 生成报告结构...")
# 创建报告结构节点
report_structure_node = ReportStructureNode(self.llm_client, query)
@@ -182,17 +181,18 @@ class DeepSearchAgent:
# 生成结构并更新状态
self.state = report_structure_node.mutate_state(state=self.state)
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
for i, paragraph in enumerate(self.state.paragraphs, 1):
print(f" {i}. {paragraph.title}")
_message += f"\n {i}. {paragraph.title}"
logger.info(_message)
def _process_paragraphs(self):
"""处理所有段落"""
total_paragraphs = len(self.state.paragraphs)
for i in range(total_paragraphs):
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
print("-" * 50)
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
logger.info("-" * 50)
# 初始搜索和总结
self._initial_search_and_summary(i)
@@ -204,7 +204,7 @@ class DeepSearchAgent:
self.state.paragraphs[i].research.mark_completed()
progress = (i + 1) / total_paragraphs * 100
print(f"段落处理完成 ({progress:.1f}%)")
logger.info(f"段落处理完成 ({progress:.1f}%)")
def _initial_search_and_summary(self, paragraph_index: int):
"""执行初始搜索和总结"""
@@ -217,18 +217,18 @@ class DeepSearchAgent:
}
# 生成搜索查询和工具选择
print(" - 生成搜索查询...")
logger.info(" - 生成搜索查询...")
search_output = self.first_search_node.run(search_input)
search_query = search_output["search_query"]
search_tool = search_output.get("search_tool", "comprehensive_search") # 默认工具
reasoning = search_output["reasoning"]
print(f" - 搜索查询: {search_query}")
print(f" - 选择的工具: {search_tool}")
print(f" - 推理: {reasoning}")
logger.info(f" - 搜索查询: {search_query}")
logger.info(f" - 选择的工具: {search_tool}")
logger.info(f" - 推理: {reasoning}")
# 执行搜索
print(" - 执行网络搜索...")
logger.info(" - 执行网络搜索...")
# 处理特殊参数(新的工具集不需要日期参数处理)
search_kwargs = {}
@@ -254,24 +254,25 @@ class DeepSearchAgent:
})
if search_results:
print(f" - 找到 {len(search_results)} 个搜索结果")
_message = f" - 找到 {len(search_results)} 个搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" - 未找到搜索结果")
logger.info(" - 未找到搜索结果")
# 更新状态中的搜索历史
paragraph.research.add_search_results(search_query, search_results)
# 生成初始总结
print(" - 生成初始总结...")
logger.info(" - 生成初始总结...")
summary_input = {
"title": paragraph.title,
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
)
}
@@ -280,14 +281,14 @@ class DeepSearchAgent:
summary_input, self.state, paragraph_index
)
print(" - 初始总结完成")
logger.info(" - 初始总结完成")
def _reflection_loop(self, paragraph_index: int):
"""执行反思循环"""
paragraph = self.state.paragraphs[paragraph_index]
for reflection_i in range(self.config.max_reflections):
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
for reflection_i in range(self.config.MAX_REFLECTIONS):
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
# 准备反思输入
reflection_input = {
@@ -302,9 +303,9 @@ class DeepSearchAgent:
search_tool = reflection_output.get("search_tool", "comprehensive_search") # 默认工具
reasoning = reflection_output["reasoning"]
print(f" 反思查询: {search_query}")
print(f" 选择的工具: {search_tool}")
print(f" 反思推理: {reasoning}")
logger.info(f" 反思查询: {search_query}")
logger.info(f" 选择的工具: {search_tool}")
logger.info(f" 反思推理: {reasoning}")
# 执行反思搜索
# 处理特殊参数
@@ -331,12 +332,13 @@ class DeepSearchAgent:
})
if search_results:
print(f" 找到 {len(search_results)} 个反思搜索结果")
_message = f" 找到 {len(search_results)} 个反思搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" 未找到反思搜索结果")
logger.info(" 未找到反思搜索结果")
# 更新搜索历史
paragraph.research.add_search_results(search_query, search_results)
@@ -347,7 +349,7 @@ class DeepSearchAgent:
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
),
"paragraph_latest_state": paragraph.research.latest_summary
}
@@ -357,11 +359,11 @@ class DeepSearchAgent:
reflection_summary_input, self.state, paragraph_index
)
print(f" 反思 {reflection_i + 1} 完成")
logger.info(f" 反思 {reflection_i + 1} 完成")
def _generate_final_report(self) -> str:
"""生成最终报告"""
print(f"\n[步骤 3] 生成最终报告...")
logger.info(f"\n[步骤 3] 生成最终报告...")
# 准备报告数据
report_data = []
@@ -375,7 +377,7 @@ class DeepSearchAgent:
try:
final_report = self.report_formatting_node.run(report_data)
except Exception as e:
print(f"LLM格式化失败,使用备用方法: {str(e)}")
logger.info(f"LLM格式化失败,使用备用方法: {str(e)}")
final_report = self.report_formatting_node.format_report_manually(
report_data, self.state.report_title
)
@@ -384,7 +386,7 @@ class DeepSearchAgent:
self.state.final_report = final_report
self.state.mark_completed()
print("最终报告生成完成")
logger.info("最终报告生成完成")
return final_report
def _save_report(self, report_content: str):
@@ -395,20 +397,20 @@ class DeepSearchAgent:
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
filepath = os.path.join(self.config.output_dir, filename)
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
# 保存报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(report_content)
print(f"报告已保存到: {filepath}")
logger.info(f"报告已保存到: {filepath}")
# 保存状态(如果配置允许)
if self.config.save_intermediate_states:
if self.config.SAVE_INTERMEDIATE_STATES:
state_filename = f"state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
self.state.save_to_file(state_filepath)
print(f"状态已保存到: {state_filepath}")
logger.info(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
@@ -417,12 +419,12 @@ class DeepSearchAgent:
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = State.load_from_file(filepath)
print(f"状态已从 {filepath} 加载")
logger.info(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
print(f"状态已保存到 {filepath}")
logger.info(f"状态已保存到 {filepath}")
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
@@ -435,5 +437,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
Returns:
DeepSearchAgent实例
"""
config = load_config(config_file)
return DeepSearchAgent(config)
settings = Settings()
return DeepSearchAgent(settings)
+19 -14
View File
@@ -7,67 +7,72 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from ..llms.base import LLMClient
from ..state.state import State
from loguru import logger
class BaseNode(ABC):
"""节点基类"""
def __init__(self, llm_client: LLMClient, node_name: str = ""):
"""
初始化节点
Args:
llm_client: LLM客户端
node_name: 节点名称
"""
self.llm_client = llm_client
self.node_name = node_name or self.__class__.__name__
@abstractmethod
def run(self, input_data: Any, **kwargs) -> Any:
"""
执行节点处理逻辑
Args:
input_data: 输入数据
**kwargs: 额外参数
Returns:
处理结果
"""
pass
def validate_input(self, input_data: Any) -> bool:
"""
验证输入数据
Args:
input_data: 输入数据
Returns:
验证是否通过
"""
return True
def process_output(self, output: Any) -> Any:
"""
处理输出数据
Args:
output: 原始输出
Returns:
处理后的输出
"""
return output
def log_info(self, message: str):
"""记录信息日志"""
print(f"[{self.node_name}] {message}")
logger.info(f"[{self.node_name}] {message}")
def log_warning(self, message: str):
"""记录警告日志"""
logger.warning(f"[{self.node_name}] 警告: {message}")
def log_error(self, message: str):
"""记录错误日志"""
print(f"[{self.node_name}] 错误: {message}")
logger.error(f"[{self.node_name}] 错误: {message}")
class StateMutationNode(BaseNode):
+7 -6
View File
@@ -5,6 +5,7 @@
import json
from typing import List, Dict, Any
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
@@ -65,7 +66,7 @@ class ReportFormattingNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在格式化最终报告")
logger.info("正在格式化最终报告")
# 调用LLM生成Markdown格式
response = self.llm_client.invoke(
@@ -76,11 +77,11 @@ class ReportFormattingNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成格式化报告")
logger.info("成功生成格式化报告")
return processed_response
except Exception as e:
self.log_error(f"报告格式化失败: {str(e)}")
logger.exception(f"报告格式化失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -109,7 +110,7 @@ class ReportFormattingNode(BaseNode):
return cleaned_output.strip()
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
@@ -125,7 +126,7 @@ class ReportFormattingNode(BaseNode):
格式化的Markdown报告
"""
try:
self.log_info("使用手动格式化方法")
logger.info("使用手动格式化方法")
# 构建报告
report_lines = [
@@ -163,5 +164,5 @@ class ReportFormattingNode(BaseNode):
return "\n".join(report_lines)
except Exception as e:
self.log_error(f"手动格式化失败: {str(e)}")
logger.exception(f"手动格式化失败: {str(e)}")
return "# 报告生成失败\n\n无法完成报告格式化。"
+21 -20
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode):
报告结构列表
"""
try:
self.log_info(f"正在为查询生成报告结构: {self.query}")
logger.info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
@@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
logger.info(f"成功生成 {len(processed_response)} 个段落结构")
return processed_response
except Exception as e:
self.log_error(f"生成报告结构失败: {str(e)}")
logger.exception(f"生成报告结构失败: {str(e)}")
raise e
def process_output(self, output: str) -> List[Dict[str, str]]:
@@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
report_structure = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
report_structure = extract_clean_response(cleaned_output)
if "error" in report_structure:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
report_structure = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认结构
return self._generate_default_structure()
else:
self.log_error("无法修复JSON,使用默认结构")
logger.error("无法修复JSON,使用默认结构")
return self._generate_default_structure()
# 验证结构
if not isinstance(report_structure, list):
self.log_info("报告结构不是列表,尝试转换...")
logger.info("报告结构不是列表,尝试转换...")
if isinstance(report_structure, dict):
# 如果是单个对象,包装成列表
report_structure = [report_structure]
else:
self.log_error("报告结构格式无效,使用默认结构")
logger.error("报告结构格式无效,使用默认结构")
return self._generate_default_structure()
# 验证每个段落
validated_structure = []
for i, paragraph in enumerate(report_structure):
if not isinstance(paragraph, dict):
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
logger.warning(f"段落 {i+1} 不是字典格式,跳过")
continue
title = paragraph.get("title", f"段落 {i+1}")
content = paragraph.get("content", "")
if not title or not content:
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
logger.warning(f"段落 {i+1} 缺少标题或内容,跳过")
continue
validated_structure.append({
@@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode):
})
if not validated_structure:
self.log_warning("没有有效的段落结构,使用默认结构")
logger.warning("没有有效的段落结构,使用默认结构")
return self._generate_default_structure()
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
logger.info(f"成功验证 {len(validated_structure)} 个段落结构")
return validated_structure
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return self._generate_default_structure()
def _generate_default_structure(self) -> List[Dict[str, str]]:
@@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode):
Returns:
默认的报告结构列表
"""
self.log_info("生成默认报告结构")
logger.info("生成默认报告结构")
return [
{
"title": "研究概述",
@@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode):
content=paragraph_data["content"]
)
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
logger.info(f"已将 {len(report_structure)} 个段落添加到状态中")
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+24 -23
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
@@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在生成首次搜索查询")
logger.info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
@@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"生成首次搜索查询失败: {str(e)}")
logger.exception(f"生成首次搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_search_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_search_query()
# 验证和清理结果
@@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_search_query()
return {
@@ -197,7 +198,7 @@ class ReflectionNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在进行反思并生成新搜索查询")
logger.info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
@@ -205,11 +206,11 @@ class ReflectionNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"反思生成搜索查询失败: {str(e)}")
logger.exception(f"反思生成搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -228,30 +229,30 @@ class ReflectionNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_reflection_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_reflection_query()
# 验证和清理结果
@@ -259,7 +260,7 @@ class ReflectionNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_reflection_query()
return {
@@ -268,7 +269,7 @@ class ReflectionNode(BaseNode):
}
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_reflection_query()
+30 -29
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -27,7 +28,7 @@ try:
FORUM_READER_AVAILABLE = True
except ImportError:
FORUM_READER_AVAILABLE = False
print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
logger.warning("无法导入forum_reader模块,将跳过HOST发言读取功能")
class FirstSummaryNode(StateMutationNode):
@@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成首次段落总结")
logger.info("正在生成首次段落总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
@@ -107,11 +108,11 @@ class FirstSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成首次段落总结")
logger.info("成功生成首次段落总结")
return processed_response
except Exception as e:
self.log_error(f"生成首次总结失败: {str(e)}")
logger.exception(f"生成首次总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -130,26 +131,26 @@ class FirstSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.exception("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.exception("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -163,7 +164,7 @@ class FirstSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "段落总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -186,7 +187,7 @@ class FirstSummaryNode(StateMutationNode):
# 更新状态
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = summary
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
logger.info(f"已更新段落 {paragraph_index} 的首次总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -194,7 +195,7 @@ class FirstSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
@@ -252,9 +253,9 @@ class ReflectionSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -264,7 +265,7 @@ class ReflectionSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成反思总结")
logger.info("正在生成反思总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
@@ -275,11 +276,11 @@ class ReflectionSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成反思总结")
logger.info("成功生成反思总结")
return processed_response
except Exception as e:
self.log_error(f"生成反思总结失败: {str(e)}")
logger.exception(f"生成反思总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -298,26 +299,26 @@ class ReflectionSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.exception("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.exception("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -331,7 +332,7 @@ class ReflectionSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "反思总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -355,7 +356,7 @@ class ReflectionSummaryNode(StateMutationNode):
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
state.paragraphs[paragraph_index].research.increment_reflection()
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
logger.info(f"已更新段落 {paragraph_index} 的反思总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -363,5 +364,5 @@ class ReflectionSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+40 -37
View File
@@ -25,6 +25,9 @@ import json
import sys
from typing import List, Dict, Any, Optional, Literal
from loguru import logger
from config import settings
# 运行前请确保已安装 requests 库: pip install requests
try:
import requests
@@ -90,8 +93,8 @@ class BochaMultimodalSearch:
一个包含多种专用多模态搜索工具的客户端。
每个公共方法都设计为供 AI Agent 独立调用的工具。
"""
BASE_URL = "https://api.bochaai.com/v1/ai-search"
BOCHA_BASE_URL = settings.BOCHA_BASE_URL or "https://api.bochaai.com/v1/ai-search"
def __init__(self, api_key: Optional[str] = None):
"""
@@ -100,10 +103,10 @@ class BochaMultimodalSearch:
api_key: Bocha API密钥,若不提供则从环境变量 BOCHA_API_KEY 读取。
"""
if api_key is None:
api_key = os.getenv("BOCHA_API_KEY")
api_key = settings.BOCHA_WEB_SEARCH_API_KEY
if not api_key:
raise ValueError("Bocha API Key未找到!请设置 BOCHA_API_KEY 环境变量或在初始化时提供")
self._headers = {
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json',
@@ -112,7 +115,7 @@ class BochaMultimodalSearch:
def _parse_search_response(self, response_dict: Dict[str, Any], query: str) -> BochaResponse:
"""从API的原始字典响应中解析出结构化的BochaResponse对象"""
final_response = BochaResponse(query=query)
final_response.conversation_id = response_dict.get('conversation_id')
@@ -125,7 +128,7 @@ class BochaMultimodalSearch:
msg_type = msg.get('type')
content_type = msg.get('content_type')
content_str = msg.get('content', '{}')
try:
content_data = json.loads(content_str)
except json.JSONDecodeError:
@@ -134,7 +137,7 @@ class BochaMultimodalSearch:
if msg_type == 'answer' and content_type == 'text':
final_response.answer = content_data
elif msg_type == 'follow_up' and content_type == 'text':
final_response.follow_ups.append(content_data)
@@ -164,7 +167,7 @@ class BochaMultimodalSearch:
card_type=content_type,
content=content_data
))
return final_response
@@ -176,23 +179,23 @@ class BochaMultimodalSearch:
"stream": False, # Agent工具通常使用非流式以获取完整结果
}
payload.update(kwargs)
try:
response = requests.post(self.BASE_URL, headers=self._headers, json=payload, timeout=30)
response = requests.post(self.BOCHA_BASE_URL, headers=self._headers, json=payload, timeout=30)
response.raise_for_status() # 如果HTTP状态码是4xx或5xx,则抛出异常
response_dict = response.json()
if response_dict.get("code") != 200:
print(f"API返回错误: {response_dict.get('msg', '未知错误')}")
logger.error(f"API返回错误: {response_dict.get('msg', '未知错误')}")
return BochaResponse(query=query)
return self._parse_search_response(response_dict, query)
except requests.exceptions.RequestException as e:
print(f"搜索时发生网络错误: {str(e)}")
logger.exception(f"搜索时发生网络错误: {str(e)}")
raise e # 让重试机制捕获并处理
except Exception as e:
print(f"处理响应时发生未知错误: {str(e)}")
logger.exception(f"处理响应时发生未知错误: {str(e)}")
raise e # 让重试机制捕获并处理
# --- Agent 可用的工具方法 ---
@@ -203,19 +206,19 @@ class BochaMultimodalSearch:
返回网页、图片、AI总结、追问建议和可能的模态卡。这是最常用的通用搜索工具。
Agent可提供搜索查询(query)和可选的最大结果数(max_results)。
"""
print(f"--- TOOL: 全面综合搜索 (query: {query}) ---")
logger.info(f"--- TOOL: 全面综合搜索 (query: {query}) ---")
return self._search_internal(
query=query,
count=max_results,
answer=True # 开启AI总结
)
def web_search_only(self, query: str, max_results: int = 15) -> BochaResponse:
"""
【工具】纯网页搜索: 只获取网页链接和摘要,不请求AI生成答案。
适用于需要快速获取原始网页信息,而不需要AI额外分析的场景。速度更快,成本更低。
"""
print(f"--- TOOL: 纯网页搜索 (query: {query}) ---")
logger.info(f"--- TOOL: 纯网页搜索 (query: {query}) ---")
return self._search_internal(
query=query,
count=max_results,
@@ -228,7 +231,7 @@ class BochaMultimodalSearch:
当Agent意图是查询天气、股票、汇率、百科定义、火车票、汽车参数等结构化信息时,应优先使用此工具。
它会返回所有信息,但Agent应重点关注结果中的 `modal_cards` 部分。
"""
print(f"--- TOOL: 结构化数据查询 (query: {query}) ---")
logger.info(f"--- TOOL: 结构化数据查询 (query: {query}) ---")
# 实现上与 comprehensive_search 相同,但通过命名和文档引导Agent的意图
return self._search_internal(
query=query,
@@ -241,7 +244,7 @@ class BochaMultimodalSearch:
【工具】搜索24小时内信息: 获取关于某个主题的最新动态。
此工具专门查找过去24小时内发布的内容。适用于追踪突发事件或最新进展。
"""
print(f"--- TOOL: 搜索24小时内信息 (query: {query}) ---")
logger.info(f"--- TOOL: 搜索24小时内信息 (query: {query}) ---")
return self._search_internal(query=query, freshness='oneDay', answer=True)
def search_last_week(self, query: str) -> BochaResponse:
@@ -249,7 +252,7 @@ class BochaMultimodalSearch:
【工具】搜索本周信息: 获取关于某个主题过去一周内的主要报道。
适用于进行周度舆情总结或回顾。
"""
print(f"--- TOOL: 搜索本周信息 (query: {query}) ---")
logger.info(f"--- TOOL: 搜索本周信息 (query: {query}) ---")
return self._search_internal(query=query, freshness='oneWeek', answer=True)
@@ -258,32 +261,32 @@ class BochaMultimodalSearch:
def print_response_summary(response: BochaResponse):
"""简化的打印函数,用于展示测试结果"""
if not response or not response.query:
print("未能获取有效响应。")
logger.error("未能获取有效响应。")
return
print(f"\n查询: '{response.query}' | 会话ID: {response.conversation_id}")
logger.info(f"\n查询: '{response.query}' | 会话ID: {response.conversation_id}")
if response.answer:
print(f"AI摘要: {response.answer[:150]}...")
print(f"找到 {len(response.webpages)} 个网页, {len(response.images)} 张图片, {len(response.modal_cards)} 个模态卡。")
logger.info(f"AI摘要: {response.answer[:150]}...")
logger.info(f"找到 {len(response.webpages)} 个网页, {len(response.images)} 张图片, {len(response.modal_cards)} 个模态卡。")
if response.modal_cards:
first_card = response.modal_cards[0]
print(f"第一个模态卡类型: {first_card.card_type}")
logger.info(f"第一个模态卡类型: {first_card.card_type}")
if response.webpages:
first_result = response.webpages[0]
print(f"第一条网页结果: {first_result.name}")
logger.info(f"第一条网页结果: {first_result.name}")
if response.follow_ups:
print(f"建议追问: {response.follow_ups}")
logger.info(f"建议追问: {response.follow_ups}")
print("-" * 60)
logger.info("-" * 60)
if __name__ == "__main__":
# 在运行前,请确保您已设置 BOCHA_API_KEY 环境变量
try:
# 初始化多模态搜索客户端,它内部包含了所有工具
search_client = BochaMultimodalSearch()
@@ -297,7 +300,7 @@ if __name__ == "__main__":
print_response_summary(response2)
# 深度解析第一个模态卡
if response2.modal_cards and response2.modal_cards[0].card_type == 'weather_china':
print("天气模态卡详情:", json.dumps(response2.modal_cards[0].content, indent=2, ensure_ascii=False))
logger.info("天气模态卡详情:", json.dumps(response2.modal_cards[0].content, indent=2, ensure_ascii=False))
# 场景3: Agent需要查询特定结构化信息 - 股票
@@ -311,11 +314,11 @@ if __name__ == "__main__":
# 场景5: Agent只需要快速获取网页信息,不需要AI总结
response5 = search_client.web_search_only(query="Python dataclasses用法")
print_response_summary(response5)
# 场景6: Agent需要回顾一周内关于某项技术的新闻
response6 = search_client.search_last_week(query="量子计算商业化")
print_response_summary(response6)
'''下面是测试程序的输出:
--- TOOL: 全面综合搜索 (query: 人工智能对未来教育的影响) ---
@@ -381,7 +384,7 @@ AI摘要: 量子计算商业化正在逐步推进。
------------------------------------------------------------'''
except ValueError as e:
print(f"初始化失败: {e}")
print("请确保 BOCHA_API_KEY 环境变量已正确设置。")
logger.exception(f"初始化失败: {e}")
logger.error("请确保 BOCHA_API_KEY 环境变量已正确设置。")
except Exception as e:
print(f"测试过程中发生未知错误: {e}")
logger.exception(f"测试过程中发生未知错误: {e}")
+4 -4
View File
@@ -12,15 +12,15 @@ from .text_processing import (
format_search_results_for_prompt
)
from .config import Config, load_config
from .config import Settings, settings
__all__ = [
"clean_json_tags",
"clean_markdown_tags",
"remove_reasoning_from_output",
"remove_reasoning_from_output",
"extract_clean_response",
"update_state_with_search_results",
"format_search_results_for_prompt",
"Config",
"load_config"
"Settings",
"settings"
]
+75 -149
View File
@@ -1,157 +1,83 @@
"""
Configuration management module for the Media Engine.
Configuration management module for the Media Engine (pydantic_settings style).
"""
import os
from dataclasses import dataclass
from pathlib import Path
from pydantic_settings import BaseSettings
from pydantic import Field
from typing import Optional
def _get_value(source, key: str, default=None, *fallback_keys: str):
candidates = (key,) + fallback_keys
value = None
for candidate in candidates:
if isinstance(source, dict):
value = source.get(candidate)
else:
value = getattr(source, candidate, None)
if value not in (None, ""):
break
if value in (None, ""):
for candidate in candidates:
env_val = os.getenv(candidate)
if env_val not in (None, ""):
value = env_val
break
return value if value not in (None, "") else default
# 计算 .env 优先级:优先当前工作目录,其次项目根目录
PROJECT_ROOT: Path = Path(__file__).resolve().parents[2]
CWD_ENV: Path = Path.cwd() / ".env"
ENV_FILE: str = str(CWD_ENV if CWD_ENV.exists() else (PROJECT_ROOT / ".env"))
class Settings(BaseSettings):
"""
全局配置支持 .env 和环境变量自动加载
变量名与原 config.py 大写一致便于平滑过渡
"""
# ====================== 数据库配置 ======================
DB_HOST: str = Field("your_db_host", description="数据库主机,例如localhost 或 127.0.0.1。我们也提供云数据库资源便捷配置,日均10w+数据,可免费申请,联系我们:670939375@qq.com NOTE:为进行数据合规性审查与服务升级,云数据库自2025年10月1日起暂停接收新的使用申请")
DB_PORT: int = Field(3306, description="数据库端口号,默认为3306")
DB_USER: str = Field("your_db_user", description="数据库用户名")
DB_PASSWORD: str = Field("your_db_password", description="数据库密码")
DB_NAME: str = Field("your_db_name", description="数据库名称")
DB_CHARSET: str = Field("utf8mb4", description="数据库字符集,推荐utf8mb4,兼容emoji")
DB_DIALECT: str = Field("mysql", description="数据库类型,例如 'mysql''postgresql'。用于支持多种数据库后端(如 SQLAlchemy,请与连接信息共同配置)")
# ======================= LLM 相关 =======================
INSIGHT_ENGINE_API_KEY: str = Field(None, description="Insight Agent(推荐Kimihttps://platform.moonshot.cn/API密钥,用于主LLM。您可以更改每个部分LLM使用的API,🚩只要兼容OpenAI请求格式都可以,定义好KEY、BASE_URL与MODEL_NAME即可正常使用。重要提醒:我们强烈推荐您先使用推荐的配置申请API,先跑通再进行您的更改!")
INSIGHT_ENGINE_BASE_URL: Optional[str] = Field("https://api.moonshot.cn/v1", description="Insight Agent LLM接口BaseUrl,可自定义厂商API")
INSIGHT_ENGINE_MODEL_NAME: str = Field("kimi-k2-0711-preview", description="Insight Agent LLM模型名称,如kimi-k2-0711-preview")
MEDIA_ENGINE_API_KEY: str = Field(None, description="Media Agent(推荐Gemini,这里我用了一个中转厂商,你也可以换成你自己的,申请地址:https://www.chataiapi.com/API密钥")
MEDIA_ENGINE_BASE_URL: Optional[str] = Field("https://www.chataiapi.com/v1", description="Media Agent LLM接口BaseUrl")
MEDIA_ENGINE_MODEL_NAME: str = Field("gemini-2.5-pro", description="Media Agent LLM模型名称,如gemini-2.5-pro")
BOCHA_WEB_SEARCH_API_KEY: Optional[str] = Field(None, description="Bocha Web Search API Key")
BOCHA_API_KEY: Optional[str] = Field(None, description="Bocha 兼容键(别名)")
SEARCH_TIMEOUT: int = Field(240, description="搜索超时(秒)")
SEARCH_CONTENT_MAX_LENGTH: int = Field(20000, description="用于提示的最长内容长度")
MAX_REFLECTIONS: int = Field(2, description="最大反思轮数")
MAX_PARAGRAPHS: int = Field(5, description="最大段落数")
MINDSPIDER_API_KEY: Optional[str] = Field(None, description="MindSpider API密钥")
MINDSPIDER_BASE_URL: Optional[str] = Field("https://api.deepseek.com", description="MindSpider LLM接口BaseUrl")
MINDSPIDER_MODEL_NAME: str = Field("deepseek-reasoner", description="MindSpider LLM模型名称,如deepseek-reasoner")
OUTPUT_DIR: str = Field("reports", description="输出目录")
SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态")
QUERY_ENGINE_API_KEY: str = Field(None, description="Query Agent(推荐DeepSeekhttps://www.deepseek.com/API密钥")
QUERY_ENGINE_BASE_URL: Optional[str] = Field("https://api.deepseek.com", description="Query Agent LLM接口BaseUrl")
QUERY_ENGINE_MODEL_NAME: str = Field("deepseek-reasoner", description="Query Agent LLM模型,如deepseek-reasoner")
REPORT_ENGINE_API_KEY: str = Field(None, description="Report Agent(推荐Gemini,这里我用了一个中转厂商,你也可以换成你自己的,申请地址:https://www.chataiapi.com/API密钥")
REPORT_ENGINE_BASE_URL: Optional[str] = Field("https://www.chataiapi.com/v1", description="Report Agent LLM接口BaseUrl")
REPORT_ENGINE_MODEL_NAME: str = Field("gemini-2.5-pro", description="Report Agent LLM模型,如gemini-2.5-pro")
FORUM_HOST_API_KEY: str = Field(None, description="Forum Host(Qwen3最新模型,这里我使用了硅基流动这个平台,申请地址:https://cloud.siliconflow.cn/API密钥")
FORUM_HOST_BASE_URL: Optional[str] = Field("https://api.siliconflow.cn/v1", description="Forum Host LLM BaseUrl")
FORUM_HOST_MODEL_NAME: str = Field("Qwen/Qwen3-235B-A22B-Instruct-2507", description="Forum Host LLM模型名,如Qwen/Qwen3-235B-A22B-Instruct-2507")
KEYWORD_OPTIMIZER_API_KEY: str = Field(None, description="SQL keyword Optimizer(小参数Qwen3模型,这里我使用了硅基流动这个平台,申请地址:https://cloud.siliconflow.cn/API密钥")
KEYWORD_OPTIMIZER_BASE_URL: Optional[str] = Field("https://api.siliconflow.cn/v1", description="Keyword Optimizer BaseUrl")
KEYWORD_OPTIMIZER_MODEL_NAME: str = Field("Qwen/Qwen3-30B-A3B-Instruct-2507", description="Keyword Optimizer LLM模型名称,如Qwen/Qwen3-30B-A3B-Instruct-2507")
# ================== 网络工具配置 ====================
TAVILY_API_KEY: str = Field(None, description="Tavily API(申请地址:https://www.tavily.com/API密钥,用于Tavily网络搜索")
BOCHA_BASE_URL: Optional[str] = Field("https://api.bochaai.com/v1/ai-search", description="Bocha AI 搜索BaseUrl或博查网页搜索BaseUrl")
BOCHA_WEB_SEARCH_API_KEY: str = Field(None, description="Bocha API(申请地址:https://open.bochaai.com/API密钥,用于Bocha搜索")
class Config:
env_file = ENV_FILE
env_prefix = ""
case_sensitive = False
extra = "allow"
@dataclass
class Config:
"""Media Engine configuration."""
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # compatibility
bocha_api_key: Optional[str] = None
search_timeout: int = 240
max_content_length: int = 20000
max_reflections: int = 2
max_paragraphs: int = 5
output_dir: str = "reports"
save_intermediate_states: bool = True
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
if not self.llm_api_key:
print("错误: Media Engine LLM API Key 未设置 (MEDIA_ENGINE_API_KEY)。")
return False
if not self.llm_model_name:
print("错误: Media Engine 模型名称未设置 (MEDIA_ENGINE_MODEL_NAME)。")
return False
if not self.bocha_api_key:
print("错误: Bocha API Key 未设置 (BOCHA_WEB_SEARCH_API_KEY)。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
if config_file.endswith(".py"):
import importlib.util
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
llm_api_key=_get_value(config_module, "MEDIA_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "MEDIA_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "MEDIA_ENGINE_MODEL_NAME"),
bocha_api_key=_get_value(
config_module,
"BOCHA_WEB_SEARCH_API_KEY",
None,
"BOCHA_API_KEY",
),
search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 2)),
max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 5)),
output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "MEDIA_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "MEDIA_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "MEDIA_ENGINE_MODEL_NAME"),
bocha_api_key=_get_value(
config_dict,
"BOCHA_WEB_SEARCH_API_KEY",
None,
"BOCHA_API_KEY",
),
search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 2)),
max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 5)),
output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
def load_config(config_file: Optional[str] = None) -> Config:
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
config = Config.from_file(file_to_load)
if not config.validate():
raise ValueError("配置校验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Config):
print("\n=== Media Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"Bocha API Key: {'已配置' if config.bocha_api_key else '未配置'}")
print(f"搜索超时: {config.search_timeout}")
print(f"最长内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}")
print(f"最大段落数: {config.max_paragraphs}")
print(f"输出目录: {config.output_dir}")
print(f"保存中间状态: {config.save_intermediate_states}")
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
print("========================\n")
settings = Settings()