1. 统一为使用基于pydantic的.env环境变量管理配置
2. 全项目基于loguru进行日志管理
This commit is contained in:
@@ -4,9 +4,9 @@ Deep Search Agent
|
||||
"""
|
||||
|
||||
from .agent import DeepSearchAgent, create_agent
|
||||
from .utils.config import Config, load_config
|
||||
from .utils.config import settings, Settings
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Deep Search Agent Team"
|
||||
|
||||
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
|
||||
__all__ = ["DeepSearchAgent", "create_agent", "settings", "Settings"]
|
||||
|
||||
+114
-119
@@ -8,6 +8,7 @@ import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from loguru import logger
|
||||
|
||||
from .llms import LLMClient
|
||||
from .nodes import (
|
||||
@@ -20,32 +21,25 @@ from .nodes import (
|
||||
)
|
||||
from .state import State
|
||||
from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer
|
||||
from .utils import Config, load_config, format_search_results_for_prompt
|
||||
from .utils.config import settings, Settings
|
||||
from .utils import format_search_results_for_prompt
|
||||
|
||||
|
||||
class DeepSearchAgent:
|
||||
"""Deep Search Agent主类"""
|
||||
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
def __init__(self, config: Optional[Settings] = None):
|
||||
"""
|
||||
初始化Deep Search Agent
|
||||
|
||||
Args:
|
||||
config: 配置对象,如果不提供则自动加载
|
||||
config: 可选配置对象(不填则用全局settings)
|
||||
"""
|
||||
# 加载配置
|
||||
self.config = config or load_config()
|
||||
self.config = config or settings
|
||||
|
||||
# 初始化LLM客户端
|
||||
self.llm_client = self._initialize_llm()
|
||||
|
||||
# 设置数据库环境变量
|
||||
os.environ["DB_HOST"] = self.config.db_host or ""
|
||||
os.environ["DB_USER"] = self.config.db_user or ""
|
||||
os.environ["DB_PASSWORD"] = self.config.db_password or ""
|
||||
os.environ["DB_NAME"] = self.config.db_name or ""
|
||||
os.environ["DB_PORT"] = str(self.config.db_port)
|
||||
os.environ["DB_CHARSET"] = self.config.db_charset
|
||||
|
||||
# 初始化搜索工具集
|
||||
self.search_agency = MediaCrawlerDB()
|
||||
@@ -60,19 +54,19 @@ class DeepSearchAgent:
|
||||
self.state = State()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(self.config.output_dir, exist_ok=True)
|
||||
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
print(f"Insight Agent已初始化")
|
||||
print(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
|
||||
print(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
|
||||
logger.info(f"Insight Agent已初始化")
|
||||
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
|
||||
logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
|
||||
|
||||
def _initialize_llm(self) -> LLMClient:
|
||||
"""初始化LLM客户端"""
|
||||
return LLMClient(
|
||||
api_key=self.config.llm_api_key,
|
||||
model_name=self.config.llm_model_name,
|
||||
base_url=self.config.llm_base_url,
|
||||
api_key=self.config.INSIGHT_ENGINE_API_KEY,
|
||||
model_name=self.config.INSIGHT_ENGINE_MODEL_NAME,
|
||||
base_url=self.config.INSIGHT_ENGINE_BASE_URL,
|
||||
)
|
||||
|
||||
def _initialize_nodes(self):
|
||||
@@ -127,7 +121,7 @@ class DeepSearchAgent:
|
||||
Returns:
|
||||
DBResponse对象(可能包含情感分析结果)
|
||||
"""
|
||||
print(f" → 执行数据库查询工具: {tool_name}")
|
||||
logger.info(f" → 执行数据库查询工具: {tool_name}")
|
||||
|
||||
# 对于热点内容搜索,不需要关键词优化(因为不需要query参数)
|
||||
if tool_name == "search_hot_content":
|
||||
@@ -138,12 +132,12 @@ class DeepSearchAgent:
|
||||
# 检查是否需要进行情感分析
|
||||
enable_sentiment = kwargs.get("enable_sentiment", True)
|
||||
if enable_sentiment and response.results and len(response.results) > 0:
|
||||
print(f" 🎭 开始对热点内容进行情感分析...")
|
||||
logger.info(f" 🎭 开始对热点内容进行情感分析...")
|
||||
sentiment_analysis = self._perform_sentiment_analysis(response.results)
|
||||
if sentiment_analysis:
|
||||
# 将情感分析结果添加到响应的parameters中
|
||||
response.parameters["sentiment_analysis"] = sentiment_analysis
|
||||
print(f" ✅ 情感分析完成")
|
||||
logger.info(f" ✅ 情感分析完成")
|
||||
|
||||
return response
|
||||
|
||||
@@ -170,32 +164,32 @@ class DeepSearchAgent:
|
||||
context=f"使用{tool_name}工具进行查询"
|
||||
)
|
||||
|
||||
print(f" 🔍 原始查询: '{query}'")
|
||||
print(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
|
||||
logger.info(f" 🔍 原始查询: '{query}'")
|
||||
logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
|
||||
|
||||
# 使用优化后的关键词进行多次查询并整合结果
|
||||
all_results = []
|
||||
total_count = 0
|
||||
|
||||
for keyword in optimized_response.optimized_keywords:
|
||||
print(f" 查询关键词: '{keyword}'")
|
||||
logger.info(f" 查询关键词: '{keyword}'")
|
||||
|
||||
try:
|
||||
if tool_name == "search_topic_globally":
|
||||
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
|
||||
limit_per_table = self.config.default_search_topic_globally_limit_per_table
|
||||
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
|
||||
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table)
|
||||
elif tool_name == "search_topic_by_date":
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
|
||||
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
|
||||
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
|
||||
if not start_date or not end_date:
|
||||
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
|
||||
response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
|
||||
elif tool_name == "get_comments_for_topic":
|
||||
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
|
||||
limit = self.config.default_get_comments_for_topic_limit // len(optimized_response.optimized_keywords)
|
||||
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords)
|
||||
limit = max(limit, 50)
|
||||
response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit)
|
||||
elif tool_name == "search_topic_on_platform":
|
||||
@@ -203,30 +197,30 @@ class DeepSearchAgent:
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
|
||||
limit = self.config.default_search_topic_on_platform_limit // len(optimized_response.optimized_keywords)
|
||||
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords)
|
||||
limit = max(limit, 30)
|
||||
if not platform:
|
||||
raise ValueError("search_topic_on_platform工具需要platform参数")
|
||||
response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit)
|
||||
else:
|
||||
print(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
|
||||
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.default_search_topic_globally_limit_per_table)
|
||||
logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
|
||||
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE)
|
||||
|
||||
# 收集结果
|
||||
if response.results:
|
||||
print(f" 找到 {len(response.results)} 条结果")
|
||||
logger.info(f" 找到 {len(response.results)} 条结果")
|
||||
all_results.extend(response.results)
|
||||
total_count += len(response.results)
|
||||
else:
|
||||
print(f" 未找到结果")
|
||||
logger.info(f" 未找到结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 查询'{keyword}'时出错: {str(e)}")
|
||||
logger.error(f" 查询'{keyword}'时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
# 去重和整合结果
|
||||
unique_results = self._deduplicate_results(all_results)
|
||||
print(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条")
|
||||
logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条")
|
||||
|
||||
# 构建整合后的响应
|
||||
integrated_response = DBResponse(
|
||||
@@ -244,12 +238,12 @@ class DeepSearchAgent:
|
||||
# 检查是否需要进行情感分析
|
||||
enable_sentiment = kwargs.get("enable_sentiment", True)
|
||||
if enable_sentiment and unique_results and len(unique_results) > 0:
|
||||
print(f" 🎭 开始对搜索结果进行情感分析...")
|
||||
logger.info(f" 🎭 开始对搜索结果进行情感分析...")
|
||||
sentiment_analysis = self._perform_sentiment_analysis(unique_results)
|
||||
if sentiment_analysis:
|
||||
# 将情感分析结果添加到响应的parameters中
|
||||
integrated_response.parameters["sentiment_analysis"] = sentiment_analysis
|
||||
print(f" ✅ 情感分析完成")
|
||||
logger.info(f" ✅ 情感分析完成")
|
||||
|
||||
return integrated_response
|
||||
|
||||
@@ -282,11 +276,11 @@ class DeepSearchAgent:
|
||||
try:
|
||||
# 初始化情感分析器(如果尚未初始化且未被禁用)
|
||||
if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled:
|
||||
print(" 初始化情感分析模型...")
|
||||
logger.info(" 初始化情感分析模型...")
|
||||
if not self.sentiment_analyzer.initialize():
|
||||
print(" 情感分析模型初始化失败,将直接透传原始文本")
|
||||
logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
|
||||
elif self.sentiment_analyzer.is_disabled:
|
||||
print(" 情感分析功能已禁用,直接透传原始文本")
|
||||
logger.info(" 情感分析功能已禁用,直接透传原始文本")
|
||||
|
||||
# 将查询结果转换为字典格式
|
||||
results_dict = []
|
||||
@@ -310,7 +304,7 @@ class DeepSearchAgent:
|
||||
return sentiment_analysis.get("sentiment_analysis")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
|
||||
logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]:
|
||||
@@ -323,16 +317,16 @@ class DeepSearchAgent:
|
||||
Returns:
|
||||
情感分析结果
|
||||
"""
|
||||
print(f" → 执行独立情感分析")
|
||||
logger.info(f" → 执行独立情感分析")
|
||||
|
||||
try:
|
||||
# 初始化情感分析器(如果尚未初始化且未被禁用)
|
||||
if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled:
|
||||
print(" 初始化情感分析模型...")
|
||||
logger.info(" 初始化情感分析模型...")
|
||||
if not self.sentiment_analyzer.initialize():
|
||||
print(" 情感分析模型初始化失败,将直接透传原始文本")
|
||||
logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
|
||||
elif self.sentiment_analyzer.is_disabled:
|
||||
print(" 情感分析功能已禁用,直接透传原始文本")
|
||||
logger.warning(" 情感分析功能已禁用,直接透传原始文本")
|
||||
|
||||
# 执行分析
|
||||
if isinstance(texts, str):
|
||||
@@ -368,7 +362,7 @@ class DeepSearchAgent:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
|
||||
logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
@@ -386,9 +380,9 @@ class DeepSearchAgent:
|
||||
Returns:
|
||||
最终报告内容
|
||||
"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"开始深度研究: {query}")
|
||||
print(f"{'='*60}")
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"开始深度研究: {query}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
try:
|
||||
# Step 1: 生成报告结构
|
||||
@@ -403,20 +397,18 @@ class DeepSearchAgent:
|
||||
# Step 4: 保存报告
|
||||
if save_report:
|
||||
self._save_report(final_report)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("深度研究完成!")
|
||||
print(f"{'='*60}")
|
||||
|
||||
logger.info("深度研究完成!")
|
||||
|
||||
return final_report
|
||||
|
||||
except Exception as e:
|
||||
print(f"研究过程中发生错误: {str(e)}")
|
||||
logger.exception(f"研究过程中发生错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
def _generate_report_structure(self, query: str):
|
||||
"""生成报告结构"""
|
||||
print(f"\n[步骤 1] 生成报告结构...")
|
||||
logger.info(f"\n[步骤 1] 生成报告结构...")
|
||||
|
||||
# 创建报告结构节点
|
||||
report_structure_node = ReportStructureNode(self.llm_client, query)
|
||||
@@ -424,17 +416,18 @@ class DeepSearchAgent:
|
||||
# 生成结构并更新状态
|
||||
self.state = report_structure_node.mutate_state(state=self.state)
|
||||
|
||||
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
|
||||
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
|
||||
for i, paragraph in enumerate(self.state.paragraphs, 1):
|
||||
print(f" {i}. {paragraph.title}")
|
||||
_message += f"\n {i}. {paragraph.title}"
|
||||
logger.info(_message)
|
||||
|
||||
def _process_paragraphs(self):
|
||||
"""处理所有段落"""
|
||||
total_paragraphs = len(self.state.paragraphs)
|
||||
|
||||
for i in range(total_paragraphs):
|
||||
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
|
||||
print("-" * 50)
|
||||
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
|
||||
logger.info("-" * 50)
|
||||
|
||||
# 初始搜索和总结
|
||||
self._initial_search_and_summary(i)
|
||||
@@ -446,7 +439,7 @@ class DeepSearchAgent:
|
||||
self.state.paragraphs[i].research.mark_completed()
|
||||
|
||||
progress = (i + 1) / total_paragraphs * 100
|
||||
print(f"段落处理完成 ({progress:.1f}%)")
|
||||
logger.info(f"段落处理完成 ({progress:.1f}%)")
|
||||
|
||||
def _initial_search_and_summary(self, paragraph_index: int):
|
||||
"""执行初始搜索和总结"""
|
||||
@@ -459,18 +452,18 @@ class DeepSearchAgent:
|
||||
}
|
||||
|
||||
# 生成搜索查询和工具选择
|
||||
print(" - 生成搜索查询...")
|
||||
logger.info(" - 生成搜索查询...")
|
||||
search_output = self.first_search_node.run(search_input)
|
||||
search_query = search_output["search_query"]
|
||||
search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具
|
||||
reasoning = search_output["reasoning"]
|
||||
|
||||
print(f" - 搜索查询: {search_query}")
|
||||
print(f" - 选择的工具: {search_tool}")
|
||||
print(f" - 推理: {reasoning}")
|
||||
logger.info(f" - 搜索查询: {search_query}")
|
||||
logger.info(f" - 选择的工具: {search_tool}")
|
||||
logger.info(f" - 推理: {reasoning}")
|
||||
|
||||
# 执行搜索
|
||||
print(" - 执行数据库查询...")
|
||||
logger.info(" - 执行数据库查询...")
|
||||
|
||||
# 处理特殊参数
|
||||
search_kwargs = {}
|
||||
@@ -485,13 +478,13 @@ class DeepSearchAgent:
|
||||
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
|
||||
search_kwargs["start_date"] = start_date
|
||||
search_kwargs["end_date"] = end_date
|
||||
print(f" - 时间范围: {start_date} 到 {end_date}")
|
||||
logger.info(f" - 时间范围: {start_date} 到 {end_date}")
|
||||
else:
|
||||
print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
|
||||
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
|
||||
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
search_tool = "search_topic_globally"
|
||||
elif search_tool == "search_topic_by_date":
|
||||
print(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
|
||||
logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理需要平台参数的工具
|
||||
@@ -499,28 +492,28 @@ class DeepSearchAgent:
|
||||
platform = search_output.get("platform")
|
||||
if platform:
|
||||
search_kwargs["platform"] = platform
|
||||
print(f" - 指定平台: {platform}")
|
||||
logger.info(f" - 指定平台: {platform}")
|
||||
else:
|
||||
print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
|
||||
logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理限制参数,使用配置文件中的默认值而不是agent提供的参数
|
||||
if search_tool == "search_hot_content":
|
||||
time_period = search_output.get("time_period", "week")
|
||||
limit = self.config.default_search_hot_content_limit
|
||||
limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT
|
||||
search_kwargs["time_period"] = time_period
|
||||
search_kwargs["limit"] = limit
|
||||
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
|
||||
if search_tool == "search_topic_globally":
|
||||
limit_per_table = self.config.default_search_topic_globally_limit_per_table
|
||||
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
|
||||
else: # search_topic_by_date
|
||||
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
|
||||
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
|
||||
search_kwargs["limit_per_table"] = limit_per_table
|
||||
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
|
||||
if search_tool == "get_comments_for_topic":
|
||||
limit = self.config.default_get_comments_for_topic_limit
|
||||
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT
|
||||
else: # search_topic_on_platform
|
||||
limit = self.config.default_search_topic_on_platform_limit
|
||||
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
|
||||
search_kwargs["limit"] = limit
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
@@ -529,8 +522,8 @@ class DeepSearchAgent:
|
||||
search_results = []
|
||||
if search_response and search_response.results:
|
||||
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
|
||||
if self.config.max_search_results_for_llm > 0:
|
||||
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
|
||||
if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
|
||||
max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM)
|
||||
else:
|
||||
max_results = len(search_response.results) # 不限制,传递所有结果
|
||||
for result in search_response.results[:max_results]:
|
||||
@@ -548,24 +541,25 @@ class DeepSearchAgent:
|
||||
})
|
||||
|
||||
if search_results:
|
||||
print(f" - 找到 {len(search_results)} 个搜索结果")
|
||||
_message = f" - 找到 {len(search_results)} 个搜索结果"
|
||||
for j, result in enumerate(search_results, 1):
|
||||
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
|
||||
print(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
|
||||
logger.info(_message)
|
||||
else:
|
||||
print(" - 未找到搜索结果")
|
||||
logger.info(" - 未找到搜索结果")
|
||||
|
||||
# 更新状态中的搜索历史
|
||||
paragraph.research.add_search_results(search_query, search_results)
|
||||
|
||||
# 生成初始总结
|
||||
print(" - 生成初始总结...")
|
||||
logger.info(" - 生成初始总结...")
|
||||
summary_input = {
|
||||
"title": paragraph.title,
|
||||
"content": paragraph.content,
|
||||
"search_query": search_query,
|
||||
"search_results": format_search_results_for_prompt(
|
||||
search_results, self.config.max_content_length
|
||||
search_results, self.config.MAX_CONTENT_LENGTH
|
||||
)
|
||||
}
|
||||
|
||||
@@ -574,14 +568,14 @@ class DeepSearchAgent:
|
||||
summary_input, self.state, paragraph_index
|
||||
)
|
||||
|
||||
print(" - 初始总结完成")
|
||||
logger.info(" - 初始总结完成")
|
||||
|
||||
def _reflection_loop(self, paragraph_index: int):
|
||||
"""执行反思循环"""
|
||||
paragraph = self.state.paragraphs[paragraph_index]
|
||||
|
||||
for reflection_i in range(self.config.max_reflections):
|
||||
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
|
||||
for reflection_i in range(self.config.MAX_REFLECTIONS):
|
||||
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
|
||||
|
||||
# 准备反思输入
|
||||
reflection_input = {
|
||||
@@ -596,9 +590,9 @@ class DeepSearchAgent:
|
||||
search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具
|
||||
reasoning = reflection_output["reasoning"]
|
||||
|
||||
print(f" 反思查询: {search_query}")
|
||||
print(f" 选择的工具: {search_tool}")
|
||||
print(f" 反思推理: {reasoning}")
|
||||
logger.info(f" 反思查询: {search_query}")
|
||||
logger.info(f" 选择的工具: {search_tool}")
|
||||
logger.info(f" 反思推理: {reasoning}")
|
||||
|
||||
# 执行反思搜索
|
||||
# 处理特殊参数
|
||||
@@ -614,13 +608,13 @@ class DeepSearchAgent:
|
||||
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
|
||||
search_kwargs["start_date"] = start_date
|
||||
search_kwargs["end_date"] = end_date
|
||||
print(f" 时间范围: {start_date} 到 {end_date}")
|
||||
logger.info(f" 时间范围: {start_date} 到 {end_date}")
|
||||
else:
|
||||
print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
|
||||
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
|
||||
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
|
||||
search_tool = "search_topic_globally"
|
||||
elif search_tool == "search_topic_by_date":
|
||||
print(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
|
||||
logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理需要平台参数的工具
|
||||
@@ -628,31 +622,31 @@ class DeepSearchAgent:
|
||||
platform = reflection_output.get("platform")
|
||||
if platform:
|
||||
search_kwargs["platform"] = platform
|
||||
print(f" 指定平台: {platform}")
|
||||
logger.info(f" 指定平台: {platform}")
|
||||
else:
|
||||
print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
|
||||
logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理限制参数
|
||||
if search_tool == "search_hot_content":
|
||||
time_period = reflection_output.get("time_period", "week")
|
||||
# 使用配置文件中的默认值,不允许agent控制limit参数
|
||||
limit = self.config.default_search_hot_content_limit
|
||||
limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT
|
||||
search_kwargs["time_period"] = time_period
|
||||
search_kwargs["limit"] = limit
|
||||
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
|
||||
# 使用配置文件中的默认值,不允许agent控制limit_per_table参数
|
||||
if search_tool == "search_topic_globally":
|
||||
limit_per_table = self.config.default_search_topic_globally_limit_per_table
|
||||
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
|
||||
else: # search_topic_by_date
|
||||
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
|
||||
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
|
||||
search_kwargs["limit_per_table"] = limit_per_table
|
||||
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
|
||||
# 使用配置文件中的默认值,不允许agent控制limit参数
|
||||
if search_tool == "get_comments_for_topic":
|
||||
limit = self.config.default_get_comments_for_topic_limit
|
||||
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT
|
||||
else: # search_topic_on_platform
|
||||
limit = self.config.default_search_topic_on_platform_limit
|
||||
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
|
||||
search_kwargs["limit"] = limit
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
@@ -661,8 +655,8 @@ class DeepSearchAgent:
|
||||
search_results = []
|
||||
if search_response and search_response.results:
|
||||
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
|
||||
if self.config.max_search_results_for_llm > 0:
|
||||
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
|
||||
if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
|
||||
max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM)
|
||||
else:
|
||||
max_results = len(search_response.results) # 不限制,传递所有结果
|
||||
for result in search_response.results[:max_results]:
|
||||
@@ -680,12 +674,13 @@ class DeepSearchAgent:
|
||||
})
|
||||
|
||||
if search_results:
|
||||
print(f" 找到 {len(search_results)} 个反思搜索结果")
|
||||
_message = f" 找到 {len(search_results)} 个反思搜索结果"
|
||||
for j, result in enumerate(search_results, 1):
|
||||
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
|
||||
print(f" {j}. {result['title'][:50]}...{date_info}")
|
||||
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
|
||||
logger.info(_message)
|
||||
else:
|
||||
print(" 未找到反思搜索结果")
|
||||
logger.info(" 未找到反思搜索结果")
|
||||
|
||||
# 更新搜索历史
|
||||
paragraph.research.add_search_results(search_query, search_results)
|
||||
@@ -696,7 +691,7 @@ class DeepSearchAgent:
|
||||
"content": paragraph.content,
|
||||
"search_query": search_query,
|
||||
"search_results": format_search_results_for_prompt(
|
||||
search_results, self.config.max_content_length
|
||||
search_results, self.config.MAX_CONTENT_LENGTH
|
||||
),
|
||||
"paragraph_latest_state": paragraph.research.latest_summary
|
||||
}
|
||||
@@ -706,11 +701,11 @@ class DeepSearchAgent:
|
||||
reflection_summary_input, self.state, paragraph_index
|
||||
)
|
||||
|
||||
print(f" 反思 {reflection_i + 1} 完成")
|
||||
logger.info(f" 反思 {reflection_i + 1} 完成")
|
||||
|
||||
def _generate_final_report(self) -> str:
|
||||
"""生成最终报告"""
|
||||
print(f"\n[步骤 3] 生成最终报告...")
|
||||
logger.info(f"\n[步骤 3] 生成最终报告...")
|
||||
|
||||
# 准备报告数据
|
||||
report_data = []
|
||||
@@ -724,7 +719,7 @@ class DeepSearchAgent:
|
||||
try:
|
||||
final_report = self.report_formatting_node.run(report_data)
|
||||
except Exception as e:
|
||||
print(f"LLM格式化失败,使用备用方法: {str(e)}")
|
||||
logger.exception(f"LLM格式化失败,使用备用方法: {str(e)}")
|
||||
final_report = self.report_formatting_node.format_report_manually(
|
||||
report_data, self.state.report_title
|
||||
)
|
||||
@@ -733,7 +728,7 @@ class DeepSearchAgent:
|
||||
self.state.final_report = final_report
|
||||
self.state.mark_completed()
|
||||
|
||||
print("最终报告生成完成")
|
||||
logger.info("最终报告生成完成")
|
||||
return final_report
|
||||
|
||||
def _save_report(self, report_content: str):
|
||||
@@ -744,20 +739,20 @@ class DeepSearchAgent:
|
||||
query_safe = query_safe.replace(' ', '_')[:30]
|
||||
|
||||
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
|
||||
filepath = os.path.join(self.config.output_dir, filename)
|
||||
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
|
||||
|
||||
# 保存报告
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(report_content)
|
||||
|
||||
print(f"报告已保存到: {filepath}")
|
||||
logger.info(f"报告已保存到: {filepath}")
|
||||
|
||||
# 保存状态(如果配置允许)
|
||||
if self.config.save_intermediate_states:
|
||||
if self.config.SAVE_INTERMEDIATE_STATES:
|
||||
state_filename = f"state_{query_safe}_{timestamp}.json"
|
||||
state_filepath = os.path.join(self.config.output_dir, state_filename)
|
||||
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
|
||||
self.state.save_to_file(state_filepath)
|
||||
print(f"状态已保存到: {state_filepath}")
|
||||
logger.info(f"状态已保存到: {state_filepath}")
|
||||
|
||||
def get_progress_summary(self) -> Dict[str, Any]:
|
||||
"""获取进度摘要"""
|
||||
@@ -766,12 +761,12 @@ class DeepSearchAgent:
|
||||
def load_state(self, filepath: str):
|
||||
"""从文件加载状态"""
|
||||
self.state = State.load_from_file(filepath)
|
||||
print(f"状态已从 {filepath} 加载")
|
||||
logger.info(f"状态已从 {filepath} 加载")
|
||||
|
||||
def save_state(self, filepath: str):
|
||||
"""保存状态到文件"""
|
||||
self.state.save_to_file(filepath)
|
||||
print(f"状态已保存到 {filepath}")
|
||||
logger.info(f"状态已保存到 {filepath}")
|
||||
|
||||
|
||||
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
|
||||
@@ -784,5 +779,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
|
||||
Returns:
|
||||
DeepSearchAgent实例
|
||||
"""
|
||||
config = load_config(config_file)
|
||||
config = settings
|
||||
return DeepSearchAgent(config)
|
||||
|
||||
@@ -31,9 +31,9 @@ class LLMClient:
|
||||
|
||||
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
|
||||
if not api_key:
|
||||
raise ValueError("Insight Engine LLM API key is required.")
|
||||
raise ValueError("Insight Engine INSIGHT_ENGINE_API_KEY is required.")
|
||||
if not model_name:
|
||||
raise ValueError("Insight Engine model name is required.")
|
||||
raise ValueError("Insight Engine INSIGHT_ENGINE_MODEL_NAME is required.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional
|
||||
from loguru import logger
|
||||
from ..llms.base import LLMClient
|
||||
from ..state.state import State
|
||||
|
||||
@@ -63,11 +64,15 @@ class BaseNode(ABC):
|
||||
|
||||
def log_info(self, message: str):
|
||||
"""记录信息日志"""
|
||||
print(f"[{self.node_name}] {message}")
|
||||
logger.info(f"[{self.node_name}] {message}")
|
||||
|
||||
def log_warning(self, message: str):
|
||||
"""记录警告日志"""
|
||||
logger.warning(f"[{self.node_name}] 警告: {message}")
|
||||
|
||||
def log_error(self, message: str):
|
||||
"""记录错误日志"""
|
||||
print(f"[{self.node_name}] 错误: {message}")
|
||||
logger.error(f"[{self.node_name}] 错误: {message}")
|
||||
|
||||
|
||||
class StateMutationNode(BaseNode):
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
|
||||
@@ -14,6 +15,8 @@ from ..utils.text_processing import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class ReportFormattingNode(BaseNode):
|
||||
"""格式化最终报告的节点"""
|
||||
|
||||
@@ -65,19 +68,22 @@ class ReportFormattingNode(BaseNode):
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在格式化最终报告")
|
||||
logger.info("正在格式化最终报告")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message)
|
||||
response = self.llm_client.invoke(
|
||||
SYSTEM_PROMPT_REPORT_FORMATTING,
|
||||
message,
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成格式化报告")
|
||||
logger.info("成功生成格式化报告")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"报告格式化失败: {str(e)}")
|
||||
logger.exception(f"报告格式化失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
@@ -106,7 +112,7 @@ class ReportFormattingNode(BaseNode):
|
||||
return cleaned_output.strip()
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
|
||||
|
||||
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
|
||||
@@ -122,7 +128,7 @@ class ReportFormattingNode(BaseNode):
|
||||
格式化的Markdown报告
|
||||
"""
|
||||
try:
|
||||
self.log_info("使用手动格式化方法")
|
||||
logger.info("使用手动格式化方法")
|
||||
|
||||
# 构建报告
|
||||
report_lines = [
|
||||
@@ -160,5 +166,5 @@ class ReportFormattingNode(BaseNode):
|
||||
return "\n".join(report_lines)
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"手动格式化失败: {str(e)}")
|
||||
logger.exception(f"手动格式化失败: {str(e)}")
|
||||
return "# 报告生成失败\n\n无法完成报告格式化。"
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
from loguru import logger
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..state.state import State
|
||||
@@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode):
|
||||
报告结构列表
|
||||
"""
|
||||
try:
|
||||
self.log_info(f"正在为查询生成报告结构: {self.query}")
|
||||
logger.info(f"正在为查询生成报告结构: {self.query}")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
||||
@@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
|
||||
logger.info(f"成功生成 {len(processed_response)} 个段落结构")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成报告结构失败: {str(e)}")
|
||||
logger.exception(f"生成报告结构失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> List[Dict[str, str]]:
|
||||
@@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
report_structure = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
report_structure = extract_clean_response(cleaned_output)
|
||||
if "error" in report_structure:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
logger.exception("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
report_structure = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
logger.exception("JSON修复失败")
|
||||
# 返回默认结构
|
||||
return self._generate_default_structure()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认结构")
|
||||
logger.exception("无法修复JSON,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
# 验证结构
|
||||
if not isinstance(report_structure, list):
|
||||
self.log_info("报告结构不是列表,尝试转换...")
|
||||
logger.info("报告结构不是列表,尝试转换...")
|
||||
if isinstance(report_structure, dict):
|
||||
# 如果是单个对象,包装成列表
|
||||
report_structure = [report_structure]
|
||||
else:
|
||||
self.log_error("报告结构格式无效,使用默认结构")
|
||||
logger.exception("报告结构格式无效,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
# 验证每个段落
|
||||
validated_structure = []
|
||||
for i, paragraph in enumerate(report_structure):
|
||||
if not isinstance(paragraph, dict):
|
||||
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
|
||||
logger.warning(f"段落 {i+1} 不是字典格式,跳过")
|
||||
continue
|
||||
|
||||
title = paragraph.get("title", f"段落 {i+1}")
|
||||
content = paragraph.get("content", "")
|
||||
|
||||
if not title or not content:
|
||||
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
|
||||
logger.warning(f"段落 {i+1} 缺少标题或内容,跳过")
|
||||
continue
|
||||
|
||||
validated_structure.append({
|
||||
@@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode):
|
||||
})
|
||||
|
||||
if not validated_structure:
|
||||
self.log_warning("没有有效的段落结构,使用默认结构")
|
||||
logger.warning("没有有效的段落结构,使用默认结构")
|
||||
return self._generate_default_structure()
|
||||
|
||||
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
|
||||
logger.info(f"成功验证 {len(validated_structure)} 个段落结构")
|
||||
return validated_structure
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
return self._generate_default_structure()
|
||||
|
||||
def _generate_default_structure(self) -> List[Dict[str, str]]:
|
||||
@@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode):
|
||||
Returns:
|
||||
默认的报告结构列表
|
||||
"""
|
||||
self.log_info("生成默认报告结构")
|
||||
logger.info("生成默认报告结构")
|
||||
return [
|
||||
{
|
||||
"title": "研究概述",
|
||||
@@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode):
|
||||
content=paragraph_data["content"]
|
||||
)
|
||||
|
||||
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
|
||||
logger.info(f"已将 {len(report_structure)} 个段落添加到状态中")
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
logger.exception(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from json.decoder import JSONDecodeError
|
||||
from loguru import logger
|
||||
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
|
||||
@@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode):
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在生成首次搜索查询")
|
||||
logger.info("正在生成首次搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
||||
@@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次搜索查询失败: {str(e)}")
|
||||
logger.exception(f"生成首次搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
@@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
logger.error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
logger.error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_search_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
logger.error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
# 验证和清理结果
|
||||
@@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode):
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
logger.warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_search_query()
|
||||
|
||||
return {
|
||||
@@ -197,7 +198,7 @@ class ReflectionNode(BaseNode):
|
||||
else:
|
||||
message = json.dumps(input_data, ensure_ascii=False)
|
||||
|
||||
self.log_info("正在进行反思并生成新搜索查询")
|
||||
logger.info("正在进行反思并生成新搜索查询")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
|
||||
@@ -205,11 +206,11 @@ class ReflectionNode(BaseNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"反思生成搜索查询失败: {str(e)}")
|
||||
logger.exception(f"反思生成搜索查询失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> Dict[str, str]:
|
||||
@@ -228,30 +229,30 @@ class ReflectionNode(BaseNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 使用更强大的提取方法
|
||||
result = extract_clean_response(cleaned_output)
|
||||
if "error" in result:
|
||||
self.log_error("JSON解析失败,尝试修复...")
|
||||
logger.error("JSON解析失败,尝试修复...")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_error("JSON修复失败")
|
||||
logger.error("JSON修复失败")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
else:
|
||||
self.log_error("无法修复JSON,使用默认查询")
|
||||
logger.error("无法修复JSON,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
# 验证和清理结果
|
||||
@@ -259,7 +260,7 @@ class ReflectionNode(BaseNode):
|
||||
reasoning = result.get("reasoning", "")
|
||||
|
||||
if not search_query:
|
||||
self.log_warning("未找到搜索查询,使用默认查询")
|
||||
logger.warning("未找到搜索查询,使用默认查询")
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
return {
|
||||
@@ -268,7 +269,7 @@ class ReflectionNode(BaseNode):
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
# 返回默认查询
|
||||
return self._get_default_reflection_query()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from json.decoder import JSONDecodeError
|
||||
from loguru import logger
|
||||
|
||||
from .base_node import StateMutationNode
|
||||
from ..state.state import State
|
||||
@@ -27,7 +28,7 @@ try:
|
||||
FORUM_READER_AVAILABLE = True
|
||||
except ImportError:
|
||||
FORUM_READER_AVAILABLE = False
|
||||
print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
|
||||
logger.warning("无法导入forum_reader模块,将跳过HOST发言读取功能")
|
||||
|
||||
|
||||
class FirstSummaryNode(StateMutationNode):
|
||||
@@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode):
|
||||
if host_speech:
|
||||
# 将HOST发言添加到输入数据中
|
||||
data['host_speech'] = host_speech
|
||||
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
|
||||
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
|
||||
except Exception as e:
|
||||
self.log_info(f"读取HOST发言失败: {str(e)}")
|
||||
logger.exception(f"读取HOST发言失败: {str(e)}")
|
||||
|
||||
# 转换为JSON字符串
|
||||
message = json.dumps(data, ensure_ascii=False)
|
||||
@@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode):
|
||||
formatted_host = format_host_speech_for_prompt(data['host_speech'])
|
||||
message = formatted_host + "\n" + message
|
||||
|
||||
self.log_info("正在生成首次段落总结")
|
||||
logger.info("正在生成首次段落总结")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
|
||||
@@ -104,11 +105,11 @@ class FirstSummaryNode(StateMutationNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成首次段落总结")
|
||||
logger.info("成功生成首次段落总结")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成首次总结失败: {str(e)}")
|
||||
logger.exception(f"生成首次总结失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
@@ -127,26 +128,26 @@ class FirstSummaryNode(StateMutationNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_info("JSON修复失败,直接使用清理后的文本")
|
||||
logger.exception("JSON修复失败,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
else:
|
||||
self.log_info("无法修复JSON,直接使用清理后的文本")
|
||||
logger.exception("无法修复JSON,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
@@ -160,7 +161,7 @@ class FirstSummaryNode(StateMutationNode):
|
||||
return cleaned_output
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
return "段落总结生成失败"
|
||||
|
||||
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
|
||||
@@ -183,7 +184,7 @@ class FirstSummaryNode(StateMutationNode):
|
||||
# 更新状态
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
state.paragraphs[paragraph_index].research.latest_summary = summary
|
||||
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
|
||||
logger.info(f"已更新段落 {paragraph_index} 的首次总结")
|
||||
else:
|
||||
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
|
||||
|
||||
@@ -191,7 +192,7 @@ class FirstSummaryNode(StateMutationNode):
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
logger.exception(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
@@ -249,9 +250,9 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
if host_speech:
|
||||
# 将HOST发言添加到输入数据中
|
||||
data['host_speech'] = host_speech
|
||||
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
|
||||
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
|
||||
except Exception as e:
|
||||
self.log_info(f"读取HOST发言失败: {str(e)}")
|
||||
logger.exception(f"读取HOST发言失败: {str(e)}")
|
||||
|
||||
# 转换为JSON字符串
|
||||
message = json.dumps(data, ensure_ascii=False)
|
||||
@@ -261,7 +262,7 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
formatted_host = format_host_speech_for_prompt(data['host_speech'])
|
||||
message = formatted_host + "\n" + message
|
||||
|
||||
self.log_info("正在生成反思总结")
|
||||
logger.info("正在生成反思总结")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
|
||||
@@ -269,11 +270,11 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
# 处理响应
|
||||
processed_response = self.process_output(response)
|
||||
|
||||
self.log_info("成功生成反思总结")
|
||||
logger.info("成功生成反思总结")
|
||||
return processed_response
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"生成反思总结失败: {str(e)}")
|
||||
logger.exception(f"生成反思总结失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def process_output(self, output: str) -> str:
|
||||
@@ -292,26 +293,26 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
cleaned_output = clean_json_tags(cleaned_output)
|
||||
|
||||
# 记录清理后的输出用于调试
|
||||
self.log_info(f"清理后的输出: {cleaned_output}")
|
||||
logger.info(f"清理后的输出: {cleaned_output}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
result = json.loads(cleaned_output)
|
||||
self.log_info("JSON解析成功")
|
||||
logger.info("JSON解析成功")
|
||||
except JSONDecodeError as e:
|
||||
self.log_info(f"JSON解析失败: {str(e)}")
|
||||
logger.exception(f"JSON解析失败: {str(e)}")
|
||||
# 尝试修复JSON
|
||||
fixed_json = fix_incomplete_json(cleaned_output)
|
||||
if fixed_json:
|
||||
try:
|
||||
result = json.loads(fixed_json)
|
||||
self.log_info("JSON修复成功")
|
||||
logger.info("JSON修复成功")
|
||||
except JSONDecodeError:
|
||||
self.log_info("JSON修复失败,直接使用清理后的文本")
|
||||
logger.info("JSON修复失败,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
else:
|
||||
self.log_info("无法修复JSON,直接使用清理后的文本")
|
||||
logger.info("无法修复JSON,直接使用清理后的文本")
|
||||
# 如果不是JSON格式,直接返回清理后的文本
|
||||
return cleaned_output
|
||||
|
||||
@@ -325,7 +326,7 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
return cleaned_output
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"处理输出失败: {str(e)}")
|
||||
logger.exception(f"处理输出失败: {str(e)}")
|
||||
return "反思总结生成失败"
|
||||
|
||||
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
|
||||
@@ -349,7 +350,7 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
if 0 <= paragraph_index < len(state.paragraphs):
|
||||
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
|
||||
state.paragraphs[paragraph_index].research.increment_reflection()
|
||||
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
|
||||
logger.info(f"已更新段落 {paragraph_index} 的反思总结")
|
||||
else:
|
||||
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
|
||||
|
||||
@@ -357,5 +358,5 @@ class ReflectionSummaryNode(StateMutationNode):
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
self.log_error(f"状态更新失败: {str(e)}")
|
||||
logger.exception(f"状态更新失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
@@ -12,7 +12,8 @@ from dataclasses import dataclass
|
||||
|
||||
# 添加项目根目录到Python路径以导入config
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
from config import KEYWORD_OPTIMIZER_API_KEY, KEYWORD_OPTIMIZER_BASE_URL, KEYWORD_OPTIMIZER_MODEL_NAME
|
||||
from config import settings
|
||||
from loguru import logger
|
||||
|
||||
# 添加utils目录到Python路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -46,18 +47,18 @@ class KeywordOptimizer:
|
||||
api_key: 硅基流动API密钥,如果不提供则从配置文件读取
|
||||
base_url: 接口基础地址,默认使用配置文件提供的SiliconFlow地址
|
||||
"""
|
||||
self.api_key = api_key or KEYWORD_OPTIMIZER_API_KEY
|
||||
self.api_key = api_key or settings.KEYWORD_OPTIMIZER_API_KEY
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("未找到硅基流动API密钥,请在config.py中设置KEYWORD_OPTIMIZER_API_KEY")
|
||||
|
||||
self.base_url = base_url or KEYWORD_OPTIMIZER_BASE_URL
|
||||
self.base_url = base_url or settings.KEYWORD_OPTIMIZER_BASE_URL
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
self.model = model_name or KEYWORD_OPTIMIZER_MODEL_NAME
|
||||
self.model = model_name or settings.KEYWORD_OPTIMIZER_MODEL_NAME
|
||||
|
||||
def optimize_keywords(self, original_query: str, context: str = "") -> KeywordOptimizationResponse:
|
||||
"""
|
||||
@@ -70,7 +71,7 @@ class KeywordOptimizer:
|
||||
Returns:
|
||||
KeywordOptimizationResponse: 优化后的关键词列表
|
||||
"""
|
||||
print(f"🔍 关键词优化中间件: 处理查询 '{original_query}'")
|
||||
logger.info(f"🔍 关键词优化中间件: 处理查询 '{original_query}'")
|
||||
|
||||
try:
|
||||
# 构建优化prompt
|
||||
@@ -97,9 +98,13 @@ class KeywordOptimizer:
|
||||
# 验证关键词质量
|
||||
validated_keywords = self._validate_keywords(keywords)
|
||||
|
||||
print(f"✅ 优化成功: {len(validated_keywords)}个关键词")
|
||||
for i, keyword in enumerate(validated_keywords, 1):
|
||||
print(f" {i}. '{keyword}'")
|
||||
logger.info(
|
||||
f"✅ 优化成功: {len(validated_keywords)}个关键词" +
|
||||
("" if not validated_keywords else "\n" +
|
||||
"\n".join([f" {i}. '{k}'" for i, k in enumerate(validated_keywords, 1)]))
|
||||
)
|
||||
|
||||
|
||||
|
||||
return KeywordOptimizationResponse(
|
||||
original_query=original_query,
|
||||
@@ -109,7 +114,7 @@ class KeywordOptimizer:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 解析响应失败,使用备用方案: {str(e)}")
|
||||
logger.exception(f"⚠️ 解析响应失败,使用备用方案: {str(e)}")
|
||||
# 备用方案:从原始查询中提取关键词
|
||||
fallback_keywords = self._fallback_keyword_extraction(original_query)
|
||||
return KeywordOptimizationResponse(
|
||||
@@ -119,7 +124,7 @@ class KeywordOptimizer:
|
||||
success=True
|
||||
)
|
||||
else:
|
||||
print(f"❌ API调用失败: {response['error']}")
|
||||
logger.error(f"❌ API调用失败: {response['error']}")
|
||||
# 使用备用方案
|
||||
fallback_keywords = self._fallback_keyword_extraction(original_query)
|
||||
return KeywordOptimizationResponse(
|
||||
@@ -131,7 +136,7 @@ class KeywordOptimizer:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 关键词优化失败: {str(e)}")
|
||||
logger.error(f"❌ 关键词优化失败: {str(e)}")
|
||||
# 最终备用方案
|
||||
fallback_keywords = self._fallback_keyword_extraction(original_query)
|
||||
return KeywordOptimizationResponse(
|
||||
|
||||
@@ -25,10 +25,11 @@ V3.0 核心更新:
|
||||
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
import pymysql.cursors
|
||||
from loguru import logger
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
from ..utils.db import fetch_all
|
||||
from datetime import datetime, timedelta, date
|
||||
|
||||
# --- 1. 数据结构定义 ---
|
||||
@@ -69,36 +70,28 @@ class MediaCrawlerDB:
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化客户端。连接信息从环境变量自动读取:
|
||||
- DB_HOST, DB_USER, DB_PASSWORD, DB_NAME
|
||||
- DB_PORT (可选, 默认 3306)
|
||||
- DB_CHARSET (可选, 默认 utf8mb4)
|
||||
初始化客户端。
|
||||
"""
|
||||
self.db_config = {
|
||||
'host': os.getenv("DB_HOST"),
|
||||
'user': os.getenv("DB_USER"),
|
||||
'password': os.getenv("DB_PASSWORD"),
|
||||
'db': os.getenv("DB_NAME"),
|
||||
'port': int(os.getenv("DB_PORT", 3306)),
|
||||
'charset': os.getenv("DB_CHARSET", "utf8mb4"),
|
||||
'cursorclass': pymysql.cursors.DictCursor
|
||||
}
|
||||
required = ['host', 'user', 'password', 'db']
|
||||
if missing := [k for k in required if not self.db_config[k]]:
|
||||
raise ValueError(f"数据库配置缺失! 请设置环境变量或在代码中提供: {', '.join([f'DB_{k.upper()}' for k in missing])}")
|
||||
|
||||
pass
|
||||
|
||||
def _execute_query(self, query: str, params: tuple = None) -> List[Dict[str, Any]]:
|
||||
conn = None
|
||||
try:
|
||||
conn = pymysql.connect(**self.db_config)
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(query, params or ())
|
||||
return cursor.fetchall()
|
||||
except pymysql.Error as e:
|
||||
print(f"数据库查询时发生错误: {e}")
|
||||
# 获取或创建event loop
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 直接运行协程
|
||||
return loop.run_until_complete(fetch_all(query, params))
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"数据库查询时发生错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
if conn: conn.close()
|
||||
|
||||
@staticmethod
|
||||
def _to_datetime(ts: Any) -> Optional[datetime]:
|
||||
@@ -149,7 +142,7 @@ class MediaCrawlerDB:
|
||||
DBResponse: 包含按综合热度排序后的内容列表。
|
||||
"""
|
||||
params_for_log = {'time_period': time_period, 'limit': limit}
|
||||
print(f"--- TOOL: 查找热点内容 (params: {params_for_log}) ---")
|
||||
logger.info(f"--- TOOL: 查找热点内容 (params: {params_for_log}) ---")
|
||||
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(days={'24h': 1, 'week': 7}.get(time_period, 365))
|
||||
@@ -202,22 +195,28 @@ class MediaCrawlerDB:
|
||||
DBResponse: 包含所有匹配结果的聚合列表。
|
||||
"""
|
||||
params_for_log = {'topic': topic, 'limit_per_table': limit_per_table}
|
||||
print(f"--- TOOL: 全局话题搜索 (params: {params_for_log}) ---")
|
||||
logger.info(f"--- TOOL: 全局话题搜索 (params: {params_for_log}) ---")
|
||||
|
||||
search_term, all_results = f"%{topic}%", []
|
||||
search_configs = { 'bilibili_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'bilibili_video_comment': {'fields': ['content'], 'type': 'comment'}, 'douyin_aweme': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'douyin_aweme_comment': {'fields': ['content'], 'type': 'comment'}, 'kuaishou_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'kuaishou_video_comment': {'fields': ['content'], 'type': 'comment'}, 'weibo_note': {'fields': ['content', 'source_keyword'], 'type': 'note'}, 'weibo_note_comment': {'fields': ['content'], 'type': 'comment'}, 'xhs_note': {'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note'}, 'xhs_note_comment': {'fields': ['content'], 'type': 'comment'}, 'zhihu_content': {'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content'}, 'zhihu_comment': {'fields': ['content'], 'type': 'comment'}, 'tieba_note': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'note'}, 'tieba_comment': {'fields': ['content'], 'type': 'comment'}, 'daily_news': {'fields': ['title'], 'type': 'news'}, }
|
||||
|
||||
for table, config in search_configs.items():
|
||||
where_clause = " OR ".join([f"`{field}` LIKE %s" for field in config['fields']])
|
||||
query = f"SELECT * FROM `{table}` WHERE {where_clause} ORDER BY id DESC LIMIT %s"
|
||||
params = (search_term,) * len(config['fields']) + (limit_per_table,)
|
||||
raw_results = self._execute_query(query, params)
|
||||
param_dict = {}
|
||||
where_clauses = []
|
||||
for idx, field in enumerate(config['fields']):
|
||||
pname = f"term_{idx}"
|
||||
where_clauses.append(f'"{field}" LIKE :{pname}')
|
||||
param_dict[pname] = search_term
|
||||
param_dict['limit'] = limit_per_table
|
||||
where_clause = " OR ".join(where_clauses)
|
||||
query = f'SELECT * FROM "{table}" WHERE {where_clause} ORDER BY id DESC LIMIT :limit'
|
||||
raw_results = self._execute_query(query, param_dict)
|
||||
for row in raw_results:
|
||||
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
|
||||
time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date')
|
||||
all_results.append(QueryResult(
|
||||
platform=table.split('_')[0], content_type=config['type'],
|
||||
title_or_content=content[:500] if content else '',
|
||||
title_or_content=content if content else '',
|
||||
author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'),
|
||||
url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'),
|
||||
publish_time=self._to_datetime(time_key),
|
||||
@@ -241,7 +240,7 @@ class MediaCrawlerDB:
|
||||
DBResponse: 包含在指定日期范围内找到的结果的聚合列表。
|
||||
"""
|
||||
params_for_log = {'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit_per_table': limit_per_table}
|
||||
print(f"--- TOOL: 按日期搜索话题 (params: {params_for_log}) ---")
|
||||
logger.info(f"--- TOOL: 按日期搜索话题 (params: {params_for_log}) ---")
|
||||
|
||||
try:
|
||||
start_dt, end_dt = datetime.strptime(start_date, '%Y-%m-%d'), datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=1)
|
||||
@@ -257,25 +256,25 @@ class MediaCrawlerDB:
|
||||
}
|
||||
|
||||
for table, config in search_configs.items():
|
||||
topic_clause = " OR ".join([f"`{field}` LIKE %s" for field in config['fields']])
|
||||
time_col, time_type = config['time_col'], config['time_type']
|
||||
if time_type == 'sec': time_params = (int(start_dt.timestamp()), int(end_dt.timestamp()))
|
||||
elif time_type == 'ms': time_params = (int(start_dt.timestamp() * 1000), int(end_dt.timestamp() * 1000))
|
||||
elif time_type in ['str', 'date_str']: time_params = (start_dt.strftime('%Y-%m-%d'), end_dt.strftime('%Y-%m-%d'))
|
||||
else: time_params = (str(int(start_dt.timestamp())), str(int(end_dt.timestamp())))
|
||||
time_clause = f"`{time_col}` >= %s AND `{time_col}` < %s"
|
||||
if table == 'zhihu_content': time_clause = f"CAST(`{time_col}` AS UNSIGNED) >= %s AND CAST(`{time_col}` AS UNSIGNED) < %s"
|
||||
query = f"SELECT * FROM `{table}` WHERE ({topic_clause}) AND ({time_clause}) ORDER BY id DESC LIMIT %s"
|
||||
params = (search_term,) * len(config['fields']) + time_params + (limit_per_table,)
|
||||
raw_results = self._execute_query(query, params)
|
||||
param_dict = {}
|
||||
where_clauses = []
|
||||
for idx, field in enumerate(config['fields']):
|
||||
pname = f"term_{idx}"
|
||||
where_clauses.append(f'"{field}" LIKE :{pname}')
|
||||
param_dict[pname] = search_term
|
||||
param_dict['limit'] = limit_per_table
|
||||
where_clause = ' OR '.join(where_clauses)
|
||||
query = f'SELECT * FROM "{table}" WHERE {where_clause} ORDER BY id DESC LIMIT :limit'
|
||||
raw_results = self._execute_query(query, param_dict)
|
||||
for row in raw_results:
|
||||
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
|
||||
time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date')
|
||||
all_results.append(QueryResult(
|
||||
platform=table.split('_')[0], content_type=config['type'],
|
||||
title_or_content=content[:500] if content else '',
|
||||
author_nickname=row.get('nickname') or row.get('user_nickname'),
|
||||
title_or_content=content if content else '',
|
||||
author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'),
|
||||
url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'),
|
||||
publish_time=self._to_datetime(row.get(config['time_col'])),
|
||||
publish_time=self._to_datetime(time_key),
|
||||
engagement=self._extract_engagement(row),
|
||||
source_keyword=row.get('source_keyword'),
|
||||
source_table=table
|
||||
@@ -294,7 +293,7 @@ class MediaCrawlerDB:
|
||||
DBResponse: 包含匹配的评论列表。
|
||||
"""
|
||||
params_for_log = {'topic': topic, 'limit': limit}
|
||||
print(f"--- TOOL: 获取话题评论 (params: {params_for_log}) ---")
|
||||
logger.info(f"--- TOOL: 获取话题评论 (params: {params_for_log}) ---")
|
||||
|
||||
search_term = f"%{topic}%"
|
||||
comment_tables = ['bilibili_video_comment', 'douyin_aweme_comment', 'kuaishou_video_comment', 'weibo_note_comment', 'xhs_note_comment', 'zhihu_comment', 'tieba_comment']
|
||||
@@ -341,7 +340,7 @@ class MediaCrawlerDB:
|
||||
DBResponse: 包含在该平台找到的结果列表。
|
||||
"""
|
||||
params_for_log = {'platform': platform, 'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit': limit}
|
||||
print(f"--- TOOL: 平台定向搜索 (params: {params_for_log}) ---")
|
||||
logger.info(f"--- TOOL: 平台定向搜索 (params: {params_for_log}) ---")
|
||||
|
||||
all_configs = { 'bilibili': [{'table': 'bilibili_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'sec'}, {'table': 'bilibili_video_comment', 'fields': ['content'], 'type': 'comment'}], 'douyin': [{'table': 'douyin_aweme', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'douyin_aweme_comment', 'fields': ['content'], 'type': 'comment'}], 'kuaishou': [{'table': 'kuaishou_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'kuaishou_video_comment', 'fields': ['content'], 'type': 'comment'}], 'weibo': [{'table': 'weibo_note', 'fields': ['content', 'source_keyword'], 'type': 'note', 'time_col': 'create_date_time', 'time_type': 'str'}, {'table': 'weibo_note_comment', 'fields': ['content'], 'type': 'comment'}], 'xhs': [{'table': 'xhs_note', 'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note', 'time_col': 'time', 'time_type': 'ms'}, {'table': 'xhs_note_comment', 'fields': ['content'], 'type': 'comment'}], 'zhihu': [{'table': 'zhihu_content', 'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content', 'time_col': 'created_time', 'time_type': 'sec_str'}, {'table': 'zhihu_comment', 'fields': ['content'], 'type': 'comment'}], 'tieba': [{'table': 'tieba_note', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'note', 'time_col': 'publish_time', 'time_type': 'str'}, {'table': 'tieba_comment', 'fields': ['content'], 'type': 'comment'}] }
|
||||
|
||||
@@ -386,7 +385,7 @@ class MediaCrawlerDB:
|
||||
for row in raw_results:
|
||||
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
|
||||
time_key = config.get('time_col') and row.get(config.get('time_col'))
|
||||
all_results.append(QueryResult(platform=platform, content_type=config['type'], title_or_content=content[:500] if content else '', author_nickname=row.get('nickname') or row.get('user_nickname'), url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), publish_time=self._to_datetime(time_key), engagement=self._extract_engagement(row), source_keyword=row.get('source_keyword'), source_table=table))
|
||||
all_results.append(QueryResult(platform=platform, content_type=config['type'], title_or_content=content if content else '', author_nickname=row.get('nickname') or row.get('user_nickname'), url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), publish_time=self._to_datetime(time_key), engagement=self._extract_engagement(row), source_keyword=row.get('source_keyword'), source_table=table))
|
||||
|
||||
return DBResponse("search_topic_on_platform", params_for_log, results=all_results, results_count=len(all_results))
|
||||
|
||||
@@ -394,33 +393,41 @@ class MediaCrawlerDB:
|
||||
def print_response_summary(response: DBResponse):
|
||||
"""简化的打印函数,用于展示测试结果"""
|
||||
if response.error_message:
|
||||
print(f"工具 '{response.tool_name}' 执行出错: {response.error_message}")
|
||||
print("-" * 80)
|
||||
logger.info(f"工具 '{response.tool_name}' 执行出错: {response.error_message}")
|
||||
return
|
||||
|
||||
params_str = ", ".join(f"{k}='{v}'" for k, v in response.parameters.items())
|
||||
print(f"查询: 工具='{response.tool_name}', 参数=[{params_str}]")
|
||||
print(f"找到 {response.results_count} 条相关记录。")
|
||||
logger.info(f"查询: 工具='{response.tool_name}', 参数=[{params_str}]")
|
||||
logger.info(f"找到 {response.results_count} 条相关记录。")
|
||||
|
||||
if response.results:
|
||||
print("--- 前5条结果示例 ---")
|
||||
for i, res in enumerate(response.results[:5]):
|
||||
engagement_str = ", ".join(f"{k}: {v}" for k, v in res.engagement.items() if v)
|
||||
content_preview = (res.title_or_content.replace('\n', ' ')[:70] + '...') if res.title_or_content and len(res.title_or_content) > 70 else res.title_or_content
|
||||
hotness_str = f", hotness: {res.hotness_score:.2f}" if res.hotness_score > 0 else ""
|
||||
print(
|
||||
f"{i+1}. [{res.platform.upper()}/{res.content_type}] {content_preview}\n"
|
||||
f" by: {res.author_nickname}, at: {res.publish_time.strftime('%Y-%m-%d %H:%M') if res.publish_time else 'N/A'}"
|
||||
f", src_kw: '{res.source_keyword or 'N/A'}'{hotness_str}"
|
||||
f", engagement: {{{engagement_str}}}"
|
||||
# 统一为一个消息输出
|
||||
output_lines = []
|
||||
output_lines.append("==== 查询结果预览(最多前5条) ====")
|
||||
if response.results and len(response.results) > 0:
|
||||
for idx, res in enumerate(response.results[:5], 1):
|
||||
content_preview = (res.title_or_content.replace('\n', ' ')[:70] + '...') if res.title_or_content and len(res.title_or_content) > 70 else (res.title_or_content or '')
|
||||
author_str = res.author_nickname or "N/A"
|
||||
publish_time_str = res.publish_time.strftime('%Y-%m-%d %H:%M') if res.publish_time else "N/A"
|
||||
hotness_str = f", hotness: {res.hotness_score:.2f}" if getattr(res, "hotness_score", 0) > 0 else ""
|
||||
engagement_dict = getattr(res, "engagement", {}) or {}
|
||||
engagement_str = ", ".join(f"{k}: {v}" for k, v in engagement_dict.items() if v)
|
||||
output_lines.append(
|
||||
f"{idx}. [{res.platform.upper()}/{res.content_type}] {content_preview}\n"
|
||||
f" 作者: {author_str} | 时间: {publish_time_str}"
|
||||
f"{hotness_str} | 源关键词: '{res.source_keyword or 'N/A'}'\n"
|
||||
f" 链接: {res.url or 'N/A'}\n"
|
||||
f" 互动数据: {{{engagement_str}}}"
|
||||
)
|
||||
print("-" * 80)
|
||||
else:
|
||||
output_lines.append("暂无相关内容。")
|
||||
output_lines.append("=" * 60)
|
||||
logger.info('\n'.join(output_lines))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
db_agent_tools = MediaCrawlerDB()
|
||||
print("数据库工具初始化成功,开始执行测试场景...\n")
|
||||
logger.info("数据库工具初始化成功,开始执行测试场景...\n")
|
||||
|
||||
# 场景1: (新) 查找过去一周综合热度最高的内容 (不再需要sort_by)
|
||||
response1 = db_agent_tools.search_hot_content(time_period='week', limit=5)
|
||||
@@ -443,7 +450,7 @@ if __name__ == "__main__":
|
||||
print_response_summary(response5)
|
||||
|
||||
except ValueError as e:
|
||||
print(f"初始化失败: {e}")
|
||||
print("请确保相关的数据库环境变量已正确设置, 或在代码中直接提供连接信息。")
|
||||
logger.exception(f"初始化失败: {e}")
|
||||
logger.exception("请确保相关的数据库环境变量已正确设置, 或在代码中直接提供连接信息。")
|
||||
except Exception as e:
|
||||
print(f"测试过程中发生未知错误: {e}")
|
||||
logger.exception(f"测试过程中发生未知错误: {e}")
|
||||
@@ -12,8 +12,6 @@ from .text_processing import (
|
||||
format_search_results_for_prompt
|
||||
)
|
||||
|
||||
from .config import Config, load_config
|
||||
|
||||
__all__ = [
|
||||
"clean_json_tags",
|
||||
"clean_markdown_tags",
|
||||
@@ -21,6 +19,4 @@ __all__ = [
|
||||
"extract_clean_response",
|
||||
"update_state_with_search_results",
|
||||
"format_search_results_for_prompt",
|
||||
"Config",
|
||||
"load_config"
|
||||
]
|
||||
|
||||
+34
-212
@@ -6,218 +6,40 @@ Handles environment variables and config file parameters.
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
from loguru import logger
|
||||
|
||||
class Settings(BaseSettings):
|
||||
INSIGHT_ENGINE_API_KEY: Optional[str] = Field(None, description="Insight Engine LLM API密钥")
|
||||
INSIGHT_ENGINE_BASE_URL: Optional[str] = Field(None, description="Insight Engine LLM base url,可选")
|
||||
INSIGHT_ENGINE_MODEL_NAME: Optional[str] = Field(None, description="Insight Engine LLM模型名称")
|
||||
INSIGHT_ENGINE_PROVIDER: Optional[str] = Field(None, description="Insight Engine模型提供者,不再建议使用")
|
||||
DB_HOST: Optional[str] = Field(None, description="数据库主机")
|
||||
DB_USER: Optional[str] = Field(None, description="数据库用户名")
|
||||
DB_PASSWORD: Optional[str] = Field(None, description="数据库密码")
|
||||
DB_NAME: Optional[str] = Field(None, description="数据库名称")
|
||||
DB_PORT: int = Field(3306, description="数据库端口")
|
||||
DB_CHARSET: str = Field("utf8mb4", description="数据库字符集")
|
||||
DB_DIALECT: Optional[str] = Field("mysql", description="数据库方言,如mysql、postgresql等,SQLAlchemy后端选择")
|
||||
MAX_REFLECTIONS: int = Field(3, description="最大反思次数")
|
||||
MAX_PARAGRAPHS: int = Field(6, description="最大段落数")
|
||||
SEARCH_TIMEOUT: int = Field(240, description="单次搜索请求超时")
|
||||
MAX_CONTENT_LENGTH: int = Field(500000, description="搜索最大内容长度")
|
||||
DEFAULT_SEARCH_HOT_CONTENT_LIMIT: int = Field(100, description="热榜内容默认最大数")
|
||||
DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE: int = Field(50, description="按表全局话题最大数")
|
||||
DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE: int = Field(100, description="按日期话题最大数")
|
||||
DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT: int = Field(500, description="单话题评论最大数")
|
||||
DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT: int = Field(200, description="平台搜索话题最大数")
|
||||
MAX_SEARCH_RESULTS_FOR_LLM: int = Field(0, description="供LLM用搜索结果最大数")
|
||||
MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS: int = Field(0, description="高置信度情感分析最大数")
|
||||
OUTPUT_DIR: str = Field("reports", description="输出路径")
|
||||
SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态")
|
||||
|
||||
def _get_value(source, key: str, default=None):
|
||||
"""
|
||||
Helper to fetch a configuration value with environment fallback.
|
||||
"""
|
||||
value = None
|
||||
if isinstance(source, dict):
|
||||
value = source.get(key)
|
||||
else:
|
||||
value = getattr(source, key, None)
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_prefix = ""
|
||||
case_sensitive = False
|
||||
extra = "allow"
|
||||
|
||||
if value is None:
|
||||
value = os.getenv(key, default)
|
||||
return value if value not in ("", None) else default
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Insight Engine configuration."""
|
||||
|
||||
# LLM configuration
|
||||
llm_api_key: Optional[str] = None
|
||||
llm_base_url: Optional[str] = None
|
||||
llm_model_name: Optional[str] = None
|
||||
llm_provider: Optional[str] = None # kept for backward compatibility
|
||||
|
||||
# Database configuration
|
||||
db_host: Optional[str] = None
|
||||
db_user: Optional[str] = None
|
||||
db_password: Optional[str] = None
|
||||
db_name: Optional[str] = None
|
||||
db_port: int = 3306
|
||||
db_charset: str = "utf8mb4"
|
||||
|
||||
# Model behaviour configuration
|
||||
max_reflections: int = 3
|
||||
max_paragraphs: int = 6
|
||||
search_timeout: int = 240
|
||||
max_content_length: int = 500000
|
||||
|
||||
# Search result limits
|
||||
default_search_hot_content_limit: int = 100
|
||||
default_search_topic_globally_limit_per_table: int = 50
|
||||
default_search_topic_by_date_limit_per_table: int = 100
|
||||
default_get_comments_for_topic_limit: int = 500
|
||||
default_search_topic_on_platform_limit: int = 200
|
||||
max_search_results_for_llm: int = 0
|
||||
max_high_confidence_sentiment_results: int = 0
|
||||
|
||||
# Output configuration
|
||||
output_dir: str = "reports"
|
||||
save_intermediate_states: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.llm_provider and self.llm_model_name:
|
||||
# Provider is no longer used, but keep the attribute for compatibility.
|
||||
self.llm_provider = self.llm_model_name
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""Validate configuration."""
|
||||
if not self.llm_api_key:
|
||||
print("错误: Insight Engine LLM API Key 未设置 (INSIGHT_ENGINE_API_KEY)。")
|
||||
return False
|
||||
|
||||
if not self.llm_model_name:
|
||||
print("错误: Insight Engine 模型名称未设置 (INSIGHT_ENGINE_MODEL_NAME)。")
|
||||
return False
|
||||
|
||||
if not all([self.db_host, self.db_user, self.db_password, self.db_name]):
|
||||
print("错误: 数据库连接信息不完整,请检查 config.py 中的 DB_* 配置。")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_file: str) -> "Config":
|
||||
"""Create configuration from file."""
|
||||
if config_file.endswith(".py"):
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("config", config_file)
|
||||
config_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config_module)
|
||||
|
||||
return cls(
|
||||
llm_api_key=_get_value(config_module, "INSIGHT_ENGINE_API_KEY"),
|
||||
llm_base_url=_get_value(config_module, "INSIGHT_ENGINE_BASE_URL"),
|
||||
llm_model_name=_get_value(config_module, "INSIGHT_ENGINE_MODEL_NAME"),
|
||||
db_host=_get_value(config_module, "DB_HOST"),
|
||||
db_user=_get_value(config_module, "DB_USER"),
|
||||
db_password=_get_value(config_module, "DB_PASSWORD"),
|
||||
db_name=_get_value(config_module, "DB_NAME"),
|
||||
db_port=int(_get_value(config_module, "DB_PORT", 3306)),
|
||||
db_charset=_get_value(config_module, "DB_CHARSET", "utf8mb4"),
|
||||
max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 3)),
|
||||
max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 6)),
|
||||
search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
|
||||
max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 500000)),
|
||||
default_search_hot_content_limit=int(
|
||||
_get_value(config_module, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100)
|
||||
),
|
||||
default_search_topic_globally_limit_per_table=int(
|
||||
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50)
|
||||
),
|
||||
default_search_topic_by_date_limit_per_table=int(
|
||||
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100)
|
||||
),
|
||||
default_get_comments_for_topic_limit=int(
|
||||
_get_value(config_module, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500)
|
||||
),
|
||||
default_search_topic_on_platform_limit=int(
|
||||
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200)
|
||||
),
|
||||
max_search_results_for_llm=int(_get_value(config_module, "MAX_SEARCH_RESULTS_FOR_LLM", 0)),
|
||||
max_high_confidence_sentiment_results=int(
|
||||
_get_value(config_module, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0)
|
||||
),
|
||||
output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=str(
|
||||
_get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
|
||||
).lower()
|
||||
in ("true", "1", "yes"),
|
||||
)
|
||||
|
||||
# .env style configuration
|
||||
config_dict = {}
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
config_dict[key.strip()] = value.strip()
|
||||
|
||||
return cls(
|
||||
llm_api_key=_get_value(config_dict, "INSIGHT_ENGINE_API_KEY"),
|
||||
llm_base_url=_get_value(config_dict, "INSIGHT_ENGINE_BASE_URL"),
|
||||
llm_model_name=_get_value(config_dict, "INSIGHT_ENGINE_MODEL_NAME"),
|
||||
db_host=_get_value(config_dict, "DB_HOST"),
|
||||
db_user=_get_value(config_dict, "DB_USER"),
|
||||
db_password=_get_value(config_dict, "DB_PASSWORD"),
|
||||
db_name=_get_value(config_dict, "DB_NAME"),
|
||||
db_port=int(_get_value(config_dict, "DB_PORT", 3306)),
|
||||
db_charset=_get_value(config_dict, "DB_CHARSET", "utf8mb4"),
|
||||
max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 3)),
|
||||
max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 6)),
|
||||
search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
|
||||
max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 500000)),
|
||||
default_search_hot_content_limit=int(
|
||||
_get_value(config_dict, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100)
|
||||
),
|
||||
default_search_topic_globally_limit_per_table=int(
|
||||
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50)
|
||||
),
|
||||
default_search_topic_by_date_limit_per_table=int(
|
||||
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100)
|
||||
),
|
||||
default_get_comments_for_topic_limit=int(
|
||||
_get_value(config_dict, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500)
|
||||
),
|
||||
default_search_topic_on_platform_limit=int(
|
||||
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200)
|
||||
),
|
||||
max_search_results_for_llm=int(_get_value(config_dict, "MAX_SEARCH_RESULTS_FOR_LLM", 0)),
|
||||
max_high_confidence_sentiment_results=int(
|
||||
_get_value(config_dict, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0)
|
||||
),
|
||||
output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=str(
|
||||
_get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
|
||||
).lower()
|
||||
in ("true", "1", "yes"),
|
||||
)
|
||||
|
||||
|
||||
def load_config(config_file: Optional[str] = None) -> Config:
|
||||
"""
|
||||
Load configuration.
|
||||
"""
|
||||
if config_file:
|
||||
if not os.path.exists(config_file):
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_file}")
|
||||
file_to_load = config_file
|
||||
else:
|
||||
for candidate in ("config.py", "config.env", ".env"):
|
||||
if os.path.exists(candidate):
|
||||
file_to_load = candidate
|
||||
print(f"已找到配置文件: {candidate}")
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
|
||||
|
||||
config = Config.from_file(file_to_load)
|
||||
|
||||
if not config.validate():
|
||||
raise ValueError("配置校验失败,请检查 config.py 中的相关配置。")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def print_config(config: Config):
|
||||
"""Print configuration (sensitive values masked)."""
|
||||
print("\n=== Insight Engine 配置 ===")
|
||||
print(f"LLM 模型: {config.llm_model_name}")
|
||||
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
|
||||
print(f"搜索超时: {config.search_timeout} 秒")
|
||||
print(f"最长内容长度: {config.max_content_length}")
|
||||
print(f"最大反思次数: {config.max_reflections}")
|
||||
print(f"最大段落数: {config.max_paragraphs}")
|
||||
print(f"输出目录: {config.output_dir}")
|
||||
print(f"保存中间状态: {config.save_intermediate_states}")
|
||||
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
|
||||
print(f"数据库连接: {'已配置' if all([config.db_host, config.db_user, config.db_password, config.db_name]) else '未配置'}")
|
||||
print("========================\n")
|
||||
settings = Settings()
|
||||
Reference in New Issue
Block a user