""" Report Agent主类 整合所有模块,实现完整的报告生成流程 """ import json import os import logging from datetime import datetime from typing import Optional, Dict, Any, List from .llms import GeminiLLM, BaseLLM from .nodes import ( TemplateSelectionNode, HTMLGenerationNode ) from .state import ReportState from .utils import Config, load_config class FileCountBaseline: """文件数量基准管理器""" def __init__(self): self.baseline_file = 'logs/report_baseline.json' self.baseline_data = self._load_baseline() def _load_baseline(self) -> Dict[str, int]: """加载基准数据""" try: if os.path.exists(self.baseline_file): with open(self.baseline_file, 'r', encoding='utf-8') as f: return json.load(f) except Exception as e: print(f"加载基准数据失败: {e}") return {} def _save_baseline(self): """保存基准数据""" try: os.makedirs(os.path.dirname(self.baseline_file), exist_ok=True) with open(self.baseline_file, 'w', encoding='utf-8') as f: json.dump(self.baseline_data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"保存基准数据失败: {e}") def initialize_baseline(self, directories: Dict[str, str]) -> Dict[str, int]: """初始化文件数量基准""" current_counts = {} for engine, directory in directories.items(): if os.path.exists(directory): md_files = [f for f in os.listdir(directory) if f.endswith('.md')] current_counts[engine] = len(md_files) else: current_counts[engine] = 0 # 保存基准数据 self.baseline_data = current_counts.copy() self._save_baseline() print(f"文件数量基准已初始化: {current_counts}") return current_counts def check_new_files(self, directories: Dict[str, str]) -> Dict[str, Any]: """检查是否有新文件""" current_counts = {} new_files_found = {} all_have_new = True for engine, directory in directories.items(): if os.path.exists(directory): md_files = [f for f in os.listdir(directory) if f.endswith('.md')] current_counts[engine] = len(md_files) baseline_count = self.baseline_data.get(engine, 0) if current_counts[engine] > baseline_count: new_files_found[engine] = current_counts[engine] - baseline_count else: new_files_found[engine] = 0 all_have_new = False else: current_counts[engine] = 0 new_files_found[engine] = 0 all_have_new = False return { 'ready': all_have_new, 'baseline_counts': self.baseline_data, 'current_counts': current_counts, 'new_files_found': new_files_found, 'missing_engines': [engine for engine, count in new_files_found.items() if count == 0] } def get_latest_files(self, directories: Dict[str, str]) -> Dict[str, str]: """获取每个目录的最新文件""" latest_files = {} for engine, directory in directories.items(): if os.path.exists(directory): md_files = [f for f in os.listdir(directory) if f.endswith('.md')] if md_files: latest_file = max(md_files, key=lambda x: os.path.getmtime(os.path.join(directory, x))) latest_files[engine] = os.path.join(directory, latest_file) return latest_files class ReportAgent: """Report Agent主类""" def __init__(self, config: Optional[Config] = None): """ 初始化Report Agent Args: config: 配置对象,如果不提供则自动加载 """ # 加载配置 self.config = config or load_config() # 初始化文件基准管理器 self.file_baseline = FileCountBaseline() # 初始化日志 self._setup_logging() # 初始化LLM客户端 self.llm_client = self._initialize_llm() # 初始化节点 self._initialize_nodes() # 初始化文件数量基准 self._initialize_file_baseline() # 状态 self.state = ReportState() # 确保输出目录存在 os.makedirs(self.config.output_dir, exist_ok=True) self.logger.info("Report Agent已初始化") self.logger.info(f"使用LLM: {self.llm_client.get_model_info()}") def _setup_logging(self): """设置日志""" # 确保日志目录存在 log_dir = os.path.dirname(self.config.log_file) os.makedirs(log_dir, exist_ok=True) # 创建专用的logger,避免与其他模块冲突 self.logger = logging.getLogger('ReportEngine') self.logger.setLevel(logging.INFO) # 清除已有的handlers if self.logger.handlers: self.logger.handlers.clear() # 创建文件handler file_handler = logging.FileHandler(self.config.log_file, encoding='utf-8') file_handler.setLevel(logging.INFO) # 创建控制台handler console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) # 设置格式 formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) console_handler.setFormatter(formatter) # 添加handlers self.logger.addHandler(file_handler) self.logger.addHandler(console_handler) # 防止日志向上传播 self.logger.propagate = False def _initialize_file_baseline(self): """初始化文件数量基准""" directories = { 'insight': 'insight_engine_streamlit_reports', 'media': 'media_engine_streamlit_reports', 'query': 'query_engine_streamlit_reports' } self.file_baseline.initialize_baseline(directories) def _initialize_llm(self) -> BaseLLM: """初始化LLM客户端""" if self.config.default_llm_provider == "gemini": return GeminiLLM( api_key=self.config.gemini_api_key, model_name=self.config.gemini_model, config=self.config # 传入配置对象以支持动态超时设置 ) else: raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}") def _initialize_nodes(self): """初始化处理节点""" self.template_selection_node = TemplateSelectionNode( self.llm_client, self.config.template_dir ) self.html_generation_node = HTMLGenerationNode(self.llm_client) def generate_report(self, query: str, reports: List[Any], forum_logs: str = "", custom_template: str = "", save_report: bool = True) -> str: """ 生成综合报告 Args: query: 原始查询 reports: 三个子agent的报告列表(按顺序:QueryEngine, MediaEngine, InsightEngine) forum_logs: 论坛日志内容 custom_template: 用户自定义模板(可选) save_report: 是否保存报告到文件 Returns: 最终HTML报告内容 """ start_time = datetime.now() self.logger.info(f"开始生成报告: {query}") self.logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(forum_logs)}") try: # Step 1: 模板选择 template_result = self._select_template(query, reports, forum_logs, custom_template) # Step 2: 直接生成HTML报告 html_report = self._generate_html_report(query, reports, forum_logs, template_result) # Step 3: 保存报告 if save_report: self._save_report(html_report) # 更新生成时间 end_time = datetime.now() generation_time = (end_time - start_time).total_seconds() self.state.metadata.generation_time = generation_time self.logger.info(f"报告生成完成,耗时: {generation_time:.2f} 秒") return html_report except Exception as e: self.logger.error(f"报告生成过程中发生错误: {str(e)}") raise e def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str): """选择报告模板""" self.logger.info("选择报告模板...") # 如果用户提供了自定义模板,直接使用 if custom_template: self.logger.info("使用用户自定义模板") return { 'template_name': 'custom', 'template_content': custom_template, 'selection_reason': '用户指定的自定义模板' } template_input = { 'query': query, 'reports': reports, 'forum_logs': forum_logs } try: template_result = self.template_selection_node.run(template_input) # 更新状态 self.state.metadata.template_used = template_result['template_name'] self.logger.info(f"选择模板: {template_result['template_name']}") self.logger.info(f"选择理由: {template_result['selection_reason']}") return template_result except Exception as e: self.logger.error(f"模板选择失败,使用默认模板: {str(e)}") # 直接使用备用模板 fallback_template = { 'template_name': '社会公共热点事件分析报告模板', 'template_content': self._get_fallback_template_content(), 'selection_reason': '模板选择失败,使用默认社会热点事件分析模板' } self.state.metadata.template_used = fallback_template['template_name'] return fallback_template def _generate_html_report(self, query: str, reports: List[Any], forum_logs: str, template_result: Dict[str, Any]) -> str: """生成HTML报告""" self.logger.info("多轮生成HTML报告...") # 准备报告内容,确保有3个报告 query_report = reports[0] if len(reports) > 0 else "" media_report = reports[1] if len(reports) > 1 else "" insight_report = reports[2] if len(reports) > 2 else "" # 转换为字符串格式 query_report = str(query_report) if query_report else "" media_report = str(media_report) if media_report else "" insight_report = str(insight_report) if insight_report else "" html_input = { 'query': query, 'query_engine_report': query_report, 'media_engine_report': media_report, 'insight_engine_report': insight_report, 'forum_logs': forum_logs, 'selected_template': template_result.get('template_content', '') } # 使用HTML生成节点生成报告 html_content = self.html_generation_node.run(html_input) # 更新状态 self.state.html_content = html_content self.state.mark_completed() self.logger.info("HTML报告生成完成") return html_content def _get_fallback_template_content(self) -> str: """获取备用模板内容""" return """# 社会公共热点事件分析报告 ## 执行摘要 本报告针对当前社会热点事件进行综合分析,整合了多方信息源的观点和数据。 ## 事件概况 ### 基本信息 - 事件性质:{event_nature} - 发生时间:{event_time} - 涉及范围:{event_scope} ## 舆情态势分析 ### 整体趋势 {sentiment_analysis} ### 主要观点分布 {opinion_distribution} ## 媒体报道分析 ### 主流媒体态度 {media_analysis} ### 报道重点 {report_focus} ## 社会影响评估 ### 直接影响 {direct_impact} ### 潜在影响 {potential_impact} ## 应对建议 ### 即时措施 {immediate_actions} ### 长期策略 {long_term_strategy} ## 结论与展望 {conclusion} --- *报告类型:社会公共热点事件分析* *生成时间:{generation_time}* """ def _save_report(self, html_content: str): """保存报告到文件""" # 生成文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") query_safe = "".join(c for c in self.state.metadata.query if c.isalnum() or c in (' ', '-', '_')).rstrip() query_safe = query_safe.replace(' ', '_')[:30] filename = f"final_report_{query_safe}_{timestamp}.html" filepath = os.path.join(self.config.output_dir, filename) # 保存HTML报告 with open(filepath, 'w', encoding='utf-8') as f: f.write(html_content) self.logger.info(f"报告已保存到: {filepath}") # 保存状态 state_filename = f"report_state_{query_safe}_{timestamp}.json" state_filepath = os.path.join(self.config.output_dir, state_filename) self.state.save_to_file(state_filepath) self.logger.info(f"状态已保存到: {state_filepath}") def get_progress_summary(self) -> Dict[str, Any]: """获取进度摘要""" return self.state.to_dict() def load_state(self, filepath: str): """从文件加载状态""" self.state = ReportState.load_from_file(filepath) self.logger.info(f"状态已从 {filepath} 加载") def save_state(self, filepath: str): """保存状态到文件""" self.state.save_to_file(filepath) self.logger.info(f"状态已保存到 {filepath}") def check_input_files(self, insight_dir: str, media_dir: str, query_dir: str, forum_log_path: str) -> Dict[str, Any]: """ 检查输入文件是否准备就绪(基于文件数量增加) Args: insight_dir: InsightEngine报告目录 media_dir: MediaEngine报告目录 query_dir: QueryEngine报告目录 forum_log_path: 论坛日志文件路径 Returns: 检查结果字典 """ # 检查各个报告目录的文件数量变化 directories = { 'insight': insight_dir, 'media': media_dir, 'query': query_dir } # 使用文件基准管理器检查新文件 check_result = self.file_baseline.check_new_files(directories) # 检查论坛日志 forum_ready = os.path.exists(forum_log_path) # 构建返回结果 result = { 'ready': check_result['ready'] and forum_ready, 'baseline_counts': check_result['baseline_counts'], 'current_counts': check_result['current_counts'], 'new_files_found': check_result['new_files_found'], 'missing_files': [], 'files_found': [], 'latest_files': {} } # 构建详细信息 for engine, new_count in check_result['new_files_found'].items(): current_count = check_result['current_counts'][engine] baseline_count = check_result['baseline_counts'].get(engine, 0) if new_count > 0: result['files_found'].append(f"{engine}: {current_count}个文件 (新增{new_count}个)") else: result['missing_files'].append(f"{engine}: {current_count}个文件 (基准{baseline_count}个,无新增)") # 检查论坛日志 if forum_ready: result['files_found'].append(f"forum: {os.path.basename(forum_log_path)}") else: result['missing_files'].append("forum: 日志文件不存在") # 获取最新文件路径(用于实际报告生成) if result['ready']: result['latest_files'] = self.file_baseline.get_latest_files(directories) if forum_ready: result['latest_files']['forum'] = forum_log_path return result def load_input_files(self, file_paths: Dict[str, str]) -> Dict[str, Any]: """ 加载输入文件内容 Args: file_paths: 文件路径字典 Returns: 加载的内容字典 """ content = { 'reports': [], 'forum_logs': '' } # 加载报告文件 engines = ['query', 'media', 'insight'] for engine in engines: if engine in file_paths: try: with open(file_paths[engine], 'r', encoding='utf-8') as f: report_content = f.read() content['reports'].append(report_content) self.logger.info(f"已加载 {engine} 报告: {len(report_content)} 字符") except Exception as e: self.logger.error(f"加载 {engine} 报告失败: {str(e)}") content['reports'].append("") # 加载论坛日志 if 'forum' in file_paths: try: with open(file_paths['forum'], 'r', encoding='utf-8') as f: content['forum_logs'] = f.read() self.logger.info(f"已加载论坛日志: {len(content['forum_logs'])} 字符") except Exception as e: self.logger.error(f"加载论坛日志失败: {str(e)}") return content def create_agent(config_file: Optional[str] = None) -> ReportAgent: """ 创建Report Agent实例的便捷函数 Args: config_file: 配置文件路径 Returns: ReportAgent实例 """ config = load_config(config_file) return ReportAgent(config)