WIP: baseline on fix/retry-failure-root-causes

This commit is contained in:
2026-05-24 22:38:30 +08:00
parent 2d5183d2bd
commit f25a93b539
5 changed files with 438 additions and 35 deletions
+81 -16
View File
@@ -124,13 +124,8 @@ def process_input(state: AgentState) -> Dict:
try:
from backend.ocr_extractor import OcrExtractor
extractor = OcrExtractor()
default_fields = [
"发票代码", "发票号码", "开票日期", "合计金额", "校验码",
"价税合计", "总金额", "日期", "金额", "数量", "单价", "税率",
"购买方名称", "销售方名称", "货物名称", "规格型号",
"不含税金额", "税额",
]
ocr_result = extractor.extract(uploaded_path, default_fields)
# 不传预设字段名,让 OCR 自动发现文档中的所有键值对
ocr_result = extractor.extract(uploaded_path)
if ocr_result.get("ocr_available"):
state["ocr_extraction_result"] = ocr_result
_node_log.info(
@@ -483,12 +478,18 @@ def _format_row_coordinates(row: dict) -> dict:
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
cols = []
for ci, e in enumerate(sorted_elems):
x = e.get("x", 0)
y = e.get("y", 0)
w = e.get("w", 0)
h = e.get("h", 0)
if not (x > 0 and y > 0 and w > 0 and h > 0):
continue
cols.append({
"col": ci,
"x": e.get("x", 0),
"y": e.get("y", 0),
"w": e.get("w", 0),
"h": e.get("h", 0),
"x": x,
"y": y,
"w": w,
"h": h,
"font_size": e.get("font_size", 12),
"text": e.get("text", ""),
})
@@ -529,10 +530,33 @@ def _extract_xml_fragment(text: str) -> str:
return text
def _count_zero_coordinate_elements(xml: str) -> tuple[int, int]:
"""统计坐标无效(x=0 或 y=0 或 width=0 或 height=0)的 reportElement 数量。
返回 (zero_count, total_count)。
"""
total = 0
zero = 0
for m in re.finditer(r'<reportElement\b([^>]*)/>', xml):
total += 1
attrs = m.group(1)
xm = re.search(r'\sx\s*=\s*"(\d+)"', attrs)
ym = re.search(r'\sy\s*=\s*"(\d+)"', attrs)
wm = re.search(r'\swidth\s*=\s*"(\d+)"', attrs)
hm = re.search(r'\sheight\s*=\s*"(\d+)"', attrs)
x = int(xm.group(1)) if xm else 0
y = int(ym.group(1)) if ym else 0
w = int(wm.group(1)) if wm else 0
h = int(hm.group(1)) if hm else 0
if x == 0 or y == 0 or w == 0 or h == 0:
zero += 1
return zero, total
def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
"""程序化字段映射:将 $F{{field_N}} 替换为 OCR 提取的真实字段名。
纯正则替换,不调 LLM。100% 确定性,零内容丢失。
未映射的 field_N 会被重命名为基于波段上下文的描述性名称。
"""
result = jrxml
for i, f in enumerate(ocr_fields):
@@ -543,13 +567,45 @@ def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
real_name = _sanitize_field_name(raw_name)
if real_name == placeholder:
continue
# 替换 field 声明: <ns0:field name="field_1" → <ns0:field name="customer_name"
result = re.sub(
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
rf'\g<1>{real_name}\g<2>', result,
)
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{real_name}}}')
# 第二遍:为剩余未映射的 field_N 赋予基于波段位置的描述性名称
remaining = set()
for m in re.finditer(r'\$F\{(field_\d+)\}', result):
remaining.add(m.group(1))
if remaining:
_SECTION_TAGS = (
"title", "pageHeader", "columnHeader", "detail", "columnFooter",
"pageFooter", "summary", "background", "noData",
)
for placeholder in sorted(remaining, key=lambda x: int(re.search(r'\d+', x).group())):
n = int(re.search(r'\d+', placeholder).group())
# 查找第一个引用此字段的位置,确定波段上下文
pattern = rf'\$F\{{{re.escape(placeholder)}\}}'
m = re.search(pattern, result)
section = "data"
if m:
before = result[:m.start()]
# 从后往前找最近的 section 标签
for tag in _SECTION_TAGS:
# 找最近的未闭合 section 标签
opens = [o.start() for o in re.finditer(rf'<{tag}>', before)]
closes = [o.start() for o in re.finditer(rf'</{tag}>', before)]
last_open = opens[-1] if opens else -1
last_close = closes[-1] if closes else -1
if last_open > last_close:
section = tag
break
new_name = f"{section}_f{n}"
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{new_name}}}')
result = re.sub(
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
rf'\g<1>{new_name}\g<2>', result,
)
return result
@@ -943,9 +999,18 @@ def refine_layout(state: AgentState) -> Dict:
content = response.content if hasattr(response, "content") else str(response)
fragment = _extract_xml_fragment(content)
if fragment:
band_results.append(fragment)
writer({"type": "stream", "node": "refine_layout",
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
zero_count, total = _count_zero_coordinate_elements(fragment)
if total > 0 and zero_count / total > 0.3:
_node_log.warning(
"refine_layout 窗口 %s/%d 零坐标元素 %d/%d (%.0f%%),使用原文",
band.label, wi + 1, zero_count, total,
zero_count / total * 100,
)
band_results.append(win_xml)
else:
band_results.append(fragment)
writer({"type": "stream", "node": "refine_layout",
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
else:
_node_log.warning("refine_layout 窗口 %s/%d 返回空,使用原文",
band.label, wi + 1)