fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.
Solution (programmatic node control, not prompt engineering):
- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
LLM) + individual bands. Split bands >4000 chars at element boundaries.
Reassemble with element count validation (>10% change = rollback).
- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
each). LLM cannot "reimagine" the entire report.
- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
replacement. Zero LLM calls, zero content loss.
- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
valid JRXML identifiers.
- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
Full suite 385 tests, zero regressions.
This commit is contained in:
@@ -0,0 +1,115 @@
|
||||
"""数据源模式解析模块。
|
||||
|
||||
默认使用 $P{xxx} 参数模式;用户可选择 JDBC 直连模式。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agent.state import AgentState
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def resolve_datasource_mode(state: AgentState) -> str:
|
||||
"""返回数据源模式: "parameter" 或 "jdbc"。
|
||||
|
||||
优先读取 state 中已设定的模式,否则根据用户输入检测。
|
||||
"""
|
||||
existing = state.get("datasource_mode", "")
|
||||
if existing in ("parameter", "jdbc"):
|
||||
return existing
|
||||
|
||||
user_input = state.get("user_input", "")
|
||||
if _detect_jdbc_intent(user_input):
|
||||
return "jdbc"
|
||||
return "parameter"
|
||||
|
||||
|
||||
def _detect_jdbc_intent(user_input: str) -> bool:
|
||||
"""检测用户是否想要 JDBC 直连数据库模式。"""
|
||||
patterns = [
|
||||
r"(直连|直连数据库|数据库直连)",
|
||||
r"(从|在)(数据库|DB|MySQL|PostgreSQL|Oracle|SQL Server)\w*",
|
||||
r"(jdbc|JDBC)",
|
||||
r"(连接|连)(数据库|DB)",
|
||||
r"(查询|select|SELECT)\s",
|
||||
]
|
||||
for pat in patterns:
|
||||
if re.search(pat, user_input):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _sanitize_url(url: str) -> str:
|
||||
"""剥离 JDBC URL 中的 user:password@ 片段,防止泄露到 LLM prompt。"""
|
||||
return re.sub(r"://[^@]*@", "://***:***@", url)
|
||||
|
||||
|
||||
def build_datasource_context(mode: str, kb_fields: list, db_config: Optional[dict] = None) -> str:
|
||||
"""构建数据源上下文字符串,注入生成 prompt。"""
|
||||
if mode == "jdbc":
|
||||
if not db_config or not db_config.get("url"):
|
||||
return (
|
||||
"[数据源模式: JDBC]\n"
|
||||
"⚠ 用户想要 JDBC 直连模式,但尚未配置数据库连接信息。\n"
|
||||
"请先生成带 $P{xxx} 参数占位符的 JRXML,并提醒用户配置 JDBC 连接。"
|
||||
)
|
||||
safe_url = _sanitize_url(db_config.get("url", ""))
|
||||
return (
|
||||
"[数据源模式: JDBC]\n"
|
||||
f"连接URL: {safe_url}\n"
|
||||
f"驱动: {db_config.get('driver', '')}\n"
|
||||
"请使用 <queryString><![CDATA[...]]></queryString> 中的 SQL 查询。"
|
||||
)
|
||||
|
||||
# parameter mode
|
||||
if kb_fields:
|
||||
field_list = "\n".join(
|
||||
f"| {f['name']} | {f.get('description', '')} | {f.get('type', 'java.lang.String')} |"
|
||||
for f in kb_fields
|
||||
)
|
||||
return (
|
||||
"[数据源模式: 参数]\n"
|
||||
"使用 $P{xxx} 参数模式,以下为可用参数:\n"
|
||||
f"| 参数名 | 含义 | 类型 |\n|---|---|---|\n{field_list}"
|
||||
)
|
||||
return "[数据源模式: 参数]\n使用 $P{xxx} 参数模式生成 JRXML。"
|
||||
|
||||
|
||||
def configure_jdbc(state: AgentState, url: str = "", driver: str = "",
|
||||
username: str = "", password: str = "") -> dict:
|
||||
"""配置 JDBC 连接并返回更新字段。
|
||||
|
||||
注意:db_config 会被存入 AgentState 并持久化到会话文件。
|
||||
生产环境应使用外部密钥管理服务,避免明文存储密码。
|
||||
"""
|
||||
return {
|
||||
"datasource_mode": "jdbc",
|
||||
"db_config": {
|
||||
"url": url,
|
||||
"driver": driver or "com.mysql.cj.jdbc.Driver",
|
||||
"username": username,
|
||||
"password": password,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def ask_db_config(state: AgentState) -> Optional[str]:
|
||||
"""如果用户选了 JDBC 模式但未配置 DB 连接,返回反问消息。"""
|
||||
mode = resolve_datasource_mode(state)
|
||||
if mode == "jdbc":
|
||||
db_config = state.get("db_config", {})
|
||||
if not db_config or not db_config.get("url"):
|
||||
return (
|
||||
"您选择了数据库直连模式,请提供以下信息:\n"
|
||||
"1. JDBC URL(如 jdbc:mysql://localhost:3306/dbname)\n"
|
||||
"2. 数据库用户名\n"
|
||||
"3. 数据库密码\n"
|
||||
"4. 驱动类名(可选,默认 com.mysql.cj.jdbc.Driver)"
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,377 @@
|
||||
"""JRXML 窗口化拆解与重组工具。
|
||||
|
||||
用于 3 阶段生成管道的 refine_layout 和 map_fields 节点:
|
||||
- 将大段 JRXML 按 band 拆解为独立窗口
|
||||
- 每个窗口独立发送给 LLM 进行坐标精调
|
||||
- 重组所有窗口 + 校验元素完整性
|
||||
|
||||
调用者: agent/nodes.py (refine_layout, map_fields)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import defusedxml.ElementTree as ET
|
||||
|
||||
from backend.logger import get_logger
|
||||
|
||||
_windower_log = get_logger("jrxml.windower")
|
||||
|
||||
# 需要按 section 拆解的 band 容器标签
|
||||
_SECTION_TAGS = {
|
||||
"title", "pageHeader", "columnHeader", "detail", "columnFooter",
|
||||
"pageFooter", "lastPageFooter", "summary", "noData", "background",
|
||||
}
|
||||
|
||||
# 不发给 LLM 的 header 元素(原样保留)
|
||||
_HEADER_TAGS = {
|
||||
"property", "propertyExpression", "import", "template", "reportFont",
|
||||
"style", "subDataset", "scriptlet", "parameter", "queryString",
|
||||
"field", "sortField", "variable", "filterExpression", "group",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BandInfo:
|
||||
"""单个 band 的拆解信息。"""
|
||||
section_name: str # 所属 section 名,如 "title", "detail"
|
||||
band_index: int # 在该 section 中的序号(0-based)
|
||||
band_xml: str # 完整 <band ...>...</band> 原始 XML
|
||||
element_count: int # textField + staticText 数量
|
||||
char_length: int # 字符数
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
"""用于日志和 prompt 的标识。"""
|
||||
if self.band_index > 0:
|
||||
return f"{self.section_name}_band{self.band_index}"
|
||||
return self.section_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class JRXMLParts:
|
||||
"""JRXML 拆解结果。"""
|
||||
declaration: str # <?xml version="1.0"?>(如有)
|
||||
root_open: str # <jasperReport ...>
|
||||
header_xml: str # fields/params/queryString 等(不发给 LLM)
|
||||
bands: list[BandInfo] # 按出现顺序
|
||||
footer: str # </jasperReport>
|
||||
|
||||
@property
|
||||
def band_count(self) -> int:
|
||||
return len(self.bands)
|
||||
|
||||
@property
|
||||
def total_elements(self) -> int:
|
||||
return sum(b.element_count for b in self.bands)
|
||||
|
||||
|
||||
# ── 拆解 ──────────────────────────────────────────────────────────
|
||||
|
||||
def decompose_jrxml(jrxml: str) -> Optional[JRXMLParts]:
|
||||
"""将 JRXML 字符串拆解为 header + bands + footer 三部分。
|
||||
|
||||
使用 defusedxml.ElementTree 进行安全解析。
|
||||
返回 None 表示解析失败。
|
||||
"""
|
||||
try:
|
||||
root = ET.fromstring(jrxml)
|
||||
except ET.ParseError as e:
|
||||
_windower_log.error("JRXML 解析失败: %s", e)
|
||||
return None
|
||||
|
||||
tag = _local_tag(root.tag)
|
||||
if tag != "jasperReport":
|
||||
_windower_log.error("根元素不是 jasperReport: %s", tag)
|
||||
return None
|
||||
|
||||
# 提取 XML 声明
|
||||
declaration = ""
|
||||
if jrxml.strip().startswith("<?xml"):
|
||||
decl_end = jrxml.find("?>")
|
||||
if decl_end != -1:
|
||||
declaration = jrxml[:decl_end + 2]
|
||||
|
||||
# 提取根元素属性来重建 root_open
|
||||
root_open = _build_root_open(jrxml, root)
|
||||
|
||||
# 分离 header 子元素和 section 子元素
|
||||
header_children = []
|
||||
section_children = [] # (section_tag, child_elem)
|
||||
|
||||
for child in root:
|
||||
child_tag = _local_tag(child.tag)
|
||||
if child_tag in _HEADER_TAGS:
|
||||
header_children.append(child)
|
||||
elif child_tag in _SECTION_TAGS:
|
||||
section_children.append((child_tag, child))
|
||||
|
||||
# 构建 header_xml:序列化所有 header 子元素
|
||||
header_parts = []
|
||||
for child in header_children:
|
||||
header_parts.append(_elem_to_string(child))
|
||||
header_xml = "\n".join(header_parts)
|
||||
|
||||
# 提取 bands:每个 section 内可能有多个 <band>
|
||||
bands = []
|
||||
for sec_tag, sec_elem in section_children:
|
||||
for bi, band_elem in enumerate(sec_elem):
|
||||
band_local = _local_tag(band_elem.tag)
|
||||
if band_local != "band":
|
||||
continue
|
||||
band_xml = _elem_to_string(band_elem)
|
||||
ec = _count_elements_in_text(band_xml)
|
||||
bands.append(BandInfo(
|
||||
section_name=sec_tag,
|
||||
band_index=bi,
|
||||
band_xml=band_xml,
|
||||
element_count=ec,
|
||||
char_length=len(band_xml),
|
||||
))
|
||||
|
||||
# 提取 footer:</jasperReport> 闭合标签
|
||||
footer = _extract_footer(jrxml)
|
||||
|
||||
parts = JRXMLParts(
|
||||
declaration=declaration,
|
||||
root_open=root_open,
|
||||
header_xml=header_xml,
|
||||
bands=bands,
|
||||
footer=footer,
|
||||
)
|
||||
_windower_log.info(
|
||||
"JRXML 拆解完成: %d bands, %d 个元素, header %d 字符",
|
||||
len(bands), parts.total_elements, len(header_xml),
|
||||
)
|
||||
return parts
|
||||
|
||||
|
||||
# ── 窗口切分 ──────────────────────────────────────────────────────
|
||||
|
||||
# 安全的元素边界:在这些闭合标签后切分
|
||||
_SAFE_SPLIT_CLOSING = re.compile(
|
||||
r"</(?:[\w:]+:)?(?:textField|staticText|line|rectangle|ellipse|image|"
|
||||
r"frame|subreport|elementGroup|break|componentElement)>\s*"
|
||||
)
|
||||
|
||||
|
||||
def split_band_into_windows(band: BandInfo, max_chars: int = 4000) -> list[str]:
|
||||
"""将一个 band 的 XML 在元素边界处切分为多个窗口。
|
||||
|
||||
每个窗口是合法的 XML 片段(完整的 <band>...</band>),
|
||||
大小不超过 max_chars。
|
||||
"""
|
||||
if band.char_length <= max_chars:
|
||||
return [band.band_xml]
|
||||
|
||||
inner = _extract_band_inner(band.band_xml)
|
||||
if not inner:
|
||||
return [band.band_xml]
|
||||
|
||||
segments = _split_at_boundaries(inner, _SAFE_SPLIT_CLOSING)
|
||||
if len(segments) <= 1:
|
||||
return [band.band_xml]
|
||||
|
||||
windows = _greedy_aggregate(segments, band.band_xml, max_chars)
|
||||
return windows
|
||||
|
||||
|
||||
# ── 重组 ──────────────────────────────────────────────────────────
|
||||
|
||||
def reassemble_band_windows(modified_windows: list[str]) -> str:
|
||||
"""将多个窗口的修改结果重新合并为一个 band XML。
|
||||
|
||||
策略:取第一个窗口的开头(band 标签)和最后一个窗口的结尾(/band 标签),
|
||||
中间拼接所有窗口内部的元素内容。
|
||||
"""
|
||||
if len(modified_windows) == 1:
|
||||
return modified_windows[0]
|
||||
|
||||
first = modified_windows[0]
|
||||
band_open_end = first.find(">")
|
||||
if band_open_end == -1:
|
||||
return "\n".join(modified_windows)
|
||||
band_open = first[:band_open_end + 1]
|
||||
|
||||
last = modified_windows[-1]
|
||||
band_close = _extract_band_close(last)
|
||||
|
||||
inner_parts = []
|
||||
for win in modified_windows:
|
||||
inner = _extract_band_inner(win)
|
||||
if inner:
|
||||
inner_parts.append(inner)
|
||||
|
||||
return band_open + "\n" + "\n".join(inner_parts) + "\n" + band_close
|
||||
|
||||
|
||||
def reassemble_jrxml(parts: JRXMLParts, modified_bands: dict[str, str]) -> str:
|
||||
"""将修改后的 bands 与 header/footer 重新组装为完整 JRXML。
|
||||
|
||||
modified_bands 的 key 格式为 "{section_name}_band{index}" 或 "{section_name}"(index=0 时)。
|
||||
"""
|
||||
result = []
|
||||
if parts.declaration:
|
||||
result.append(parts.declaration)
|
||||
result.append(parts.root_open)
|
||||
if parts.header_xml.strip():
|
||||
result.append(parts.header_xml)
|
||||
|
||||
current_section = None
|
||||
for band in parts.bands:
|
||||
if band.section_name != current_section:
|
||||
if current_section is not None:
|
||||
result.append(f"</{current_section}>")
|
||||
current_section = band.section_name
|
||||
result.append(f"<{current_section}>")
|
||||
|
||||
modified = modified_bands.get(band.label, band.band_xml)
|
||||
result.append(modified)
|
||||
|
||||
if current_section is not None:
|
||||
result.append(f"</{current_section}>")
|
||||
|
||||
result.append(parts.footer)
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
# ── 元素计数与校验 ────────────────────────────────────────────────
|
||||
|
||||
_ELEMENT_RE = re.compile(r"<(?:[\w:]+:)?(textField|staticText|field)\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def count_elements(jrxml: str) -> int:
|
||||
"""正则计数 JRXML 中的 textField + staticText + field 声明。"""
|
||||
return len(_ELEMENT_RE.findall(jrxml))
|
||||
|
||||
|
||||
def validate_element_count(original: str, modified: str, stage: str) -> dict:
|
||||
"""校验修改前后的元素数变化。
|
||||
|
||||
返回:
|
||||
{"ok": bool, "original": int, "modified": int, "change_pct": float}
|
||||
变化 > 10% 时 ok=False,调用方应回退。
|
||||
"""
|
||||
orig = count_elements(original)
|
||||
mod = count_elements(modified)
|
||||
if orig == 0:
|
||||
return {"ok": True, "original": 0, "modified": mod, "change_pct": 0}
|
||||
change = abs(mod - orig) / orig
|
||||
ok = change <= 0.10
|
||||
if not ok:
|
||||
_windower_log.error(
|
||||
"%s 元素数变化过大: %d → %d (%.1f%%)",
|
||||
stage, orig, mod, change * 100,
|
||||
)
|
||||
elif change > 0.05:
|
||||
_windower_log.warning(
|
||||
"%s 元素数有差异: %d → %d (%.1f%%)",
|
||||
stage, orig, mod, change * 100,
|
||||
)
|
||||
return {"ok": ok, "original": orig, "modified": mod, "change_pct": round(change, 4)}
|
||||
|
||||
|
||||
# ── 内部工具函数 ──────────────────────────────────────────────────
|
||||
|
||||
def _local_tag(tag: str) -> str:
|
||||
"""去除 XML 命名空间前缀。"""
|
||||
return tag.split("}")[-1] if "}" in tag else tag
|
||||
|
||||
|
||||
def _elem_to_string(elem: ET.Element) -> str:
|
||||
"""将 ElementTree 元素序列化为字符串(使用 defusedxml 的 tostring)。"""
|
||||
raw = ET.tostring(elem, encoding="unicode")
|
||||
return raw.strip()
|
||||
|
||||
|
||||
def _build_root_open(jrxml: str, root: ET.Element) -> str:
|
||||
"""从原始文本重建 <jasperReport ...> 开头标签。"""
|
||||
m = re.search(r"<jasperReport\b[^>]*>", jrxml, re.IGNORECASE)
|
||||
if m:
|
||||
return m.group(0)
|
||||
attrs = []
|
||||
for k, v in root.attrib.items():
|
||||
attrs.append(f'{k}="{v}"')
|
||||
return "<jasperReport " + " ".join(attrs) + ">"
|
||||
|
||||
|
||||
def _extract_footer(jrxml: str) -> str:
|
||||
"""提取 </jasperReport> 闭合标签。"""
|
||||
m = re.search(r"</(?:[\w:]+:)?jasperReport>\s*$", jrxml, re.IGNORECASE)
|
||||
if m:
|
||||
return m.group(0).rstrip()
|
||||
return "</jasperReport>"
|
||||
|
||||
|
||||
_BAND_CLOSE_RE = re.compile(r"</(?:[\w:]+:)?band>\s*$", re.IGNORECASE)
|
||||
|
||||
def _extract_band_close(band_xml: str) -> str:
|
||||
"""提取 band 的闭合标签(兼容命名空间前缀),如 '</ns0:band>' 或 '</band>'。"""
|
||||
m = _BAND_CLOSE_RE.search(band_xml)
|
||||
return m.group(0).rstrip() if m else "</band>"
|
||||
|
||||
def _extract_band_inner(band_xml: str) -> str:
|
||||
"""提取 <band ...> 和 </ns0:band> 之间的内容(兼容命名空间前缀)。"""
|
||||
tag_end = band_xml.find(">")
|
||||
if tag_end == -1:
|
||||
return ""
|
||||
close_m = _BAND_CLOSE_RE.search(band_xml)
|
||||
if not close_m:
|
||||
return band_xml[tag_end + 1:].strip()
|
||||
return band_xml[tag_end + 1:close_m.start()].strip()
|
||||
|
||||
|
||||
def _split_at_boundaries(text: str, boundary_re: re.Pattern) -> list[str]:
|
||||
"""在正则匹配的闭合标签处切分文本。
|
||||
|
||||
返回切分后的片段列表(分隔符附加到前一个片段末尾)。
|
||||
"""
|
||||
segments = []
|
||||
last_end = 0
|
||||
for m in boundary_re.finditer(text):
|
||||
end = m.end()
|
||||
segments.append(text[last_end:end])
|
||||
last_end = end
|
||||
if last_end < len(text):
|
||||
segments.append(text[last_end:])
|
||||
elif not segments:
|
||||
segments.append(text)
|
||||
return segments
|
||||
|
||||
|
||||
def _greedy_aggregate(segments: list[str], band_xml: str, max_chars: int) -> list[str]:
|
||||
"""贪心聚合:将片段组合成不超过 max_chars 的窗口。
|
||||
|
||||
每个窗口包上 <band ...> 和 </band> 标签。
|
||||
"""
|
||||
tag_end = band_xml.find(">")
|
||||
band_open = band_xml[:tag_end + 1] if tag_end != -1 else "<band>"
|
||||
band_close = _extract_band_close(band_xml)
|
||||
overhead = len(band_open) + len(band_close) + 1 # +1 for \n
|
||||
|
||||
windows = []
|
||||
current = []
|
||||
current_len = overhead
|
||||
|
||||
for seg in segments:
|
||||
seg_len = len(seg)
|
||||
if current and current_len + seg_len > max_chars:
|
||||
windows.append(band_open + "\n" + "".join(current) + "\n" + band_close)
|
||||
current = [seg]
|
||||
current_len = overhead + seg_len
|
||||
else:
|
||||
current.append(seg)
|
||||
current_len += seg_len
|
||||
|
||||
if current:
|
||||
windows.append(band_open + "\n" + "".join(current) + "\n" + band_close)
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def _count_elements_in_text(xml_text: str) -> int:
|
||||
"""统计 XML 文本中的 textField + staticText 数量。"""
|
||||
return len(_ELEMENT_RE.findall(xml_text))
|
||||
+307
-75
@@ -418,7 +418,8 @@ def load_session_node(state: AgentState) -> Dict:
|
||||
state["session_name"] = data.get("session_name", "")
|
||||
state["created_at"] = data.get("created_at", "")
|
||||
except Exception:
|
||||
pass
|
||||
_node_log.warning("会话加载失败,使用空状态",
|
||||
extra={"session_id": state.get("session_id", "")})
|
||||
return state
|
||||
|
||||
|
||||
@@ -454,7 +455,8 @@ def save_session_node(state: AgentState) -> Dict:
|
||||
state["session_name"] = session_name
|
||||
state["updated_at"] = persistable["updated_at"]
|
||||
except Exception:
|
||||
pass
|
||||
_node_log.exception("会话保存失败",
|
||||
extra={"session_id": state.get("session_id", "")})
|
||||
return state
|
||||
|
||||
|
||||
@@ -493,6 +495,90 @@ def _format_row_coordinates(row: dict) -> dict:
|
||||
return {"y_center": row.get("y_center", 0), "columns": cols}
|
||||
|
||||
|
||||
def _build_sampled_text(ocr_rows: list) -> str:
|
||||
"""从 OCR 行数据构建采样坐标 JSON 字符串。"""
|
||||
sampled = {}
|
||||
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
|
||||
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
|
||||
if len(ocr_rows) > 1:
|
||||
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
|
||||
if len(ocr_rows) > 2:
|
||||
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
|
||||
return json.dumps(sampled, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def _extract_band_height(band_xml: str) -> int:
|
||||
"""从 <band height="N"> 中提取高度值。"""
|
||||
m = re.search(r'<band\b[^>]*\sheight\s*=\s*"(\d+)"', band_xml)
|
||||
return int(m.group(1)) if m else 0
|
||||
|
||||
|
||||
def _extract_xml_fragment(text: str) -> str:
|
||||
"""从 LLM 响应中提取 XML 片段(去除 markdown 代码块和解释文本)。"""
|
||||
text = text.strip()
|
||||
# 尝试提取 markdown 代码块内的内容
|
||||
m = re.search(r"```(?:xml)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||||
if m:
|
||||
content = m.group(1).strip()
|
||||
if content:
|
||||
return content
|
||||
# 尝试找到 <band ...>...</band> 片段
|
||||
m = re.search(r"(<band\b[\s\S]*?</band>)", text, re.IGNORECASE)
|
||||
if m:
|
||||
return m.group(1).strip()
|
||||
return text
|
||||
|
||||
|
||||
def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
|
||||
"""程序化字段映射:将 $F{{field_N}} 替换为 OCR 提取的真实字段名。
|
||||
|
||||
纯正则替换,不调 LLM。100% 确定性,零内容丢失。
|
||||
"""
|
||||
result = jrxml
|
||||
for i, f in enumerate(ocr_fields):
|
||||
placeholder = f"field_{i+1}"
|
||||
raw_name = f.get("field_name", "")
|
||||
if not raw_name:
|
||||
continue
|
||||
real_name = _sanitize_field_name(raw_name)
|
||||
if real_name == placeholder:
|
||||
continue
|
||||
# 替换 field 声明: <field name="field_1" → <field name="customer_name"
|
||||
result = re.sub(
|
||||
rf'(<field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
|
||||
rf'\g<1>{real_name}\g<2>', result,
|
||||
)
|
||||
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
|
||||
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{real_name}}}')
|
||||
return result
|
||||
|
||||
|
||||
def _sanitize_field_name(name: str) -> str:
|
||||
"""将 OCR 字段名净化为合法的 JRXML field name(仅 ASCII 字母/数字/下划线)。
|
||||
|
||||
非 ASCII 字符会被替换为其 Unicode 码点,确保唯一且合法。
|
||||
"""
|
||||
result = []
|
||||
for ch in name:
|
||||
if ch.isascii() and (ch.isalnum() or ch == '_'):
|
||||
result.append(ch)
|
||||
elif ch.isascii():
|
||||
result.append('_')
|
||||
else:
|
||||
# 非 ASCII 转 _uXXXX_ 格式,保留可追溯性
|
||||
cp = ord(ch)
|
||||
result.append(f'_u{cp:04X}_')
|
||||
cleaned = ''.join(result)
|
||||
cleaned = cleaned.strip('_')
|
||||
if not cleaned:
|
||||
return "unnamed_field"
|
||||
if cleaned[0].isdigit():
|
||||
cleaned = 'f_' + cleaned
|
||||
# 压缩连续下划线
|
||||
cleaned = re.sub(r'_{2,}', '_', cleaned)
|
||||
return cleaned.lower()
|
||||
|
||||
|
||||
def _format_ocr_context(state: AgentState) -> str:
|
||||
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
|
||||
ocr_result = state.get("ocr_extraction_result")
|
||||
@@ -619,27 +705,116 @@ def _log_ocr_layers(state: AgentState) -> None:
|
||||
|
||||
@log_node("retrieve")
|
||||
def retrieve(state: AgentState) -> Dict:
|
||||
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。"""
|
||||
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。
|
||||
支持按 KB 隔离搜索 + 模板意图检测。
|
||||
"""
|
||||
try:
|
||||
from backend.rag_adapter import search_chunks
|
||||
from backend.error_kb import search_error_cases
|
||||
|
||||
user_input = state.get("user_input", "")
|
||||
context = search_chunks(user_input, k=5)
|
||||
kb_id = state.get("kb_id", "")
|
||||
context = search_chunks(user_input, k=5, kb_id=kb_id)
|
||||
|
||||
# 如果有最近错误,同时搜索错误知识库
|
||||
# 错误知识库
|
||||
error_msg = state.get("error_msg", "")
|
||||
if error_msg:
|
||||
error_context = search_error_cases(error_msg, k=2)
|
||||
if error_context:
|
||||
context = f"{context}\n\n[历史错误修正案例]\n{error_context}"
|
||||
|
||||
# 模板意图检测:用户是否提到了模板名?
|
||||
template_keywords = _detect_template_intent(user_input)
|
||||
if template_keywords and kb_id:
|
||||
try:
|
||||
from backend.kb_searcher import search_templates_in_kb
|
||||
templates = search_templates_in_kb(kb_id, template_keywords, k=1)
|
||||
if templates:
|
||||
tmpl = templates[0]
|
||||
state["kb_template_jrxml"] = tmpl.get("content", "")
|
||||
state["kb_template_name"] = (
|
||||
tmpl.get("metadata", {}).get("report_name", "")
|
||||
)
|
||||
context += (
|
||||
f"\n\n[匹配到模板: {state['kb_template_name']}]\n"
|
||||
f"{state['kb_template_jrxml']}"
|
||||
)
|
||||
except Exception:
|
||||
_node_log.warning("模板检索失败", extra={"kb_id": kb_id})
|
||||
|
||||
state["retrieved_context"] = context
|
||||
except Exception:
|
||||
_node_log.exception("RAG 检索失败", extra={"user_input": user_input[:80]})
|
||||
state["retrieved_context"] = ""
|
||||
return state
|
||||
|
||||
|
||||
def _detect_template_intent(user_input: str) -> str:
|
||||
"""检测用户输入中是否包含模板引用意图,返回提取的搜索关键词。"""
|
||||
import re
|
||||
patterns = [
|
||||
r"根据(.+?)模板",
|
||||
r"基于(.+?)模板",
|
||||
r"参照(.+?)模板",
|
||||
r"用(.+?)模板",
|
||||
r"使用(.+?)模板",
|
||||
r"(.+单)模板",
|
||||
]
|
||||
for pat in patterns:
|
||||
m = re.search(pat, user_input)
|
||||
if m:
|
||||
kw = m.group(1).strip()
|
||||
if len(kw) >= 2:
|
||||
return kw
|
||||
return ""
|
||||
|
||||
|
||||
def _build_template_context(state: dict) -> str:
|
||||
"""构建模板上下文,用于注入生成 prompt。
|
||||
优先级:对话上传 > KB 检索 > KB 字段定义。
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# 对话中上传的 JRXML 模板
|
||||
uploaded = state.get("uploaded_template_jrxml", "")
|
||||
if uploaded:
|
||||
params = state.get("uploaded_template_params", [])
|
||||
param_str = "\n".join(
|
||||
f" - {p['name']} ({p.get('type', 'String')})" for p in params
|
||||
) if params else "(参数列表未解析)"
|
||||
parts.append(
|
||||
"[对话上传的模板]\n"
|
||||
f"以下为用户上传的 JRXML 模板,请基于此模板进行修改:\n"
|
||||
f"模板参数:\n{param_str}\n"
|
||||
f"```xml\n{uploaded}\n```"
|
||||
)
|
||||
|
||||
# KB 检索到的模板
|
||||
kb_tmpl = state.get("kb_template_jrxml", "")
|
||||
kb_name = state.get("kb_template_name", "")
|
||||
if kb_tmpl and not uploaded:
|
||||
parts.append(
|
||||
f"[知识库模板: {kb_name}]\n"
|
||||
f"以下为从知识库检索到的 JRXML 模板,请作为结构参考:\n"
|
||||
f"```xml\n{kb_tmpl}\n```"
|
||||
)
|
||||
|
||||
# KB 字段定义
|
||||
kb_fields = state.get("kb_fields", [])
|
||||
if kb_fields:
|
||||
field_table = "\n".join(
|
||||
f"| {f['name']} | {f.get('description', '')} | {f.get('type', 'String')} |"
|
||||
for f in kb_fields
|
||||
)
|
||||
parts.append(
|
||||
"[可用数据字段]\n"
|
||||
"生成 JRXML 时请使用以下字段作为 $P{{xxx}} 参数:\n"
|
||||
f"| 字段名 | 含义 | 类型 |\n|---|---|---|\n{field_table}"
|
||||
)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
@log_node("generate")
|
||||
def generate(state: AgentState) -> Dict:
|
||||
"""根据用户需求和检索到的上下文生成初始 JRXML。"""
|
||||
@@ -656,6 +831,7 @@ def generate(state: AgentState) -> Dict:
|
||||
prompt = load_prompt("initial_generation").format(
|
||||
context=state.get("retrieved_context", ""),
|
||||
user_request=user_request,
|
||||
template_context=_build_template_context(state),
|
||||
)
|
||||
full = []
|
||||
for chunk in llm.stream(prompt):
|
||||
@@ -683,6 +859,7 @@ def generate_skeleton(state: AgentState) -> Dict:
|
||||
layout_schema=schema_text,
|
||||
context=state.get("retrieved_context", ""),
|
||||
user_request=user_request,
|
||||
template_context=_build_template_context(state),
|
||||
)
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
|
||||
@@ -700,92 +877,145 @@ def generate_skeleton(state: AgentState) -> Dict:
|
||||
|
||||
@log_node("refine_layout")
|
||||
def refine_layout(state: AgentState) -> Dict:
|
||||
"""阶段二:使用采样坐标(表头 + 首行数据 + 最后一行)精确调整元素位置。"""
|
||||
"""阶段二:Band 级窗口化精调 — 拆解骨架 JRXML 为独立 band,逐窗口 LLM 调整坐标。
|
||||
|
||||
流程:
|
||||
1. decompose_jrxml() 拆解为 header + bands + footer
|
||||
2. 每个 band 作为一个(或多个,若 >4000 字符)窗口发给 LLM
|
||||
3. LLM 只修改该窗口内 reportElement 的 x/y/width/height
|
||||
4. reassemble_jrxml() 重组 + validate_element_count() 校验
|
||||
"""
|
||||
from langgraph.config import get_stream_writer
|
||||
from agent.jrxml_windower import (
|
||||
decompose_jrxml, split_band_into_windows, reassemble_band_windows,
|
||||
reassemble_jrxml, validate_element_count,
|
||||
)
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="refine_layout")
|
||||
|
||||
ocr_rows = state.get("ocr_elements", [])
|
||||
sampled = {}
|
||||
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
|
||||
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
|
||||
if len(ocr_rows) > 1:
|
||||
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
|
||||
if len(ocr_rows) > 2:
|
||||
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
|
||||
sampled_text = json.dumps(sampled, ensure_ascii=False, indent=2)
|
||||
|
||||
prompt = load_prompt("refine_layout").format(
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
sampled_coordinates=sampled_text,
|
||||
)
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "refine_layout")
|
||||
if not full_text.strip():
|
||||
_node_log.error("refine_layout LLM 返回空响应,保留前一版本")
|
||||
if not prev_jrxml.strip():
|
||||
_node_log.warning("refine_layout 无输入 JRXML,跳过")
|
||||
return state
|
||||
jrxml = _extract_jrxml(full_text)
|
||||
if len(jrxml.strip()) < 200:
|
||||
_node_log.warning(f"refine_layout 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||||
jrxml = prev_jrxml
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
|
||||
# 拆解 JRXML
|
||||
parts = decompose_jrxml(prev_jrxml)
|
||||
if parts is None or parts.band_count == 0:
|
||||
_node_log.warning("refine_layout 拆解失败或无 band,跳过")
|
||||
return state
|
||||
|
||||
# 构建采样坐标
|
||||
ocr_rows = state.get("ocr_elements", [])
|
||||
sampled_text = _build_sampled_text(ocr_rows)
|
||||
template_ctx = _build_template_context(state)
|
||||
|
||||
# 逐 band 窗口化精调
|
||||
modified_bands: dict[str, str] = {}
|
||||
total_windows = 0
|
||||
for band in parts.bands:
|
||||
if band.element_count == 0:
|
||||
modified_bands[band.label] = band.band_xml
|
||||
continue
|
||||
|
||||
windows = split_band_into_windows(band, max_chars=4000)
|
||||
total_windows += len(windows)
|
||||
band_results: list[str] = []
|
||||
|
||||
for wi, win_xml in enumerate(windows):
|
||||
prompt = load_prompt("refine_layout").format(
|
||||
band_name=band.section_name,
|
||||
band_index=band.band_index,
|
||||
band_height=_extract_band_height(win_xml),
|
||||
window_index=wi + 1,
|
||||
total_windows=len(windows),
|
||||
xml_fragment=win_xml,
|
||||
sampled_coordinates=sampled_text,
|
||||
template_context=template_ctx,
|
||||
)
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
fragment = _extract_xml_fragment(content)
|
||||
if fragment:
|
||||
band_results.append(fragment)
|
||||
writer({"type": "stream", "node": "refine_layout",
|
||||
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
|
||||
else:
|
||||
_node_log.warning("refine_layout 窗口 %s/%d 返回空,使用原文",
|
||||
band.label, wi + 1)
|
||||
band_results.append(win_xml)
|
||||
except Exception as e:
|
||||
_node_log.warning("refine_layout 窗口 %s/%d LLM 失败: %s,使用原文",
|
||||
band.label, wi + 1, e)
|
||||
band_results.append(win_xml)
|
||||
|
||||
if len(band_results) == 1:
|
||||
modified_bands[band.label] = band_results[0]
|
||||
else:
|
||||
modified_bands[band.label] = reassemble_band_windows(band_results)
|
||||
|
||||
# 重组并校验
|
||||
result = reassemble_jrxml(parts, modified_bands)
|
||||
validation = validate_element_count(prev_jrxml, result, "refine_layout")
|
||||
_node_log.info(
|
||||
"refine_layout 窗口化完成: %d bands, %d 窗口, 元素 %d→%d (%.1f%%)",
|
||||
parts.band_count, total_windows,
|
||||
validation["original"], validation["modified"],
|
||||
validation["change_pct"] * 100,
|
||||
)
|
||||
|
||||
if not validation["ok"]:
|
||||
_node_log.error("refine_layout 元素丢失过多,回退到骨架版本")
|
||||
return state
|
||||
|
||||
state["current_jrxml"] = result
|
||||
state["conversation_history"].append({"role": "assistant", "content": result})
|
||||
return state
|
||||
|
||||
|
||||
@log_node("map_fields")
|
||||
def map_fields(state: AgentState) -> Dict:
|
||||
"""阶段三:将占位字段名替换为 OCR 提取的真实字段名。"""
|
||||
from langgraph.config import get_stream_writer
|
||||
"""阶段三:程序化字段映射 — 用正则将 $F{field_N} 替换为 OCR 字段名,不调 LLM。
|
||||
|
||||
writer = get_stream_writer()
|
||||
llm = get_llm(caller="map_fields")
|
||||
仅当 OCR 字段名包含中文等需要语义解释时才回退到 LLM。
|
||||
"""
|
||||
from agent.jrxml_windower import validate_element_count
|
||||
|
||||
ocr_result = state.get("ocr_extraction_result", {})
|
||||
fields_text = ""
|
||||
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
|
||||
field_descs = []
|
||||
for f in ocr_result["fields"]:
|
||||
fname = f.get("field_name", "")
|
||||
fval = f.get("field_value", "")
|
||||
if fname:
|
||||
field_descs.append(f" - {fname}: {fval}")
|
||||
if field_descs:
|
||||
fields_text = "提取的字段:\n" + "\n".join(field_descs)
|
||||
if not fields_text:
|
||||
elements = ocr_result.get("elements", []) if isinstance(ocr_result, dict) else []
|
||||
if elements:
|
||||
texts = [e.get("text", "") for e in elements if e.get("text")]
|
||||
fields_text = "OCR 文本内容:\n" + "\n".join(f" - {t}" for t in texts[:50])
|
||||
|
||||
prompt = load_prompt("field_mapping").format(
|
||||
current_jrxml=state.get("current_jrxml", ""),
|
||||
ocr_fields=fields_text,
|
||||
)
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "map_fields")
|
||||
# 空响应重试:有时 LLM 第一轮不输出,换个方式再试一次
|
||||
if not full_text.strip():
|
||||
_node_log.warning("map_fields 第一轮返回空响应,尝试简化 prompt 重试")
|
||||
retry_prompt = (
|
||||
"请将以下 JRXML 中的占位字段名 $F{field_1}, $F{field_2}, ... 替换为 OCR 提取的真实字段名。\n"
|
||||
"规则:根据列顺序映射——$F{field_1} 对应第1列,$F{field_2} 对应第2列,以此类推。\n"
|
||||
"同时更新 <field name=\"...\"> 声明和所有 $F{...} 引用。\n"
|
||||
"只输出完整 JRXML,不要解释。\n\n"
|
||||
f"OCR 字段:\n{fields_text}\n\n"
|
||||
f"JRXML:\n{prev_jrxml}"
|
||||
)
|
||||
full_text = _generate_with_continuation(llm, retry_prompt, writer, "map_fields")
|
||||
if not full_text.strip():
|
||||
_node_log.error("map_fields LLM 重试后仍返回空响应,保留占位字段版本")
|
||||
if not prev_jrxml.strip():
|
||||
_node_log.warning("map_fields 无输入 JRXML,跳过")
|
||||
return state
|
||||
jrxml = _extract_jrxml(full_text)
|
||||
if len(jrxml.strip()) < 200:
|
||||
_node_log.warning(f"map_fields 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||||
jrxml = prev_jrxml
|
||||
state["current_jrxml"] = jrxml
|
||||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||||
|
||||
# 提取 OCR 字段列表
|
||||
ocr_result = state.get("ocr_extraction_result", {})
|
||||
ocr_fields: list[dict] = []
|
||||
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
|
||||
ocr_fields = ocr_result["fields"]
|
||||
|
||||
if not ocr_fields:
|
||||
_node_log.info("map_fields 无 OCR 字段,保留占位字段名")
|
||||
state["conversation_history"].append({"role": "assistant", "content": prev_jrxml})
|
||||
return state
|
||||
|
||||
# 程序化替换(主路径)
|
||||
result = _programmatic_map_fields(prev_jrxml, ocr_fields)
|
||||
|
||||
# 校验
|
||||
validation = validate_element_count(prev_jrxml, result, "map_fields")
|
||||
_node_log.info(
|
||||
"map_fields 程序化完成: %d 个字段, 元素 %d→%d (%.1f%%)",
|
||||
len(ocr_fields),
|
||||
validation["original"], validation["modified"],
|
||||
validation["change_pct"] * 100,
|
||||
)
|
||||
|
||||
if not validation["ok"]:
|
||||
_node_log.error("map_fields 元素丢失过多,回退到前一版本")
|
||||
return state
|
||||
|
||||
state["current_jrxml"] = result
|
||||
state["conversation_history"].append({"role": "assistant", "content": result})
|
||||
return state
|
||||
|
||||
|
||||
@@ -810,6 +1040,7 @@ def modify_jrxml(state: AgentState) -> Dict:
|
||||
conversation_history=conv_text,
|
||||
modification_request=state.get("user_modification_request", ""),
|
||||
ocr_context=_format_ocr_context(state),
|
||||
template_context=_build_template_context(state),
|
||||
)
|
||||
prev_jrxml = state.get("current_jrxml", "")
|
||||
full_text = _generate_with_continuation(llm, prompt, writer, "modify_jrxml")
|
||||
@@ -1181,6 +1412,7 @@ def correct_jrxml(state: AgentState) -> Dict:
|
||||
ocr_context=ocr_context,
|
||||
layout_schema_text=layout_text,
|
||||
fidelity_context=fidelity_text,
|
||||
template_context=_build_template_context(state),
|
||||
)
|
||||
# 保存修正前状态(供 validate 判断是否写入错误知识库)
|
||||
state["last_error_case"] = {
|
||||
|
||||
@@ -51,3 +51,14 @@ class AgentState(TypedDict, total=False):
|
||||
# 需求9:分层精确生成
|
||||
layout_schema: dict # extract_layout_schema() 输出,列+区域结构
|
||||
ocr_elements: list # OCR 原始行数据(用于阶段二坐标采样)
|
||||
|
||||
# 需求10:多租户知识库
|
||||
kb_id: str # 当前会话绑定的知识库 ID
|
||||
kb_fields: list # KB 提取的字段定义 [{name, description, type, required}]
|
||||
kb_field_mapping: dict # OCR 字段 → KB 字段映射 {"工单号": "billNo", ...}
|
||||
uploaded_template_jrxml: str # 对话中上传的 JRXML 模板原文
|
||||
uploaded_template_params: list # 解析出的参数 [{name, type}]
|
||||
kb_template_jrxml: str # 从 KB 检索到的模板 JRXML
|
||||
kb_template_name: str # 检索到的模板名称
|
||||
datasource_mode: str # "parameter" 或 "jdbc"
|
||||
db_config: dict # JDBC 连接配置
|
||||
|
||||
Reference in New Issue
Block a user