Files
panda bd5bfbac2d fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.

Solution (programmatic node control, not prompt engineering):

- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
  LLM) + individual bands. Split bands >4000 chars at element boundaries.
  Reassemble with element count validation (>10% change = rollback).

- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
  each). LLM cannot "reimagine" the entire report.

- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
  replacement. Zero LLM calls, zero content loss.

- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
  valid JRXML identifiers.

- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
  Full suite 385 tests, zero regressions.
2026-05-24 08:55:38 +08:00

949 lines
36 KiB
Python

"""JRXML Agent API Server — FastAPI + SSE streaming.
Replaces the Streamlit UI (app.py) with a REST + SSE backend.
The LangGraph agent pipeline is wrapped unchanged.
SSE Event Types:
node_start — 节点开始执行
node_complete — 节点执行完成(含详情)
stream_token — LLM 逐字输出
agent_complete — 全图执行完成
agent_error — 执行异常
Usage:
python -m uvicorn api_server:app --host 0.0.0.0 --port 8000
"""
import asyncio
import base64
import contextvars
import json
import mimetypes
import os
import queue
import tempfile
import time
import traceback
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
load_dotenv(override=True)
from agent.graph import build_graph
from agent.state import AgentState
from backend.logger import get_logger, generate_trace_id, set_trace_id, get_trace_id
from backend.session import (
create_session,
load_session,
save_session,
list_all_sessions,
delete_session,
get_session_state,
SESSIONS_DIR,
)
from backend.file_parser import parse_file
from backend.layout_analyzer import analyze_layout, extract_layout_schema
from backend.kb_manager import (
create_user, list_users, get_user, delete_user,
create_kb, list_kbs, get_kb, update_kb_meta, delete_kb,
get_kb_raw_dir,
)
from backend.kb_parser import parse_jrxml_fields, build_kb_from_files
from backend.kb_searcher import search_kb, search_templates_in_kb
# ─────────────────────────────────────────────
# 常量(从 app.py 迁移)
# ─────────────────────────────────────────────
NODE_LABELS = {
"load_session": "加载会话",
"process_input": "记录输入",
"manage_context": "管理上下文",
"save_state_snapshot": "保存快照",
"classify_intent": "识别意图",
"retrieve": "检索模板",
"generate": "生成 JRXML",
"modify_jrxml": "修改 JRXML",
"validate": "验证",
"explain_error": "分析错误",
"correct_jrxml": "自动修正",
"finalize": "完成",
"handle_consult": "咨询回答",
"handle_undo": "撤销操作",
"handle_reset": "重置会话",
"save_session": "保存会话",
"generate_skeleton": "生成骨架",
"refine_layout": "精调布局",
"map_fields": "映射字段",
}
INTENT_LABELS = {
"initial_generation": "新建报表",
"modify_report": "修改报表",
"preview_report": "预览报表",
"export_pdf": "导出 PDF",
"export_jrxml": "下载 JRXML",
"undo_modification": "撤销修改",
"consult_question": "咨询问题",
"reset_session": "重置会话",
}
SKIP_NODES = {"load_session", "process_input", "manage_context",
"save_state_snapshot", "save_session"}
# ─────────────────────────────────────────────
# 日志 & 路径
# ─────────────────────────────────────────────
_api_log = get_logger("api")
UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", "./uploads"))
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB
def _check_session_id(session_id: str) -> None:
"""校验 session_id 合法性(防路径穿越),非法时抛出 HTTPException(400)。"""
from backend.session import validate_session_id
if not validate_session_id(session_id):
raise HTTPException(status_code=400, detail=f"Invalid session_id: {session_id!r}")
# ─────────────────────────────────────────────
# 图编译(全局单例,带 node_start 回调)
# ─────────────────────────────────────────────
# 当前请求的事件队列(单个用户桌面应用)
_current_event_queue: Optional[queue.Queue] = None
_step_counter: contextvars.ContextVar[int] = contextvars.ContextVar('_step_counter', default=0)
def _on_node_start(node_name: str):
"""全局 node_start 回调 — 将事件推入当前请求的事件队列。"""
q = _current_event_queue
if q is not None:
_step_counter.set(_step_counter.get() + 1)
q.put(("node_start", {
"node": node_name,
"label": NODE_LABELS.get(node_name, node_name),
"step_index": _step_counter.get(),
}))
_graph = build_graph(on_node_start=_on_node_start)
# ─────────────────────────────────────────────
# 文件注册表(内存中,桌面应用级别可接受)
# ─────────────────────────────────────────────
_file_registry: dict[str, dict] = {} # file_id → {path, filename, content_type, size}
def _ensure_upload_dir(session_id: str = "") -> Path:
d = UPLOADS_DIR / session_id if session_id else UPLOADS_DIR
d.mkdir(parents=True, exist_ok=True)
return d
# ─────────────────────────────────────────────
# SSE 辅助
# ─────────────────────────────────────────────
def _extract_detail(node_name: str, node_state: dict) -> str:
"""从节点状态中提取详情文本(用于 node_complete 事件)。"""
if node_name == "classify_intent":
intent = node_state.get("intent", "")
return f"意图: {INTENT_LABELS.get(intent, intent)}"
elif node_name == "retrieve":
ctx = node_state.get("retrieved_context", "")
return f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
elif node_name in ("generate", "modify_jrxml", "correct_jrxml",
"generate_skeleton", "refine_layout", "map_fields"):
jrxml = node_state.get("current_jrxml", "")
return f"生成 {len(jrxml)} 字符 JRXML"
elif node_name == "validate":
status = node_state.get("status", "")
if status == "pass":
return "验证通过 ✓"
err = node_state.get("error_msg", "")
return f"验证失败: {err[:80]}"
elif node_name == "explain_error":
expl = node_state.get("natural_explanation", "")
return expl[:120]
elif node_name == "handle_consult":
ans = node_state.get("consult_answer", "")
return ans[:150]
return ""
def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
"""在后台线程中运行 graph.stream(),将所有事件推入队列。
graph.stream() 只产出事件,不修改传入的 agent_state。
因此需要手动收集每个节点的返回并合并到 agent_state。
"""
try:
for event in _graph.stream(agent_state, stream_mode=["updates", "custom"]):
event_q.put(event)
# 将节点更新合并到 agent_state
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
if mode == "updates" and isinstance(data, dict):
for node_state in data.values():
if isinstance(node_state, dict):
agent_state.update({k: v for k, v in node_state.items() if v is not None})
# 在 graph 完成后立即保存 session,防止 SSE 流中断导致数据丢失
sid = agent_state.get("session_id", "")
if sid:
try:
save_session(sid, agent_state)
except Exception as exc:
_api_log.error("图运行中保存会话失败", extra={
"session_id": sid,
"error": str(exc),
"traceback": traceback.format_exc(),
})
event_q.put(("done", {"reason": "graph_completed"}))
except Exception as exc:
event_q.put(("error", {
"error": str(exc),
"traceback": traceback.format_exc(),
}))
async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str:
"""SSE 事件生成器 —— 在后台线程运行图,异步产出 SSE 字符串。"""
global _current_event_queue
_step_counter.set(0)
t_start = time.time()
event_q: queue.Queue = queue.Queue()
_current_event_queue = event_q
loop = asyncio.get_running_loop()
future = loop.run_in_executor(None, _run_graph_sync, agent_state, event_q)
# 从队列读取事件,写 SSE(用 short sleep 做非阻塞轮询)
while True:
# 先排空队列中的所有事件
had_events = False
while True:
try:
item = event_q.get_nowait()
had_events = True
except queue.Empty:
break
kind = item[0]
if kind == "done":
_current_event_queue = None
total_ms = round((time.time() - t_start) * 1000)
if session_id:
save_session(session_id, agent_state)
versions = agent_state.get("jrxml_versions", [])
last_ver = versions[-1] if versions else {}
yield _sse_line("agent_complete", {
"reason": "done",
"intent": agent_state.get("intent", ""),
"status": agent_state.get("status", ""),
"jrxml_length": len(agent_state.get("current_jrxml", "")),
"error_msg": agent_state.get("error_msg", ""),
"natural_explanation": agent_state.get("natural_explanation", ""),
"consult_answer": agent_state.get("consult_answer", ""),
"retry_count": agent_state.get("retry_count", 0),
"total_duration_ms": total_ms,
"ocr_extraction_result": agent_state.get("ocr_extraction_result", {}),
"versions": len(versions),
"has_failed_version": last_ver.get("status") == "fail" if last_ver else False,
"failed_version_index": len(versions) - 1 if last_ver.get("status") == "fail" else -1,
})
await future
return
elif kind == "error":
_current_event_queue = None
yield _sse_line("agent_error", item[1])
await future
return
elif kind == "node_start":
yield _sse_line("node_start", item[1])
else:
# mode=data 来自 graph.stream()
mode, data = item
if mode == "updates":
for node_name, node_state in data.items():
detail = _extract_detail(node_name, node_state)
yield _sse_line("node_complete", {
"node": node_name,
"label": NODE_LABELS.get(node_name, node_name),
"detail": detail,
})
elif mode == "custom":
cd = data
if cd.get("type") == "stream":
yield _sse_line("stream_token", {
"text": cd.get("text", ""),
"type": "stream",
})
if not had_events:
await asyncio.sleep(0.05)
yield ": keepalive\n\n"
def _sse_line(event_type: str, data: dict) -> str:
"""构造单条 SSE 消息。"""
payload = json.dumps(data, ensure_ascii=False)
return f"event: {event_type}\ndata: {payload}\n\n"
# ─────────────────────────────────────────────
# FastAPI 应用
# ─────────────────────────────────────────────
app = FastAPI(
title="JRXML Agent API",
version="5.0",
description="JRXML 报表生成代理 — 前后端分离 API",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ─────────────────────────────────────────────
# 健康检查 & 配置
# ─────────────────────────────────────────────
@app.get("/api/health")
async def health():
return {
"status": "ok",
"version": "5.0",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
@app.get("/api/config")
async def config():
safe = {}
for key in ("LLM_PROVIDER", "OCR_ENGINE", "EMBEDDING_PROVIDER",
"MAX_RETRY", "CONTEXT_MAX_TOKENS", "CONTEXT_KEEP_RECENT"):
val = os.getenv(key, "")
safe[key] = val
return {"config": safe}
# ─────────────────────────────────────────────
# 会话 CRUD
# ─────────────────────────────────────────────
@app.post("/api/sessions")
async def create_new_session():
data = create_session()
return {
"session_id": data["session_id"],
"session_name": data["session_name"],
"created_at": data["created_at"],
"updated_at": data["updated_at"],
}
@app.get("/api/sessions")
async def list_sessions():
return {"sessions": list_all_sessions()}
@app.get("/api/sessions/{session_id}")
async def get_session(session_id: str):
_check_session_id(session_id)
data = get_session_state(session_id)
if data is None:
raise HTTPException(status_code=404, detail="会话不存在")
return {
"session_id": data.get("session_id"),
"session_name": data.get("session_name"),
"created_at": data.get("created_at"),
"updated_at": data.get("updated_at"),
"agent_state": data.get("agent_state", {}),
}
@app.delete("/api/sessions/{session_id}")
async def remove_session(session_id: str):
_check_session_id(session_id)
ok = delete_session(session_id)
if not ok:
raise HTTPException(status_code=404, detail="会话不存在或已删除")
return {"status": "deleted", "session_id": session_id}
# ─────────────────────────────────────────────
# 用户管理
# ─────────────────────────────────────────────
@app.post("/api/users")
async def create_new_user(payload: dict):
name = payload.get("name", "").strip()
if not name:
raise HTTPException(status_code=400, detail="用户名不能为空")
try:
user = create_user(name)
return user
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/users")
async def list_all_users():
return {"users": list_users()}
@app.get("/api/users/{user_id}")
async def get_user_info(user_id: str):
user = get_user(user_id)
if user is None:
raise HTTPException(status_code=404, detail="用户不存在")
return user
@app.delete("/api/users/{user_id}")
async def remove_user(user_id: str):
ok = delete_user(user_id)
if not ok:
raise HTTPException(status_code=404, detail="用户不存在")
return {"status": "deleted", "user_id": user_id}
# ─────────────────────────────────────────────
# 知识库 CRUD
# ─────────────────────────────────────────────
@app.get("/api/users/{user_id}/kbs")
async def list_user_kbs(user_id: str):
return {"kbs": list_kbs(user_id)}
@app.post("/api/users/{user_id}/kbs")
async def create_user_kb(user_id: str, payload: dict):
name = payload.get("name", "").strip()
description = payload.get("description", "")
if not name:
raise HTTPException(status_code=400, detail="知识库名称不能为空")
try:
kb = create_kb(user_id, name, description)
return kb
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/kbs/{kb_id}")
async def get_kb_info(kb_id: str):
kb = get_kb(kb_id)
if kb is None:
raise HTTPException(status_code=404, detail="知识库不存在")
return kb
@app.delete("/api/kbs/{kb_id}")
async def remove_kb(kb_id: str):
ok = delete_kb(kb_id)
if not ok:
raise HTTPException(status_code=404, detail="知识库不存在")
return {"status": "deleted", "kb_id": kb_id}
# ─────────────────────────────────────────────
# 知识库文件上传
# ─────────────────────────────────────────────
@app.post("/api/kbs/{kb_id}/upload")
async def upload_to_kb(kb_id: str, file: UploadFile = File(...)):
kb = get_kb(kb_id)
if kb is None:
raise HTTPException(status_code=404, detail="知识库不存在")
raw_dir = get_kb_raw_dir(kb_id)
if raw_dir is None:
raise HTTPException(status_code=500, detail="知识库存储目录不可用")
raw_dir.mkdir(parents=True, exist_ok=True)
safe_name = Path(file.filename or "upload").name
dest = raw_dir / safe_name
content = await file.read()
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=413, detail="文件大小超过 50MB 上限")
dest.write_bytes(content)
from backend.kb_parser import process_file_for_kb
result = process_file_for_kb(kb_id, str(dest), source_name=safe_name)
_api_log.info("KB文件上传", extra={
"kb_id": kb_id, "file": safe_name, "type": result.get("type"),
})
return {
"filename": safe_name,
"type": result.get("type", ""),
"error": result.get("error"),
}
@app.post("/api/kbs/{kb_id}/build")
async def build_kb(kb_id: str):
"""构建知识库:对已上传的文件执行 chunk → embed 管线。"""
from backend.kb_parser import build_kb_from_files as build_fn
raw_dir = get_kb_raw_dir(kb_id)
if raw_dir is None or not raw_dir.exists():
raise HTTPException(status_code=404, detail="知识库无已上传文件")
files = [str(p) for p in raw_dir.iterdir() if p.is_file()]
if not files:
raise HTTPException(status_code=400, detail="知识库无文件,请先上传")
result = build_fn(kb_id, files)
return result
@app.get("/api/kbs/{kb_id}/status")
async def kb_status(kb_id: str):
kb = get_kb(kb_id)
if kb is None:
raise HTTPException(status_code=404, detail="知识库不存在")
return {
"kb_id": kb_id,
"name": kb.get("name", ""),
"field_count": len(kb.get("fields", [])),
"template_count": len(kb.get("templates", [])),
"file_count": kb.get("file_count", 0),
"chunk_count": kb.get("chunk_count", 0),
"parse_status": kb.get("parse_status", "empty"),
"created_at": kb.get("created_at", ""),
}
@app.get("/api/kbs/{kb_id}/fields")
async def kb_fields(kb_id: str):
kb = get_kb(kb_id)
if kb is None:
raise HTTPException(status_code=404, detail="知识库不存在")
return {"fields": kb.get("fields", []), "templates": kb.get("templates", [])}
@app.get("/api/kbs/{kb_id}/search")
async def kb_search(kb_id: str, q: str = "", type: str = ""):
if not q:
raise HTTPException(status_code=400, detail="查询参数 q 不能为空")
if type == "template":
results = search_templates_in_kb(kb_id, q, k=5)
else:
ctx = search_kb(kb_id, q, k=5)
return {"query": q, "context": ctx}
return {"query": q, "results": results}
# ─────────────────────────────────────────────
# 会话-知识库绑定
# ─────────────────────────────────────────────
@app.put("/api/sessions/{session_id}/kb")
async def bind_session_kb(session_id: str, payload: dict):
_check_session_id(session_id)
kb_id = payload.get("kb_id", "").strip()
data = load_session(session_id)
if data is None:
raise HTTPException(status_code=404, detail="会话不存在")
agent_state = data.get("agent_state", {})
if kb_id:
kb = get_kb(kb_id)
if kb is None:
raise HTTPException(status_code=404, detail="知识库不存在")
agent_state["kb_id"] = kb_id
agent_state["kb_fields"] = kb.get("fields", [])
else:
agent_state.pop("kb_id", None)
agent_state.pop("kb_fields", None)
save_session(session_id, agent_state)
return {"session_id": session_id, "kb_id": kb_id or None}
@app.get("/api/sessions/{session_id}/kb")
async def get_session_kb(session_id: str):
_check_session_id(session_id)
data = load_session(session_id)
if data is None:
raise HTTPException(status_code=404, detail="会话不存在")
agent_state = data.get("agent_state", {})
kb_id = agent_state.get("kb_id", "")
result = {"kb_id": kb_id, "kb_fields": agent_state.get("kb_fields", [])}
if kb_id:
kb = get_kb(kb_id)
if kb:
result["kb_name"] = kb.get("name", "")
result["templates"] = kb.get("templates", [])
return result
# ─────────────────────────────────────────────
# 文件上传
# ─────────────────────────────────────────────
@app.post("/api/upload")
async def upload_file(file: UploadFile = File(...), session_id: str = ""):
if session_id:
_check_session_id(session_id)
file_id = uuid.uuid4().hex[:12]
_ensure_upload_dir(session_id)
# 保留原始文件名
safe_name = Path(file.filename or "upload.bin").name
dest = _ensure_upload_dir(session_id) / f"{file_id}_{safe_name}"
content = await file.read()
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(status_code=413, detail="文件大小超过 50MB 上限")
dest.write_bytes(content)
content_type = file.content_type or mimetypes.guess_type(safe_name)[0] or "application/octet-stream"
_file_registry[file_id] = {
"path": str(dest),
"filename": safe_name,
"content_type": content_type,
"size": len(content),
}
_api_log.info("文件上传", extra={
"file_id": file_id, "file_name": safe_name, "size": len(content),
})
return {
"file_id": file_id,
"filename": safe_name,
"content_type": content_type,
"size": len(content),
}
# ─────────────────────────────────────────────
# 文件处理辅助
# ─────────────────────────────────────────────
def _parse_jrxml_file(file_path: str) -> dict:
"""解析上传的 JRXML 文件,提取模板参数和字段。
Returns:
{jrxml_text, parameters: [{name, type}], fields: [{name, type}],
query: str, report_name: str, page_width: str, page_height: str}
"""
jrxml_info = parse_jrxml_fields(file_path)
try:
raw_xml = Path(file_path).read_text(encoding="utf-8")
except Exception:
raw_xml = ""
return {
"jrxml_text": raw_xml,
"parameters": jrxml_info.get("parameters", []),
"fields": jrxml_info.get("fields", []),
"query": jrxml_info.get("query", ""),
"report_name": jrxml_info.get("report_name", ""),
"page_width": jrxml_info.get("page_width", ""),
"page_height": jrxml_info.get("page_height", ""),
"error": jrxml_info.get("error"),
}
def _process_files(file_ids: list[str], session_id: str) -> dict:
"""处理上传的文件:解析 → 布局分析 → 提取 schema 文本。
JRXML 文件额外解析为模板上下文注入 agent_state。
Returns:
{full_prompt_prefix, uploaded_paths, layout_schema, ocr_text,
jrxml_template: dict | None}
"""
if not file_ids:
return {"full_prompt_prefix": "", "uploaded_paths": [],
"layout_schema": {}, "ocr_text": "", "jrxml_template": None}
parts = []
uploaded_paths = []
layout_schema = {}
ocr_text = ""
jrxml_template = None
for fid in file_ids:
info = _file_registry.get(fid)
if not info:
_api_log.warning("文件ID未注册", extra={"file_id": fid})
continue
file_path = info["path"]
uploaded_paths.append(file_path)
suffix = Path(info["filename"]).suffix.lower()
# JRXML 文件 → 解析为模板
if suffix == ".jrxml":
jrxml_template = _parse_jrxml_file(file_path)
if jrxml_template.get("error"):
parts.append(f"[JRXML 模板: {info['filename']}]\n解析失败: {jrxml_template['error']}")
else:
params = jrxml_template["parameters"]
fields = jrxml_template["fields"]
param_desc = "\n".join(
f" - {p['name']} ({p.get('type', 'String')})" for p in params
) if params else " (无参数)"
field_desc = "\n".join(
f" - {f['name']} ({f.get('type', 'String')})" for f in fields
) if fields else " (无字段)"
parts.append(
f"[上传的 JRXML 模板: {jrxml_template['report_name'] or info['filename']}]\n"
f"页面尺寸: {jrxml_template['page_width']}x{jrxml_template['page_height']}\n"
f"参数列表:\n{param_desc}\n"
f"字段列表:\n{field_desc}\n"
f"SQL查询: {jrxml_template['query'] or '(无)'}\n"
f"--- XML 内容 ---\n{jrxml_template['jrxml_text']}"
)
continue
parsed = parse_file(file_path, suffix)
if parsed.get("error"):
parts.append(f"[文件: {info['filename']}]\n解析失败: {parsed['error']}")
continue
parts.append(f"[文件: {info['filename']}]\n{parsed['text']}")
# 图片文件 → 布局分析
if info["content_type"] and info["content_type"].startswith("image/"):
layout = analyze_layout(file_path)
if layout.get("is_a4_template"):
parts.append(
f"\n[A4模板布局]\n"
f"表格行数: {layout.get('total_rows', 0)}, "
f"总元素: {layout.get('total_elements', 0)}, "
f"比例: {layout.get('a4_confidence', '')}"
)
if layout.get("description"):
parts.append(f"\n[布局描述]\n{layout['description']}")
schema = extract_layout_schema(layout)
if schema and schema.get("total_rows", 0) > 0:
layout_schema = schema
schema_text = schema.get("schema_text", "")
if schema_text:
parts.append(f"\n[布局Schema]\n{schema_text}")
# OCR 元素文本
ocr_elements = layout.get("rows", [])
if ocr_elements:
ocr_lines = []
for row in ocr_elements[:30]:
texts = [e.get("text", "") for e in row.get("elements", [])]
ocr_lines.append(" | ".join(texts))
ocr_text = "\n".join(ocr_lines)
if ocr_lines:
parts.append(f"\n[OCR 识别文本]\n{ocr_text}")
return {
"full_prompt_prefix": "\n\n".join(parts) if parts else "",
"uploaded_paths": uploaded_paths,
"layout_schema": layout_schema,
"ocr_text": ocr_text,
"jrxml_template": jrxml_template,
}
# ─────────────────────────────────────────────
# 核心:SSE 聊天端点
# ─────────────────────────────────────────────
@app.post("/api/sessions/{session_id}/chat")
async def chat(session_id: str, payload: dict):
"""发送消息并获取 SSE 流式响应。
Body:
{text: str, file_ids: [str, ...]}
Returns:
text/event-stream (SSE)
"""
_check_session_id(session_id)
text = payload.get("text", "").strip()
file_ids = payload.get("file_ids", [])
if not text and not file_ids:
raise HTTPException(status_code=400, detail="text 和 file_ids 均为空")
# ── 加载或创建会话 ──
trace_id = generate_trace_id()
set_trace_id(trace_id)
data = load_session(session_id)
if data is None:
data = create_session(session_id=session_id)
_api_log.info("自动创建会话", extra={"session_id": session_id, "trace_id": trace_id})
agent_state: AgentState = data.get("agent_state", {})
agent_state["session_id"] = session_id
# ── 处理文件 ──
file_result = _process_files(file_ids, session_id)
full_prompt = text
if file_result["full_prompt_prefix"]:
full_prompt = f"{file_result['full_prompt_prefix']}\n\n用户问题: {text}" if text else file_result["full_prompt_prefix"]
# ── 注入布局 schema(用于分层精确生成)──
if file_result.get("layout_schema"):
agent_state["layout_schema"] = file_result["layout_schema"]
if file_result.get("ocr_text"):
ocr_rows = [{"elements": [{"text": t} for t in line.split(" | ")]}
for line in file_result["ocr_text"].split("\n") if line.strip()]
if ocr_rows:
agent_state["ocr_elements"] = ocr_rows
if file_result.get("uploaded_paths"):
agent_state["uploaded_file_path"] = file_result["uploaded_paths"][0]
# ── 注入 JRXML 模板(对话中上传的模板)──
jrxml_tmpl = file_result.get("jrxml_template")
if jrxml_tmpl and not jrxml_tmpl.get("error"):
agent_state["uploaded_template_jrxml"] = jrxml_tmpl["jrxml_text"]
agent_state["uploaded_template_params"] = jrxml_tmpl["parameters"]
# ── 设置本轮输入 ──
if agent_state.get("current_jrxml"):
agent_state["user_modification_request"] = full_prompt
agent_state["user_input"] = full_prompt
agent_state["retry_count"] = 0
_api_log.info("对话请求", extra={
"session_id": session_id,
"trace_id": trace_id,
"text_length": len(text),
"file_count": len(file_ids),
"prompt_total": len(full_prompt),
})
# ── 返回 SSE 流 ──
async def stream_and_save():
# 如果上传了附件,先发送处理状态
if file_ids:
yield _sse_line("node_start", {
"node": "process_attachments",
"label": "正在处理附件",
})
yield _sse_line("node_complete", {
"node": "process_attachments",
"label": "正在处理附件",
"detail": f"已解析 {len(file_ids)} 个文件",
})
async for sse_chunk in _sse_generator(agent_state, session_id):
yield sse_chunk
return StreamingResponse(
stream_and_save(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"X-Trace-Id": trace_id,
},
)
# ─────────────────────────────────────────────
# 下载
# ─────────────────────────────────────────────
@app.get("/api/sessions/{session_id}/download/latest")
async def download_latest(session_id: str, background_tasks: BackgroundTasks):
"""下载最新 JRXML 文件。"""
_check_session_id(session_id)
data = load_session(session_id)
if data is None:
raise HTTPException(status_code=404, detail="会话不存在")
agent_state = data.get("agent_state", {})
jrxml = agent_state.get("current_jrxml", "")
if not jrxml:
raise HTTPException(status_code=404, detail="该会话暂无 JRXML")
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml", delete=False,
encoding="utf-8")
tmp.write(jrxml)
tmp.close()
background_tasks.add_task(os.unlink, tmp.name)
return FileResponse(
tmp.name,
media_type="application/xml",
filename=f"report_{session_id}.jrxml",
)
@app.get("/api/sessions/{session_id}/download/{version}")
async def download_version(session_id: str, version: int, background_tasks: BackgroundTasks):
"""下载指定版本的 JRXML 文件。"""
_check_session_id(session_id)
data = load_session(session_id)
if data is None:
raise HTTPException(status_code=404, detail="会话不存在")
versions = data.get("agent_state", {}).get("jrxml_versions", [])
if version < 0 or version >= len(versions):
raise HTTPException(status_code=404, detail="版本不存在")
jrxml = versions[version].get("jrxml", "")
if not jrxml:
raise HTTPException(status_code=404, detail="该版本内容为空")
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml", delete=False,
encoding="utf-8")
tmp.write(jrxml)
tmp.close()
background_tasks.add_task(os.unlink, tmp.name)
return FileResponse(
tmp.name,
media_type="application/xml",
filename=f"report_{session_id}_v{version}.jrxml",
)
# ─────────────────────────────────────────────
# 下载上传文件
# ─────────────────────────────────────────────
@app.get("/api/files/{file_id}")
async def download_file(file_id: str):
info = _file_registry.get(file_id)
if not info:
raise HTTPException(status_code=404, detail="文件未找到")
return FileResponse(info["path"], filename=info["filename"])
# ─────────────────────────────────────────────
# 启动入口
# ─────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("API_PORT", "8000"))
uvicorn.run("api_server:app", host="0.0.0.0", port=port, reload=False)