fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.
Solution (programmatic node control, not prompt engineering):
- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
LLM) + individual bands. Split bands >4000 chars at element boundaries.
Reassemble with element count validation (>10% change = rollback).
- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
each). LLM cannot "reimagine" the entire report.
- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
replacement. Zero LLM calls, zero content loss.
- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
valid JRXML identifiers.
- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
Full suite 385 tests, zero regressions.
This commit is contained in:
@@ -0,0 +1,336 @@
|
||||
"""KB 解析管道 — 文件提取→字段解析→chunk 切割→向量嵌入。
|
||||
|
||||
调用者: api_server.py (upload endpoint), scripts/init_default_kb.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import zipfile
|
||||
import tarfile
|
||||
import tempfile
|
||||
import defusedxml.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from backend.logger import get_logger
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_kb_parse_log = get_logger("kb_parser")
|
||||
|
||||
|
||||
def _find_tag(elem, tag):
|
||||
for el in elem.iter():
|
||||
local = el.tag.split("}")[-1] if "}" in el.tag else el.tag
|
||||
if local == tag:
|
||||
return el
|
||||
return None
|
||||
|
||||
|
||||
def _find_all_tags(elem, tag):
|
||||
results = []
|
||||
for el in elem.iter():
|
||||
local = el.tag.split("}")[-1] if "}" in el.tag else el.tag
|
||||
if local == tag:
|
||||
results.append(el)
|
||||
return results
|
||||
|
||||
|
||||
def parse_jrxml_fields(jrxml_path: str) -> dict:
|
||||
"""解析 JRXML 文件,提取参数和字段定义。"""
|
||||
try:
|
||||
tree = ET.parse(jrxml_path)
|
||||
root = tree.getroot()
|
||||
except ET.ParseError as e:
|
||||
return {"error": f"JRXML 解析失败: {e}", "parameters": [], "fields": [],
|
||||
"report_name": ""}
|
||||
|
||||
report_name = root.attrib.get("name", "")
|
||||
page_width = root.attrib.get("pageWidth", "")
|
||||
page_height = root.attrib.get("pageHeight", "")
|
||||
|
||||
parameters = []
|
||||
for p in _find_all_tags(root, "parameter"):
|
||||
params = {"name": p.attrib.get("name", ""),
|
||||
"type": p.attrib.get("class", "java.lang.String"),
|
||||
"description": ""}
|
||||
desc = _find_tag(p, "parameterDescription")
|
||||
if desc is not None and desc.text:
|
||||
params["description"] = desc.text.strip()
|
||||
parameters.append(params)
|
||||
|
||||
fields = []
|
||||
for f in _find_all_tags(root, "field"):
|
||||
fields.append({"name": f.attrib.get("name", ""),
|
||||
"type": f.attrib.get("class", "java.lang.String"),
|
||||
"description": ""})
|
||||
|
||||
query_text = ""
|
||||
query = _find_tag(root, "queryString")
|
||||
if query is not None and query.text:
|
||||
query_text = query.text.strip()
|
||||
|
||||
return {"report_name": report_name, "page_width": page_width,
|
||||
"page_height": page_height, "parameters": parameters,
|
||||
"fields": fields, "query": query_text, "error": None}
|
||||
|
||||
|
||||
def _extract_archive(file_path: str, dest_dir: str) -> list[str]:
|
||||
extracted = []
|
||||
resolved_dest = os.path.realpath(dest_dir)
|
||||
|
||||
if zipfile.is_zipfile(file_path):
|
||||
with zipfile.ZipFile(file_path, "r") as zf:
|
||||
for member in zf.namelist():
|
||||
member_path = os.path.realpath(os.path.join(dest_dir, member))
|
||||
if not member_path.startswith(resolved_dest + os.sep):
|
||||
continue
|
||||
zf.extract(member, dest_dir)
|
||||
if not member.endswith("/"):
|
||||
extracted.append(member_path)
|
||||
elif tarfile.is_tarfile(file_path):
|
||||
with tarfile.open(file_path, "r:*") as tf:
|
||||
for member in tf.getmembers():
|
||||
member_path = os.path.realpath(os.path.join(dest_dir, member.name))
|
||||
if not member_path.startswith(resolved_dest + os.sep):
|
||||
continue
|
||||
tf.extract(member, dest_dir)
|
||||
if not member.name.endswith("/"):
|
||||
extracted.append(member_path)
|
||||
return extracted
|
||||
|
||||
|
||||
def process_file_for_kb(kb_id: str, file_path: str,
|
||||
source_name: str = "") -> dict:
|
||||
from backend.kb_manager import get_kb_raw_dir
|
||||
raw_dir = get_kb_raw_dir(kb_id)
|
||||
if raw_dir is None:
|
||||
return {"error": "KB 不存在"}
|
||||
|
||||
fname = source_name or os.path.basename(file_path)
|
||||
dest = raw_dir / fname
|
||||
shutil.copy2(file_path, dest)
|
||||
|
||||
suffix = Path(fname).suffix.lower()
|
||||
|
||||
if suffix == ".jrxml":
|
||||
jrxml_info = parse_jrxml_fields(file_path)
|
||||
text = f"[JRXML 模板: {jrxml_info['report_name']}]\n"
|
||||
text += f"页面: {jrxml_info['page_width']}x{jrxml_info['page_height']}\n"
|
||||
text += "参数:\n" + "\n".join(
|
||||
f" {p['name']} ({p['type']})" for p in jrxml_info["parameters"])
|
||||
text += "\n字段:\n" + "\n".join(
|
||||
f" {f['name']} ({f['type']})" for f in jrxml_info["fields"])
|
||||
if jrxml_info["query"]:
|
||||
text += f"\n查询:\n{jrxml_info['query']}"
|
||||
try:
|
||||
raw_xml = Path(file_path).read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
raw_xml = ""
|
||||
return {"filename": fname, "type": "jrxml", "text": text,
|
||||
"raw_xml": raw_xml, "jrxml_info": jrxml_info, "error": None}
|
||||
|
||||
if suffix in (".zip", ".tar", ".gz", ".tgz"):
|
||||
tmpdir = tempfile.mkdtemp(prefix="kb_extract_")
|
||||
try:
|
||||
extracted = _extract_archive(file_path, tmpdir)
|
||||
sub_results = []
|
||||
for ep in extracted:
|
||||
sub = process_file_for_kb(
|
||||
kb_id, ep, source_name=os.path.basename(ep))
|
||||
sub_results.append(sub)
|
||||
return {"filename": fname, "type": "archive", "text": "",
|
||||
"archive_contents": sub_results, "error": None}
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
parse_result = parse_file(str(dest))
|
||||
return {"filename": fname, "type": suffix.lstrip("."),
|
||||
"text": parse_result.get("text", ""),
|
||||
"error": parse_result.get("error")}
|
||||
|
||||
|
||||
def chunk_file_results(results: list[dict], kb_name: str = "") -> list[dict]:
|
||||
chunks = []
|
||||
chunk_idx = 0
|
||||
|
||||
for r in results:
|
||||
if r.get("type") == "archive":
|
||||
for sub in r.get("archive_contents", []):
|
||||
chunks.extend(chunk_file_results([sub], kb_name))
|
||||
continue
|
||||
|
||||
fname = r.get("filename", "")
|
||||
ftype = r.get("type", "")
|
||||
text = r.get("text", "")
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
if ftype == "jrxml" and r.get("raw_xml"):
|
||||
jinfo = r.get("jrxml_info", {})
|
||||
report_name = jinfo.get("report_name", "")
|
||||
chunks.append({
|
||||
"id": f"chunk_{chunk_idx}",
|
||||
"content": f"[JRXML 模板: {report_name}]\n{r['text']}\n\n"
|
||||
f"<xml>\n{r['raw_xml']}\n</xml>",
|
||||
"metadata": {"chunk_type": "jrxml_template",
|
||||
"source_file": fname,
|
||||
"report_name": report_name,
|
||||
"kb_name": kb_name,
|
||||
"param_count": len(jinfo.get("parameters", [])),
|
||||
"field_count": len(jinfo.get("fields", []))},
|
||||
})
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
for para in paragraphs:
|
||||
if len(para) < 10:
|
||||
continue
|
||||
chunk_type = "md_section" if ftype in ("md", "") else f"{ftype}_text"
|
||||
chunks.append({
|
||||
"id": f"chunk_{chunk_idx}",
|
||||
"content": para,
|
||||
"metadata": {"chunk_type": chunk_type,
|
||||
"source_file": fname, "kb_name": kb_name},
|
||||
})
|
||||
chunk_idx += 1
|
||||
return chunks
|
||||
|
||||
|
||||
def extract_fields_with_llm(text: str, llm=None) -> list[dict]:
|
||||
if llm is None:
|
||||
return _extract_fields_from_table(text)
|
||||
prompt = (
|
||||
"请分析以下接口文档内容,提取所有字段定义。\n"
|
||||
"对每个字段,输出: 字段名, 含义, 类型, 是否必需。\n"
|
||||
"以 JSON 数组格式输出,每个元素为 {\"name\": \"...\", "
|
||||
"\"description\": \"...\", \"type\": \"...\", \"required\": false}。\n\n"
|
||||
f"{text}"
|
||||
)
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
start = content.find("[")
|
||||
end = content.rfind("]") + 1
|
||||
if start >= 0 and end > start:
|
||||
return json.loads(content[start:end])
|
||||
except Exception as e:
|
||||
_kb_parse_log.warning("LLM 字段提取失败,使用表格回退: %s", e)
|
||||
return _extract_fields_from_table(text)
|
||||
|
||||
|
||||
def _extract_fields_from_table(text: str) -> list[dict]:
|
||||
fields = []
|
||||
lines = text.split("\n")
|
||||
header_found = False
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line.startswith("|"):
|
||||
continue
|
||||
cells = [c.strip() for c in line.split("|")[1:-1]]
|
||||
if not cells:
|
||||
continue
|
||||
if not header_found:
|
||||
if any(h in str(c) for c in cells
|
||||
for h in ["字段", "名称", "含义", "说明", "类型"]):
|
||||
header_found = True
|
||||
continue
|
||||
if all(c.replace("-", "").replace(":", "").replace(" ", "") == ""
|
||||
for c in cells):
|
||||
continue
|
||||
if len(cells) >= 2:
|
||||
name = cells[0].replace("**", "").replace("L ", "").replace("\\", "").strip()
|
||||
if not name or name in ("", "---"):
|
||||
continue
|
||||
field = {"name": name, "description": "", "type": "java.lang.String",
|
||||
"required": False}
|
||||
if len(cells) >= 2 and cells[1]:
|
||||
field["description"] = cells[1].replace("<br/>", " ").strip()
|
||||
if len(cells) >= 3 and cells[2]:
|
||||
field["required"] = cells[2].strip() in ("是", "Y", "y", "yes", "Yes", "必填")
|
||||
if len(cells) >= 4 and cells[3]:
|
||||
field["type"] = cells[3].strip()
|
||||
fields.append(field)
|
||||
return fields
|
||||
|
||||
|
||||
def build_kb_from_files(kb_id: str, file_paths: list[str],
|
||||
llm=None) -> dict:
|
||||
from backend.kb_manager import update_kb_meta, get_kb_chunks_path
|
||||
from backend.kb_searcher import get_kb_searcher
|
||||
|
||||
all_results = []
|
||||
errors = []
|
||||
for fp in file_paths:
|
||||
try:
|
||||
r = process_file_for_kb(kb_id, fp)
|
||||
all_results.append(r)
|
||||
if r.get("error"):
|
||||
errors.append({"file": os.path.basename(fp), "error": r["error"]})
|
||||
except Exception as e:
|
||||
errors.append({"file": os.path.basename(fp), "error": str(e)})
|
||||
|
||||
chunks = chunk_file_results(all_results)
|
||||
|
||||
chunks_path = get_kb_chunks_path(kb_id)
|
||||
if chunks_path:
|
||||
chunks_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
json.dump(chunks, f, ensure_ascii=False, indent=2)
|
||||
|
||||
searcher = get_kb_searcher(kb_id)
|
||||
if searcher and chunks:
|
||||
try:
|
||||
searcher.add_chunks(chunks)
|
||||
except Exception as e:
|
||||
errors.append({"file": "embedding", "error": str(e)})
|
||||
|
||||
all_fields = []
|
||||
template_names = []
|
||||
for r in all_results:
|
||||
_collect_from_result(r, all_fields, template_names)
|
||||
|
||||
for r in all_results:
|
||||
if r.get("type") in ("archive", "jrxml"):
|
||||
continue
|
||||
text = r.get("text", "")
|
||||
if text.strip():
|
||||
for ef in extract_fields_with_llm(text, llm):
|
||||
if not any(f["name"] == ef["name"] for f in all_fields):
|
||||
all_fields.append(ef)
|
||||
|
||||
update_kb_meta(kb_id, {
|
||||
"fields": all_fields, "templates": template_names,
|
||||
"file_count": len(file_paths), "chunk_count": len(chunks),
|
||||
"parse_status": "ready" if not errors else "partial",
|
||||
})
|
||||
|
||||
_kb_parse_log.info("KB 构建完成", extra={
|
||||
"kb_id": kb_id, "fields": len(all_fields),
|
||||
"templates": len(template_names), "chunks": len(chunks),
|
||||
})
|
||||
|
||||
return {"status": "ready" if not errors else "partial",
|
||||
"field_count": len(all_fields), "chunk_count": len(chunks),
|
||||
"template_count": len(template_names), "errors": errors}
|
||||
|
||||
|
||||
def _collect_from_result(r: dict, all_fields: list, template_names: list) -> None:
|
||||
jinfo = r.get("jrxml_info")
|
||||
if jinfo and jinfo.get("report_name"):
|
||||
template_names.append({"name": jinfo["report_name"],
|
||||
"file": r.get("filename", "")})
|
||||
for p in jinfo.get("parameters", []):
|
||||
field = {"name": p["name"], "description": p.get("description", ""),
|
||||
"type": p.get("type", "java.lang.String"), "required": False}
|
||||
if not any(f["name"] == field["name"] for f in all_fields):
|
||||
all_fields.append(field)
|
||||
for f in jinfo.get("fields", []):
|
||||
field = {"name": f["name"], "description": f.get("description", ""),
|
||||
"type": f.get("type", "java.lang.String"), "required": False}
|
||||
if not any(fi["name"] == field["name"] for fi in all_fields):
|
||||
all_fields.append(field)
|
||||
Reference in New Issue
Block a user