fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.
Solution (programmatic node control, not prompt engineering):
- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
LLM) + individual bands. Split bands >4000 chars at element boundaries.
Reassemble with element count validation (>10% change = rollback).
- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
each). LLM cannot "reimagine" the entire report.
- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
replacement. Zero LLM calls, zero content loss.
- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
valid JRXML identifiers.
- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
Full suite 385 tests, zero regressions.
This commit is contained in:
+292
-8
@@ -30,7 +30,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
|
||||
@@ -50,6 +50,13 @@ from backend.session import (
|
||||
)
|
||||
from backend.file_parser import parse_file
|
||||
from backend.layout_analyzer import analyze_layout, extract_layout_schema
|
||||
from backend.kb_manager import (
|
||||
create_user, list_users, get_user, delete_user,
|
||||
create_kb, list_kbs, get_kb, update_kb_meta, delete_kb,
|
||||
get_kb_raw_dir,
|
||||
)
|
||||
from backend.kb_parser import parse_jrxml_fields, build_kb_from_files
|
||||
from backend.kb_searcher import search_kb, search_templates_in_kb
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# 常量(从 app.py 迁移)
|
||||
@@ -97,6 +104,7 @@ SKIP_NODES = {"load_session", "process_input", "manage_context",
|
||||
|
||||
_api_log = get_logger("api")
|
||||
UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", "./uploads"))
|
||||
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
def _check_session_id(session_id: str) -> None:
|
||||
"""校验 session_id 合法性(防路径穿越),非法时抛出 HTTPException(400)。"""
|
||||
@@ -380,6 +388,218 @@ async def remove_session(session_id: str):
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# 用户管理
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
@app.post("/api/users")
|
||||
async def create_new_user(payload: dict):
|
||||
name = payload.get("name", "").strip()
|
||||
if not name:
|
||||
raise HTTPException(status_code=400, detail="用户名不能为空")
|
||||
try:
|
||||
user = create_user(name)
|
||||
return user
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/users")
|
||||
async def list_all_users():
|
||||
return {"users": list_users()}
|
||||
|
||||
|
||||
@app.get("/api/users/{user_id}")
|
||||
async def get_user_info(user_id: str):
|
||||
user = get_user(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
return user
|
||||
|
||||
|
||||
@app.delete("/api/users/{user_id}")
|
||||
async def remove_user(user_id: str):
|
||||
ok = delete_user(user_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
return {"status": "deleted", "user_id": user_id}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# 知识库 CRUD
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/users/{user_id}/kbs")
|
||||
async def list_user_kbs(user_id: str):
|
||||
return {"kbs": list_kbs(user_id)}
|
||||
|
||||
|
||||
@app.post("/api/users/{user_id}/kbs")
|
||||
async def create_user_kb(user_id: str, payload: dict):
|
||||
name = payload.get("name", "").strip()
|
||||
description = payload.get("description", "")
|
||||
if not name:
|
||||
raise HTTPException(status_code=400, detail="知识库名称不能为空")
|
||||
try:
|
||||
kb = create_kb(user_id, name, description)
|
||||
return kb
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/kbs/{kb_id}")
|
||||
async def get_kb_info(kb_id: str):
|
||||
kb = get_kb(kb_id)
|
||||
if kb is None:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
return kb
|
||||
|
||||
|
||||
@app.delete("/api/kbs/{kb_id}")
|
||||
async def remove_kb(kb_id: str):
|
||||
ok = delete_kb(kb_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
return {"status": "deleted", "kb_id": kb_id}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# 知识库文件上传
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
@app.post("/api/kbs/{kb_id}/upload")
|
||||
async def upload_to_kb(kb_id: str, file: UploadFile = File(...)):
|
||||
kb = get_kb(kb_id)
|
||||
if kb is None:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
|
||||
raw_dir = get_kb_raw_dir(kb_id)
|
||||
if raw_dir is None:
|
||||
raise HTTPException(status_code=500, detail="知识库存储目录不可用")
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
safe_name = Path(file.filename or "upload").name
|
||||
dest = raw_dir / safe_name
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(status_code=413, detail="文件大小超过 50MB 上限")
|
||||
|
||||
dest.write_bytes(content)
|
||||
|
||||
from backend.kb_parser import process_file_for_kb
|
||||
result = process_file_for_kb(kb_id, str(dest), source_name=safe_name)
|
||||
|
||||
_api_log.info("KB文件上传", extra={
|
||||
"kb_id": kb_id, "file": safe_name, "type": result.get("type"),
|
||||
})
|
||||
|
||||
return {
|
||||
"filename": safe_name,
|
||||
"type": result.get("type", ""),
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/kbs/{kb_id}/build")
|
||||
async def build_kb(kb_id: str):
|
||||
"""构建知识库:对已上传的文件执行 chunk → embed 管线。"""
|
||||
from backend.kb_parser import build_kb_from_files as build_fn
|
||||
raw_dir = get_kb_raw_dir(kb_id)
|
||||
if raw_dir is None or not raw_dir.exists():
|
||||
raise HTTPException(status_code=404, detail="知识库无已上传文件")
|
||||
|
||||
files = [str(p) for p in raw_dir.iterdir() if p.is_file()]
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="知识库无文件,请先上传")
|
||||
|
||||
result = build_fn(kb_id, files)
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/kbs/{kb_id}/status")
|
||||
async def kb_status(kb_id: str):
|
||||
kb = get_kb(kb_id)
|
||||
if kb is None:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
return {
|
||||
"kb_id": kb_id,
|
||||
"name": kb.get("name", ""),
|
||||
"field_count": len(kb.get("fields", [])),
|
||||
"template_count": len(kb.get("templates", [])),
|
||||
"file_count": kb.get("file_count", 0),
|
||||
"chunk_count": kb.get("chunk_count", 0),
|
||||
"parse_status": kb.get("parse_status", "empty"),
|
||||
"created_at": kb.get("created_at", ""),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/kbs/{kb_id}/fields")
|
||||
async def kb_fields(kb_id: str):
|
||||
kb = get_kb(kb_id)
|
||||
if kb is None:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
return {"fields": kb.get("fields", []), "templates": kb.get("templates", [])}
|
||||
|
||||
|
||||
@app.get("/api/kbs/{kb_id}/search")
|
||||
async def kb_search(kb_id: str, q: str = "", type: str = ""):
|
||||
if not q:
|
||||
raise HTTPException(status_code=400, detail="查询参数 q 不能为空")
|
||||
if type == "template":
|
||||
results = search_templates_in_kb(kb_id, q, k=5)
|
||||
else:
|
||||
ctx = search_kb(kb_id, q, k=5)
|
||||
return {"query": q, "context": ctx}
|
||||
return {"query": q, "results": results}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# 会话-知识库绑定
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
@app.put("/api/sessions/{session_id}/kb")
|
||||
async def bind_session_kb(session_id: str, payload: dict):
|
||||
_check_session_id(session_id)
|
||||
kb_id = payload.get("kb_id", "").strip()
|
||||
data = load_session(session_id)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
agent_state = data.get("agent_state", {})
|
||||
if kb_id:
|
||||
kb = get_kb(kb_id)
|
||||
if kb is None:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
agent_state["kb_id"] = kb_id
|
||||
agent_state["kb_fields"] = kb.get("fields", [])
|
||||
else:
|
||||
agent_state.pop("kb_id", None)
|
||||
agent_state.pop("kb_fields", None)
|
||||
|
||||
save_session(session_id, agent_state)
|
||||
return {"session_id": session_id, "kb_id": kb_id or None}
|
||||
|
||||
|
||||
@app.get("/api/sessions/{session_id}/kb")
|
||||
async def get_session_kb(session_id: str):
|
||||
_check_session_id(session_id)
|
||||
data = load_session(session_id)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
agent_state = data.get("agent_state", {})
|
||||
kb_id = agent_state.get("kb_id", "")
|
||||
result = {"kb_id": kb_id, "kb_fields": agent_state.get("kb_fields", [])}
|
||||
if kb_id:
|
||||
kb = get_kb(kb_id)
|
||||
if kb:
|
||||
result["kb_name"] = kb.get("name", "")
|
||||
result["templates"] = kb.get("templates", [])
|
||||
return result
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# 文件上传
|
||||
# ─────────────────────────────────────────────
|
||||
@@ -396,6 +616,9 @@ async def upload_file(file: UploadFile = File(...), session_id: str = ""):
|
||||
dest = _ensure_upload_dir(session_id) / f"{file_id}_{safe_name}"
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(status_code=413, detail="文件大小超过 50MB 上限")
|
||||
|
||||
dest.write_bytes(content)
|
||||
|
||||
content_type = file.content_type or mimetypes.guess_type(safe_name)[0] or "application/octet-stream"
|
||||
@@ -423,20 +646,47 @@ async def upload_file(file: UploadFile = File(...), session_id: str = ""):
|
||||
# 文件处理辅助
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
def _process_files(file_ids: list[str], session_id: str) -> dict:
|
||||
"""处理上传的文件:解析 → 布局分析 → 提取 schema 文本。
|
||||
def _parse_jrxml_file(file_path: str) -> dict:
|
||||
"""解析上传的 JRXML 文件,提取模板参数和字段。
|
||||
|
||||
Returns:
|
||||
{full_prompt_prefix, uploaded_paths, layout_schema, ocr_text}
|
||||
{jrxml_text, parameters: [{name, type}], fields: [{name, type}],
|
||||
query: str, report_name: str, page_width: str, page_height: str}
|
||||
"""
|
||||
jrxml_info = parse_jrxml_fields(file_path)
|
||||
try:
|
||||
raw_xml = Path(file_path).read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
raw_xml = ""
|
||||
return {
|
||||
"jrxml_text": raw_xml,
|
||||
"parameters": jrxml_info.get("parameters", []),
|
||||
"fields": jrxml_info.get("fields", []),
|
||||
"query": jrxml_info.get("query", ""),
|
||||
"report_name": jrxml_info.get("report_name", ""),
|
||||
"page_width": jrxml_info.get("page_width", ""),
|
||||
"page_height": jrxml_info.get("page_height", ""),
|
||||
"error": jrxml_info.get("error"),
|
||||
}
|
||||
|
||||
|
||||
def _process_files(file_ids: list[str], session_id: str) -> dict:
|
||||
"""处理上传的文件:解析 → 布局分析 → 提取 schema 文本。
|
||||
JRXML 文件额外解析为模板上下文注入 agent_state。
|
||||
|
||||
Returns:
|
||||
{full_prompt_prefix, uploaded_paths, layout_schema, ocr_text,
|
||||
jrxml_template: dict | None}
|
||||
"""
|
||||
if not file_ids:
|
||||
return {"full_prompt_prefix": "", "uploaded_paths": [],
|
||||
"layout_schema": {}, "ocr_text": ""}
|
||||
"layout_schema": {}, "ocr_text": "", "jrxml_template": None}
|
||||
|
||||
parts = []
|
||||
uploaded_paths = []
|
||||
layout_schema = {}
|
||||
ocr_text = ""
|
||||
jrxml_template = None
|
||||
|
||||
for fid in file_ids:
|
||||
info = _file_registry.get(fid)
|
||||
@@ -446,8 +696,33 @@ def _process_files(file_ids: list[str], session_id: str) -> dict:
|
||||
|
||||
file_path = info["path"]
|
||||
uploaded_paths.append(file_path)
|
||||
suffix = Path(info["filename"]).suffix.lower()
|
||||
|
||||
parsed = parse_file(file_path, Path(info["filename"]).suffix)
|
||||
# JRXML 文件 → 解析为模板
|
||||
if suffix == ".jrxml":
|
||||
jrxml_template = _parse_jrxml_file(file_path)
|
||||
if jrxml_template.get("error"):
|
||||
parts.append(f"[JRXML 模板: {info['filename']}]\n解析失败: {jrxml_template['error']}")
|
||||
else:
|
||||
params = jrxml_template["parameters"]
|
||||
fields = jrxml_template["fields"]
|
||||
param_desc = "\n".join(
|
||||
f" - {p['name']} ({p.get('type', 'String')})" for p in params
|
||||
) if params else " (无参数)"
|
||||
field_desc = "\n".join(
|
||||
f" - {f['name']} ({f.get('type', 'String')})" for f in fields
|
||||
) if fields else " (无字段)"
|
||||
parts.append(
|
||||
f"[上传的 JRXML 模板: {jrxml_template['report_name'] or info['filename']}]\n"
|
||||
f"页面尺寸: {jrxml_template['page_width']}x{jrxml_template['page_height']}\n"
|
||||
f"参数列表:\n{param_desc}\n"
|
||||
f"字段列表:\n{field_desc}\n"
|
||||
f"SQL查询: {jrxml_template['query'] or '(无)'}\n"
|
||||
f"--- XML 内容 ---\n{jrxml_template['jrxml_text']}"
|
||||
)
|
||||
continue
|
||||
|
||||
parsed = parse_file(file_path, suffix)
|
||||
if parsed.get("error"):
|
||||
parts.append(f"[文件: {info['filename']}]\n解析失败: {parsed['error']}")
|
||||
continue
|
||||
@@ -490,6 +765,7 @@ def _process_files(file_ids: list[str], session_id: str) -> dict:
|
||||
"uploaded_paths": uploaded_paths,
|
||||
"layout_schema": layout_schema,
|
||||
"ocr_text": ocr_text,
|
||||
"jrxml_template": jrxml_template,
|
||||
}
|
||||
|
||||
|
||||
@@ -543,6 +819,12 @@ async def chat(session_id: str, payload: dict):
|
||||
if file_result.get("uploaded_paths"):
|
||||
agent_state["uploaded_file_path"] = file_result["uploaded_paths"][0]
|
||||
|
||||
# ── 注入 JRXML 模板(对话中上传的模板)──
|
||||
jrxml_tmpl = file_result.get("jrxml_template")
|
||||
if jrxml_tmpl and not jrxml_tmpl.get("error"):
|
||||
agent_state["uploaded_template_jrxml"] = jrxml_tmpl["jrxml_text"]
|
||||
agent_state["uploaded_template_params"] = jrxml_tmpl["parameters"]
|
||||
|
||||
# ── 设置本轮输入 ──
|
||||
if agent_state.get("current_jrxml"):
|
||||
agent_state["user_modification_request"] = full_prompt
|
||||
@@ -591,7 +873,7 @@ async def chat(session_id: str, payload: dict):
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/sessions/{session_id}/download/latest")
|
||||
async def download_latest(session_id: str):
|
||||
async def download_latest(session_id: str, background_tasks: BackgroundTasks):
|
||||
"""下载最新 JRXML 文件。"""
|
||||
_check_session_id(session_id)
|
||||
data = load_session(session_id)
|
||||
@@ -607,6 +889,7 @@ async def download_latest(session_id: str):
|
||||
encoding="utf-8")
|
||||
tmp.write(jrxml)
|
||||
tmp.close()
|
||||
background_tasks.add_task(os.unlink, tmp.name)
|
||||
|
||||
return FileResponse(
|
||||
tmp.name,
|
||||
@@ -616,7 +899,7 @@ async def download_latest(session_id: str):
|
||||
|
||||
|
||||
@app.get("/api/sessions/{session_id}/download/{version}")
|
||||
async def download_version(session_id: str, version: int):
|
||||
async def download_version(session_id: str, version: int, background_tasks: BackgroundTasks):
|
||||
"""下载指定版本的 JRXML 文件。"""
|
||||
_check_session_id(session_id)
|
||||
data = load_session(session_id)
|
||||
@@ -635,6 +918,7 @@ async def download_version(session_id: str, version: int):
|
||||
encoding="utf-8")
|
||||
tmp.write(jrxml)
|
||||
tmp.close()
|
||||
background_tasks.add_task(os.unlink, tmp.name)
|
||||
|
||||
return FileResponse(
|
||||
tmp.name,
|
||||
|
||||
Reference in New Issue
Block a user