fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss

Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.

Solution (programmatic node control, not prompt engineering):

- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
  LLM) + individual bands. Split bands >4000 chars at element boundaries.
  Reassemble with element count validation (>10% change = rollback).

- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
  each). LLM cannot "reimagine" the entire report.

- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
  replacement. Zero LLM calls, zero content loss.

- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
  valid JRXML identifiers.

- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
  Full suite 385 tests, zero regressions.
This commit is contained in:
2026-05-24 08:55:38 +08:00
parent bb6cc6e241
commit bd5bfbac2d
80 changed files with 39463 additions and 108 deletions
+307 -75
View File
@@ -418,7 +418,8 @@ def load_session_node(state: AgentState) -> Dict:
state["session_name"] = data.get("session_name", "")
state["created_at"] = data.get("created_at", "")
except Exception:
pass
_node_log.warning("会话加载失败,使用空状态",
extra={"session_id": state.get("session_id", "")})
return state
@@ -454,7 +455,8 @@ def save_session_node(state: AgentState) -> Dict:
state["session_name"] = session_name
state["updated_at"] = persistable["updated_at"]
except Exception:
pass
_node_log.exception("会话保存失败",
extra={"session_id": state.get("session_id", "")})
return state
@@ -493,6 +495,90 @@ def _format_row_coordinates(row: dict) -> dict:
return {"y_center": row.get("y_center", 0), "columns": cols}
def _build_sampled_text(ocr_rows: list) -> str:
"""从 OCR 行数据构建采样坐标 JSON 字符串。"""
sampled = {}
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
if len(ocr_rows) > 1:
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
if len(ocr_rows) > 2:
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
return json.dumps(sampled, ensure_ascii=False, indent=2)
def _extract_band_height(band_xml: str) -> int:
"""从 <band height="N"> 中提取高度值。"""
m = re.search(r'<band\b[^>]*\sheight\s*=\s*"(\d+)"', band_xml)
return int(m.group(1)) if m else 0
def _extract_xml_fragment(text: str) -> str:
"""从 LLM 响应中提取 XML 片段(去除 markdown 代码块和解释文本)。"""
text = text.strip()
# 尝试提取 markdown 代码块内的内容
m = re.search(r"```(?:xml)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if m:
content = m.group(1).strip()
if content:
return content
# 尝试找到 <band ...>...</band> 片段
m = re.search(r"(<band\b[\s\S]*?</band>)", text, re.IGNORECASE)
if m:
return m.group(1).strip()
return text
def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
"""程序化字段映射:将 $F{{field_N}} 替换为 OCR 提取的真实字段名。
纯正则替换,不调 LLM。100% 确定性,零内容丢失。
"""
result = jrxml
for i, f in enumerate(ocr_fields):
placeholder = f"field_{i+1}"
raw_name = f.get("field_name", "")
if not raw_name:
continue
real_name = _sanitize_field_name(raw_name)
if real_name == placeholder:
continue
# 替换 field 声明: <field name="field_1" → <field name="customer_name"
result = re.sub(
rf'(<field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
rf'\g<1>{real_name}\g<2>', result,
)
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{real_name}}}')
return result
def _sanitize_field_name(name: str) -> str:
"""将 OCR 字段名净化为合法的 JRXML field name(仅 ASCII 字母/数字/下划线)。
非 ASCII 字符会被替换为其 Unicode 码点,确保唯一且合法。
"""
result = []
for ch in name:
if ch.isascii() and (ch.isalnum() or ch == '_'):
result.append(ch)
elif ch.isascii():
result.append('_')
else:
# 非 ASCII 转 _uXXXX_ 格式,保留可追溯性
cp = ord(ch)
result.append(f'_u{cp:04X}_')
cleaned = ''.join(result)
cleaned = cleaned.strip('_')
if not cleaned:
return "unnamed_field"
if cleaned[0].isdigit():
cleaned = 'f_' + cleaned
# 压缩连续下划线
cleaned = re.sub(r'_{2,}', '_', cleaned)
return cleaned.lower()
def _format_ocr_context(state: AgentState) -> str:
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
ocr_result = state.get("ocr_extraction_result")
@@ -619,27 +705,116 @@ def _log_ocr_layers(state: AgentState) -> None:
@log_node("retrieve")
def retrieve(state: AgentState) -> Dict:
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。"""
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。
支持按 KB 隔离搜索 + 模板意图检测。
"""
try:
from backend.rag_adapter import search_chunks
from backend.error_kb import search_error_cases
user_input = state.get("user_input", "")
context = search_chunks(user_input, k=5)
kb_id = state.get("kb_id", "")
context = search_chunks(user_input, k=5, kb_id=kb_id)
# 如果有最近错误,同时搜索错误知识库
# 错误知识库
error_msg = state.get("error_msg", "")
if error_msg:
error_context = search_error_cases(error_msg, k=2)
if error_context:
context = f"{context}\n\n[历史错误修正案例]\n{error_context}"
# 模板意图检测:用户是否提到了模板名?
template_keywords = _detect_template_intent(user_input)
if template_keywords and kb_id:
try:
from backend.kb_searcher import search_templates_in_kb
templates = search_templates_in_kb(kb_id, template_keywords, k=1)
if templates:
tmpl = templates[0]
state["kb_template_jrxml"] = tmpl.get("content", "")
state["kb_template_name"] = (
tmpl.get("metadata", {}).get("report_name", "")
)
context += (
f"\n\n[匹配到模板: {state['kb_template_name']}]\n"
f"{state['kb_template_jrxml']}"
)
except Exception:
_node_log.warning("模板检索失败", extra={"kb_id": kb_id})
state["retrieved_context"] = context
except Exception:
_node_log.exception("RAG 检索失败", extra={"user_input": user_input[:80]})
state["retrieved_context"] = ""
return state
def _detect_template_intent(user_input: str) -> str:
"""检测用户输入中是否包含模板引用意图,返回提取的搜索关键词。"""
import re
patterns = [
r"根据(.+?)模板",
r"基于(.+?)模板",
r"参照(.+?)模板",
r"用(.+?)模板",
r"使用(.+?)模板",
r"(.+单)模板",
]
for pat in patterns:
m = re.search(pat, user_input)
if m:
kw = m.group(1).strip()
if len(kw) >= 2:
return kw
return ""
def _build_template_context(state: dict) -> str:
"""构建模板上下文,用于注入生成 prompt。
优先级:对话上传 > KB 检索 > KB 字段定义。
"""
parts = []
# 对话中上传的 JRXML 模板
uploaded = state.get("uploaded_template_jrxml", "")
if uploaded:
params = state.get("uploaded_template_params", [])
param_str = "\n".join(
f" - {p['name']} ({p.get('type', 'String')})" for p in params
) if params else "(参数列表未解析)"
parts.append(
"[对话上传的模板]\n"
f"以下为用户上传的 JRXML 模板,请基于此模板进行修改:\n"
f"模板参数:\n{param_str}\n"
f"```xml\n{uploaded}\n```"
)
# KB 检索到的模板
kb_tmpl = state.get("kb_template_jrxml", "")
kb_name = state.get("kb_template_name", "")
if kb_tmpl and not uploaded:
parts.append(
f"[知识库模板: {kb_name}]\n"
f"以下为从知识库检索到的 JRXML 模板,请作为结构参考:\n"
f"```xml\n{kb_tmpl}\n```"
)
# KB 字段定义
kb_fields = state.get("kb_fields", [])
if kb_fields:
field_table = "\n".join(
f"| {f['name']} | {f.get('description', '')} | {f.get('type', 'String')} |"
for f in kb_fields
)
parts.append(
"[可用数据字段]\n"
"生成 JRXML 时请使用以下字段作为 $P{{xxx}} 参数:\n"
f"| 字段名 | 含义 | 类型 |\n|---|---|---|\n{field_table}"
)
return "\n\n".join(parts)
@log_node("generate")
def generate(state: AgentState) -> Dict:
"""根据用户需求和检索到的上下文生成初始 JRXML。"""
@@ -656,6 +831,7 @@ def generate(state: AgentState) -> Dict:
prompt = load_prompt("initial_generation").format(
context=state.get("retrieved_context", ""),
user_request=user_request,
template_context=_build_template_context(state),
)
full = []
for chunk in llm.stream(prompt):
@@ -683,6 +859,7 @@ def generate_skeleton(state: AgentState) -> Dict:
layout_schema=schema_text,
context=state.get("retrieved_context", ""),
user_request=user_request,
template_context=_build_template_context(state),
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
@@ -700,92 +877,145 @@ def generate_skeleton(state: AgentState) -> Dict:
@log_node("refine_layout")
def refine_layout(state: AgentState) -> Dict:
"""阶段二:使用采样坐标(表头 + 首行数据 + 最后一行)精确调整元素位置。"""
"""阶段二:Band 级窗口化精调 — 拆解骨架 JRXML 为独立 band,逐窗口 LLM 调整坐标。
流程:
1. decompose_jrxml() 拆解为 header + bands + footer
2. 每个 band 作为一个(或多个,若 >4000 字符)窗口发给 LLM
3. LLM 只修改该窗口内 reportElement 的 x/y/width/height
4. reassemble_jrxml() 重组 + validate_element_count() 校验
"""
from langgraph.config import get_stream_writer
from agent.jrxml_windower import (
decompose_jrxml, split_band_into_windows, reassemble_band_windows,
reassemble_jrxml, validate_element_count,
)
writer = get_stream_writer()
llm = get_llm(caller="refine_layout")
ocr_rows = state.get("ocr_elements", [])
sampled = {}
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
if len(ocr_rows) > 1:
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
if len(ocr_rows) > 2:
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
sampled_text = json.dumps(sampled, ensure_ascii=False, indent=2)
prompt = load_prompt("refine_layout").format(
current_jrxml=state.get("current_jrxml", ""),
sampled_coordinates=sampled_text,
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "refine_layout")
if not full_text.strip():
_node_log.error("refine_layout LLM 返回空响应,保留前一版本")
if not prev_jrxml.strip():
_node_log.warning("refine_layout 无输入 JRXML,跳过")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"refine_layout 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
# 拆解 JRXML
parts = decompose_jrxml(prev_jrxml)
if parts is None or parts.band_count == 0:
_node_log.warning("refine_layout 拆解失败或无 band,跳过")
return state
# 构建采样坐标
ocr_rows = state.get("ocr_elements", [])
sampled_text = _build_sampled_text(ocr_rows)
template_ctx = _build_template_context(state)
# 逐 band 窗口化精调
modified_bands: dict[str, str] = {}
total_windows = 0
for band in parts.bands:
if band.element_count == 0:
modified_bands[band.label] = band.band_xml
continue
windows = split_band_into_windows(band, max_chars=4000)
total_windows += len(windows)
band_results: list[str] = []
for wi, win_xml in enumerate(windows):
prompt = load_prompt("refine_layout").format(
band_name=band.section_name,
band_index=band.band_index,
band_height=_extract_band_height(win_xml),
window_index=wi + 1,
total_windows=len(windows),
xml_fragment=win_xml,
sampled_coordinates=sampled_text,
template_context=template_ctx,
)
try:
response = llm.invoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
fragment = _extract_xml_fragment(content)
if fragment:
band_results.append(fragment)
writer({"type": "stream", "node": "refine_layout",
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
else:
_node_log.warning("refine_layout 窗口 %s/%d 返回空,使用原文",
band.label, wi + 1)
band_results.append(win_xml)
except Exception as e:
_node_log.warning("refine_layout 窗口 %s/%d LLM 失败: %s,使用原文",
band.label, wi + 1, e)
band_results.append(win_xml)
if len(band_results) == 1:
modified_bands[band.label] = band_results[0]
else:
modified_bands[band.label] = reassemble_band_windows(band_results)
# 重组并校验
result = reassemble_jrxml(parts, modified_bands)
validation = validate_element_count(prev_jrxml, result, "refine_layout")
_node_log.info(
"refine_layout 窗口化完成: %d bands, %d 窗口, 元素 %d%d (%.1f%%)",
parts.band_count, total_windows,
validation["original"], validation["modified"],
validation["change_pct"] * 100,
)
if not validation["ok"]:
_node_log.error("refine_layout 元素丢失过多,回退到骨架版本")
return state
state["current_jrxml"] = result
state["conversation_history"].append({"role": "assistant", "content": result})
return state
@log_node("map_fields")
def map_fields(state: AgentState) -> Dict:
"""阶段三:将占位字段名替换为 OCR 提取的真实字段名。"""
from langgraph.config import get_stream_writer
"""阶段三:程序化字段映射 — 用正则将 $F{field_N} 替换为 OCR 字段名,不调 LLM。
writer = get_stream_writer()
llm = get_llm(caller="map_fields")
仅当 OCR 字段名包含中文等需要语义解释时才回退到 LLM。
"""
from agent.jrxml_windower import validate_element_count
ocr_result = state.get("ocr_extraction_result", {})
fields_text = ""
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
field_descs = []
for f in ocr_result["fields"]:
fname = f.get("field_name", "")
fval = f.get("field_value", "")
if fname:
field_descs.append(f" - {fname}: {fval}")
if field_descs:
fields_text = "提取的字段:\n" + "\n".join(field_descs)
if not fields_text:
elements = ocr_result.get("elements", []) if isinstance(ocr_result, dict) else []
if elements:
texts = [e.get("text", "") for e in elements if e.get("text")]
fields_text = "OCR 文本内容:\n" + "\n".join(f" - {t}" for t in texts[:50])
prompt = load_prompt("field_mapping").format(
current_jrxml=state.get("current_jrxml", ""),
ocr_fields=fields_text,
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "map_fields")
# 空响应重试:有时 LLM 第一轮不输出,换个方式再试一次
if not full_text.strip():
_node_log.warning("map_fields 第一轮返回空响应,尝试简化 prompt 重试")
retry_prompt = (
"请将以下 JRXML 中的占位字段名 $F{field_1}, $F{field_2}, ... 替换为 OCR 提取的真实字段名。\n"
"规则:根据列顺序映射——$F{field_1} 对应第1列,$F{field_2} 对应第2列,以此类推。\n"
"同时更新 <field name=\"...\"> 声明和所有 $F{...} 引用。\n"
"只输出完整 JRXML,不要解释。\n\n"
f"OCR 字段:\n{fields_text}\n\n"
f"JRXML\n{prev_jrxml}"
)
full_text = _generate_with_continuation(llm, retry_prompt, writer, "map_fields")
if not full_text.strip():
_node_log.error("map_fields LLM 重试后仍返回空响应,保留占位字段版本")
if not prev_jrxml.strip():
_node_log.warning("map_fields 无输入 JRXML,跳过")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"map_fields 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
# 提取 OCR 字段列表
ocr_result = state.get("ocr_extraction_result", {})
ocr_fields: list[dict] = []
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
ocr_fields = ocr_result["fields"]
if not ocr_fields:
_node_log.info("map_fields 无 OCR 字段,保留占位字段名")
state["conversation_history"].append({"role": "assistant", "content": prev_jrxml})
return state
# 程序化替换(主路径)
result = _programmatic_map_fields(prev_jrxml, ocr_fields)
# 校验
validation = validate_element_count(prev_jrxml, result, "map_fields")
_node_log.info(
"map_fields 程序化完成: %d 个字段, 元素 %d%d (%.1f%%)",
len(ocr_fields),
validation["original"], validation["modified"],
validation["change_pct"] * 100,
)
if not validation["ok"]:
_node_log.error("map_fields 元素丢失过多,回退到前一版本")
return state
state["current_jrxml"] = result
state["conversation_history"].append({"role": "assistant", "content": result})
return state
@@ -810,6 +1040,7 @@ def modify_jrxml(state: AgentState) -> Dict:
conversation_history=conv_text,
modification_request=state.get("user_modification_request", ""),
ocr_context=_format_ocr_context(state),
template_context=_build_template_context(state),
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "modify_jrxml")
@@ -1181,6 +1412,7 @@ def correct_jrxml(state: AgentState) -> Dict:
ocr_context=ocr_context,
layout_schema_text=layout_text,
fidelity_context=fidelity_text,
template_context=_build_template_context(state),
)
# 保存修正前状态(供 validate 判断是否写入错误知识库)
state["last_error_case"] = {