fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss

Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.

Solution (programmatic node control, not prompt engineering):

- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
  LLM) + individual bands. Split bands >4000 chars at element boundaries.
  Reassemble with element count validation (>10% change = rollback).

- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
  each). LLM cannot "reimagine" the entire report.

- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
  replacement. Zero LLM calls, zero content loss.

- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
  valid JRXML identifiers.

- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
  Full suite 385 tests, zero regressions.
This commit is contained in:
2026-05-24 08:55:38 +08:00
parent bb6cc6e241
commit bd5bfbac2d
80 changed files with 39463 additions and 108 deletions
+115
View File
@@ -0,0 +1,115 @@
"""数据源模式解析模块。
默认使用 $P{xxx} 参数模式;用户可选择 JDBC 直连模式。
"""
import json
import os
import re
from typing import Optional
from dotenv import load_dotenv
from agent.state import AgentState
load_dotenv()
def resolve_datasource_mode(state: AgentState) -> str:
"""返回数据源模式: "parameter""jdbc"
优先读取 state 中已设定的模式,否则根据用户输入检测。
"""
existing = state.get("datasource_mode", "")
if existing in ("parameter", "jdbc"):
return existing
user_input = state.get("user_input", "")
if _detect_jdbc_intent(user_input):
return "jdbc"
return "parameter"
def _detect_jdbc_intent(user_input: str) -> bool:
"""检测用户是否想要 JDBC 直连数据库模式。"""
patterns = [
r"(直连|直连数据库|数据库直连)",
r"(从|在)(数据库|DB|MySQL|PostgreSQL|Oracle|SQL Server)\w*",
r"(jdbc|JDBC)",
r"(连接|连)(数据库|DB)",
r"(查询|select|SELECT)\s",
]
for pat in patterns:
if re.search(pat, user_input):
return True
return False
def _sanitize_url(url: str) -> str:
"""剥离 JDBC URL 中的 user:password@ 片段,防止泄露到 LLM prompt。"""
return re.sub(r"://[^@]*@", "://***:***@", url)
def build_datasource_context(mode: str, kb_fields: list, db_config: Optional[dict] = None) -> str:
"""构建数据源上下文字符串,注入生成 prompt。"""
if mode == "jdbc":
if not db_config or not db_config.get("url"):
return (
"[数据源模式: JDBC]\n"
"⚠ 用户想要 JDBC 直连模式,但尚未配置数据库连接信息。\n"
"请先生成带 $P{xxx} 参数占位符的 JRXML,并提醒用户配置 JDBC 连接。"
)
safe_url = _sanitize_url(db_config.get("url", ""))
return (
"[数据源模式: JDBC]\n"
f"连接URL: {safe_url}\n"
f"驱动: {db_config.get('driver', '')}\n"
"请使用 <queryString><![CDATA[...]]></queryString> 中的 SQL 查询。"
)
# parameter mode
if kb_fields:
field_list = "\n".join(
f"| {f['name']} | {f.get('description', '')} | {f.get('type', 'java.lang.String')} |"
for f in kb_fields
)
return (
"[数据源模式: 参数]\n"
"使用 $P{xxx} 参数模式,以下为可用参数:\n"
f"| 参数名 | 含义 | 类型 |\n|---|---|---|\n{field_list}"
)
return "[数据源模式: 参数]\n使用 $P{xxx} 参数模式生成 JRXML。"
def configure_jdbc(state: AgentState, url: str = "", driver: str = "",
username: str = "", password: str = "") -> dict:
"""配置 JDBC 连接并返回更新字段。
注意:db_config 会被存入 AgentState 并持久化到会话文件。
生产环境应使用外部密钥管理服务,避免明文存储密码。
"""
return {
"datasource_mode": "jdbc",
"db_config": {
"url": url,
"driver": driver or "com.mysql.cj.jdbc.Driver",
"username": username,
"password": password,
},
}
def ask_db_config(state: AgentState) -> Optional[str]:
"""如果用户选了 JDBC 模式但未配置 DB 连接,返回反问消息。"""
mode = resolve_datasource_mode(state)
if mode == "jdbc":
db_config = state.get("db_config", {})
if not db_config or not db_config.get("url"):
return (
"您选择了数据库直连模式,请提供以下信息:\n"
"1. JDBC URL(如 jdbc:mysql://localhost:3306/dbname\n"
"2. 数据库用户名\n"
"3. 数据库密码\n"
"4. 驱动类名(可选,默认 com.mysql.cj.jdbc.Driver"
)
return None
+377
View File
@@ -0,0 +1,377 @@
"""JRXML 窗口化拆解与重组工具。
用于 3 阶段生成管道的 refine_layout 和 map_fields 节点:
- 将大段 JRXML 按 band 拆解为独立窗口
- 每个窗口独立发送给 LLM 进行坐标精调
- 重组所有窗口 + 校验元素完整性
调用者: agent/nodes.py (refine_layout, map_fields)
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Optional
import defusedxml.ElementTree as ET
from backend.logger import get_logger
_windower_log = get_logger("jrxml.windower")
# 需要按 section 拆解的 band 容器标签
_SECTION_TAGS = {
"title", "pageHeader", "columnHeader", "detail", "columnFooter",
"pageFooter", "lastPageFooter", "summary", "noData", "background",
}
# 不发给 LLM 的 header 元素(原样保留)
_HEADER_TAGS = {
"property", "propertyExpression", "import", "template", "reportFont",
"style", "subDataset", "scriptlet", "parameter", "queryString",
"field", "sortField", "variable", "filterExpression", "group",
}
@dataclass
class BandInfo:
"""单个 band 的拆解信息。"""
section_name: str # 所属 section 名,如 "title", "detail"
band_index: int # 在该 section 中的序号(0-based
band_xml: str # 完整 <band ...>...</band> 原始 XML
element_count: int # textField + staticText 数量
char_length: int # 字符数
@property
def label(self) -> str:
"""用于日志和 prompt 的标识。"""
if self.band_index > 0:
return f"{self.section_name}_band{self.band_index}"
return self.section_name
@dataclass
class JRXMLParts:
"""JRXML 拆解结果。"""
declaration: str # <?xml version="1.0"?>(如有)
root_open: str # <jasperReport ...>
header_xml: str # fields/params/queryString 等(不发给 LLM
bands: list[BandInfo] # 按出现顺序
footer: str # </jasperReport>
@property
def band_count(self) -> int:
return len(self.bands)
@property
def total_elements(self) -> int:
return sum(b.element_count for b in self.bands)
# ── 拆解 ──────────────────────────────────────────────────────────
def decompose_jrxml(jrxml: str) -> Optional[JRXMLParts]:
"""将 JRXML 字符串拆解为 header + bands + footer 三部分。
使用 defusedxml.ElementTree 进行安全解析。
返回 None 表示解析失败。
"""
try:
root = ET.fromstring(jrxml)
except ET.ParseError as e:
_windower_log.error("JRXML 解析失败: %s", e)
return None
tag = _local_tag(root.tag)
if tag != "jasperReport":
_windower_log.error("根元素不是 jasperReport: %s", tag)
return None
# 提取 XML 声明
declaration = ""
if jrxml.strip().startswith("<?xml"):
decl_end = jrxml.find("?>")
if decl_end != -1:
declaration = jrxml[:decl_end + 2]
# 提取根元素属性来重建 root_open
root_open = _build_root_open(jrxml, root)
# 分离 header 子元素和 section 子元素
header_children = []
section_children = [] # (section_tag, child_elem)
for child in root:
child_tag = _local_tag(child.tag)
if child_tag in _HEADER_TAGS:
header_children.append(child)
elif child_tag in _SECTION_TAGS:
section_children.append((child_tag, child))
# 构建 header_xml:序列化所有 header 子元素
header_parts = []
for child in header_children:
header_parts.append(_elem_to_string(child))
header_xml = "\n".join(header_parts)
# 提取 bands:每个 section 内可能有多个 <band>
bands = []
for sec_tag, sec_elem in section_children:
for bi, band_elem in enumerate(sec_elem):
band_local = _local_tag(band_elem.tag)
if band_local != "band":
continue
band_xml = _elem_to_string(band_elem)
ec = _count_elements_in_text(band_xml)
bands.append(BandInfo(
section_name=sec_tag,
band_index=bi,
band_xml=band_xml,
element_count=ec,
char_length=len(band_xml),
))
# 提取 footer</jasperReport> 闭合标签
footer = _extract_footer(jrxml)
parts = JRXMLParts(
declaration=declaration,
root_open=root_open,
header_xml=header_xml,
bands=bands,
footer=footer,
)
_windower_log.info(
"JRXML 拆解完成: %d bands, %d 个元素, header %d 字符",
len(bands), parts.total_elements, len(header_xml),
)
return parts
# ── 窗口切分 ──────────────────────────────────────────────────────
# 安全的元素边界:在这些闭合标签后切分
_SAFE_SPLIT_CLOSING = re.compile(
r"</(?:[\w:]+:)?(?:textField|staticText|line|rectangle|ellipse|image|"
r"frame|subreport|elementGroup|break|componentElement)>\s*"
)
def split_band_into_windows(band: BandInfo, max_chars: int = 4000) -> list[str]:
"""将一个 band 的 XML 在元素边界处切分为多个窗口。
每个窗口是合法的 XML 片段(完整的 <band>...</band>),
大小不超过 max_chars。
"""
if band.char_length <= max_chars:
return [band.band_xml]
inner = _extract_band_inner(band.band_xml)
if not inner:
return [band.band_xml]
segments = _split_at_boundaries(inner, _SAFE_SPLIT_CLOSING)
if len(segments) <= 1:
return [band.band_xml]
windows = _greedy_aggregate(segments, band.band_xml, max_chars)
return windows
# ── 重组 ──────────────────────────────────────────────────────────
def reassemble_band_windows(modified_windows: list[str]) -> str:
"""将多个窗口的修改结果重新合并为一个 band XML。
策略:取第一个窗口的开头(band 标签)和最后一个窗口的结尾(/band 标签),
中间拼接所有窗口内部的元素内容。
"""
if len(modified_windows) == 1:
return modified_windows[0]
first = modified_windows[0]
band_open_end = first.find(">")
if band_open_end == -1:
return "\n".join(modified_windows)
band_open = first[:band_open_end + 1]
last = modified_windows[-1]
band_close = _extract_band_close(last)
inner_parts = []
for win in modified_windows:
inner = _extract_band_inner(win)
if inner:
inner_parts.append(inner)
return band_open + "\n" + "\n".join(inner_parts) + "\n" + band_close
def reassemble_jrxml(parts: JRXMLParts, modified_bands: dict[str, str]) -> str:
"""将修改后的 bands 与 header/footer 重新组装为完整 JRXML。
modified_bands 的 key 格式为 "{section_name}_band{index}""{section_name}"index=0 时)。
"""
result = []
if parts.declaration:
result.append(parts.declaration)
result.append(parts.root_open)
if parts.header_xml.strip():
result.append(parts.header_xml)
current_section = None
for band in parts.bands:
if band.section_name != current_section:
if current_section is not None:
result.append(f"</{current_section}>")
current_section = band.section_name
result.append(f"<{current_section}>")
modified = modified_bands.get(band.label, band.band_xml)
result.append(modified)
if current_section is not None:
result.append(f"</{current_section}>")
result.append(parts.footer)
return "\n".join(result)
# ── 元素计数与校验 ────────────────────────────────────────────────
_ELEMENT_RE = re.compile(r"<(?:[\w:]+:)?(textField|staticText|field)\b", re.IGNORECASE)
def count_elements(jrxml: str) -> int:
"""正则计数 JRXML 中的 textField + staticText + field 声明。"""
return len(_ELEMENT_RE.findall(jrxml))
def validate_element_count(original: str, modified: str, stage: str) -> dict:
"""校验修改前后的元素数变化。
返回:
{"ok": bool, "original": int, "modified": int, "change_pct": float}
变化 > 10% 时 ok=False,调用方应回退。
"""
orig = count_elements(original)
mod = count_elements(modified)
if orig == 0:
return {"ok": True, "original": 0, "modified": mod, "change_pct": 0}
change = abs(mod - orig) / orig
ok = change <= 0.10
if not ok:
_windower_log.error(
"%s 元素数变化过大: %d%d (%.1f%%)",
stage, orig, mod, change * 100,
)
elif change > 0.05:
_windower_log.warning(
"%s 元素数有差异: %d%d (%.1f%%)",
stage, orig, mod, change * 100,
)
return {"ok": ok, "original": orig, "modified": mod, "change_pct": round(change, 4)}
# ── 内部工具函数 ──────────────────────────────────────────────────
def _local_tag(tag: str) -> str:
"""去除 XML 命名空间前缀。"""
return tag.split("}")[-1] if "}" in tag else tag
def _elem_to_string(elem: ET.Element) -> str:
"""将 ElementTree 元素序列化为字符串(使用 defusedxml 的 tostring)。"""
raw = ET.tostring(elem, encoding="unicode")
return raw.strip()
def _build_root_open(jrxml: str, root: ET.Element) -> str:
"""从原始文本重建 <jasperReport ...> 开头标签。"""
m = re.search(r"<jasperReport\b[^>]*>", jrxml, re.IGNORECASE)
if m:
return m.group(0)
attrs = []
for k, v in root.attrib.items():
attrs.append(f'{k}="{v}"')
return "<jasperReport " + " ".join(attrs) + ">"
def _extract_footer(jrxml: str) -> str:
"""提取 </jasperReport> 闭合标签。"""
m = re.search(r"</(?:[\w:]+:)?jasperReport>\s*$", jrxml, re.IGNORECASE)
if m:
return m.group(0).rstrip()
return "</jasperReport>"
_BAND_CLOSE_RE = re.compile(r"</(?:[\w:]+:)?band>\s*$", re.IGNORECASE)
def _extract_band_close(band_xml: str) -> str:
"""提取 band 的闭合标签(兼容命名空间前缀),如 '</ns0:band>''</band>'"""
m = _BAND_CLOSE_RE.search(band_xml)
return m.group(0).rstrip() if m else "</band>"
def _extract_band_inner(band_xml: str) -> str:
"""提取 <band ...> 和 </ns0:band> 之间的内容(兼容命名空间前缀)。"""
tag_end = band_xml.find(">")
if tag_end == -1:
return ""
close_m = _BAND_CLOSE_RE.search(band_xml)
if not close_m:
return band_xml[tag_end + 1:].strip()
return band_xml[tag_end + 1:close_m.start()].strip()
def _split_at_boundaries(text: str, boundary_re: re.Pattern) -> list[str]:
"""在正则匹配的闭合标签处切分文本。
返回切分后的片段列表(分隔符附加到前一个片段末尾)。
"""
segments = []
last_end = 0
for m in boundary_re.finditer(text):
end = m.end()
segments.append(text[last_end:end])
last_end = end
if last_end < len(text):
segments.append(text[last_end:])
elif not segments:
segments.append(text)
return segments
def _greedy_aggregate(segments: list[str], band_xml: str, max_chars: int) -> list[str]:
"""贪心聚合:将片段组合成不超过 max_chars 的窗口。
每个窗口包上 <band ...> 和 </band> 标签。
"""
tag_end = band_xml.find(">")
band_open = band_xml[:tag_end + 1] if tag_end != -1 else "<band>"
band_close = _extract_band_close(band_xml)
overhead = len(band_open) + len(band_close) + 1 # +1 for \n
windows = []
current = []
current_len = overhead
for seg in segments:
seg_len = len(seg)
if current and current_len + seg_len > max_chars:
windows.append(band_open + "\n" + "".join(current) + "\n" + band_close)
current = [seg]
current_len = overhead + seg_len
else:
current.append(seg)
current_len += seg_len
if current:
windows.append(band_open + "\n" + "".join(current) + "\n" + band_close)
return windows
def _count_elements_in_text(xml_text: str) -> int:
"""统计 XML 文本中的 textField + staticText 数量。"""
return len(_ELEMENT_RE.findall(xml_text))
+307 -75
View File
@@ -418,7 +418,8 @@ def load_session_node(state: AgentState) -> Dict:
state["session_name"] = data.get("session_name", "")
state["created_at"] = data.get("created_at", "")
except Exception:
pass
_node_log.warning("会话加载失败,使用空状态",
extra={"session_id": state.get("session_id", "")})
return state
@@ -454,7 +455,8 @@ def save_session_node(state: AgentState) -> Dict:
state["session_name"] = session_name
state["updated_at"] = persistable["updated_at"]
except Exception:
pass
_node_log.exception("会话保存失败",
extra={"session_id": state.get("session_id", "")})
return state
@@ -493,6 +495,90 @@ def _format_row_coordinates(row: dict) -> dict:
return {"y_center": row.get("y_center", 0), "columns": cols}
def _build_sampled_text(ocr_rows: list) -> str:
"""从 OCR 行数据构建采样坐标 JSON 字符串。"""
sampled = {}
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
if len(ocr_rows) > 1:
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
if len(ocr_rows) > 2:
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
return json.dumps(sampled, ensure_ascii=False, indent=2)
def _extract_band_height(band_xml: str) -> int:
"""从 <band height="N"> 中提取高度值。"""
m = re.search(r'<band\b[^>]*\sheight\s*=\s*"(\d+)"', band_xml)
return int(m.group(1)) if m else 0
def _extract_xml_fragment(text: str) -> str:
"""从 LLM 响应中提取 XML 片段(去除 markdown 代码块和解释文本)。"""
text = text.strip()
# 尝试提取 markdown 代码块内的内容
m = re.search(r"```(?:xml)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if m:
content = m.group(1).strip()
if content:
return content
# 尝试找到 <band ...>...</band> 片段
m = re.search(r"(<band\b[\s\S]*?</band>)", text, re.IGNORECASE)
if m:
return m.group(1).strip()
return text
def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
"""程序化字段映射:将 $F{{field_N}} 替换为 OCR 提取的真实字段名。
纯正则替换,不调 LLM。100% 确定性,零内容丢失。
"""
result = jrxml
for i, f in enumerate(ocr_fields):
placeholder = f"field_{i+1}"
raw_name = f.get("field_name", "")
if not raw_name:
continue
real_name = _sanitize_field_name(raw_name)
if real_name == placeholder:
continue
# 替换 field 声明: <field name="field_1" → <field name="customer_name"
result = re.sub(
rf'(<field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
rf'\g<1>{real_name}\g<2>', result,
)
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{real_name}}}')
return result
def _sanitize_field_name(name: str) -> str:
"""将 OCR 字段名净化为合法的 JRXML field name(仅 ASCII 字母/数字/下划线)。
非 ASCII 字符会被替换为其 Unicode 码点,确保唯一且合法。
"""
result = []
for ch in name:
if ch.isascii() and (ch.isalnum() or ch == '_'):
result.append(ch)
elif ch.isascii():
result.append('_')
else:
# 非 ASCII 转 _uXXXX_ 格式,保留可追溯性
cp = ord(ch)
result.append(f'_u{cp:04X}_')
cleaned = ''.join(result)
cleaned = cleaned.strip('_')
if not cleaned:
return "unnamed_field"
if cleaned[0].isdigit():
cleaned = 'f_' + cleaned
# 压缩连续下划线
cleaned = re.sub(r'_{2,}', '_', cleaned)
return cleaned.lower()
def _format_ocr_context(state: AgentState) -> str:
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
ocr_result = state.get("ocr_extraction_result")
@@ -619,27 +705,116 @@ def _log_ocr_layers(state: AgentState) -> None:
@log_node("retrieve")
def retrieve(state: AgentState) -> Dict:
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。"""
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。
支持按 KB 隔离搜索 + 模板意图检测。
"""
try:
from backend.rag_adapter import search_chunks
from backend.error_kb import search_error_cases
user_input = state.get("user_input", "")
context = search_chunks(user_input, k=5)
kb_id = state.get("kb_id", "")
context = search_chunks(user_input, k=5, kb_id=kb_id)
# 如果有最近错误,同时搜索错误知识库
# 错误知识库
error_msg = state.get("error_msg", "")
if error_msg:
error_context = search_error_cases(error_msg, k=2)
if error_context:
context = f"{context}\n\n[历史错误修正案例]\n{error_context}"
# 模板意图检测:用户是否提到了模板名?
template_keywords = _detect_template_intent(user_input)
if template_keywords and kb_id:
try:
from backend.kb_searcher import search_templates_in_kb
templates = search_templates_in_kb(kb_id, template_keywords, k=1)
if templates:
tmpl = templates[0]
state["kb_template_jrxml"] = tmpl.get("content", "")
state["kb_template_name"] = (
tmpl.get("metadata", {}).get("report_name", "")
)
context += (
f"\n\n[匹配到模板: {state['kb_template_name']}]\n"
f"{state['kb_template_jrxml']}"
)
except Exception:
_node_log.warning("模板检索失败", extra={"kb_id": kb_id})
state["retrieved_context"] = context
except Exception:
_node_log.exception("RAG 检索失败", extra={"user_input": user_input[:80]})
state["retrieved_context"] = ""
return state
def _detect_template_intent(user_input: str) -> str:
"""检测用户输入中是否包含模板引用意图,返回提取的搜索关键词。"""
import re
patterns = [
r"根据(.+?)模板",
r"基于(.+?)模板",
r"参照(.+?)模板",
r"用(.+?)模板",
r"使用(.+?)模板",
r"(.+单)模板",
]
for pat in patterns:
m = re.search(pat, user_input)
if m:
kw = m.group(1).strip()
if len(kw) >= 2:
return kw
return ""
def _build_template_context(state: dict) -> str:
"""构建模板上下文,用于注入生成 prompt。
优先级:对话上传 > KB 检索 > KB 字段定义。
"""
parts = []
# 对话中上传的 JRXML 模板
uploaded = state.get("uploaded_template_jrxml", "")
if uploaded:
params = state.get("uploaded_template_params", [])
param_str = "\n".join(
f" - {p['name']} ({p.get('type', 'String')})" for p in params
) if params else "(参数列表未解析)"
parts.append(
"[对话上传的模板]\n"
f"以下为用户上传的 JRXML 模板,请基于此模板进行修改:\n"
f"模板参数:\n{param_str}\n"
f"```xml\n{uploaded}\n```"
)
# KB 检索到的模板
kb_tmpl = state.get("kb_template_jrxml", "")
kb_name = state.get("kb_template_name", "")
if kb_tmpl and not uploaded:
parts.append(
f"[知识库模板: {kb_name}]\n"
f"以下为从知识库检索到的 JRXML 模板,请作为结构参考:\n"
f"```xml\n{kb_tmpl}\n```"
)
# KB 字段定义
kb_fields = state.get("kb_fields", [])
if kb_fields:
field_table = "\n".join(
f"| {f['name']} | {f.get('description', '')} | {f.get('type', 'String')} |"
for f in kb_fields
)
parts.append(
"[可用数据字段]\n"
"生成 JRXML 时请使用以下字段作为 $P{{xxx}} 参数:\n"
f"| 字段名 | 含义 | 类型 |\n|---|---|---|\n{field_table}"
)
return "\n\n".join(parts)
@log_node("generate")
def generate(state: AgentState) -> Dict:
"""根据用户需求和检索到的上下文生成初始 JRXML。"""
@@ -656,6 +831,7 @@ def generate(state: AgentState) -> Dict:
prompt = load_prompt("initial_generation").format(
context=state.get("retrieved_context", ""),
user_request=user_request,
template_context=_build_template_context(state),
)
full = []
for chunk in llm.stream(prompt):
@@ -683,6 +859,7 @@ def generate_skeleton(state: AgentState) -> Dict:
layout_schema=schema_text,
context=state.get("retrieved_context", ""),
user_request=user_request,
template_context=_build_template_context(state),
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
@@ -700,92 +877,145 @@ def generate_skeleton(state: AgentState) -> Dict:
@log_node("refine_layout")
def refine_layout(state: AgentState) -> Dict:
"""阶段二:使用采样坐标(表头 + 首行数据 + 最后一行)精确调整元素位置。"""
"""阶段二:Band 级窗口化精调 — 拆解骨架 JRXML 为独立 band,逐窗口 LLM 调整坐标。
流程:
1. decompose_jrxml() 拆解为 header + bands + footer
2. 每个 band 作为一个(或多个,若 >4000 字符)窗口发给 LLM
3. LLM 只修改该窗口内 reportElement 的 x/y/width/height
4. reassemble_jrxml() 重组 + validate_element_count() 校验
"""
from langgraph.config import get_stream_writer
from agent.jrxml_windower import (
decompose_jrxml, split_band_into_windows, reassemble_band_windows,
reassemble_jrxml, validate_element_count,
)
writer = get_stream_writer()
llm = get_llm(caller="refine_layout")
ocr_rows = state.get("ocr_elements", [])
sampled = {}
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
if len(ocr_rows) > 1:
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
if len(ocr_rows) > 2:
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
sampled_text = json.dumps(sampled, ensure_ascii=False, indent=2)
prompt = load_prompt("refine_layout").format(
current_jrxml=state.get("current_jrxml", ""),
sampled_coordinates=sampled_text,
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "refine_layout")
if not full_text.strip():
_node_log.error("refine_layout LLM 返回空响应,保留前一版本")
if not prev_jrxml.strip():
_node_log.warning("refine_layout 无输入 JRXML,跳过")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"refine_layout 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
# 拆解 JRXML
parts = decompose_jrxml(prev_jrxml)
if parts is None or parts.band_count == 0:
_node_log.warning("refine_layout 拆解失败或无 band,跳过")
return state
# 构建采样坐标
ocr_rows = state.get("ocr_elements", [])
sampled_text = _build_sampled_text(ocr_rows)
template_ctx = _build_template_context(state)
# 逐 band 窗口化精调
modified_bands: dict[str, str] = {}
total_windows = 0
for band in parts.bands:
if band.element_count == 0:
modified_bands[band.label] = band.band_xml
continue
windows = split_band_into_windows(band, max_chars=4000)
total_windows += len(windows)
band_results: list[str] = []
for wi, win_xml in enumerate(windows):
prompt = load_prompt("refine_layout").format(
band_name=band.section_name,
band_index=band.band_index,
band_height=_extract_band_height(win_xml),
window_index=wi + 1,
total_windows=len(windows),
xml_fragment=win_xml,
sampled_coordinates=sampled_text,
template_context=template_ctx,
)
try:
response = llm.invoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
fragment = _extract_xml_fragment(content)
if fragment:
band_results.append(fragment)
writer({"type": "stream", "node": "refine_layout",
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
else:
_node_log.warning("refine_layout 窗口 %s/%d 返回空,使用原文",
band.label, wi + 1)
band_results.append(win_xml)
except Exception as e:
_node_log.warning("refine_layout 窗口 %s/%d LLM 失败: %s,使用原文",
band.label, wi + 1, e)
band_results.append(win_xml)
if len(band_results) == 1:
modified_bands[band.label] = band_results[0]
else:
modified_bands[band.label] = reassemble_band_windows(band_results)
# 重组并校验
result = reassemble_jrxml(parts, modified_bands)
validation = validate_element_count(prev_jrxml, result, "refine_layout")
_node_log.info(
"refine_layout 窗口化完成: %d bands, %d 窗口, 元素 %d%d (%.1f%%)",
parts.band_count, total_windows,
validation["original"], validation["modified"],
validation["change_pct"] * 100,
)
if not validation["ok"]:
_node_log.error("refine_layout 元素丢失过多,回退到骨架版本")
return state
state["current_jrxml"] = result
state["conversation_history"].append({"role": "assistant", "content": result})
return state
@log_node("map_fields")
def map_fields(state: AgentState) -> Dict:
"""阶段三:将占位字段名替换为 OCR 提取的真实字段名。"""
from langgraph.config import get_stream_writer
"""阶段三:程序化字段映射 — 用正则将 $F{field_N} 替换为 OCR 字段名,不调 LLM。
writer = get_stream_writer()
llm = get_llm(caller="map_fields")
仅当 OCR 字段名包含中文等需要语义解释时才回退到 LLM。
"""
from agent.jrxml_windower import validate_element_count
ocr_result = state.get("ocr_extraction_result", {})
fields_text = ""
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
field_descs = []
for f in ocr_result["fields"]:
fname = f.get("field_name", "")
fval = f.get("field_value", "")
if fname:
field_descs.append(f" - {fname}: {fval}")
if field_descs:
fields_text = "提取的字段:\n" + "\n".join(field_descs)
if not fields_text:
elements = ocr_result.get("elements", []) if isinstance(ocr_result, dict) else []
if elements:
texts = [e.get("text", "") for e in elements if e.get("text")]
fields_text = "OCR 文本内容:\n" + "\n".join(f" - {t}" for t in texts[:50])
prompt = load_prompt("field_mapping").format(
current_jrxml=state.get("current_jrxml", ""),
ocr_fields=fields_text,
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "map_fields")
# 空响应重试:有时 LLM 第一轮不输出,换个方式再试一次
if not full_text.strip():
_node_log.warning("map_fields 第一轮返回空响应,尝试简化 prompt 重试")
retry_prompt = (
"请将以下 JRXML 中的占位字段名 $F{field_1}, $F{field_2}, ... 替换为 OCR 提取的真实字段名。\n"
"规则:根据列顺序映射——$F{field_1} 对应第1列,$F{field_2} 对应第2列,以此类推。\n"
"同时更新 <field name=\"...\"> 声明和所有 $F{...} 引用。\n"
"只输出完整 JRXML,不要解释。\n\n"
f"OCR 字段:\n{fields_text}\n\n"
f"JRXML\n{prev_jrxml}"
)
full_text = _generate_with_continuation(llm, retry_prompt, writer, "map_fields")
if not full_text.strip():
_node_log.error("map_fields LLM 重试后仍返回空响应,保留占位字段版本")
if not prev_jrxml.strip():
_node_log.warning("map_fields 无输入 JRXML,跳过")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"map_fields 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
# 提取 OCR 字段列表
ocr_result = state.get("ocr_extraction_result", {})
ocr_fields: list[dict] = []
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
ocr_fields = ocr_result["fields"]
if not ocr_fields:
_node_log.info("map_fields 无 OCR 字段,保留占位字段名")
state["conversation_history"].append({"role": "assistant", "content": prev_jrxml})
return state
# 程序化替换(主路径)
result = _programmatic_map_fields(prev_jrxml, ocr_fields)
# 校验
validation = validate_element_count(prev_jrxml, result, "map_fields")
_node_log.info(
"map_fields 程序化完成: %d 个字段, 元素 %d%d (%.1f%%)",
len(ocr_fields),
validation["original"], validation["modified"],
validation["change_pct"] * 100,
)
if not validation["ok"]:
_node_log.error("map_fields 元素丢失过多,回退到前一版本")
return state
state["current_jrxml"] = result
state["conversation_history"].append({"role": "assistant", "content": result})
return state
@@ -810,6 +1040,7 @@ def modify_jrxml(state: AgentState) -> Dict:
conversation_history=conv_text,
modification_request=state.get("user_modification_request", ""),
ocr_context=_format_ocr_context(state),
template_context=_build_template_context(state),
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "modify_jrxml")
@@ -1181,6 +1412,7 @@ def correct_jrxml(state: AgentState) -> Dict:
ocr_context=ocr_context,
layout_schema_text=layout_text,
fidelity_context=fidelity_text,
template_context=_build_template_context(state),
)
# 保存修正前状态(供 validate 判断是否写入错误知识库)
state["last_error_case"] = {
+11
View File
@@ -51,3 +51,14 @@ class AgentState(TypedDict, total=False):
# 需求9:分层精确生成
layout_schema: dict # extract_layout_schema() 输出,列+区域结构
ocr_elements: list # OCR 原始行数据(用于阶段二坐标采样)
# 需求10:多租户知识库
kb_id: str # 当前会话绑定的知识库 ID
kb_fields: list # KB 提取的字段定义 [{name, description, type, required}]
kb_field_mapping: dict # OCR 字段 → KB 字段映射 {"工单号": "billNo", ...}
uploaded_template_jrxml: str # 对话中上传的 JRXML 模板原文
uploaded_template_params: list # 解析出的参数 [{name, type}]
kb_template_jrxml: str # 从 KB 检索到的模板 JRXML
kb_template_name: str # 检索到的模板名称
datasource_mode: str # "parameter" 或 "jdbc"
db_config: dict # JDBC 连接配置