"""JRXML Agent API Server — FastAPI + SSE streaming. Replaces the Streamlit UI (app.py) with a REST + SSE backend. The LangGraph agent pipeline is wrapped unchanged. SSE Event Types: node_start — 节点开始执行 node_complete — 节点执行完成(含详情) stream_token — LLM 逐字输出 agent_complete — 全图执行完成 agent_error — 执行异常 Usage: python -m uvicorn api_server:app --host 0.0.0.0 --port 8000 """ import asyncio import base64 import json import mimetypes import os import queue import tempfile import time import traceback import uuid from datetime import datetime, timezone from pathlib import Path from typing import Optional from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, FileResponse load_dotenv(override=True) from agent.graph import build_graph from agent.state import AgentState from backend.logger import get_logger, generate_trace_id, set_trace_id, get_trace_id from backend.session import ( create_session, load_session, save_session, list_all_sessions, delete_session, get_session_state, SESSIONS_DIR, ) from backend.file_parser import parse_file from backend.layout_analyzer import analyze_layout, extract_layout_schema # ───────────────────────────────────────────── # 常量(从 app.py 迁移) # ───────────────────────────────────────────── NODE_LABELS = { "load_session": "加载会话", "process_input": "记录输入", "manage_context": "管理上下文", "save_state_snapshot": "保存快照", "classify_intent": "识别意图", "retrieve": "检索模板", "generate": "生成 JRXML", "modify_jrxml": "修改 JRXML", "validate": "验证", "explain_error": "分析错误", "correct_jrxml": "自动修正", "finalize": "完成", "handle_consult": "咨询回答", "handle_undo": "撤销操作", "handle_reset": "重置会话", "save_session": "保存会话", "generate_skeleton": "生成骨架", "refine_layout": "精调布局", "map_fields": "映射字段", } INTENT_LABELS = { "initial_generation": "新建报表", "modify_report": "修改报表", "preview_report": "预览报表", "export_pdf": "导出 PDF", "export_jrxml": "下载 JRXML", "undo_modification": "撤销修改", "consult_question": "咨询问题", "reset_session": "重置会话", } SKIP_NODES = {"load_session", "process_input", "manage_context", "save_state_snapshot", "save_session"} # ───────────────────────────────────────────── # 日志 & 路径 # ───────────────────────────────────────────── _api_log = get_logger("api") UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", "./uploads")) # ───────────────────────────────────────────── # 图编译(全局单例,带 node_start 回调) # ───────────────────────────────────────────── # 当前请求的事件队列(单个用户桌面应用,无并发问题) _current_event_queue: Optional[queue.Queue] = None _step_counter: int = 0 def _on_node_start(node_name: str): """全局 node_start 回调 — 将事件推入当前请求的事件队列。""" global _step_counter q = _current_event_queue if q is not None: _step_counter += 1 q.put(("node_start", { "node": node_name, "label": NODE_LABELS.get(node_name, node_name), "step_index": _step_counter, })) _graph = build_graph(on_node_start=_on_node_start) # ───────────────────────────────────────────── # 文件注册表(内存中,桌面应用级别可接受) # ───────────────────────────────────────────── _file_registry: dict[str, dict] = {} # file_id → {path, filename, content_type, size} def _ensure_upload_dir(session_id: str = "") -> Path: d = UPLOADS_DIR / session_id if session_id else UPLOADS_DIR d.mkdir(parents=True, exist_ok=True) return d # ───────────────────────────────────────────── # SSE 辅助 # ───────────────────────────────────────────── def _extract_detail(node_name: str, node_state: dict) -> str: """从节点状态中提取详情文本(用于 node_complete 事件)。""" if node_name == "classify_intent": intent = node_state.get("intent", "") return f"意图: {INTENT_LABELS.get(intent, intent)}" elif node_name == "retrieve": ctx = node_state.get("retrieved_context", "") return f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板" elif node_name in ("generate", "modify_jrxml", "correct_jrxml", "generate_skeleton", "refine_layout", "map_fields"): jrxml = node_state.get("current_jrxml", "") return f"生成 {len(jrxml)} 字符 JRXML" elif node_name == "validate": status = node_state.get("status", "") if status == "pass": return "验证通过 ✓" err = node_state.get("error_msg", "") return f"验证失败: {err[:80]}" elif node_name == "explain_error": expl = node_state.get("natural_explanation", "") return expl[:120] elif node_name == "handle_consult": ans = node_state.get("consult_answer", "") return ans[:150] return "" def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue): """在后台线程中运行 graph.stream(),将所有事件推入队列。 graph.stream() 只产出事件,不修改传入的 agent_state。 因此需要手动收集每个节点的返回并合并到 agent_state。 """ try: for event in _graph.stream(agent_state, stream_mode=["updates", "custom"]): event_q.put(event) # 将节点更新合并到 agent_state if isinstance(event, tuple) and len(event) == 2: mode, data = event if mode == "updates" and isinstance(data, dict): for node_state in data.values(): if isinstance(node_state, dict): agent_state.update(node_state) event_q.put(("done", {"reason": "graph_completed"})) except Exception as exc: event_q.put(("error", { "error": str(exc), "traceback": traceback.format_exc(), })) async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str: """SSE 事件生成器 —— 在后台线程运行图,异步产出 SSE 字符串。""" global _current_event_queue, _step_counter _step_counter = 0 t_start = time.time() event_q: queue.Queue = queue.Queue() _current_event_queue = event_q loop = asyncio.get_running_loop() future = loop.run_in_executor(None, _run_graph_sync, agent_state, event_q) # 从队列读取事件,写 SSE(用 short sleep 做非阻塞轮询) while True: # 先排空队列中的所有事件 had_events = False while True: try: item = event_q.get_nowait() had_events = True except queue.Empty: break kind = item[0] if kind == "done": _current_event_queue = None total_ms = round((time.time() - t_start) * 1000) if session_id: save_session(session_id, agent_state) yield _sse_line("agent_complete", { "reason": "done", "intent": agent_state.get("intent", ""), "status": agent_state.get("status", ""), "jrxml_length": len(agent_state.get("current_jrxml", "")), "error_msg": agent_state.get("error_msg", ""), "natural_explanation": agent_state.get("natural_explanation", ""), "retry_count": agent_state.get("retry_count", 0), "total_duration_ms": total_ms, "ocr_extraction_result": agent_state.get("ocr_extraction_result", {}), }) await future return elif kind == "error": _current_event_queue = None yield _sse_line("agent_error", item[1]) await future return elif kind == "node_start": yield _sse_line("node_start", item[1]) else: # mode=data 来自 graph.stream() mode, data = item if mode == "updates": for node_name, node_state in data.items(): detail = _extract_detail(node_name, node_state) yield _sse_line("node_complete", { "node": node_name, "label": NODE_LABELS.get(node_name, node_name), "detail": detail, }) elif mode == "custom": cd = data if cd.get("type") == "stream": yield _sse_line("stream_token", { "text": cd.get("text", ""), "type": "stream", }) if not had_events: await asyncio.sleep(0.05) yield ": keepalive\n\n" def _sse_line(event_type: str, data: dict) -> str: """构造单条 SSE 消息。""" payload = json.dumps(data, ensure_ascii=False) return f"event: {event_type}\ndata: {payload}\n\n" # ───────────────────────────────────────────── # FastAPI 应用 # ───────────────────────────────────────────── app = FastAPI( title="JRXML Agent API", version="5.0", description="JRXML 报表生成代理 — 前后端分离 API", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ───────────────────────────────────────────── # 健康检查 & 配置 # ───────────────────────────────────────────── @app.get("/api/health") async def health(): return { "status": "ok", "version": "5.0", "timestamp": datetime.now(timezone.utc).isoformat(), } @app.get("/api/config") async def config(): safe = {} for key in ("LLM_PROVIDER", "OCR_ENGINE", "EMBEDDING_PROVIDER", "MAX_RETRY", "CONTEXT_MAX_TOKENS", "CONTEXT_KEEP_RECENT"): val = os.getenv(key, "") safe[key] = val return {"config": safe} # ───────────────────────────────────────────── # 会话 CRUD # ───────────────────────────────────────────── @app.post("/api/sessions") async def create_new_session(): data = create_session() return { "session_id": data["session_id"], "session_name": data["session_name"], "created_at": data["created_at"], "updated_at": data["updated_at"], } @app.get("/api/sessions") async def list_sessions(): return {"sessions": list_all_sessions()} @app.get("/api/sessions/{session_id}") async def get_session(session_id: str): data = get_session_state(session_id) if data is None: raise HTTPException(status_code=404, detail="会话不存在") return { "session_id": data.get("session_id"), "session_name": data.get("session_name"), "created_at": data.get("created_at"), "updated_at": data.get("updated_at"), "agent_state": data.get("agent_state", {}), } @app.delete("/api/sessions/{session_id}") async def remove_session(session_id: str): ok = delete_session(session_id) if not ok: raise HTTPException(status_code=404, detail="会话不存在或已删除") return {"status": "deleted", "session_id": session_id} # ───────────────────────────────────────────── # 文件上传 # ───────────────────────────────────────────── @app.post("/api/upload") async def upload_file(file: UploadFile = File(...), session_id: str = ""): file_id = uuid.uuid4().hex[:12] _ensure_upload_dir(session_id) # 保留原始文件名 safe_name = Path(file.filename or "upload.bin").name dest = _ensure_upload_dir(session_id) / f"{file_id}_{safe_name}" content = await file.read() dest.write_bytes(content) content_type = file.content_type or mimetypes.guess_type(safe_name)[0] or "application/octet-stream" _file_registry[file_id] = { "path": str(dest), "filename": safe_name, "content_type": content_type, "size": len(content), } _api_log.info("文件上传", extra={ "file_id": file_id, "file_name": safe_name, "size": len(content), }) return { "file_id": file_id, "filename": safe_name, "content_type": content_type, "size": len(content), } # ───────────────────────────────────────────── # 文件处理辅助 # ───────────────────────────────────────────── def _process_files(file_ids: list[str], session_id: str) -> dict: """处理上传的文件:解析 → 布局分析 → 提取 schema 文本。 Returns: {full_prompt_prefix, uploaded_paths, layout_schema, ocr_text} """ if not file_ids: return {"full_prompt_prefix": "", "uploaded_paths": [], "layout_schema": {}, "ocr_text": ""} parts = [] uploaded_paths = [] layout_schema = {} ocr_text = "" for fid in file_ids: info = _file_registry.get(fid) if not info: _api_log.warning("文件ID未注册", extra={"file_id": fid}) continue file_path = info["path"] uploaded_paths.append(file_path) parsed = parse_file(file_path, Path(info["filename"]).suffix) if parsed.get("error"): parts.append(f"[文件: {info['filename']}]\n解析失败: {parsed['error']}") continue parts.append(f"[文件: {info['filename']}]\n{parsed['text']}") # 图片文件 → 布局分析 if info["content_type"] and info["content_type"].startswith("image/"): layout = analyze_layout(file_path) if layout.get("is_a4_template"): parts.append( f"\n[A4模板布局]\n" f"表格行数: {layout.get('total_rows', 0)}, " f"总元素: {layout.get('total_elements', 0)}, " f"比例: {layout.get('a4_confidence', '')}" ) if layout.get("description"): parts.append(f"\n[布局描述]\n{layout['description']}") schema = extract_layout_schema(layout) if schema and schema.get("total_rows", 0) > 0: layout_schema = schema schema_text = schema.get("schema_text", "") if schema_text: parts.append(f"\n[布局Schema]\n{schema_text}") # OCR 元素文本 ocr_elements = layout.get("rows", []) if ocr_elements: ocr_lines = [] for row in ocr_elements[:30]: texts = [e.get("text", "") for e in row.get("elements", [])] ocr_lines.append(" | ".join(texts)) ocr_text = "\n".join(ocr_lines) if ocr_lines: parts.append(f"\n[OCR 识别文本]\n{ocr_text}") return { "full_prompt_prefix": "\n\n".join(parts) if parts else "", "uploaded_paths": uploaded_paths, "layout_schema": layout_schema, "ocr_text": ocr_text, } # ───────────────────────────────────────────── # 核心:SSE 聊天端点 # ───────────────────────────────────────────── @app.post("/api/sessions/{session_id}/chat") async def chat(session_id: str, payload: dict): """发送消息并获取 SSE 流式响应。 Body: {text: str, file_ids: [str, ...]} Returns: text/event-stream (SSE) """ text = payload.get("text", "").strip() file_ids = payload.get("file_ids", []) if not text and not file_ids: raise HTTPException(status_code=400, detail="text 和 file_ids 均为空") # ── 加载或创建会话 ── trace_id = generate_trace_id() set_trace_id(trace_id) data = load_session(session_id) if data is None: data = create_session(session_id=session_id) _api_log.info("自动创建会话", extra={"session_id": session_id, "trace_id": trace_id}) agent_state: AgentState = data.get("agent_state", {}) agent_state["session_id"] = session_id # ── 处理文件 ── file_result = _process_files(file_ids, session_id) full_prompt = text if file_result["full_prompt_prefix"]: full_prompt = f"{file_result['full_prompt_prefix']}\n\n用户问题: {text}" if text else file_result["full_prompt_prefix"] # ── 注入布局 schema(用于分层精确生成)── if file_result.get("layout_schema"): agent_state["layout_schema"] = file_result["layout_schema"] if file_result.get("ocr_text"): ocr_rows = [{"elements": [{"text": t} for t in line.split(" | ")]} for line in file_result["ocr_text"].split("\n") if line.strip()] if ocr_rows: agent_state["ocr_elements"] = ocr_rows if file_result.get("uploaded_paths"): agent_state["uploaded_file_path"] = file_result["uploaded_paths"][0] # ── 设置本轮输入 ── if agent_state.get("current_jrxml"): agent_state["user_modification_request"] = full_prompt agent_state["user_input"] = full_prompt agent_state["retry_count"] = 0 _api_log.info("对话请求", extra={ "session_id": session_id, "trace_id": trace_id, "text_length": len(text), "file_count": len(file_ids), "prompt_total": len(full_prompt), }) # ── 返回 SSE 流 ── async def stream_and_save(): final_state = None async for sse_chunk in _sse_generator(agent_state, session_id): yield sse_chunk return StreamingResponse( stream_and_save(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", "X-Trace-Id": trace_id, }, ) # ───────────────────────────────────────────── # 下载 # ───────────────────────────────────────────── @app.get("/api/sessions/{session_id}/download/latest") async def download_latest(session_id: str): """下载最新 JRXML 文件。""" data = load_session(session_id) if data is None: raise HTTPException(status_code=404, detail="会话不存在") agent_state = data.get("agent_state", {}) jrxml = agent_state.get("current_jrxml", "") if not jrxml: raise HTTPException(status_code=404, detail="该会话暂无 JRXML") tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml", delete=False, encoding="utf-8") tmp.write(jrxml) tmp.close() return FileResponse( tmp.name, media_type="application/xml", filename=f"report_{session_id}.jrxml", ) @app.get("/api/sessions/{session_id}/download/{version}") async def download_version(session_id: str, version: int): """下载指定版本的 JRXML 文件。""" data = load_session(session_id) if data is None: raise HTTPException(status_code=404, detail="会话不存在") versions = data.get("agent_state", {}).get("jrxml_versions", []) if version < 0 or version >= len(versions): raise HTTPException(status_code=404, detail="版本不存在") jrxml = versions[version].get("jrxml", "") if not jrxml: raise HTTPException(status_code=404, detail="该版本内容为空") tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml", delete=False, encoding="utf-8") tmp.write(jrxml) tmp.close() return FileResponse( tmp.name, media_type="application/xml", filename=f"report_{session_id}_v{version}.jrxml", ) # ───────────────────────────────────────────── # 下载上传文件 # ───────────────────────────────────────────── @app.get("/api/files/{file_id}") async def download_file(file_id: str): info = _file_registry.get(file_id) if not info: raise HTTPException(status_code=404, detail="文件未找到") return FileResponse(info["path"], filename=info["filename"]) # ───────────────────────────────────────────── # 启动入口 # ───────────────────────────────────────────── if __name__ == "__main__": import uvicorn port = int(os.getenv("API_PORT", "8000")) uvicorn.run("api_server:app", host="0.0.0.0", port=port, reload=True)