Streaming

This commit is contained in:
马一丁
2025-11-13 22:30:36 +08:00
parent 1c2f82e285
commit fa787af135
3 changed files with 402 additions and 39 deletions
+270 -29
View File
@@ -7,9 +7,11 @@ import os
import json
import threading
import time
from collections import deque, defaultdict
from datetime import datetime
from flask import Blueprint, request, jsonify, Response, send_file
from typing import Dict, Any
from queue import Queue, Empty
from flask import Blueprint, request, jsonify, Response, send_file, stream_with_context
from typing import Dict, Any, List, Optional
from loguru import logger
from .agent import ReportAgent, create_agent
from .utils.config import settings
@@ -23,6 +25,69 @@ report_agent = None
current_task = None
task_lock = threading.Lock()
# ====== 流式推送与任务历史管理 ======
# 通过有界deque缓存最近的事件,方便SSE断线后快速补发
MAX_TASK_HISTORY = 5
STREAM_HEARTBEAT_INTERVAL = 15 # 心跳间隔秒
stream_lock = threading.Lock()
stream_subscribers = defaultdict(list)
tasks_registry: Dict[str, 'ReportTask'] = {}
def _register_stream(task_id: str) -> Queue:
"""为指定任务注册一个事件队列,供SSE监听器消费。"""
queue = Queue()
with stream_lock:
stream_subscribers[task_id].append(queue)
return queue
def _unregister_stream(task_id: str, queue: Queue):
"""安全移除事件队列,避免内存泄漏。"""
with stream_lock:
listeners = stream_subscribers.get(task_id, [])
if queue in listeners:
listeners.remove(queue)
if not listeners and task_id in stream_subscribers:
stream_subscribers.pop(task_id, None)
def _broadcast_event(task_id: str, event: Dict[str, Any]):
"""将事件推送给所有监听者,失败时做好异常捕获。"""
with stream_lock:
listeners = list(stream_subscribers.get(task_id, []))
for queue in listeners:
try:
queue.put(event, timeout=0.1)
except Exception:
logger.exception("推送流式事件失败,跳过当前监听队列")
def _prune_task_history_locked():
"""在task_lock持有期间调用,清理过多的历史任务以控制内存。"""
if len(tasks_registry) <= MAX_TASK_HISTORY:
return
# 按创建时间排序,移除最旧的任务
sorted_tasks = sorted(tasks_registry.values(), key=lambda t: t.created_at)
for task in sorted_tasks[:-MAX_TASK_HISTORY]:
tasks_registry.pop(task.task_id, None)
def _get_task(task_id: str) -> Optional['ReportTask']:
"""统一的任务查找方法,优先返回当前任务。"""
with task_lock:
if current_task and current_task.task_id == task_id:
return current_task
return tasks_registry.get(task_id)
def _format_sse(event: Dict[str, Any]) -> str:
"""按SSE协议格式化消息。"""
payload = json.dumps(event, ensure_ascii=False)
event_id = event.get('id', 0)
event_type = event.get('type', 'message')
return f"id: {event_id}\nevent: {event_type}\ndata: {payload}\n\n"
def initialize_report_engine():
"""初始化Report Engine"""
@@ -63,6 +128,11 @@ class ReportTask:
self.report_file_name = ""
self.state_file_path = ""
self.state_file_relative_path = ""
# ====== 流式事件缓存与并发保护 ======
# 使用deque保存最近的事件,结合锁保证多线程下的安全访问
self.event_history: deque = deque(maxlen=1000)
self._event_lock = threading.Lock()
self.last_event_id = 0
def update_status(self, status: str, progress: int = None, error_message: str = ""):
"""更新任务状态"""
@@ -72,6 +142,17 @@ class ReportTask:
if error_message:
self.error_message = error_message
self.updated_at = datetime.now()
# 推送状态变更事件,方便前端实时刷新
self.publish_event(
'status',
{
'status': self.status,
'progress': self.progress,
'error_message': self.error_message,
'hint': error_message or '',
'task': self.to_dict(),
}
)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
@@ -91,6 +172,29 @@ class ReportTask:
'state_file_path': self.state_file_relative_path or self.state_file_path
}
def publish_event(self, event_type: str, payload: Dict[str, Any]) -> None:
"""将任意事件放入缓存并广播,所有新增逻辑均配套中文说明。"""
timestamp = datetime.utcnow().isoformat() + 'Z'
event: Dict[str, Any] = {
'id': 0,
'type': event_type,
'task_id': self.task_id,
'timestamp': timestamp,
'payload': payload,
}
with self._event_lock:
self.last_event_id += 1
event['id'] = self.last_event_id
self.event_history.append(event)
_broadcast_event(self.task_id, event)
def history_since(self, last_event_id: Optional[int]) -> List[Dict[str, Any]]:
"""根据Last-Event-ID补发历史事件,确保断线重连无遗漏。"""
with self._event_lock:
if last_event_id is None:
return list(self.event_history)
return [evt for evt in self.event_history if evt['id'] > last_event_id]
def check_engines_ready() -> Dict[str, Any]:
"""检查三个子引擎是否都有新文件"""
@@ -121,7 +225,13 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
global current_task
try:
# 在局部闭包内封装推送逻辑,便于传递给ReportAgent
def stream_handler(event_type: str, payload: Dict[str, Any]):
"""所有阶段事件都通过同一个接口分发,保证日志一致。"""
task.publish_event(event_type, payload)
task.update_status("running", 10)
task.publish_event('stage', {'message': '任务已启动,正在检查输入文件', 'stage': 'prepare'})
# 检查输入文件
check_result = check_engines_ready()
@@ -129,21 +239,54 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
task.update_status("error", 0, f"输入文件未准备就绪: {check_result.get('missing_files', [])}")
return
task.publish_event('stage', {
'message': '输入文件检查通过,准备载入内容',
'stage': 'io_ready',
'files': check_result.get('latest_files', {})
})
task.update_status("running", 30)
# 加载输入文件
content = report_agent.load_input_files(check_result['latest_files'])
task.publish_event('stage', {'message': '源数据加载完成,启动生成流程', 'stage': 'data_loaded'})
task.update_status("running", 50)
# 生成报告
generation_result = report_agent.generate_report(
query=query,
reports=content['reports'],
forum_logs=content['forum_logs'],
custom_template=custom_template,
save_report=True
)
# 生成报告(附带兜底重试,缓解瞬时网络抖动)
for attempt in range(1, 3):
try:
task.publish_event('stage', {
'message': f'正在调用ReportAgent生成报告(第{attempt}次尝试)',
'stage': 'agent_running',
'attempt': attempt
})
generation_result = report_agent.generate_report(
query=query,
reports=content['reports'],
forum_logs=content['forum_logs'],
custom_template=custom_template,
save_report=True,
stream_handler=stream_handler
)
break
except Exception as err:
# 将错误即时推送至前端,方便观察重试策略
task.publish_event('warning', {
'message': f'ReportAgent执行失败: {str(err)}',
'stage': 'agent_running',
'attempt': attempt
})
if attempt == 2:
raise
# 简单的指数退避,防止频繁触发限流(单位秒)
backoff = min(5 * attempt, 15)
task.publish_event('stage', {
'message': f'{backoff} 秒后重试生成任务',
'stage': 'retry_wait',
'wait_seconds': backoff
})
time.sleep(backoff)
if isinstance(generation_result, dict):
html_report = generation_result.get('html_content', '')
@@ -151,6 +294,7 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
html_report = generation_result
task.update_status("running", 90)
task.publish_event('stage', {'message': '报告生成完毕,准备持久化', 'stage': 'persist'})
# 保存结果
task.html_content = html_report
@@ -160,11 +304,28 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
task.report_file_name = generation_result.get('report_filename', '')
task.state_file_path = generation_result.get('state_filepath', '')
task.state_file_relative_path = generation_result.get('state_relative_path', '')
task.publish_event('html_ready', {
'message': 'HTML渲染完成,可刷新预览',
'report_file': task.report_file_relative_path or task.report_file_path,
'state_file': task.state_file_relative_path or task.state_file_path,
'task': task.to_dict(),
})
task.update_status("completed", 100)
task.publish_event('completed', {
'message': '任务完成',
'duration_seconds': (task.updated_at - task.created_at).total_seconds(),
'report_file': task.report_file_relative_path or task.report_file_path,
'task': task.to_dict(),
})
except Exception as e:
logger.exception(f"报告生成过程中发生错误: {str(e)}")
task.update_status("error", 0, str(e))
task.publish_event('error', {
'message': str(e),
'stage': 'failed',
'task': task.to_dict(),
})
# 只在出错时清理任务
with task_lock:
if current_task and current_task.task_id == task.task_id:
@@ -242,6 +403,19 @@ def generate_report():
with task_lock:
current_task = task
tasks_registry[task_id] = task
_prune_task_history_locked()
# 通过主动推送pending事件告知前端任务已经排队
task.publish_event(
'status',
{
'status': task.status,
'progress': task.progress,
'message': '任务已排队,等待资源空闲',
'task': task.to_dict(),
}
)
# 在后台线程中运行报告生成
thread = threading.Thread(
@@ -255,7 +429,8 @@ def generate_report():
'success': True,
'task_id': task_id,
'message': '报告生成已启动',
'task': task.to_dict()
'task': task.to_dict(),
'stream_url': f"/api/report/stream/{task_id}"
})
except Exception as e:
@@ -270,9 +445,9 @@ def generate_report():
def get_progress(task_id: str):
"""获取报告生成进度"""
try:
if not current_task or current_task.task_id != task_id:
# 如果任务不存在,可能是已经完成并被清理了
# 返回一个默认的完成状态而不是404
task = _get_task(task_id)
if not task:
# 如果任务不存在,可能是历史记录已被清理,回传一个完成态兜底
return jsonify({
'success': True,
'task': {
@@ -291,7 +466,7 @@ def get_progress(task_id: str):
return jsonify({
'success': True,
'task': current_task.to_dict()
'task': task.to_dict()
})
except Exception as e:
@@ -302,25 +477,78 @@ def get_progress(task_id: str):
}), 500
@report_bp.route('/stream/<task_id>', methods=['GET'])
def stream_task(task_id: str):
"""基于SSE的实时推送接口,向前端持续广播阶段事件。"""
task = _get_task(task_id)
if not task:
return jsonify({'success': False, 'error': '任务不存在'}), 404
last_event_header = request.headers.get('Last-Event-ID')
try:
last_event_id = int(last_event_header) if last_event_header else None
except ValueError:
last_event_id = None
def event_generator():
queue = _register_stream(task_id)
try:
# 断线重连场景下,先补发历史事件,保证界面状态一致
history = task.history_since(last_event_id)
for event in history:
yield _format_sse(event)
finished = task.status in ("completed", "error", "cancelled")
while True:
if finished:
break
try:
event = queue.get(timeout=STREAM_HEARTBEAT_INTERVAL)
yield _format_sse(event)
if event.get('type') in ("completed", "error"):
finished = True
except Empty:
heartbeat = {
'id': f"hb-{int(time.time() * 1000)}",
'type': 'heartbeat',
'task_id': task_id,
'timestamp': datetime.utcnow().isoformat() + 'Z',
'payload': {'status': task.status}
}
yield _format_sse(heartbeat)
finished = task.status in ("completed", "error", "cancelled")
finally:
_unregister_stream(task_id, queue)
response = Response(
stream_with_context(event_generator()),
mimetype='text/event-stream'
)
response.headers['Cache-Control'] = 'no-cache'
response.headers['X-Accel-Buffering'] = 'no'
return response
@report_bp.route('/result/<task_id>', methods=['GET'])
def get_result(task_id: str):
"""获取报告生成结果"""
try:
if not current_task or current_task.task_id != task_id:
task = _get_task(task_id)
if not task:
return jsonify({
'success': False,
'error': '任务不存在'
}), 404
if current_task.status != "completed":
if task.status != "completed":
return jsonify({
'success': False,
'error': '报告尚未完成',
'task': current_task.to_dict()
'task': task.to_dict()
}), 400
return Response(
current_task.html_content,
task.html_content,
mimetype='text/html'
)
@@ -336,23 +564,24 @@ def get_result(task_id: str):
def get_result_json(task_id: str):
"""获取报告生成结果(JSON格式)"""
try:
if not current_task or current_task.task_id != task_id:
task = _get_task(task_id)
if not task:
return jsonify({
'success': False,
'error': '任务不存在'
}), 404
if current_task.status != "completed":
if task.status != "completed":
return jsonify({
'success': False,
'error': '报告尚未完成',
'task': current_task.to_dict()
'task': task.to_dict()
}), 400
return jsonify({
'success': True,
'task': current_task.to_dict(),
'html_content': current_task.html_content
'task': task.to_dict(),
'html_content': task.html_content
})
except Exception as e:
@@ -367,27 +596,28 @@ def get_result_json(task_id: str):
def download_report(task_id: str):
"""下载已生成的报告HTML文件"""
try:
if not current_task or current_task.task_id != task_id:
task = _get_task(task_id)
if not task:
return jsonify({
'success': False,
'error': '任务不存在'
}), 404
if current_task.status != "completed" or not current_task.report_file_path:
if task.status != "completed" or not task.report_file_path:
return jsonify({
'success': False,
'error': '报告尚未完成或尚未保存'
}), 400
if not os.path.exists(current_task.report_file_path):
if not os.path.exists(task.report_file_path):
return jsonify({
'success': False,
'error': '报告文件不存在或已被删除'
}), 404
download_name = current_task.report_file_name or os.path.basename(current_task.report_file_path)
download_name = task.report_file_name or os.path.basename(task.report_file_path)
return send_file(
current_task.report_file_path,
task.report_file_path,
mimetype='text/html',
as_attachment=True,
download_name=download_name
@@ -411,7 +641,18 @@ def cancel_task(task_id: str):
if current_task and current_task.task_id == task_id:
if current_task.status == "running":
current_task.update_status("cancelled", 0, "用户取消任务")
current_task.publish_event('cancelled', {
'message': '任务被用户主动终止',
'task': current_task.to_dict(),
})
current_task = None
task = tasks_registry.get(task_id)
if task and task.status == 'running':
task.update_status("cancelled", task.progress, "用户取消任务")
task.publish_event('cancelled', {
'message': '任务被用户主动终止',
'task': task.to_dict(),
})
return jsonify({
'success': True,