WIP: baseline on fix/retry-failure-root-causes
This commit is contained in:
+25
-3
@@ -181,6 +181,28 @@ def split_band_into_windows(band: BandInfo, max_chars: int = 4000) -> list[str]:
|
||||
|
||||
# ── 重组 ──────────────────────────────────────────────────────────
|
||||
|
||||
def _recalc_band_height(band_xml: str, margin: int = 20) -> str:
|
||||
"""根据波段内所有子元素的 y + height 重新计算波段 height。"""
|
||||
max_bottom = 0
|
||||
for m in re.finditer(r'<reportElement\b([^>]*)/>', band_xml):
|
||||
attrs = m.group(1)
|
||||
ym = re.search(r'\sy\s*=\s*"(\d+)"', attrs)
|
||||
hm = re.search(r'\sheight\s*=\s*"(\d+)"', attrs)
|
||||
if ym and hm:
|
||||
bottom = int(ym.group(1)) + int(hm.group(1))
|
||||
if bottom > max_bottom:
|
||||
max_bottom = bottom
|
||||
if max_bottom == 0:
|
||||
return band_xml
|
||||
new_height = max_bottom + margin
|
||||
return re.sub(
|
||||
r'(<band\b[^>]*\sheight\s*=\s*)"(\d+)"',
|
||||
rf'\g<1>"{new_height}"',
|
||||
band_xml,
|
||||
count=1,
|
||||
)
|
||||
|
||||
|
||||
def reassemble_band_windows(modified_windows: list[str]) -> str:
|
||||
"""将多个窗口的修改结果重新合并为一个 band XML。
|
||||
|
||||
@@ -188,12 +210,12 @@ def reassemble_band_windows(modified_windows: list[str]) -> str:
|
||||
中间拼接所有窗口内部的元素内容。
|
||||
"""
|
||||
if len(modified_windows) == 1:
|
||||
return modified_windows[0]
|
||||
return _recalc_band_height(modified_windows[0])
|
||||
|
||||
first = modified_windows[0]
|
||||
band_open_end = first.find(">")
|
||||
if band_open_end == -1:
|
||||
return "\n".join(modified_windows)
|
||||
return _recalc_band_height("\n".join(modified_windows))
|
||||
band_open = first[:band_open_end + 1]
|
||||
|
||||
last = modified_windows[-1]
|
||||
@@ -205,7 +227,7 @@ def reassemble_band_windows(modified_windows: list[str]) -> str:
|
||||
if inner:
|
||||
inner_parts.append(inner)
|
||||
|
||||
return band_open + "\n" + "\n".join(inner_parts) + "\n" + band_close
|
||||
return _recalc_band_height(band_open + "\n" + "\n".join(inner_parts) + "\n" + band_close)
|
||||
|
||||
|
||||
def reassemble_jrxml(parts: JRXMLParts, modified_bands: dict[str, str]) -> str:
|
||||
|
||||
+81
-16
@@ -124,13 +124,8 @@ def process_input(state: AgentState) -> Dict:
|
||||
try:
|
||||
from backend.ocr_extractor import OcrExtractor
|
||||
extractor = OcrExtractor()
|
||||
default_fields = [
|
||||
"发票代码", "发票号码", "开票日期", "合计金额", "校验码",
|
||||
"价税合计", "总金额", "日期", "金额", "数量", "单价", "税率",
|
||||
"购买方名称", "销售方名称", "货物名称", "规格型号",
|
||||
"不含税金额", "税额",
|
||||
]
|
||||
ocr_result = extractor.extract(uploaded_path, default_fields)
|
||||
# 不传预设字段名,让 OCR 自动发现文档中的所有键值对
|
||||
ocr_result = extractor.extract(uploaded_path)
|
||||
if ocr_result.get("ocr_available"):
|
||||
state["ocr_extraction_result"] = ocr_result
|
||||
_node_log.info(
|
||||
@@ -483,12 +478,18 @@ def _format_row_coordinates(row: dict) -> dict:
|
||||
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
|
||||
cols = []
|
||||
for ci, e in enumerate(sorted_elems):
|
||||
x = e.get("x", 0)
|
||||
y = e.get("y", 0)
|
||||
w = e.get("w", 0)
|
||||
h = e.get("h", 0)
|
||||
if not (x > 0 and y > 0 and w > 0 and h > 0):
|
||||
continue
|
||||
cols.append({
|
||||
"col": ci,
|
||||
"x": e.get("x", 0),
|
||||
"y": e.get("y", 0),
|
||||
"w": e.get("w", 0),
|
||||
"h": e.get("h", 0),
|
||||
"x": x,
|
||||
"y": y,
|
||||
"w": w,
|
||||
"h": h,
|
||||
"font_size": e.get("font_size", 12),
|
||||
"text": e.get("text", ""),
|
||||
})
|
||||
@@ -529,10 +530,33 @@ def _extract_xml_fragment(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _count_zero_coordinate_elements(xml: str) -> tuple[int, int]:
|
||||
"""统计坐标无效(x=0 或 y=0 或 width=0 或 height=0)的 reportElement 数量。
|
||||
返回 (zero_count, total_count)。
|
||||
"""
|
||||
total = 0
|
||||
zero = 0
|
||||
for m in re.finditer(r'<reportElement\b([^>]*)/>', xml):
|
||||
total += 1
|
||||
attrs = m.group(1)
|
||||
xm = re.search(r'\sx\s*=\s*"(\d+)"', attrs)
|
||||
ym = re.search(r'\sy\s*=\s*"(\d+)"', attrs)
|
||||
wm = re.search(r'\swidth\s*=\s*"(\d+)"', attrs)
|
||||
hm = re.search(r'\sheight\s*=\s*"(\d+)"', attrs)
|
||||
x = int(xm.group(1)) if xm else 0
|
||||
y = int(ym.group(1)) if ym else 0
|
||||
w = int(wm.group(1)) if wm else 0
|
||||
h = int(hm.group(1)) if hm else 0
|
||||
if x == 0 or y == 0 or w == 0 or h == 0:
|
||||
zero += 1
|
||||
return zero, total
|
||||
|
||||
|
||||
def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
|
||||
"""程序化字段映射:将 $F{{field_N}} 替换为 OCR 提取的真实字段名。
|
||||
|
||||
纯正则替换,不调 LLM。100% 确定性,零内容丢失。
|
||||
未映射的 field_N 会被重命名为基于波段上下文的描述性名称。
|
||||
"""
|
||||
result = jrxml
|
||||
for i, f in enumerate(ocr_fields):
|
||||
@@ -543,13 +567,45 @@ def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
|
||||
real_name = _sanitize_field_name(raw_name)
|
||||
if real_name == placeholder:
|
||||
continue
|
||||
# 替换 field 声明: <ns0:field name="field_1" → <ns0:field name="customer_name"
|
||||
result = re.sub(
|
||||
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
|
||||
rf'\g<1>{real_name}\g<2>', result,
|
||||
)
|
||||
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
|
||||
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{real_name}}}')
|
||||
|
||||
# 第二遍:为剩余未映射的 field_N 赋予基于波段位置的描述性名称
|
||||
remaining = set()
|
||||
for m in re.finditer(r'\$F\{(field_\d+)\}', result):
|
||||
remaining.add(m.group(1))
|
||||
if remaining:
|
||||
_SECTION_TAGS = (
|
||||
"title", "pageHeader", "columnHeader", "detail", "columnFooter",
|
||||
"pageFooter", "summary", "background", "noData",
|
||||
)
|
||||
for placeholder in sorted(remaining, key=lambda x: int(re.search(r'\d+', x).group())):
|
||||
n = int(re.search(r'\d+', placeholder).group())
|
||||
# 查找第一个引用此字段的位置,确定波段上下文
|
||||
pattern = rf'\$F\{{{re.escape(placeholder)}\}}'
|
||||
m = re.search(pattern, result)
|
||||
section = "data"
|
||||
if m:
|
||||
before = result[:m.start()]
|
||||
# 从后往前找最近的 section 标签
|
||||
for tag in _SECTION_TAGS:
|
||||
# 找最近的未闭合 section 标签
|
||||
opens = [o.start() for o in re.finditer(rf'<{tag}>', before)]
|
||||
closes = [o.start() for o in re.finditer(rf'</{tag}>', before)]
|
||||
last_open = opens[-1] if opens else -1
|
||||
last_close = closes[-1] if closes else -1
|
||||
if last_open > last_close:
|
||||
section = tag
|
||||
break
|
||||
new_name = f"{section}_f{n}"
|
||||
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{new_name}}}')
|
||||
result = re.sub(
|
||||
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
|
||||
rf'\g<1>{new_name}\g<2>', result,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -943,9 +999,18 @@ def refine_layout(state: AgentState) -> Dict:
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
fragment = _extract_xml_fragment(content)
|
||||
if fragment:
|
||||
band_results.append(fragment)
|
||||
writer({"type": "stream", "node": "refine_layout",
|
||||
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
|
||||
zero_count, total = _count_zero_coordinate_elements(fragment)
|
||||
if total > 0 and zero_count / total > 0.3:
|
||||
_node_log.warning(
|
||||
"refine_layout 窗口 %s/%d 零坐标元素 %d/%d (%.0f%%),使用原文",
|
||||
band.label, wi + 1, zero_count, total,
|
||||
zero_count / total * 100,
|
||||
)
|
||||
band_results.append(win_xml)
|
||||
else:
|
||||
band_results.append(fragment)
|
||||
writer({"type": "stream", "node": "refine_layout",
|
||||
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
|
||||
else:
|
||||
_node_log.warning("refine_layout 窗口 %s/%d 返回空,使用原文",
|
||||
band.label, wi + 1)
|
||||
|
||||
Reference in New Issue
Block a user