WIP: baseline on fix/retry-failure-root-causes
This commit is contained in:
+122
-15
@@ -136,13 +136,13 @@ class OcrExtractor:
|
||||
def extract(
|
||||
self,
|
||||
file_path: str,
|
||||
target_fields: list[str],
|
||||
target_fields: Optional[list[str]] = None,
|
||||
) -> dict:
|
||||
"""执行两阶段 OCR 字段提取。
|
||||
|
||||
Args:
|
||||
file_path: 图片文件路径(支持 png/jpg/jpeg/bmp/webp)
|
||||
target_fields: 需要提取的字段名称列表,如 ["发票代码", "发票号码", "合计金额"]
|
||||
target_fields: 需要提取的字段名称列表。为空或 None 时自动发现文档中所有键值对。
|
||||
|
||||
Returns:
|
||||
提取结果字典,格式见 ExtractionResult.to_dict()
|
||||
@@ -168,20 +168,40 @@ class OcrExtractor:
|
||||
return result.to_dict()
|
||||
|
||||
result.ocr_available = True
|
||||
for field_name in target_fields:
|
||||
extracted = self._extract_field(field_name, elements)
|
||||
if extracted:
|
||||
result.fields.append(extracted)
|
||||
else:
|
||||
result.fields.append(
|
||||
ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value="",
|
||||
bbox=[],
|
||||
confidence=0.0,
|
||||
extraction_method="none",
|
||||
|
||||
if target_fields:
|
||||
# 有预设字段名:按名单查找
|
||||
for field_name in target_fields:
|
||||
extracted = self._extract_field(field_name, elements)
|
||||
if extracted:
|
||||
result.fields.append(extracted)
|
||||
else:
|
||||
result.fields.append(
|
||||
ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value="",
|
||||
bbox=[],
|
||||
confidence=0.0,
|
||||
extraction_method="none",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 无预设字段名:自动发现文档中所有键值对
|
||||
discovered = self._discover_fields(elements)
|
||||
for field in discovered:
|
||||
extracted = self._extract_field(field, elements)
|
||||
if extracted:
|
||||
result.fields.append(extracted)
|
||||
else:
|
||||
result.fields.append(
|
||||
ExtractedField(
|
||||
field_name=field,
|
||||
field_value="",
|
||||
bbox=[],
|
||||
confidence=0.0,
|
||||
extraction_method="none",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
@@ -396,6 +416,83 @@ class OcrExtractor:
|
||||
# 阶段2: 字段精确提取
|
||||
# ========================================================================
|
||||
|
||||
def _discover_fields(self, elements: list[OcrTextElement]) -> list[str]:
|
||||
"""自动发现文档中的字段名(无需预设字段列表)。
|
||||
|
||||
策略:
|
||||
1. 单元素内"标签: 值"模式 — 从中提取标签
|
||||
2. 同行相邻键值对 — 短文本(标签) + 长文本(值)
|
||||
3. 表头行 — 首行/第二行的文本作为列字段名
|
||||
"""
|
||||
separators = [":", ":", "=", "—"]
|
||||
discovered: set[str] = set()
|
||||
elements_sorted = sorted(elements, key=lambda e: (e.y_min, e.x_min))
|
||||
|
||||
# 策略 1: 单元素内嵌键值对
|
||||
for elem in elements:
|
||||
text = elem.text
|
||||
for sep in separators:
|
||||
if sep in text:
|
||||
parts = text.split(sep, 1)
|
||||
label = parts[0].strip()
|
||||
value = parts[1].strip()
|
||||
if label and value and len(label) <= 20 and label != value:
|
||||
discovered.add(label)
|
||||
|
||||
# 策略 2: 同行相邻键值对(标签在左,值在右)
|
||||
# 按行分组
|
||||
rows: dict[int, list[OcrTextElement]] = {}
|
||||
for elem in elements_sorted:
|
||||
row_key = int(elem.y_min)
|
||||
for existing_key in list(rows.keys()):
|
||||
if abs(int(elem.y_min) - existing_key) < 10:
|
||||
row_key = existing_key
|
||||
break
|
||||
if row_key not in rows:
|
||||
rows[row_key] = []
|
||||
rows[row_key].append(elem)
|
||||
|
||||
for row_elems in rows.values():
|
||||
row_elems.sort(key=lambda e: e.x_min)
|
||||
for i in range(len(row_elems) - 1):
|
||||
left = row_elems[i]
|
||||
right = row_elems[i + 1]
|
||||
# 左边是短文本(可能标签),右边是相邻的正常文本(可能值)
|
||||
if (len(left.text) <= 15 and len(right.text) > 0
|
||||
and abs(right.x_min - left.x_max) < left.width * 3):
|
||||
# 左边不含仅数字/金额模式(这些更可能是值)
|
||||
if not re.match(r'^[\d,.]+\s*%?$', left.text.strip()):
|
||||
discovered.add(left.text.strip())
|
||||
|
||||
# 策略 3: 表头行 — 取前两行中较短的元素作为字段名候选
|
||||
sorted_row_keys = sorted(rows.keys())
|
||||
header_rows = sorted_row_keys[:min(3, len(sorted_row_keys))]
|
||||
for row_key in header_rows:
|
||||
for elem in rows.get(row_key, []):
|
||||
text = elem.text.strip()
|
||||
if text and len(text) <= 20 and not re.match(r'^[\d,.]+\s*%?$', text):
|
||||
discovered.add(text)
|
||||
|
||||
# 去重合并:移除值文本中误识别为标签的条目
|
||||
# 排除纯数字、日期、金额等明显是值的文本
|
||||
value_patterns = [
|
||||
r'^\d{1,2}[月/-]\d{1,2}[日/-]?\d{0,4}$',
|
||||
r'^[\d,]+\.?\d*\s*%?$',
|
||||
r'^[¥¥]\s*[\d,]+\.?\d*$',
|
||||
r'^\d{3,}$',
|
||||
]
|
||||
filtered = set()
|
||||
for name in discovered:
|
||||
is_value = False
|
||||
for pat in value_patterns:
|
||||
if re.match(pat, name):
|
||||
is_value = True
|
||||
break
|
||||
if not is_value:
|
||||
filtered.add(name)
|
||||
|
||||
return sorted(filtered)
|
||||
|
||||
def _extract_field(
|
||||
self,
|
||||
field_name: str,
|
||||
@@ -558,6 +655,7 @@ class OcrExtractor:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
PREDEFINED_PATTERNS: dict[str, str] = {
|
||||
# 发票字段
|
||||
"发票代码": r"[0-9A-Za-z]{10,12}",
|
||||
"发票号码": r"\d{8}",
|
||||
"合计金额": r"[\d,]+\.?\d*",
|
||||
@@ -571,6 +669,15 @@ class OcrExtractor:
|
||||
"数量": r"\d+\.?\d*",
|
||||
"单价": r"[\d,]+\.?\d*",
|
||||
"税率": r"\d+\.?\d*%?",
|
||||
# 车历卡/维修结算单字段
|
||||
"维修单号": r"[A-Za-z0-9\-]{6,20}",
|
||||
"车牌号": r"[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤川青藏琼宁][A-Z][·\-]?[A-Z0-9]{5,6}",
|
||||
"联系电话": r"1[3-9]\d{9}",
|
||||
"VIN码": r"[A-HJ-NPR-Z0-9]{17}",
|
||||
"发动机号": r"[A-Z0-9]{6,12}",
|
||||
# 采购单字段
|
||||
"采购日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
|
||||
"订单号": r"[A-Z0-9\-]{6,20}",
|
||||
}
|
||||
|
||||
def _regex_match(
|
||||
|
||||
Reference in New Issue
Block a user