Streaming
This commit is contained in:
+270
-29
@@ -7,9 +7,11 @@ import os
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections import deque, defaultdict
|
||||
from datetime import datetime
|
||||
from flask import Blueprint, request, jsonify, Response, send_file
|
||||
from typing import Dict, Any
|
||||
from queue import Queue, Empty
|
||||
from flask import Blueprint, request, jsonify, Response, send_file, stream_with_context
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from .agent import ReportAgent, create_agent
|
||||
from .utils.config import settings
|
||||
@@ -23,6 +25,69 @@ report_agent = None
|
||||
current_task = None
|
||||
task_lock = threading.Lock()
|
||||
|
||||
# ====== 流式推送与任务历史管理 ======
|
||||
# 通过有界deque缓存最近的事件,方便SSE断线后快速补发
|
||||
MAX_TASK_HISTORY = 5
|
||||
STREAM_HEARTBEAT_INTERVAL = 15 # 心跳间隔秒
|
||||
stream_lock = threading.Lock()
|
||||
stream_subscribers = defaultdict(list)
|
||||
tasks_registry: Dict[str, 'ReportTask'] = {}
|
||||
|
||||
|
||||
def _register_stream(task_id: str) -> Queue:
|
||||
"""为指定任务注册一个事件队列,供SSE监听器消费。"""
|
||||
queue = Queue()
|
||||
with stream_lock:
|
||||
stream_subscribers[task_id].append(queue)
|
||||
return queue
|
||||
|
||||
|
||||
def _unregister_stream(task_id: str, queue: Queue):
|
||||
"""安全移除事件队列,避免内存泄漏。"""
|
||||
with stream_lock:
|
||||
listeners = stream_subscribers.get(task_id, [])
|
||||
if queue in listeners:
|
||||
listeners.remove(queue)
|
||||
if not listeners and task_id in stream_subscribers:
|
||||
stream_subscribers.pop(task_id, None)
|
||||
|
||||
|
||||
def _broadcast_event(task_id: str, event: Dict[str, Any]):
|
||||
"""将事件推送给所有监听者,失败时做好异常捕获。"""
|
||||
with stream_lock:
|
||||
listeners = list(stream_subscribers.get(task_id, []))
|
||||
for queue in listeners:
|
||||
try:
|
||||
queue.put(event, timeout=0.1)
|
||||
except Exception:
|
||||
logger.exception("推送流式事件失败,跳过当前监听队列")
|
||||
|
||||
|
||||
def _prune_task_history_locked():
|
||||
"""在task_lock持有期间调用,清理过多的历史任务以控制内存。"""
|
||||
if len(tasks_registry) <= MAX_TASK_HISTORY:
|
||||
return
|
||||
# 按创建时间排序,移除最旧的任务
|
||||
sorted_tasks = sorted(tasks_registry.values(), key=lambda t: t.created_at)
|
||||
for task in sorted_tasks[:-MAX_TASK_HISTORY]:
|
||||
tasks_registry.pop(task.task_id, None)
|
||||
|
||||
|
||||
def _get_task(task_id: str) -> Optional['ReportTask']:
|
||||
"""统一的任务查找方法,优先返回当前任务。"""
|
||||
with task_lock:
|
||||
if current_task and current_task.task_id == task_id:
|
||||
return current_task
|
||||
return tasks_registry.get(task_id)
|
||||
|
||||
|
||||
def _format_sse(event: Dict[str, Any]) -> str:
|
||||
"""按SSE协议格式化消息。"""
|
||||
payload = json.dumps(event, ensure_ascii=False)
|
||||
event_id = event.get('id', 0)
|
||||
event_type = event.get('type', 'message')
|
||||
return f"id: {event_id}\nevent: {event_type}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
def initialize_report_engine():
|
||||
"""初始化Report Engine"""
|
||||
@@ -63,6 +128,11 @@ class ReportTask:
|
||||
self.report_file_name = ""
|
||||
self.state_file_path = ""
|
||||
self.state_file_relative_path = ""
|
||||
# ====== 流式事件缓存与并发保护 ======
|
||||
# 使用deque保存最近的事件,结合锁保证多线程下的安全访问
|
||||
self.event_history: deque = deque(maxlen=1000)
|
||||
self._event_lock = threading.Lock()
|
||||
self.last_event_id = 0
|
||||
|
||||
def update_status(self, status: str, progress: int = None, error_message: str = ""):
|
||||
"""更新任务状态"""
|
||||
@@ -72,6 +142,17 @@ class ReportTask:
|
||||
if error_message:
|
||||
self.error_message = error_message
|
||||
self.updated_at = datetime.now()
|
||||
# 推送状态变更事件,方便前端实时刷新
|
||||
self.publish_event(
|
||||
'status',
|
||||
{
|
||||
'status': self.status,
|
||||
'progress': self.progress,
|
||||
'error_message': self.error_message,
|
||||
'hint': error_message or '',
|
||||
'task': self.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
@@ -91,6 +172,29 @@ class ReportTask:
|
||||
'state_file_path': self.state_file_relative_path or self.state_file_path
|
||||
}
|
||||
|
||||
def publish_event(self, event_type: str, payload: Dict[str, Any]) -> None:
|
||||
"""将任意事件放入缓存并广播,所有新增逻辑均配套中文说明。"""
|
||||
timestamp = datetime.utcnow().isoformat() + 'Z'
|
||||
event: Dict[str, Any] = {
|
||||
'id': 0,
|
||||
'type': event_type,
|
||||
'task_id': self.task_id,
|
||||
'timestamp': timestamp,
|
||||
'payload': payload,
|
||||
}
|
||||
with self._event_lock:
|
||||
self.last_event_id += 1
|
||||
event['id'] = self.last_event_id
|
||||
self.event_history.append(event)
|
||||
_broadcast_event(self.task_id, event)
|
||||
|
||||
def history_since(self, last_event_id: Optional[int]) -> List[Dict[str, Any]]:
|
||||
"""根据Last-Event-ID补发历史事件,确保断线重连无遗漏。"""
|
||||
with self._event_lock:
|
||||
if last_event_id is None:
|
||||
return list(self.event_history)
|
||||
return [evt for evt in self.event_history if evt['id'] > last_event_id]
|
||||
|
||||
|
||||
def check_engines_ready() -> Dict[str, Any]:
|
||||
"""检查三个子引擎是否都有新文件"""
|
||||
@@ -121,7 +225,13 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
|
||||
global current_task
|
||||
|
||||
try:
|
||||
# 在局部闭包内封装推送逻辑,便于传递给ReportAgent
|
||||
def stream_handler(event_type: str, payload: Dict[str, Any]):
|
||||
"""所有阶段事件都通过同一个接口分发,保证日志一致。"""
|
||||
task.publish_event(event_type, payload)
|
||||
|
||||
task.update_status("running", 10)
|
||||
task.publish_event('stage', {'message': '任务已启动,正在检查输入文件', 'stage': 'prepare'})
|
||||
|
||||
# 检查输入文件
|
||||
check_result = check_engines_ready()
|
||||
@@ -129,21 +239,54 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
|
||||
task.update_status("error", 0, f"输入文件未准备就绪: {check_result.get('missing_files', [])}")
|
||||
return
|
||||
|
||||
task.publish_event('stage', {
|
||||
'message': '输入文件检查通过,准备载入内容',
|
||||
'stage': 'io_ready',
|
||||
'files': check_result.get('latest_files', {})
|
||||
})
|
||||
|
||||
task.update_status("running", 30)
|
||||
|
||||
# 加载输入文件
|
||||
content = report_agent.load_input_files(check_result['latest_files'])
|
||||
task.publish_event('stage', {'message': '源数据加载完成,启动生成流程', 'stage': 'data_loaded'})
|
||||
|
||||
task.update_status("running", 50)
|
||||
|
||||
# 生成报告
|
||||
generation_result = report_agent.generate_report(
|
||||
query=query,
|
||||
reports=content['reports'],
|
||||
forum_logs=content['forum_logs'],
|
||||
custom_template=custom_template,
|
||||
save_report=True
|
||||
)
|
||||
# 生成报告(附带兜底重试,缓解瞬时网络抖动)
|
||||
for attempt in range(1, 3):
|
||||
try:
|
||||
task.publish_event('stage', {
|
||||
'message': f'正在调用ReportAgent生成报告(第{attempt}次尝试)',
|
||||
'stage': 'agent_running',
|
||||
'attempt': attempt
|
||||
})
|
||||
generation_result = report_agent.generate_report(
|
||||
query=query,
|
||||
reports=content['reports'],
|
||||
forum_logs=content['forum_logs'],
|
||||
custom_template=custom_template,
|
||||
save_report=True,
|
||||
stream_handler=stream_handler
|
||||
)
|
||||
break
|
||||
except Exception as err:
|
||||
# 将错误即时推送至前端,方便观察重试策略
|
||||
task.publish_event('warning', {
|
||||
'message': f'ReportAgent执行失败: {str(err)}',
|
||||
'stage': 'agent_running',
|
||||
'attempt': attempt
|
||||
})
|
||||
if attempt == 2:
|
||||
raise
|
||||
# 简单的指数退避,防止频繁触发限流(单位秒)
|
||||
backoff = min(5 * attempt, 15)
|
||||
task.publish_event('stage', {
|
||||
'message': f'{backoff} 秒后重试生成任务',
|
||||
'stage': 'retry_wait',
|
||||
'wait_seconds': backoff
|
||||
})
|
||||
time.sleep(backoff)
|
||||
|
||||
if isinstance(generation_result, dict):
|
||||
html_report = generation_result.get('html_content', '')
|
||||
@@ -151,6 +294,7 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
|
||||
html_report = generation_result
|
||||
|
||||
task.update_status("running", 90)
|
||||
task.publish_event('stage', {'message': '报告生成完毕,准备持久化', 'stage': 'persist'})
|
||||
|
||||
# 保存结果
|
||||
task.html_content = html_report
|
||||
@@ -160,11 +304,28 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
|
||||
task.report_file_name = generation_result.get('report_filename', '')
|
||||
task.state_file_path = generation_result.get('state_filepath', '')
|
||||
task.state_file_relative_path = generation_result.get('state_relative_path', '')
|
||||
task.publish_event('html_ready', {
|
||||
'message': 'HTML渲染完成,可刷新预览',
|
||||
'report_file': task.report_file_relative_path or task.report_file_path,
|
||||
'state_file': task.state_file_relative_path or task.state_file_path,
|
||||
'task': task.to_dict(),
|
||||
})
|
||||
task.update_status("completed", 100)
|
||||
task.publish_event('completed', {
|
||||
'message': '任务完成',
|
||||
'duration_seconds': (task.updated_at - task.created_at).total_seconds(),
|
||||
'report_file': task.report_file_relative_path or task.report_file_path,
|
||||
'task': task.to_dict(),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"报告生成过程中发生错误: {str(e)}")
|
||||
task.update_status("error", 0, str(e))
|
||||
task.publish_event('error', {
|
||||
'message': str(e),
|
||||
'stage': 'failed',
|
||||
'task': task.to_dict(),
|
||||
})
|
||||
# 只在出错时清理任务
|
||||
with task_lock:
|
||||
if current_task and current_task.task_id == task.task_id:
|
||||
@@ -242,6 +403,19 @@ def generate_report():
|
||||
|
||||
with task_lock:
|
||||
current_task = task
|
||||
tasks_registry[task_id] = task
|
||||
_prune_task_history_locked()
|
||||
|
||||
# 通过主动推送pending事件告知前端任务已经排队
|
||||
task.publish_event(
|
||||
'status',
|
||||
{
|
||||
'status': task.status,
|
||||
'progress': task.progress,
|
||||
'message': '任务已排队,等待资源空闲',
|
||||
'task': task.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
# 在后台线程中运行报告生成
|
||||
thread = threading.Thread(
|
||||
@@ -255,7 +429,8 @@ def generate_report():
|
||||
'success': True,
|
||||
'task_id': task_id,
|
||||
'message': '报告生成已启动',
|
||||
'task': task.to_dict()
|
||||
'task': task.to_dict(),
|
||||
'stream_url': f"/api/report/stream/{task_id}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -270,9 +445,9 @@ def generate_report():
|
||||
def get_progress(task_id: str):
|
||||
"""获取报告生成进度"""
|
||||
try:
|
||||
if not current_task or current_task.task_id != task_id:
|
||||
# 如果任务不存在,可能是已经完成并被清理了
|
||||
# 返回一个默认的完成状态而不是404
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
# 如果任务不存在,可能是历史记录已被清理,回传一个完成态兜底
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'task': {
|
||||
@@ -291,7 +466,7 @@ def get_progress(task_id: str):
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'task': current_task.to_dict()
|
||||
'task': task.to_dict()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -302,25 +477,78 @@ def get_progress(task_id: str):
|
||||
}), 500
|
||||
|
||||
|
||||
@report_bp.route('/stream/<task_id>', methods=['GET'])
|
||||
def stream_task(task_id: str):
|
||||
"""基于SSE的实时推送接口,向前端持续广播阶段事件。"""
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({'success': False, 'error': '任务不存在'}), 404
|
||||
|
||||
last_event_header = request.headers.get('Last-Event-ID')
|
||||
try:
|
||||
last_event_id = int(last_event_header) if last_event_header else None
|
||||
except ValueError:
|
||||
last_event_id = None
|
||||
|
||||
def event_generator():
|
||||
queue = _register_stream(task_id)
|
||||
try:
|
||||
# 断线重连场景下,先补发历史事件,保证界面状态一致
|
||||
history = task.history_since(last_event_id)
|
||||
for event in history:
|
||||
yield _format_sse(event)
|
||||
|
||||
finished = task.status in ("completed", "error", "cancelled")
|
||||
while True:
|
||||
if finished:
|
||||
break
|
||||
try:
|
||||
event = queue.get(timeout=STREAM_HEARTBEAT_INTERVAL)
|
||||
yield _format_sse(event)
|
||||
if event.get('type') in ("completed", "error"):
|
||||
finished = True
|
||||
except Empty:
|
||||
heartbeat = {
|
||||
'id': f"hb-{int(time.time() * 1000)}",
|
||||
'type': 'heartbeat',
|
||||
'task_id': task_id,
|
||||
'timestamp': datetime.utcnow().isoformat() + 'Z',
|
||||
'payload': {'status': task.status}
|
||||
}
|
||||
yield _format_sse(heartbeat)
|
||||
finished = task.status in ("completed", "error", "cancelled")
|
||||
finally:
|
||||
_unregister_stream(task_id, queue)
|
||||
|
||||
response = Response(
|
||||
stream_with_context(event_generator()),
|
||||
mimetype='text/event-stream'
|
||||
)
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
response.headers['X-Accel-Buffering'] = 'no'
|
||||
return response
|
||||
|
||||
|
||||
@report_bp.route('/result/<task_id>', methods=['GET'])
|
||||
def get_result(task_id: str):
|
||||
"""获取报告生成结果"""
|
||||
try:
|
||||
if not current_task or current_task.task_id != task_id:
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '任务不存在'
|
||||
}), 404
|
||||
|
||||
if current_task.status != "completed":
|
||||
if task.status != "completed":
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '报告尚未完成',
|
||||
'task': current_task.to_dict()
|
||||
'task': task.to_dict()
|
||||
}), 400
|
||||
|
||||
return Response(
|
||||
current_task.html_content,
|
||||
task.html_content,
|
||||
mimetype='text/html'
|
||||
)
|
||||
|
||||
@@ -336,23 +564,24 @@ def get_result(task_id: str):
|
||||
def get_result_json(task_id: str):
|
||||
"""获取报告生成结果(JSON格式)"""
|
||||
try:
|
||||
if not current_task or current_task.task_id != task_id:
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '任务不存在'
|
||||
}), 404
|
||||
|
||||
if current_task.status != "completed":
|
||||
if task.status != "completed":
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '报告尚未完成',
|
||||
'task': current_task.to_dict()
|
||||
'task': task.to_dict()
|
||||
}), 400
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'task': current_task.to_dict(),
|
||||
'html_content': current_task.html_content
|
||||
'task': task.to_dict(),
|
||||
'html_content': task.html_content
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -367,27 +596,28 @@ def get_result_json(task_id: str):
|
||||
def download_report(task_id: str):
|
||||
"""下载已生成的报告HTML文件"""
|
||||
try:
|
||||
if not current_task or current_task.task_id != task_id:
|
||||
task = _get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '任务不存在'
|
||||
}), 404
|
||||
|
||||
if current_task.status != "completed" or not current_task.report_file_path:
|
||||
if task.status != "completed" or not task.report_file_path:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '报告尚未完成或尚未保存'
|
||||
}), 400
|
||||
|
||||
if not os.path.exists(current_task.report_file_path):
|
||||
if not os.path.exists(task.report_file_path):
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': '报告文件不存在或已被删除'
|
||||
}), 404
|
||||
|
||||
download_name = current_task.report_file_name or os.path.basename(current_task.report_file_path)
|
||||
download_name = task.report_file_name or os.path.basename(task.report_file_path)
|
||||
return send_file(
|
||||
current_task.report_file_path,
|
||||
task.report_file_path,
|
||||
mimetype='text/html',
|
||||
as_attachment=True,
|
||||
download_name=download_name
|
||||
@@ -411,7 +641,18 @@ def cancel_task(task_id: str):
|
||||
if current_task and current_task.task_id == task_id:
|
||||
if current_task.status == "running":
|
||||
current_task.update_status("cancelled", 0, "用户取消任务")
|
||||
current_task.publish_event('cancelled', {
|
||||
'message': '任务被用户主动终止',
|
||||
'task': current_task.to_dict(),
|
||||
})
|
||||
current_task = None
|
||||
task = tasks_registry.get(task_id)
|
||||
if task and task.status == 'running':
|
||||
task.update_status("cancelled", task.progress, "用户取消任务")
|
||||
task.publish_event('cancelled', {
|
||||
'message': '任务被用户主动终止',
|
||||
'task': task.to_dict(),
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
|
||||
Reference in New Issue
Block a user