WIP: baseline on fix/retry-failure-root-causes

This commit is contained in:
2026-05-24 22:38:30 +08:00
parent 2d5183d2bd
commit f25a93b539
5 changed files with 438 additions and 35 deletions
+25 -3
View File
@@ -181,6 +181,28 @@ def split_band_into_windows(band: BandInfo, max_chars: int = 4000) -> list[str]:
# ── 重组 ──────────────────────────────────────────────────────────
def _recalc_band_height(band_xml: str, margin: int = 20) -> str:
"""根据波段内所有子元素的 y + height 重新计算波段 height。"""
max_bottom = 0
for m in re.finditer(r'<reportElement\b([^>]*)/>', band_xml):
attrs = m.group(1)
ym = re.search(r'\sy\s*=\s*"(\d+)"', attrs)
hm = re.search(r'\sheight\s*=\s*"(\d+)"', attrs)
if ym and hm:
bottom = int(ym.group(1)) + int(hm.group(1))
if bottom > max_bottom:
max_bottom = bottom
if max_bottom == 0:
return band_xml
new_height = max_bottom + margin
return re.sub(
r'(<band\b[^>]*\sheight\s*=\s*)"(\d+)"',
rf'\g<1>"{new_height}"',
band_xml,
count=1,
)
def reassemble_band_windows(modified_windows: list[str]) -> str:
"""将多个窗口的修改结果重新合并为一个 band XML。
@@ -188,12 +210,12 @@ def reassemble_band_windows(modified_windows: list[str]) -> str:
中间拼接所有窗口内部的元素内容。
"""
if len(modified_windows) == 1:
return modified_windows[0]
return _recalc_band_height(modified_windows[0])
first = modified_windows[0]
band_open_end = first.find(">")
if band_open_end == -1:
return "\n".join(modified_windows)
return _recalc_band_height("\n".join(modified_windows))
band_open = first[:band_open_end + 1]
last = modified_windows[-1]
@@ -205,7 +227,7 @@ def reassemble_band_windows(modified_windows: list[str]) -> str:
if inner:
inner_parts.append(inner)
return band_open + "\n" + "\n".join(inner_parts) + "\n" + band_close
return _recalc_band_height(band_open + "\n" + "\n".join(inner_parts) + "\n" + band_close)
def reassemble_jrxml(parts: JRXMLParts, modified_bands: dict[str, str]) -> str:
+81 -16
View File
@@ -124,13 +124,8 @@ def process_input(state: AgentState) -> Dict:
try:
from backend.ocr_extractor import OcrExtractor
extractor = OcrExtractor()
default_fields = [
"发票代码", "发票号码", "开票日期", "合计金额", "校验码",
"价税合计", "总金额", "日期", "金额", "数量", "单价", "税率",
"购买方名称", "销售方名称", "货物名称", "规格型号",
"不含税金额", "税额",
]
ocr_result = extractor.extract(uploaded_path, default_fields)
# 不传预设字段名,让 OCR 自动发现文档中的所有键值对
ocr_result = extractor.extract(uploaded_path)
if ocr_result.get("ocr_available"):
state["ocr_extraction_result"] = ocr_result
_node_log.info(
@@ -483,12 +478,18 @@ def _format_row_coordinates(row: dict) -> dict:
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
cols = []
for ci, e in enumerate(sorted_elems):
x = e.get("x", 0)
y = e.get("y", 0)
w = e.get("w", 0)
h = e.get("h", 0)
if not (x > 0 and y > 0 and w > 0 and h > 0):
continue
cols.append({
"col": ci,
"x": e.get("x", 0),
"y": e.get("y", 0),
"w": e.get("w", 0),
"h": e.get("h", 0),
"x": x,
"y": y,
"w": w,
"h": h,
"font_size": e.get("font_size", 12),
"text": e.get("text", ""),
})
@@ -529,10 +530,33 @@ def _extract_xml_fragment(text: str) -> str:
return text
def _count_zero_coordinate_elements(xml: str) -> tuple[int, int]:
"""统计坐标无效(x=0 或 y=0 或 width=0 或 height=0)的 reportElement 数量。
返回 (zero_count, total_count)。
"""
total = 0
zero = 0
for m in re.finditer(r'<reportElement\b([^>]*)/>', xml):
total += 1
attrs = m.group(1)
xm = re.search(r'\sx\s*=\s*"(\d+)"', attrs)
ym = re.search(r'\sy\s*=\s*"(\d+)"', attrs)
wm = re.search(r'\swidth\s*=\s*"(\d+)"', attrs)
hm = re.search(r'\sheight\s*=\s*"(\d+)"', attrs)
x = int(xm.group(1)) if xm else 0
y = int(ym.group(1)) if ym else 0
w = int(wm.group(1)) if wm else 0
h = int(hm.group(1)) if hm else 0
if x == 0 or y == 0 or w == 0 or h == 0:
zero += 1
return zero, total
def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
"""程序化字段映射:将 $F{{field_N}} 替换为 OCR 提取的真实字段名。
纯正则替换,不调 LLM。100% 确定性,零内容丢失。
未映射的 field_N 会被重命名为基于波段上下文的描述性名称。
"""
result = jrxml
for i, f in enumerate(ocr_fields):
@@ -543,13 +567,45 @@ def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
real_name = _sanitize_field_name(raw_name)
if real_name == placeholder:
continue
# 替换 field 声明: <ns0:field name="field_1" → <ns0:field name="customer_name"
result = re.sub(
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
rf'\g<1>{real_name}\g<2>', result,
)
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{real_name}}}')
# 第二遍:为剩余未映射的 field_N 赋予基于波段位置的描述性名称
remaining = set()
for m in re.finditer(r'\$F\{(field_\d+)\}', result):
remaining.add(m.group(1))
if remaining:
_SECTION_TAGS = (
"title", "pageHeader", "columnHeader", "detail", "columnFooter",
"pageFooter", "summary", "background", "noData",
)
for placeholder in sorted(remaining, key=lambda x: int(re.search(r'\d+', x).group())):
n = int(re.search(r'\d+', placeholder).group())
# 查找第一个引用此字段的位置,确定波段上下文
pattern = rf'\$F\{{{re.escape(placeholder)}\}}'
m = re.search(pattern, result)
section = "data"
if m:
before = result[:m.start()]
# 从后往前找最近的 section 标签
for tag in _SECTION_TAGS:
# 找最近的未闭合 section 标签
opens = [o.start() for o in re.finditer(rf'<{tag}>', before)]
closes = [o.start() for o in re.finditer(rf'</{tag}>', before)]
last_open = opens[-1] if opens else -1
last_close = closes[-1] if closes else -1
if last_open > last_close:
section = tag
break
new_name = f"{section}_f{n}"
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{new_name}}}')
result = re.sub(
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
rf'\g<1>{new_name}\g<2>', result,
)
return result
@@ -943,9 +999,18 @@ def refine_layout(state: AgentState) -> Dict:
content = response.content if hasattr(response, "content") else str(response)
fragment = _extract_xml_fragment(content)
if fragment:
band_results.append(fragment)
writer({"type": "stream", "node": "refine_layout",
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
zero_count, total = _count_zero_coordinate_elements(fragment)
if total > 0 and zero_count / total > 0.3:
_node_log.warning(
"refine_layout 窗口 %s/%d 零坐标元素 %d/%d (%.0f%%),使用原文",
band.label, wi + 1, zero_count, total,
zero_count / total * 100,
)
band_results.append(win_xml)
else:
band_results.append(fragment)
writer({"type": "stream", "node": "refine_layout",
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
else:
_node_log.warning("refine_layout 窗口 %s/%d 返回空,使用原文",
band.label, wi + 1)