Merge pull request #125 from DoiiarX/env-file-and-loguru

统一为使用基于pydantic的.env环境变量管理配置、全项目基于loguru进行日志管理、前端添加数据库类型选项
This commit is contained in:
Doiiars
2025-11-05 15:16:47 +08:00
committed by GitHub
50 changed files with 1404 additions and 1731 deletions
+2
View File
@@ -63,5 +63,7 @@ KEYWORD_OPTIMIZER_MODEL_NAME=
# ================== 网络工具配置 ====================
# Tavily API密钥,用于Tavily网络搜索。注册地址:https://www.tavily.com/
TAVILY_API_KEY=
# Bocha Web/AI Search BASEURL,用于Bocha搜索。注册地址:https://open.bochaai.com/
BOCHA_BASE_URL=
# Bocha Web Search API密钥,用于Bocha搜索。注册地址:https://open.bochaai.com/
BOCHA_WEB_SEARCH_API_KEY=
+28 -27
View File
@@ -11,13 +11,14 @@ import re
import json
from typing import Dict, Optional, List
from threading import Lock
from loguru import logger
# 导入论坛主持人模块
try:
from .llm_host import generate_host_speech
HOST_AVAILABLE = True
except ImportError:
print("ForumEngine: 论坛主持人模块未找到,将以纯监控模式运行")
logger.warning("ForumEngine: 论坛主持人模块未找到,将以纯监控模式运行")
HOST_AVAILABLE = False
class LogMonitor:
@@ -76,7 +77,7 @@ class LogMonitor:
pass # 先创建空文件
self.write_to_forum_log(f"=== ForumEngine 监控开始 - {start_time} ===", "SYSTEM")
print(f"ForumEngine: forum.log 已清空并初始化")
logger.info(f"ForumEngine: forum.log 已清空并初始化")
# 重置JSON捕获状态
self.capturing_json = {}
@@ -88,7 +89,7 @@ class LogMonitor:
self.is_host_generating = False
except Exception as e:
print(f"ForumEngine: 清空forum.log失败: {e}")
logger.exception(f"ForumEngine: 清空forum.log失败: {e}")
def write_to_forum_log(self, content: str, source: str = None):
"""写入内容到forum.log(线程安全)"""
@@ -105,7 +106,7 @@ class LogMonitor:
f.write(f"[{timestamp}] {content_one_line}\n")
f.flush()
except Exception as e:
print(f"ForumEngine: 写入forum.log失败: {e}")
logger.exception(f"ForumEngine: 写入forum.log失败: {e}")
def is_target_log_line(self, line: str) -> bool:
"""检查是否是目标日志行(SummaryNode"""
@@ -241,7 +242,7 @@ class LogMonitor:
return f"清理后的输出: {json.dumps(json_obj, ensure_ascii=False, indent=2)}"
except Exception as e:
print(f"ForumEngine: 格式化JSON时出错: {e}")
logger.exception(f"ForumEngine: 格式化JSON时出错: {e}")
return f"清理后的输出: {json.dumps(json_obj, ensure_ascii=False, indent=2)}"
def extract_node_content(self, line: str) -> Optional[str]:
@@ -331,7 +332,7 @@ class LogMonitor:
new_lines = [line.strip() for line in new_lines if line.strip()]
except Exception as e:
print(f"ForumEngine: 读取{app_name}日志失败: {e}")
logger.exception(f"ForumEngine: 读取{app_name}日志失败: {e}")
return new_lines
@@ -406,7 +407,7 @@ class LogMonitor:
self.is_host_generating = False
return
print("ForumEngine: 正在生成主持人发言...")
logger.info("ForumEngine: 正在生成主持人发言...")
# 调用主持人生成发言(传入最近5条)
host_speech = generate_host_speech(recent_speeches)
@@ -414,18 +415,18 @@ class LogMonitor:
if host_speech:
# 写入主持人发言到forum.log
self.write_to_forum_log(host_speech, "HOST")
print(f"ForumEngine: 主持人发言已记录")
logger.info(f"ForumEngine: 主持人发言已记录")
# 清空已处理的5条发言
self.agent_speeches_buffer = self.agent_speeches_buffer[5:]
else:
print("ForumEngine: 主持人发言生成失败")
logger.error("ForumEngine: 主持人发言生成失败")
# 重置生成标志
self.is_host_generating = False
except Exception as e:
print(f"ForumEngine: 触发主持人发言时出错: {e}")
logger.exception(f"ForumEngine: 触发主持人发言时出错: {e}")
self.is_host_generating = False
def _clean_content_tags(self, content: str, app_name: str) -> str:
@@ -453,7 +454,7 @@ class LogMonitor:
def monitor_logs(self):
"""智能监控日志文件"""
print("ForumEngine: 论坛创建中...")
logger.info("ForumEngine: 论坛创建中...")
# 初始化文件行数和位置 - 记录当前状态作为基线
for app_name, log_file in self.monitored_logs.items():
@@ -461,7 +462,7 @@ class LogMonitor:
self.file_positions[app_name] = self.get_file_size(log_file)
self.capturing_json[app_name] = False
self.json_buffer[app_name] = []
# print(f"ForumEngine: {app_name} 基线行数: {self.file_line_counts[app_name]}")
# logger.info(f"ForumEngine: {app_name} 基线行数: {self.file_line_counts[app_name]}")
while self.is_monitoring:
try:
@@ -484,7 +485,7 @@ class LogMonitor:
if not self.is_searching:
for line in new_lines:
if line.strip() and 'FirstSummaryNode' in line:
print(f"ForumEngine: 在{app_name}中检测到第一次论坛发表内容")
logger.info(f"ForumEngine: 在{app_name}中检测到第一次论坛发表内容")
self.is_searching = True
self.search_inactive_count = 0
# 清空forum.log开始新会话
@@ -500,7 +501,7 @@ class LogMonitor:
# 将app_name转换为大写作为标签(如 insight -> INSIGHT
source_tag = app_name.upper()
self.write_to_forum_log(content, source_tag)
# print(f"ForumEngine: 捕获 - {content}")
# logger.info(f"ForumEngine: 捕获 - {content}")
captured_any = True
# 将发言添加到缓冲区(格式化为完整的日志行)
@@ -515,7 +516,7 @@ class LogMonitor:
elif current_lines < previous_lines:
any_shrink = True
# print(f"ForumEngine: 检测到 {app_name} 日志缩短,将重置基线")
# logger.info(f"ForumEngine: 检测到 {app_name} 日志缩短,将重置基线")
# 重置文件位置到新的文件末尾
self.file_positions[app_name] = self.get_file_size(log_file)
# 重置JSON捕获状态
@@ -529,7 +530,7 @@ class LogMonitor:
if self.is_searching:
if any_shrink:
# log变短,结束当前搜索会话,重置为等待状态
# print("ForumEngine: 日志缩短,结束当前搜索会话,回到等待状态")
# logger.info("ForumEngine: 日志缩短,结束当前搜索会话,回到等待状态")
self.is_searching = False
self.search_inactive_count = 0
# 重置主持人相关状态
@@ -538,12 +539,12 @@ class LogMonitor:
# 写入结束标记
end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
self.write_to_forum_log(f"=== ForumEngine 论坛结束 - {end_time} ===", "SYSTEM")
# print("ForumEngine: 已重置基线,等待下次FirstSummaryNode触发")
# logger.info("ForumEngine: 已重置基线,等待下次FirstSummaryNode触发")
elif not any_growth and not captured_any:
# 没有增长也没有捕获内容,增加非活跃计数
self.search_inactive_count += 1
if self.search_inactive_count >= 900: # 15分钟无活动才结束
print("ForumEngine: 长时间无活动,结束论坛")
logger.info("ForumEngine: 长时间无活动,结束论坛")
self.is_searching = False
self.search_inactive_count = 0
# 重置主持人相关状态
@@ -559,17 +560,17 @@ class LogMonitor:
time.sleep(1)
except Exception as e:
print(f"ForumEngine: 论坛记录中出错: {e}")
logger.exception(f"ForumEngine: 论坛记录中出错: {e}")
import traceback
traceback.print_exc()
time.sleep(2)
print("ForumEngine: 停止论坛日志文件")
logger.info("ForumEngine: 停止论坛日志文件")
def start_monitoring(self):
"""开始智能监控"""
if self.is_monitoring:
print("ForumEngine: 论坛已经在运行中")
logger.info("ForumEngine: 论坛已经在运行中")
return False
try:
@@ -578,18 +579,18 @@ class LogMonitor:
self.monitor_thread = threading.Thread(target=self.monitor_logs, daemon=True)
self.monitor_thread.start()
print("ForumEngine: 论坛已启动")
logger.info("ForumEngine: 论坛已启动")
return True
except Exception as e:
print(f"ForumEngine: 启动论坛失败: {e}")
logger.exception(f"ForumEngine: 启动论坛失败: {e}")
self.is_monitoring = False
return False
def stop_monitoring(self):
"""停止监控"""
if not self.is_monitoring:
print("ForumEngine: 论坛未运行")
logger.info("ForumEngine: 论坛未运行")
return
try:
@@ -602,10 +603,10 @@ class LogMonitor:
end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
self.write_to_forum_log(f"=== ForumEngine 论坛结束 - {end_time} ===", "SYSTEM")
print("ForumEngine: 论坛已停止")
logger.info("ForumEngine: 论坛已停止")
except Exception as e:
print(f"ForumEngine: 停止论坛失败: {e}")
logger.exception(f"ForumEngine: 停止论坛失败: {e}")
def get_forum_log_content(self) -> List[str]:
"""获取forum.log的内容"""
@@ -617,7 +618,7 @@ class LogMonitor:
return [line.rstrip('\n\r') for line in f.readlines()]
except Exception as e:
print(f"ForumEngine: 读取forum.log失败: {e}")
logger.exception(f"ForumEngine: 读取forum.log失败: {e}")
return []
def fix_json_string(self, json_text: str) -> str:
+2 -2
View File
@@ -4,9 +4,9 @@ Deep Search Agent
"""
from .agent import DeepSearchAgent, create_agent
from .utils.config import Config, load_config
from .utils.config import settings, Settings
__version__ = "1.0.0"
__author__ = "Deep Search Agent Team"
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
__all__ = ["DeepSearchAgent", "create_agent", "settings", "Settings"]
+113 -118
View File
@@ -8,6 +8,7 @@ import os
import re
from datetime import datetime
from typing import Optional, Dict, Any, List, Union
from loguru import logger
from .llms import LLMClient
from .nodes import (
@@ -20,32 +21,25 @@ from .nodes import (
)
from .state import State
from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer
from .utils import Config, load_config, format_search_results_for_prompt
from .utils.config import settings, Settings
from .utils import format_search_results_for_prompt
class DeepSearchAgent:
"""Deep Search Agent主类"""
def __init__(self, config: Optional[Config] = None):
def __init__(self, config: Optional[Settings] = None):
"""
初始化Deep Search Agent
Args:
config: 配置对象,如果不提供则自动加载
config: 可选配置对象(不填则用全局settings
"""
# 加载配置
self.config = config or load_config()
self.config = config or settings
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
# 设置数据库环境变量
os.environ["DB_HOST"] = self.config.db_host or ""
os.environ["DB_USER"] = self.config.db_user or ""
os.environ["DB_PASSWORD"] = self.config.db_password or ""
os.environ["DB_NAME"] = self.config.db_name or ""
os.environ["DB_PORT"] = str(self.config.db_port)
os.environ["DB_CHARSET"] = self.config.db_charset
# 初始化搜索工具集
self.search_agency = MediaCrawlerDB()
@@ -60,19 +54,19 @@ class DeepSearchAgent:
self.state = State()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
print(f"Insight Agent已初始化")
print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
print(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
logger.info(f"Insight Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
api_key=self.config.INSIGHT_ENGINE_API_KEY,
model_name=self.config.INSIGHT_ENGINE_MODEL_NAME,
base_url=self.config.INSIGHT_ENGINE_BASE_URL,
)
def _initialize_nodes(self):
@@ -127,7 +121,7 @@ class DeepSearchAgent:
Returns:
DBResponse对象(可能包含情感分析结果)
"""
print(f" → 执行数据库查询工具: {tool_name}")
logger.info(f" → 执行数据库查询工具: {tool_name}")
# 对于热点内容搜索,不需要关键词优化(因为不需要query参数)
if tool_name == "search_hot_content":
@@ -138,12 +132,12 @@ class DeepSearchAgent:
# 检查是否需要进行情感分析
enable_sentiment = kwargs.get("enable_sentiment", True)
if enable_sentiment and response.results and len(response.results) > 0:
print(f" 🎭 开始对热点内容进行情感分析...")
logger.info(f" 🎭 开始对热点内容进行情感分析...")
sentiment_analysis = self._perform_sentiment_analysis(response.results)
if sentiment_analysis:
# 将情感分析结果添加到响应的parameters中
response.parameters["sentiment_analysis"] = sentiment_analysis
print(f" ✅ 情感分析完成")
logger.info(f" ✅ 情感分析完成")
return response
@@ -170,32 +164,32 @@ class DeepSearchAgent:
context=f"使用{tool_name}工具进行查询"
)
print(f" 🔍 原始查询: '{query}'")
print(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
logger.info(f" 🔍 原始查询: '{query}'")
logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
# 使用优化后的关键词进行多次查询并整合结果
all_results = []
total_count = 0
for keyword in optimized_response.optimized_keywords:
print(f" 查询关键词: '{keyword}'")
logger.info(f" 查询关键词: '{keyword}'")
try:
if tool_name == "search_topic_globally":
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table)
elif tool_name == "search_topic_by_date":
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
if not start_date or not end_date:
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
elif tool_name == "get_comments_for_topic":
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
limit = self.config.default_get_comments_for_topic_limit // len(optimized_response.optimized_keywords)
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords)
limit = max(limit, 50)
response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit)
elif tool_name == "search_topic_on_platform":
@@ -203,30 +197,30 @@ class DeepSearchAgent:
start_date = kwargs.get("start_date")
end_date = kwargs.get("end_date")
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
limit = self.config.default_search_topic_on_platform_limit // len(optimized_response.optimized_keywords)
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords)
limit = max(limit, 30)
if not platform:
raise ValueError("search_topic_on_platform工具需要platform参数")
response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit)
else:
print(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.default_search_topic_globally_limit_per_table)
logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE)
# 收集结果
if response.results:
print(f" 找到 {len(response.results)} 条结果")
logger.info(f" 找到 {len(response.results)} 条结果")
all_results.extend(response.results)
total_count += len(response.results)
else:
print(f" 未找到结果")
logger.info(f" 未找到结果")
except Exception as e:
print(f" 查询'{keyword}'时出错: {str(e)}")
logger.error(f" 查询'{keyword}'时出错: {str(e)}")
continue
# 去重和整合结果
unique_results = self._deduplicate_results(all_results)
print(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)}")
logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)}")
# 构建整合后的响应
integrated_response = DBResponse(
@@ -244,12 +238,12 @@ class DeepSearchAgent:
# 检查是否需要进行情感分析
enable_sentiment = kwargs.get("enable_sentiment", True)
if enable_sentiment and unique_results and len(unique_results) > 0:
print(f" 🎭 开始对搜索结果进行情感分析...")
logger.info(f" 🎭 开始对搜索结果进行情感分析...")
sentiment_analysis = self._perform_sentiment_analysis(unique_results)
if sentiment_analysis:
# 将情感分析结果添加到响应的parameters中
integrated_response.parameters["sentiment_analysis"] = sentiment_analysis
print(f" ✅ 情感分析完成")
logger.info(f" ✅ 情感分析完成")
return integrated_response
@@ -282,11 +276,11 @@ class DeepSearchAgent:
try:
# 初始化情感分析器(如果尚未初始化且未被禁用)
if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled:
print(" 初始化情感分析模型...")
logger.info(" 初始化情感分析模型...")
if not self.sentiment_analyzer.initialize():
print(" 情感分析模型初始化失败,将直接透传原始文本")
logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
elif self.sentiment_analyzer.is_disabled:
print(" 情感分析功能已禁用,直接透传原始文本")
logger.info(" 情感分析功能已禁用,直接透传原始文本")
# 将查询结果转换为字典格式
results_dict = []
@@ -310,7 +304,7 @@ class DeepSearchAgent:
return sentiment_analysis.get("sentiment_analysis")
except Exception as e:
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
return None
def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]:
@@ -323,16 +317,16 @@ class DeepSearchAgent:
Returns:
情感分析结果
"""
print(f" → 执行独立情感分析")
logger.info(f" → 执行独立情感分析")
try:
# 初始化情感分析器(如果尚未初始化且未被禁用)
if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled:
print(" 初始化情感分析模型...")
logger.info(" 初始化情感分析模型...")
if not self.sentiment_analyzer.initialize():
print(" 情感分析模型初始化失败,将直接透传原始文本")
logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
elif self.sentiment_analyzer.is_disabled:
print(" 情感分析功能已禁用,直接透传原始文本")
logger.warning(" 情感分析功能已禁用,直接透传原始文本")
# 执行分析
if isinstance(texts, str):
@@ -368,7 +362,7 @@ class DeepSearchAgent:
return response
except Exception as e:
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
return {
"success": False,
"error": str(e),
@@ -386,9 +380,9 @@ class DeepSearchAgent:
Returns:
最终报告内容
"""
print(f"\n{'='*60}")
print(f"开始深度研究: {query}")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info(f"开始深度研究: {query}")
logger.info(f"{'='*60}")
try:
# Step 1: 生成报告结构
@@ -404,19 +398,17 @@ class DeepSearchAgent:
if save_report:
self._save_report(final_report)
print(f"\n{'='*60}")
print("深度研究完成!")
print(f"{'='*60}")
logger.info("深度研究完成!")
return final_report
except Exception as e:
print(f"研究过程中发生错误: {str(e)}")
logger.exception(f"研究过程中发生错误: {str(e)}")
raise e
def _generate_report_structure(self, query: str):
"""生成报告结构"""
print(f"\n[步骤 1] 生成报告结构...")
logger.info(f"\n[步骤 1] 生成报告结构...")
# 创建报告结构节点
report_structure_node = ReportStructureNode(self.llm_client, query)
@@ -424,17 +416,18 @@ class DeepSearchAgent:
# 生成结构并更新状态
self.state = report_structure_node.mutate_state(state=self.state)
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
for i, paragraph in enumerate(self.state.paragraphs, 1):
print(f" {i}. {paragraph.title}")
_message += f"\n {i}. {paragraph.title}"
logger.info(_message)
def _process_paragraphs(self):
"""处理所有段落"""
total_paragraphs = len(self.state.paragraphs)
for i in range(total_paragraphs):
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
print("-" * 50)
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
logger.info("-" * 50)
# 初始搜索和总结
self._initial_search_and_summary(i)
@@ -446,7 +439,7 @@ class DeepSearchAgent:
self.state.paragraphs[i].research.mark_completed()
progress = (i + 1) / total_paragraphs * 100
print(f"段落处理完成 ({progress:.1f}%)")
logger.info(f"段落处理完成 ({progress:.1f}%)")
def _initial_search_and_summary(self, paragraph_index: int):
"""执行初始搜索和总结"""
@@ -459,18 +452,18 @@ class DeepSearchAgent:
}
# 生成搜索查询和工具选择
print(" - 生成搜索查询...")
logger.info(" - 生成搜索查询...")
search_output = self.first_search_node.run(search_input)
search_query = search_output["search_query"]
search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具
reasoning = search_output["reasoning"]
print(f" - 搜索查询: {search_query}")
print(f" - 选择的工具: {search_tool}")
print(f" - 推理: {reasoning}")
logger.info(f" - 搜索查询: {search_query}")
logger.info(f" - 选择的工具: {search_tool}")
logger.info(f" - 推理: {reasoning}")
# 执行搜索
print(" - 执行数据库查询...")
logger.info(" - 执行数据库查询...")
# 处理特殊参数
search_kwargs = {}
@@ -485,13 +478,13 @@ class DeepSearchAgent:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" - 时间范围: {start_date}{end_date}")
logger.info(f" - 时间范围: {start_date}{end_date}")
else:
print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "search_topic_globally"
elif search_tool == "search_topic_by_date":
print(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理需要平台参数的工具
@@ -499,28 +492,28 @@ class DeepSearchAgent:
platform = search_output.get("platform")
if platform:
search_kwargs["platform"] = platform
print(f" - 指定平台: {platform}")
logger.info(f" - 指定平台: {platform}")
else:
print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理限制参数,使用配置文件中的默认值而不是agent提供的参数
if search_tool == "search_hot_content":
time_period = search_output.get("time_period", "week")
limit = self.config.default_search_hot_content_limit
limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT
search_kwargs["time_period"] = time_period
search_kwargs["limit"] = limit
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
if search_tool == "search_topic_globally":
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
else: # search_topic_by_date
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
search_kwargs["limit_per_table"] = limit_per_table
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
if search_tool == "get_comments_for_topic":
limit = self.config.default_get_comments_for_topic_limit
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT
else: # search_topic_on_platform
limit = self.config.default_search_topic_on_platform_limit
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
search_kwargs["limit"] = limit
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -529,8 +522,8 @@ class DeepSearchAgent:
search_results = []
if search_response and search_response.results:
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
if self.config.max_search_results_for_llm > 0:
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM)
else:
max_results = len(search_response.results) # 不限制,传递所有结果
for result in search_response.results[:max_results]:
@@ -548,24 +541,25 @@ class DeepSearchAgent:
})
if search_results:
print(f" - 找到 {len(search_results)} 个搜索结果")
_message = f" - 找到 {len(search_results)} 个搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" - 未找到搜索结果")
logger.info(" - 未找到搜索结果")
# 更新状态中的搜索历史
paragraph.research.add_search_results(search_query, search_results)
# 生成初始总结
print(" - 生成初始总结...")
logger.info(" - 生成初始总结...")
summary_input = {
"title": paragraph.title,
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.MAX_CONTENT_LENGTH
)
}
@@ -574,14 +568,14 @@ class DeepSearchAgent:
summary_input, self.state, paragraph_index
)
print(" - 初始总结完成")
logger.info(" - 初始总结完成")
def _reflection_loop(self, paragraph_index: int):
"""执行反思循环"""
paragraph = self.state.paragraphs[paragraph_index]
for reflection_i in range(self.config.max_reflections):
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
for reflection_i in range(self.config.MAX_REFLECTIONS):
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
# 准备反思输入
reflection_input = {
@@ -596,9 +590,9 @@ class DeepSearchAgent:
search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具
reasoning = reflection_output["reasoning"]
print(f" 反思查询: {search_query}")
print(f" 选择的工具: {search_tool}")
print(f" 反思推理: {reasoning}")
logger.info(f" 反思查询: {search_query}")
logger.info(f" 选择的工具: {search_tool}")
logger.info(f" 反思推理: {reasoning}")
# 执行反思搜索
# 处理特殊参数
@@ -614,13 +608,13 @@ class DeepSearchAgent:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" 时间范围: {start_date}{end_date}")
logger.info(f" 时间范围: {start_date}{end_date}")
else:
print(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "search_topic_globally"
elif search_tool == "search_topic_by_date":
print(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理需要平台参数的工具
@@ -628,31 +622,31 @@ class DeepSearchAgent:
platform = reflection_output.get("platform")
if platform:
search_kwargs["platform"] = platform
print(f" 指定平台: {platform}")
logger.info(f" 指定平台: {platform}")
else:
print(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索")
search_tool = "search_topic_globally"
# 处理限制参数
if search_tool == "search_hot_content":
time_period = reflection_output.get("time_period", "week")
# 使用配置文件中的默认值,不允许agent控制limit参数
limit = self.config.default_search_hot_content_limit
limit = self.config.DEFAULT_SEARCH_HOT_CONTENT_LIMIT
search_kwargs["time_period"] = time_period
search_kwargs["limit"] = limit
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
# 使用配置文件中的默认值,不允许agent控制limit_per_table参数
if search_tool == "search_topic_globally":
limit_per_table = self.config.default_search_topic_globally_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
else: # search_topic_by_date
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
search_kwargs["limit_per_table"] = limit_per_table
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
# 使用配置文件中的默认值,不允许agent控制limit参数
if search_tool == "get_comments_for_topic":
limit = self.config.default_get_comments_for_topic_limit
limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT
else: # search_topic_on_platform
limit = self.config.default_search_topic_on_platform_limit
limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
search_kwargs["limit"] = limit
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -661,8 +655,8 @@ class DeepSearchAgent:
search_results = []
if search_response and search_response.results:
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
if self.config.max_search_results_for_llm > 0:
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM)
else:
max_results = len(search_response.results) # 不限制,传递所有结果
for result in search_response.results[:max_results]:
@@ -680,12 +674,13 @@ class DeepSearchAgent:
})
if search_results:
print(f" 找到 {len(search_results)} 个反思搜索结果")
_message = f" 找到 {len(search_results)} 个反思搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" 未找到反思搜索结果")
logger.info(" 未找到反思搜索结果")
# 更新搜索历史
paragraph.research.add_search_results(search_query, search_results)
@@ -696,7 +691,7 @@ class DeepSearchAgent:
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.MAX_CONTENT_LENGTH
),
"paragraph_latest_state": paragraph.research.latest_summary
}
@@ -706,11 +701,11 @@ class DeepSearchAgent:
reflection_summary_input, self.state, paragraph_index
)
print(f" 反思 {reflection_i + 1} 完成")
logger.info(f" 反思 {reflection_i + 1} 完成")
def _generate_final_report(self) -> str:
"""生成最终报告"""
print(f"\n[步骤 3] 生成最终报告...")
logger.info(f"\n[步骤 3] 生成最终报告...")
# 准备报告数据
report_data = []
@@ -724,7 +719,7 @@ class DeepSearchAgent:
try:
final_report = self.report_formatting_node.run(report_data)
except Exception as e:
print(f"LLM格式化失败,使用备用方法: {str(e)}")
logger.exception(f"LLM格式化失败,使用备用方法: {str(e)}")
final_report = self.report_formatting_node.format_report_manually(
report_data, self.state.report_title
)
@@ -733,7 +728,7 @@ class DeepSearchAgent:
self.state.final_report = final_report
self.state.mark_completed()
print("最终报告生成完成")
logger.info("最终报告生成完成")
return final_report
def _save_report(self, report_content: str):
@@ -744,20 +739,20 @@ class DeepSearchAgent:
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
filepath = os.path.join(self.config.output_dir, filename)
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
# 保存报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(report_content)
print(f"报告已保存到: {filepath}")
logger.info(f"报告已保存到: {filepath}")
# 保存状态(如果配置允许)
if self.config.save_intermediate_states:
if self.config.SAVE_INTERMEDIATE_STATES:
state_filename = f"state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
self.state.save_to_file(state_filepath)
print(f"状态已保存到: {state_filepath}")
logger.info(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
@@ -766,12 +761,12 @@ class DeepSearchAgent:
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = State.load_from_file(filepath)
print(f"状态已从 {filepath} 加载")
logger.info(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
print(f"状态已保存到 {filepath}")
logger.info(f"状态已保存到 {filepath}")
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
@@ -784,5 +779,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
Returns:
DeepSearchAgent实例
"""
config = load_config(config_file)
config = settings
return DeepSearchAgent(config)
+2 -2
View File
@@ -31,9 +31,9 @@ class LLMClient:
def __init__(self, api_key: str, model_name: str, base_url: Optional[str] = None):
if not api_key:
raise ValueError("Insight Engine LLM API key is required.")
raise ValueError("Insight Engine INSIGHT_ENGINE_API_KEY is required.")
if not model_name:
raise ValueError("Insight Engine model name is required.")
raise ValueError("Insight Engine INSIGHT_ENGINE_MODEL_NAME is required.")
self.api_key = api_key
self.base_url = base_url
+7 -2
View File
@@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from loguru import logger
from ..llms.base import LLMClient
from ..state.state import State
@@ -63,11 +64,15 @@ class BaseNode(ABC):
def log_info(self, message: str):
"""记录信息日志"""
print(f"[{self.node_name}] {message}")
logger.info(f"[{self.node_name}] {message}")
def log_warning(self, message: str):
"""记录警告日志"""
logger.warning(f"[{self.node_name}] 警告: {message}")
def log_error(self, message: str):
"""记录错误日志"""
print(f"[{self.node_name}] 错误: {message}")
logger.error(f"[{self.node_name}] 错误: {message}")
class StateMutationNode(BaseNode):
+13 -7
View File
@@ -5,6 +5,7 @@
import json
from typing import List, Dict, Any
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
@@ -14,6 +15,8 @@ from ..utils.text_processing import (
)
class ReportFormattingNode(BaseNode):
"""格式化最终报告的节点"""
@@ -65,19 +68,22 @@ class ReportFormattingNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在格式化最终报告")
logger.info("正在格式化最终报告")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_FORMATTING, message)
response = self.llm_client.invoke(
SYSTEM_PROMPT_REPORT_FORMATTING,
message,
)
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成格式化报告")
logger.info("成功生成格式化报告")
return processed_response
except Exception as e:
self.log_error(f"报告格式化失败: {str(e)}")
logger.exception(f"报告格式化失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -106,7 +112,7 @@ class ReportFormattingNode(BaseNode):
return cleaned_output.strip()
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
@@ -122,7 +128,7 @@ class ReportFormattingNode(BaseNode):
格式化的Markdown报告
"""
try:
self.log_info("使用手动格式化方法")
logger.info("使用手动格式化方法")
# 构建报告
report_lines = [
@@ -160,5 +166,5 @@ class ReportFormattingNode(BaseNode):
return "\n".join(report_lines)
except Exception as e:
self.log_error(f"手动格式化失败: {str(e)}")
logger.exception(f"手动格式化失败: {str(e)}")
return "# 报告生成失败\n\n无法完成报告格式化。"
+21 -20
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode):
报告结构列表
"""
try:
self.log_info(f"正在为查询生成报告结构: {self.query}")
logger.info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
@@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
logger.info(f"成功生成 {len(processed_response)} 个段落结构")
return processed_response
except Exception as e:
self.log_error(f"生成报告结构失败: {str(e)}")
logger.exception(f"生成报告结构失败: {str(e)}")
raise e
def process_output(self, output: str) -> List[Dict[str, str]]:
@@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
report_structure = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
report_structure = extract_clean_response(cleaned_output)
if "error" in report_structure:
self.log_error("JSON解析失败,尝试修复...")
logger.exception("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
report_structure = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.exception("JSON修复失败")
# 返回默认结构
return self._generate_default_structure()
else:
self.log_error("无法修复JSON,使用默认结构")
logger.exception("无法修复JSON,使用默认结构")
return self._generate_default_structure()
# 验证结构
if not isinstance(report_structure, list):
self.log_info("报告结构不是列表,尝试转换...")
logger.info("报告结构不是列表,尝试转换...")
if isinstance(report_structure, dict):
# 如果是单个对象,包装成列表
report_structure = [report_structure]
else:
self.log_error("报告结构格式无效,使用默认结构")
logger.exception("报告结构格式无效,使用默认结构")
return self._generate_default_structure()
# 验证每个段落
validated_structure = []
for i, paragraph in enumerate(report_structure):
if not isinstance(paragraph, dict):
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
logger.warning(f"段落 {i+1} 不是字典格式,跳过")
continue
title = paragraph.get("title", f"段落 {i+1}")
content = paragraph.get("content", "")
if not title or not content:
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
logger.warning(f"段落 {i+1} 缺少标题或内容,跳过")
continue
validated_structure.append({
@@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode):
})
if not validated_structure:
self.log_warning("没有有效的段落结构,使用默认结构")
logger.warning("没有有效的段落结构,使用默认结构")
return self._generate_default_structure()
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
logger.info(f"成功验证 {len(validated_structure)} 个段落结构")
return validated_structure
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return self._generate_default_structure()
def _generate_default_structure(self) -> List[Dict[str, str]]:
@@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode):
Returns:
默认的报告结构列表
"""
self.log_info("生成默认报告结构")
logger.info("生成默认报告结构")
return [
{
"title": "研究概述",
@@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode):
content=paragraph_data["content"]
)
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
logger.info(f"已将 {len(report_structure)} 个段落添加到状态中")
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+24 -23
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
@@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在生成首次搜索查询")
logger.info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
@@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"生成首次搜索查询失败: {str(e)}")
logger.exception(f"生成首次搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_search_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_search_query()
# 验证和清理结果
@@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_search_query()
return {
@@ -197,7 +198,7 @@ class ReflectionNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在进行反思并生成新搜索查询")
logger.info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
@@ -205,11 +206,11 @@ class ReflectionNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"反思生成搜索查询失败: {str(e)}")
logger.exception(f"反思生成搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -228,30 +229,30 @@ class ReflectionNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_reflection_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_reflection_query()
# 验证和清理结果
@@ -259,7 +260,7 @@ class ReflectionNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_reflection_query()
return {
@@ -268,7 +269,7 @@ class ReflectionNode(BaseNode):
}
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_reflection_query()
+30 -29
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -27,7 +28,7 @@ try:
FORUM_READER_AVAILABLE = True
except ImportError:
FORUM_READER_AVAILABLE = False
print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
logger.warning("无法导入forum_reader模块,将跳过HOST发言读取功能")
class FirstSummaryNode(StateMutationNode):
@@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成首次段落总结")
logger.info("正在生成首次段落总结")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
@@ -104,11 +105,11 @@ class FirstSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成首次段落总结")
logger.info("成功生成首次段落总结")
return processed_response
except Exception as e:
self.log_error(f"生成首次总结失败: {str(e)}")
logger.exception(f"生成首次总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -127,26 +128,26 @@ class FirstSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.exception("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.exception("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -160,7 +161,7 @@ class FirstSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "段落总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -183,7 +184,7 @@ class FirstSummaryNode(StateMutationNode):
# 更新状态
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = summary
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
logger.info(f"已更新段落 {paragraph_index} 的首次总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -191,7 +192,7 @@ class FirstSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
@@ -249,9 +250,9 @@ class ReflectionSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -261,7 +262,7 @@ class ReflectionSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成反思总结")
logger.info("正在生成反思总结")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
@@ -269,11 +270,11 @@ class ReflectionSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成反思总结")
logger.info("成功生成反思总结")
return processed_response
except Exception as e:
self.log_error(f"生成反思总结失败: {str(e)}")
logger.exception(f"生成反思总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -292,26 +293,26 @@ class ReflectionSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.info("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.info("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -325,7 +326,7 @@ class ReflectionSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "反思总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -349,7 +350,7 @@ class ReflectionSummaryNode(StateMutationNode):
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
state.paragraphs[paragraph_index].research.increment_reflection()
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
logger.info(f"已更新段落 {paragraph_index} 的反思总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -357,5 +358,5 @@ class ReflectionSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+16 -11
View File
@@ -12,7 +12,8 @@ from dataclasses import dataclass
# 添加项目根目录到Python路径以导入config
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from config import KEYWORD_OPTIMIZER_API_KEY, KEYWORD_OPTIMIZER_BASE_URL, KEYWORD_OPTIMIZER_MODEL_NAME
from config import settings
from loguru import logger
# 添加utils目录到Python路径
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -46,18 +47,18 @@ class KeywordOptimizer:
api_key: 硅基流动API密钥,如果不提供则从配置文件读取
base_url: 接口基础地址,默认使用配置文件提供的SiliconFlow地址
"""
self.api_key = api_key or KEYWORD_OPTIMIZER_API_KEY
self.api_key = api_key or settings.KEYWORD_OPTIMIZER_API_KEY
if not self.api_key:
raise ValueError("未找到硅基流动API密钥,请在config.py中设置KEYWORD_OPTIMIZER_API_KEY")
self.base_url = base_url or KEYWORD_OPTIMIZER_BASE_URL
self.base_url = base_url or settings.KEYWORD_OPTIMIZER_BASE_URL
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
self.model = model_name or KEYWORD_OPTIMIZER_MODEL_NAME
self.model = model_name or settings.KEYWORD_OPTIMIZER_MODEL_NAME
def optimize_keywords(self, original_query: str, context: str = "") -> KeywordOptimizationResponse:
"""
@@ -70,7 +71,7 @@ class KeywordOptimizer:
Returns:
KeywordOptimizationResponse: 优化后的关键词列表
"""
print(f"🔍 关键词优化中间件: 处理查询 '{original_query}'")
logger.info(f"🔍 关键词优化中间件: 处理查询 '{original_query}'")
try:
# 构建优化prompt
@@ -97,9 +98,13 @@ class KeywordOptimizer:
# 验证关键词质量
validated_keywords = self._validate_keywords(keywords)
print(f"✅ 优化成功: {len(validated_keywords)}个关键词")
for i, keyword in enumerate(validated_keywords, 1):
print(f" {i}. '{keyword}'")
logger.info(
f"✅ 优化成功: {len(validated_keywords)}个关键词" +
("" if not validated_keywords else "\n" +
"\n".join([f" {i}. '{k}'" for i, k in enumerate(validated_keywords, 1)]))
)
return KeywordOptimizationResponse(
original_query=original_query,
@@ -109,7 +114,7 @@ class KeywordOptimizer:
)
except Exception as e:
print(f"⚠️ 解析响应失败,使用备用方案: {str(e)}")
logger.exception(f"⚠️ 解析响应失败,使用备用方案: {str(e)}")
# 备用方案:从原始查询中提取关键词
fallback_keywords = self._fallback_keyword_extraction(original_query)
return KeywordOptimizationResponse(
@@ -119,7 +124,7 @@ class KeywordOptimizer:
success=True
)
else:
print(f"❌ API调用失败: {response['error']}")
logger.error(f"❌ API调用失败: {response['error']}")
# 使用备用方案
fallback_keywords = self._fallback_keyword_extraction(original_query)
return KeywordOptimizationResponse(
@@ -131,7 +136,7 @@ class KeywordOptimizer:
)
except Exception as e:
print(f"❌ 关键词优化失败: {str(e)}")
logger.error(f"❌ 关键词优化失败: {str(e)}")
# 最终备用方案
fallback_keywords = self._fallback_keyword_extraction(original_query)
return KeywordOptimizationResponse(
+79 -72
View File
@@ -25,10 +25,11 @@ V3.0 核心更新:
import os
import json
import pymysql
import pymysql.cursors
from loguru import logger
import asyncio
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
from ..utils.db import fetch_all
from datetime import datetime, timedelta, date
# --- 1. 数据结构定义 ---
@@ -69,36 +70,28 @@ class MediaCrawlerDB:
def __init__(self):
"""
初始化客户端。连接信息从环境变量自动读取:
- DB_HOST, DB_USER, DB_PASSWORD, DB_NAME
- DB_PORT (可选, 默认 3306)
- DB_CHARSET (可选, 默认 utf8mb4)
初始化客户端。
"""
self.db_config = {
'host': os.getenv("DB_HOST"),
'user': os.getenv("DB_USER"),
'password': os.getenv("DB_PASSWORD"),
'db': os.getenv("DB_NAME"),
'port': int(os.getenv("DB_PORT", 3306)),
'charset': os.getenv("DB_CHARSET", "utf8mb4"),
'cursorclass': pymysql.cursors.DictCursor
}
required = ['host', 'user', 'password', 'db']
if missing := [k for k in required if not self.db_config[k]]:
raise ValueError(f"数据库配置缺失! 请设置环境变量或在代码中提供: {', '.join([f'DB_{k.upper()}' for k in missing])}")
pass
def _execute_query(self, query: str, params: tuple = None) -> List[Dict[str, Any]]:
conn = None
try:
conn = pymysql.connect(**self.db_config)
with conn.cursor() as cursor:
cursor.execute(query, params or ())
return cursor.fetchall()
except pymysql.Error as e:
print(f"数据库查询时发生错误: {e}")
# 获取或创建event loop
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 直接运行协程
return loop.run_until_complete(fetch_all(query, params))
except Exception as e:
logger.exception(f"数据库查询时发生错误: {e}")
return []
finally:
if conn: conn.close()
@staticmethod
def _to_datetime(ts: Any) -> Optional[datetime]:
@@ -149,7 +142,7 @@ class MediaCrawlerDB:
DBResponse: 包含按综合热度排序后的内容列表。
"""
params_for_log = {'time_period': time_period, 'limit': limit}
print(f"--- TOOL: 查找热点内容 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 查找热点内容 (params: {params_for_log}) ---")
now = datetime.now()
start_time = now - timedelta(days={'24h': 1, 'week': 7}.get(time_period, 365))
@@ -202,22 +195,28 @@ class MediaCrawlerDB:
DBResponse: 包含所有匹配结果的聚合列表。
"""
params_for_log = {'topic': topic, 'limit_per_table': limit_per_table}
print(f"--- TOOL: 全局话题搜索 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 全局话题搜索 (params: {params_for_log}) ---")
search_term, all_results = f"%{topic}%", []
search_configs = { 'bilibili_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'bilibili_video_comment': {'fields': ['content'], 'type': 'comment'}, 'douyin_aweme': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'douyin_aweme_comment': {'fields': ['content'], 'type': 'comment'}, 'kuaishou_video': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'video'}, 'kuaishou_video_comment': {'fields': ['content'], 'type': 'comment'}, 'weibo_note': {'fields': ['content', 'source_keyword'], 'type': 'note'}, 'weibo_note_comment': {'fields': ['content'], 'type': 'comment'}, 'xhs_note': {'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note'}, 'xhs_note_comment': {'fields': ['content'], 'type': 'comment'}, 'zhihu_content': {'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content'}, 'zhihu_comment': {'fields': ['content'], 'type': 'comment'}, 'tieba_note': {'fields': ['title', 'desc', 'source_keyword'], 'type': 'note'}, 'tieba_comment': {'fields': ['content'], 'type': 'comment'}, 'daily_news': {'fields': ['title'], 'type': 'news'}, }
for table, config in search_configs.items():
where_clause = " OR ".join([f"`{field}` LIKE %s" for field in config['fields']])
query = f"SELECT * FROM `{table}` WHERE {where_clause} ORDER BY id DESC LIMIT %s"
params = (search_term,) * len(config['fields']) + (limit_per_table,)
raw_results = self._execute_query(query, params)
param_dict = {}
where_clauses = []
for idx, field in enumerate(config['fields']):
pname = f"term_{idx}"
where_clauses.append(f'"{field}" LIKE :{pname}')
param_dict[pname] = search_term
param_dict['limit'] = limit_per_table
where_clause = " OR ".join(where_clauses)
query = f'SELECT * FROM "{table}" WHERE {where_clause} ORDER BY id DESC LIMIT :limit'
raw_results = self._execute_query(query, param_dict)
for row in raw_results:
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date')
all_results.append(QueryResult(
platform=table.split('_')[0], content_type=config['type'],
title_or_content=content[:500] if content else '',
title_or_content=content if content else '',
author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'),
url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'),
publish_time=self._to_datetime(time_key),
@@ -241,7 +240,7 @@ class MediaCrawlerDB:
DBResponse: 包含在指定日期范围内找到的结果的聚合列表。
"""
params_for_log = {'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit_per_table': limit_per_table}
print(f"--- TOOL: 按日期搜索话题 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 按日期搜索话题 (params: {params_for_log}) ---")
try:
start_dt, end_dt = datetime.strptime(start_date, '%Y-%m-%d'), datetime.strptime(end_date, '%Y-%m-%d') + timedelta(days=1)
@@ -257,25 +256,25 @@ class MediaCrawlerDB:
}
for table, config in search_configs.items():
topic_clause = " OR ".join([f"`{field}` LIKE %s" for field in config['fields']])
time_col, time_type = config['time_col'], config['time_type']
if time_type == 'sec': time_params = (int(start_dt.timestamp()), int(end_dt.timestamp()))
elif time_type == 'ms': time_params = (int(start_dt.timestamp() * 1000), int(end_dt.timestamp() * 1000))
elif time_type in ['str', 'date_str']: time_params = (start_dt.strftime('%Y-%m-%d'), end_dt.strftime('%Y-%m-%d'))
else: time_params = (str(int(start_dt.timestamp())), str(int(end_dt.timestamp())))
time_clause = f"`{time_col}` >= %s AND `{time_col}` < %s"
if table == 'zhihu_content': time_clause = f"CAST(`{time_col}` AS UNSIGNED) >= %s AND CAST(`{time_col}` AS UNSIGNED) < %s"
query = f"SELECT * FROM `{table}` WHERE ({topic_clause}) AND ({time_clause}) ORDER BY id DESC LIMIT %s"
params = (search_term,) * len(config['fields']) + time_params + (limit_per_table,)
raw_results = self._execute_query(query, params)
param_dict = {}
where_clauses = []
for idx, field in enumerate(config['fields']):
pname = f"term_{idx}"
where_clauses.append(f'"{field}" LIKE :{pname}')
param_dict[pname] = search_term
param_dict['limit'] = limit_per_table
where_clause = ' OR '.join(where_clauses)
query = f'SELECT * FROM "{table}" WHERE {where_clause} ORDER BY id DESC LIMIT :limit'
raw_results = self._execute_query(query, param_dict)
for row in raw_results:
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
time_key = row.get('create_time') or row.get('time') or row.get('created_time') or row.get('publish_time') or row.get('crawl_date')
all_results.append(QueryResult(
platform=table.split('_')[0], content_type=config['type'],
title_or_content=content[:500] if content else '',
author_nickname=row.get('nickname') or row.get('user_nickname'),
title_or_content=content if content else '',
author_nickname=row.get('nickname') or row.get('user_nickname') or row.get('user_name'),
url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'),
publish_time=self._to_datetime(row.get(config['time_col'])),
publish_time=self._to_datetime(time_key),
engagement=self._extract_engagement(row),
source_keyword=row.get('source_keyword'),
source_table=table
@@ -294,7 +293,7 @@ class MediaCrawlerDB:
DBResponse: 包含匹配的评论列表。
"""
params_for_log = {'topic': topic, 'limit': limit}
print(f"--- TOOL: 获取话题评论 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 获取话题评论 (params: {params_for_log}) ---")
search_term = f"%{topic}%"
comment_tables = ['bilibili_video_comment', 'douyin_aweme_comment', 'kuaishou_video_comment', 'weibo_note_comment', 'xhs_note_comment', 'zhihu_comment', 'tieba_comment']
@@ -341,7 +340,7 @@ class MediaCrawlerDB:
DBResponse: 包含在该平台找到的结果列表。
"""
params_for_log = {'platform': platform, 'topic': topic, 'start_date': start_date, 'end_date': end_date, 'limit': limit}
print(f"--- TOOL: 平台定向搜索 (params: {params_for_log}) ---")
logger.info(f"--- TOOL: 平台定向搜索 (params: {params_for_log}) ---")
all_configs = { 'bilibili': [{'table': 'bilibili_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'sec'}, {'table': 'bilibili_video_comment', 'fields': ['content'], 'type': 'comment'}], 'douyin': [{'table': 'douyin_aweme', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'douyin_aweme_comment', 'fields': ['content'], 'type': 'comment'}], 'kuaishou': [{'table': 'kuaishou_video', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'video', 'time_col': 'create_time', 'time_type': 'ms'}, {'table': 'kuaishou_video_comment', 'fields': ['content'], 'type': 'comment'}], 'weibo': [{'table': 'weibo_note', 'fields': ['content', 'source_keyword'], 'type': 'note', 'time_col': 'create_date_time', 'time_type': 'str'}, {'table': 'weibo_note_comment', 'fields': ['content'], 'type': 'comment'}], 'xhs': [{'table': 'xhs_note', 'fields': ['title', 'desc', 'tag_list', 'source_keyword'], 'type': 'note', 'time_col': 'time', 'time_type': 'ms'}, {'table': 'xhs_note_comment', 'fields': ['content'], 'type': 'comment'}], 'zhihu': [{'table': 'zhihu_content', 'fields': ['title', 'desc', 'content_text', 'source_keyword'], 'type': 'content', 'time_col': 'created_time', 'time_type': 'sec_str'}, {'table': 'zhihu_comment', 'fields': ['content'], 'type': 'comment'}], 'tieba': [{'table': 'tieba_note', 'fields': ['title', 'desc', 'source_keyword'], 'type': 'note', 'time_col': 'publish_time', 'time_type': 'str'}, {'table': 'tieba_comment', 'fields': ['content'], 'type': 'comment'}] }
@@ -386,7 +385,7 @@ class MediaCrawlerDB:
for row in raw_results:
content = (row.get('title') or row.get('content') or row.get('desc') or row.get('content_text', ''))
time_key = config.get('time_col') and row.get(config.get('time_col'))
all_results.append(QueryResult(platform=platform, content_type=config['type'], title_or_content=content[:500] if content else '', author_nickname=row.get('nickname') or row.get('user_nickname'), url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), publish_time=self._to_datetime(time_key), engagement=self._extract_engagement(row), source_keyword=row.get('source_keyword'), source_table=table))
all_results.append(QueryResult(platform=platform, content_type=config['type'], title_or_content=content if content else '', author_nickname=row.get('nickname') or row.get('user_nickname'), url=row.get('video_url') or row.get('note_url') or row.get('content_url') or row.get('url') or row.get('aweme_url'), publish_time=self._to_datetime(time_key), engagement=self._extract_engagement(row), source_keyword=row.get('source_keyword'), source_table=table))
return DBResponse("search_topic_on_platform", params_for_log, results=all_results, results_count=len(all_results))
@@ -394,33 +393,41 @@ class MediaCrawlerDB:
def print_response_summary(response: DBResponse):
"""简化的打印函数,用于展示测试结果"""
if response.error_message:
print(f"工具 '{response.tool_name}' 执行出错: {response.error_message}")
print("-" * 80)
logger.info(f"工具 '{response.tool_name}' 执行出错: {response.error_message}")
return
params_str = ", ".join(f"{k}='{v}'" for k, v in response.parameters.items())
print(f"查询: 工具='{response.tool_name}', 参数=[{params_str}]")
print(f"找到 {response.results_count} 条相关记录。")
logger.info(f"查询: 工具='{response.tool_name}', 参数=[{params_str}]")
logger.info(f"找到 {response.results_count} 条相关记录。")
if response.results:
print("--- 前5条结果示例 ---")
for i, res in enumerate(response.results[:5]):
engagement_str = ", ".join(f"{k}: {v}" for k, v in res.engagement.items() if v)
content_preview = (res.title_or_content.replace('\n', ' ')[:70] + '...') if res.title_or_content and len(res.title_or_content) > 70 else res.title_or_content
hotness_str = f", hotness: {res.hotness_score:.2f}" if res.hotness_score > 0 else ""
print(
f"{i+1}. [{res.platform.upper()}/{res.content_type}] {content_preview}\n"
f" by: {res.author_nickname}, at: {res.publish_time.strftime('%Y-%m-%d %H:%M') if res.publish_time else 'N/A'}"
f", src_kw: '{res.source_keyword or 'N/A'}'{hotness_str}"
f", engagement: {{{engagement_str}}}"
# 统一为一个消息输出
output_lines = []
output_lines.append("==== 查询结果预览(最多前5条) ====")
if response.results and len(response.results) > 0:
for idx, res in enumerate(response.results[:5], 1):
content_preview = (res.title_or_content.replace('\n', ' ')[:70] + '...') if res.title_or_content and len(res.title_or_content) > 70 else (res.title_or_content or '')
author_str = res.author_nickname or "N/A"
publish_time_str = res.publish_time.strftime('%Y-%m-%d %H:%M') if res.publish_time else "N/A"
hotness_str = f", hotness: {res.hotness_score:.2f}" if getattr(res, "hotness_score", 0) > 0 else ""
engagement_dict = getattr(res, "engagement", {}) or {}
engagement_str = ", ".join(f"{k}: {v}" for k, v in engagement_dict.items() if v)
output_lines.append(
f"{idx}. [{res.platform.upper()}/{res.content_type}] {content_preview}\n"
f" 作者: {author_str} | 时间: {publish_time_str}"
f"{hotness_str} | 源关键词: '{res.source_keyword or 'N/A'}'\n"
f" 链接: {res.url or 'N/A'}\n"
f" 互动数据: {{{engagement_str}}}"
)
print("-" * 80)
else:
output_lines.append("暂无相关内容。")
output_lines.append("=" * 60)
logger.info('\n'.join(output_lines))
if __name__ == "__main__":
try:
db_agent_tools = MediaCrawlerDB()
print("数据库工具初始化成功,开始执行测试场景...\n")
logger.info("数据库工具初始化成功,开始执行测试场景...\n")
# 场景1: (新) 查找过去一周综合热度最高的内容 (不再需要sort_by)
response1 = db_agent_tools.search_hot_content(time_period='week', limit=5)
@@ -443,7 +450,7 @@ if __name__ == "__main__":
print_response_summary(response5)
except ValueError as e:
print(f"初始化失败: {e}")
print("请确保相关的数据库环境变量已正确设置, 或在代码中直接提供连接信息。")
logger.exception(f"初始化失败: {e}")
logger.exception("请确保相关的数据库环境变量已正确设置, 或在代码中直接提供连接信息。")
except Exception as e:
print(f"测试过程中发生未知错误: {e}")
logger.exception(f"测试过程中发生未知错误: {e}")
-4
View File
@@ -12,8 +12,6 @@ from .text_processing import (
format_search_results_for_prompt
)
from .config import Config, load_config
__all__ = [
"clean_json_tags",
"clean_markdown_tags",
@@ -21,6 +19,4 @@ __all__ = [
"extract_clean_response",
"update_state_with_search_results",
"format_search_results_for_prompt",
"Config",
"load_config"
]
+33 -211
View File
@@ -6,218 +6,40 @@ Handles environment variables and config file parameters.
import os
from dataclasses import dataclass
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import Field
from loguru import logger
class Settings(BaseSettings):
INSIGHT_ENGINE_API_KEY: Optional[str] = Field(None, description="Insight Engine LLM API密钥")
INSIGHT_ENGINE_BASE_URL: Optional[str] = Field(None, description="Insight Engine LLM base url,可选")
INSIGHT_ENGINE_MODEL_NAME: Optional[str] = Field(None, description="Insight Engine LLM模型名称")
INSIGHT_ENGINE_PROVIDER: Optional[str] = Field(None, description="Insight Engine模型提供者,不再建议使用")
DB_HOST: Optional[str] = Field(None, description="数据库主机")
DB_USER: Optional[str] = Field(None, description="数据库用户名")
DB_PASSWORD: Optional[str] = Field(None, description="数据库密码")
DB_NAME: Optional[str] = Field(None, description="数据库名称")
DB_PORT: int = Field(3306, description="数据库端口")
DB_CHARSET: str = Field("utf8mb4", description="数据库字符集")
DB_DIALECT: Optional[str] = Field("mysql", description="数据库方言,如mysql、postgresql等,SQLAlchemy后端选择")
MAX_REFLECTIONS: int = Field(3, description="最大反思次数")
MAX_PARAGRAPHS: int = Field(6, description="最大段落数")
SEARCH_TIMEOUT: int = Field(240, description="单次搜索请求超时")
MAX_CONTENT_LENGTH: int = Field(500000, description="搜索最大内容长度")
DEFAULT_SEARCH_HOT_CONTENT_LIMIT: int = Field(100, description="热榜内容默认最大数")
DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE: int = Field(50, description="按表全局话题最大数")
DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE: int = Field(100, description="按日期话题最大数")
DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT: int = Field(500, description="单话题评论最大数")
DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT: int = Field(200, description="平台搜索话题最大数")
MAX_SEARCH_RESULTS_FOR_LLM: int = Field(0, description="供LLM用搜索结果最大数")
MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS: int = Field(0, description="高置信度情感分析最大数")
OUTPUT_DIR: str = Field("reports", description="输出路径")
SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态")
def _get_value(source, key: str, default=None):
"""
Helper to fetch a configuration value with environment fallback.
"""
value = None
if isinstance(source, dict):
value = source.get(key)
else:
value = getattr(source, key, None)
if value is None:
value = os.getenv(key, default)
return value if value not in ("", None) else default
@dataclass
class Config:
"""Insight Engine configuration."""
env_file = ".env"
env_prefix = ""
case_sensitive = False
extra = "allow"
# LLM configuration
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # kept for backward compatibility
# Database configuration
db_host: Optional[str] = None
db_user: Optional[str] = None
db_password: Optional[str] = None
db_name: Optional[str] = None
db_port: int = 3306
db_charset: str = "utf8mb4"
# Model behaviour configuration
max_reflections: int = 3
max_paragraphs: int = 6
search_timeout: int = 240
max_content_length: int = 500000
# Search result limits
default_search_hot_content_limit: int = 100
default_search_topic_globally_limit_per_table: int = 50
default_search_topic_by_date_limit_per_table: int = 100
default_get_comments_for_topic_limit: int = 500
default_search_topic_on_platform_limit: int = 200
max_search_results_for_llm: int = 0
max_high_confidence_sentiment_results: int = 0
# Output configuration
output_dir: str = "reports"
save_intermediate_states: bool = True
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
# Provider is no longer used, but keep the attribute for compatibility.
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
"""Validate configuration."""
if not self.llm_api_key:
print("错误: Insight Engine LLM API Key 未设置 (INSIGHT_ENGINE_API_KEY)。")
return False
if not self.llm_model_name:
print("错误: Insight Engine 模型名称未设置 (INSIGHT_ENGINE_MODEL_NAME)。")
return False
if not all([self.db_host, self.db_user, self.db_password, self.db_name]):
print("错误: 数据库连接信息不完整,请检查 config.py 中的 DB_* 配置。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
"""Create configuration from file."""
if config_file.endswith(".py"):
import importlib.util
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
llm_api_key=_get_value(config_module, "INSIGHT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "INSIGHT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "INSIGHT_ENGINE_MODEL_NAME"),
db_host=_get_value(config_module, "DB_HOST"),
db_user=_get_value(config_module, "DB_USER"),
db_password=_get_value(config_module, "DB_PASSWORD"),
db_name=_get_value(config_module, "DB_NAME"),
db_port=int(_get_value(config_module, "DB_PORT", 3306)),
db_charset=_get_value(config_module, "DB_CHARSET", "utf8mb4"),
max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 3)),
max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 6)),
search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 500000)),
default_search_hot_content_limit=int(
_get_value(config_module, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100)
),
default_search_topic_globally_limit_per_table=int(
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50)
),
default_search_topic_by_date_limit_per_table=int(
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100)
),
default_get_comments_for_topic_limit=int(
_get_value(config_module, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500)
),
default_search_topic_on_platform_limit=int(
_get_value(config_module, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200)
),
max_search_results_for_llm=int(_get_value(config_module, "MAX_SEARCH_RESULTS_FOR_LLM", 0)),
max_high_confidence_sentiment_results=int(
_get_value(config_module, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0)
),
output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
# .env style configuration
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "INSIGHT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "INSIGHT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "INSIGHT_ENGINE_MODEL_NAME"),
db_host=_get_value(config_dict, "DB_HOST"),
db_user=_get_value(config_dict, "DB_USER"),
db_password=_get_value(config_dict, "DB_PASSWORD"),
db_name=_get_value(config_dict, "DB_NAME"),
db_port=int(_get_value(config_dict, "DB_PORT", 3306)),
db_charset=_get_value(config_dict, "DB_CHARSET", "utf8mb4"),
max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 3)),
max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 6)),
search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 500000)),
default_search_hot_content_limit=int(
_get_value(config_dict, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100)
),
default_search_topic_globally_limit_per_table=int(
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50)
),
default_search_topic_by_date_limit_per_table=int(
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100)
),
default_get_comments_for_topic_limit=int(
_get_value(config_dict, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500)
),
default_search_topic_on_platform_limit=int(
_get_value(config_dict, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200)
),
max_search_results_for_llm=int(_get_value(config_dict, "MAX_SEARCH_RESULTS_FOR_LLM", 0)),
max_high_confidence_sentiment_results=int(
_get_value(config_dict, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0)
),
output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
def load_config(config_file: Optional[str] = None) -> Config:
"""
Load configuration.
"""
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
config = Config.from_file(file_to_load)
if not config.validate():
raise ValueError("配置校验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Config):
"""Print configuration (sensitive values masked)."""
print("\n=== Insight Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"搜索超时: {config.search_timeout}")
print(f"最长内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}")
print(f"最大段落数: {config.max_paragraphs}")
print(f"输出目录: {config.output_dir}")
print(f"保存中间状态: {config.save_intermediate_states}")
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
print(f"数据库连接: {'已配置' if all([config.db_host, config.db_user, config.db_password, config.db_name]) else '未配置'}")
print("========================\n")
settings = Settings()
+2 -2
View File
@@ -4,9 +4,9 @@ Deep Search Agent
"""
from .agent import DeepSearchAgent, create_agent
from .utils.config import Config, load_config
from .utils.config import Settings
__version__ = "1.0.0"
__author__ = "Deep Search Agent Team"
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
__all__ = ["DeepSearchAgent", "create_agent", "Settings"]
+65 -63
View File
@@ -8,7 +8,7 @@ import os
import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from loguru import logger
from .llms import LLMClient
from .nodes import (
ReportStructureNode,
@@ -20,29 +20,26 @@ from .nodes import (
)
from .state import State
from .tools import BochaMultimodalSearch, BochaResponse
from .utils import Config, load_config, format_search_results_for_prompt
from .utils import settings, Settings, format_search_results_for_prompt
class DeepSearchAgent:
"""Deep Search Agent主类"""
def __init__(self, config: Optional[Config] = None):
def __init__(self, config: Optional[Settings] = None):
"""
初始化Deep Search Agent
Args:
config: 配置对象,如果不提供则自动加载
"""
# 加载配置
self.config = config or load_config()
os.environ["BOCHA_API_KEY"] = self.config.bocha_api_key or ""
os.environ["BOCHA_WEB_SEARCH_API_KEY"] = self.config.bocha_api_key or ""
self.config = config or settings
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
# 初始化搜索工具集
self.search_agency = BochaMultimodalSearch(api_key=self.config.bocha_api_key)
self.search_agency = BochaMultimodalSearch(api_key=(self.config.BOCHA_API_KEY or self.config.BOCHA_WEB_SEARCH_API_KEY))
# 初始化节点
self._initialize_nodes()
@@ -51,18 +48,18 @@ class DeepSearchAgent:
self.state = State()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
print(f"Meida Agent已初始化")
print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)")
logger.info(f"Meida Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info(f"搜索工具集: BochaMultimodalSearch (支持5种多模态搜索工具)")
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
api_key=(self.config.MEDIA_ENGINE_API_KEY or self.config.MINDSPIDER_API_KEY),
model_name=(self.config.MEDIA_ENGINE_MODEL_NAME or self.config.MINDSPIDER_MODEL_NAME),
base_url=(self.config.MEDIA_ENGINE_BASE_URL or self.config.MINDSPIDER_BASE_URL),
)
def _initialize_nodes(self):
@@ -115,7 +112,7 @@ class DeepSearchAgent:
Returns:
BochaResponse对象
"""
print(f" → 执行搜索工具: {tool_name}")
logger.info(f" → 执行搜索工具: {tool_name}")
if tool_name == "comprehensive_search":
max_results = kwargs.get("max_results", 10)
@@ -130,7 +127,7 @@ class DeepSearchAgent:
elif tool_name == "search_last_week":
return self.search_agency.search_last_week(query)
else:
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索")
logger.info(f" ⚠️ 未知的搜索工具: {tool_name},使用默认综合搜索")
return self.search_agency.comprehensive_search(query)
def research(self, query: str, save_report: bool = True) -> str:
@@ -144,9 +141,9 @@ class DeepSearchAgent:
Returns:
最终报告内容
"""
print(f"\n{'='*60}")
print(f"开始深度研究: {query}")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info(f"开始深度研究: {query}")
logger.info(f"{'='*60}")
try:
# Step 1: 生成报告结构
@@ -162,19 +159,21 @@ class DeepSearchAgent:
if save_report:
self._save_report(final_report)
print(f"\n{'='*60}")
print("深度研究完成!")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info("深度研究完成!")
logger.info(f"{'='*60}")
return final_report
except Exception as e:
print(f"研究过程中发生错误: {str(e)}")
import traceback
error_traceback = traceback.format_exc()
logger.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}")
raise e
def _generate_report_structure(self, query: str):
"""生成报告结构"""
print(f"\n[步骤 1] 生成报告结构...")
logger.info(f"\n[步骤 1] 生成报告结构...")
# 创建报告结构节点
report_structure_node = ReportStructureNode(self.llm_client, query)
@@ -182,17 +181,18 @@ class DeepSearchAgent:
# 生成结构并更新状态
self.state = report_structure_node.mutate_state(state=self.state)
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
for i, paragraph in enumerate(self.state.paragraphs, 1):
print(f" {i}. {paragraph.title}")
_message += f"\n {i}. {paragraph.title}"
logger.info(_message)
def _process_paragraphs(self):
"""处理所有段落"""
total_paragraphs = len(self.state.paragraphs)
for i in range(total_paragraphs):
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
print("-" * 50)
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
logger.info("-" * 50)
# 初始搜索和总结
self._initial_search_and_summary(i)
@@ -204,7 +204,7 @@ class DeepSearchAgent:
self.state.paragraphs[i].research.mark_completed()
progress = (i + 1) / total_paragraphs * 100
print(f"段落处理完成 ({progress:.1f}%)")
logger.info(f"段落处理完成 ({progress:.1f}%)")
def _initial_search_and_summary(self, paragraph_index: int):
"""执行初始搜索和总结"""
@@ -217,18 +217,18 @@ class DeepSearchAgent:
}
# 生成搜索查询和工具选择
print(" - 生成搜索查询...")
logger.info(" - 生成搜索查询...")
search_output = self.first_search_node.run(search_input)
search_query = search_output["search_query"]
search_tool = search_output.get("search_tool", "comprehensive_search") # 默认工具
reasoning = search_output["reasoning"]
print(f" - 搜索查询: {search_query}")
print(f" - 选择的工具: {search_tool}")
print(f" - 推理: {reasoning}")
logger.info(f" - 搜索查询: {search_query}")
logger.info(f" - 选择的工具: {search_tool}")
logger.info(f" - 推理: {reasoning}")
# 执行搜索
print(" - 执行网络搜索...")
logger.info(" - 执行网络搜索...")
# 处理特殊参数(新的工具集不需要日期参数处理)
search_kwargs = {}
@@ -254,24 +254,25 @@ class DeepSearchAgent:
})
if search_results:
print(f" - 找到 {len(search_results)} 个搜索结果")
_message = f" - 找到 {len(search_results)} 个搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" - 未找到搜索结果")
logger.info(" - 未找到搜索结果")
# 更新状态中的搜索历史
paragraph.research.add_search_results(search_query, search_results)
# 生成初始总结
print(" - 生成初始总结...")
logger.info(" - 生成初始总结...")
summary_input = {
"title": paragraph.title,
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
)
}
@@ -280,14 +281,14 @@ class DeepSearchAgent:
summary_input, self.state, paragraph_index
)
print(" - 初始总结完成")
logger.info(" - 初始总结完成")
def _reflection_loop(self, paragraph_index: int):
"""执行反思循环"""
paragraph = self.state.paragraphs[paragraph_index]
for reflection_i in range(self.config.max_reflections):
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
for reflection_i in range(self.config.MAX_REFLECTIONS):
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
# 准备反思输入
reflection_input = {
@@ -302,9 +303,9 @@ class DeepSearchAgent:
search_tool = reflection_output.get("search_tool", "comprehensive_search") # 默认工具
reasoning = reflection_output["reasoning"]
print(f" 反思查询: {search_query}")
print(f" 选择的工具: {search_tool}")
print(f" 反思推理: {reasoning}")
logger.info(f" 反思查询: {search_query}")
logger.info(f" 选择的工具: {search_tool}")
logger.info(f" 反思推理: {reasoning}")
# 执行反思搜索
# 处理特殊参数
@@ -331,12 +332,13 @@ class DeepSearchAgent:
})
if search_results:
print(f" 找到 {len(search_results)} 个反思搜索结果")
_message = f" 找到 {len(search_results)} 个反思搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" 未找到反思搜索结果")
logger.info(" 未找到反思搜索结果")
# 更新搜索历史
paragraph.research.add_search_results(search_query, search_results)
@@ -347,7 +349,7 @@ class DeepSearchAgent:
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
),
"paragraph_latest_state": paragraph.research.latest_summary
}
@@ -357,11 +359,11 @@ class DeepSearchAgent:
reflection_summary_input, self.state, paragraph_index
)
print(f" 反思 {reflection_i + 1} 完成")
logger.info(f" 反思 {reflection_i + 1} 完成")
def _generate_final_report(self) -> str:
"""生成最终报告"""
print(f"\n[步骤 3] 生成最终报告...")
logger.info(f"\n[步骤 3] 生成最终报告...")
# 准备报告数据
report_data = []
@@ -375,7 +377,7 @@ class DeepSearchAgent:
try:
final_report = self.report_formatting_node.run(report_data)
except Exception as e:
print(f"LLM格式化失败,使用备用方法: {str(e)}")
logger.info(f"LLM格式化失败,使用备用方法: {str(e)}")
final_report = self.report_formatting_node.format_report_manually(
report_data, self.state.report_title
)
@@ -384,7 +386,7 @@ class DeepSearchAgent:
self.state.final_report = final_report
self.state.mark_completed()
print("最终报告生成完成")
logger.info("最终报告生成完成")
return final_report
def _save_report(self, report_content: str):
@@ -395,20 +397,20 @@ class DeepSearchAgent:
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
filepath = os.path.join(self.config.output_dir, filename)
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
# 保存报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(report_content)
print(f"报告已保存到: {filepath}")
logger.info(f"报告已保存到: {filepath}")
# 保存状态(如果配置允许)
if self.config.save_intermediate_states:
if self.config.SAVE_INTERMEDIATE_STATES:
state_filename = f"state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
self.state.save_to_file(state_filepath)
print(f"状态已保存到: {state_filepath}")
logger.info(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
@@ -417,12 +419,12 @@ class DeepSearchAgent:
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = State.load_from_file(filepath)
print(f"状态已从 {filepath} 加载")
logger.info(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
print(f"状态已保存到 {filepath}")
logger.info(f"状态已保存到 {filepath}")
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
@@ -435,5 +437,5 @@ def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
Returns:
DeepSearchAgent实例
"""
config = load_config(config_file)
return DeepSearchAgent(config)
settings = Settings()
return DeepSearchAgent(settings)
+7 -2
View File
@@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from ..llms.base import LLMClient
from ..state.state import State
from loguru import logger
class BaseNode(ABC):
@@ -63,11 +64,15 @@ class BaseNode(ABC):
def log_info(self, message: str):
"""记录信息日志"""
print(f"[{self.node_name}] {message}")
logger.info(f"[{self.node_name}] {message}")
def log_warning(self, message: str):
"""记录警告日志"""
logger.warning(f"[{self.node_name}] 警告: {message}")
def log_error(self, message: str):
"""记录错误日志"""
print(f"[{self.node_name}] 错误: {message}")
logger.error(f"[{self.node_name}] 错误: {message}")
class StateMutationNode(BaseNode):
+7 -6
View File
@@ -5,6 +5,7 @@
import json
from typing import List, Dict, Any
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
@@ -65,7 +66,7 @@ class ReportFormattingNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在格式化最终报告")
logger.info("正在格式化最终报告")
# 调用LLM生成Markdown格式
response = self.llm_client.invoke(
@@ -76,11 +77,11 @@ class ReportFormattingNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成格式化报告")
logger.info("成功生成格式化报告")
return processed_response
except Exception as e:
self.log_error(f"报告格式化失败: {str(e)}")
logger.exception(f"报告格式化失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -109,7 +110,7 @@ class ReportFormattingNode(BaseNode):
return cleaned_output.strip()
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
@@ -125,7 +126,7 @@ class ReportFormattingNode(BaseNode):
格式化的Markdown报告
"""
try:
self.log_info("使用手动格式化方法")
logger.info("使用手动格式化方法")
# 构建报告
report_lines = [
@@ -163,5 +164,5 @@ class ReportFormattingNode(BaseNode):
return "\n".join(report_lines)
except Exception as e:
self.log_error(f"手动格式化失败: {str(e)}")
logger.exception(f"手动格式化失败: {str(e)}")
return "# 报告生成失败\n\n无法完成报告格式化。"
+21 -20
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode):
报告结构列表
"""
try:
self.log_info(f"正在为查询生成报告结构: {self.query}")
logger.info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
@@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
logger.info(f"成功生成 {len(processed_response)} 个段落结构")
return processed_response
except Exception as e:
self.log_error(f"生成报告结构失败: {str(e)}")
logger.exception(f"生成报告结构失败: {str(e)}")
raise e
def process_output(self, output: str) -> List[Dict[str, str]]:
@@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
report_structure = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
report_structure = extract_clean_response(cleaned_output)
if "error" in report_structure:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
report_structure = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认结构
return self._generate_default_structure()
else:
self.log_error("无法修复JSON,使用默认结构")
logger.error("无法修复JSON,使用默认结构")
return self._generate_default_structure()
# 验证结构
if not isinstance(report_structure, list):
self.log_info("报告结构不是列表,尝试转换...")
logger.info("报告结构不是列表,尝试转换...")
if isinstance(report_structure, dict):
# 如果是单个对象,包装成列表
report_structure = [report_structure]
else:
self.log_error("报告结构格式无效,使用默认结构")
logger.error("报告结构格式无效,使用默认结构")
return self._generate_default_structure()
# 验证每个段落
validated_structure = []
for i, paragraph in enumerate(report_structure):
if not isinstance(paragraph, dict):
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
logger.warning(f"段落 {i+1} 不是字典格式,跳过")
continue
title = paragraph.get("title", f"段落 {i+1}")
content = paragraph.get("content", "")
if not title or not content:
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
logger.warning(f"段落 {i+1} 缺少标题或内容,跳过")
continue
validated_structure.append({
@@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode):
})
if not validated_structure:
self.log_warning("没有有效的段落结构,使用默认结构")
logger.warning("没有有效的段落结构,使用默认结构")
return self._generate_default_structure()
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
logger.info(f"成功验证 {len(validated_structure)} 个段落结构")
return validated_structure
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return self._generate_default_structure()
def _generate_default_structure(self) -> List[Dict[str, str]]:
@@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode):
Returns:
默认的报告结构列表
"""
self.log_info("生成默认报告结构")
logger.info("生成默认报告结构")
return [
{
"title": "研究概述",
@@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode):
content=paragraph_data["content"]
)
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
logger.info(f"已将 {len(report_structure)} 个段落添加到状态中")
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+24 -23
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
@@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在生成首次搜索查询")
logger.info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
@@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"生成首次搜索查询失败: {str(e)}")
logger.exception(f"生成首次搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_search_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_search_query()
# 验证和清理结果
@@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_search_query()
return {
@@ -197,7 +198,7 @@ class ReflectionNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在进行反思并生成新搜索查询")
logger.info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
@@ -205,11 +206,11 @@ class ReflectionNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"反思生成搜索查询失败: {str(e)}")
logger.exception(f"反思生成搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -228,30 +229,30 @@ class ReflectionNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_reflection_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_reflection_query()
# 验证和清理结果
@@ -259,7 +260,7 @@ class ReflectionNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_reflection_query()
return {
@@ -268,7 +269,7 @@ class ReflectionNode(BaseNode):
}
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_reflection_query()
+30 -29
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -27,7 +28,7 @@ try:
FORUM_READER_AVAILABLE = True
except ImportError:
FORUM_READER_AVAILABLE = False
print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
logger.warning("无法导入forum_reader模块,将跳过HOST发言读取功能")
class FirstSummaryNode(StateMutationNode):
@@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成首次段落总结")
logger.info("正在生成首次段落总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
@@ -107,11 +108,11 @@ class FirstSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成首次段落总结")
logger.info("成功生成首次段落总结")
return processed_response
except Exception as e:
self.log_error(f"生成首次总结失败: {str(e)}")
logger.exception(f"生成首次总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -130,26 +131,26 @@ class FirstSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.exception("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.exception("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -163,7 +164,7 @@ class FirstSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "段落总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -186,7 +187,7 @@ class FirstSummaryNode(StateMutationNode):
# 更新状态
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = summary
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
logger.info(f"已更新段落 {paragraph_index} 的首次总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -194,7 +195,7 @@ class FirstSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
@@ -252,9 +253,9 @@ class ReflectionSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -264,7 +265,7 @@ class ReflectionSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成反思总结")
logger.info("正在生成反思总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
@@ -275,11 +276,11 @@ class ReflectionSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成反思总结")
logger.info("成功生成反思总结")
return processed_response
except Exception as e:
self.log_error(f"生成反思总结失败: {str(e)}")
logger.exception(f"生成反思总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -298,26 +299,26 @@ class ReflectionSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.exception("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.exception("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -331,7 +332,7 @@ class ReflectionSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "反思总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -355,7 +356,7 @@ class ReflectionSummaryNode(StateMutationNode):
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
state.paragraphs[paragraph_index].research.increment_reflection()
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
logger.info(f"已更新段落 {paragraph_index} 的反思总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -363,5 +364,5 @@ class ReflectionSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+26 -23
View File
@@ -25,6 +25,9 @@ import json
import sys
from typing import List, Dict, Any, Optional, Literal
from loguru import logger
from config import settings
# 运行前请确保已安装 requests 库: pip install requests
try:
import requests
@@ -91,7 +94,7 @@ class BochaMultimodalSearch:
每个公共方法都设计为供 AI Agent 独立调用的工具
"""
BASE_URL = "https://api.bochaai.com/v1/ai-search"
BOCHA_BASE_URL = settings.BOCHA_BASE_URL or "https://api.bochaai.com/v1/ai-search"
def __init__(self, api_key: Optional[str] = None):
"""
@@ -100,7 +103,7 @@ class BochaMultimodalSearch:
api_key: Bocha API密钥若不提供则从环境变量 BOCHA_API_KEY 读取
"""
if api_key is None:
api_key = os.getenv("BOCHA_API_KEY")
api_key = settings.BOCHA_WEB_SEARCH_API_KEY
if not api_key:
raise ValueError("Bocha API Key未找到!请设置 BOCHA_API_KEY 环境变量或在初始化时提供")
@@ -178,21 +181,21 @@ class BochaMultimodalSearch:
payload.update(kwargs)
try:
response = requests.post(self.BASE_URL, headers=self._headers, json=payload, timeout=30)
response = requests.post(self.BOCHA_BASE_URL, headers=self._headers, json=payload, timeout=30)
response.raise_for_status() # 如果HTTP状态码是4xx或5xx,则抛出异常
response_dict = response.json()
if response_dict.get("code") != 200:
print(f"API返回错误: {response_dict.get('msg', '未知错误')}")
logger.error(f"API返回错误: {response_dict.get('msg', '未知错误')}")
return BochaResponse(query=query)
return self._parse_search_response(response_dict, query)
except requests.exceptions.RequestException as e:
print(f"搜索时发生网络错误: {str(e)}")
logger.exception(f"搜索时发生网络错误: {str(e)}")
raise e # 让重试机制捕获并处理
except Exception as e:
print(f"处理响应时发生未知错误: {str(e)}")
logger.exception(f"处理响应时发生未知错误: {str(e)}")
raise e # 让重试机制捕获并处理
# --- Agent 可用的工具方法 ---
@@ -203,7 +206,7 @@ class BochaMultimodalSearch:
返回网页图片AI总结追问建议和可能的模态卡这是最常用的通用搜索工具
Agent可提供搜索查询(query)和可选的最大结果数(max_results)
"""
print(f"--- TOOL: 全面综合搜索 (query: {query}) ---")
logger.info(f"--- TOOL: 全面综合搜索 (query: {query}) ---")
return self._search_internal(
query=query,
count=max_results,
@@ -215,7 +218,7 @@ class BochaMultimodalSearch:
工具纯网页搜索: 只获取网页链接和摘要不请求AI生成答案
适用于需要快速获取原始网页信息而不需要AI额外分析的场景速度更快成本更低
"""
print(f"--- TOOL: 纯网页搜索 (query: {query}) ---")
logger.info(f"--- TOOL: 纯网页搜索 (query: {query}) ---")
return self._search_internal(
query=query,
count=max_results,
@@ -228,7 +231,7 @@ class BochaMultimodalSearch:
当Agent意图是查询天气股票汇率百科定义火车票汽车参数等结构化信息时应优先使用此工具
它会返回所有信息但Agent应重点关注结果中的 `modal_cards` 部分
"""
print(f"--- TOOL: 结构化数据查询 (query: {query}) ---")
logger.info(f"--- TOOL: 结构化数据查询 (query: {query}) ---")
# 实现上与 comprehensive_search 相同,但通过命名和文档引导Agent的意图
return self._search_internal(
query=query,
@@ -241,7 +244,7 @@ class BochaMultimodalSearch:
工具搜索24小时内信息: 获取关于某个主题的最新动态
此工具专门查找过去24小时内发布的内容适用于追踪突发事件或最新进展
"""
print(f"--- TOOL: 搜索24小时内信息 (query: {query}) ---")
logger.info(f"--- TOOL: 搜索24小时内信息 (query: {query}) ---")
return self._search_internal(query=query, freshness='oneDay', answer=True)
def search_last_week(self, query: str) -> BochaResponse:
@@ -249,7 +252,7 @@ class BochaMultimodalSearch:
工具搜索本周信息: 获取关于某个主题过去一周内的主要报道
适用于进行周度舆情总结或回顾
"""
print(f"--- TOOL: 搜索本周信息 (query: {query}) ---")
logger.info(f"--- TOOL: 搜索本周信息 (query: {query}) ---")
return self._search_internal(query=query, freshness='oneWeek', answer=True)
@@ -258,27 +261,27 @@ class BochaMultimodalSearch:
def print_response_summary(response: BochaResponse):
"""简化的打印函数,用于展示测试结果"""
if not response or not response.query:
print("未能获取有效响应。")
logger.error("未能获取有效响应。")
return
print(f"\n查询: '{response.query}' | 会话ID: {response.conversation_id}")
logger.info(f"\n查询: '{response.query}' | 会话ID: {response.conversation_id}")
if response.answer:
print(f"AI摘要: {response.answer[:150]}...")
logger.info(f"AI摘要: {response.answer[:150]}...")
print(f"找到 {len(response.webpages)} 个网页, {len(response.images)} 张图片, {len(response.modal_cards)} 个模态卡。")
logger.info(f"找到 {len(response.webpages)} 个网页, {len(response.images)} 张图片, {len(response.modal_cards)} 个模态卡。")
if response.modal_cards:
first_card = response.modal_cards[0]
print(f"第一个模态卡类型: {first_card.card_type}")
logger.info(f"第一个模态卡类型: {first_card.card_type}")
if response.webpages:
first_result = response.webpages[0]
print(f"第一条网页结果: {first_result.name}")
logger.info(f"第一条网页结果: {first_result.name}")
if response.follow_ups:
print(f"建议追问: {response.follow_ups}")
logger.info(f"建议追问: {response.follow_ups}")
print("-" * 60)
logger.info("-" * 60)
if __name__ == "__main__":
@@ -297,7 +300,7 @@ if __name__ == "__main__":
print_response_summary(response2)
# 深度解析第一个模态卡
if response2.modal_cards and response2.modal_cards[0].card_type == 'weather_china':
print("天气模态卡详情:", json.dumps(response2.modal_cards[0].content, indent=2, ensure_ascii=False))
logger.info("天气模态卡详情:", json.dumps(response2.modal_cards[0].content, indent=2, ensure_ascii=False))
# 场景3: Agent需要查询特定结构化信息 - 股票
@@ -381,7 +384,7 @@ AI摘要: 量子计算商业化正在逐步推进。
------------------------------------------------------------'''
except ValueError as e:
print(f"初始化失败: {e}")
print("请确保 BOCHA_API_KEY 环境变量已正确设置。")
logger.exception(f"初始化失败: {e}")
logger.error("请确保 BOCHA_API_KEY 环境变量已正确设置。")
except Exception as e:
print(f"测试过程中发生未知错误: {e}")
logger.exception(f"测试过程中发生未知错误: {e}")
+3 -3
View File
@@ -12,7 +12,7 @@ from .text_processing import (
format_search_results_for_prompt
)
from .config import Config, load_config
from .config import Settings, settings
__all__ = [
"clean_json_tags",
@@ -21,6 +21,6 @@ __all__ = [
"extract_clean_response",
"update_state_with_search_results",
"format_search_results_for_prompt",
"Config",
"load_config"
"Settings",
"settings"
]
+72 -146
View File
@@ -1,157 +1,83 @@
"""
Configuration management module for the Media Engine.
Configuration management module for the Media Engine (pydantic_settings style).
"""
import os
from dataclasses import dataclass
from pathlib import Path
from pydantic_settings import BaseSettings
from pydantic import Field
from typing import Optional
def _get_value(source, key: str, default=None, *fallback_keys: str):
candidates = (key,) + fallback_keys
value = None
for candidate in candidates:
if isinstance(source, dict):
value = source.get(candidate)
else:
value = getattr(source, candidate, None)
if value not in (None, ""):
break
if value in (None, ""):
for candidate in candidates:
env_val = os.getenv(candidate)
if env_val not in (None, ""):
value = env_val
break
return value if value not in (None, "") else default
# 计算 .env 优先级:优先当前工作目录,其次项目根目录
PROJECT_ROOT: Path = Path(__file__).resolve().parents[2]
CWD_ENV: Path = Path.cwd() / ".env"
ENV_FILE: str = str(CWD_ENV if CWD_ENV.exists() else (PROJECT_ROOT / ".env"))
class Settings(BaseSettings):
"""
全局配置支持 .env 和环境变量自动加载
变量名与原 config.py 大写一致便于平滑过渡
"""
# ====================== 数据库配置 ======================
DB_HOST: str = Field("your_db_host", description="数据库主机,例如localhost 或 127.0.0.1。我们也提供云数据库资源便捷配置,日均10w+数据,可免费申请,联系我们:670939375@qq.com NOTE:为进行数据合规性审查与服务升级,云数据库自2025年10月1日起暂停接收新的使用申请")
DB_PORT: int = Field(3306, description="数据库端口号,默认为3306")
DB_USER: str = Field("your_db_user", description="数据库用户名")
DB_PASSWORD: str = Field("your_db_password", description="数据库密码")
DB_NAME: str = Field("your_db_name", description="数据库名称")
DB_CHARSET: str = Field("utf8mb4", description="数据库字符集,推荐utf8mb4,兼容emoji")
DB_DIALECT: str = Field("mysql", description="数据库类型,例如 'mysql''postgresql'。用于支持多种数据库后端(如 SQLAlchemy,请与连接信息共同配置)")
# ======================= LLM 相关 =======================
INSIGHT_ENGINE_API_KEY: str = Field(None, description="Insight Agent(推荐Kimihttps://platform.moonshot.cn/API密钥,用于主LLM。您可以更改每个部分LLM使用的API,🚩只要兼容OpenAI请求格式都可以,定义好KEY、BASE_URL与MODEL_NAME即可正常使用。重要提醒:我们强烈推荐您先使用推荐的配置申请API,先跑通再进行您的更改!")
INSIGHT_ENGINE_BASE_URL: Optional[str] = Field("https://api.moonshot.cn/v1", description="Insight Agent LLM接口BaseUrl,可自定义厂商API")
INSIGHT_ENGINE_MODEL_NAME: str = Field("kimi-k2-0711-preview", description="Insight Agent LLM模型名称,如kimi-k2-0711-preview")
MEDIA_ENGINE_API_KEY: str = Field(None, description="Media Agent(推荐Gemini,这里我用了一个中转厂商,你也可以换成你自己的,申请地址:https://www.chataiapi.com/API密钥")
MEDIA_ENGINE_BASE_URL: Optional[str] = Field("https://www.chataiapi.com/v1", description="Media Agent LLM接口BaseUrl")
MEDIA_ENGINE_MODEL_NAME: str = Field("gemini-2.5-pro", description="Media Agent LLM模型名称,如gemini-2.5-pro")
BOCHA_WEB_SEARCH_API_KEY: Optional[str] = Field(None, description="Bocha Web Search API Key")
BOCHA_API_KEY: Optional[str] = Field(None, description="Bocha 兼容键(别名)")
SEARCH_TIMEOUT: int = Field(240, description="搜索超时(秒)")
SEARCH_CONTENT_MAX_LENGTH: int = Field(20000, description="用于提示的最长内容长度")
MAX_REFLECTIONS: int = Field(2, description="最大反思轮数")
MAX_PARAGRAPHS: int = Field(5, description="最大段落数")
MINDSPIDER_API_KEY: Optional[str] = Field(None, description="MindSpider API密钥")
MINDSPIDER_BASE_URL: Optional[str] = Field("https://api.deepseek.com", description="MindSpider LLM接口BaseUrl")
MINDSPIDER_MODEL_NAME: str = Field("deepseek-reasoner", description="MindSpider LLM模型名称,如deepseek-reasoner")
OUTPUT_DIR: str = Field("reports", description="输出目录")
SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态")
@dataclass
QUERY_ENGINE_API_KEY: str = Field(None, description="Query Agent(推荐DeepSeekhttps://www.deepseek.com/API密钥")
QUERY_ENGINE_BASE_URL: Optional[str] = Field("https://api.deepseek.com", description="Query Agent LLM接口BaseUrl")
QUERY_ENGINE_MODEL_NAME: str = Field("deepseek-reasoner", description="Query Agent LLM模型,如deepseek-reasoner")
REPORT_ENGINE_API_KEY: str = Field(None, description="Report Agent(推荐Gemini,这里我用了一个中转厂商,你也可以换成你自己的,申请地址:https://www.chataiapi.com/API密钥")
REPORT_ENGINE_BASE_URL: Optional[str] = Field("https://www.chataiapi.com/v1", description="Report Agent LLM接口BaseUrl")
REPORT_ENGINE_MODEL_NAME: str = Field("gemini-2.5-pro", description="Report Agent LLM模型,如gemini-2.5-pro")
FORUM_HOST_API_KEY: str = Field(None, description="Forum Host(Qwen3最新模型,这里我使用了硅基流动这个平台,申请地址:https://cloud.siliconflow.cn/API密钥")
FORUM_HOST_BASE_URL: Optional[str] = Field("https://api.siliconflow.cn/v1", description="Forum Host LLM BaseUrl")
FORUM_HOST_MODEL_NAME: str = Field("Qwen/Qwen3-235B-A22B-Instruct-2507", description="Forum Host LLM模型名,如Qwen/Qwen3-235B-A22B-Instruct-2507")
KEYWORD_OPTIMIZER_API_KEY: str = Field(None, description="SQL keyword Optimizer(小参数Qwen3模型,这里我使用了硅基流动这个平台,申请地址:https://cloud.siliconflow.cn/API密钥")
KEYWORD_OPTIMIZER_BASE_URL: Optional[str] = Field("https://api.siliconflow.cn/v1", description="Keyword Optimizer BaseUrl")
KEYWORD_OPTIMIZER_MODEL_NAME: str = Field("Qwen/Qwen3-30B-A3B-Instruct-2507", description="Keyword Optimizer LLM模型名称,如Qwen/Qwen3-30B-A3B-Instruct-2507")
# ================== 网络工具配置 ====================
TAVILY_API_KEY: str = Field(None, description="Tavily API(申请地址:https://www.tavily.com/API密钥,用于Tavily网络搜索")
BOCHA_BASE_URL: Optional[str] = Field("https://api.bochaai.com/v1/ai-search", description="Bocha AI 搜索BaseUrl或博查网页搜索BaseUrl")
BOCHA_WEB_SEARCH_API_KEY: str = Field(None, description="Bocha API(申请地址:https://open.bochaai.com/API密钥,用于Bocha搜索")
class Config:
"""Media Engine configuration."""
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # compatibility
bocha_api_key: Optional[str] = None
search_timeout: int = 240
max_content_length: int = 20000
max_reflections: int = 2
max_paragraphs: int = 5
output_dir: str = "reports"
save_intermediate_states: bool = True
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
if not self.llm_api_key:
print("错误: Media Engine LLM API Key 未设置 (MEDIA_ENGINE_API_KEY)。")
return False
if not self.llm_model_name:
print("错误: Media Engine 模型名称未设置 (MEDIA_ENGINE_MODEL_NAME)。")
return False
if not self.bocha_api_key:
print("错误: Bocha API Key 未设置 (BOCHA_WEB_SEARCH_API_KEY)。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
if config_file.endswith(".py"):
import importlib.util
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
llm_api_key=_get_value(config_module, "MEDIA_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "MEDIA_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "MEDIA_ENGINE_MODEL_NAME"),
bocha_api_key=_get_value(
config_module,
"BOCHA_WEB_SEARCH_API_KEY",
None,
"BOCHA_API_KEY",
),
search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 2)),
max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 5)),
output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "MEDIA_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "MEDIA_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "MEDIA_ENGINE_MODEL_NAME"),
bocha_api_key=_get_value(
config_dict,
"BOCHA_WEB_SEARCH_API_KEY",
None,
"BOCHA_API_KEY",
),
search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 2)),
max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 5)),
output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
env_file = ENV_FILE
env_prefix = ""
case_sensitive = False
extra = "allow"
def load_config(config_file: Optional[str] = None) -> Config:
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
config = Config.from_file(file_to_load)
if not config.validate():
raise ValueError("配置校验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Config):
print("\n=== Media Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"Bocha API Key: {'已配置' if config.bocha_api_key else '未配置'}")
print(f"搜索超时: {config.search_timeout}")
print(f"最长内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}")
print(f"最大段落数: {config.max_paragraphs}")
print(f"输出目录: {config.output_dir}")
print(f"保存中间状态: {config.save_intermediate_states}")
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
print("========================\n")
settings = Settings()
+2 -2
View File
@@ -4,9 +4,9 @@ Deep Search Agent
"""
from .agent import DeepSearchAgent, create_agent
from .utils.config import Config, load_config
from .utils.config import Settings
__version__ = "1.0.0"
__author__ = "Deep Search Agent Team"
__all__ = ["DeepSearchAgent", "create_agent", "Config", "load_config"]
__all__ = ["DeepSearchAgent", "create_agent", "Settings"]
+74 -73
View File
@@ -20,13 +20,13 @@ from .nodes import (
)
from .state import State
from .tools import TavilyNewsAgency, TavilyResponse
from .utils import Config, load_config, format_search_results_for_prompt
from .utils import Settings, format_search_results_for_prompt
from loguru import logger
class DeepSearchAgent:
"""Deep Search Agent主类"""
def __init__(self, config: Optional[Config] = None):
def __init__(self, config: Optional[Settings] = None):
"""
初始化Deep Search Agent
@@ -34,14 +34,14 @@ class DeepSearchAgent:
config: 配置对象如果不提供则自动加载
"""
# 加载配置
self.config = config or load_config()
os.environ["TAVILY_API_KEY"] = self.config.tavily_api_key or ""
from .utils.config import settings
self.config = config or settings
# 初始化LLM客户端
self.llm_client = self._initialize_llm()
# 初始化搜索工具集
self.search_agency = TavilyNewsAgency(api_key=self.config.tavily_api_key)
self.search_agency = TavilyNewsAgency(api_key=self.config.TAVILY_API_KEY)
# 初始化节点
self._initialize_nodes()
@@ -50,18 +50,18 @@ class DeepSearchAgent:
self.state = State()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
print(f"Query Agent已初始化")
print(f"使用LLM: {self.llm_client.get_model_info()}")
print(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)")
logger.info(f"Query Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)")
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
api_key=self.config.QUERY_ENGINE_API_KEY,
model_name=self.config.QUERY_ENGINE_MODEL_NAME,
base_url=self.config.QUERY_ENGINE_BASE_URL,
)
def _initialize_nodes(self):
@@ -115,7 +115,7 @@ class DeepSearchAgent:
Returns:
TavilyResponse对象
"""
print(f" → 执行搜索工具: {tool_name}")
logger.info(f" → 执行搜索工具: {tool_name}")
if tool_name == "basic_search_news":
max_results = kwargs.get("max_results", 7)
@@ -135,7 +135,7 @@ class DeepSearchAgent:
raise ValueError("search_news_by_date工具需要start_date和end_date参数")
return self.search_agency.search_news_by_date(query, start_date, end_date)
else:
print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索")
logger.warning(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索")
return self.search_agency.basic_search_news(query)
def research(self, query: str, save_report: bool = True) -> str:
@@ -149,9 +149,9 @@ class DeepSearchAgent:
Returns:
最终报告内容
"""
print(f"\n{'='*60}")
print(f"开始深度研究: {query}")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info(f"开始深度研究: {query}")
logger.info(f"{'='*60}")
try:
# Step 1: 生成报告结构
@@ -167,19 +167,21 @@ class DeepSearchAgent:
if save_report:
self._save_report(final_report)
print(f"\n{'='*60}")
print("深度研究完成!")
print(f"{'='*60}")
logger.info(f"\n{'='*60}")
logger.info("深度研究完成!")
logger.info(f"{'='*60}")
return final_report
except Exception as e:
print(f"研究过程中发生错误: {str(e)}")
import traceback
error_traceback = traceback.format_exc()
logger.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}")
raise e
def _generate_report_structure(self, query: str):
"""生成报告结构"""
print(f"\n[步骤 1] 生成报告结构...")
logger.info(f"\n[步骤 1] 生成报告结构...")
# 创建报告结构节点
report_structure_node = ReportStructureNode(self.llm_client, query)
@@ -187,17 +189,18 @@ class DeepSearchAgent:
# 生成结构并更新状态
self.state = report_structure_node.mutate_state(state=self.state)
print(f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:")
_message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
for i, paragraph in enumerate(self.state.paragraphs, 1):
print(f" {i}. {paragraph.title}")
_message += f"\n {i}. {paragraph.title}"
logger.info(_message)
def _process_paragraphs(self):
"""处理所有段落"""
total_paragraphs = len(self.state.paragraphs)
for i in range(total_paragraphs):
print(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
print("-" * 50)
logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}")
logger.info("-" * 50)
# 初始搜索和总结
self._initial_search_and_summary(i)
@@ -209,7 +212,7 @@ class DeepSearchAgent:
self.state.paragraphs[i].research.mark_completed()
progress = (i + 1) / total_paragraphs * 100
print(f"段落处理完成 ({progress:.1f}%)")
logger.info(f"段落处理完成 ({progress:.1f}%)")
def _initial_search_and_summary(self, paragraph_index: int):
"""执行初始搜索和总结"""
@@ -222,18 +225,18 @@ class DeepSearchAgent:
}
# 生成搜索查询和工具选择
print(" - 生成搜索查询...")
logger.info(" - 生成搜索查询...")
search_output = self.first_search_node.run(search_input)
search_query = search_output["search_query"]
search_tool = search_output.get("search_tool", "basic_search_news") # 默认工具
reasoning = search_output["reasoning"]
print(f" - 搜索查询: {search_query}")
print(f" - 选择的工具: {search_tool}")
print(f" - 推理: {reasoning}")
logger.info(f" - 搜索查询: {search_query}")
logger.info(f" - 选择的工具: {search_tool}")
logger.info(f" - 推理: {reasoning}")
# 执行搜索
print(" - 执行网络搜索...")
logger.info(" - 执行网络搜索...")
# 处理search_news_by_date的特殊参数
search_kwargs = {}
@@ -246,13 +249,13 @@ class DeepSearchAgent:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" - 时间范围: {start_date}{end_date}")
logger.info(f" - 时间范围: {start_date}{end_date}")
else:
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "basic_search_news"
else:
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
search_tool = "basic_search_news"
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -273,24 +276,24 @@ class DeepSearchAgent:
})
if search_results:
print(f" - 找到 {len(search_results)} 个搜索结果")
_message = f" - 找到 {len(search_results)} 个搜索结果"
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
_message += f"\n {j}. {result['title'][:50]}...{date_info}"
logger.info(_message)
else:
print(" - 未找到搜索结果")
logger.info(" - 未找到搜索结果")
# 更新状态中的搜索历史
paragraph.research.add_search_results(search_query, search_results)
# 生成初始总结
print(" - 生成初始总结...")
logger.info(" - 生成初始总结...")
summary_input = {
"title": paragraph.title,
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
)
}
@@ -299,14 +302,14 @@ class DeepSearchAgent:
summary_input, self.state, paragraph_index
)
print(" - 初始总结完成")
logger.info(" - 初始总结完成")
def _reflection_loop(self, paragraph_index: int):
"""执行反思循环"""
paragraph = self.state.paragraphs[paragraph_index]
for reflection_i in range(self.config.max_reflections):
print(f" - 反思 {reflection_i + 1}/{self.config.max_reflections}...")
for reflection_i in range(self.config.MAX_REFLECTIONS):
logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
# 准备反思输入
reflection_input = {
@@ -321,9 +324,9 @@ class DeepSearchAgent:
search_tool = reflection_output.get("search_tool", "basic_search_news") # 默认工具
reasoning = reflection_output["reasoning"]
print(f" 反思查询: {search_query}")
print(f" 选择的工具: {search_tool}")
print(f" 反思推理: {reasoning}")
logger.info(f" 反思查询: {search_query}")
logger.info(f" 选择的工具: {search_tool}")
logger.info(f" 反思推理: {reasoning}")
# 执行反思搜索
# 处理search_news_by_date的特殊参数
@@ -337,13 +340,13 @@ class DeepSearchAgent:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
print(f" 时间范围: {start_date}{end_date}")
logger.info(f" 时间范围: {start_date}{end_date}")
else:
print(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
print(f" 提供的日期: start_date={start_date}, end_date={end_date}")
logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "basic_search_news"
else:
print(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
search_tool = "basic_search_news"
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
@@ -364,12 +367,12 @@ class DeepSearchAgent:
})
if search_results:
print(f" 找到 {len(search_results)} 个反思搜索结果")
logger.info(f" 找到 {len(search_results)} 个反思搜索结果")
for j, result in enumerate(search_results, 1):
date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else ""
print(f" {j}. {result['title'][:50]}...{date_info}")
logger.info(f" {j}. {result['title'][:50]}...{date_info}")
else:
print(" 未找到反思搜索结果")
logger.info(" 未找到反思搜索结果")
# 更新搜索历史
paragraph.research.add_search_results(search_query, search_results)
@@ -380,7 +383,7 @@ class DeepSearchAgent:
"content": paragraph.content,
"search_query": search_query,
"search_results": format_search_results_for_prompt(
search_results, self.config.max_content_length
search_results, self.config.SEARCH_CONTENT_MAX_LENGTH
),
"paragraph_latest_state": paragraph.research.latest_summary
}
@@ -390,11 +393,11 @@ class DeepSearchAgent:
reflection_summary_input, self.state, paragraph_index
)
print(f" 反思 {reflection_i + 1} 完成")
logger.info(f" 反思 {reflection_i + 1} 完成")
def _generate_final_report(self) -> str:
"""生成最终报告"""
print(f"\n[步骤 3] 生成最终报告...")
logger.info(f"\n[步骤 3] 生成最终报告...")
# 准备报告数据
report_data = []
@@ -408,7 +411,7 @@ class DeepSearchAgent:
try:
final_report = self.report_formatting_node.run(report_data)
except Exception as e:
print(f"LLM格式化失败,使用备用方法: {str(e)}")
logger.error(f"LLM格式化失败,使用备用方法: {str(e)}")
final_report = self.report_formatting_node.format_report_manually(
report_data, self.state.report_title
)
@@ -417,7 +420,7 @@ class DeepSearchAgent:
self.state.final_report = final_report
self.state.mark_completed()
print("最终报告生成完成")
logger.info("最终报告生成完成")
return final_report
def _save_report(self, report_content: str):
@@ -428,20 +431,20 @@ class DeepSearchAgent:
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"deep_search_report_{query_safe}_{timestamp}.md"
filepath = os.path.join(self.config.output_dir, filename)
filepath = os.path.join(self.config.OUTPUT_DIR, filename)
# 保存报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(report_content)
print(f"报告已保存到: {filepath}")
logger.info(f"报告已保存到: {filepath}")
# 保存状态(如果配置允许)
if self.config.save_intermediate_states:
if self.config.SAVE_INTERMEDIATE_STATES:
state_filename = f"state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
self.state.save_to_file(state_filepath)
print(f"状态已保存到: {state_filepath}")
logger.info(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
@@ -450,23 +453,21 @@ class DeepSearchAgent:
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = State.load_from_file(filepath)
print(f"状态已从 {filepath} 加载")
logger.info(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
print(f"状态已保存到 {filepath}")
logger.info(f"状态已保存到 {filepath}")
def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
def create_agent() -> DeepSearchAgent:
"""
创建Deep Search Agent实例的便捷函数
Args:
config_file: 配置文件路径
Returns:
DeepSearchAgent实例
"""
config = load_config(config_file)
from .utils.config import Settings
config = Settings()
return DeepSearchAgent(config)
+7 -2
View File
@@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from loguru import logger
from ..llms.base import LLMClient
from ..state.state import State
@@ -63,11 +64,15 @@ class BaseNode(ABC):
def log_info(self, message: str):
"""记录信息日志"""
print(f"[{self.node_name}] {message}")
logger.info(f"[{self.node_name}] {message}")
def log_warning(self, message: str):
"""记录警告日志"""
logger.warning(f"[{self.node_name}] 警告: {message}")
def log_error(self, message: str):
"""记录错误日志"""
print(f"[{self.node_name}] 错误: {message}")
logger.error(f"[{self.node_name}] 错误: {message}")
class StateMutationNode(BaseNode):
+7 -6
View File
@@ -7,6 +7,7 @@ import json
from typing import List, Dict, Any
from .base_node import BaseNode
from loguru import logger
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
from ..utils.text_processing import (
remove_reasoning_from_output,
@@ -65,7 +66,7 @@ class ReportFormattingNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在格式化最终报告")
logger.info("正在格式化最终报告")
# 调用LLM生成Markdown格式
response = self.llm_client.invoke(
@@ -76,11 +77,11 @@ class ReportFormattingNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成格式化报告")
logger.info("成功生成格式化报告")
return processed_response
except Exception as e:
self.log_error(f"报告格式化失败: {str(e)}")
logger.exception(f"报告格式化失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -109,7 +110,7 @@ class ReportFormattingNode(BaseNode):
return cleaned_output.strip()
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
@@ -125,7 +126,7 @@ class ReportFormattingNode(BaseNode):
格式化的Markdown报告
"""
try:
self.log_info("使用手动格式化方法")
logger.info("使用手动格式化方法")
# 构建报告
report_lines = [
@@ -163,5 +164,5 @@ class ReportFormattingNode(BaseNode):
return "\n".join(report_lines)
except Exception as e:
self.log_error(f"手动格式化失败: {str(e)}")
logger.exception(f"手动格式化失败: {str(e)}")
return "# 报告生成失败\n\n无法完成报告格式化。"
+21 -20
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -48,7 +49,7 @@ class ReportStructureNode(StateMutationNode):
报告结构列表
"""
try:
self.log_info(f"正在为查询生成报告结构: {self.query}")
logger.info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
@@ -56,11 +57,11 @@ class ReportStructureNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"成功生成 {len(processed_response)} 个段落结构")
logger.info(f"成功生成 {len(processed_response)} 个段落结构")
return processed_response
except Exception as e:
self.log_error(f"生成报告结构失败: {str(e)}")
logger.exception(f"生成报告结构失败: {str(e)}")
raise e
def process_output(self, output: str) -> List[Dict[str, str]]:
@@ -79,54 +80,54 @@ class ReportStructureNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
report_structure = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
report_structure = extract_clean_response(cleaned_output)
if "error" in report_structure:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
report_structure = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认结构
return self._generate_default_structure()
else:
self.log_error("无法修复JSON,使用默认结构")
logger.error("无法修复JSON,使用默认结构")
return self._generate_default_structure()
# 验证结构
if not isinstance(report_structure, list):
self.log_info("报告结构不是列表,尝试转换...")
logger.info("报告结构不是列表,尝试转换...")
if isinstance(report_structure, dict):
# 如果是单个对象,包装成列表
report_structure = [report_structure]
else:
self.log_error("报告结构格式无效,使用默认结构")
logger.error("报告结构格式无效,使用默认结构")
return self._generate_default_structure()
# 验证每个段落
validated_structure = []
for i, paragraph in enumerate(report_structure):
if not isinstance(paragraph, dict):
self.log_warning(f"段落 {i+1} 不是字典格式,跳过")
logger.warning(f"段落 {i+1} 不是字典格式,跳过")
continue
title = paragraph.get("title", f"段落 {i+1}")
content = paragraph.get("content", "")
if not title or not content:
self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过")
logger.warning(f"段落 {i+1} 缺少标题或内容,跳过")
continue
validated_structure.append({
@@ -135,14 +136,14 @@ class ReportStructureNode(StateMutationNode):
})
if not validated_structure:
self.log_warning("没有有效的段落结构,使用默认结构")
logger.warning("没有有效的段落结构,使用默认结构")
return self._generate_default_structure()
self.log_info(f"成功验证 {len(validated_structure)} 个段落结构")
logger.info(f"成功验证 {len(validated_structure)} 个段落结构")
return validated_structure
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return self._generate_default_structure()
def _generate_default_structure(self) -> List[Dict[str, str]]:
@@ -152,7 +153,7 @@ class ReportStructureNode(StateMutationNode):
Returns:
默认的报告结构列表
"""
self.log_info("生成默认报告结构")
logger.info("生成默认报告结构")
return [
{
"title": "研究概述",
@@ -195,9 +196,9 @@ class ReportStructureNode(StateMutationNode):
content=paragraph_data["content"]
)
self.log_info(f"已将 {len(report_structure)} 个段落添加到状态中")
logger.info(f"已将 {len(report_structure)} 个段落添加到状态中")
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+24 -23
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
@@ -62,7 +63,7 @@ class FirstSearchNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在生成首次搜索查询")
logger.info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
@@ -70,11 +71,11 @@ class FirstSearchNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"生成首次搜索查询失败: {str(e)}")
logger.exception(f"生成首次搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -93,30 +94,30 @@ class FirstSearchNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_search_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_search_query()
# 验证和清理结果
@@ -124,7 +125,7 @@ class FirstSearchNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_search_query()
return {
@@ -197,7 +198,7 @@ class ReflectionNode(BaseNode):
else:
message = json.dumps(input_data, ensure_ascii=False)
self.log_info("正在进行反思并生成新搜索查询")
logger.info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
@@ -205,11 +206,11 @@ class ReflectionNode(BaseNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
logger.info(f"反思生成搜索查询: {processed_response.get('search_query', 'N/A')}")
return processed_response
except Exception as e:
self.log_error(f"反思生成搜索查询失败: {str(e)}")
logger.exception(f"反思生成搜索查询失败: {str(e)}")
raise e
def process_output(self, output: str) -> Dict[str, str]:
@@ -228,30 +229,30 @@ class ReflectionNode(BaseNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 使用更强大的提取方法
result = extract_clean_response(cleaned_output)
if "error" in result:
self.log_error("JSON解析失败,尝试修复...")
logger.error("JSON解析失败,尝试修复...")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_error("JSON修复失败")
logger.error("JSON修复失败")
# 返回默认查询
return self._get_default_reflection_query()
else:
self.log_error("无法修复JSON,使用默认查询")
logger.error("无法修复JSON,使用默认查询")
return self._get_default_reflection_query()
# 验证和清理结果
@@ -259,7 +260,7 @@ class ReflectionNode(BaseNode):
reasoning = result.get("reasoning", "")
if not search_query:
self.log_warning("未找到搜索查询,使用默认查询")
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_reflection_query()
return {
@@ -268,7 +269,7 @@ class ReflectionNode(BaseNode):
}
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_reflection_query()
+30 -29
View File
@@ -6,6 +6,7 @@
import json
from typing import Dict, Any, List
from json.decoder import JSONDecodeError
from loguru import logger
from .base_node import StateMutationNode
from ..state.state import State
@@ -27,7 +28,7 @@ try:
FORUM_READER_AVAILABLE = True
except ImportError:
FORUM_READER_AVAILABLE = False
print("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
logger.warning("警告: 无法导入forum_reader模块,将跳过HOST发言读取功能")
class FirstSummaryNode(StateMutationNode):
@@ -84,9 +85,9 @@ class FirstSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -96,7 +97,7 @@ class FirstSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成首次段落总结")
logger.info("正在生成首次段落总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
@@ -107,11 +108,11 @@ class FirstSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成首次段落总结")
logger.info("成功生成首次段落总结")
return processed_response
except Exception as e:
self.log_error(f"生成首次总结失败: {str(e)}")
logger.exception(f"生成首次总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -130,26 +131,26 @@ class FirstSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.exception("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.exception("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -163,7 +164,7 @@ class FirstSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "段落总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -186,7 +187,7 @@ class FirstSummaryNode(StateMutationNode):
# 更新状态
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = summary
self.log_info(f"已更新段落 {paragraph_index} 的首次总结")
logger.info(f"已更新段落 {paragraph_index} 的首次总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -194,7 +195,7 @@ class FirstSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
@@ -252,9 +253,9 @@ class ReflectionSummaryNode(StateMutationNode):
if host_speech:
# 将HOST发言添加到输入数据中
data['host_speech'] = host_speech
self.log_info(f"已读取HOST发言,长度: {len(host_speech)}字符")
logger.info(f"已读取HOST发言,长度: {len(host_speech)}字符")
except Exception as e:
self.log_info(f"读取HOST发言失败: {str(e)}")
logger.exception(f"读取HOST发言失败: {str(e)}")
# 转换为JSON字符串
message = json.dumps(data, ensure_ascii=False)
@@ -264,7 +265,7 @@ class ReflectionSummaryNode(StateMutationNode):
formatted_host = format_host_speech_for_prompt(data['host_speech'])
message = formatted_host + "\n" + message
self.log_info("正在生成反思总结")
logger.info("正在生成反思总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
@@ -275,11 +276,11 @@ class ReflectionSummaryNode(StateMutationNode):
# 处理响应
processed_response = self.process_output(response)
self.log_info("成功生成反思总结")
logger.info("成功生成反思总结")
return processed_response
except Exception as e:
self.log_error(f"生成反思总结失败: {str(e)}")
logger.exception(f"生成反思总结失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
@@ -298,26 +299,26 @@ class ReflectionSummaryNode(StateMutationNode):
cleaned_output = clean_json_tags(cleaned_output)
# 记录清理后的输出用于调试
self.log_info(f"清理后的输出: {cleaned_output}")
logger.info(f"清理后的输出: {cleaned_output}")
# 解析JSON
try:
result = json.loads(cleaned_output)
self.log_info("JSON解析成功")
logger.info("JSON解析成功")
except JSONDecodeError as e:
self.log_info(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试修复JSON
fixed_json = fix_incomplete_json(cleaned_output)
if fixed_json:
try:
result = json.loads(fixed_json)
self.log_info("JSON修复成功")
logger.info("JSON修复成功")
except JSONDecodeError:
self.log_info("JSON修复失败,直接使用清理后的文本")
logger.exception("JSON修复失败,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
else:
self.log_info("无法修复JSON,直接使用清理后的文本")
logger.exception("无法修复JSON,直接使用清理后的文本")
# 如果不是JSON格式,直接返回清理后的文本
return cleaned_output
@@ -331,7 +332,7 @@ class ReflectionSummaryNode(StateMutationNode):
return cleaned_output
except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
logger.exception(f"处理输出失败: {str(e)}")
return "反思总结生成失败"
def mutate_state(self, input_data: Any, state: State, paragraph_index: int, **kwargs) -> State:
@@ -355,7 +356,7 @@ class ReflectionSummaryNode(StateMutationNode):
if 0 <= paragraph_index < len(state.paragraphs):
state.paragraphs[paragraph_index].research.latest_summary = updated_summary
state.paragraphs[paragraph_index].research.increment_reflection()
self.log_info(f"已更新段落 {paragraph_index} 的反思总结")
logger.info(f"已更新段落 {paragraph_index} 的反思总结")
else:
raise ValueError(f"段落索引 {paragraph_index} 超出范围")
@@ -363,5 +364,5 @@ class ReflectionSummaryNode(StateMutationNode):
return state
except Exception as e:
self.log_error(f"状态更新失败: {str(e)}")
logger.exception(f"状态更新失败: {str(e)}")
raise e
+2 -3
View File
@@ -12,7 +12,7 @@ from .text_processing import (
format_search_results_for_prompt
)
from .config import Config, load_config
from .config import Settings
__all__ = [
"clean_json_tags",
@@ -21,6 +21,5 @@ __all__ = [
"extract_clean_response",
"update_state_with_search_results",
"format_search_results_for_prompt",
"Config",
"load_config"
"Settings",
]
+66 -138
View File
@@ -1,151 +1,79 @@
"""
Configuration management module for the Query Engine.
Query Engine 配置管理模块
此模块使用 pydantic-settings 管理 Query Engine 的配置支持从环境变量和 .env 文件自动加载
数据模型定义位置
- 本文件 - 配置模型定义
"""
import os
from dataclasses import dataclass
from pathlib import Path
from pydantic_settings import BaseSettings
from pydantic import Field
from typing import Optional
from loguru import logger
def _get_value(source, key: str, default=None, *fallback_keys: str):
candidates = (key,) + fallback_keys
value = None
for candidate in candidates:
if isinstance(source, dict):
value = source.get(candidate)
else:
value = getattr(source, candidate, None)
if value not in (None, ""):
break
if value in (None, ""):
for candidate in candidates:
env_val = os.getenv(candidate)
if env_val not in (None, ""):
value = env_val
break
return value if value not in (None, "") else default
# 计算 .env 优先级:优先当前工作目录,其次项目根目录
PROJECT_ROOT: Path = Path(__file__).resolve().parents[2]
CWD_ENV: Path = Path.cwd() / ".env"
ENV_FILE: str = str(CWD_ENV if CWD_ENV.exists() else (PROJECT_ROOT / ".env"))
@dataclass
class Settings(BaseSettings):
"""
Query Engine 全局配置支持 .env 和环境变量自动加载
变量名与原 config.py 大写一致便于平滑过渡
"""
# ======================= LLM 相关 =======================
QUERY_ENGINE_API_KEY: str = Field(..., description="Query Engine LLM API密钥,用于主LLM。您可以更改每个部分LLM使用的API,🚩只要兼容OpenAI请求格式都可以,定义好KEY、BASE_URL与MODEL_NAME即可正常使用。")
QUERY_ENGINE_BASE_URL: Optional[str] = Field(None, description="Query Engine LLM接口BaseUrl,可自定义厂商API")
QUERY_ENGINE_MODEL_NAME: str = Field(..., description="Query Engine LLM模型名称")
QUERY_ENGINE_PROVIDER: Optional[str] = Field(None, description="Query Engine LLM提供商(兼容字段)")
# ================== 网络工具配置 ====================
TAVILY_API_KEY: str = Field(..., description="Tavily API(申请地址:https://www.tavily.com/API密钥,用于Tavily网络搜索")
# ================== 搜索参数配置 ====================
SEARCH_TIMEOUT: int = Field(240, description="搜索超时(秒)")
SEARCH_CONTENT_MAX_LENGTH: int = Field(20000, description="用于提示的最长内容长度")
MAX_REFLECTIONS: int = Field(2, description="最大反思轮数")
MAX_PARAGRAPHS: int = Field(5, description="最大段落数")
MAX_SEARCH_RESULTS: int = Field(20, description="最大搜索结果数")
# ================== 输出配置 ====================
OUTPUT_DIR: str = Field("reports", description="输出目录")
SAVE_INTERMEDIATE_STATES: bool = Field(True, description="是否保存中间状态")
class Config:
"""Query Engine configuration."""
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # compatibility
tavily_api_key: Optional[str] = None
search_timeout: int = 240
max_content_length: int = 20000
max_reflections: int = 2
max_paragraphs: int = 5
max_search_results: int = 20
output_dir: str = "reports"
save_intermediate_states: bool = True
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
if not self.llm_api_key:
print("错误: Query Engine LLM API Key 未设置 (QUERY_ENGINE_API_KEY)。")
return False
if not self.llm_model_name:
print("错误: Query Engine 模型名称未设置 (QUERY_ENGINE_MODEL_NAME)。")
return False
if not self.tavily_api_key:
print("错误: Tavily API Key 未设置 (TAVILY_API_KEY)。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
if config_file.endswith(".py"):
import importlib.util
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
llm_api_key=_get_value(config_module, "QUERY_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "QUERY_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "QUERY_ENGINE_MODEL_NAME"),
tavily_api_key=_get_value(config_module, "TAVILY_API_KEY"),
search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 2)),
max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 5)),
max_search_results=int(_get_value(config_module, "MAX_SEARCH_RESULTS", 20)),
output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "QUERY_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "QUERY_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "QUERY_ENGINE_MODEL_NAME"),
tavily_api_key=_get_value(config_dict, "TAVILY_API_KEY"),
search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 20000)),
max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 2)),
max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 5)),
max_search_results=int(_get_value(config_dict, "MAX_SEARCH_RESULTS", 20)),
output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
save_intermediate_states=str(
_get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
).lower()
in ("true", "1", "yes"),
)
env_file = ENV_FILE
env_prefix = ""
case_sensitive = False
extra = "allow"
def load_config(config_file: Optional[str] = None) -> Config:
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
# 创建全局配置实例
settings = Settings()
config = Config.from_file(file_to_load)
if not config.validate():
raise ValueError("配置校验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Settings):
"""
打印配置信息
def print_config(config: Config):
print("\n=== Query Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"Tavily API Key: {'已配置' if config.tavily_api_key else '未配置'}")
print(f"搜索超时: {config.search_timeout}")
print(f"最长内容长度: {config.max_content_length}")
print(f"最大反思次数: {config.max_reflections}")
print(f"大段落数: {config.max_paragraphs}")
print(f"最大搜索结果数: {config.max_search_results}")
print(f"输出目录: {config.output_dir}")
print(f"保存中间状态: {config.save_intermediate_states}")
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
print("========================\n")
Args:
config: Settings配置对象
"""
message = ""
message += "=== Query Engine 配置 ===\n"
message += f"LLM 模型: {config.QUERY_ENGINE_MODEL_NAME}\n"
message += f"LLM Base URL: {config.QUERY_ENGINE_BASE_URL or '(默认)'}\n"
message += f"Tavily API Key: {'已配置' if config.TAVILY_API_KEY else '未配置'}\n"
message += f"搜索超时: {config.SEARCH_TIMEOUT}\n"
message += f"长内容长度: {config.SEARCH_CONTENT_MAX_LENGTH}\n"
message += f"最大反思次数: {config.MAX_REFLECTIONS}\n"
message += f"最大段落数: {config.MAX_PARAGRAPHS}\n"
message += f"最大搜索结果数: {config.MAX_SEARCH_RESULTS}\n"
message += f"输出目录: {config.OUTPUT_DIR}\n"
message += f"保存中间状态: {config.SAVE_INTERMEDIATE_STATES}\n"
message += f"LLM API Key: {'已配置' if config.QUERY_ENGINE_API_KEY else '未配置'}\n"
message += "========================\n"
logger.info(message)
+27 -5
View File
@@ -191,7 +191,9 @@ Weibo_PublicOpinion_AnalysisSystem/
- **Database**: MySQL (optional, you can choose our cloud database service)
- **Memory**: 2GB+ recommended
### 1. Create Conda Environment
### 1. Create Environment
#### If Using Conda
```bash
# Create conda environment
@@ -199,11 +201,21 @@ conda create -n your_conda_name python=3.11
conda activate your_conda_name
```
#### If Using uv
```bash
# Create uv environment
uv venv --python 3.11 # Create Python 3.11 environment
```
### 2. Install Dependencies
```bash
# Basic dependency installation
pip install -r requirements.txt
# uv version command (faster installation)
uv pip install -r requirements.txt
# If you do not want to use the local sentiment analysis model (which has low computational requirements and defaults to the CPU version), you can comment out the 'Machine Learning' section in this file before executing the command.
```
@@ -218,9 +230,9 @@ playwright install chromium
#### 4.1 Configure API Keys
Copy the `config.py.example` file to `config.py`
Copy the `.env.example` file in the project root directory to `.env`
Edit the `config.py` file and fill in your API keys (you can also choose your own models and search proxies; see the config file for details):
Edit the `.env` file and fill in your API keys (you can also choose your own models and search proxies; see the `.env.example` file in the project root directory or the `config.py` file for details):
```python
# MySQL Database Configuration
@@ -246,7 +258,8 @@ INSIGHT_ENGINE_MODEL_NAME = "kimi-k2-0711-preview"
**Option 1: Use Local Database**
You can refer to `MindSpider\config.py.example` for the configuration template, copy this file and rename it to `config.py`.
> ~~The MindSpider crawler system and the public opinion system are independent of each other, so you need to configure `MindSpider\config.py`. Copy the `config.py.example` file in the `MindSpider` folder and rename it to `config.py`.~~
> Configuration has been changed to be based on environment variables. Please copy the `.env.example` file in the project root directory to `.env` and fill in all configurations in it.
```bash
# Local MySQL database initialization
@@ -279,6 +292,15 @@ conda activate your_conda_name
python app.py
```
uv version startup command:
```bash
# In project root directory, activate uv environment
.venv\Scripts\activate
# Start main application
python app.py
```
> Note 1: After a run is terminated, the Streamlit app might not shut down correctly and may still be occupying the port. If this occurs, find the process that is holding the port and kill it.
> Note 2: Data scraping needs to be performed as a separate operation. Please refer to the instructions in section 5.3.
@@ -327,7 +349,7 @@ python main.py --broad-topic --date 2024-01-20
python main.py --deep-sentiment --platforms xhs dy wb
```
## ⚙️ Advanced Configuration
## ⚙️ Advanced Configuration (Deprecated: Configuration has been unified to the `.env` file in the project root directory, and other sub-agents automatically inherit the root directory configuration)
### Modify Key Parameters
+33 -7
View File
@@ -193,7 +193,9 @@ Weibo_PublicOpinion_AnalysisSystem/
- **数据库**: MySQL(可选择我们的云数据库服务)
- **内存**: 建议2GB以上
### 1. 创建Conda环境
### 1. 创建环境
#### 如果使用Conda
```bash
# 创建conda环境
@@ -201,11 +203,21 @@ conda create -n your_conda_name python=3.11
conda activate your_conda_name
```
#### 如果使用uv
```bash
# 创建uv环境
uv venv --python 3.11 # 创建3.11环境
```
### 2. 安装依赖包
```bash
# 基础依赖安装
pip install -r requirements.txt
# uv版本命令(更快速安装)
uv pip install -r requirements.txt
# 如果不想使用本地情感分析模型(算力需求很小,默认安装cpu版本),可以将该文件中的“机器学习”部分注释掉再执行指令
```
@@ -220,9 +232,9 @@ playwright install chromium
#### 4.1 配置API密钥
复制一份 `config.py.example` 文件,命名为 `config.py`
复制一份 项目根目录 `.env.example` 文件,命名为 `.env`
编辑 `config.py` 文件,填入您的API密钥(您也可以选择自己的模型、搜索代理,详情见config文件内):
编辑 `.env` 文件,填入您的API密钥(您也可以选择自己的模型、搜索代理,详情见根目录.env.example文件内或根目录config.py中的说明):
```python
# MySQL数据库配置
@@ -248,12 +260,14 @@ INSIGHT_ENGINE_MODEL_NAME = "kimi-k2-0711-preview"
**选择1:使用本地数据库**
> MindSpider爬虫系统跟舆情系统是各自独立的,所以需要再去`MindSpider\config.py`配置一下,复制`MindSpider`文件夹下的 `config.py.example` 文件,命名为 `config.py`
> ~~MindSpider爬虫系统跟舆情系统是各自独立的,所以需要再去`MindSpider\config.py`配置一下,复制`MindSpider`文件夹下的 `config.py.example` 文件,命名为 `config.py`~~
先已更改为基于环境变量配置,请复制项目根目录.env.example文件为.env文件,并在其中填写各项配置
```bash
# 本地MySQL数据库初始化
cd MindSpider
python schema/init_database.py
# 项目初始化
python main.py --setup
```
**选择2:使用云数据库服务(推荐)**
@@ -281,6 +295,15 @@ conda activate your_conda_name
python app.py
```
uv 版本启动命令
```bash
# 在项目根目录下,激活uv环境
.venv\Scripts\activate
# 启动主应用即可
python app.py
```
> 注1:一次运行终止后,streamlit app可能结束异常仍然占用端口,此时搜索占用端口的进程kill掉即可
> 注2:数据爬取需要单独操作,见5.3指引
@@ -319,6 +342,9 @@ cd MindSpider
# 项目初始化
python main.py --setup
# 运行话题提取(获取热点新闻和关键词)
python main.py --broad-topic
# 运行完整爬虫流程
python main.py --complete --date 2024-01-20
@@ -329,7 +355,7 @@ python main.py --broad-topic --date 2024-01-20
python main.py --deep-sentiment --platforms xhs dy wb
```
## ⚙️ 高级配置
## ⚙️ 高级配置(已过时,已经统一为项目根目录.env文件管理,其他子agent自动继承根目录配置)
### 修改关键参数
+1 -2
View File
@@ -5,9 +5,8 @@ Report Engine
"""
from .agent import ReportAgent, create_agent
from .utils.config import Config, load_config
__version__ = "1.0.0"
__author__ = "Report Engine Team"
__all__ = ["ReportAgent", "create_agent", "Config", "load_config"]
__all__ = ["ReportAgent", "create_agent"]
+38 -62
View File
@@ -5,7 +5,7 @@ Report Agent主类
import json
import os
import logging
from loguru import logger
from datetime import datetime
from typing import Optional, Dict, Any, List
@@ -15,7 +15,7 @@ from .nodes import (
HTMLGenerationNode
)
from .state import ReportState
from .utils import Config, load_config
from .utils.config import settings, Settings
class FileCountBaseline:
@@ -32,7 +32,7 @@ class FileCountBaseline:
with open(self.baseline_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"加载基准数据失败: {e}")
logger.exception(f"加载基准数据失败: {e}")
return {}
def _save_baseline(self):
@@ -42,7 +42,7 @@ class FileCountBaseline:
with open(self.baseline_file, 'w', encoding='utf-8') as f:
json.dump(self.baseline_data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存基准数据失败: {e}")
logger.exception(f"保存基准数据失败: {e}")
def initialize_baseline(self, directories: Dict[str, str]) -> Dict[str, int]:
"""初始化文件数量基准"""
@@ -59,7 +59,7 @@ class FileCountBaseline:
self.baseline_data = current_counts.copy()
self._save_baseline()
print(f"文件数量基准已初始化: {current_counts}")
logger.info(f"文件数量基准已初始化: {current_counts}")
return current_counts
def check_new_files(self, directories: Dict[str, str]) -> Dict[str, Any]:
@@ -109,7 +109,7 @@ class FileCountBaseline:
class ReportAgent:
"""Report Agent主类"""
def __init__(self, config: Optional[Config] = None):
def __init__(self, config: Optional[Settings] = None):
"""
初始化Report Agent
@@ -117,7 +117,7 @@ class ReportAgent:
config: 配置对象如果不提供则自动加载
"""
# 加载配置
self.config = config or load_config()
self.config = config or settings
# 初始化文件基准管理器
self.file_baseline = FileCountBaseline()
@@ -138,44 +138,19 @@ class ReportAgent:
self.state = ReportState()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
os.makedirs(settings.OUTPUT_DIR, exist_ok=True)
self.logger.info("Report Agent已初始化")
self.logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info("Report Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
def _setup_logging(self):
"""设置日志"""
# 确保日志目录存在
log_dir = os.path.dirname(self.config.log_file)
log_dir = os.path.dirname(settings.LOG_FILE)
os.makedirs(log_dir, exist_ok=True)
# 创建专用的logger,避免与其他模块冲突
self.logger = logging.getLogger('ReportEngine')
self.logger.setLevel(logging.INFO)
# 清除已有的handlers
if self.logger.handlers:
self.logger.handlers.clear()
# 创建文件handler
file_handler = logging.FileHandler(self.config.log_file, encoding='utf-8')
file_handler.setLevel(logging.INFO)
# 创建控制台handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 设置格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加handlers
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
# 防止日志向上传播
self.logger.propagate = False
logger.add(settings.LOG_FILE, level="INFO")
def _initialize_file_baseline(self):
"""初始化文件数量基准"""
@@ -189,16 +164,16 @@ class ReportAgent:
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
api_key=settings.REPORT_ENGINE_API_KEY,
model_name=settings.REPORT_ENGINE_MODEL_NAME,
base_url=settings.REPORT_ENGINE_BASE_URL,
)
def _initialize_nodes(self):
"""初始化处理节点"""
self.template_selection_node = TemplateSelectionNode(
self.llm_client,
self.config.template_dir
self.config.TEMPLATE_DIR
)
self.html_generation_node = HTMLGenerationNode(self.llm_client)
@@ -219,7 +194,7 @@ class ReportAgent:
"""
start_time = datetime.now()
self.logger.info(f"开始生成报告: {query}")
logger.info(f"开始生成报告: {query}")
self.logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(forum_logs)}")
try:
@@ -238,21 +213,21 @@ class ReportAgent:
generation_time = (end_time - start_time).total_seconds()
self.state.metadata.generation_time = generation_time
self.logger.info(f"报告生成完成,耗时: {generation_time:.2f}")
logger.info(f"报告生成完成,耗时: {generation_time:.2f}")
return html_report
except Exception as e:
self.logger.error(f"报告生成过程中发生错误: {str(e)}")
logger.exception(f"报告生成过程中发生错误: {str(e)}")
raise e
def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str):
"""选择报告模板"""
self.logger.info("选择报告模板...")
logger.info("选择报告模板...")
# 如果用户提供了自定义模板,直接使用
if custom_template:
self.logger.info("使用用户自定义模板")
logger.info("使用用户自定义模板")
return {
'template_name': 'custom',
'template_content': custom_template,
@@ -271,12 +246,12 @@ class ReportAgent:
# 更新状态
self.state.metadata.template_used = template_result['template_name']
self.logger.info(f"选择模板: {template_result['template_name']}")
self.logger.info(f"选择理由: {template_result['selection_reason']}")
logger.info(f"选择模板: {template_result['template_name']}")
logger.info(f"选择理由: {template_result['selection_reason']}")
return template_result
except Exception as e:
self.logger.error(f"模板选择失败,使用默认模板: {str(e)}")
logger.error(f"模板选择失败,使用默认模板: {str(e)}")
# 直接使用备用模板
fallback_template = {
'template_name': '社会公共热点事件分析报告模板',
@@ -288,7 +263,7 @@ class ReportAgent:
def _generate_html_report(self, query: str, reports: List[Any], forum_logs: str, template_result: Dict[str, Any]) -> str:
"""生成HTML报告"""
self.logger.info("多轮生成HTML报告...")
logger.info("多轮生成HTML报告...")
# 准备报告内容,确保有3个报告
query_report = reports[0] if len(reports) > 0 else ""
@@ -316,7 +291,7 @@ class ReportAgent:
self.state.html_content = html_content
self.state.mark_completed()
self.logger.info("HTML报告生成完成")
logger.info("HTML报告生成完成")
return html_content
def _get_fallback_template_content(self) -> str:
@@ -376,19 +351,19 @@ class ReportAgent:
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"final_report_{query_safe}_{timestamp}.html"
filepath = os.path.join(self.config.output_dir, filename)
filepath = os.path.join(settings.OUTPUT_DIR, filename)
# 保存HTML报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(html_content)
self.logger.info(f"报告已保存到: {filepath}")
logger.info(f"报告已保存到: {filepath}")
# 保存状态
state_filename = f"report_state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
state_filepath = os.path.join(settings.OUTPUT_DIR, state_filename)
self.state.save_to_file(state_filepath)
self.logger.info(f"状态已保存到: {state_filepath}")
logger.info(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
@@ -397,12 +372,12 @@ class ReportAgent:
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = ReportState.load_from_file(filepath)
self.logger.info(f"状态已从 {filepath} 加载")
logger.info(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
self.logger.info(f"状态已保存到 {filepath}")
logger.info(f"状态已保存到 {filepath}")
def check_input_files(self, insight_dir: str, media_dir: str, query_dir: str, forum_log_path: str) -> Dict[str, Any]:
"""
@@ -488,9 +463,9 @@ class ReportAgent:
with open(file_paths[engine], 'r', encoding='utf-8') as f:
report_content = f.read()
content['reports'].append(report_content)
self.logger.info(f"已加载 {engine} 报告: {len(report_content)} 字符")
logger.info(f"已加载 {engine} 报告: {len(report_content)} 字符")
except Exception as e:
self.logger.error(f"加载 {engine} 报告失败: {str(e)}")
logger.exception(f"加载 {engine} 报告失败: {str(e)}")
content['reports'].append("")
# 加载论坛日志
@@ -498,9 +473,9 @@ class ReportAgent:
try:
with open(file_paths['forum'], 'r', encoding='utf-8') as f:
content['forum_logs'] = f.read()
self.logger.info(f"已加载论坛日志: {len(content['forum_logs'])} 字符")
logger.info(f"已加载论坛日志: {len(content['forum_logs'])} 字符")
except Exception as e:
self.logger.error(f"加载论坛日志失败: {str(e)}")
logger.exception(f"加载论坛日志失败: {str(e)}")
return content
@@ -515,5 +490,6 @@ def create_agent(config_file: Optional[str] = None) -> ReportAgent:
Returns:
ReportAgent实例
"""
config = load_config(config_file)
config = Settings() # 以空配置初始化,而从从环境变量初始化
return ReportAgent(config)
+14 -13
View File
@@ -10,9 +10,9 @@ import time
from datetime import datetime
from flask import Blueprint, request, jsonify, Response
from typing import Dict, Any
from loguru import logger
from .agent import ReportAgent, create_agent
from .utils.config import load_config
from .utils.config import settings
# 创建Blueprint
@@ -28,12 +28,11 @@ def initialize_report_engine():
"""初始化Report Engine"""
global report_agent
try:
config = load_config()
report_agent = create_agent()
print("Report Engine初始化成功")
logger.info("Report Engine初始化成功")
return True
except Exception as e:
print(f"Report Engine初始化失败: {str(e)}")
logger.exception(f"Report Engine初始化失败: {str(e)}")
return False
@@ -259,6 +258,7 @@ def get_progress(task_id: str):
})
except Exception as e:
logger.exception(f"获取报告生成进度失败: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
@@ -288,6 +288,7 @@ def get_result(task_id: str):
)
except Exception as e:
logger.exception(f"获取报告生成结果失败: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
@@ -363,7 +364,7 @@ def get_templates():
'error': 'Report Engine未初始化'
}), 500
template_dir = report_agent.config.template_dir
template_dir = settings.TEMPLATE_DIR
templates = []
if os.path.exists(template_dir):
@@ -381,7 +382,7 @@ def get_templates():
'size': len(content)
})
except Exception as e:
print(f"读取模板失败 {filename}: {str(e)}")
logger.exception(f"读取模板失败 {filename}: {str(e)}")
return jsonify({
'success': True,
@@ -416,21 +417,19 @@ def internal_error(error):
def clear_report_log():
"""清空report.log文件"""
try:
config = load_config()
log_file = config.log_file
log_file = settings.LOG_FILE
with open(log_file, 'w', encoding='utf-8') as f:
f.write('')
print(f"已清空日志文件: {log_file}")
logger.info(f"已清空日志文件: {log_file}")
except Exception as e:
print(f"清空日志文件失败: {str(e)}")
logger.exception(f"清空日志文件失败: {str(e)}")
@report_bp.route('/log', methods=['GET'])
def get_report_log():
"""获取report.log内容"""
try:
config = load_config()
log_file = config.log_file
log_file = settings.LOG_FILE
if not os.path.exists(log_file):
return jsonify({
@@ -450,6 +449,7 @@ def get_report_log():
})
except Exception as e:
logger.exception(f"读取日志失败: {str(e)}")
return jsonify({
'success': False,
'error': f'读取日志失败: {str(e)}'
@@ -466,6 +466,7 @@ def clear_log():
'message': '日志已清空'
})
except Exception as e:
logger.exception(f"清空日志失败: {str(e)}")
return jsonify({
'success': False,
'error': f'清空日志失败: {str(e)}'
+3 -5
View File
@@ -3,12 +3,11 @@ Report Engine节点基类
定义所有处理节点的基础接口
"""
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from ..llms.base import LLMClient
from ..state.state import ReportState
from loguru import logger
class BaseNode(ABC):
"""节点基类"""
@@ -23,7 +22,6 @@ class BaseNode(ABC):
"""
self.llm_client = llm_client
self.node_name = node_name or self.__class__.__name__
self.logger = logging.getLogger('ReportEngine')
@abstractmethod
def run(self, input_data: Any, **kwargs) -> Any:
@@ -66,12 +64,12 @@ class BaseNode(ABC):
def log_info(self, message: str):
"""记录信息日志"""
formatted_message = f"[{self.node_name}] {message}"
self.logger.info(formatted_message)
logger.info(formatted_message)
def log_error(self, message: str):
"""记录错误日志"""
formatted_message = f"[{self.node_name}] {message}"
self.logger.error(formatted_message)
logger.error(formatted_message)
class StateMutationNode(BaseNode):
+9 -8
View File
@@ -6,6 +6,7 @@ HTML生成节点
import json
from datetime import datetime
from typing import Dict, Any
from loguru import logger
from .base_node import StateMutationNode
from ..llms.base import LLMClient
@@ -42,7 +43,7 @@ class HTMLGenerationNode(StateMutationNode):
Returns:
生成的HTML内容
"""
self.log_info("开始生成HTML报告...")
logger.info("开始生成HTML报告...")
try:
# 准备LLM输入数据
@@ -64,11 +65,11 @@ class HTMLGenerationNode(StateMutationNode):
# 处理响应(简化版)
processed_response = self.process_output(response)
self.log_info("HTML报告生成完成")
logger.info("HTML报告生成完成")
return processed_response
except Exception as e:
self.log_error(f"HTML生成失败: {str(e)}")
logger.exception(f"HTML生成失败: {str(e)}")
# 返回备用HTML
return self._generate_fallback_html(input_data)
@@ -104,7 +105,7 @@ class HTMLGenerationNode(StateMutationNode):
HTML内容
"""
try:
self.log_info(f"处理LLM原始输出,长度: {len(output)} 字符")
logger.info(f"处理LLM原始输出,长度: {len(output)} 字符")
html_content = output.strip()
@@ -120,14 +121,14 @@ class HTMLGenerationNode(StateMutationNode):
# 如果内容为空,返回原始输出
if not html_content:
self.log_info("处理后内容为空,返回原始输出")
logger.info("处理后内容为空,返回原始输出")
html_content = output
self.log_info(f"HTML处理完成,最终长度: {len(html_content)} 字符")
logger.info(f"HTML处理完成,最终长度: {len(html_content)} 字符")
return html_content
except Exception as e:
self.log_error(f"处理HTML输出失败: {str(e)},返回原始输出")
logger.exception(f"处理HTML输出失败: {str(e)},返回原始输出")
return output
def _generate_fallback_html(self, input_data: Dict[str, Any]) -> str:
@@ -140,7 +141,7 @@ class HTMLGenerationNode(StateMutationNode):
Returns:
备用HTML内容
"""
self.log_info("使用备用HTML生成方法")
logger.info("使用备用HTML生成方法")
query = input_data.get('query', '智能舆情分析报告')
query_report = input_data.get('query_engine_report', '')
+15 -14
View File
@@ -6,6 +6,7 @@
import os
import json
from typing import Dict, Any, List, Optional
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_TEMPLATE_SELECTION
@@ -38,7 +39,7 @@ class TemplateSelectionNode(BaseNode):
Returns:
选择的模板信息
"""
self.log_info("开始模板选择...")
logger.info("开始模板选择...")
query = input_data.get('query', '')
reports = input_data.get('reports', [])
@@ -48,7 +49,7 @@ class TemplateSelectionNode(BaseNode):
available_templates = self._get_available_templates()
if not available_templates:
self.log_info("未找到预设模板,使用内置默认模板")
logger.info("未找到预设模板,使用内置默认模板")
return self._get_fallback_template()
# 使用LLM进行模板选择
@@ -57,7 +58,7 @@ class TemplateSelectionNode(BaseNode):
if llm_result:
return llm_result
except Exception as e:
self.log_error(f"LLM模板选择失败: {str(e)}")
logger.exception(f"LLM模板选择失败: {str(e)}")
# 如果LLM选择失败,使用备选方案
return self._get_fallback_template()
@@ -67,7 +68,7 @@ class TemplateSelectionNode(BaseNode):
def _llm_template_selection(self, query: str, reports: List[Any], forum_logs: str,
available_templates: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""使用LLM进行模板选择"""
self.log_info("尝试使用LLM进行模板选择...")
logger.info("尝试使用LLM进行模板选择...")
# 构建模板列表
template_list = "\n".join([f"- {t['name']}: {t['description']}" for t in available_templates])
@@ -118,10 +119,10 @@ class TemplateSelectionNode(BaseNode):
# 检查响应是否为空
if not response or not response.strip():
self.log_error("LLM返回空响应")
logger.error("LLM返回空响应")
return None
self.log_info(f"LLM原始响应: {response}")
logger.info(f"LLM原始响应: {response}")
# 尝试解析JSON响应
try:
@@ -133,18 +134,18 @@ class TemplateSelectionNode(BaseNode):
selected_template_name = result.get('template_name', '')
for template in available_templates:
if template['name'] == selected_template_name or selected_template_name in template['name']:
self.log_info(f"LLM选择模板: {selected_template_name}")
logger.info(f"LLM选择模板: {selected_template_name}")
return {
'template_name': template['name'],
'template_content': template['content'],
'selection_reason': result.get('selection_reason', 'LLM智能选择')
}
self.log_error(f"LLM选择的模板不存在: {selected_template_name}")
logger.error(f"LLM选择的模板不存在: {selected_template_name}")
return None
except json.JSONDecodeError as e:
self.log_error(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试从文本响应中提取模板信息
return self._extract_template_from_text(response, available_templates)
@@ -163,7 +164,7 @@ class TemplateSelectionNode(BaseNode):
def _extract_template_from_text(self, response: str, available_templates: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""从文本响应中提取模板信息"""
self.log_info("尝试从文本响应中提取模板信息")
logger.info("尝试从文本响应中提取模板信息")
# 查找响应中是否包含模板名称
for template in available_templates:
@@ -175,7 +176,7 @@ class TemplateSelectionNode(BaseNode):
for variant in template_name_variants:
if variant in response:
self.log_info(f"在响应中找到模板: {template['name']}")
logger.info(f"在响应中找到模板: {template['name']}")
return {
'template_name': template['name'],
'template_content': template['content'],
@@ -189,7 +190,7 @@ class TemplateSelectionNode(BaseNode):
templates = []
if not os.path.exists(self.template_dir):
self.log_error(f"模板目录不存在: {self.template_dir}")
logger.error(f"模板目录不存在: {self.template_dir}")
return templates
# 查找所有markdown模板文件
@@ -210,7 +211,7 @@ class TemplateSelectionNode(BaseNode):
'description': description
})
except Exception as e:
self.log_error(f"读取模板文件失败 {filename}: {str(e)}")
logger.exception(f"读取模板文件失败 {filename}: {str(e)}")
return templates
@@ -235,7 +236,7 @@ class TemplateSelectionNode(BaseNode):
def _get_fallback_template(self) -> Dict[str, Any]:
"""获取备用默认模板(空模板,让LLM自行发挥)"""
self.log_info("未找到合适模板,使用空模板让LLM自行发挥")
logger.info("未找到合适模板,使用空模板让LLM自行发挥")
return {
'template_name': '自由发挥模板',
-3
View File
@@ -3,9 +3,6 @@ Report Engine工具模块
包含配置管理
"""
from .config import Config, load_config
__all__ = [
"Config",
"load_config"
]
+40 -139
View File
@@ -3,150 +3,51 @@ Configuration management module for the Report Engine.
"""
import os
from dataclasses import dataclass
from pydantic_settings import BaseSettings
from pydantic import Field
from typing import Optional
from loguru import logger
def _get_value(source, key: str, default=None, *fallback_keys: str):
candidates = (key,) + fallback_keys
value = None
for candidate in candidates:
if isinstance(source, dict):
value = source.get(candidate)
else:
value = getattr(source, candidate, None)
if value not in (None, ""):
break
if value in (None, ""):
for candidate in candidates:
env_val = os.getenv(candidate)
if env_val not in (None, ""):
value = env_val
break
return value if value not in (None, "") else default
class Settings(BaseSettings):
"""Report Engine 配置,环境变量与字段均为REPORT_ENGINE_前缀一致大写。"""
REPORT_ENGINE_API_KEY: Optional[str] = Field(None, description="Report Engine LLM API密钥")
REPORT_ENGINE_BASE_URL: Optional[str] = Field(None, description="Report Engine LLM基础URL")
REPORT_ENGINE_MODEL_NAME: Optional[str] = Field(None, description="Report Engine LLM模型名称")
REPORT_ENGINE_PROVIDER: Optional[str] = Field(None, description="模型服务商,仅兼容保留")
MAX_CONTENT_LENGTH: int = Field(200000, description="最大内容长度")
OUTPUT_DIR: str = Field("final_reports", description="主输出目录")
TEMPLATE_DIR: str = Field("ReportEngine/report_template", description="多模板目录")
API_TIMEOUT: float = Field(900.0, description="单API超时时间(秒)")
MAX_RETRY_DELAY: float = Field(180.0, description="最大重试间隔(秒)")
MAX_RETRIES: int = Field(8, description="最大重试次数")
LOG_FILE: str = Field("logs/report.log", description="日志输出文件")
ENABLE_PDF_EXPORT: bool = Field(True, description="是否允许导出PDF")
CHART_STYLE: str = Field("modern", description="图表样式:modern/classic/")
@dataclass
class Config:
"""Report Engine configuration."""
env_file = ".env"
env_prefix = ""
case_sensitive = False
extra = "allow"
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # compatibility
max_content_length: int = 200000
output_dir: str = "final_reports"
template_dir: str = "ReportEngine/report_template"
api_timeout: float = 900.0
max_retry_delay: float = 180.0
max_retries: int = 8
log_file: str = "logs/report.log"
enable_pdf_export: bool = True
chart_style: str = "modern"
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
if not self.llm_api_key:
print("错误: Report Engine LLM API Key 未设置 (REPORT_ENGINE_API_KEY)。")
return False
if not self.llm_model_name:
print("错误: Report Engine 模型名称未设置 (REPORT_ENGINE_MODEL_NAME)。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
if config_file.endswith(".py"):
import importlib.util
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
llm_api_key=_get_value(config_module, "REPORT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "REPORT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "REPORT_ENGINE_MODEL_NAME"),
max_content_length=int(_get_value(config_module, "MAX_CONTENT_LENGTH", 200000)),
output_dir=_get_value(config_module, "REPORT_OUTPUT_DIR", "final_reports"),
template_dir=_get_value(config_module, "TEMPLATE_DIR", "ReportEngine/report_template"),
api_timeout=float(_get_value(config_module, "REPORT_API_TIMEOUT", 900.0)),
max_retry_delay=float(_get_value(config_module, "REPORT_MAX_RETRY_DELAY", 180.0)),
max_retries=int(_get_value(config_module, "REPORT_MAX_RETRIES", 8)),
log_file=_get_value(config_module, "REPORT_LOG_FILE", "logs/report.log"),
enable_pdf_export=str(
_get_value(config_module, "ENABLE_PDF_EXPORT", "true")
).lower()
in ("true", "1", "yes"),
chart_style=_get_value(config_module, "CHART_STYLE", "modern"),
)
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "REPORT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "REPORT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "REPORT_ENGINE_MODEL_NAME"),
max_content_length=int(_get_value(config_dict, "MAX_CONTENT_LENGTH", 200000)),
output_dir=_get_value(config_dict, "REPORT_OUTPUT_DIR", "final_reports"),
template_dir=_get_value(config_dict, "TEMPLATE_DIR", "ReportEngine/report_template"),
api_timeout=float(_get_value(config_dict, "REPORT_API_TIMEOUT", 900.0)),
max_retry_delay=float(_get_value(config_dict, "REPORT_MAX_RETRY_DELAY", 180.0)),
max_retries=int(_get_value(config_dict, "REPORT_MAX_RETRIES", 8)),
log_file=_get_value(config_dict, "REPORT_LOG_FILE", "logs/report.log"),
enable_pdf_export=str(
_get_value(config_dict, "ENABLE_PDF_EXPORT", "true")
).lower()
in ("true", "1", "yes"),
chart_style=_get_value(config_dict, "CHART_STYLE", "modern"),
)
settings = Settings()
def load_config(config_file: Optional[str] = None) -> Config:
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
config = Config.from_file(file_to_load)
if not config.validate():
raise ValueError("Report Engine 配置校验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Config):
print("\n=== Report Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"最大内容长度: {config.max_content_length}")
print(f"输出目录: {config.output_dir}")
print(f"模板目录: {config.template_dir}")
print(f"API 超时时间: {config.api_timeout}")
print(f"最大重试间隔: {config.max_retry_delay}")
print(f"最大重试次数: {config.max_retries}")
print(f"日志文件: {config.log_file}")
print(f"PDF 导出: {config.enable_pdf_export}")
print(f"图表样式: {config.chart_style}")
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
print("========================\n")
def print_config(config: Settings):
message = ""
message += "\n=== Report Engine 配置 ===\n"
message += f"LLM 模型: {config.REPORT_ENGINE_MODEL_NAME}\n"
message += f"LLM Base URL: {config.REPORT_ENGINE_BASE_URL or '(默认)'}\n"
message += f"最大内容长度: {config.MAX_CONTENT_LENGTH}\n"
message += f"输出目录: {config.OUTPUT_DIR}\n"
message += f"模板目录: {config.TEMPLATE_DIR}\n"
message += f"API 超时时间: {config.API_TIMEOUT}\n"
message += f"最大重试间隔: {config.MAX_RETRY_DELAY}\n"
message += f"最大重试次数: {config.MAX_RETRIES}\n"
message += f"日志文件: {config.LOG_FILE}\n"
message += f"PDF 导出: {config.ENABLE_PDF_EXPORT}\n"
message += f"图表样式: {config.CHART_STYLE}\n"
message += f"LLM API Key: {'已配置' if config.REPORT_ENGINE_API_KEY else '未配置'}\n"
message += "=========================\n"
logger.info(message)
+34 -37
View File
@@ -9,6 +9,7 @@ import streamlit as st
from datetime import datetime
import json
import locale
from loguru import logger
# 设置UTF-8编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
@@ -26,18 +27,8 @@ except locale.Error:
# 添加src目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from InsightEngine import DeepSearchAgent, Config
from config import (
INSIGHT_ENGINE_API_KEY,
INSIGHT_ENGINE_BASE_URL,
INSIGHT_ENGINE_MODEL_NAME,
DB_HOST,
DB_USER,
DB_PASSWORD,
DB_NAME,
DB_PORT,
DB_CHARSET,
)
from InsightEngine import DeepSearchAgent, Settings
from config import settings
def main():
@@ -66,7 +57,7 @@ def main():
# ----- 配置被硬编码 -----
# 强制使用 Kimi
model_name = INSIGHT_ENGINE_MODEL_NAME or "kimi-k2-0711-preview"
model_name = settings.INSIGHT_ENGINE_MODEL_NAME or "kimi-k2-0711-preview"
# 默认高级配置
max_reflections = 2
max_content_length = 500000 # Kimi支持长文本
@@ -100,42 +91,45 @@ def main():
if start_research:
if not query.strip():
st.error("请输入研究查询")
logger.error("请输入研究查询")
return
# 检查配置中的LLM密钥
if not INSIGHT_ENGINE_API_KEY:
st.error("请在您的配置文件(config.py)中设置INSIGHT_ENGINE_API_KEY")
if not settings.INSIGHT_ENGINE_API_KEY:
st.error("请在您的环境变量中设置INSIGHT_ENGINE_API_KEY")
logger.error("请在您的环境变量中设置INSIGHT_ENGINE_API_KEY")
return
# 自动使用配置文件中的API密钥和数据库配置
db_host = DB_HOST
db_user = DB_USER
db_password = DB_PASSWORD
db_name = DB_NAME
db_port = DB_PORT
db_charset = DB_CHARSET
db_host = settings.DB_HOST
db_user = settings.DB_USER
db_password = settings.DB_PASSWORD
db_name = settings.DB_NAME
db_port = settings.DB_PORT
db_charset = settings.DB_CHARSET
# 创建配置
config = Config(
llm_api_key=INSIGHT_ENGINE_API_KEY,
llm_base_url=INSIGHT_ENGINE_BASE_URL,
llm_model_name=model_name,
db_host=db_host,
db_user=db_user,
db_password=db_password,
db_name=db_name,
db_port=db_port,
db_charset=db_charset,
max_reflections=max_reflections,
max_content_length=max_content_length,
output_dir="insight_engine_streamlit_reports"
# 创建Settings配置(字段必须用大写,以适配Settings类)
config = Settings(
INSIGHT_ENGINE_API_KEY=settings.INSIGHT_ENGINE_API_KEY,
INSIGHT_ENGINE_BASE_URL=settings.INSIGHT_ENGINE_BASE_URL,
INSIGHT_ENGINE_MODEL_NAME=model_name,
DB_HOST=db_host,
DB_USER=db_user,
DB_PASSWORD=db_password,
DB_NAME=db_name,
DB_PORT=db_port,
DB_CHARSET=db_charset,
DB_DIALECT=settings.DB_DIALECT,
MAX_REFLECTIONS=max_reflections,
MAX_CONTENT_LENGTH=max_content_length,
OUTPUT_DIR="insight_engine_streamlit_reports"
)
# 执行研究
execute_research(query, config)
def execute_research(query: str, config: Config):
def execute_research(query: str, config: Settings):
"""执行研究"""
try:
# 创建进度条
@@ -187,7 +181,10 @@ def execute_research(query: str, config: Config):
display_results(agent, final_report)
except Exception as e:
st.error(f"研究过程中发生错误: {str(e)}")
import traceback
error_traceback = traceback.format_exc()
st.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}")
logger.exception(f"研究过程中发生错误: {str(e)}")
def display_results(agent: DeepSearchAgent, final_report: str):
+30 -26
View File
@@ -9,6 +9,7 @@ import streamlit as st
from datetime import datetime
import json
import locale
from loguru import logger
# 设置UTF-8编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
@@ -26,13 +27,8 @@ except locale.Error:
# 添加src目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from MediaEngine import DeepSearchAgent, Config
from config import (
MEDIA_ENGINE_API_KEY,
MEDIA_ENGINE_BASE_URL,
MEDIA_ENGINE_MODEL_NAME,
BOCHA_WEB_SEARCH_API_KEY,
)
from MediaEngine import DeepSearchAgent, Settings
from config import settings
def main():
@@ -62,7 +58,7 @@ def main():
# ----- 配置被硬编码 -----
# 强制使用 Gemini
model_name = MEDIA_ENGINE_MODEL_NAME or "gemini-2.5-pro"
model_name = settings.MEDIA_ENGINE_MODEL_NAME or "gemini-2.5-pro"
# 默认高级配置
max_reflections = 2
max_content_length = 20000
@@ -96,36 +92,39 @@ def main():
if start_research:
if not query.strip():
st.error("请输入研究查询")
logger.error("请输入研究查询")
return
# 由于强制使用Gemini,检查相关的API密钥
if not MEDIA_ENGINE_API_KEY:
st.error("请在您的配置文件(config.py)中设置MEDIA_ENGINE_API_KEY")
if not settings.MEDIA_ENGINE_API_KEY:
st.error("请在您的环境变量中设置MEDIA_ENGINE_API_KEY")
logger.error("请在您的环境变量中设置MEDIA_ENGINE_API_KEY")
return
if not BOCHA_WEB_SEARCH_API_KEY:
st.error("请在您的配置文件(config.py)中设置BOCHA_WEB_SEARCH_API_KEY")
if not settings.BOCHA_WEB_SEARCH_API_KEY:
st.error("请在您的环境变量中设置BOCHA_WEB_SEARCH_API_KEY")
logger.error("请在您的环境变量中设置BOCHA_WEB_SEARCH_API_KEY")
return
# 自动使用配置文件中的API密钥
engine_key = MEDIA_ENGINE_API_KEY
bocha_key = BOCHA_WEB_SEARCH_API_KEY
engine_key = settings.MEDIA_ENGINE_API_KEY
bocha_key = settings.BOCHA_WEB_SEARCH_API_KEY
# 创建配置
config = Config(
llm_api_key=engine_key,
llm_base_url=MEDIA_ENGINE_BASE_URL,
llm_model_name=model_name,
bocha_api_key=bocha_key,
max_reflections=max_reflections,
max_content_length=max_content_length,
output_dir="media_engine_streamlit_reports"
# 构建 Settingspydantic_settings风格,优先大写环境变量)
config = Settings(
MEDIA_ENGINE_API_KEY=engine_key,
MEDIA_ENGINE_BASE_URL=settings.MEDIA_ENGINE_BASE_URL,
MEDIA_ENGINE_MODEL_NAME=model_name,
BOCHA_WEB_SEARCH_API_KEY=bocha_key,
MAX_REFLECTIONS=max_reflections,
SEARCH_CONTENT_MAX_LENGTH=max_content_length,
OUTPUT_DIR="media_engine_streamlit_reports",
)
# 执行研究
execute_research(query, config)
def execute_research(query: str, config: Config):
def execute_research(query: str, config: Settings):
"""执行研究"""
try:
# 创建进度条
@@ -163,21 +162,26 @@ def execute_research(query: str, config: Config):
# 生成最终报告
status_text.text("正在生成最终报告...")
logger.info("正在生成最终报告...")
final_report = agent._generate_final_report()
progress_bar.progress(90)
# 保存报告
status_text.text("正在保存报告...")
logger.info("正在保存报告...")
agent._save_report(final_report)
progress_bar.progress(100)
status_text.text("研究完成!")
logger.info("研究完成!")
# 显示结果
display_results(agent, final_report)
except Exception as e:
st.error(f"研究过程中发生错误: {str(e)}")
import traceback
error_traceback = traceback.format_exc()
st.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}")
logger.exception(f"研究过程中发生错误: {str(e)}")
def display_results(agent: DeepSearchAgent, final_report: str):
+23 -19
View File
@@ -9,6 +9,7 @@ import streamlit as st
from datetime import datetime
import json
import locale
from loguru import logger
# 设置UTF-8编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
@@ -26,8 +27,8 @@ except locale.Error:
# 添加src目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from QueryEngine import DeepSearchAgent, Config
from config import QUERY_ENGINE_API_KEY, QUERY_ENGINE_BASE_URL, QUERY_ENGINE_MODEL_NAME, TAVILY_API_KEY
from QueryEngine import DeepSearchAgent, Settings
from config import settings
def main():
@@ -56,7 +57,7 @@ def main():
# ----- 配置被硬编码 -----
# 强制使用 DeepSeek
model_name = QUERY_ENGINE_MODEL_NAME or "deepseek-chat"
model_name = settings.QUERY_ENGINE_MODEL_NAME or "deepseek-chat"
# 默认高级配置
max_reflections = 2
max_content_length = 20000
@@ -93,33 +94,33 @@ def main():
return
# 由于强制使用DeepSeek,检查相关的API密钥
if not QUERY_ENGINE_API_KEY:
st.error("请在您的配置文件(config.py)中设置QUERY_ENGINE_API_KEY")
if not settings.QUERY_ENGINE_API_KEY:
st.error("请在您的环境变量中设置QUERY_ENGINE_API_KEY")
return
if not TAVILY_API_KEY:
st.error("请在您的配置文件(config.py)中设置TAVILY_API_KEY")
if not settings.TAVILY_API_KEY:
st.error("请在您的环境变量中设置TAVILY_API_KEY")
return
# 自动使用配置文件中的API密钥
engine_key = QUERY_ENGINE_API_KEY
tavily_key = TAVILY_API_KEY
engine_key = settings.QUERY_ENGINE_API_KEY
tavily_key = settings.TAVILY_API_KEY
# 创建配置
config = Config(
llm_api_key=engine_key,
llm_base_url=QUERY_ENGINE_BASE_URL,
llm_model_name=model_name,
tavily_api_key=tavily_key,
max_reflections=max_reflections,
max_content_length=max_content_length,
output_dir="query_engine_streamlit_reports"
config = Settings(
QUERY_ENGINE_API_KEY=engine_key,
QUERY_ENGINE_BASE_URL=settings.QUERY_ENGINE_BASE_URL,
QUERY_ENGINE_MODEL_NAME=model_name,
TAVILY_API_KEY=tavily_key,
MAX_REFLECTIONS=max_reflections,
SEARCH_CONTENT_MAX_LENGTH=max_content_length,
OUTPUT_DIR="query_engine_streamlit_reports"
)
# 执行研究
execute_research(query, config)
def execute_research(query: str, config: Config):
def execute_research(query: str, config: Settings):
"""执行研究"""
try:
# 创建进度条
@@ -171,7 +172,10 @@ def execute_research(query: str, config: Config):
display_results(agent, final_report)
except Exception as e:
st.error(f"研究过程中发生错误: {str(e)}")
import traceback
error_traceback = traceback.format_exc()
st.error(f"研究过程中发生错误: {str(e)} \n错误堆栈: {error_traceback}")
logger.exception(f"研究过程中发生错误: {str(e)}")
def display_results(agent: DeepSearchAgent, final_report: str):
+91 -50
View File
@@ -15,7 +15,7 @@ from flask_socketio import SocketIO, emit
import signal
import atexit
import requests
import logging
from loguru import logger
import importlib
import re
from pathlib import Path
@@ -25,7 +25,7 @@ try:
from ReportEngine.flask_interface import report_bp, initialize_report_engine
REPORT_ENGINE_AVAILABLE = True
except ImportError as e:
print(f"ReportEngine导入失败: {e}")
logger.error(f"ReportEngine导入失败: {e}")
REPORT_ENGINE_AVAILABLE = False
app = Flask(__name__)
@@ -35,9 +35,9 @@ socketio = SocketIO(app, cors_allowed_origins="*")
# 注册ReportEngine Blueprint
if REPORT_ENGINE_AVAILABLE:
app.register_blueprint(report_bp, url_prefix='/api/report')
print("ReportEngine接口已注册")
logger.info("ReportEngine接口已注册")
else:
print("ReportEngine不可用,跳过接口注册")
logger.info("ReportEngine不可用,跳过接口注册")
# 设置UTF-8编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
@@ -50,6 +50,7 @@ LOG_DIR.mkdir(exist_ok=True)
CONFIG_MODULE_NAME = 'config'
CONFIG_FILE_PATH = Path(__file__).resolve().parent / 'config.py'
CONFIG_KEYS = [
'DB_DIALECT',
'DB_HOST',
'DB_PORT',
'DB_USER',
@@ -95,19 +96,34 @@ def _load_config_module():
def read_config_values():
"""Return the current configuration values that are exposed to the frontend."""
module = _load_config_module()
if not module:
try:
# 重新导入 config 模块以获取最新的 Settings 实例
importlib.invalidate_caches()
if CONFIG_MODULE_NAME in sys.modules:
importlib.reload(sys.modules[CONFIG_MODULE_NAME])
else:
importlib.import_module(CONFIG_MODULE_NAME)
# 从 config 模块获取 settings 实例
config_module = sys.modules[CONFIG_MODULE_NAME]
if not hasattr(config_module, 'settings'):
logger.error("config 模块中没有找到 settings 实例")
return {}
settings = config_module.settings
values = {}
for key in CONFIG_KEYS:
value = getattr(module, key, '')
# 从 Pydantic Settings 实例读取值
value = getattr(settings, key, None)
# Convert to string for uniform handling on the frontend.
if value is None:
values[key] = ''
else:
values[key] = str(value)
return values
except Exception as exc:
logger.exception(f"读取配置失败: {exc}")
return {}
def _serialize_config_value(value):
@@ -125,35 +141,58 @@ def _serialize_config_value(value):
def write_config_values(updates):
"""Persist configuration updates into config.py."""
if not CONFIG_FILE_PATH.exists():
raise FileNotFoundError("配置文件 config.py 不存在")
"""Persist configuration updates to .env file (Pydantic Settings source)."""
from pathlib import Path
content = CONFIG_FILE_PATH.read_text(encoding='utf-8')
# 确定 .env 文件路径(与 config.py 中的逻辑一致)
project_root = Path(__file__).resolve().parent
cwd_env = Path.cwd() / ".env"
env_file_path = cwd_env if cwd_env.exists() else (project_root / ".env")
# 读取现有的 .env 文件内容
env_lines = []
env_key_indices = {} # 记录每个键在文件中的索引位置
if env_file_path.exists():
env_lines = env_file_path.read_text(encoding='utf-8').splitlines()
# 提取已存在的键及其索引
for i, line in enumerate(env_lines):
line_stripped = line.strip()
if line_stripped and not line_stripped.startswith('#'):
if '=' in line_stripped:
key = line_stripped.split('=')[0].strip()
env_key_indices[key] = i
# 更新或添加配置项
for key, raw_value in updates.items():
formatted_value = _serialize_config_value(raw_value)
pattern = re.compile(
rf'^(\s*{key}\s*=\s*)(["\'].*?["\']|None|True|False|[0-9\.-]+)(.*)$',
re.MULTILINE
)
# 格式化值用于 .env 文件(不需要引号,除非是字符串且包含空格)
if raw_value is None or raw_value == '':
env_value = ''
elif isinstance(raw_value, (int, float)):
env_value = str(raw_value)
elif isinstance(raw_value, bool):
env_value = 'True' if raw_value else 'False'
else:
value_str = str(raw_value)
# 如果包含空格或特殊字符,需要引号
if ' ' in value_str or '\n' in value_str or '#' in value_str:
escaped = value_str.replace('\\', '\\\\').replace('"', '\\"')
env_value = f'"{escaped}"'
else:
env_value = value_str
def replace(match):
prefix, _, suffix = match.groups()
return f"{prefix}{formatted_value}{suffix}"
# 更新或添加配置项
if key in env_key_indices:
# 更新现有行
env_lines[env_key_indices[key]] = f'{key}={env_value}'
else:
# 添加新行到文件末尾
env_lines.append(f'{key}={env_value}')
new_content, count = pattern.subn(replace, content, count=1)
# 写入 .env 文件
env_file_path.parent.mkdir(parents=True, exist_ok=True)
env_file_path.write_text('\n'.join(env_lines) + '\n', encoding='utf-8')
if count == 0:
# Append the new key if it was not present.
if not new_content.endswith('\n'):
new_content += '\n'
new_content += f'{key} = {formatted_value}\n'
content = new_content
CONFIG_FILE_PATH.write_text(content, encoding='utf-8')
# Reload the module so the rest of the app observes the new values when possible.
# 重新加载配置模块(这会重新读取 .env 文件并创建新的 Settings 实例)
_load_config_module()
@@ -268,14 +307,14 @@ def init_forum_log():
with open(forum_log_file, 'w', encoding='utf-8') as f:
start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
f.write(f"=== ForumEngine 系统初始化 - {start_time} ===\n")
print(f"ForumEngine: forum.log 已初始化")
logger.info(f"ForumEngine: forum.log 已初始化")
else:
with open(forum_log_file, 'w', encoding='utf-8') as f:
start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
f.write(f"=== ForumEngine 系统初始化 - {start_time} ===\n")
print(f"ForumEngine: forum.log 已初始化")
logger.info(f"ForumEngine: forum.log 已初始化")
except Exception as e:
print(f"ForumEngine: 初始化forum.log失败: {e}")
logger.exception(f"ForumEngine: 初始化forum.log失败: {e}")
# 初始化forum.log
init_forum_log()
@@ -285,23 +324,23 @@ def start_forum_engine():
"""启动ForumEngine论坛"""
try:
from ForumEngine.monitor import start_forum_monitoring
print("ForumEngine: 启动论坛...")
logger.info("ForumEngine: 启动论坛...")
success = start_forum_monitoring()
if not success:
print("ForumEngine: 论坛启动失败")
logger.info("ForumEngine: 论坛启动失败")
except Exception as e:
print(f"ForumEngine: 启动论坛失败: {e}")
logger.exception(f"ForumEngine: 启动论坛失败: {e}")
# 停止ForumEngine智能监控
def stop_forum_engine():
"""停止ForumEngine论坛"""
try:
from ForumEngine.monitor import stop_forum_monitoring
print("ForumEngine: 停止论坛...")
logger.info("ForumEngine: 停止论坛...")
stop_forum_monitoring()
print("ForumEngine: 论坛已停止")
logger.info("ForumEngine: 论坛已停止")
except Exception as e:
print(f"ForumEngine: 停止论坛失败: {e}")
logger.exception(f"ForumEngine: 停止论坛失败: {e}")
def parse_forum_log_line(line):
"""解析forum.log行内容,提取对话信息"""
@@ -396,7 +435,7 @@ def monitor_forum_log():
time.sleep(1) # 每秒检查一次
except Exception as e:
print(f"Forum日志监听错误: {e}")
logger.error(f"Forum日志监听错误: {e}")
time.sleep(5)
# 启动Forum日志监听线程
@@ -433,7 +472,7 @@ def write_log_to_file(app_name, line):
f.write(line + '\n')
f.flush()
except Exception as e:
print(f"Error writing log for {app_name}: {e}")
logger.error(f"Error writing log for {app_name}: {e}")
def read_log_from_file(app_name, tail_lines=None):
"""从文件读取日志"""
@@ -450,7 +489,7 @@ def read_log_from_file(app_name, tail_lines=None):
return lines[-tail_lines:]
return lines
except Exception as e:
print(f"Error reading log for {app_name}: {e}")
logger.exception(f"Error reading log for {app_name}: {e}")
return []
def read_process_output(process, app_name):
@@ -519,8 +558,7 @@ def read_process_output(process, app_name):
})
except Exception as e:
error_msg = f"Error reading output for {app_name}: {e}"
print(error_msg)
logger.exception(f"Error reading output for {app_name}: {e}")
write_log_to_file(app_name, f"[{datetime.now().strftime('%H:%M:%S')}] {error_msg}")
break
@@ -675,7 +713,7 @@ def cleanup_processes():
try:
stop_forum_engine()
except Exception: # pragma: no cover
logging.exception("停止ForumEngine失败")
logger.exception("停止ForumEngine失败")
_set_system_state(started=False, starting=False)
# 注册清理函数
@@ -863,7 +901,7 @@ def search():
return jsonify({'success': False, 'message': '搜索查询不能为空'})
# ForumEngine论坛已经在后台运行,会自动检测搜索活动
# print("ForumEngine: 搜索请求已收到,论坛将自动检测日志变化")
# logger.info("ForumEngine: 搜索请求已收到,论坛将自动检测日志变化")
# 检查哪些应用正在运行
check_app_status()
@@ -993,12 +1031,15 @@ def handle_status_request():
})
if __name__ == '__main__':
print("等待配置确认,系统将在前端指令后启动组件...")
print("启动Flask服务器...")
HOST = '0.0.0.0'
PORT = 5000
logger.info("等待配置确认,系统将在前端指令后启动组件...")
logger.info(f"Flask服务器已启动,访问地址: http://{HOST}:{PORT}")
try:
socketio.run(app, host='0.0.0.0', port=5000, debug=False)
socketio.run(app, host=HOST, port=PORT, debug=False)
except KeyboardInterrupt:
print("\n正在关闭应用...")
logger.info("\n正在关闭应用...")
cleanup_processes()
-58
View File
@@ -1,58 +0,0 @@
# -*- coding: utf-8 -*-
"""
微舆配置文件
"""
# ============================== 数据库配置 ==============================
# 配置这些值以连接到您的MySQL实例。
DB_HOST = "your_db_host" # 例如:"localhost" 或 "127.0.0.1"
DB_PORT = 3306
DB_USER = "your_db_user"
DB_PASSWORD = "your_db_password"
DB_NAME = "your_db_name"
DB_CHARSET = "utf8mb4"
# 我们也提供云数据库资源便捷配置,日均10w+数据,可免费申请,联系我们:670939375@qq.com
# NOTE:为进行数据合规性审查与服务升级,云数据库自2025年10月1日起暂停接收新的使用申请
# ============================== LLM配置 ==============================
# 您可以更改每个部分LLM使用的API,🚩只要兼容OpenAI请求格式都可以,定义好KEY、BASE_URL与MODEL_NAME即可正常使用。
# 重要提醒:我们强烈推荐您先使用推荐的配置申请API,先跑通再进行您的更改!
# Insight Agent(推荐Kimi,申请地址:https://platform.moonshot.cn/
INSIGHT_ENGINE_API_KEY = "your_api_key"
INSIGHT_ENGINE_BASE_URL = "https://api.moonshot.cn/v1"
INSIGHT_ENGINE_MODEL_NAME = "kimi-k2-0711-preview"
# Media Agent(推荐Gemini,这里我用了一个中转厂商,你也可以换成你自己的,申请地址:https://www.chataiapi.com/
MEDIA_ENGINE_API_KEY = "your_api_key"
MEDIA_ENGINE_BASE_URL = "https://www.chataiapi.com/v1"
MEDIA_ENGINE_MODEL_NAME = "gemini-2.5-pro"
# Query Agent(推荐DeepSeek,申请地址:https://www.deepseek.com/
QUERY_ENGINE_API_KEY = "your_api_key"
QUERY_ENGINE_BASE_URL = "https://api.deepseek.com"
QUERY_ENGINE_MODEL_NAME = "deepseek-reasoner"
# Report Agent(推荐Gemini,这里我用了一个中转厂商,你也可以换成你自己的)
REPORT_ENGINE_API_KEY = "your_api_key"
REPORT_ENGINE_BASE_URL = "https://www.chataiapi.com/v1"
REPORT_ENGINE_MODEL_NAME = "gemini-2.5-pro"
# Forum Host(Qwen3最新模型,这里我使用了硅基流动这个平台,申请地址:https://cloud.siliconflow.cn/
FORUM_HOST_API_KEY = "your_api_key"
FORUM_HOST_BASE_URL = "https://api.siliconflow.cn/v1"
FORUM_HOST_MODEL_NAME = "Qwen/Qwen3-235B-A22B-Instruct-2507"
# SQL keyword Optimizer(小参数Qwen3模型,这里我使用了硅基流动这个平台,申请地址:https://cloud.siliconflow.cn/
KEYWORD_OPTIMIZER_API_KEY = "your_api_key"
KEYWORD_OPTIMIZER_BASE_URL = "https://api.siliconflow.cn/v1"
KEYWORD_OPTIMIZER_MODEL_NAME = "Qwen/Qwen3-30B-A3B-Instruct-2507"
# ============================== 网络工具配置 ==============================
# Tavily API(申请地址:https://www.tavily.com/
TAVILY_API_KEY = "your_api_key"
# Bocha API(申请地址:https://open.bochaai.com/
BOCHA_WEB_SEARCH_API_KEY = "your_api_key"
+2
View File
@@ -8,6 +8,8 @@ services:
image: bettafish:latest
container_name: bettafish
restart: unless-stopped
env_file:
- .env
environment:
- PYTHONUNBUFFERED=1
ports:
+54 -12
View File
@@ -456,6 +456,15 @@
border-color: #333333;
}
.config-field-input[data-field-type="select"] {
cursor: pointer;
appearance: none;
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 12 12'%3E%3Cpath fill='%23333' d='M6 9L1 4h10z'/%3E%3C/svg%3E");
background-repeat: no-repeat;
background-position: right 12px center;
padding-right: 36px;
}
.config-modal-footer {
display: flex;
justify-content: space-between;
@@ -1090,7 +1099,7 @@
<div class="config-modal-overlay" id="configModal">
<div class="config-modal">
<div class="config-modal-header">
<div class="config-modal-title">LLM 配置 - 与Config文件双向同步</div>
<div class="config-modal-title">LLM 配置 - 与.env文件双向同步</div>
<div class="config-modal-actions">
<button class="config-secondary-button" id="refreshConfigButton">刷新</button>
<button class="config-close-button" id="closeConfigModal" aria-label="关闭配置窗口">×</button>
@@ -1141,6 +1150,7 @@
title: '数据库连接',
subtitle: '用于连接业务数据库的基本配置',
fields: [
{ key: 'DB_DIALECT', label: '数据库类型', type: 'select', options: ['mysql', 'postgresql'] },
{ key: 'DB_HOST', label: '主机地址' },
{ key: 'DB_PORT', label: '端口' },
{ key: 'DB_USER', label: '用户名' },
@@ -1478,27 +1488,59 @@
const fieldsHtml = group.fields.map(field => {
const value = values[field.key] !== undefined ? values[field.key] : '';
const safeValue = escapeHtml(String(value || ''));
const inputType = field.type === 'password' ? 'password' : (field.type || 'text');
const inputElement = `
<input
type="${inputType}"
let control;
if (field.type === 'select' && field.options) {
// 下拉选择框
const optionsHtml = field.options.map(option => {
const selected = option === value ? 'selected' : '';
const safeOption = escapeHtml(String(option));
return `<option value="${safeOption}" ${selected}>${safeOption}</option>`;
}).join('');
control = `
<select
class="config-field-input"
data-config-key="${field.key}"
data-field-type="${field.type || 'text'}"
data-field-type="select"
>
${optionsHtml}
</select>
`;
} else if (field.type === 'password') {
// 密码输入框
const inputElement = `
<input
type="password"
class="config-field-input"
data-config-key="${field.key}"
data-field-type="password"
value="${safeValue}"
placeholder="填写${field.label}"
autocomplete="${field.type === 'password' ? 'off' : 'on'}"
autocomplete="off"
>
`;
const control = field.type === 'password'
? `
control = `
<div class="config-password-wrapper">
${inputElement}
<button type="button" class="config-password-toggle" data-target="${field.key}">显示</button>
</div>
`
: inputElement;
`;
} else {
// 普通文本输入框
const inputType = field.type || 'text';
control = `
<input
type="${inputType}"
class="config-field-input"
data-config-key="${field.key}"
data-field-type="${inputType}"
value="${safeValue}"
placeholder="填写${field.label}"
autocomplete="on"
>
`;
}
return `
<label class="config-field">