1. 统一为使用基于pydantic的.env环境变量管理配置

2. 全项目基于loguru进行日志管理
This commit is contained in:
Doiiars
2025-11-05 14:56:49 +08:00
parent 1d2e23d8c1
commit 537d682861
50 changed files with 1404 additions and 1731 deletions
+2 -2
View File
@@ -4,9 +4,9 @@ Deep Search Agent
"""
from .agent import DeepSearchAgent, create_agent
from .utils.config import Config, load_config
from .utils.config import settings, Settings
__version__ = "1.0.0"
__author__ = "Deep Search Agent Team"
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
__all__ = ["DeepSearchAgent", "create_agent", "settings", "Settings"]
+114 -119
View File
@@ -8,6 +8,7 @@ import os
import re
from datetime import datetime
from typing import Optional, Dict, Any, List, Union
from loguru import logger
from .llms import LLMClient
from .nodes import (
@@ -20,32 +21,25 @@ from .nodes import (
)
from .state import State
from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer
from .utils import Config, load_config, format_search_results_for_prompt
from .utils.config import settings, Settings
from .utils import format_search_results_for_prompt
class DeepSearchAgent:
"""Deep Search Agent主类"""
def __init__(self, config: Optional[Config] = None):
def __init__(self, config: Optional[Settings] = None):
"""
初始化Deep Search Agent
Args:
config: 配置对象,如果不提供则自动加载
config: 可选配置对象(不填则用全局settings
"""
# 加载配置
self.config = config or load_config()
self.config = config or settings
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
# 设置数据库环境变量
os.environ["DB_HOST"] = self.config.db_host or ""
os.environ["DB_USER"] = self.config.db_user or ""
os.environ["DB_PASSWORD"] = self.config.db_password or ""
os.environ["DB_NAME"] = self.config.db_name or ""
os.environ["DB_PORT"] = str(self.config.db_port)
os.environ["DB_CHARSET"] = self.config.db_charset
# 初始化搜索工具集
self.search_agency = MediaCrawlerDB()
@@ -60,19 +54,19 @@ class DeepSearchAgent:
self.state = State()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
print(f"Insight Agent已初始化")
print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
print(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
logger.info(f"Insight Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
api_key=self.config.INSIGHT_ENGINE_API_KEY,
model_name=self.config.INSIGHT_ENGINE_MODEL_NAME,
base_url=self.config.INSIGHT_ENGINE_BASE_URL,
)
def _initialize_nodes(self):
@@ -127,7 +121,7 @@ class DeepSearchAgent:
Returns:
DBResponse对象(可能包含情感分析结果)
"""
print(f" → 执行数据库查询工具: {tool_name}")
logger.info(f" → 执行数据库查询工具: {tool_name}")
# 对于热点内容搜索,不需要关键词优化(因为不需要query参数)
if tool_name == "search_hot_content":
@@ -138,12 +132,12 @@ class DeepSearchAgent:
# 检查是否需要进行情感分析
enable_sentiment = kwargs.get("enable_sentiment", True)
if enable_sentiment and response.results and len(response.results) > 0:
print(f" 🎭 开始对热点内容进行情感分析...")
logger.info(f" 🎭 开始对热点内容进行情感分析...")
sentiment_analysis = self._perform_sentiment_analysis(response.results)
if sentiment_analysis:
# 将情感分析结果添加到响应的parameters中
response.parameters["sentiment_analysis"] = sentiment_analysis
print(f" ✅ 情感分析完成")
logger.info(f" ✅ 情感分析完成")
return response
@@ -170,32 +164,32 @@ class DeepSearchAgent:
context=f"使用{tool_name}工具进行查询"
)
print(f" 🔍 原始查询: '{query}'")
print(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
logger.info(f" 🔍 原始查询: '{query}'")
logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
# 使用优化后的关键词进行多次查询并整合结果
all_results = []
total_count = 0
for keyword in optimized_response.optimized_keywords:
print(f" 查询关键词: '{keyword}'")
logger.info(f" 查询关键词: '{keyword}'")
try:
if tool_name == "search_topic_globally":
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table)
elif tool_name == "search_topic_by_date":
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
if not start_date or not end_date:
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
elif tool_name == "get_comments_for_topic":
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
limit = self.config.default_get_comments_for_topic_limit // len(optimized_response.optimized_keywords)
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords)
limit = max(limit, 50)
response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit)
elif tool_name == "search_topic_on_platform":
@@ -203,30 +197,30 @@ class DeepSearchAgent:
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
limit = self.config.default_search_topic_on_platform_limit // len(optimized_response.optimized_keywords)
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords)
limit = max(limit, 30)
if not platform:
raise ValueError("search_topic_on_platform工具需要platform参数")
response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit)
else:
print(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.default_search_topic_globally_limit_per_table)
logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE)
# 收集结果
if response.results:
print(f" 找到 {len(response.results)} 条结果")
logger.info(f" 找到 {len(response.results)} 条结果")
all_results.extend(response.results)
total_count += len(response.results)
else:
print(f" 未找到结果")
logger.info(f" 未找到结果")
except Exception as e:
print(f" 查询'{keyword}'时出错: {str(e)}")
logger.error(f" 查询'{keyword}'时出错: {str(e)}")
continue
# 去重和整合结果
unique_results = self._deduplicate_results(all_results)
print(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)}")
logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)}")
# 构建整合后的响应
integrated_response = DBResponse(
@@ -244,12 +238,12 @@ class DeepSearchAgent:
# 检查是否需要进行情感分析
enable_sentiment = kwargs.get("enable_sentiment", True)
if enable_sentiment and unique_results and len(unique_results) > 0:
print(f" 🎭 开始对搜索结果进行情感分析...")
logger.info(f" 🎭 开始对搜索结果进行情感分析...")
sentiment_analysis = self._perform_sentiment_analysis(unique_results)
if sentiment_analysis:
# 将情感分析结果添加到响应的parameters中
integrated_response.parameters["sentiment_analysis"] = sentiment_analysis
print(f" ✅ 情感分析完成")
logger.info(f" ✅ 情感分析完成")
return integrated_response
@@ -282,11 +276,11 @@ class DeepSearchAgent:
try:
# 初始化情感分析器(如果尚未初始化且未被禁用)
if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled:
print(" 初始化情感分析模型...")
logger.info(" 初始化情感分析模型...")
if not self.sentiment_analyzer.initialize():
print(" 情感分析模型初始化失败,将直接透传原始文本")
logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
elif self.sentiment_analyzer.is_disabled:
print(" 情感分析功能已禁用,直接透传原始文本")
logger.info(" 情感分析功能已禁用,直接透传原始文本")
# 将查询结果转换为字典格式
results_dict = []
@@ -310,7 +304,7 @@ class DeepSearchAgent:
return sentiment_analysis.get("sentiment_analysis")
except Exception as e:
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
return None
def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]:
@@ -323,16 +317,16 @@ class DeepSearchAgent:
Returns:
情感分析结果
"""
print(f" → 执行独立情感分析")
logger.info(f" → 执行独立情感分析")
try:
# 初始化情感分析器(如果尚未初始化且未被禁用)
if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled:
print(" 初始化情感分析模型...")
logger.info(" 初始化情感分析模型...")
if not self.sentiment_analyzer.initialize():
print(" 情感分析模型初始化失败,将直接透传原始文本")
logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
elif self.sentiment_analyzer.is_disabled:
print(" 情感分析功能已禁用,直接透传原始文本")
logger.warning(" 情感分析功能已禁用,直接透传原始文本")
# 执行分析
if isinstance(texts, str):
@@ -368,7 +362,7 @@ class DeepSearchAgent:
return response
except Exception as e:
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
return {
"success": False,
"error": str(e),
@@ -386,9 +380,9 @@ class DeepSearchAgent:
Returns:
最终报告内容
"""
print(f"\n{'='*60}")
print(f"开始深度研究: {query}")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info(f"开始深度研究: {query}")
logger.info(f"{'='*60}")
try:
# Step 1: 生成报告结构
@@ -403,20 +397,18 @@ class DeepSearchAgent:
# Step 4: 保存报告
if save_report:
self._save_report(final_report)
print(f"\n{'='*60}")
print("深度研究完成!")
print(f"{'='*60}")
logger.info("深度研究完成!")
return final_report
except Exception as e:
print(f"研究过程中发生错误: {str(e)}")
logger.exception(f"研究过程中发生错误: {str(e)}")
raise e
def _generate_report_structure(self, query: str):
"""生成报告结构"""
print(f"\n[步骤 1] 生成报告结构...")
logger.info(f"\n[步骤 1] 生成报告结构...")
# 创建报告结构节点
report_structure_node = ReportStructureNode(self.llm_client, query)
@@ -424,17 +416,18 @@ class DeepSearchAgent:
# 生成结构并更新状态
self.state = report_structure_node.mutate_state(state=self.state)
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
for i, paragraph in enumerate(self.state.paragraphs, 1):
print(f" {i}. {paragraph.title}")
_message += f"\n {i}. {paragraph.title}"
logger.info(_message)
def _process_paragraphs(self):
"""处理所有段落"""
total_paragraphs = len(self.state.paragraphs)
for i in range(total_paragraphs):
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
print("-" * 50)
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
logger.info("-" * 50)
# 初始搜索和总结
self._initial_search_and_summary(i)
@@ -446,7 +439,7 @@ class DeepSearchAgent:
self.state.paragraphs[i].research.mark_completed()
progress = (i + 1) / total_paragraphs * 100
print(f"段落处理完成 ({progress:.1f}%)")
logger.info(f"段落处理完成 ({progress:.1f}%)")
def _initial_search_and_summary(self, paragraph_index: int):
"""执行初始搜索和总结"""
@@ -459,18 +452,18 @@ class DeepSearchAgent:
}
# 生成搜索查询和工具选择
print(" - 生成搜索查询...")
logger.info(" - 生成搜索查询...")
search_output = self.first_search_node.run(search_input)
search_query = search_output["search_query"]
search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具
reasoning = search_output["reasoning"]
print(f" - 搜索查询: {search_query}")
print(f" - 选择的工具: {search_tool}")
print(f" - 推理: {reasoning}")
logger.info(f" - 搜索查询: {search_query}")
logger.info(f" - 选择的工具: {search_tool}")
logger.info(f" - 推理: {reasoning}")
# 执行搜索
print(" - 执行数据库查询...")
logger.info(" - 执行数据库查询...")
# 处理特殊参数
search_kwargs = {}
@@ -485,13 +478,13 @@ class DeepSearchAgent:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" - 时间范围: {start_date}{end_date}")
logger.info(f" - 时间范围: {start_date}{end_date}")
else:
print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "search_topic_globally"
elif search_tool == "search_topic_by_date":
print(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理需要平台参数的工具
@@ -499,28 +492,28 @@ class DeepSearchAgent:
platform = search_output.get("platform")
if platform:
search_kwargs["platform"] = platform
print(f" - 指定平台: {platform}")
logger.info(f" - 指定平台: {platform}")
else:
print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理限制参数,使用配置文件中的默认值而不是agent提供的参数
if search_tool == "search_hot_content":
time_period = search_output.get("time_period", "week")
limit = self.config.default_search_hot_content_limit
limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT
search_kwargs["time_period"] = time_period
search_kwargs["limit"] = limit
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
if search_tool == "search_topic_globally":
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
else: # search_topic_by_date
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
search_kwargs["limit_per_table"] = limit_per_table
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
if search_tool == "get_comments_for_topic":
limit = self.config.default_get_comments_for_topic_limit
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT
else: # search_topic_on_platform
limit = self.config.default_search_topic_on_platform_limit
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
search_kwargs["limit"] = limit
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -529,8 +522,8 @@ class DeepSearchAgent:
search_results = []
if search_response and search_response.results:
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
if self.config.max_search_results_for_llm > 0:
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM)
else:
max_results = len(search_response.results) # 不限制,传递所有结果
for result in search_response.results[:max_results]:
@@ -548,24 +541,25 @@ class DeepSearchAgent:
})
if search_results:
print(f" - 找到 {len(search_results)} 个搜索结果")
_message = f" - 找到 {len(search_results)} 个搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" - 未找到搜索结果")
logger.info(" - 未找到搜索结果")
# 更新状态中的搜索历史
paragraph.research.add_search_results(search_query, search_results)
# 生成初始总结
print(" - 生成初始总结...")
logger.info(" - 生成初始总结...")
summary_input = {
"title": paragraph.title,
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.MAX_CONTENT_LENGTH
)
}
@@ -574,14 +568,14 @@ class DeepSearchAgent:
summary_input, self.state, paragraph_index
)
print(" - 初始总结完成")
logger.info(" - 初始总结完成")
def _reflection_loop(self, paragraph_index: int):
"""执行反思循环"""
paragraph = self.state.paragraphs[paragraph_index]
for reflection_i in range(self.config.max_reflections):
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
for reflection_i in range(self.config.MAX_REFLECTIONS):
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
# 准备反思输入
reflection_input = {
@@ -596,9 +590,9 @@ class DeepSearchAgent:
search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具
reasoning = reflection_output["reasoning"]
print(f" 反思查询: {search_query}")
print(f" 选择的工具: {search_tool}")
print(f" 反思推理: {reasoning}")
logger.info(f" 反思查询: {search_query}")
logger.info(f" 选择的工具: {search_tool}")
logger.info(f" 反思推理: {reasoning}")
# 执行反思搜索
# 处理特殊参数
@@ -614,13 +608,13 @@ class DeepSearchAgent:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" 时间范围: {start_date}{end_date}")
logger.info(f" 时间范围: {start_date}{end_date}")
else:
print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "search_topic_globally"
elif search_tool == "search_topic_by_date":
print(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理需要平台参数的工具
@@ -628,31 +622,31 @@ class DeepSearchAgent:
platform = reflection_output.get("platform")
if platform:
search_kwargs["platform"] = platform
print(f" 指定平台: {platform}")
logger.info(f" 指定平台: {platform}")
else:
print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理限制参数
if search_tool == "search_hot_content":
time_period = reflection_output.get("time_period", "week")
# 使用配置文件中的默认值,不允许agent控制limit参数
limit = self.config.default_search_hot_content_limit
limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT
search_kwargs["time_period"] = time_period
search_kwargs["limit"] = limit
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
# 使用配置文件中的默认值,不允许agent控制limit_per_table参数
if search_tool == "search_topic_globally":
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
else: # search_topic_by_date
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
search_kwargs["limit_per_table"] = limit_per_table
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
# 使用配置文件中的默认值,不允许agent控制limit参数
if search_tool == "get_comments_for_topic":
limit = self.config.default_get_comments_for_topic_limit
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT
else: # search_topic_on_platform
limit = self.config.default_search_topic_on_platform_limit
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
search_kwargs["limit"] = limit
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -661,8 +655,8 @@ class DeepSearchAgent:
search_results = []
if search_response and search_response.results:
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
if self.config.max_search_results_for_llm > 0:
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM)
else:
max_results = len(search_response.results) # 不限制,传递所有结果
for result in search_response.results[:max_results]:
@@ -680,12 +674,13 @@ class DeepSearchAgent:
})
if search_results:
print(f" 找到 {len(search_results)} 个反思搜索结果")
_message = f" 找到 {len(search_results)} 个反思搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" 未找到反思搜索结果")
logger.info(" 未找到反思搜索结果")
# 更新搜索历史
paragraph.research.add_search_results(search_query, search_results)
@@ -696,7 +691,7 @@ class DeepSearchAgent:
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.MAX_CONTENT_LENGTH
),
"paragraph_latest_state": paragraph.research.latest_summary
}
@@ -706,11 +701,11 @@ class DeepSearchAgent:
reflection_summary_input, self.state, paragraph_index
)
print(f" 反思 {reflection_i + 1} 完成")
logger.info(f" 反思 {reflection_i + 1} 完成")
def _generate_final_report(self) -> str:
"""生成最终报告"""
print(f"\n[步骤 3] 生成最终报告...")
logger.info(f"\n[步骤 3] 生成最终报告...")
# 准备报告数据
report_data = []
@@ -724,7 +719,7 @@ class DeepSearchAgent:
try:
final_report = self.report_formatting_node.run(report_data)
except Exception as e:
print(f"LLM格式化失败,使用备用方法: {str(e)}")
logger.exception(f"LLM格式化失败,使用备用方法: {str(e)}")
final_report = self.report_formatting_node.format_report_manually(
report_data, self.state.report_title
)
@@ -733,7 +728,7 @@ class DeepSearchAgent:
self.state.final_report = final_report
self.state.mark_completed()
print("最终报告生成完成")
logger.info("最终报告生成完成")
return final_report
def _save_report(self, report_content: str):
@@ -744,20 +739,20 @@ class DeepSearchAgent:
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
filepath = os.path.join(self.config.output_dir, filename)
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
# 保存报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(report_content)
print(f"报告已保存到: {filepath}")
logger.info(f"报告已保存到: {filepath}")
# 保存状态(如果配置允许)
if self.config.save_intermediate_states:
if self.config.SAVE_INTERMEDIATE_STATES:
state_filename = f"state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
self.state.save_to_file(state_filepath)
print(f"状态已保存到: {state_filepath}")
logger.info(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
@@ -766,12 +761,12 @@ class DeepSearchAgent:
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = State.load_from_file(filepath)
print(f"状态已从 {filepath} 加载")
logger.info(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
print(f"状态已保存到 {filepath}")
logger.info(f"状态已保存到 {filepath}")
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
@@ -784,5 +779,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
Returns:
DeepSearchAgent实例
"""
config = load_config(config_file)
config = settings
return DeepSearchAgent(config)
+2 -2
View File
@@ -31,9 +31,9 @@ class LLMClient:
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
if not api_key:
raise ValueError("Insight Engine LLM API key is required.")
raise ValueError("Insight Engine INSIGHT_ENGINE_API_KEY is required.")
if not model_name:
raise ValueError("Insight Engine model name is required.")
raise ValueError("Insight Engine INSIGHT_ENGINE_MODEL_NAME is required.")
self.api_key = api_key
self.base_url = base_url
+7 -2
View File
@@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from loguru import logger
from ..llms.base import LLMClient
from ..state.state import State
@@ -63,11 +64,15 @@ class BaseNode(ABC):
def log_info(self, message: str):
"""记录信息日志"""
print(f"[{self.node_name}] {message}")
logger.info(f"[{self.node_name}] {message}")
def log_warning(self, message: str):
"""记录警告日志"""
logger.warning(f"[{self.node_name}] 警告: {message}")
def log_error(self, message: str):
"""记录错误日志"""
print(f"[{self.node_name}] 错误: {message}")
logger.error(f"[{self.node_name}] 错误: {message}")
class StateMutationNode(BaseNode):
+13 -7
View File
@@ -5,6 +5,7 @@
import json
from typing import List, Dict, Any
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
@@ -14,6 +15,8 @@ from ..utils.text_processing import (
)
class ReportFormattingNode(BaseNode):
"""格式化最终报告的节点"""
@@ -65,19 +68,22 @@ class ReportFormattingNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在格式化最终报告")
logger.info("正在格式化最终报告")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message)
response = self.llm_client.invoke(
SYSTEM_PROMPT_REPORT_FORMATTING,
message,
)
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成格式化报告")
logger.info("成功生成格式化报告")
return processed_response
except Exception as e:
self.log_error(f"报告格式化失败: {str(e)}")
logger.exception(f"报告格式化失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -106,7 +112,7 @@ class ReportFormattingNode(BaseNode):
return cleaned_output.strip()
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
@@ -122,7 +128,7 @@ class ReportFormattingNode(BaseNode):
格式化的Markdown报告
"""
try:
self.log_info("使用手动格式化方法")
logger.info("使用手动格式化方法")
# 构建报告
report_lines = [
@@ -160,5 +166,5 @@ class ReportFormattingNode(BaseNode):
return "\n".join(report_lines)
except Exception as e:
self.log_error(f"手动格式化失败: {str(e)}")
logger.exception(f"手动格式化失败: {str(e)}")
return "# 报告生成失败\n\n无法完成报告格式化。"
+21 -20
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode):
报告结构列表
"""
try:
self.log_info(f"正在为查询生成报告结构: {self.query}")
logger.info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
@@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
logger.info(f"成功生成 {len(processed_response)} 个段落结构")
return processed_response
except Exception as e:
self.log_error(f"生成报告结构失败: {str(e)}")
logger.exception(f"生成报告结构失败: {str(e)}")
raise e
def process_output(self, output: str) -> List[Dict[str, str]]:
@@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
report_structure = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
report_structure = extract_clean_response(cleaned_output)
if "error" in report_structure:
self.log_error("JSON解析失败,尝试修复...")
logger.exception("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
report_structure = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.exception("JSON修复失败")
# 返回默认结构
return self._generate_default_structure()
else:
self.log_error("无法修复JSON,使用默认结构")
logger.exception("无法修复JSON,使用默认结构")
return self._generate_default_structure()
# 验证结构
if not isinstance(report_structure, list):
self.log_info("报告结构不是列表,尝试转换...")
logger.info("报告结构不是列表,尝试转换...")
if isinstance(report_structure, dict):
# 如果是单个对象,包装成列表
report_structure = [report_structure]
else:
self.log_error("报告结构格式无效,使用默认结构")
logger.exception("报告结构格式无效,使用默认结构")
return self._generate_default_structure()
# 验证每个段落
validated_structure = []
for i, paragraph in enumerate(report_structure):
if not isinstance(paragraph, dict):
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
logger.warning(f"段落 {i+1} 不是字典格式,跳过")
continue
title = paragraph.get("title", f"段落 {i+1}")
content = paragraph.get("content", "")
if not title or not content:
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
logger.warning(f"段落 {i+1} 缺少标题或内容,跳过")
continue
validated_structure.append({
@@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode):
})
if not validated_structure:
self.log_warning("没有有效的段落结构,使用默认结构")
logger.warning("没有有效的段落结构,使用默认结构")
return self._generate_default_structure()
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
logger.info(f"成功验证 {len(validated_structure)} 个段落结构")
return validated_structure
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return self._generate_default_structure()
def _generate_default_structure(self) -> List[Dict[str, str]]:
@@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode):
Returns:
默认的报告结构列表
"""
self.log_info("生成默认报告结构")
logger.info("生成默认报告结构")
return [
{
"title": "研究概述",
@@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode):
content=paragraph_data["content"]
)
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
logger.info(f"已将 {len(report_structure)} 个段落添加到状态中")
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+24 -23
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
@@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在生成首次搜索查询")
logger.info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
@@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"生成首次搜索查询失败: {str(e)}")
logger.exception(f"生成首次搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_search_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_search_query()
# 验证和清理结果
@@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_search_query()
return {
@@ -197,7 +198,7 @@ class ReflectionNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在进行反思并生成新搜索查询")
logger.info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
@@ -205,11 +206,11 @@ class ReflectionNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"反思生成搜索查询失败: {str(e)}")
logger.exception(f"反思生成搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -228,30 +229,30 @@ class ReflectionNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_reflection_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_reflection_query()
# 验证和清理结果
@@ -259,7 +260,7 @@ class ReflectionNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_reflection_query()
return {
@@ -268,7 +269,7 @@ class ReflectionNode(BaseNode):
}
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_reflection_query()
+30 -29
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -27,7 +28,7 @@ try:
FORUM_READER_AVAILABLE = True
except ImportError:
FORUM_READER_AVAILABLE = False
print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
logger.warning("无法导入forum_reader模块,将跳过HOST发言读取功能")
class FirstSummaryNode(StateMutationNode):
@@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成首次段落总结")
logger.info("正在生成首次段落总结")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
@@ -104,11 +105,11 @@ class FirstSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成首次段落总结")
logger.info("成功生成首次段落总结")
return processed_response
except Exception as e:
self.log_error(f"生成首次总结失败: {str(e)}")
logger.exception(f"生成首次总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -127,26 +128,26 @@ class FirstSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.exception("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.exception("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -160,7 +161,7 @@ class FirstSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "段落总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -183,7 +184,7 @@ class FirstSummaryNode(StateMutationNode):
# 更新状态
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = summary
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
logger.info(f"已更新段落 {paragraph_index} 的首次总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -191,7 +192,7 @@ class FirstSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
@@ -249,9 +250,9 @@ class ReflectionSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -261,7 +262,7 @@ class ReflectionSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成反思总结")
logger.info("正在生成反思总结")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
@@ -269,11 +270,11 @@ class ReflectionSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成反思总结")
logger.info("成功生成反思总结")
return processed_response
except Exception as e:
self.log_error(f"生成反思总结失败: {str(e)}")
logger.exception(f"生成反思总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -292,26 +293,26 @@ class ReflectionSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.info("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.info("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -325,7 +326,7 @@ class ReflectionSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "反思总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -349,7 +350,7 @@ class ReflectionSummaryNode(StateMutationNode):
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
state.paragraphs[paragraph_index].research.increment_reflection()
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
logger.info(f"已更新段落 {paragraph_index} 的反思总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -357,5 +358,5 @@ class ReflectionSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+16 -11
View File
@@ -12,7 +12,8 @@ from dataclasses import dataclass
# 添加项目根目录到Python路径以导入config
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from config import KEYWORD_OPTIMIZER_API_KEY, KEYWORD_OPTIMIZER_BASE_URL, KEYWORD_OPTIMIZER_MODEL_NAME
from config import settings
from loguru import logger
# 添加utils目录到Python路径
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -46,18 +47,18 @@ class KeywordOptimizer:
api_key: 硅基流动API密钥如果不提供则从配置文件读取
base_url: 接口基础地址默认使用配置文件提供的SiliconFlow地址
"""
self.api_key = api_key or KEYWORD_OPTIMIZER_API_KEY
self.api_key = api_key or settings.KEYWORD_OPTIMIZER_API_KEY
if not self.api_key:
raise ValueError("未找到硅基流动API密钥,请在config.py中设置KEYWORD_OPTIMIZER_API_KEY")
self.base_url = base_url or KEYWORD_OPTIMIZER_BASE_URL
self.base_url = base_url or settings.KEYWORD_OPTIMIZER_BASE_URL
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.model = model_name or KEYWORD_OPTIMIZER_MODEL_NAME
self.model = model_name or settings.KEYWORD_OPTIMIZER_MODEL_NAME
def optimize_keywords(self, original_query: str, context: str = "") -> KeywordOptimizationResponse:
"""
@@ -70,7 +71,7 @@ class KeywordOptimizer:
Returns:
KeywordOptimizationResponse: 优化后的关键词列表
"""
print(f"🔍 关键词优化中间件: 处理查询 '{original_query}'")
logger.info(f"🔍 关键词优化中间件: 处理查询 '{original_query}'")
try:
# 构建优化prompt
@@ -97,9 +98,13 @@ class KeywordOptimizer:
# 验证关键词质量
validated_keywords = self._validate_keywords(keywords)
print(f"✅ 优化成功: {len(validated_keywords)}个关键词")
for i, keyword in enumerate(validated_keywords, 1):
print(f" {i}. '{keyword}'")
logger.info(
f"✅ 优化成功: {len(validated_keywords)}个关键词" +
("" if not validated_keywords else "\n" +
"\n".join([f" {i}. '{k}'" for i, k in enumerate(validated_keywords, 1)]))
)
return KeywordOptimizationResponse(
original_query=original_query,
@@ -109,7 +114,7 @@ class KeywordOptimizer:
)
except Exception as e:
print(f"⚠️ 解析响应失败,使用备用方案: {str(e)}")
logger.exception(f"⚠️ 解析响应失败,使用备用方案: {str(e)}")
# 备用方案:从原始查询中提取关键词
fallback_keywords = self._fallback_keyword_extraction(original_query)
return KeywordOptimizationResponse(
@@ -119,7 +124,7 @@ class KeywordOptimizer:
success=True
)
else:
print(f"❌ API调用失败: {response['error']}")
logger.error(f"❌ API调用失败: {response['error']}")
# 使用备用方案
fallback_keywords = self._fallback_keyword_extraction(original_query)
return KeywordOptimizationResponse(
@@ -131,7 +136,7 @@ class KeywordOptimizer:
)
except Exception as e:
print(f"❌ 关键词优化失败: {str(e)}")
logger.error(f"❌ 关键词优化失败: {str(e)}")
# 最终备用方案
fallback_keywords = self._fallback_keyword_extraction(original_query)
return KeywordOptimizationResponse(
+80 -73
View File
@@ -25,10 +25,11 @@ V3.0 核心更新:
import os
import json
import pymysql
import pymysql.cursors
from loguru import logger
import asyncio
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
from ..utils.db import fetch_all
from datetime import datetime, timedelta, date
# --- 1. 数据结构定义 ---
@@ -69,36 +70,28 @@ class MediaCrawlerDB:
def __init__(self):
"""
初始化客户端连接信息从环境变量自动读取:
- DB_HOST, DB_USER, DB_PASSWORD, DB_NAME
- DB_PORT (可选, 默认 3306)
- DB_CHARSET (可选, 默认 utf8mb4)
初始化客户端
"""
self.db_config = {
'host': os.getenv("DB_HOST"),
'user': os.getenv("DB_USER"),
'password': os.getenv("DB_PASSWORD"),
'db': os.getenv("DB_NAME"),
'port': int(os.getenv("DB_PORT", 3306)),
'charset': os.getenv("DB_CHARSET", "utf8mb4"),
'cursorclass': pymysql.cursors.DictCursor
}
required = ['host', 'user', 'password', 'db']
if missing := [k for k in required if not self.db_config[k]]:
raise ValueError(f"数据库配置缺失! 请设置环境变量或在代码中提供: {', '.join([f'DB_{k.upper()}' for k in missing])}")
pass
def _execute_query(self, query: str, params: tuple = None) -> List[Dict[str, Any]]:
conn = None
try:
conn = pymysql.connect(**self.db_config)
with conn.cursor() as cursor:
cursor.execute(query, params or ())
return cursor.fetchall()
except pymysql.Error as e:
print(f"数据库查询时发生错误: {e}")
# 获取或创建event loop
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 直接运行协程
return loop.run_until_complete(fetch_all(query, params))
except Exception as e:
logger.exception(f"数据库查询时发生错误: {e}")
return []
finally:
if conn: conn.close()
@staticmethod
def _to_datetime(ts: Any) -> Optional[datetime]:
@@ -149,7 +142,7 @@ class MediaCrawlerDB:
DBResponse: 包含按综合热度排序后的内容列表
"""
params_for_log = {'time_period': time_period, 'limit': limit}
print(f"--- TOOL: 查找热点内容 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 查找热点内容 (params: {params_for_log}) ---")
now = datetime.now()
start_time = now - timedelta(days={'24h': 1, 'week': 7}.get(time_period, 365))
@@ -202,22 +195,28 @@ class MediaCrawlerDB:
DBResponse: 包含所有匹配结果的聚合列表
"""
params_for_log = {'topic': topic, 'limit_per_table': limit_per_table}
print(f"--- TOOL: 全局话题搜索 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 全局话题搜索 (params: {params_for_log}) ---")
search_term, all_results = f"%{topic}%", []
search_configs = { 'bilibili_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'bilibili_video_comment': {'fields': ['content'], 'type': 'comment'}, 'douyin_aweme': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'douyin_aweme_comment': {'fields': ['content'], 'type': 'comment'}, 'kuaishou_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'kuaishou_video_comment': {'fields': ['content'], 'type': 'comment'}, 'weibo_note': {'fields': ['content', 'source_keyword'], 'type': 'note'}, 'weibo_note_comment': {'fields': ['content'], 'type': 'comment'}, 'xhs_note': {'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note'}, 'xhs_note_comment': {'fields': ['content'], 'type': 'comment'}, 'zhihu_content': {'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content'}, 'zhihu_comment': {'fields': ['content'], 'type': 'comment'}, 'tieba_note': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'note'}, 'tieba_comment': {'fields': ['content'], 'type': 'comment'}, 'daily_news': {'fields': ['title'], 'type': 'news'}, }
for table, config in search_configs.items():
where_clause = " OR ".join([f"`{field}` LIKE %s" for field in config['fields']])
query = f"SELECT * FROM `{table}` WHERE {where_clause} ORDER BY id DESC LIMIT %s"
params = (search_term,) * len(config['fields']) + (limit_per_table,)
raw_results = self._execute_query(query, params)
param_dict = {}
where_clauses = []
for idx, field in enumerate(config['fields']):
pname = f"term_{idx}"
where_clauses.append(f'"{field}" LIKE :{pname}')
param_dict[pname] = search_term
param_dict['limit'] = limit_per_table
where_clause = " OR ".join(where_clauses)
query = f'SELECT * FROM "{table}" WHERE {where_clause} ORDER BY id DESC LIMIT :limit'
raw_results = self._execute_query(query, param_dict)
for row in raw_results:
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date')
all_results.append(QueryResult(
platform=table.split('_')[0], content_type=config['type'],
title_or_content=content[:500] if content else '',
title_or_content=content if content else '',
author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'),
url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'),
publish_time=self._to_datetime(time_key),
@@ -241,7 +240,7 @@ class MediaCrawlerDB:
DBResponse: 包含在指定日期范围内找到的结果的聚合列表
"""
params_for_log = {'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit_per_table': limit_per_table}
print(f"--- TOOL: 按日期搜索话题 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 按日期搜索话题 (params: {params_for_log}) ---")
try:
start_dt, end_dt = datetime.strptime(start_date, '%Y-%m-%d'), datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=1)
@@ -257,25 +256,25 @@ class MediaCrawlerDB:
}
for table, config in search_configs.items():
topic_clause = " OR ".join([f"`{field}` LIKE %s" for field in config['fields']])
time_col, time_type = config['time_col'], config['time_type']
if time_type == 'sec': time_params = (int(start_dt.timestamp()), int(end_dt.timestamp()))
elif time_type == 'ms': time_params = (int(start_dt.timestamp() * 1000), int(end_dt.timestamp() * 1000))
elif time_type in ['str', 'date_str']: time_params = (start_dt.strftime('%Y-%m-%d'), end_dt.strftime('%Y-%m-%d'))
else: time_params = (str(int(start_dt.timestamp())), str(int(end_dt.timestamp())))
time_clause = f"`{time_col}` >= %s AND `{time_col}` < %s"
if table == 'zhihu_content': time_clause = f"CAST(`{time_col}` AS UNSIGNED) >= %s AND CAST(`{time_col}` AS UNSIGNED) < %s"
query = f"SELECT * FROM `{table}` WHERE ({topic_clause}) AND ({time_clause}) ORDER BY id DESC LIMIT %s"
params = (search_term,) * len(config['fields']) + time_params + (limit_per_table,)
raw_results = self._execute_query(query, params)
param_dict = {}
where_clauses = []
for idx, field in enumerate(config['fields']):
pname = f"term_{idx}"
where_clauses.append(f'"{field}" LIKE :{pname}')
param_dict[pname] = search_term
param_dict['limit'] = limit_per_table
where_clause = ' OR '.join(where_clauses)
query = f'SELECT * FROM "{table}" WHERE {where_clause} ORDER BY id DESC LIMIT :limit'
raw_results = self._execute_query(query, param_dict)
for row in raw_results:
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date')
all_results.append(QueryResult(
platform=table.split('_')[0], content_type=config['type'],
title_or_content=content[:500] if content else '',
author_nickname=row.get('nickname') or row.get('user_nickname'),
title_or_content=content if content else '',
author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'),
url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'),
publish_time=self._to_datetime(row.get(config['time_col'])),
publish_time=self._to_datetime(time_key),
engagement=self._extract_engagement(row),
source_keyword=row.get('source_keyword'),
source_table=table
@@ -294,7 +293,7 @@ class MediaCrawlerDB:
DBResponse: 包含匹配的评论列表
"""
params_for_log = {'topic': topic, 'limit': limit}
print(f"--- TOOL: 获取话题评论 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 获取话题评论 (params: {params_for_log}) ---")
search_term = f"%{topic}%"
comment_tables = ['bilibili_video_comment', 'douyin_aweme_comment', 'kuaishou_video_comment', 'weibo_note_comment', 'xhs_note_comment', 'zhihu_comment', 'tieba_comment']
@@ -341,7 +340,7 @@ class MediaCrawlerDB:
DBResponse: 包含在该平台找到的结果列表
"""
params_for_log = {'platform': platform, 'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit': limit}
print(f"--- TOOL: 平台定向搜索 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 平台定向搜索 (params: {params_for_log}) ---")
all_configs = { 'bilibili': [{'table': 'bilibili_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'sec'}, {'table': 'bilibili_video_comment', 'fields': ['content'], 'type': 'comment'}], 'douyin': [{'table': 'douyin_aweme', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'douyin_aweme_comment', 'fields': ['content'], 'type': 'comment'}], 'kuaishou': [{'table': 'kuaishou_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'kuaishou_video_comment', 'fields': ['content'], 'type': 'comment'}], 'weibo': [{'table': 'weibo_note', 'fields': ['content', 'source_keyword'], 'type': 'note', 'time_col': 'create_date_time', 'time_type': 'str'}, {'table': 'weibo_note_comment', 'fields': ['content'], 'type': 'comment'}], 'xhs': [{'table': 'xhs_note', 'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note', 'time_col': 'time', 'time_type': 'ms'}, {'table': 'xhs_note_comment', 'fields': ['content'], 'type': 'comment'}], 'zhihu': [{'table': 'zhihu_content', 'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content', 'time_col': 'created_time', 'time_type': 'sec_str'}, {'table': 'zhihu_comment', 'fields': ['content'], 'type': 'comment'}], 'tieba': [{'table': 'tieba_note', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'note', 'time_col': 'publish_time', 'time_type': 'str'}, {'table': 'tieba_comment', 'fields': ['content'], 'type': 'comment'}] }
@@ -386,7 +385,7 @@ class MediaCrawlerDB:
for row in raw_results:
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
time_key = config.get('time_col') and row.get(config.get('time_col'))
all_results.append(QueryResult(platform=platform, content_type=config['type'], title_or_content=content[:500] if content else '', author_nickname=row.get('nickname') or row.get('user_nickname'), url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), publish_time=self._to_datetime(time_key), engagement=self._extract_engagement(row), source_keyword=row.get('source_keyword'), source_table=table))
all_results.append(QueryResult(platform=platform, content_type=config['type'], title_or_content=content if content else '', author_nickname=row.get('nickname') or row.get('user_nickname'), url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), publish_time=self._to_datetime(time_key), engagement=self._extract_engagement(row), source_keyword=row.get('source_keyword'), source_table=table))
return DBResponse("search_topic_on_platform", params_for_log, results=all_results, results_count=len(all_results))
@@ -394,33 +393,41 @@ class MediaCrawlerDB:
def print_response_summary(response: DBResponse):
"""简化的打印函数,用于展示测试结果"""
if response.error_message:
print(f"工具 '{response.tool_name}' 执行出错: {response.error_message}")
print("-" * 80)
logger.info(f"工具 '{response.tool_name}' 执行出错: {response.error_message}")
return
params_str = ", ".join(f"{k}='{v}'" for k, v in response.parameters.items())
print(f"查询: 工具='{response.tool_name}', 参数=[{params_str}]")
print(f"找到 {response.results_count} 条相关记录。")
logger.info(f"查询: 工具='{response.tool_name}', 参数=[{params_str}]")
logger.info(f"找到 {response.results_count} 条相关记录。")
if response.results:
print("--- 前5条结果示例 ---")
for i, res in enumerate(response.results[:5]):
engagement_str = ", ".join(f"{k}: {v}" for k, v in res.engagement.items() if v)
content_preview = (res.title_or_content.replace('\n', ' ')[:70] + '...') if res.title_or_content and len(res.title_or_content) > 70 else res.title_or_content
hotness_str = f", hotness: {res.hotness_score:.2f}" if res.hotness_score > 0 else ""
print(
f"{i+1}. [{res.platform.upper()}/{res.content_type}] {content_preview}\n"
f" by: {res.author_nickname}, at: {res.publish_time.strftime('%Y-%m-%d %H:%M') if res.publish_time else 'N/A'}"
f", src_kw: '{res.source_keyword or 'N/A'}'{hotness_str}"
f", engagement: {{{engagement_str}}}"
# 统一为一个消息输出
output_lines = []
output_lines.append("==== 查询结果预览(最多前5条) ====")
if response.results and len(response.results) > 0:
for idx, res in enumerate(response.results[:5], 1):
content_preview = (res.title_or_content.replace('\n', ' ')[:70] + '...') if res.title_or_content and len(res.title_or_content) > 70 else (res.title_or_content or '')
author_str = res.author_nickname or "N/A"
publish_time_str = res.publish_time.strftime('%Y-%m-%d %H:%M') if res.publish_time else "N/A"
hotness_str = f", hotness: {res.hotness_score:.2f}" if getattr(res, "hotness_score", 0) > 0 else ""
engagement_dict = getattr(res, "engagement", {}) or {}
engagement_str = ", ".join(f"{k}: {v}" for k, v in engagement_dict.items() if v)
output_lines.append(
f"{idx}. [{res.platform.upper()}/{res.content_type}] {content_preview}\n"
f" 作者: {author_str} | 时间: {publish_time_str}"
f"{hotness_str} | 源关键词: '{res.source_keyword or 'N/A'}'\n"
f" 链接: {res.url or 'N/A'}\n"
f" 互动数据: {{{engagement_str}}}"
)
print("-" * 80)
else:
output_lines.append("暂无相关内容。")
output_lines.append("=" * 60)
logger.info('\n'.join(output_lines))
if __name__ == "__main__":
try:
db_agent_tools = MediaCrawlerDB()
print("数据库工具初始化成功,开始执行测试场景...\n")
logger.info("数据库工具初始化成功,开始执行测试场景...\n")
# 场景1: (新) 查找过去一周综合热度最高的内容 (不再需要sort_by)
response1 = db_agent_tools.search_hot_content(time_period='week', limit=5)
@@ -443,7 +450,7 @@ if __name__ == "__main__":
print_response_summary(response5)
except ValueError as e:
print(f"初始化失败: {e}")
print("请确保相关的数据库环境变量已正确设置, 或在代码中直接提供连接信息。")
logger.exception(f"初始化失败: {e}")
logger.exception("请确保相关的数据库环境变量已正确设置, 或在代码中直接提供连接信息。")
except Exception as e:
print(f"测试过程中发生未知错误: {e}")
logger.exception(f"测试过程中发生未知错误: {e}")
-4
View File
@@ -12,8 +12,6 @@ from .text_processing import (
format_search_results_for_prompt
)
from .config import Config, load_config
__all__ = [
"clean_json_tags",
"clean_markdown_tags",
@@ -21,6 +19,4 @@ __all__ = [
"extract_clean_response",
"update_state_with_search_results",
"format_search_results_for_prompt",
"Config",
"load_config"
]
+34 -212
View File
@@ -6,218 +6,40 @@ Handles environment variables and config file parameters.
import os
from dataclasses import dataclass
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import Field
from loguru import logger
class Settings(BaseSettings):
INSIGHT_ENGINE_API_KEY: Optional[str] = Field(None, description="Insight Engine LLM API密钥")
INSIGHT_ENGINE_BASE_URL: Optional[str] = Field(None, description="Insight Engine LLM base url,可选")
INSIGHT_ENGINE_MODEL_NAME: Optional[str] = Field(None, description="Insight Engine LLM模型名称")
INSIGHT_ENGINE_PROVIDER: Optional[str] = Field(None, description="Insight Engine模型提供者,不再建议使用")
DB_HOST: Optional[str] = Field(None, description="数据库主机")
DB_USER: Optional[str] = Field(None, description="数据库用户名")
DB_PASSWORD: Optional[str] = Field(None, description="数据库密码")
DB_NAME: Optional[str] = Field(None, description="数据库名称")
DB_PORT: int = Field(3306, description="数据库端口")
DB_CHARSET: str = Field("utf8mb4", description="数据库字符集")
DB_DIALECT: Optional[str] = Field("mysql", description="数据库方言,如mysql、postgresql等,SQLAlchemy后端选择")
MAX_REFLECTIONS: int = Field(3, description="最大反思次数")
MAX_PARAGRAPHS: int = Field(6, description="最大段落数")
SEARCH_TIMEOUT: int = Field(240, description="单次搜索请求超时")
MAX_CONTENT_LENGTH: int = Field(500000, description="搜索最大内容长度")
DEFAULT_SEARCH_HOT_CONTENT_LIMIT: int = Field(100, description="热榜内容默认最大数")
DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE: int = Field(50, description="按表全局话题最大数")
DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE: int = Field(100, description="按日期话题最大数")
DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT: int = Field(500, description="单话题评论最大数")
DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT: int = Field(200, description="平台搜索话题最大数")
MAX_SEARCH_RESULTS_FOR_LLM: int = Field(0, description="供LLM用搜索结果最大数")
MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS: int = Field(0, description="高置信度情感分析最大数")
OUTPUT_DIR: str = Field("reports", description="输出路径")
SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态")
def _get_value(source, key: str, default=None):
"""
Helper to fetch a configuration value with environment fallback.
"""
value = None
if isinstance(source, dict):
value = source.get(key)
else:
value = getattr(source, key, None)
class Config:
env_file = ".env"
env_prefix = ""
case_sensitive = False
extra = "allow"
if value is None:
value = os.getenv(key, default)
return value if value not in ("", None) else default
@dataclass
class Config:
"""Insight Engine configuration."""
# LLM configuration
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # kept for backward compatibility
# Database configuration
db_host: Optional[str] = None
db_user: Optional[str] = None
db_password: Optional[str] = None
db_name: Optional[str] = None
db_port: int = 3306
db_charset: str = "utf8mb4"
# Model behaviour configuration
max_reflections: int = 3
max_paragraphs: int = 6
search_timeout: int = 240
max_content_length: int = 500000
# Search result limits
default_search_hot_content_limit: int = 100
default_search_topic_globally_limit_per_table: int = 50
default_search_topic_by_date_limit_per_table: int = 100
default_get_comments_for_topic_limit: int = 500
default_search_topic_on_platform_limit: int = 200
max_search_results_for_llm: int = 0
max_high_confidence_sentiment_results: int = 0
# Output configuration
output_dir: str = "reports"
save_intermediate_states: bool = True
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
# Provider is no longer used, but keep the attribute for compatibility.
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
"""Validate configuration."""
if not self.llm_api_key:
print("错误: Insight Engine LLM API Key 未设置 (INSIGHT_ENGINE_API_KEY)。")
return False
if not self.llm_model_name:
print("错误: Insight Engine 模型名称未设置 (INSIGHT_ENGINE_MODEL_NAME)。")
return False
if not all([self.db_host, self.db_user, self.db_password, self.db_name]):
print("错误: 数据库连接信息不完整,请检查 config.py 中的 DB_* 配置。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
"""Create configuration from file."""
if config_file.endswith(".py"):
import importlib.util
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
llm_api_key=_get_value(config_module, "INSIGHT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "INSIGHT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "INSIGHT_ENGINE_MODEL_NAME"),
db_host=_get_value(config_module, "DB_HOST"),
db_user=_get_value(config_module, "DB_USER"),
db_password=_get_value(config_module, "DB_PASSWORD"),
db_name=_get_value(config_module, "DB_NAME"),
db_port=int(_get_value(config_module, "DB_PORT", 3306)),
db_charset=_get_value(config_module, "DB_CHARSET", "utf8mb4"),
max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 3)),
max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 6)),
search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 500000)),
default_search_hot_content_limit=int(
_get_value(config_module, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100)
),
default_search_topic_globally_limit_per_table=int(
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50)
),
default_search_topic_by_date_limit_per_table=int(
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100)
),
default_get_comments_for_topic_limit=int(
_get_value(config_module, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500)
),
default_search_topic_on_platform_limit=int(
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200)
),
max_search_results_for_llm=int(_get_value(config_module, "MAX_SEARCH_RESULTS_FOR_LLM", 0)),
max_high_confidence_sentiment_results=int(
_get_value(config_module, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0)
),
output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
# .env style configuration
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "INSIGHT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "INSIGHT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "INSIGHT_ENGINE_MODEL_NAME"),
db_host=_get_value(config_dict, "DB_HOST"),
db_user=_get_value(config_dict, "DB_USER"),
db_password=_get_value(config_dict, "DB_PASSWORD"),
db_name=_get_value(config_dict, "DB_NAME"),
db_port=int(_get_value(config_dict, "DB_PORT", 3306)),
db_charset=_get_value(config_dict, "DB_CHARSET", "utf8mb4"),
max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 3)),
max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 6)),
search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 500000)),
default_search_hot_content_limit=int(
_get_value(config_dict, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100)
),
default_search_topic_globally_limit_per_table=int(
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50)
),
default_search_topic_by_date_limit_per_table=int(
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100)
),
default_get_comments_for_topic_limit=int(
_get_value(config_dict, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500)
),
default_search_topic_on_platform_limit=int(
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200)
),
max_search_results_for_llm=int(_get_value(config_dict, "MAX_SEARCH_RESULTS_FOR_LLM", 0)),
max_high_confidence_sentiment_results=int(
_get_value(config_dict, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0)
),
output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
def load_config(config_file: Optional[str] = None) -> Config:
"""
Load configuration.
"""
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
config = Config.from_file(file_to_load)
if not config.validate():
raise ValueError("配置校验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Config):
"""Print configuration (sensitive values masked)."""
print("\n=== Insight Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"搜索超时: {config.search_timeout}")
print(f"最长内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}")
print(f"最大段落数: {config.max_paragraphs}")
print(f"输出目录: {config.output_dir}")
print(f"保存中间状态: {config.save_intermediate_states}")
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
print(f"数据库连接: {'已配置' if all([config.db_host, config.db_user, config.db_password, config.db_name]) else '未配置'}")
print("========================\n")
settings = Settings()