From fa787af1358b31ad1c3c318b5c95b863a0c72883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E4=B8=80=E4=B8=81?= <1769123563@qq.com> Date: Thu, 13 Nov 2025 22:30:36 +0800 Subject: [PATCH] Streaming --- ReportEngine/agent.py | 139 ++++++++++++++- ReportEngine/flask_interface.py | 299 ++++++++++++++++++++++++++++---- ReportEngine/nodes/__init__.py | 3 +- 3 files changed, 402 insertions(+), 39 deletions(-) diff --git a/ReportEngine/agent.py b/ReportEngine/agent.py index 6a0ece7..9c21d20 100644 --- a/ReportEngine/agent.py +++ b/ReportEngine/agent.py @@ -8,7 +8,7 @@ import os from pathlib import Path from uuid import uuid4 from datetime import datetime -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Callable from loguru import logger @@ -23,6 +23,7 @@ from .llms import LLMClient from .nodes import ( TemplateSelectionNode, ChapterGenerationNode, + ChapterJsonParseError, DocumentLayoutNode, WordBudgetNode, ) @@ -205,10 +206,11 @@ class ReportAgent: ) def generate_report(self, query: str, reports: List[Any], forum_logs: str = "", - custom_template: str = "", save_report: bool = True) -> str: + custom_template: str = "", save_report: bool = True, + stream_handler: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> str: """ 生成综合报告(章节JSON → IR → HTML) - + Returns: dict: HTML内容以及保存的文件路径信息 """ @@ -220,15 +222,32 @@ class ReportAgent: self.state.mark_processing() normalized_reports = self._normalize_reports(reports) + + def emit(event_type: str, payload: Dict[str, Any]): + """面向Report Engine流通道的事件分发器,保证错误不外泄。""" + if not stream_handler: + return + try: + stream_handler(event_type, payload) + except Exception as callback_error: # pragma: no cover - 仅记录 + logger.warning(f"流式事件回调失败: {callback_error}") + logger.info(f"开始生成报告 {report_id}: {query}") logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(str(forum_logs))}") + emit('stage', {'stage': 'agent_start', 'report_id': report_id, 'query': query}) try: template_result = self._select_template(query, reports, forum_logs, custom_template) self.state.metadata.template_used = template_result.get('template_name', '') + emit('stage', { + 'stage': 'template_selected', + 'template': template_result.get('template_name'), + 'reason': template_result.get('selection_reason') + }) sections = self._slice_template(template_result.get('template_content', '')) if not sections: raise ValueError("模板无法解析出章节,请检查模板内容。") + emit('stage', {'stage': 'template_sliced', 'section_count': len(sections)}) template_text = template_result.get('template_content', '') template_overview = self._build_template_overview(template_text, sections) @@ -241,6 +260,11 @@ class ReportAgent: query, template_overview, ) + emit('stage', { + 'stage': 'layout_designed', + 'title': layout_design.get('title'), + 'toc': layout_design.get('tocTitle') + }) # 使用刚生成的设计稿对全书进行篇幅规划,约束各章字数与重点 word_plan = self.word_budget_node.run( sections, @@ -250,6 +274,10 @@ class ReportAgent: query, template_overview, ) + emit('stage', { + 'stage': 'word_plan_ready', + 'chapter_targets': len(word_plan.get('chapters', [])) + }) # 记录每个章节的目标字数/强调点,后续传给章节LLM chapter_targets = { entry.get("chapterId"): entry @@ -296,23 +324,97 @@ class ReportAgent: # 初始化章节输出目录并写入manifest,方便流式存盘 run_dir = self.chapter_storage.start_session(report_id, manifest_meta) self._persist_planning_artifacts(run_dir, layout_design, word_plan, template_overview) + emit('stage', {'stage': 'storage_ready', 'run_dir': str(run_dir)}) chapters = [] + chapter_max_attempts = max(1, self.config.CHAPTER_JSON_MAX_ATTEMPTS) for section in sections: logger.info(f"生成章节: {section.title}") - chapter = self.chapter_generation_node.run( - section, - generation_context, - run_dir - ) - chapters.append(chapter) + emit('chapter_status', { + 'chapterId': section.chapter_id, + 'title': section.title, + 'status': 'running' + }) + # 章节流式回调:把LLM返回的delta透传给SSE,便于前端实时渲染 + def chunk_callback(delta: str, meta: Dict[str, Any], section_ref: TemplateSection = section): + emit('chapter_chunk', { + 'chapterId': meta.get('chapterId') or section_ref.chapter_id, + 'title': meta.get('title') or section_ref.title, + 'delta': delta + }) + + chapter_payload: Dict[str, Any] | None = None + attempt = 1 + while attempt <= chapter_max_attempts: + try: + chapter_payload = self.chapter_generation_node.run( + section, + generation_context, + run_dir, + stream_callback=chunk_callback + ) + break + except ChapterJsonParseError as parse_error: + logger.warning( + "章节 %s JSON解析失败(第 %s/%s 次尝试): %s", + section.title, + attempt, + chapter_max_attempts, + parse_error, + ) + emit('chapter_status', { + 'chapterId': section.chapter_id, + 'title': section.title, + 'status': 'retrying' if attempt < chapter_max_attempts else 'error', + 'attempt': attempt, + 'error': str(parse_error), + }) + if attempt >= chapter_max_attempts: + raise + attempt += 1 + continue + except Exception as chapter_error: + if not self._should_retry_inappropriate_content_error(chapter_error): + raise + logger.warning( + "章节 %s 触发内容安全限制(第 %s/%s 次尝试),准备重新生成: %s", + section.title, + attempt, + chapter_max_attempts, + chapter_error, + ) + emit('chapter_status', { + 'chapterId': section.chapter_id, + 'title': section.title, + 'status': 'retrying' if attempt < chapter_max_attempts else 'error', + 'attempt': attempt, + 'error': str(chapter_error), + 'reason': 'content_filter' + }) + if attempt >= chapter_max_attempts: + raise + attempt += 1 + continue + if chapter_payload is None: + raise ChapterJsonParseError( + f"{section.title} 章节JSON在 {chapter_max_attempts} 次尝试后仍无法解析" + ) + chapters.append(chapter_payload) + emit('chapter_status', { + 'chapterId': section.chapter_id, + 'title': section.title, + 'status': 'completed', + 'attempt': attempt, + }) document_ir = self.document_composer.build_document( report_id, manifest_meta, chapters ) + emit('stage', {'stage': 'chapters_compiled', 'chapter_count': len(chapters)}) html_report = self.renderer.render(document_ir) + emit('stage', {'stage': 'html_rendered', 'html_length': len(html_report)}) self.state.html_content = html_report self.state.mark_completed() @@ -320,10 +422,12 @@ class ReportAgent: saved_files = {} if save_report: saved_files = self._save_report(html_report, document_ir, report_id) + emit('stage', {'stage': 'report_saved', 'files': saved_files}) generation_time = (datetime.now() - start_time).total_seconds() self.state.metadata.generation_time = generation_time logger.info(f"报告生成完成,耗时: {generation_time:.2f} 秒") + emit('metrics', {'generation_seconds': generation_time}) return { 'html_content': html_report, 'report_id': report_id, @@ -333,6 +437,7 @@ class ReportAgent: except Exception as e: self.state.mark_failed(str(e)) logger.exception(f"报告生成过程中发生错误: {str(e)}") + emit('error', {'stage': 'agent_failed', 'message': str(e)}) raise def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str): @@ -444,6 +549,22 @@ class ReportAgent: normalized[key] = self._stringify(value) return normalized + def _should_retry_inappropriate_content_error(self, error: Exception) -> bool: + """ + 判断LLM异常是否由内容安全/不当内容导致,满足时允许重新生成整章。 + """ + message = str(error) if error else "" + if not message: + return False + normalized = message.lower() + keywords = [ + "inappropriate content", + "content violation", + "content moderation", + "model-studio/error-code", + ] + return any(keyword in normalized for keyword in keywords) + def _stringify(self, value: Any) -> str: """安全地将对象转成字符串""" if value is None: diff --git a/ReportEngine/flask_interface.py b/ReportEngine/flask_interface.py index a95816f..75679cc 100644 --- a/ReportEngine/flask_interface.py +++ b/ReportEngine/flask_interface.py @@ -7,9 +7,11 @@ import os import json import threading import time +from collections import deque, defaultdict from datetime import datetime -from flask import Blueprint, request, jsonify, Response, send_file -from typing import Dict, Any +from queue import Queue, Empty +from flask import Blueprint, request, jsonify, Response, send_file, stream_with_context +from typing import Dict, Any, List, Optional from loguru import logger from .agent import ReportAgent, create_agent from .utils.config import settings @@ -23,6 +25,69 @@ report_agent = None current_task = None task_lock = threading.Lock() +# ====== 流式推送与任务历史管理 ====== +# 通过有界deque缓存最近的事件,方便SSE断线后快速补发 +MAX_TASK_HISTORY = 5 +STREAM_HEARTBEAT_INTERVAL = 15 # 心跳间隔秒 +stream_lock = threading.Lock() +stream_subscribers = defaultdict(list) +tasks_registry: Dict[str, 'ReportTask'] = {} + + +def _register_stream(task_id: str) -> Queue: + """为指定任务注册一个事件队列,供SSE监听器消费。""" + queue = Queue() + with stream_lock: + stream_subscribers[task_id].append(queue) + return queue + + +def _unregister_stream(task_id: str, queue: Queue): + """安全移除事件队列,避免内存泄漏。""" + with stream_lock: + listeners = stream_subscribers.get(task_id, []) + if queue in listeners: + listeners.remove(queue) + if not listeners and task_id in stream_subscribers: + stream_subscribers.pop(task_id, None) + + +def _broadcast_event(task_id: str, event: Dict[str, Any]): + """将事件推送给所有监听者,失败时做好异常捕获。""" + with stream_lock: + listeners = list(stream_subscribers.get(task_id, [])) + for queue in listeners: + try: + queue.put(event, timeout=0.1) + except Exception: + logger.exception("推送流式事件失败,跳过当前监听队列") + + +def _prune_task_history_locked(): + """在task_lock持有期间调用,清理过多的历史任务以控制内存。""" + if len(tasks_registry) <= MAX_TASK_HISTORY: + return + # 按创建时间排序,移除最旧的任务 + sorted_tasks = sorted(tasks_registry.values(), key=lambda t: t.created_at) + for task in sorted_tasks[:-MAX_TASK_HISTORY]: + tasks_registry.pop(task.task_id, None) + + +def _get_task(task_id: str) -> Optional['ReportTask']: + """统一的任务查找方法,优先返回当前任务。""" + with task_lock: + if current_task and current_task.task_id == task_id: + return current_task + return tasks_registry.get(task_id) + + +def _format_sse(event: Dict[str, Any]) -> str: + """按SSE协议格式化消息。""" + payload = json.dumps(event, ensure_ascii=False) + event_id = event.get('id', 0) + event_type = event.get('type', 'message') + return f"id: {event_id}\nevent: {event_type}\ndata: {payload}\n\n" + def initialize_report_engine(): """初始化Report Engine""" @@ -63,6 +128,11 @@ class ReportTask: self.report_file_name = "" self.state_file_path = "" self.state_file_relative_path = "" + # ====== 流式事件缓存与并发保护 ====== + # 使用deque保存最近的事件,结合锁保证多线程下的安全访问 + self.event_history: deque = deque(maxlen=1000) + self._event_lock = threading.Lock() + self.last_event_id = 0 def update_status(self, status: str, progress: int = None, error_message: str = ""): """更新任务状态""" @@ -72,6 +142,17 @@ class ReportTask: if error_message: self.error_message = error_message self.updated_at = datetime.now() + # 推送状态变更事件,方便前端实时刷新 + self.publish_event( + 'status', + { + 'status': self.status, + 'progress': self.progress, + 'error_message': self.error_message, + 'hint': error_message or '', + 'task': self.to_dict(), + } + ) def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" @@ -91,6 +172,29 @@ class ReportTask: 'state_file_path': self.state_file_relative_path or self.state_file_path } + def publish_event(self, event_type: str, payload: Dict[str, Any]) -> None: + """将任意事件放入缓存并广播,所有新增逻辑均配套中文说明。""" + timestamp = datetime.utcnow().isoformat() + 'Z' + event: Dict[str, Any] = { + 'id': 0, + 'type': event_type, + 'task_id': self.task_id, + 'timestamp': timestamp, + 'payload': payload, + } + with self._event_lock: + self.last_event_id += 1 + event['id'] = self.last_event_id + self.event_history.append(event) + _broadcast_event(self.task_id, event) + + def history_since(self, last_event_id: Optional[int]) -> List[Dict[str, Any]]: + """根据Last-Event-ID补发历史事件,确保断线重连无遗漏。""" + with self._event_lock: + if last_event_id is None: + return list(self.event_history) + return [evt for evt in self.event_history if evt['id'] > last_event_id] + def check_engines_ready() -> Dict[str, Any]: """检查三个子引擎是否都有新文件""" @@ -121,7 +225,13 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " global current_task try: + # 在局部闭包内封装推送逻辑,便于传递给ReportAgent + def stream_handler(event_type: str, payload: Dict[str, Any]): + """所有阶段事件都通过同一个接口分发,保证日志一致。""" + task.publish_event(event_type, payload) + task.update_status("running", 10) + task.publish_event('stage', {'message': '任务已启动,正在检查输入文件', 'stage': 'prepare'}) # 检查输入文件 check_result = check_engines_ready() @@ -129,21 +239,54 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " task.update_status("error", 0, f"输入文件未准备就绪: {check_result.get('missing_files', [])}") return + task.publish_event('stage', { + 'message': '输入文件检查通过,准备载入内容', + 'stage': 'io_ready', + 'files': check_result.get('latest_files', {}) + }) + task.update_status("running", 30) # 加载输入文件 content = report_agent.load_input_files(check_result['latest_files']) + task.publish_event('stage', {'message': '源数据加载完成,启动生成流程', 'stage': 'data_loaded'}) task.update_status("running", 50) - # 生成报告 - generation_result = report_agent.generate_report( - query=query, - reports=content['reports'], - forum_logs=content['forum_logs'], - custom_template=custom_template, - save_report=True - ) + # 生成报告(附带兜底重试,缓解瞬时网络抖动) + for attempt in range(1, 3): + try: + task.publish_event('stage', { + 'message': f'正在调用ReportAgent生成报告(第{attempt}次尝试)', + 'stage': 'agent_running', + 'attempt': attempt + }) + generation_result = report_agent.generate_report( + query=query, + reports=content['reports'], + forum_logs=content['forum_logs'], + custom_template=custom_template, + save_report=True, + stream_handler=stream_handler + ) + break + except Exception as err: + # 将错误即时推送至前端,方便观察重试策略 + task.publish_event('warning', { + 'message': f'ReportAgent执行失败: {str(err)}', + 'stage': 'agent_running', + 'attempt': attempt + }) + if attempt == 2: + raise + # 简单的指数退避,防止频繁触发限流(单位秒) + backoff = min(5 * attempt, 15) + task.publish_event('stage', { + 'message': f'{backoff} 秒后重试生成任务', + 'stage': 'retry_wait', + 'wait_seconds': backoff + }) + time.sleep(backoff) if isinstance(generation_result, dict): html_report = generation_result.get('html_content', '') @@ -151,6 +294,7 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " html_report = generation_result task.update_status("running", 90) + task.publish_event('stage', {'message': '报告生成完毕,准备持久化', 'stage': 'persist'}) # 保存结果 task.html_content = html_report @@ -160,11 +304,28 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " task.report_file_name = generation_result.get('report_filename', '') task.state_file_path = generation_result.get('state_filepath', '') task.state_file_relative_path = generation_result.get('state_relative_path', '') + task.publish_event('html_ready', { + 'message': 'HTML渲染完成,可刷新预览', + 'report_file': task.report_file_relative_path or task.report_file_path, + 'state_file': task.state_file_relative_path or task.state_file_path, + 'task': task.to_dict(), + }) task.update_status("completed", 100) + task.publish_event('completed', { + 'message': '任务完成', + 'duration_seconds': (task.updated_at - task.created_at).total_seconds(), + 'report_file': task.report_file_relative_path or task.report_file_path, + 'task': task.to_dict(), + }) except Exception as e: logger.exception(f"报告生成过程中发生错误: {str(e)}") task.update_status("error", 0, str(e)) + task.publish_event('error', { + 'message': str(e), + 'stage': 'failed', + 'task': task.to_dict(), + }) # 只在出错时清理任务 with task_lock: if current_task and current_task.task_id == task.task_id: @@ -242,6 +403,19 @@ def generate_report(): with task_lock: current_task = task + tasks_registry[task_id] = task + _prune_task_history_locked() + + # 通过主动推送pending事件告知前端任务已经排队 + task.publish_event( + 'status', + { + 'status': task.status, + 'progress': task.progress, + 'message': '任务已排队,等待资源空闲', + 'task': task.to_dict(), + } + ) # 在后台线程中运行报告生成 thread = threading.Thread( @@ -255,7 +429,8 @@ def generate_report(): 'success': True, 'task_id': task_id, 'message': '报告生成已启动', - 'task': task.to_dict() + 'task': task.to_dict(), + 'stream_url': f"/api/report/stream/{task_id}" }) except Exception as e: @@ -270,9 +445,9 @@ def generate_report(): def get_progress(task_id: str): """获取报告生成进度""" try: - if not current_task or current_task.task_id != task_id: - # 如果任务不存在,可能是已经完成并被清理了 - # 返回一个默认的完成状态而不是404 + task = _get_task(task_id) + if not task: + # 如果任务不存在,可能是历史记录已被清理,回传一个完成态兜底 return jsonify({ 'success': True, 'task': { @@ -291,7 +466,7 @@ def get_progress(task_id: str): return jsonify({ 'success': True, - 'task': current_task.to_dict() + 'task': task.to_dict() }) except Exception as e: @@ -302,25 +477,78 @@ def get_progress(task_id: str): }), 500 +@report_bp.route('/stream/', methods=['GET']) +def stream_task(task_id: str): + """基于SSE的实时推送接口,向前端持续广播阶段事件。""" + task = _get_task(task_id) + if not task: + return jsonify({'success': False, 'error': '任务不存在'}), 404 + + last_event_header = request.headers.get('Last-Event-ID') + try: + last_event_id = int(last_event_header) if last_event_header else None + except ValueError: + last_event_id = None + + def event_generator(): + queue = _register_stream(task_id) + try: + # 断线重连场景下,先补发历史事件,保证界面状态一致 + history = task.history_since(last_event_id) + for event in history: + yield _format_sse(event) + + finished = task.status in ("completed", "error", "cancelled") + while True: + if finished: + break + try: + event = queue.get(timeout=STREAM_HEARTBEAT_INTERVAL) + yield _format_sse(event) + if event.get('type') in ("completed", "error"): + finished = True + except Empty: + heartbeat = { + 'id': f"hb-{int(time.time() * 1000)}", + 'type': 'heartbeat', + 'task_id': task_id, + 'timestamp': datetime.utcnow().isoformat() + 'Z', + 'payload': {'status': task.status} + } + yield _format_sse(heartbeat) + finished = task.status in ("completed", "error", "cancelled") + finally: + _unregister_stream(task_id, queue) + + response = Response( + stream_with_context(event_generator()), + mimetype='text/event-stream' + ) + response.headers['Cache-Control'] = 'no-cache' + response.headers['X-Accel-Buffering'] = 'no' + return response + + @report_bp.route('/result/', methods=['GET']) def get_result(task_id: str): """获取报告生成结果""" try: - if not current_task or current_task.task_id != task_id: + task = _get_task(task_id) + if not task: return jsonify({ 'success': False, 'error': '任务不存在' }), 404 - if current_task.status != "completed": + if task.status != "completed": return jsonify({ 'success': False, 'error': '报告尚未完成', - 'task': current_task.to_dict() + 'task': task.to_dict() }), 400 return Response( - current_task.html_content, + task.html_content, mimetype='text/html' ) @@ -336,23 +564,24 @@ def get_result(task_id: str): def get_result_json(task_id: str): """获取报告生成结果(JSON格式)""" try: - if not current_task or current_task.task_id != task_id: + task = _get_task(task_id) + if not task: return jsonify({ 'success': False, 'error': '任务不存在' }), 404 - if current_task.status != "completed": + if task.status != "completed": return jsonify({ 'success': False, 'error': '报告尚未完成', - 'task': current_task.to_dict() + 'task': task.to_dict() }), 400 return jsonify({ 'success': True, - 'task': current_task.to_dict(), - 'html_content': current_task.html_content + 'task': task.to_dict(), + 'html_content': task.html_content }) except Exception as e: @@ -367,27 +596,28 @@ def get_result_json(task_id: str): def download_report(task_id: str): """下载已生成的报告HTML文件""" try: - if not current_task or current_task.task_id != task_id: + task = _get_task(task_id) + if not task: return jsonify({ 'success': False, 'error': '任务不存在' }), 404 - if current_task.status != "completed" or not current_task.report_file_path: + if task.status != "completed" or not task.report_file_path: return jsonify({ 'success': False, 'error': '报告尚未完成或尚未保存' }), 400 - if not os.path.exists(current_task.report_file_path): + if not os.path.exists(task.report_file_path): return jsonify({ 'success': False, 'error': '报告文件不存在或已被删除' }), 404 - download_name = current_task.report_file_name or os.path.basename(current_task.report_file_path) + download_name = task.report_file_name or os.path.basename(task.report_file_path) return send_file( - current_task.report_file_path, + task.report_file_path, mimetype='text/html', as_attachment=True, download_name=download_name @@ -411,7 +641,18 @@ def cancel_task(task_id: str): if current_task and current_task.task_id == task_id: if current_task.status == "running": current_task.update_status("cancelled", 0, "用户取消任务") + current_task.publish_event('cancelled', { + 'message': '任务被用户主动终止', + 'task': current_task.to_dict(), + }) current_task = None + task = tasks_registry.get(task_id) + if task and task.status == 'running': + task.update_status("cancelled", task.progress, "用户取消任务") + task.publish_event('cancelled', { + 'message': '任务被用户主动终止', + 'task': task.to_dict(), + }) return jsonify({ 'success': True, diff --git a/ReportEngine/nodes/__init__.py b/ReportEngine/nodes/__init__.py index 228e4d1..995c2c9 100644 --- a/ReportEngine/nodes/__init__.py +++ b/ReportEngine/nodes/__init__.py @@ -5,7 +5,7 @@ Report Engine节点处理模块 from .base_node import BaseNode, StateMutationNode from .template_selection_node import TemplateSelectionNode -from .chapter_generation_node import ChapterGenerationNode +from .chapter_generation_node import ChapterGenerationNode, ChapterJsonParseError from .document_layout_node import DocumentLayoutNode from .word_budget_node import WordBudgetNode @@ -14,6 +14,7 @@ __all__ = [ "StateMutationNode", "TemplateSelectionNode", "ChapterGenerationNode", + "ChapterJsonParseError", "DocumentLayoutNode", "WordBudgetNode", ]