Files
agent_jrxml/backend/ocr_extractor.py

912 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""OCR 单据字段精确提取器。
两阶段提取流程:
阶段1 - 文档分析: 复用 file_parser.parse_file() 和 layout_analyzer.analyze_layout()
获取每个文本元素的精确坐标和内容
阶段2 - 字段提取: 给定目标字段列表,通过四种策略(精确KV匹配、模糊KV匹配、
正则模式匹配、表格结构匹配)提取字段值、位置和置信度
用法:
from backend.ocr_extractor import OcrExtractor
extractor = OcrExtractor()
result = extractor.extract("invoice.png", ["发票代码", "发票号码", "合计金额"])
for field in result:
print(f"{field['field_name']}: {field['field_value']} (置信度: {field['confidence']})")
"""
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
from dotenv import load_dotenv
load_dotenv()
OCR_USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() in ("true", "1", "yes")
OCR_CONFIDENCE_THRESHOLD = float(os.getenv("OCR_CONFIDENCE_THRESHOLD", "0.5"))
@dataclass
class OcrTextElement:
"""OCR 文本元素,包含精确坐标和内容。"""
text: str
x_min: float
y_min: float
x_max: float
y_max: float
confidence: float = 1.0
@property
def center_x(self) -> float:
return (self.x_min + self.x_max) / 2
@property
def center_y(self) -> float:
return (self.y_min + self.y_max) / 2
@property
def width(self) -> float:
return self.x_max - self.x_min
@property
def height(self) -> float:
return self.y_max - self.y_min
@property
def bbox(self) -> list[float]:
return [self.x_min, self.y_min, self.x_max, self.y_max]
@dataclass
class ExtractedField:
"""提取的字段结果。"""
field_name: str
field_value: str
bbox: list[float]
confidence: float
extraction_method: str
@dataclass
class ExtractionResult:
"""单次提取的完整结果。"""
file_path: str
image_size: tuple[int, int]
fields: list[ExtractedField] = field(default_factory=list)
all_elements: list[OcrTextElement] = field(default_factory=list)
errors: list[str] = field(default_factory=list)
ocr_available: bool = False
def to_dict(self) -> dict:
return {
"file_path": self.file_path,
"image_size": self.image_size,
"ocr_available": self.ocr_available,
"fields": [
{
"field_name": f.field_name,
"field_value": f.field_value,
"bbox": f.bbox,
"confidence": f.confidence,
"extraction_method": f.extraction_method,
}
for f in self.fields
],
"all_elements": [
{
"text": e.text,
"bbox": e.bbox,
"confidence": e.confidence,
}
for e in self.all_elements
],
"total_elements": len(self.all_elements),
"errors": self.errors,
}
class OcrExtractor:
"""OCR 单据字段精确提取器。
两阶段流水线:
阶段1: 对上传图片进行 OCR + 版面分析,产出带坐标的文本元素列表
阶段2: 根据目标字段列表,按优先级逐一尝试四种提取策略
"""
def __init__(
self,
use_gpu: bool = False,
confidence_threshold: float = 0.5,
):
"""初始化提取器。
Args:
use_gpu: 是否使用 GPU 加速 OCR(需要相应驱动)
confidence_threshold: OCR 文本置信度最低阈值,低于此值的元素被忽略
"""
self.use_gpu = use_gpu if use_gpu else OCR_USE_GPU
self.confidence_threshold = (
confidence_threshold
if confidence_threshold != 0.5
else OCR_CONFIDENCE_THRESHOLD
)
# ========================================================================
# 公共接口
# ========================================================================
def extract(
self,
file_path: str,
target_fields: Optional[list[str]] = None,
) -> dict:
"""执行两阶段 OCR 字段提取。
Args:
file_path: 图片文件路径(支持 png/jpg/jpeg/bmp/webp
target_fields: 需要提取的字段名称列表。为空或 None 时自动发现文档中所有键值对。
Returns:
提取结果字典,格式见 ExtractionResult.to_dict()
"""
result = ExtractionResult(file_path=file_path, image_size=(0, 0))
if not Path(file_path).exists():
result.errors.append(f"文件不存在: {file_path}")
return result.to_dict()
elements, image_size = self._analyze_document(file_path)
result.image_size = image_size
result.all_elements = elements
if not elements:
result.ocr_available = self._check_ocr_availability()
if not result.ocr_available:
result.errors.append(
"OCR 引擎不可用,请安装 easyocr (pip install easyocr) 或 paddleocr"
)
else:
result.errors.append("图片未检测到文字元素")
return result.to_dict()
result.ocr_available = True
if target_fields:
# 有预设字段名:按名单查找
for field_name in target_fields:
extracted = self._extract_field(field_name, elements)
if extracted:
result.fields.append(extracted)
else:
result.fields.append(
ExtractedField(
field_name=field_name,
field_value="",
bbox=[],
confidence=0.0,
extraction_method="none",
)
)
else:
# 无预设字段名:自动发现文档中所有键值对
discovered = self._discover_fields(elements)
for field in discovered:
extracted = self._extract_field(field, elements)
if extracted:
result.fields.append(extracted)
else:
result.fields.append(
ExtractedField(
field_name=field,
field_value="",
bbox=[],
confidence=0.0,
extraction_method="none",
)
)
return result.to_dict()
def extract_from_layout_result(
self,
layout_result: dict,
target_fields: list[str],
) -> dict:
"""直接从 layout_analyzer.analyze_layout() 的结果中提取字段。
当已有版面分析结果时,跳过阶段1的重复 OCR,直接进入阶段2。
Args:
layout_result: analyze_layout() 的返回值
target_fields: 需要提取的字段名称列表
Returns:
提取结果字典
"""
rows = layout_result.get("rows", [])
if not rows:
return ExtractionResult(
file_path="(from layout)",
image_size=layout_result.get("image_size", (0, 0)),
errors=["版面分析结果中没有文本行"],
).to_dict()
elements = []
for row in rows:
for elem_data in row.get("elements", []):
elements.append(
OcrTextElement(
text=elem_data.get("text", ""),
x_min=elem_data.get("x", 0),
y_min=elem_data.get("y", 0),
x_max=elem_data.get("x", 0) + elem_data.get("w", 0),
y_max=elem_data.get("y", 0) + elem_data.get("h", 0),
)
)
result = ExtractionResult(
file_path="(from layout)",
image_size=layout_result.get("image_size", (0, 0)),
all_elements=elements,
ocr_available=True,
)
for field_name in target_fields:
extracted = self._extract_field(field_name, elements)
if extracted:
result.fields.append(extracted)
else:
result.fields.append(
ExtractedField(
field_name=field_name,
field_value="",
bbox=[],
confidence=0.0,
extraction_method="none",
)
)
return result.to_dict()
# ========================================================================
# 阶段1: 文档分析
# ========================================================================
def _analyze_document(self, file_path: str) -> tuple[list[OcrTextElement], tuple[int, int]]:
"""阶段1: OCR + 版面分析,产出带坐标的文本元素列表。"""
from backend.layout_analyzer import _load_image, _ocr_elements
img = _load_image(Path(file_path))
if img is None:
return [], (0, 0)
image_size = img.size
raw_elements = self._ocr_elements_enhanced(img, file_path)
elements = []
for e_data in raw_elements:
if e_data.get("confidence", 1.0) < self.confidence_threshold:
continue
elements.append(
OcrTextElement(
text=e_data.get("text", ""),
x_min=e_data.get("x", 0),
y_min=e_data.get("y", 0),
x_max=e_data.get("x", 0) + e_data.get("w", 0),
y_max=e_data.get("y", 0) + e_data.get("h", 0),
confidence=e_data.get("confidence", 1.0),
)
)
elements.sort(key=lambda e: (e.y_min, e.x_min))
return elements, image_size
def _ocr_elements_enhanced(self, img, file_path: str) -> list[dict]:
"""增强版 OCR,返回带置信度的元素列表。"""
try:
import numpy as np
paddleocr_result = self._try_paddleocr(img, file_path)
if paddleocr_result:
return paddleocr_result
easyocr_result = self._try_easyocr(np.array(img))
if easyocr_result:
return easyocr_result
except Exception:
pass
return []
def _try_easyocr(self, np_img) -> Optional[list[dict]]:
try:
import easyocr
reader = easyocr.Reader(
["ch_sim", "en"],
gpu=self.use_gpu,
verbose=False,
)
raw_result = reader.readtext(np_img)
elements = []
for bbox, text, confidence in raw_result:
if not text.strip():
continue
xs = [p[0] for p in bbox]
ys = [p[1] for p in bbox]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
elements.append({
"x": round(x_min, 1),
"y": round(y_min, 1),
"w": round(x_max - x_min, 1),
"h": round(y_max - y_min, 1),
"text": text.strip(),
"confidence": round(confidence, 4),
})
elements.sort(key=lambda e: (e["y"], e["x"]))
return elements
except ImportError:
return None
except Exception:
return None
def _try_paddleocr(self, img, file_path: str) -> Optional[list[dict]]:
try:
from paddleocr import PaddleOCR
import numpy as np
ocr = PaddleOCR(lang="ch")
raw_result = ocr.ocr(np.array(img))
elements = []
if raw_result and raw_result[0]:
for line in raw_result[0]:
if len(line) < 2:
continue
box = line[0]
text_info = line[1]
if isinstance(text_info, (list, tuple)):
text = text_info[0]
confidence = text_info[1] if len(text_info) > 1 else 1.0
else:
text = str(text_info)
confidence = 1.0
if not text.strip():
continue
xs = [p[0] for p in box]
ys = [p[1] for p in box]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
elements.append({
"x": round(x_min, 1),
"y": round(y_min, 1),
"w": round(x_max - x_min, 1),
"h": round(y_max - y_min, 1),
"text": text.strip(),
"confidence": round(float(confidence), 4),
})
elements.sort(key=lambda e: (e["y"], e["x"]))
return elements
except ImportError:
return None
except Exception:
return None
def _check_ocr_availability(self) -> bool:
try:
import easyocr
return True
except ImportError:
pass
try:
import paddleocr
return True
except ImportError:
pass
return False
# ========================================================================
# 阶段2: 字段精确提取
# ========================================================================
def _discover_fields(self, elements: list[OcrTextElement]) -> list[str]:
"""自动发现文档中的字段名(无需预设字段列表)。
策略:
1. 单元素内"标签: 值"模式 — 从中提取标签
2. 同行相邻键值对 — 短文本(标签) + 长文本(值)
3. 表头行 — 首行/第二行的文本作为列字段名
"""
separators = [":", "", "=", ""]
discovered: set[str] = set()
elements_sorted = sorted(elements, key=lambda e: (e.y_min, e.x_min))
# 策略 1: 单元素内嵌键值对
for elem in elements:
text = elem.text
for sep in separators:
if sep in text:
parts = text.split(sep, 1)
label = parts[0].strip()
value = parts[1].strip()
if label and value and len(label) <= 20 and label != value:
discovered.add(label)
# 策略 2: 同行相邻键值对(标签在左,值在右)
# 按行分组
rows: dict[int, list[OcrTextElement]] = {}
for elem in elements_sorted:
row_key = int(elem.y_min)
for existing_key in list(rows.keys()):
if abs(int(elem.y_min) - existing_key) < 10:
row_key = existing_key
break
if row_key not in rows:
rows[row_key] = []
rows[row_key].append(elem)
for row_elems in rows.values():
row_elems.sort(key=lambda e: e.x_min)
for i in range(len(row_elems) - 1):
left = row_elems[i]
right = row_elems[i + 1]
# 左边是短文本(可能标签),右边是相邻的正常文本(可能值)
if (len(left.text) <= 15 and len(right.text) > 0
and abs(right.x_min - left.x_max) < left.width * 3):
# 左边不含仅数字/金额模式(这些更可能是值)
if not re.match(r'^[\d,.]+\s*%?$', left.text.strip()):
discovered.add(left.text.strip())
# 策略 3: 表头行 — 取前两行中较短的元素作为字段名候选
sorted_row_keys = sorted(rows.keys())
header_rows = sorted_row_keys[:min(3, len(sorted_row_keys))]
for row_key in header_rows:
for elem in rows.get(row_key, []):
text = elem.text.strip()
if text and len(text) <= 20 and not re.match(r'^[\d,.]+\s*%?$', text):
discovered.add(text)
# 去重合并:移除值文本中误识别为标签的条目
# 排除纯数字、日期、金额等明显是值的文本
value_patterns = [
r'^\d{1,2}[月/-]\d{1,2}[日/-]?\d{0,4}$',
r'^[\d,]+\.?\d*\s*%?$',
r'^[¥¥]\s*[\d,]+\.?\d*$',
r'^\d{3,}$',
]
filtered = set()
for name in discovered:
is_value = False
for pat in value_patterns:
if re.match(pat, name):
is_value = True
break
if not is_value:
filtered.add(name)
return sorted(filtered)
def _extract_field(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""按优先级尝试四种策略提取单个字段。
策略优先级:
1. 精确键值对匹配
2. 模糊键值对匹配
3. 正则模式匹配
4. 表格结构匹配
"""
strategies = [
("exact_match", self._exact_kv_match),
("kv_pair", self._fuzzy_kv_match),
("regex", self._regex_match),
("table_match", self._table_match),
]
for method_name, strategy_fn in strategies:
result = strategy_fn(field_name, elements)
if result and result.field_value:
result.extraction_method = method_name
return result
return None
# -----------------------------------------------------------------------
# 策略1: 精确键值对匹配
# -----------------------------------------------------------------------
def _exact_kv_match(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""精确键值对匹配: 识别"字段名: 值""字段名:值"模式。
在同一文本元素中查找 "字段名" 后紧跟分隔符 + "" 的模式。
如 OCR 识别出 "发票代码: 12345678" 这一整个元素。
"""
separators = [":", "", "=", "-", "", "", "\t", "|"]
field_name_clean = field_name.strip()
for elem in elements:
text = elem.text
if field_name_clean not in text:
continue
for sep in separators:
pattern = re.escape(field_name_clean) + r"\s*" + re.escape(sep) + r"\s*(.+)"
m = re.search(pattern, text)
if m:
value = m.group(1).strip()
if value:
return ExtractedField(
field_name=field_name,
field_value=value,
bbox=elem.bbox,
confidence=0.95,
extraction_method="",
)
simple_pattern = re.escape(field_name_clean) + r"\s+(.+)"
m = re.search(simple_pattern, text)
if m:
value = m.group(1).strip()
if value and value != field_name_clean:
return ExtractedField(
field_name=field_name,
field_value=value,
bbox=elem.bbox,
confidence=0.85,
extraction_method="",
)
return None
# -----------------------------------------------------------------------
# 策略2: 模糊键值对匹配
# -----------------------------------------------------------------------
def _fuzzy_kv_match(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""模糊键值对匹配: 字段名和值分布在相邻的文本元素中。
找到含字段名的元素后,在同一行或相邻元素中查找值。
"""
field_name_clean = field_name.strip()
field_elem = None
for elem in elements:
if field_name_clean in elem.text:
field_elem = elem
break
if field_elem is None:
matching = []
for elem in elements:
sim = self._text_similarity(field_name_clean, elem.text)
if sim > 0.6:
matching.append((sim, elem))
if matching:
matching.sort(key=lambda x: x[0], reverse=True)
field_elem = matching[0][1]
if field_elem is None:
return None
candidates = []
for elem in elements:
if elem is field_elem:
continue
candidates.append(elem)
same_row = []
for elem in candidates:
if abs(elem.center_y - field_elem.center_y) < field_elem.height * 1.5:
same_row.append(elem)
if same_row:
same_row.sort(key=lambda e: e.x_min)
for elem in same_row:
if elem.x_min > field_elem.x_max:
return ExtractedField(
field_name=field_name,
field_value=elem.text,
bbox=elem.bbox,
confidence=0.75,
extraction_method="",
)
nearest = None
nearest_dist = float("inf")
for elem in candidates:
if elem.y_min > field_elem.y_max:
dy = elem.y_min - field_elem.y_max
dx = abs(elem.center_x - field_elem.center_x)
dist = dy + dx * 0.3
if dist < nearest_dist and dy < field_elem.height * 3:
nearest_dist = dist
nearest = elem
if nearest:
return ExtractedField(
field_name=field_name,
field_value=nearest.text,
bbox=nearest.bbox,
confidence=0.6,
extraction_method="",
)
return None
# -----------------------------------------------------------------------
# 策略3: 正则模式匹配
# -----------------------------------------------------------------------
PREDEFINED_PATTERNS: dict[str, str] = {
# 发票字段
"发票代码": r"[0-9A-Za-z]{10,12}",
"发票号码": r"\d{8}",
"合计金额": r"[\d,]+\.?\d*",
"金额": r"[\d,]+\.?\d*",
"开票日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
"日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
"校验码": r"[0-9A-Fa-f]{5,20}",
"总价": r"[\d,]+\.?\d*",
"总金额": r"[\d,]+\.?\d*",
"价税合计": r"[\d,]+\.?\d*",
"数量": r"\d+\.?\d*",
"单价": r"[\d,]+\.?\d*",
"税率": r"\d+\.?\d*%?",
# 车历卡/维修结算单字段
"维修单号": r"[A-Za-z0-9\-]{6,20}",
"车牌号": r"[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤川青藏琼宁][A-Z][·\-]?[A-Z0-9]{5,6}",
"联系电话": r"1[3-9]\d{9}",
"VIN码": r"[A-HJ-NPR-Z0-9]{17}",
"发动机号": r"[A-Z0-9]{6,12}",
# 采购单字段
"采购日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
"订单号": r"[A-Z0-9\-]{6,20}",
}
def _regex_match(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""正则模式匹配: 根据字段名选择预定义的正则模式,在所有元素中搜索。"""
pattern = self.PREDEFINED_PATTERNS.get(field_name)
if not pattern:
for key, pat in self.PREDEFINED_PATTERNS.items():
if key in field_name or field_name in key:
pattern = pat
break
if not pattern:
return None
compiled = re.compile(r"^\s*" + pattern + r"\s*$")
for elem in elements:
if compiled.match(elem.text):
return ExtractedField(
field_name=field_name,
field_value=elem.text.strip(),
bbox=elem.bbox,
confidence=0.7,
extraction_method="",
)
compiled_partial = re.compile(pattern)
for elem in elements:
m = compiled_partial.search(elem.text)
if m:
return ExtractedField(
field_name=field_name,
field_value=m.group(0),
bbox=elem.bbox,
confidence=0.6,
extraction_method="",
)
return None
# -----------------------------------------------------------------------
# 策略4: 表格结构匹配
# -----------------------------------------------------------------------
def _table_match(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""表格结构匹配: 将元素按行列分组,查找表头-值对应关系。
识别逻辑:
1. 将元素按 Y 坐标分组为""
2. 查找包含 field_name 的表头行
3. 在表头列对应的数据行中取值
"""
if len(elements) < 3:
return None
rows = self._group_elements_by_rows(elements)
if len(rows) < 2:
return None
header_row_idx = -1
header_col_idx = -1
for ri, row in enumerate(rows):
for ci, elem in enumerate(row):
if field_name in elem.text:
header_row_idx = ri
header_col_idx = ci
break
if header_row_idx >= 0:
break
if header_row_idx < 0:
for ri, row in enumerate(rows):
for ci, elem in enumerate(row):
sim = self._text_similarity(field_name, elem.text)
if sim > 0.5:
header_row_idx = ri
header_col_idx = ci
break
if header_row_idx >= 0:
break
if header_row_idx < 0:
return None
data_rows = rows[header_row_idx + 1:]
if not data_rows:
data_rows = [rows[header_row_idx]]
matched_elem = None
for row in data_rows:
if header_col_idx < len(row):
matched_elem = row[header_col_idx]
break
closest = None
min_dist = float("inf")
header_x = float("inf")
if header_col_idx < len(rows[header_row_idx]):
header_x = rows[header_row_idx][header_col_idx].center_x
for elem in row:
dist = abs(elem.center_x - header_x)
if dist < min_dist:
min_dist = dist
closest = elem
if closest:
matched_elem = closest
break
if matched_elem and matched_elem.text != field_name:
return ExtractedField(
field_name=field_name,
field_value=matched_elem.text,
bbox=matched_elem.bbox,
confidence=0.55,
extraction_method="",
)
return None
# ========================================================================
# 工具方法
# ========================================================================
@staticmethod
def _group_elements_by_rows(
elements: list[OcrTextElement],
) -> list[list[OcrTextElement]]:
"""将元素按 Y 坐标分组为行(容差为元素平均高度的一半)。"""
if not elements:
return []
avg_height = sum(e.height for e in elements) / len(elements)
tolerance = max(avg_height * 0.5, 5.0)
rows = []
current_row = [elements[0]]
for elem in elements[1:]:
prev_center_y = current_row[0].center_y
if abs(elem.center_y - prev_center_y) < tolerance:
current_row.append(elem)
else:
current_row.sort(key=lambda e: e.x_min)
rows.append(current_row)
current_row = [elem]
if current_row:
current_row.sort(key=lambda e: e.x_min)
rows.append(current_row)
return rows
@staticmethod
def _text_similarity(text1: str, text2: str) -> float:
"""计算两个文本的简单相似度(公共字符比例)。"""
if not text1 or not text2:
return 0.0
t1 = text1.lower().strip()
t2 = text2.lower().strip()
if t1 == t2:
return 1.0
if t1 in t2 or t2 in t1:
return 0.8
chars1 = set(t1)
chars2 = set(t2)
if not chars1:
return 0.0
intersection = chars1 & chars2
return len(intersection) / len(chars1)
def extract_ocr_fields(
file_path: str,
target_fields: list[str],
use_gpu: bool = False,
confidence_threshold: float = 0.5,
) -> dict:
"""便捷函数: 对指定图片执行 OCR 字段提取。
Args:
file_path: 图片文件路径
target_fields: 目标字段名列表
use_gpu: 是否使用 GPU 加速
confidence_threshold: OCR 置信度阈值
Returns:
提取结果字典
"""
extractor = OcrExtractor(
use_gpu=use_gpu,
confidence_threshold=confidence_threshold,
)
return extractor.extract(file_path, target_fields)
def extract_from_layout(
layout_result: dict,
target_fields: list[str],
confidence_threshold: float = 0.5,
) -> dict:
"""便捷函数: 从已有的版面分析结果中提取字段。
Args:
layout_result: analyze_layout() 的返回值
target_fields: 目标字段名列表
confidence_threshold: OCR 置信度阈值
Returns:
提取结果字典
"""
extractor = OcrExtractor(confidence_threshold=confidence_threshold)
return extractor.extract_from_layout_result(layout_result, target_fields)