feat: FastAPI+SSE API server, JRXML auto-reorder, session integrity fixes

This commit is contained in:
2026-05-22 17:53:59 +08:00
parent 1144a86d02
commit 1e5ce9725b
32 changed files with 9189 additions and 309 deletions
+24 -2
View File
@@ -181,6 +181,13 @@ def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
for node_state in data.values():
if isinstance(node_state, dict):
agent_state.update(node_state)
# 在 graph 完成后立即保存 session,防止 SSE 流中断导致数据丢失
sid = agent_state.get("session_id", "")
if sid:
try:
save_session(sid, agent_state)
except Exception:
pass # 静默失败,SSE 流中还有一次保存机会
event_q.put(("done", {"reason": "graph_completed"}))
except Exception as exc:
event_q.put(("error", {
@@ -218,6 +225,8 @@ async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str:
total_ms = round((time.time() - t_start) * 1000)
if session_id:
save_session(session_id, agent_state)
versions = agent_state.get("jrxml_versions", [])
last_ver = versions[-1] if versions else {}
yield _sse_line("agent_complete", {
"reason": "done",
"intent": agent_state.get("intent", ""),
@@ -228,6 +237,9 @@ async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str:
"retry_count": agent_state.get("retry_count", 0),
"total_duration_ms": total_ms,
"ocr_extraction_result": agent_state.get("ocr_extraction_result", {}),
"versions": len(versions),
"has_failed_version": last_ver.get("status") == "fail" if last_ver else False,
"failed_version_index": len(versions) - 1 if last_ver.get("status") == "fail" else -1,
})
await future
return
@@ -532,7 +544,17 @@ async def chat(session_id: str, payload: dict):
# ── 返回 SSE 流 ──
async def stream_and_save():
final_state = None
# 如果上传了附件,先发送处理状态
if file_ids:
yield _sse_line("node_start", {
"node": "process_attachments",
"label": "正在处理附件",
})
yield _sse_line("node_complete", {
"node": "process_attachments",
"label": "正在处理附件",
"detail": f"已解析 {len(file_ids)} 个文件",
})
async for sse_chunk in _sse_generator(agent_state, session_id):
yield sse_chunk
@@ -622,4 +644,4 @@ async def download_file(file_id: str):
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("API_PORT", "8000"))
uvicorn.run("api_server:app", host="0.0.0.0", port=port, reload=True)
uvicorn.run("api_server:app", host="0.0.0.0", port=port, reload=False)