bd5bfbac2d
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.
Solution (programmatic node control, not prompt engineering):
- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
LLM) + individual bands. Split bands >4000 chars at element boundaries.
Reassemble with element count validation (>10% change = rollback).
- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
each). LLM cannot "reimagine" the entire report.
- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
replacement. Zero LLM calls, zero content loss.
- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
valid JRXML identifiers.
- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
Full suite 385 tests, zero regressions.
949 lines
36 KiB
Python
949 lines
36 KiB
Python
"""JRXML Agent API Server — FastAPI + SSE streaming.
|
|
|
|
Replaces the Streamlit UI (app.py) with a REST + SSE backend.
|
|
The LangGraph agent pipeline is wrapped unchanged.
|
|
|
|
SSE Event Types:
|
|
node_start — 节点开始执行
|
|
node_complete — 节点执行完成(含详情)
|
|
stream_token — LLM 逐字输出
|
|
agent_complete — 全图执行完成
|
|
agent_error — 执行异常
|
|
|
|
Usage:
|
|
python -m uvicorn api_server:app --host 0.0.0.0 --port 8000
|
|
"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import contextvars
|
|
import json
|
|
import mimetypes
|
|
import os
|
|
import queue
|
|
import tempfile
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, FileResponse
|
|
|
|
load_dotenv(override=True)
|
|
|
|
from agent.graph import build_graph
|
|
from agent.state import AgentState
|
|
from backend.logger import get_logger, generate_trace_id, set_trace_id, get_trace_id
|
|
from backend.session import (
|
|
create_session,
|
|
load_session,
|
|
save_session,
|
|
list_all_sessions,
|
|
delete_session,
|
|
get_session_state,
|
|
SESSIONS_DIR,
|
|
)
|
|
from backend.file_parser import parse_file
|
|
from backend.layout_analyzer import analyze_layout, extract_layout_schema
|
|
from backend.kb_manager import (
|
|
create_user, list_users, get_user, delete_user,
|
|
create_kb, list_kbs, get_kb, update_kb_meta, delete_kb,
|
|
get_kb_raw_dir,
|
|
)
|
|
from backend.kb_parser import parse_jrxml_fields, build_kb_from_files
|
|
from backend.kb_searcher import search_kb, search_templates_in_kb
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 常量(从 app.py 迁移)
|
|
# ─────────────────────────────────────────────
|
|
|
|
NODE_LABELS = {
|
|
"load_session": "加载会话",
|
|
"process_input": "记录输入",
|
|
"manage_context": "管理上下文",
|
|
"save_state_snapshot": "保存快照",
|
|
"classify_intent": "识别意图",
|
|
"retrieve": "检索模板",
|
|
"generate": "生成 JRXML",
|
|
"modify_jrxml": "修改 JRXML",
|
|
"validate": "验证",
|
|
"explain_error": "分析错误",
|
|
"correct_jrxml": "自动修正",
|
|
"finalize": "完成",
|
|
"handle_consult": "咨询回答",
|
|
"handle_undo": "撤销操作",
|
|
"handle_reset": "重置会话",
|
|
"save_session": "保存会话",
|
|
"generate_skeleton": "生成骨架",
|
|
"refine_layout": "精调布局",
|
|
"map_fields": "映射字段",
|
|
}
|
|
|
|
INTENT_LABELS = {
|
|
"initial_generation": "新建报表",
|
|
"modify_report": "修改报表",
|
|
"preview_report": "预览报表",
|
|
"export_pdf": "导出 PDF",
|
|
"export_jrxml": "下载 JRXML",
|
|
"undo_modification": "撤销修改",
|
|
"consult_question": "咨询问题",
|
|
"reset_session": "重置会话",
|
|
}
|
|
|
|
SKIP_NODES = {"load_session", "process_input", "manage_context",
|
|
"save_state_snapshot", "save_session"}
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 日志 & 路径
|
|
# ─────────────────────────────────────────────
|
|
|
|
_api_log = get_logger("api")
|
|
UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", "./uploads"))
|
|
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB
|
|
|
|
def _check_session_id(session_id: str) -> None:
|
|
"""校验 session_id 合法性(防路径穿越),非法时抛出 HTTPException(400)。"""
|
|
from backend.session import validate_session_id
|
|
if not validate_session_id(session_id):
|
|
raise HTTPException(status_code=400, detail=f"Invalid session_id: {session_id!r}")
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 图编译(全局单例,带 node_start 回调)
|
|
# ─────────────────────────────────────────────
|
|
|
|
# 当前请求的事件队列(单个用户桌面应用)
|
|
_current_event_queue: Optional[queue.Queue] = None
|
|
_step_counter: contextvars.ContextVar[int] = contextvars.ContextVar('_step_counter', default=0)
|
|
|
|
|
|
def _on_node_start(node_name: str):
|
|
"""全局 node_start 回调 — 将事件推入当前请求的事件队列。"""
|
|
q = _current_event_queue
|
|
if q is not None:
|
|
_step_counter.set(_step_counter.get() + 1)
|
|
q.put(("node_start", {
|
|
"node": node_name,
|
|
"label": NODE_LABELS.get(node_name, node_name),
|
|
"step_index": _step_counter.get(),
|
|
}))
|
|
|
|
|
|
_graph = build_graph(on_node_start=_on_node_start)
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 文件注册表(内存中,桌面应用级别可接受)
|
|
# ─────────────────────────────────────────────
|
|
|
|
_file_registry: dict[str, dict] = {} # file_id → {path, filename, content_type, size}
|
|
|
|
|
|
def _ensure_upload_dir(session_id: str = "") -> Path:
|
|
d = UPLOADS_DIR / session_id if session_id else UPLOADS_DIR
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# SSE 辅助
|
|
# ─────────────────────────────────────────────
|
|
|
|
def _extract_detail(node_name: str, node_state: dict) -> str:
|
|
"""从节点状态中提取详情文本(用于 node_complete 事件)。"""
|
|
if node_name == "classify_intent":
|
|
intent = node_state.get("intent", "")
|
|
return f"意图: {INTENT_LABELS.get(intent, intent)}"
|
|
elif node_name == "retrieve":
|
|
ctx = node_state.get("retrieved_context", "")
|
|
return f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板"
|
|
elif node_name in ("generate", "modify_jrxml", "correct_jrxml",
|
|
"generate_skeleton", "refine_layout", "map_fields"):
|
|
jrxml = node_state.get("current_jrxml", "")
|
|
return f"生成 {len(jrxml)} 字符 JRXML"
|
|
elif node_name == "validate":
|
|
status = node_state.get("status", "")
|
|
if status == "pass":
|
|
return "验证通过 ✓"
|
|
err = node_state.get("error_msg", "")
|
|
return f"验证失败: {err[:80]}"
|
|
elif node_name == "explain_error":
|
|
expl = node_state.get("natural_explanation", "")
|
|
return expl[:120]
|
|
elif node_name == "handle_consult":
|
|
ans = node_state.get("consult_answer", "")
|
|
return ans[:150]
|
|
return ""
|
|
|
|
|
|
def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
|
|
"""在后台线程中运行 graph.stream(),将所有事件推入队列。
|
|
|
|
graph.stream() 只产出事件,不修改传入的 agent_state。
|
|
因此需要手动收集每个节点的返回并合并到 agent_state。
|
|
"""
|
|
try:
|
|
for event in _graph.stream(agent_state, stream_mode=["updates", "custom"]):
|
|
event_q.put(event)
|
|
# 将节点更新合并到 agent_state
|
|
if isinstance(event, tuple) and len(event) == 2:
|
|
mode, data = event
|
|
if mode == "updates" and isinstance(data, dict):
|
|
for node_state in data.values():
|
|
if isinstance(node_state, dict):
|
|
agent_state.update({k: v for k, v in node_state.items() if v is not None})
|
|
# 在 graph 完成后立即保存 session,防止 SSE 流中断导致数据丢失
|
|
sid = agent_state.get("session_id", "")
|
|
if sid:
|
|
try:
|
|
save_session(sid, agent_state)
|
|
except Exception as exc:
|
|
_api_log.error("图运行中保存会话失败", extra={
|
|
"session_id": sid,
|
|
"error": str(exc),
|
|
"traceback": traceback.format_exc(),
|
|
})
|
|
event_q.put(("done", {"reason": "graph_completed"}))
|
|
except Exception as exc:
|
|
event_q.put(("error", {
|
|
"error": str(exc),
|
|
"traceback": traceback.format_exc(),
|
|
}))
|
|
|
|
|
|
async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str:
|
|
"""SSE 事件生成器 —— 在后台线程运行图,异步产出 SSE 字符串。"""
|
|
global _current_event_queue
|
|
|
|
_step_counter.set(0)
|
|
t_start = time.time()
|
|
event_q: queue.Queue = queue.Queue()
|
|
_current_event_queue = event_q
|
|
|
|
loop = asyncio.get_running_loop()
|
|
future = loop.run_in_executor(None, _run_graph_sync, agent_state, event_q)
|
|
|
|
# 从队列读取事件,写 SSE(用 short sleep 做非阻塞轮询)
|
|
while True:
|
|
# 先排空队列中的所有事件
|
|
had_events = False
|
|
while True:
|
|
try:
|
|
item = event_q.get_nowait()
|
|
had_events = True
|
|
except queue.Empty:
|
|
break
|
|
|
|
kind = item[0]
|
|
if kind == "done":
|
|
_current_event_queue = None
|
|
total_ms = round((time.time() - t_start) * 1000)
|
|
if session_id:
|
|
save_session(session_id, agent_state)
|
|
versions = agent_state.get("jrxml_versions", [])
|
|
last_ver = versions[-1] if versions else {}
|
|
yield _sse_line("agent_complete", {
|
|
"reason": "done",
|
|
"intent": agent_state.get("intent", ""),
|
|
"status": agent_state.get("status", ""),
|
|
"jrxml_length": len(agent_state.get("current_jrxml", "")),
|
|
"error_msg": agent_state.get("error_msg", ""),
|
|
"natural_explanation": agent_state.get("natural_explanation", ""),
|
|
"consult_answer": agent_state.get("consult_answer", ""),
|
|
"retry_count": agent_state.get("retry_count", 0),
|
|
"total_duration_ms": total_ms,
|
|
"ocr_extraction_result": agent_state.get("ocr_extraction_result", {}),
|
|
"versions": len(versions),
|
|
"has_failed_version": last_ver.get("status") == "fail" if last_ver else False,
|
|
"failed_version_index": len(versions) - 1 if last_ver.get("status") == "fail" else -1,
|
|
})
|
|
await future
|
|
return
|
|
|
|
elif kind == "error":
|
|
_current_event_queue = None
|
|
yield _sse_line("agent_error", item[1])
|
|
await future
|
|
return
|
|
|
|
elif kind == "node_start":
|
|
yield _sse_line("node_start", item[1])
|
|
|
|
else:
|
|
# mode=data 来自 graph.stream()
|
|
mode, data = item
|
|
if mode == "updates":
|
|
for node_name, node_state in data.items():
|
|
detail = _extract_detail(node_name, node_state)
|
|
yield _sse_line("node_complete", {
|
|
"node": node_name,
|
|
"label": NODE_LABELS.get(node_name, node_name),
|
|
"detail": detail,
|
|
})
|
|
elif mode == "custom":
|
|
cd = data
|
|
if cd.get("type") == "stream":
|
|
yield _sse_line("stream_token", {
|
|
"text": cd.get("text", ""),
|
|
"type": "stream",
|
|
})
|
|
|
|
if not had_events:
|
|
await asyncio.sleep(0.05)
|
|
yield ": keepalive\n\n"
|
|
|
|
|
|
def _sse_line(event_type: str, data: dict) -> str:
|
|
"""构造单条 SSE 消息。"""
|
|
payload = json.dumps(data, ensure_ascii=False)
|
|
return f"event: {event_type}\ndata: {payload}\n\n"
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# FastAPI 应用
|
|
# ─────────────────────────────────────────────
|
|
|
|
app = FastAPI(
|
|
title="JRXML Agent API",
|
|
version="5.0",
|
|
description="JRXML 报表生成代理 — 前后端分离 API",
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 健康检查 & 配置
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.get("/api/health")
|
|
async def health():
|
|
return {
|
|
"status": "ok",
|
|
"version": "5.0",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
|
|
@app.get("/api/config")
|
|
async def config():
|
|
safe = {}
|
|
for key in ("LLM_PROVIDER", "OCR_ENGINE", "EMBEDDING_PROVIDER",
|
|
"MAX_RETRY", "CONTEXT_MAX_TOKENS", "CONTEXT_KEEP_RECENT"):
|
|
val = os.getenv(key, "")
|
|
safe[key] = val
|
|
return {"config": safe}
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 会话 CRUD
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.post("/api/sessions")
|
|
async def create_new_session():
|
|
data = create_session()
|
|
return {
|
|
"session_id": data["session_id"],
|
|
"session_name": data["session_name"],
|
|
"created_at": data["created_at"],
|
|
"updated_at": data["updated_at"],
|
|
}
|
|
|
|
|
|
@app.get("/api/sessions")
|
|
async def list_sessions():
|
|
return {"sessions": list_all_sessions()}
|
|
|
|
|
|
@app.get("/api/sessions/{session_id}")
|
|
async def get_session(session_id: str):
|
|
_check_session_id(session_id)
|
|
data = get_session_state(session_id)
|
|
if data is None:
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|
return {
|
|
"session_id": data.get("session_id"),
|
|
"session_name": data.get("session_name"),
|
|
"created_at": data.get("created_at"),
|
|
"updated_at": data.get("updated_at"),
|
|
"agent_state": data.get("agent_state", {}),
|
|
}
|
|
|
|
|
|
@app.delete("/api/sessions/{session_id}")
|
|
async def remove_session(session_id: str):
|
|
_check_session_id(session_id)
|
|
ok = delete_session(session_id)
|
|
if not ok:
|
|
raise HTTPException(status_code=404, detail="会话不存在或已删除")
|
|
return {"status": "deleted", "session_id": session_id}
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 用户管理
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.post("/api/users")
|
|
async def create_new_user(payload: dict):
|
|
name = payload.get("name", "").strip()
|
|
if not name:
|
|
raise HTTPException(status_code=400, detail="用户名不能为空")
|
|
try:
|
|
user = create_user(name)
|
|
return user
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/api/users")
|
|
async def list_all_users():
|
|
return {"users": list_users()}
|
|
|
|
|
|
@app.get("/api/users/{user_id}")
|
|
async def get_user_info(user_id: str):
|
|
user = get_user(user_id)
|
|
if user is None:
|
|
raise HTTPException(status_code=404, detail="用户不存在")
|
|
return user
|
|
|
|
|
|
@app.delete("/api/users/{user_id}")
|
|
async def remove_user(user_id: str):
|
|
ok = delete_user(user_id)
|
|
if not ok:
|
|
raise HTTPException(status_code=404, detail="用户不存在")
|
|
return {"status": "deleted", "user_id": user_id}
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 知识库 CRUD
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.get("/api/users/{user_id}/kbs")
|
|
async def list_user_kbs(user_id: str):
|
|
return {"kbs": list_kbs(user_id)}
|
|
|
|
|
|
@app.post("/api/users/{user_id}/kbs")
|
|
async def create_user_kb(user_id: str, payload: dict):
|
|
name = payload.get("name", "").strip()
|
|
description = payload.get("description", "")
|
|
if not name:
|
|
raise HTTPException(status_code=400, detail="知识库名称不能为空")
|
|
try:
|
|
kb = create_kb(user_id, name, description)
|
|
return kb
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/api/kbs/{kb_id}")
|
|
async def get_kb_info(kb_id: str):
|
|
kb = get_kb(kb_id)
|
|
if kb is None:
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|
return kb
|
|
|
|
|
|
@app.delete("/api/kbs/{kb_id}")
|
|
async def remove_kb(kb_id: str):
|
|
ok = delete_kb(kb_id)
|
|
if not ok:
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|
return {"status": "deleted", "kb_id": kb_id}
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 知识库文件上传
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.post("/api/kbs/{kb_id}/upload")
|
|
async def upload_to_kb(kb_id: str, file: UploadFile = File(...)):
|
|
kb = get_kb(kb_id)
|
|
if kb is None:
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|
|
|
raw_dir = get_kb_raw_dir(kb_id)
|
|
if raw_dir is None:
|
|
raise HTTPException(status_code=500, detail="知识库存储目录不可用")
|
|
|
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
safe_name = Path(file.filename or "upload").name
|
|
dest = raw_dir / safe_name
|
|
|
|
content = await file.read()
|
|
if len(content) > MAX_UPLOAD_SIZE:
|
|
raise HTTPException(status_code=413, detail="文件大小超过 50MB 上限")
|
|
|
|
dest.write_bytes(content)
|
|
|
|
from backend.kb_parser import process_file_for_kb
|
|
result = process_file_for_kb(kb_id, str(dest), source_name=safe_name)
|
|
|
|
_api_log.info("KB文件上传", extra={
|
|
"kb_id": kb_id, "file": safe_name, "type": result.get("type"),
|
|
})
|
|
|
|
return {
|
|
"filename": safe_name,
|
|
"type": result.get("type", ""),
|
|
"error": result.get("error"),
|
|
}
|
|
|
|
|
|
@app.post("/api/kbs/{kb_id}/build")
|
|
async def build_kb(kb_id: str):
|
|
"""构建知识库:对已上传的文件执行 chunk → embed 管线。"""
|
|
from backend.kb_parser import build_kb_from_files as build_fn
|
|
raw_dir = get_kb_raw_dir(kb_id)
|
|
if raw_dir is None or not raw_dir.exists():
|
|
raise HTTPException(status_code=404, detail="知识库无已上传文件")
|
|
|
|
files = [str(p) for p in raw_dir.iterdir() if p.is_file()]
|
|
if not files:
|
|
raise HTTPException(status_code=400, detail="知识库无文件,请先上传")
|
|
|
|
result = build_fn(kb_id, files)
|
|
return result
|
|
|
|
|
|
@app.get("/api/kbs/{kb_id}/status")
|
|
async def kb_status(kb_id: str):
|
|
kb = get_kb(kb_id)
|
|
if kb is None:
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|
return {
|
|
"kb_id": kb_id,
|
|
"name": kb.get("name", ""),
|
|
"field_count": len(kb.get("fields", [])),
|
|
"template_count": len(kb.get("templates", [])),
|
|
"file_count": kb.get("file_count", 0),
|
|
"chunk_count": kb.get("chunk_count", 0),
|
|
"parse_status": kb.get("parse_status", "empty"),
|
|
"created_at": kb.get("created_at", ""),
|
|
}
|
|
|
|
|
|
@app.get("/api/kbs/{kb_id}/fields")
|
|
async def kb_fields(kb_id: str):
|
|
kb = get_kb(kb_id)
|
|
if kb is None:
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|
return {"fields": kb.get("fields", []), "templates": kb.get("templates", [])}
|
|
|
|
|
|
@app.get("/api/kbs/{kb_id}/search")
|
|
async def kb_search(kb_id: str, q: str = "", type: str = ""):
|
|
if not q:
|
|
raise HTTPException(status_code=400, detail="查询参数 q 不能为空")
|
|
if type == "template":
|
|
results = search_templates_in_kb(kb_id, q, k=5)
|
|
else:
|
|
ctx = search_kb(kb_id, q, k=5)
|
|
return {"query": q, "context": ctx}
|
|
return {"query": q, "results": results}
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 会话-知识库绑定
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.put("/api/sessions/{session_id}/kb")
|
|
async def bind_session_kb(session_id: str, payload: dict):
|
|
_check_session_id(session_id)
|
|
kb_id = payload.get("kb_id", "").strip()
|
|
data = load_session(session_id)
|
|
if data is None:
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|
|
|
agent_state = data.get("agent_state", {})
|
|
if kb_id:
|
|
kb = get_kb(kb_id)
|
|
if kb is None:
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|
agent_state["kb_id"] = kb_id
|
|
agent_state["kb_fields"] = kb.get("fields", [])
|
|
else:
|
|
agent_state.pop("kb_id", None)
|
|
agent_state.pop("kb_fields", None)
|
|
|
|
save_session(session_id, agent_state)
|
|
return {"session_id": session_id, "kb_id": kb_id or None}
|
|
|
|
|
|
@app.get("/api/sessions/{session_id}/kb")
|
|
async def get_session_kb(session_id: str):
|
|
_check_session_id(session_id)
|
|
data = load_session(session_id)
|
|
if data is None:
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|
agent_state = data.get("agent_state", {})
|
|
kb_id = agent_state.get("kb_id", "")
|
|
result = {"kb_id": kb_id, "kb_fields": agent_state.get("kb_fields", [])}
|
|
if kb_id:
|
|
kb = get_kb(kb_id)
|
|
if kb:
|
|
result["kb_name"] = kb.get("name", "")
|
|
result["templates"] = kb.get("templates", [])
|
|
return result
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 文件上传
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.post("/api/upload")
|
|
async def upload_file(file: UploadFile = File(...), session_id: str = ""):
|
|
if session_id:
|
|
_check_session_id(session_id)
|
|
file_id = uuid.uuid4().hex[:12]
|
|
_ensure_upload_dir(session_id)
|
|
|
|
# 保留原始文件名
|
|
safe_name = Path(file.filename or "upload.bin").name
|
|
dest = _ensure_upload_dir(session_id) / f"{file_id}_{safe_name}"
|
|
|
|
content = await file.read()
|
|
if len(content) > MAX_UPLOAD_SIZE:
|
|
raise HTTPException(status_code=413, detail="文件大小超过 50MB 上限")
|
|
|
|
dest.write_bytes(content)
|
|
|
|
content_type = file.content_type or mimetypes.guess_type(safe_name)[0] or "application/octet-stream"
|
|
|
|
_file_registry[file_id] = {
|
|
"path": str(dest),
|
|
"filename": safe_name,
|
|
"content_type": content_type,
|
|
"size": len(content),
|
|
}
|
|
|
|
_api_log.info("文件上传", extra={
|
|
"file_id": file_id, "file_name": safe_name, "size": len(content),
|
|
})
|
|
|
|
return {
|
|
"file_id": file_id,
|
|
"filename": safe_name,
|
|
"content_type": content_type,
|
|
"size": len(content),
|
|
}
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 文件处理辅助
|
|
# ─────────────────────────────────────────────
|
|
|
|
def _parse_jrxml_file(file_path: str) -> dict:
|
|
"""解析上传的 JRXML 文件,提取模板参数和字段。
|
|
|
|
Returns:
|
|
{jrxml_text, parameters: [{name, type}], fields: [{name, type}],
|
|
query: str, report_name: str, page_width: str, page_height: str}
|
|
"""
|
|
jrxml_info = parse_jrxml_fields(file_path)
|
|
try:
|
|
raw_xml = Path(file_path).read_text(encoding="utf-8")
|
|
except Exception:
|
|
raw_xml = ""
|
|
return {
|
|
"jrxml_text": raw_xml,
|
|
"parameters": jrxml_info.get("parameters", []),
|
|
"fields": jrxml_info.get("fields", []),
|
|
"query": jrxml_info.get("query", ""),
|
|
"report_name": jrxml_info.get("report_name", ""),
|
|
"page_width": jrxml_info.get("page_width", ""),
|
|
"page_height": jrxml_info.get("page_height", ""),
|
|
"error": jrxml_info.get("error"),
|
|
}
|
|
|
|
|
|
def _process_files(file_ids: list[str], session_id: str) -> dict:
|
|
"""处理上传的文件:解析 → 布局分析 → 提取 schema 文本。
|
|
JRXML 文件额外解析为模板上下文注入 agent_state。
|
|
|
|
Returns:
|
|
{full_prompt_prefix, uploaded_paths, layout_schema, ocr_text,
|
|
jrxml_template: dict | None}
|
|
"""
|
|
if not file_ids:
|
|
return {"full_prompt_prefix": "", "uploaded_paths": [],
|
|
"layout_schema": {}, "ocr_text": "", "jrxml_template": None}
|
|
|
|
parts = []
|
|
uploaded_paths = []
|
|
layout_schema = {}
|
|
ocr_text = ""
|
|
jrxml_template = None
|
|
|
|
for fid in file_ids:
|
|
info = _file_registry.get(fid)
|
|
if not info:
|
|
_api_log.warning("文件ID未注册", extra={"file_id": fid})
|
|
continue
|
|
|
|
file_path = info["path"]
|
|
uploaded_paths.append(file_path)
|
|
suffix = Path(info["filename"]).suffix.lower()
|
|
|
|
# JRXML 文件 → 解析为模板
|
|
if suffix == ".jrxml":
|
|
jrxml_template = _parse_jrxml_file(file_path)
|
|
if jrxml_template.get("error"):
|
|
parts.append(f"[JRXML 模板: {info['filename']}]\n解析失败: {jrxml_template['error']}")
|
|
else:
|
|
params = jrxml_template["parameters"]
|
|
fields = jrxml_template["fields"]
|
|
param_desc = "\n".join(
|
|
f" - {p['name']} ({p.get('type', 'String')})" for p in params
|
|
) if params else " (无参数)"
|
|
field_desc = "\n".join(
|
|
f" - {f['name']} ({f.get('type', 'String')})" for f in fields
|
|
) if fields else " (无字段)"
|
|
parts.append(
|
|
f"[上传的 JRXML 模板: {jrxml_template['report_name'] or info['filename']}]\n"
|
|
f"页面尺寸: {jrxml_template['page_width']}x{jrxml_template['page_height']}\n"
|
|
f"参数列表:\n{param_desc}\n"
|
|
f"字段列表:\n{field_desc}\n"
|
|
f"SQL查询: {jrxml_template['query'] or '(无)'}\n"
|
|
f"--- XML 内容 ---\n{jrxml_template['jrxml_text']}"
|
|
)
|
|
continue
|
|
|
|
parsed = parse_file(file_path, suffix)
|
|
if parsed.get("error"):
|
|
parts.append(f"[文件: {info['filename']}]\n解析失败: {parsed['error']}")
|
|
continue
|
|
|
|
parts.append(f"[文件: {info['filename']}]\n{parsed['text']}")
|
|
|
|
# 图片文件 → 布局分析
|
|
if info["content_type"] and info["content_type"].startswith("image/"):
|
|
layout = analyze_layout(file_path)
|
|
if layout.get("is_a4_template"):
|
|
parts.append(
|
|
f"\n[A4模板布局]\n"
|
|
f"表格行数: {layout.get('total_rows', 0)}, "
|
|
f"总元素: {layout.get('total_elements', 0)}, "
|
|
f"比例: {layout.get('a4_confidence', '')}"
|
|
)
|
|
if layout.get("description"):
|
|
parts.append(f"\n[布局描述]\n{layout['description']}")
|
|
|
|
schema = extract_layout_schema(layout)
|
|
if schema and schema.get("total_rows", 0) > 0:
|
|
layout_schema = schema
|
|
schema_text = schema.get("schema_text", "")
|
|
if schema_text:
|
|
parts.append(f"\n[布局Schema]\n{schema_text}")
|
|
|
|
# OCR 元素文本
|
|
ocr_elements = layout.get("rows", [])
|
|
if ocr_elements:
|
|
ocr_lines = []
|
|
for row in ocr_elements[:30]:
|
|
texts = [e.get("text", "") for e in row.get("elements", [])]
|
|
ocr_lines.append(" | ".join(texts))
|
|
ocr_text = "\n".join(ocr_lines)
|
|
if ocr_lines:
|
|
parts.append(f"\n[OCR 识别文本]\n{ocr_text}")
|
|
|
|
return {
|
|
"full_prompt_prefix": "\n\n".join(parts) if parts else "",
|
|
"uploaded_paths": uploaded_paths,
|
|
"layout_schema": layout_schema,
|
|
"ocr_text": ocr_text,
|
|
"jrxml_template": jrxml_template,
|
|
}
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 核心:SSE 聊天端点
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.post("/api/sessions/{session_id}/chat")
|
|
async def chat(session_id: str, payload: dict):
|
|
"""发送消息并获取 SSE 流式响应。
|
|
|
|
Body:
|
|
{text: str, file_ids: [str, ...]}
|
|
|
|
Returns:
|
|
text/event-stream (SSE)
|
|
"""
|
|
_check_session_id(session_id)
|
|
text = payload.get("text", "").strip()
|
|
file_ids = payload.get("file_ids", [])
|
|
|
|
if not text and not file_ids:
|
|
raise HTTPException(status_code=400, detail="text 和 file_ids 均为空")
|
|
|
|
# ── 加载或创建会话 ──
|
|
trace_id = generate_trace_id()
|
|
set_trace_id(trace_id)
|
|
|
|
data = load_session(session_id)
|
|
if data is None:
|
|
data = create_session(session_id=session_id)
|
|
_api_log.info("自动创建会话", extra={"session_id": session_id, "trace_id": trace_id})
|
|
|
|
agent_state: AgentState = data.get("agent_state", {})
|
|
agent_state["session_id"] = session_id
|
|
|
|
# ── 处理文件 ──
|
|
file_result = _process_files(file_ids, session_id)
|
|
full_prompt = text
|
|
if file_result["full_prompt_prefix"]:
|
|
full_prompt = f"{file_result['full_prompt_prefix']}\n\n用户问题: {text}" if text else file_result["full_prompt_prefix"]
|
|
|
|
# ── 注入布局 schema(用于分层精确生成)──
|
|
if file_result.get("layout_schema"):
|
|
agent_state["layout_schema"] = file_result["layout_schema"]
|
|
if file_result.get("ocr_text"):
|
|
ocr_rows = [{"elements": [{"text": t} for t in line.split(" | ")]}
|
|
for line in file_result["ocr_text"].split("\n") if line.strip()]
|
|
if ocr_rows:
|
|
agent_state["ocr_elements"] = ocr_rows
|
|
if file_result.get("uploaded_paths"):
|
|
agent_state["uploaded_file_path"] = file_result["uploaded_paths"][0]
|
|
|
|
# ── 注入 JRXML 模板(对话中上传的模板)──
|
|
jrxml_tmpl = file_result.get("jrxml_template")
|
|
if jrxml_tmpl and not jrxml_tmpl.get("error"):
|
|
agent_state["uploaded_template_jrxml"] = jrxml_tmpl["jrxml_text"]
|
|
agent_state["uploaded_template_params"] = jrxml_tmpl["parameters"]
|
|
|
|
# ── 设置本轮输入 ──
|
|
if agent_state.get("current_jrxml"):
|
|
agent_state["user_modification_request"] = full_prompt
|
|
|
|
agent_state["user_input"] = full_prompt
|
|
agent_state["retry_count"] = 0
|
|
|
|
_api_log.info("对话请求", extra={
|
|
"session_id": session_id,
|
|
"trace_id": trace_id,
|
|
"text_length": len(text),
|
|
"file_count": len(file_ids),
|
|
"prompt_total": len(full_prompt),
|
|
})
|
|
|
|
# ── 返回 SSE 流 ──
|
|
async def stream_and_save():
|
|
# 如果上传了附件,先发送处理状态
|
|
if file_ids:
|
|
yield _sse_line("node_start", {
|
|
"node": "process_attachments",
|
|
"label": "正在处理附件",
|
|
})
|
|
yield _sse_line("node_complete", {
|
|
"node": "process_attachments",
|
|
"label": "正在处理附件",
|
|
"detail": f"已解析 {len(file_ids)} 个文件",
|
|
})
|
|
async for sse_chunk in _sse_generator(agent_state, session_id):
|
|
yield sse_chunk
|
|
|
|
return StreamingResponse(
|
|
stream_and_save(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
"X-Trace-Id": trace_id,
|
|
},
|
|
)
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 下载
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.get("/api/sessions/{session_id}/download/latest")
|
|
async def download_latest(session_id: str, background_tasks: BackgroundTasks):
|
|
"""下载最新 JRXML 文件。"""
|
|
_check_session_id(session_id)
|
|
data = load_session(session_id)
|
|
if data is None:
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|
|
|
agent_state = data.get("agent_state", {})
|
|
jrxml = agent_state.get("current_jrxml", "")
|
|
if not jrxml:
|
|
raise HTTPException(status_code=404, detail="该会话暂无 JRXML")
|
|
|
|
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml", delete=False,
|
|
encoding="utf-8")
|
|
tmp.write(jrxml)
|
|
tmp.close()
|
|
background_tasks.add_task(os.unlink, tmp.name)
|
|
|
|
return FileResponse(
|
|
tmp.name,
|
|
media_type="application/xml",
|
|
filename=f"report_{session_id}.jrxml",
|
|
)
|
|
|
|
|
|
@app.get("/api/sessions/{session_id}/download/{version}")
|
|
async def download_version(session_id: str, version: int, background_tasks: BackgroundTasks):
|
|
"""下载指定版本的 JRXML 文件。"""
|
|
_check_session_id(session_id)
|
|
data = load_session(session_id)
|
|
if data is None:
|
|
raise HTTPException(status_code=404, detail="会话不存在")
|
|
|
|
versions = data.get("agent_state", {}).get("jrxml_versions", [])
|
|
if version < 0 or version >= len(versions):
|
|
raise HTTPException(status_code=404, detail="版本不存在")
|
|
|
|
jrxml = versions[version].get("jrxml", "")
|
|
if not jrxml:
|
|
raise HTTPException(status_code=404, detail="该版本内容为空")
|
|
|
|
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml", delete=False,
|
|
encoding="utf-8")
|
|
tmp.write(jrxml)
|
|
tmp.close()
|
|
background_tasks.add_task(os.unlink, tmp.name)
|
|
|
|
return FileResponse(
|
|
tmp.name,
|
|
media_type="application/xml",
|
|
filename=f"report_{session_id}_v{version}.jrxml",
|
|
)
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 下载上传文件
|
|
# ─────────────────────────────────────────────
|
|
|
|
@app.get("/api/files/{file_id}")
|
|
async def download_file(file_id: str):
|
|
info = _file_registry.get(file_id)
|
|
if not info:
|
|
raise HTTPException(status_code=404, detail="文件未找到")
|
|
return FileResponse(info["path"], filename=info["filename"])
|
|
|
|
|
|
# ─────────────────────────────────────────────
|
|
# 启动入口
|
|
# ─────────────────────────────────────────────
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
port = int(os.getenv("API_PORT", "8000"))
|
|
uvicorn.run("api_server:app", host="0.0.0.0", port=port, reload=False) |