Files
agent_jrxml/agent/jrxml_windower.py
T

400 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""JRXML 窗口化拆解与重组工具。
用于 3 阶段生成管道的 refine_layout 和 map_fields 节点:
- 将大段 JRXML 按 band 拆解为独立窗口
- 每个窗口独立发送给 LLM 进行坐标精调
- 重组所有窗口 + 校验元素完整性
调用者: agent/nodes.py (refine_layout, map_fields)
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Optional
import defusedxml.ElementTree as ET
from backend.logger import get_logger
_windower_log = get_logger("jrxml.windower")
# 需要按 section 拆解的 band 容器标签
_SECTION_TAGS = {
"title", "pageHeader", "columnHeader", "detail", "columnFooter",
"pageFooter", "lastPageFooter", "summary", "noData", "background",
}
# 不发给 LLM 的 header 元素(原样保留)
_HEADER_TAGS = {
"property", "propertyExpression", "import", "template", "reportFont",
"style", "subDataset", "scriptlet", "parameter", "queryString",
"field", "sortField", "variable", "filterExpression", "group",
}
@dataclass
class BandInfo:
"""单个 band 的拆解信息。"""
section_name: str # 所属 section 名,如 "title", "detail"
band_index: int # 在该 section 中的序号(0-based
band_xml: str # 完整 <band ...>...</band> 原始 XML
element_count: int # textField + staticText 数量
char_length: int # 字符数
@property
def label(self) -> str:
"""用于日志和 prompt 的标识。"""
if self.band_index > 0:
return f"{self.section_name}_band{self.band_index}"
return self.section_name
@dataclass
class JRXMLParts:
"""JRXML 拆解结果。"""
declaration: str # <?xml version="1.0"?>(如有)
root_open: str # <jasperReport ...>
header_xml: str # fields/params/queryString 等(不发给 LLM
bands: list[BandInfo] # 按出现顺序
footer: str # </jasperReport>
@property
def band_count(self) -> int:
return len(self.bands)
@property
def total_elements(self) -> int:
return sum(b.element_count for b in self.bands)
# ── 拆解 ──────────────────────────────────────────────────────────
def decompose_jrxml(jrxml: str) -> Optional[JRXMLParts]:
"""将 JRXML 字符串拆解为 header + bands + footer 三部分。
使用 defusedxml.ElementTree 进行安全解析。
返回 None 表示解析失败。
"""
try:
root = ET.fromstring(jrxml)
except ET.ParseError as e:
_windower_log.error("JRXML 解析失败: %s", e)
return None
tag = _local_tag(root.tag)
if tag != "jasperReport":
_windower_log.error("根元素不是 jasperReport: %s", tag)
return None
# 提取 XML 声明
declaration = ""
if jrxml.strip().startswith("<?xml"):
decl_end = jrxml.find("?>")
if decl_end != -1:
declaration = jrxml[:decl_end + 2]
# 提取根元素属性来重建 root_open
root_open = _build_root_open(jrxml, root)
# 分离 header 子元素和 section 子元素
header_children = []
section_children = [] # (section_tag, child_elem)
for child in root:
child_tag = _local_tag(child.tag)
if child_tag in _HEADER_TAGS:
header_children.append(child)
elif child_tag in _SECTION_TAGS:
section_children.append((child_tag, child))
# 构建 header_xml:序列化所有 header 子元素
header_parts = []
for child in header_children:
header_parts.append(_elem_to_string(child))
header_xml = "\n".join(header_parts)
# 提取 bands:每个 section 内可能有多个 <band>
bands = []
for sec_tag, sec_elem in section_children:
for bi, band_elem in enumerate(sec_elem):
band_local = _local_tag(band_elem.tag)
if band_local != "band":
continue
band_xml = _elem_to_string(band_elem)
ec = _count_elements_in_text(band_xml)
bands.append(BandInfo(
section_name=sec_tag,
band_index=bi,
band_xml=band_xml,
element_count=ec,
char_length=len(band_xml),
))
# 提取 footer</jasperReport> 闭合标签
footer = _extract_footer(jrxml)
parts = JRXMLParts(
declaration=declaration,
root_open=root_open,
header_xml=header_xml,
bands=bands,
footer=footer,
)
_windower_log.info(
"JRXML 拆解完成: %d bands, %d 个元素, header %d 字符",
len(bands), parts.total_elements, len(header_xml),
)
return parts
# ── 窗口切分 ──────────────────────────────────────────────────────
# 安全的元素边界:在这些闭合标签后切分
_SAFE_SPLIT_CLOSING = re.compile(
r"</(?:[\w:]+:)?(?:textField|staticText|line|rectangle|ellipse|image|"
r"frame|subreport|elementGroup|break|componentElement)>\s*"
)
def split_band_into_windows(band: BandInfo, max_chars: int = 4000) -> list[str]:
"""将一个 band 的 XML 在元素边界处切分为多个窗口。
每个窗口是合法的 XML 片段(完整的 <band>...</band>),
大小不超过 max_chars。
"""
if band.char_length <= max_chars:
return [band.band_xml]
inner = _extract_band_inner(band.band_xml)
if not inner:
return [band.band_xml]
segments = _split_at_boundaries(inner, _SAFE_SPLIT_CLOSING)
if len(segments) <= 1:
return [band.band_xml]
windows = _greedy_aggregate(segments, band.band_xml, max_chars)
return windows
# ── 重组 ──────────────────────────────────────────────────────────
def _recalc_band_height(band_xml: str, margin: int = 20) -> str:
"""根据波段内所有子元素的 y + height 重新计算波段 height。"""
max_bottom = 0
for m in re.finditer(r'<reportElement\b([^>]*)/>', band_xml):
attrs = m.group(1)
ym = re.search(r'\sy\s*=\s*"(\d+)"', attrs)
hm = re.search(r'\sheight\s*=\s*"(\d+)"', attrs)
if ym and hm:
bottom = int(ym.group(1)) + int(hm.group(1))
if bottom > max_bottom:
max_bottom = bottom
if max_bottom == 0:
return band_xml
new_height = max_bottom + margin
return re.sub(
r'(<band\b[^>]*\sheight\s*=\s*)"(\d+)"',
rf'\g<1>"{new_height}"',
band_xml,
count=1,
)
def reassemble_band_windows(modified_windows: list[str]) -> str:
"""将多个窗口的修改结果重新合并为一个 band XML。
策略:取第一个窗口的开头(band 标签)和最后一个窗口的结尾(/band 标签),
中间拼接所有窗口内部的元素内容。
"""
if len(modified_windows) == 1:
return _recalc_band_height(modified_windows[0])
first = modified_windows[0]
band_open_end = first.find(">")
if band_open_end == -1:
return _recalc_band_height("\n".join(modified_windows))
band_open = first[:band_open_end + 1]
last = modified_windows[-1]
band_close = _extract_band_close(last)
inner_parts = []
for win in modified_windows:
inner = _extract_band_inner(win)
if inner:
inner_parts.append(inner)
return _recalc_band_height(band_open + "\n" + "\n".join(inner_parts) + "\n" + band_close)
def reassemble_jrxml(parts: JRXMLParts, modified_bands: dict[str, str]) -> str:
"""将修改后的 bands 与 header/footer 重新组装为完整 JRXML。
modified_bands 的 key 格式为 "{section_name}_band{index}""{section_name}"index=0 时)。
"""
result = []
if parts.declaration:
result.append(parts.declaration)
result.append(parts.root_open)
if parts.header_xml.strip():
result.append(parts.header_xml)
current_section = None
for band in parts.bands:
if band.section_name != current_section:
if current_section is not None:
result.append(f"</{current_section}>")
current_section = band.section_name
result.append(f"<{current_section}>")
modified = modified_bands.get(band.label, band.band_xml)
result.append(modified)
if current_section is not None:
result.append(f"</{current_section}>")
result.append(parts.footer)
return "\n".join(result)
# ── 元素计数与校验 ────────────────────────────────────────────────
_ELEMENT_RE = re.compile(r"<(?:[\w:]+:)?(textField|staticText|field)\b", re.IGNORECASE)
def count_elements(jrxml: str) -> int:
"""正则计数 JRXML 中的 textField + staticText + field 声明。"""
return len(_ELEMENT_RE.findall(jrxml))
def validate_element_count(original: str, modified: str, stage: str) -> dict:
"""校验修改前后的元素数变化。
返回:
{"ok": bool, "original": int, "modified": int, "change_pct": float}
变化 > 10% 时 ok=False,调用方应回退。
"""
orig = count_elements(original)
mod = count_elements(modified)
if orig == 0:
return {"ok": True, "original": 0, "modified": mod, "change_pct": 0}
change = abs(mod - orig) / orig
ok = change <= 0.10
if not ok:
_windower_log.error(
"%s 元素数变化过大: %d%d (%.1f%%)",
stage, orig, mod, change * 100,
)
elif change > 0.05:
_windower_log.warning(
"%s 元素数有差异: %d%d (%.1f%%)",
stage, orig, mod, change * 100,
)
return {"ok": ok, "original": orig, "modified": mod, "change_pct": round(change, 4)}
# ── 内部工具函数 ──────────────────────────────────────────────────
def _local_tag(tag: str) -> str:
"""去除 XML 命名空间前缀。"""
return tag.split("}")[-1] if "}" in tag else tag
def _elem_to_string(elem: ET.Element) -> str:
"""将 ElementTree 元素序列化为字符串(使用 defusedxml 的 tostring)。"""
raw = ET.tostring(elem, encoding="unicode")
return raw.strip()
def _build_root_open(jrxml: str, root: ET.Element) -> str:
"""从原始文本重建 <jasperReport ...> 开头标签。"""
m = re.search(r"<jasperReport\b[^>]*>", jrxml, re.IGNORECASE)
if m:
return m.group(0)
attrs = []
for k, v in root.attrib.items():
attrs.append(f'{k}="{v}"')
return "<jasperReport " + " ".join(attrs) + ">"
def _extract_footer(jrxml: str) -> str:
"""提取 </jasperReport> 闭合标签。"""
m = re.search(r"</(?:[\w:]+:)?jasperReport>\s*$", jrxml, re.IGNORECASE)
if m:
return m.group(0).rstrip()
return "</jasperReport>"
_BAND_CLOSE_RE = re.compile(r"</(?:[\w:]+:)?band>\s*$", re.IGNORECASE)
def _extract_band_close(band_xml: str) -> str:
"""提取 band 的闭合标签(兼容命名空间前缀),如 '</ns0:band>''</band>'"""
m = _BAND_CLOSE_RE.search(band_xml)
return m.group(0).rstrip() if m else "</band>"
def _extract_band_inner(band_xml: str) -> str:
"""提取 <band ...> 和 </ns0:band> 之间的内容(兼容命名空间前缀)。"""
tag_end = band_xml.find(">")
if tag_end == -1:
return ""
close_m = _BAND_CLOSE_RE.search(band_xml)
if not close_m:
return band_xml[tag_end + 1:].strip()
return band_xml[tag_end + 1:close_m.start()].strip()
def _split_at_boundaries(text: str, boundary_re: re.Pattern) -> list[str]:
"""在正则匹配的闭合标签处切分文本。
返回切分后的片段列表(分隔符附加到前一个片段末尾)。
"""
segments = []
last_end = 0
for m in boundary_re.finditer(text):
end = m.end()
segments.append(text[last_end:end])
last_end = end
if last_end < len(text):
segments.append(text[last_end:])
elif not segments:
segments.append(text)
return segments
def _greedy_aggregate(segments: list[str], band_xml: str, max_chars: int) -> list[str]:
"""贪心聚合:将片段组合成不超过 max_chars 的窗口。
每个窗口包上 <band ...> 和 </band> 标签。
"""
tag_end = band_xml.find(">")
band_open = band_xml[:tag_end + 1] if tag_end != -1 else "<band>"
band_close = _extract_band_close(band_xml)
overhead = len(band_open) + len(band_close) + 1 # +1 for \n
windows = []
current = []
current_len = overhead
for seg in segments:
seg_len = len(seg)
if current and current_len + seg_len > max_chars:
windows.append(band_open + "\n" + "".join(current) + "\n" + band_close)
current = [seg]
current_len = overhead + seg_len
else:
current.append(seg)
current_len += seg_len
if current:
windows.append(band_open + "\n" + "".join(current) + "\n" + band_close)
return windows
def _count_elements_in_text(xml_text: str) -> int:
"""统计 XML 文本中的 textField + staticText 数量。"""
return len(_ELEMENT_RE.findall(xml_text))