1. 统一为使用基于pydantic的.env环境变量管理配置

2. 全项目基于loguru进行日志管理
This commit is contained in:
Doiiars
2025-11-05 14:56:49 +08:00
parent 1d2e23d8c1
commit 537d682861
50 changed files with 1404 additions and 1731 deletions
+1 -2
View File
@@ -5,9 +5,8 @@ Report Engine
"""
from .agent import ReportAgent, create_agent
from .utils.config import Config, load_config
__version__ = "1.0.0"
__author__ = "Report Engine Team"
__all__ = ["ReportAgent", "create_agent", "Config", "load_config"]
__all__ = ["ReportAgent", "create_agent"]
+39 -63
View File
@@ -5,7 +5,7 @@ Report Agent主类
import json
import os
import logging
from loguru import logger
from datetime import datetime
from typing import Optional, Dict, Any, List
@@ -15,7 +15,7 @@ from .nodes import (
HTMLGenerationNode
)
from .state import ReportState
from .utils import Config, load_config
from .utils.config import settings, Settings
class FileCountBaseline:
@@ -32,7 +32,7 @@ class FileCountBaseline:
with open(self.baseline_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"加载基准数据失败: {e}")
logger.exception(f"加载基准数据失败: {e}")
return {}
def _save_baseline(self):
@@ -42,7 +42,7 @@ class FileCountBaseline:
with open(self.baseline_file, 'w', encoding='utf-8') as f:
json.dump(self.baseline_data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存基准数据失败: {e}")
logger.exception(f"保存基准数据失败: {e}")
def initialize_baseline(self, directories: Dict[str, str]) -> Dict[str, int]:
"""初始化文件数量基准"""
@@ -59,7 +59,7 @@ class FileCountBaseline:
self.baseline_data = current_counts.copy()
self._save_baseline()
print(f"文件数量基准已初始化: {current_counts}")
logger.info(f"文件数量基准已初始化: {current_counts}")
return current_counts
def check_new_files(self, directories: Dict[str, str]) -> Dict[str, Any]:
@@ -109,7 +109,7 @@ class FileCountBaseline:
class ReportAgent:
"""Report Agent主类"""
def __init__(self, config: Optional[Config] = None):
def __init__(self, config: Optional[Settings] = None):
"""
初始化Report Agent
@@ -117,7 +117,7 @@ class ReportAgent:
config: 配置对象,如果不提供则自动加载
"""
# 加载配置
self.config = config or load_config()
self.config = config or settings
# 初始化文件基准管理器
self.file_baseline = FileCountBaseline()
@@ -138,45 +138,20 @@ class ReportAgent:
self.state = ReportState()
# 确保输出目录存在
os.makedirs(self.config.output_dir, exist_ok=True)
os.makedirs(settings.OUTPUT_DIR, exist_ok=True)
self.logger.info("Report Agent已初始化")
self.logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info("Report Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
def _setup_logging(self):
"""设置日志"""
# 确保日志目录存在
log_dir = os.path.dirname(self.config.log_file)
log_dir = os.path.dirname(settings.LOG_FILE)
os.makedirs(log_dir, exist_ok=True)
# 创建专用的logger,避免与其他模块冲突
self.logger = logging.getLogger('ReportEngine')
self.logger.setLevel(logging.INFO)
logger.add(settings.LOG_FILE, level="INFO")
# 清除已有的handlers
if self.logger.handlers:
self.logger.handlers.clear()
# 创建文件handler
file_handler = logging.FileHandler(self.config.log_file, encoding='utf-8')
file_handler.setLevel(logging.INFO)
# 创建控制台handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 设置格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加handlers
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
# 防止日志向上传播
self.logger.propagate = False
def _initialize_file_baseline(self):
"""初始化文件数量基准"""
directories = {
@@ -189,16 +164,16 @@ class ReportAgent:
def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
return LLMClient(
api_key=self.config.llm_api_key,
model_name=self.config.llm_model_name,
base_url=self.config.llm_base_url,
api_key=settings.REPORT_ENGINE_API_KEY,
model_name=settings.REPORT_ENGINE_MODEL_NAME,
base_url=settings.REPORT_ENGINE_BASE_URL,
)
def _initialize_nodes(self):
"""初始化处理节点"""
self.template_selection_node = TemplateSelectionNode(
self.llm_client,
self.config.template_dir
self.llm_client,
self.config.TEMPLATE_DIR
)
self.html_generation_node = HTMLGenerationNode(self.llm_client)
@@ -219,7 +194,7 @@ class ReportAgent:
"""
start_time = datetime.now()
self.logger.info(f"开始生成报告: {query}")
logger.info(f"开始生成报告: {query}")
self.logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(forum_logs)}")
try:
@@ -238,21 +213,21 @@ class ReportAgent:
generation_time = (end_time - start_time).total_seconds()
self.state.metadata.generation_time = generation_time
self.logger.info(f"报告生成完成,耗时: {generation_time:.2f}")
logger.info(f"报告生成完成,耗时: {generation_time:.2f}")
return html_report
except Exception as e:
self.logger.error(f"报告生成过程中发生错误: {str(e)}")
logger.exception(f"报告生成过程中发生错误: {str(e)}")
raise e
def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str):
"""选择报告模板"""
self.logger.info("选择报告模板...")
logger.info("选择报告模板...")
# 如果用户提供了自定义模板,直接使用
if custom_template:
self.logger.info("使用用户自定义模板")
logger.info("使用用户自定义模板")
return {
'template_name': 'custom',
'template_content': custom_template,
@@ -271,12 +246,12 @@ class ReportAgent:
# 更新状态
self.state.metadata.template_used = template_result['template_name']
self.logger.info(f"选择模板: {template_result['template_name']}")
self.logger.info(f"选择理由: {template_result['selection_reason']}")
logger.info(f"选择模板: {template_result['template_name']}")
logger.info(f"选择理由: {template_result['selection_reason']}")
return template_result
except Exception as e:
self.logger.error(f"模板选择失败,使用默认模板: {str(e)}")
logger.error(f"模板选择失败,使用默认模板: {str(e)}")
# 直接使用备用模板
fallback_template = {
'template_name': '社会公共热点事件分析报告模板',
@@ -288,7 +263,7 @@ class ReportAgent:
def _generate_html_report(self, query: str, reports: List[Any], forum_logs: str, template_result: Dict[str, Any]) -> str:
"""生成HTML报告"""
self.logger.info("多轮生成HTML报告...")
logger.info("多轮生成HTML报告...")
# 准备报告内容,确保有3个报告
query_report = reports[0] if len(reports) > 0 else ""
@@ -316,7 +291,7 @@ class ReportAgent:
self.state.html_content = html_content
self.state.mark_completed()
self.logger.info("HTML报告生成完成")
logger.info("HTML报告生成完成")
return html_content
def _get_fallback_template_content(self) -> str:
@@ -376,19 +351,19 @@ class ReportAgent:
query_safe = query_safe.replace(' ', '_')[:30]
filename = f"final_report_{query_safe}_{timestamp}.html"
filepath = os.path.join(self.config.output_dir, filename)
filepath = os.path.join(settings.OUTPUT_DIR, filename)
# 保存HTML报告
with open(filepath, 'w', encoding='utf-8') as f:
f.write(html_content)
self.logger.info(f"报告已保存到: {filepath}")
logger.info(f"报告已保存到: {filepath}")
# 保存状态
state_filename = f"report_state_{query_safe}_{timestamp}.json"
state_filepath = os.path.join(self.config.output_dir, state_filename)
state_filepath = os.path.join(settings.OUTPUT_DIR, state_filename)
self.state.save_to_file(state_filepath)
self.logger.info(f"状态已保存到: {state_filepath}")
logger.info(f"状态已保存到: {state_filepath}")
def get_progress_summary(self) -> Dict[str, Any]:
"""获取进度摘要"""
@@ -397,12 +372,12 @@ class ReportAgent:
def load_state(self, filepath: str):
"""从文件加载状态"""
self.state = ReportState.load_from_file(filepath)
self.logger.info(f"状态已从 {filepath} 加载")
logger.info(f"状态已从 {filepath} 加载")
def save_state(self, filepath: str):
"""保存状态到文件"""
self.state.save_to_file(filepath)
self.logger.info(f"状态已保存到 {filepath}")
logger.info(f"状态已保存到 {filepath}")
def check_input_files(self, insight_dir: str, media_dir: str, query_dir: str, forum_log_path: str) -> Dict[str, Any]:
"""
@@ -488,9 +463,9 @@ class ReportAgent:
with open(file_paths[engine], 'r', encoding='utf-8') as f:
report_content = f.read()
content['reports'].append(report_content)
self.logger.info(f"已加载 {engine} 报告: {len(report_content)} 字符")
logger.info(f"已加载 {engine} 报告: {len(report_content)} 字符")
except Exception as e:
self.logger.error(f"加载 {engine} 报告失败: {str(e)}")
logger.exception(f"加载 {engine} 报告失败: {str(e)}")
content['reports'].append("")
# 加载论坛日志
@@ -498,9 +473,9 @@ class ReportAgent:
try:
with open(file_paths['forum'], 'r', encoding='utf-8') as f:
content['forum_logs'] = f.read()
self.logger.info(f"已加载论坛日志: {len(content['forum_logs'])} 字符")
logger.info(f"已加载论坛日志: {len(content['forum_logs'])} 字符")
except Exception as e:
self.logger.error(f"加载论坛日志失败: {str(e)}")
logger.exception(f"加载论坛日志失败: {str(e)}")
return content
@@ -515,5 +490,6 @@ def create_agent(config_file: Optional[str] = None) -> ReportAgent:
Returns:
ReportAgent实例
"""
config = load_config(config_file)
config = Settings() # 以空配置初始化,而从从环境变量初始化
return ReportAgent(config)
+59 -58
View File
@@ -10,9 +10,9 @@ import time
from datetime import datetime
from flask import Blueprint, request, jsonify, Response
from typing import Dict, Any
from loguru import logger
from .agent import ReportAgent, create_agent
from .utils.config import load_config
from .utils.config import settings
# 创建Blueprint
@@ -28,18 +28,17 @@ def initialize_report_engine():
"""初始化Report Engine"""
global report_agent
try:
config = load_config()
report_agent = create_agent()
print("Report Engine初始化成功")
logger.info("Report Engine初始化成功")
return True
except Exception as e:
print(f"Report Engine初始化失败: {str(e)}")
logger.exception(f"Report Engine初始化失败: {str(e)}")
return False
class ReportTask:
"""报告生成任务"""
def __init__(self, query: str, task_id: str, custom_template: str = ""):
self.task_id = task_id
self.query = query
@@ -51,7 +50,7 @@ class ReportTask:
self.created_at = datetime.now()
self.updated_at = datetime.now()
self.html_content = ""
def update_status(self, status: str, progress: int = None, error_message: str = ""):
"""更新任务状态"""
self.status = status
@@ -60,7 +59,7 @@ class ReportTask:
if error_message:
self.error_message = error_message
self.updated_at = datetime.now()
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
@@ -79,21 +78,21 @@ def check_engines_ready() -> Dict[str, Any]:
"""检查三个子引擎是否都有新文件"""
directories = {
'insight': 'insight_engine_streamlit_reports',
'media': 'media_engine_streamlit_reports',
'media': 'media_engine_streamlit_reports',
'query': 'query_engine_streamlit_reports'
}
forum_log_path = 'logs/forum.log'
if not report_agent:
return {
'ready': False,
'error': 'Report Engine未初始化'
}
return report_agent.check_input_files(
directories['insight'],
directories['media'],
directories['media'],
directories['query'],
forum_log_path
)
@@ -102,23 +101,23 @@ def check_engines_ready() -> Dict[str, Any]:
def run_report_generation(task: ReportTask, query: str, custom_template: str = ""):
"""在后台线程中运行报告生成"""
global current_task
try:
task.update_status("running", 10)
# 检查输入文件
check_result = check_engines_ready()
if not check_result['ready']:
task.update_status("error", 0, f"输入文件未准备就绪: {check_result.get('missing_files', [])}")
return
task.update_status("running", 30)
# 加载输入文件
content = report_agent.load_input_files(check_result['latest_files'])
task.update_status("running", 50)
# 生成报告
html_report = report_agent.generate_report(
query=query,
@@ -127,13 +126,13 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
custom_template=custom_template,
save_report=True
)
task.update_status("running", 90)
# 保存结果
task.html_content = html_report
task.update_status("completed", 100)
except Exception as e:
task.update_status("error", 0, str(e))
# 只在出错时清理任务
@@ -147,7 +146,7 @@ def get_status():
"""获取Report Engine状态"""
try:
engines_status = check_engines_ready()
return jsonify({
'success': True,
'initialized': report_agent is not None,
@@ -167,7 +166,7 @@ def get_status():
def generate_report():
"""开始生成报告"""
global current_task
try:
# 检查是否有任务在运行
with task_lock:
@@ -177,26 +176,26 @@ def generate_report():
'error': '已有报告生成任务在运行中',
'current_task': current_task.to_dict()
}), 400
# 如果有已完成的任务,清理它
if current_task and current_task.status in ["completed", "error"]:
current_task = None
# 获取请求参数
data = request.get_json() or {}
query = data.get('query', '智能舆情分析报告')
custom_template = data.get('custom_template', '')
# 清空日志文件
clear_report_log()
# 检查Report Engine是否初始化
if not report_agent:
return jsonify({
'success': False,
'error': 'Report Engine未初始化'
}), 500
# 检查输入文件是否准备就绪
engines_status = check_engines_ready()
if not engines_status['ready']:
@@ -205,14 +204,14 @@ def generate_report():
'error': '输入文件未准备就绪',
'missing_files': engines_status.get('missing_files', [])
}), 400
# 创建新任务
task_id = f"report_{int(time.time())}"
task = ReportTask(query, task_id, custom_template)
with task_lock:
current_task = task
# 在后台线程中运行报告生成
thread = threading.Thread(
target=run_report_generation,
@@ -220,14 +219,14 @@ def generate_report():
daemon=True
)
thread.start()
return jsonify({
'success': True,
'task_id': task_id,
'message': '报告生成已启动',
'task': task.to_dict()
})
except Exception as e:
return jsonify({
'success': False,
@@ -252,13 +251,14 @@ def get_progress(task_id: str):
'has_result': True
}
})
return jsonify({
'success': True,
'task': current_task.to_dict()
})
except Exception as e:
logger.exception(f"获取报告生成进度失败: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
@@ -274,20 +274,21 @@ def get_result(task_id: str):
'success': False,
'error': '任务不存在'
}), 404
if current_task.status != "completed":
return jsonify({
'success': False,
'error': '报告尚未完成',
'task': current_task.to_dict()
}), 400
return Response(
current_task.html_content,
mimetype='text/html'
)
except Exception as e:
logger.exception(f"获取报告生成结果失败: {str(e)}")
return jsonify({
'success': False,
'error': str(e)
@@ -303,20 +304,20 @@ def get_result_json(task_id: str):
'success': False,
'error': '任务不存在'
}), 404
if current_task.status != "completed":
return jsonify({
'success': False,
'error': '报告尚未完成',
'task': current_task.to_dict()
}), 400
return jsonify({
'success': True,
'task': current_task.to_dict(),
'html_content': current_task.html_content
})
except Exception as e:
return jsonify({
'success': False,
@@ -328,14 +329,14 @@ def get_result_json(task_id: str):
def cancel_task(task_id: str):
"""取消报告生成任务"""
global current_task
try:
with task_lock:
if current_task and current_task.task_id == task_id:
if current_task.status == "running":
current_task.update_status("cancelled", 0, "用户取消任务")
current_task = None
return jsonify({
'success': True,
'message': '任务已取消'
@@ -345,7 +346,7 @@ def cancel_task(task_id: str):
'success': False,
'error': '任务不存在或无法取消'
}), 404
except Exception as e:
return jsonify({
'success': False,
@@ -362,10 +363,10 @@ def get_templates():
'success': False,
'error': 'Report Engine未初始化'
}), 500
template_dir = report_agent.config.template_dir
template_dir = settings.TEMPLATE_DIR
templates = []
if os.path.exists(template_dir):
for filename in os.listdir(template_dir):
if filename.endswith('.md'):
@@ -373,7 +374,7 @@ def get_templates():
try:
with open(template_path, 'r', encoding='utf-8') as f:
content = f.read()
templates.append({
'name': filename.replace('.md', ''),
'filename': filename,
@@ -381,14 +382,14 @@ def get_templates():
'size': len(content)
})
except Exception as e:
print(f"读取模板失败 {filename}: {str(e)}")
logger.exception(f"读取模板失败 {filename}: {str(e)}")
return jsonify({
'success': True,
'templates': templates,
'template_dir': template_dir
})
except Exception as e:
return jsonify({
'success': False,
@@ -416,21 +417,19 @@ def internal_error(error):
def clear_report_log():
"""清空report.log文件"""
try:
config = load_config()
log_file = config.log_file
log_file = settings.LOG_FILE
with open(log_file, 'w', encoding='utf-8') as f:
f.write('')
print(f"已清空日志文件: {log_file}")
logger.info(f"已清空日志文件: {log_file}")
except Exception as e:
print(f"清空日志文件失败: {str(e)}")
logger.exception(f"清空日志文件失败: {str(e)}")
@report_bp.route('/log', methods=['GET'])
def get_report_log():
"""获取report.log内容"""
try:
config = load_config()
log_file = config.log_file
log_file = settings.LOG_FILE
if not os.path.exists(log_file):
return jsonify({
@@ -450,6 +449,7 @@ def get_report_log():
})
except Exception as e:
logger.exception(f"读取日志失败: {str(e)}")
return jsonify({
'success': False,
'error': f'读取日志失败: {str(e)}'
@@ -466,6 +466,7 @@ def clear_log():
'message': '日志已清空'
})
except Exception as e:
logger.exception(f"清空日志失败: {str(e)}")
return jsonify({
'success': False,
'error': f'清空日志失败: {str(e)}'
+3 -5
View File
@@ -3,12 +3,11 @@ Report Engine节点基类
定义所有处理节点的基础接口
"""
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from ..llms.base import LLMClient
from ..state.state import ReportState
from loguru import logger
class BaseNode(ABC):
"""节点基类"""
@@ -23,7 +22,6 @@ class BaseNode(ABC):
"""
self.llm_client = llm_client
self.node_name = node_name or self.__class__.__name__
self.logger = logging.getLogger('ReportEngine')
@abstractmethod
def run(self, input_data: Any, **kwargs) -> Any:
@@ -66,12 +64,12 @@ class BaseNode(ABC):
def log_info(self, message: str):
"""记录信息日志"""
formatted_message = f"[{self.node_name}] {message}"
self.logger.info(formatted_message)
logger.info(formatted_message)
def log_error(self, message: str):
"""记录错误日志"""
formatted_message = f"[{self.node_name}] {message}"
self.logger.error(formatted_message)
logger.error(formatted_message)
class StateMutationNode(BaseNode):
+9 -8
View File
@@ -6,6 +6,7 @@ HTML生成节点
import json
from datetime import datetime
from typing import Dict, Any
from loguru import logger
from .base_node import StateMutationNode
from ..llms.base import LLMClient
@@ -42,7 +43,7 @@ class HTMLGenerationNode(StateMutationNode):
Returns:
生成的HTML内容
"""
self.log_info("开始生成HTML报告...")
logger.info("开始生成HTML报告...")
try:
# 准备LLM输入数据
@@ -64,11 +65,11 @@ class HTMLGenerationNode(StateMutationNode):
# 处理响应(简化版)
processed_response = self.process_output(response)
self.log_info("HTML报告生成完成")
logger.info("HTML报告生成完成")
return processed_response
except Exception as e:
self.log_error(f"HTML生成失败: {str(e)}")
logger.exception(f"HTML生成失败: {str(e)}")
# 返回备用HTML
return self._generate_fallback_html(input_data)
@@ -104,7 +105,7 @@ class HTMLGenerationNode(StateMutationNode):
HTML内容
"""
try:
self.log_info(f"处理LLM原始输出,长度: {len(output)} 字符")
logger.info(f"处理LLM原始输出,长度: {len(output)} 字符")
html_content = output.strip()
@@ -120,14 +121,14 @@ class HTMLGenerationNode(StateMutationNode):
# 如果内容为空,返回原始输出
if not html_content:
self.log_info("处理后内容为空,返回原始输出")
logger.info("处理后内容为空,返回原始输出")
html_content = output
self.log_info(f"HTML处理完成,最终长度: {len(html_content)} 字符")
logger.info(f"HTML处理完成,最终长度: {len(html_content)} 字符")
return html_content
except Exception as e:
self.log_error(f"处理HTML输出失败: {str(e)},返回原始输出")
logger.exception(f"处理HTML输出失败: {str(e)},返回原始输出")
return output
def _generate_fallback_html(self, input_data: Dict[str, Any]) -> str:
@@ -140,7 +141,7 @@ class HTMLGenerationNode(StateMutationNode):
Returns:
备用HTML内容
"""
self.log_info("使用备用HTML生成方法")
logger.info("使用备用HTML生成方法")
query = input_data.get('query', '智能舆情分析报告')
query_report = input_data.get('query_engine_report', '')
+15 -14
View File
@@ -6,6 +6,7 @@
import os
import json
from typing import Dict, Any, List, Optional
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_TEMPLATE_SELECTION
@@ -38,7 +39,7 @@ class TemplateSelectionNode(BaseNode):
Returns:
选择的模板信息
"""
self.log_info("开始模板选择...")
logger.info("开始模板选择...")
query = input_data.get('query', '')
reports = input_data.get('reports', [])
@@ -48,7 +49,7 @@ class TemplateSelectionNode(BaseNode):
available_templates = self._get_available_templates()
if not available_templates:
self.log_info("未找到预设模板,使用内置默认模板")
logger.info("未找到预设模板,使用内置默认模板")
return self._get_fallback_template()
# 使用LLM进行模板选择
@@ -57,7 +58,7 @@ class TemplateSelectionNode(BaseNode):
if llm_result:
return llm_result
except Exception as e:
self.log_error(f"LLM模板选择失败: {str(e)}")
logger.exception(f"LLM模板选择失败: {str(e)}")
# 如果LLM选择失败,使用备选方案
return self._get_fallback_template()
@@ -67,7 +68,7 @@ class TemplateSelectionNode(BaseNode):
def _llm_template_selection(self, query: str, reports: List[Any], forum_logs: str,
available_templates: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""使用LLM进行模板选择"""
self.log_info("尝试使用LLM进行模板选择...")
logger.info("尝试使用LLM进行模板选择...")
# 构建模板列表
template_list = "\n".join([f"- {t['name']}: {t['description']}" for t in available_templates])
@@ -118,10 +119,10 @@ class TemplateSelectionNode(BaseNode):
# 检查响应是否为空
if not response or not response.strip():
self.log_error("LLM返回空响应")
logger.error("LLM返回空响应")
return None
self.log_info(f"LLM原始响应: {response}")
logger.info(f"LLM原始响应: {response}")
# 尝试解析JSON响应
try:
@@ -133,18 +134,18 @@ class TemplateSelectionNode(BaseNode):
selected_template_name = result.get('template_name', '')
for template in available_templates:
if template['name'] == selected_template_name or selected_template_name in template['name']:
self.log_info(f"LLM选择模板: {selected_template_name}")
logger.info(f"LLM选择模板: {selected_template_name}")
return {
'template_name': template['name'],
'template_content': template['content'],
'selection_reason': result.get('selection_reason', 'LLM智能选择')
}
self.log_error(f"LLM选择的模板不存在: {selected_template_name}")
logger.error(f"LLM选择的模板不存在: {selected_template_name}")
return None
except json.JSONDecodeError as e:
self.log_error(f"JSON解析失败: {str(e)}")
logger.exception(f"JSON解析失败: {str(e)}")
# 尝试从文本响应中提取模板信息
return self._extract_template_from_text(response, available_templates)
@@ -163,7 +164,7 @@ class TemplateSelectionNode(BaseNode):
def _extract_template_from_text(self, response: str, available_templates: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""从文本响应中提取模板信息"""
self.log_info("尝试从文本响应中提取模板信息")
logger.info("尝试从文本响应中提取模板信息")
# 查找响应中是否包含模板名称
for template in available_templates:
@@ -175,7 +176,7 @@ class TemplateSelectionNode(BaseNode):
for variant in template_name_variants:
if variant in response:
self.log_info(f"在响应中找到模板: {template['name']}")
logger.info(f"在响应中找到模板: {template['name']}")
return {
'template_name': template['name'],
'template_content': template['content'],
@@ -189,7 +190,7 @@ class TemplateSelectionNode(BaseNode):
templates = []
if not os.path.exists(self.template_dir):
self.log_error(f"模板目录不存在: {self.template_dir}")
logger.error(f"模板目录不存在: {self.template_dir}")
return templates
# 查找所有markdown模板文件
@@ -210,7 +211,7 @@ class TemplateSelectionNode(BaseNode):
'description': description
})
except Exception as e:
self.log_error(f"读取模板文件失败 {filename}: {str(e)}")
logger.exception(f"读取模板文件失败 {filename}: {str(e)}")
return templates
@@ -235,7 +236,7 @@ class TemplateSelectionNode(BaseNode):
def _get_fallback_template(self) -> Dict[str, Any]:
"""获取备用默认模板(空模板,让LLM自行发挥)"""
self.log_info("未找到合适模板,使用空模板让LLM自行发挥")
logger.info("未找到合适模板,使用空模板让LLM自行发挥")
return {
'template_name': '自由发挥模板',
-3
View File
@@ -3,9 +3,6 @@ Report Engine工具模块
包含配置管理
"""
from .config import Config, load_config
__all__ = [
"Config",
"load_config"
]
+43 -142
View File
@@ -3,150 +3,51 @@ Configuration management module for the Report Engine.
"""
import os
from dataclasses import dataclass
from pydantic_settings import BaseSettings
from pydantic import Field
from typing import Optional
from loguru import logger
def _get_value(source, key: str, default=None, *fallback_keys: str):
candidates = (key,) + fallback_keys
value = None
for candidate in candidates:
if isinstance(source, dict):
value = source.get(candidate)
else:
value = getattr(source, candidate, None)
if value not in (None, ""):
break
if value in (None, ""):
for candidate in candidates:
env_val = os.getenv(candidate)
if env_val not in (None, ""):
value = env_val
break
return value if value not in (None, "") else default
class Settings(BaseSettings):
"""Report Engine 配置,环境变量与字段均为REPORT_ENGINE_前缀一致大写。"""
REPORT_ENGINE_API_KEY: Optional[str] = Field(None, description="Report Engine LLM API密钥")
REPORT_ENGINE_BASE_URL: Optional[str] = Field(None, description="Report Engine LLM基础URL")
REPORT_ENGINE_MODEL_NAME: Optional[str] = Field(None, description="Report Engine LLM模型名称")
REPORT_ENGINE_PROVIDER: Optional[str] = Field(None, description="模型服务商,仅兼容保留")
MAX_CONTENT_LENGTH: int = Field(200000, description="最大内容长度")
OUTPUT_DIR: str = Field("final_reports", description="主输出目录")
TEMPLATE_DIR: str = Field("ReportEngine/report_template", description="多模板目录")
API_TIMEOUT: float = Field(900.0, description="单API超时时间(秒)")
MAX_RETRY_DELAY: float = Field(180.0, description="最大重试间隔(秒)")
MAX_RETRIES: int = Field(8, description="最大重试次数")
LOG_FILE: str = Field("logs/report.log", description="日志输出文件")
ENABLE_PDF_EXPORT: bool = Field(True, description="是否允许导出PDF")
CHART_STYLE: str = Field("modern", description="图表样式:modern/classic/")
class Config:
env_file = ".env"
env_prefix = ""
case_sensitive = False
extra = "allow"
settings = Settings()
@dataclass
class Config:
"""Report Engine configuration."""
llm_api_key: Optional[str] = None
llm_base_url: Optional[str] = None
llm_model_name: Optional[str] = None
llm_provider: Optional[str] = None # compatibility
max_content_length: int = 200000
output_dir: str = "final_reports"
template_dir: str = "ReportEngine/report_template"
api_timeout: float = 900.0
max_retry_delay: float = 180.0
max_retries: int = 8
log_file: str = "logs/report.log"
enable_pdf_export: bool = True
chart_style: str = "modern"
def __post_init__(self):
if not self.llm_provider and self.llm_model_name:
self.llm_provider = self.llm_model_name
def validate(self) -> bool:
if not self.llm_api_key:
print("错误: Report Engine LLM API Key 未设置 (REPORT_ENGINE_API_KEY)。")
return False
if not self.llm_model_name:
print("错误: Report Engine 模型名称未设置 (REPORT_ENGINE_MODEL_NAME)。")
return False
return True
@classmethod
def from_file(cls, config_file: str) -> "Config":
if config_file.endswith(".py"):
import importlib.util
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return cls(
llm_api_key=_get_value(config_module, "REPORT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_module, "REPORT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_module, "REPORT_ENGINE_MODEL_NAME"),
max_content_length=int(_get_value(config_module, "MAX_CONTENT_LENGTH", 200000)),
output_dir=_get_value(config_module, "REPORT_OUTPUT_DIR", "final_reports"),
template_dir=_get_value(config_module, "TEMPLATE_DIR", "ReportEngine/report_template"),
api_timeout=float(_get_value(config_module, "REPORT_API_TIMEOUT", 900.0)),
max_retry_delay=float(_get_value(config_module, "REPORT_MAX_RETRY_DELAY", 180.0)),
max_retries=int(_get_value(config_module, "REPORT_MAX_RETRIES", 8)),
log_file=_get_value(config_module, "REPORT_LOG_FILE", "logs/report.log"),
enable_pdf_export=str(
_get_value(config_module, "ENABLE_PDF_EXPORT", "true")
).lower()
in ("true", "1", "yes"),
chart_style=_get_value(config_module, "CHART_STYLE", "modern"),
)
config_dict = {}
if os.path.exists(config_file):
with open(config_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
config_dict[key.strip()] = value.strip()
return cls(
llm_api_key=_get_value(config_dict, "REPORT_ENGINE_API_KEY"),
llm_base_url=_get_value(config_dict, "REPORT_ENGINE_BASE_URL"),
llm_model_name=_get_value(config_dict, "REPORT_ENGINE_MODEL_NAME"),
max_content_length=int(_get_value(config_dict, "MAX_CONTENT_LENGTH", 200000)),
output_dir=_get_value(config_dict, "REPORT_OUTPUT_DIR", "final_reports"),
template_dir=_get_value(config_dict, "TEMPLATE_DIR", "ReportEngine/report_template"),
api_timeout=float(_get_value(config_dict, "REPORT_API_TIMEOUT", 900.0)),
max_retry_delay=float(_get_value(config_dict, "REPORT_MAX_RETRY_DELAY", 180.0)),
max_retries=int(_get_value(config_dict, "REPORT_MAX_RETRIES", 8)),
log_file=_get_value(config_dict, "REPORT_LOG_FILE", "logs/report.log"),
enable_pdf_export=str(
_get_value(config_dict, "ENABLE_PDF_EXPORT", "true")
).lower()
in ("true", "1", "yes"),
chart_style=_get_value(config_dict, "CHART_STYLE", "modern"),
)
def load_config(config_file: Optional[str] = None) -> Config:
if config_file:
if not os.path.exists(config_file):
raise FileNotFoundError(f"配置文件不存在: {config_file}")
file_to_load = config_file
else:
for candidate in ("config.py", "config.env", ".env"):
if os.path.exists(candidate):
file_to_load = candidate
print(f"已找到配置文件: {candidate}")
break
else:
raise FileNotFoundError("未找到配置文件,请创建 config.py。")
config = Config.from_file(file_to_load)
if not config.validate():
raise ValueError("Report Engine 配置校验失败,请检查 config.py 中的相关配置。")
return config
def print_config(config: Config):
print("\n=== Report Engine 配置 ===")
print(f"LLM 模型: {config.llm_model_name}")
print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
print(f"最大内容长度: {config.max_content_length}")
print(f"输出目录: {config.output_dir}")
print(f"模板目录: {config.template_dir}")
print(f"API 超时时间: {config.api_timeout}")
print(f"最大重试间隔: {config.max_retry_delay}")
print(f"最大重试次数: {config.max_retries}")
print(f"日志文件: {config.log_file}")
print(f"PDF 导出: {config.enable_pdf_export}")
print(f"图表样式: {config.chart_style}")
print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
print("========================\n")
def print_config(config: Settings):
message = ""
message += "\n=== Report Engine 配置 ===\n"
message += f"LLM 模型: {config.REPORT_ENGINE_MODEL_NAME}\n"
message += f"LLM Base URL: {config.REPORT_ENGINE_BASE_URL or '(默认)'}\n"
message += f"最大内容长度: {config.MAX_CONTENT_LENGTH}\n"
message += f"输出目录: {config.OUTPUT_DIR}\n"
message += f"模板目录: {config.TEMPLATE_DIR}\n"
message += f"API 超时时间: {config.API_TIMEOUT}\n"
message += f"最大重试间隔: {config.MAX_RETRY_DELAY}\n"
message += f"最大重试次数: {config.MAX_RETRIES}\n"
message += f"日志文件: {config.LOG_FILE}\n"
message += f"PDF 导出: {config.ENABLE_PDF_EXPORT}\n"
message += f"图表样式: {config.CHART_STYLE}\n"
message += f"LLM API Key: {'已配置' if config.REPORT_ENGINE_API_KEY else '未配置'}\n"
message += "=========================\n"
logger.info(message)