Streaming
This commit is contained in:
+130
-9
@@ -8,7 +8,7 @@ import os
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -23,6 +23,7 @@ from .llms import LLMClient
|
||||
from .nodes import (
|
||||
TemplateSelectionNode,
|
||||
ChapterGenerationNode,
|
||||
ChapterJsonParseError,
|
||||
DocumentLayoutNode,
|
||||
WordBudgetNode,
|
||||
)
|
||||
@@ -205,10 +206,11 @@ class ReportAgent:
|
||||
)
|
||||
|
||||
def generate_report(self, query: str, reports: List[Any], forum_logs: str = "",
|
||||
custom_template: str = "", save_report: bool = True) -> str:
|
||||
custom_template: str = "", save_report: bool = True,
|
||||
stream_handler: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> str:
|
||||
"""
|
||||
生成综合报告(章节JSON → IR → HTML)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: HTML内容以及保存的文件路径信息
|
||||
"""
|
||||
@@ -220,15 +222,32 @@ class ReportAgent:
|
||||
self.state.mark_processing()
|
||||
|
||||
normalized_reports = self._normalize_reports(reports)
|
||||
|
||||
def emit(event_type: str, payload: Dict[str, Any]):
|
||||
"""面向Report Engine流通道的事件分发器,保证错误不外泄。"""
|
||||
if not stream_handler:
|
||||
return
|
||||
try:
|
||||
stream_handler(event_type, payload)
|
||||
except Exception as callback_error: # pragma: no cover - 仅记录
|
||||
logger.warning(f"流式事件回调失败: {callback_error}")
|
||||
|
||||
logger.info(f"开始生成报告 {report_id}: {query}")
|
||||
logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(str(forum_logs))}")
|
||||
emit('stage', {'stage': 'agent_start', 'report_id': report_id, 'query': query})
|
||||
|
||||
try:
|
||||
template_result = self._select_template(query, reports, forum_logs, custom_template)
|
||||
self.state.metadata.template_used = template_result.get('template_name', '')
|
||||
emit('stage', {
|
||||
'stage': 'template_selected',
|
||||
'template': template_result.get('template_name'),
|
||||
'reason': template_result.get('selection_reason')
|
||||
})
|
||||
sections = self._slice_template(template_result.get('template_content', ''))
|
||||
if not sections:
|
||||
raise ValueError("模板无法解析出章节,请检查模板内容。")
|
||||
emit('stage', {'stage': 'template_sliced', 'section_count': len(sections)})
|
||||
|
||||
template_text = template_result.get('template_content', '')
|
||||
template_overview = self._build_template_overview(template_text, sections)
|
||||
@@ -241,6 +260,11 @@ class ReportAgent:
|
||||
query,
|
||||
template_overview,
|
||||
)
|
||||
emit('stage', {
|
||||
'stage': 'layout_designed',
|
||||
'title': layout_design.get('title'),
|
||||
'toc': layout_design.get('tocTitle')
|
||||
})
|
||||
# 使用刚生成的设计稿对全书进行篇幅规划,约束各章字数与重点
|
||||
word_plan = self.word_budget_node.run(
|
||||
sections,
|
||||
@@ -250,6 +274,10 @@ class ReportAgent:
|
||||
query,
|
||||
template_overview,
|
||||
)
|
||||
emit('stage', {
|
||||
'stage': 'word_plan_ready',
|
||||
'chapter_targets': len(word_plan.get('chapters', []))
|
||||
})
|
||||
# 记录每个章节的目标字数/强调点,后续传给章节LLM
|
||||
chapter_targets = {
|
||||
entry.get("chapterId"): entry
|
||||
@@ -296,23 +324,97 @@ class ReportAgent:
|
||||
# 初始化章节输出目录并写入manifest,方便流式存盘
|
||||
run_dir = self.chapter_storage.start_session(report_id, manifest_meta)
|
||||
self._persist_planning_artifacts(run_dir, layout_design, word_plan, template_overview)
|
||||
emit('stage', {'stage': 'storage_ready', 'run_dir': str(run_dir)})
|
||||
|
||||
chapters = []
|
||||
chapter_max_attempts = max(1, self.config.CHAPTER_JSON_MAX_ATTEMPTS)
|
||||
for section in sections:
|
||||
logger.info(f"生成章节: {section.title}")
|
||||
chapter = self.chapter_generation_node.run(
|
||||
section,
|
||||
generation_context,
|
||||
run_dir
|
||||
)
|
||||
chapters.append(chapter)
|
||||
emit('chapter_status', {
|
||||
'chapterId': section.chapter_id,
|
||||
'title': section.title,
|
||||
'status': 'running'
|
||||
})
|
||||
# 章节流式回调:把LLM返回的delta透传给SSE,便于前端实时渲染
|
||||
def chunk_callback(delta: str, meta: Dict[str, Any], section_ref: TemplateSection = section):
|
||||
emit('chapter_chunk', {
|
||||
'chapterId': meta.get('chapterId') or section_ref.chapter_id,
|
||||
'title': meta.get('title') or section_ref.title,
|
||||
'delta': delta
|
||||
})
|
||||
|
||||
chapter_payload: Dict[str, Any] | None = None
|
||||
attempt = 1
|
||||
while attempt <= chapter_max_attempts:
|
||||
try:
|
||||
chapter_payload = self.chapter_generation_node.run(
|
||||
section,
|
||||
generation_context,
|
||||
run_dir,
|
||||
stream_callback=chunk_callback
|
||||
)
|
||||
break
|
||||
except ChapterJsonParseError as parse_error:
|
||||
logger.warning(
|
||||
"章节 %s JSON解析失败(第 %s/%s 次尝试): %s",
|
||||
section.title,
|
||||
attempt,
|
||||
chapter_max_attempts,
|
||||
parse_error,
|
||||
)
|
||||
emit('chapter_status', {
|
||||
'chapterId': section.chapter_id,
|
||||
'title': section.title,
|
||||
'status': 'retrying' if attempt < chapter_max_attempts else 'error',
|
||||
'attempt': attempt,
|
||||
'error': str(parse_error),
|
||||
})
|
||||
if attempt >= chapter_max_attempts:
|
||||
raise
|
||||
attempt += 1
|
||||
continue
|
||||
except Exception as chapter_error:
|
||||
if not self._should_retry_inappropriate_content_error(chapter_error):
|
||||
raise
|
||||
logger.warning(
|
||||
"章节 %s 触发内容安全限制(第 %s/%s 次尝试),准备重新生成: %s",
|
||||
section.title,
|
||||
attempt,
|
||||
chapter_max_attempts,
|
||||
chapter_error,
|
||||
)
|
||||
emit('chapter_status', {
|
||||
'chapterId': section.chapter_id,
|
||||
'title': section.title,
|
||||
'status': 'retrying' if attempt < chapter_max_attempts else 'error',
|
||||
'attempt': attempt,
|
||||
'error': str(chapter_error),
|
||||
'reason': 'content_filter'
|
||||
})
|
||||
if attempt >= chapter_max_attempts:
|
||||
raise
|
||||
attempt += 1
|
||||
continue
|
||||
if chapter_payload is None:
|
||||
raise ChapterJsonParseError(
|
||||
f"{section.title} 章节JSON在 {chapter_max_attempts} 次尝试后仍无法解析"
|
||||
)
|
||||
chapters.append(chapter_payload)
|
||||
emit('chapter_status', {
|
||||
'chapterId': section.chapter_id,
|
||||
'title': section.title,
|
||||
'status': 'completed',
|
||||
'attempt': attempt,
|
||||
})
|
||||
|
||||
document_ir = self.document_composer.build_document(
|
||||
report_id,
|
||||
manifest_meta,
|
||||
chapters
|
||||
)
|
||||
emit('stage', {'stage': 'chapters_compiled', 'chapter_count': len(chapters)})
|
||||
html_report = self.renderer.render(document_ir)
|
||||
emit('stage', {'stage': 'html_rendered', 'html_length': len(html_report)})
|
||||
|
||||
self.state.html_content = html_report
|
||||
self.state.mark_completed()
|
||||
@@ -320,10 +422,12 @@ class ReportAgent:
|
||||
saved_files = {}
|
||||
if save_report:
|
||||
saved_files = self._save_report(html_report, document_ir, report_id)
|
||||
emit('stage', {'stage': 'report_saved', 'files': saved_files})
|
||||
|
||||
generation_time = (datetime.now() - start_time).total_seconds()
|
||||
self.state.metadata.generation_time = generation_time
|
||||
logger.info(f"报告生成完成,耗时: {generation_time:.2f} 秒")
|
||||
emit('metrics', {'generation_seconds': generation_time})
|
||||
return {
|
||||
'html_content': html_report,
|
||||
'report_id': report_id,
|
||||
@@ -333,6 +437,7 @@ class ReportAgent:
|
||||
except Exception as e:
|
||||
self.state.mark_failed(str(e))
|
||||
logger.exception(f"报告生成过程中发生错误: {str(e)}")
|
||||
emit('error', {'stage': 'agent_failed', 'message': str(e)})
|
||||
raise
|
||||
|
||||
def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str):
|
||||
@@ -444,6 +549,22 @@ class ReportAgent:
|
||||
normalized[key] = self._stringify(value)
|
||||
return normalized
|
||||
|
||||
def _should_retry_inappropriate_content_error(self, error: Exception) -> bool:
|
||||
"""
|
||||
判断LLM异常是否由内容安全/不当内容导致,满足时允许重新生成整章。
|
||||
"""
|
||||
message = str(error) if error else ""
|
||||
if not message:
|
||||
return False
|
||||
normalized = message.lower()
|
||||
keywords = [
|
||||
"inappropriate content",
|
||||
"content violation",
|
||||
"content moderation",
|
||||
"model-studio/error-code",
|
||||
]
|
||||
return any(keyword in normalized for keyword in keywords)
|
||||
|
||||
def _stringify(self, value: Any) -> str:
|
||||
"""安全地将对象转成字符串"""
|
||||
if value is None:
|
||||
|
||||
+270
-29
@@ -7,9 +7,11 @@ import os
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections import deque, defaultdict
|
||||
from datetime import datetime
|
||||
from flask import Blueprint, request, jsonify, Response, send_file
|
||||
from typing import Dict, Any
|
||||
from queue import Queue, Empty
|
||||
from flask import Blueprint, request, jsonify, Response, send_file, stream_with_context
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from .agent import ReportAgent, create_agent
|
||||
from .utils.config import settings
|
||||
@@ -23,6 +25,69 @@ report_agent = None
|
||||
current_task = None
|
||||
task_lock = threading.Lock()
|
||||
|
||||
# ====== 流式推送与任务历史管理 ======
|
||||
# 通过有界deque缓存最近的事件,方便SSE断线后快速补发
|
||||
MAX_TASK_HISTORY = 5
|
||||
STREAM_HEARTBEAT_INTERVAL = 15 # 心跳间隔秒
|
||||
stream_lock = threading.Lock()
|
||||
stream_subscribers = defaultdict(list)
|
||||
tasks_registry: Dict[str, 'ReportTask'] = {}
|
||||
|
||||
|
||||
def _register_stream(task_id: str) -> Queue:
|
||||
"""为指定任务注册一个事件队列,供SSE监听器消费。"""
|
||||
queue = Queue()
|
||||
with stream_lock:
|
||||
stream_subscribers[task_id].append(queue)
|
||||
return queue
|
||||
|
||||
|
||||
def _unregister_stream(task_id: str, queue: Queue):
|
||||
"""安全移除事件队列,避免内存泄漏。"""
|
||||
with stream_lock:
|
||||
listeners = stream_subscribers.get(task_id, [])
|
||||
if queue in listeners:
|
||||
listeners.remove(queue)
|
||||
if not listeners and task_id in stream_subscribers:
|
||||
stream_subscribers.pop(task_id, None)
|
||||
|
||||
|
||||
def _broadcast_event(task_id: str, event: Dict[str, Any]):
|
||||
"""将事件推送给所有监听者,失败时做好异常捕获。"""
|
||||
with stream_lock:
|
||||
listeners = list(stream_subscribers.get(task_id, []))
|
||||
for queue in listeners:
|
||||
try:
|
||||
queue.put(event, timeout=0.1)
|
||||
except Exception:
|
||||
logger.exception("推送流式事件失败,跳过当前监听队列")
|
||||
|
||||
|
||||
def _prune_task_history_locked():
|
||||
"""在task_lock持有期间调用,清理过多的历史任务以控制内存。"""
|
||||
if len(tasks_registry) <= MAX_TASK_HISTORY:
|
||||
return
|
||||
# 按创建时间排序,移除最旧的任务
|
||||
sorted_tasks = sorted(tasks_registry.values(), key=lambda t: t.created_at)
|
||||
for task in sorted_tasks[:-MAX_TASK_HISTORY]:
|
||||
tasks_registry.pop(task.task_id, None)
|
||||
|
||||
|
||||
def _get_task(task_id: str) -> Optional['ReportTask']:
|
||||
"""统一的任务查找方法,优先返回当前任务。"""
|
||||
with task_lock:
|
||||
if current_task and current_task.task_id == task_id:
|
||||
return current_task
|
||||
return tasks_registry.get(task_id)
|
||||
|
||||
|
||||
def _format_sse(event: Dict[str, Any]) -> str:
|
||||
"""按SSE协议格式化消息。"""
|
||||
payload = json.dumps(event, ensure_ascii=False)
|
||||
event_id = event.get('id', 0)
|
||||
event_type = event.get('type', 'message')
|
||||
return f"id: {event_id}\nevent: {event_type}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
def initialize_report_engine():
|
||||
"""初始化Report Engine"""
|
||||
@@ -63,6 +128,11 @@ class ReportTask:
|
||||
self.report_file_name = ""
|
||||
self.state_file_path = ""
|
||||
self.state_file_relative_path = ""
|
||||
# ====== 流式事件缓存与并发保护 ======
|
||||
# 使用deque保存最近的事件,结合锁保证多线程下的安全访问
|
||||
self.event_history: deque = deque(maxlen=1000)
|
||||
self._event_lock = threading.Lock()
|
||||
self.last_event_id = 0
|
||||
|
||||
def update_status(self, status: str, progress: int = None, error_message: str = ""):
|
||||
"""更新任务状态"""
|
||||
@@ -72,6 +142,17 @@ class ReportTask:
|
||||
if error_message:
|
||||
self.error_message = error_message
|
||||
self.updated_at = datetime.now()
|
||||
# 推送状态变更事件,方便前端实时刷新
|
||||
self.publish_event(
|
||||
'status',
|
||||
{
|
||||
'status': self.status,
|
||||
'progress': self.progress,
|
||||
'error_message': self.error_message,
|
||||
'hint': error_message or '',
|
||||
'task': self.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
@@ -91,6 +172,29 @@ class ReportTask:
|
||||
'state_file_path': self.state_file_relative_path or self.state_file_path
|
||||
}
|
||||
|
||||
def publish_event(self, event_type: str, payload: Dict[str, Any]) -> None:
|
||||
"""将任意事件放入缓存并广播,所有新增逻辑均配套中文说明。"""
|
||||
timestamp = datetime.utcnow().isoformat() + 'Z'
|
||||
event: Dict[str, Any] = {
|
||||
'id': 0,
|
||||
'type': event_type,
|
||||
'task_id': self.task_id,
|
||||
'timestamp': timestamp,
|
||||
'payload': payload,
|
||||
}
|
||||
with self._event_lock:
|
||||
self.last_event_id += 1
|
||||
event['id'] = self.last_event_id
|
||||
self.event_history.append(event)
|
||||
_broadcast_event(self.task_id, event)
|
||||
|
||||
def history_since(self, last_event_id: Optional[int]) -> List[Dict[str, Any]]:
|
||||
"""根据Last-Event-ID补发历史事件,确保断线重连无遗漏。"""
|
||||
with self._event_lock:
|
||||
if last_event_id is None:
|
||||
return list(self.event_history)
|
||||
return [evt for evt in self.event_history if evt['id'] > last_event_id]
|
||||
|
||||
|
||||
def check_engines_ready() -> Dict[str, Any]:
|
||||
"""检查三个子引擎是否都有新文件"""
|
||||
@@ -121,7 +225,13 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
|
||||
global current_task
|
||||
|
||||
try:
|
||||
# 在局部闭包内封装推送逻辑,便于传递给ReportAgent
|
||||
def stream_handler(event_type: str, payload: Dict[str, Any]):
|
||||
"""所有阶段事件都通过同一个接口分发,保证日志一致。"""
|
||||
task.publish_event(event_type, payload)
|
||||
|
||||
task.update_status("running", 10)
|
||||
task.publish_event('stage', {'message': '任务已启动,正在检查输入文件', 'stage': 'prepare'})
|
||||
|
||||
# 检查输入文件
|
||||
check_result = check_engines_ready()
|
||||
@@ -129,21 +239,54 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
|
||||
task.update_status("error", 0, f"输入文件未准备就绪: {check_result.get('missing_files', [])}")
|
||||
return
|
||||
|
||||
task.publish_event('stage', {
|
||||
'message': '输入文件检查通过,准备载入内容',
|
||||
'stage': 'io_ready',
|
||||
'files': check_result.get('latest_files', {})
|
||||
})
|
||||
|
||||
task.update_status("running", 30)
|
||||
|
||||
# 加载输入文件
|
||||
content = report_agent.load_input_files(check_result['latest_files'])
|
||||
task.publish_event('stage', {'message': '源数据加载完成,启动生成流程', 'stage': 'data_loaded'})
|
||||
|
||||
task.update_status("running", 50)
|
||||
|
||||
# 生成报告
|
||||
generation_result = report_agent.generate_report(
|
||||
query=query,
|
||||
reports=content['reports'],
|
||||
forum_logs=content['forum_logs'],
|
||||
custom_template=custom_template,
|
||||
save_report=True
|
||||
)
|
||||
# 生成报告(附带兜底重试,缓解瞬时网络抖动)
|
||||
for attempt in range(1, 3):
|
||||
try:
|
||||
task.publish_event('stage', {
|
||||
'message': f'正在调用ReportAgent生成报告(第{attempt}次尝试)',
|
||||
'stage': 'agent_running',
|
||||
'attempt': attempt
|
||||
})
|
||||
generation_result = report_agent.generate_report(
|
||||
query=query,
|
||||
reports=content['reports'],
|
||||
forum_logs=content['forum_logs'],
|
||||
custom_template=custom_template,
|
||||
save_report=True,
|
||||
stream_handler=stream_handler
|
||||
)
|
||||
break
|
||||
except Exception as err:
|
||||
# 将错误即时推送至前端,方便观察重试策略
|
||||
task.publish_event('warning', {
|
||||
'message': f'ReportAgent执行失败: {str(err)}',
|
||||
'stage': 'agent_running',
|
||||
'attempt': attempt
|
||||
})
|
||||
if attempt == 2:
|
||||
raise
|
||||
# 简单的指数退避,防止频繁触发限流(单位秒)
|
||||
backoff = min(5 * attempt, 15)
|
||||
task.publish_event('stage', {
|
||||
'message': f'{backoff} 秒后重试生成任务',
|
||||
'stage': 'retry_wait',
|
||||
'wait_seconds': backoff
|
||||
})
|
||||
time.sleep(backoff)
|
||||
|
||||
if isinstance(generation_result, dict):
|
||||
html_report = generation_result.get('html_content', '')
|
||||
@@ -151,6 +294,7 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
|
||||
html_report = generation_result
|
||||
|
||||
task.update_status("running", 90)
|
||||
task.publish_event('stage', {'message': '报告生成完毕,准备持久化', 'stage': 'persist'})
|
||||
|
||||
# 保存结果
|
||||
task.html_content = html_report
|
||||
@@ -160,11 +304,28 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
|
||||
task.report_file_name = generation_result.get('report_filename', '')
|
||||
task.state_file_path = generation_result.get('state_filepath', '')
|
||||
task.state_file_relative_path = generation_result.get('state_relative_path', '')
|
||||
task.publish_event('html_ready', {
|
||||
'message': 'HTML渲染完成,可刷新预览',
|
||||
'report_file': task.report_file_relative_path or task.report_file_path,
|
||||
'state_file': task.state_file_relative_path or task.state_file_path,
|
||||
'task': task.to_dict(),
|
||||
})
|
||||
task.update_status("completed", 100)
|
||||
task.publish_event('completed', {
|
||||
'message': '任务完成',
|
||||
'duration_seconds': (task.updated_at - task.created_at).total_seconds(),
|
||||
'report_file': task.report_file_relative_path or task.report_file_path,
|
||||
'task': task.to_dict(),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"报告生成过程中发生错误: {str(e)}")
|
||||
task.update_status("error", 0, str(e))
|
||||
task.publish_event('error', {
|
||||
'message': str(e),
|
||||
'stage': 'failed',
|
||||
'task': task.to_dict(),
|
||||
})
|
||||
# 只在出错时清理任务
|
||||
with task_lock:
|
||||
if current_task and current_task.task_id == task.task_id:
|
||||
@@ -242,6 +403,19 @@ def generate_report():
|
||||
|
||||
with task_lock:
|
||||
current_task = task
|
||||
tasks_registry[task_id] = task
|
||||
_prune_task_history_locked()
|
||||
|
||||
# 通过主动推送pending事件告知前端任务已经排队
|
||||
task.publish_event(
|
||||
'status',
|
||||
{
|
||||
'status': task.status,
|
||||
'progress': task.progress,
|
||||
'message': '任务已排队,等待资源空闲',
|
||||
'task': task.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
# 在后台线程中运行报告生成
|
||||
thread = threading.Thread(
|
||||
@@ -255,7 +429,8 @@ def generate_report():
|
||||
'success': True,
|
||||
'task_id': task_id,
|
||||
'message': '报告生成已启动',
|
||||
'task': task.to_dict()
|
||||
'task': task.to_dict(),
|
||||
'stream_url': f"/api/report/stream/{task_id}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -270,9 +445,9 @@ def generate_report():
|
||||
def get_progress(task_id: str):
|
||||
"""获取报告生成进度"""
|
||||
try:
|
||||
if not current_task or current_task.task_id != task_id:
|
||||
# 如果任务不存在,可能是已经完成并被清理了
|
||||
# 返回一个默认的完成状态而不是404
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
# 如果任务不存在,可能是历史记录已被清理,回传一个完成态兜底
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'task': {
|
||||
@@ -291,7 +466,7 @@ def get_progress(task_id: str):
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'task': current_task.to_dict()
|
||||
'task': task.to_dict()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -302,25 +477,78 @@ def get_progress(task_id: str):
|
||||
}), 500
|
||||
|
||||
|
||||
@report_bp.route('/stream/<task_id>', methods=['GET'])
|
||||
def stream_task(task_id: str):
|
||||
"""基于SSE的实时推送接口,向前端持续广播阶段事件。"""
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({'success': False, 'error': '任务不存在'}), 404
|
||||
|
||||
last_event_header = request.headers.get('Last-Event-ID')
|
||||
try:
|
||||
last_event_id = int(last_event_header) if last_event_header else None
|
||||
except ValueError:
|
||||
last_event_id = None
|
||||
|
||||
def event_generator():
|
||||
queue = _register_stream(task_id)
|
||||
try:
|
||||
# 断线重连场景下,先补发历史事件,保证界面状态一致
|
||||
history = task.history_since(last_event_id)
|
||||
for event in history:
|
||||
yield _format_sse(event)
|
||||
|
||||
finished = task.status in ("completed", "error", "cancelled")
|
||||
while True:
|
||||
if finished:
|
||||
break
|
||||
try:
|
||||
event = queue.get(timeout=STREAM_HEARTBEAT_INTERVAL)
|
||||
yield _format_sse(event)
|
||||
if event.get('type') in ("completed", "error"):
|
||||
finished = True
|
||||
except Empty:
|
||||
heartbeat = {
|
||||
'id': f"hb-{int(time.time() * 1000)}",
|
||||
'type': 'heartbeat',
|
||||
'task_id': task_id,
|
||||
'timestamp': datetime.utcnow().isoformat() + 'Z',
|
||||
'payload': {'status': task.status}
|
||||
}
|
||||
yield _format_sse(heartbeat)
|
||||
finished = task.status in ("completed", "error", "cancelled")
|
||||
finally:
|
||||
_unregister_stream(task_id, queue)
|
||||
|
||||
response = Response(
|
||||
stream_with_context(event_generator()),
|
||||
mimetype='text/event-stream'
|
||||
)
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
response.headers['X-Accel-Buffering'] = 'no'
|
||||
return response
|
||||
|
||||
|
||||
@report_bp.route('/result/<task_id>', methods=['GET'])
|
||||
def get_result(task_id: str):
|
||||
"""获取报告生成结果"""
|
||||
try:
|
||||
if not current_task or current_task.task_id != task_id:
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '任务不存在'
|
||||
}), 404
|
||||
|
||||
if current_task.status != "completed":
|
||||
if task.status != "completed":
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '报告尚未完成',
|
||||
'task': current_task.to_dict()
|
||||
'task': task.to_dict()
|
||||
}), 400
|
||||
|
||||
return Response(
|
||||
current_task.html_content,
|
||||
task.html_content,
|
||||
mimetype='text/html'
|
||||
)
|
||||
|
||||
@@ -336,23 +564,24 @@ def get_result(task_id: str):
|
||||
def get_result_json(task_id: str):
|
||||
"""获取报告生成结果(JSON格式)"""
|
||||
try:
|
||||
if not current_task or current_task.task_id != task_id:
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '任务不存在'
|
||||
}), 404
|
||||
|
||||
if current_task.status != "completed":
|
||||
if task.status != "completed":
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '报告尚未完成',
|
||||
'task': current_task.to_dict()
|
||||
'task': task.to_dict()
|
||||
}), 400
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'task': current_task.to_dict(),
|
||||
'html_content': current_task.html_content
|
||||
'task': task.to_dict(),
|
||||
'html_content': task.html_content
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -367,27 +596,28 @@ def get_result_json(task_id: str):
|
||||
def download_report(task_id: str):
|
||||
"""下载已生成的报告HTML文件"""
|
||||
try:
|
||||
if not current_task or current_task.task_id != task_id:
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '任务不存在'
|
||||
}), 404
|
||||
|
||||
if current_task.status != "completed" or not current_task.report_file_path:
|
||||
if task.status != "completed" or not task.report_file_path:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '报告尚未完成或尚未保存'
|
||||
}), 400
|
||||
|
||||
if not os.path.exists(current_task.report_file_path):
|
||||
if not os.path.exists(task.report_file_path):
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '报告文件不存在或已被删除'
|
||||
}), 404
|
||||
|
||||
download_name = current_task.report_file_name or os.path.basename(current_task.report_file_path)
|
||||
download_name = task.report_file_name or os.path.basename(task.report_file_path)
|
||||
return send_file(
|
||||
current_task.report_file_path,
|
||||
task.report_file_path,
|
||||
mimetype='text/html',
|
||||
as_attachment=True,
|
||||
download_name=download_name
|
||||
@@ -411,7 +641,18 @@ def cancel_task(task_id: str):
|
||||
if current_task and current_task.task_id == task_id:
|
||||
if current_task.status == "running":
|
||||
current_task.update_status("cancelled", 0, "用户取消任务")
|
||||
current_task.publish_event('cancelled', {
|
||||
'message': '任务被用户主动终止',
|
||||
'task': current_task.to_dict(),
|
||||
})
|
||||
current_task = None
|
||||
task = tasks_registry.get(task_id)
|
||||
if task and task.status == 'running':
|
||||
task.update_status("cancelled", task.progress, "用户取消任务")
|
||||
task.publish_event('cancelled', {
|
||||
'message': '任务被用户主动终止',
|
||||
'task': task.to_dict(),
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
|
||||
@@ -5,7 +5,7 @@ Report Engine节点处理模块
|
||||
|
||||
from .base_node import BaseNode, StateMutationNode
|
||||
from .template_selection_node import TemplateSelectionNode
|
||||
from .chapter_generation_node import ChapterGenerationNode
|
||||
from .chapter_generation_node import ChapterGenerationNode, ChapterJsonParseError
|
||||
from .document_layout_node import DocumentLayoutNode
|
||||
from .word_budget_node import WordBudgetNode
|
||||
|
||||
@@ -14,6 +14,7 @@ __all__ = [
|
||||
"StateMutationNode",
|
||||
"TemplateSelectionNode",
|
||||
"ChapterGenerationNode",
|
||||
"ChapterJsonParseError",
|
||||
"DocumentLayoutNode",
|
||||
"WordBudgetNode",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user