"""OCR 单据字段精确提取器。 两阶段提取流程: 阶段1 - 文档分析: 复用 file_parser.parse_file() 和 layout_analyzer.analyze_layout() 获取每个文本元素的精确坐标和内容 阶段2 - 字段提取: 给定目标字段列表,通过四种策略(精确KV匹配、模糊KV匹配、 正则模式匹配、表格结构匹配)提取字段值、位置和置信度 用法: from backend.ocr_extractor import OcrExtractor extractor = OcrExtractor() result = extractor.extract("invoice.png", ["发票代码", "发票号码", "合计金额"]) for field in result: print(f"{field['field_name']}: {field['field_value']} (置信度: {field['confidence']})") """ import os import re from dataclasses import dataclass, field from pathlib import Path from typing import Any, Optional from dotenv import load_dotenv load_dotenv() OCR_USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() in ("true", "1", "yes") OCR_CONFIDENCE_THRESHOLD = float(os.getenv("OCR_CONFIDENCE_THRESHOLD", "0.5")) @dataclass class OcrTextElement: """OCR 文本元素,包含精确坐标和内容。""" text: str x_min: float y_min: float x_max: float y_max: float confidence: float = 1.0 @property def center_x(self) -> float: return (self.x_min + self.x_max) / 2 @property def center_y(self) -> float: return (self.y_min + self.y_max) / 2 @property def width(self) -> float: return self.x_max - self.x_min @property def height(self) -> float: return self.y_max - self.y_min @property def bbox(self) -> list[float]: return [self.x_min, self.y_min, self.x_max, self.y_max] @dataclass class ExtractedField: """提取的字段结果。""" field_name: str field_value: str bbox: list[float] confidence: float extraction_method: str @dataclass class ExtractionResult: """单次提取的完整结果。""" file_path: str image_size: tuple[int, int] fields: list[ExtractedField] = field(default_factory=list) all_elements: list[OcrTextElement] = field(default_factory=list) errors: list[str] = field(default_factory=list) ocr_available: bool = False def to_dict(self) -> dict: return { "file_path": self.file_path, "image_size": self.image_size, "ocr_available": self.ocr_available, "fields": [ { "field_name": f.field_name, "field_value": f.field_value, "bbox": f.bbox, "confidence": f.confidence, "extraction_method": f.extraction_method, } for f in self.fields ], "total_elements": len(self.all_elements), "errors": self.errors, } class OcrExtractor: """OCR 单据字段精确提取器。 两阶段流水线: 阶段1: 对上传图片进行 OCR + 版面分析,产出带坐标的文本元素列表 阶段2: 根据目标字段列表,按优先级逐一尝试四种提取策略 """ def __init__( self, use_gpu: bool = False, confidence_threshold: float = 0.5, ): """初始化提取器。 Args: use_gpu: 是否使用 GPU 加速 OCR(需要相应驱动) confidence_threshold: OCR 文本置信度最低阈值,低于此值的元素被忽略 """ self.use_gpu = use_gpu if use_gpu else OCR_USE_GPU self.confidence_threshold = ( confidence_threshold if confidence_threshold != 0.5 else OCR_CONFIDENCE_THRESHOLD ) # ======================================================================== # 公共接口 # ======================================================================== def extract( self, file_path: str, target_fields: list[str], ) -> dict: """执行两阶段 OCR 字段提取。 Args: file_path: 图片文件路径(支持 png/jpg/jpeg/bmp/webp) target_fields: 需要提取的字段名称列表,如 ["发票代码", "发票号码", "合计金额"] Returns: 提取结果字典,格式见 ExtractionResult.to_dict() """ result = ExtractionResult(file_path=file_path, image_size=(0, 0)) if not Path(file_path).exists(): result.errors.append(f"文件不存在: {file_path}") return result.to_dict() elements, image_size = self._analyze_document(file_path) result.image_size = image_size result.all_elements = elements if not elements: result.ocr_available = self._check_ocr_availability() if not result.ocr_available: result.errors.append( "OCR 引擎不可用,请安装 easyocr (pip install easyocr) 或 paddleocr" ) else: result.errors.append("图片未检测到文字元素") return result.to_dict() result.ocr_available = True for field_name in target_fields: extracted = self._extract_field(field_name, elements) if extracted: result.fields.append(extracted) else: result.fields.append( ExtractedField( field_name=field_name, field_value="", bbox=[], confidence=0.0, extraction_method="none", ) ) return result.to_dict() def extract_from_layout_result( self, layout_result: dict, target_fields: list[str], ) -> dict: """直接从 layout_analyzer.analyze_layout() 的结果中提取字段。 当已有版面分析结果时,跳过阶段1的重复 OCR,直接进入阶段2。 Args: layout_result: analyze_layout() 的返回值 target_fields: 需要提取的字段名称列表 Returns: 提取结果字典 """ rows = layout_result.get("rows", []) if not rows: return ExtractionResult( file_path="(from layout)", image_size=layout_result.get("image_size", (0, 0)), errors=["版面分析结果中没有文本行"], ).to_dict() elements = [] for row in rows: for elem_data in row.get("elements", []): elements.append( OcrTextElement( text=elem_data.get("text", ""), x_min=elem_data.get("x", 0), y_min=elem_data.get("y", 0), x_max=elem_data.get("x", 0) + elem_data.get("w", 0), y_max=elem_data.get("y", 0) + elem_data.get("h", 0), ) ) result = ExtractionResult( file_path="(from layout)", image_size=layout_result.get("image_size", (0, 0)), all_elements=elements, ocr_available=True, ) for field_name in target_fields: extracted = self._extract_field(field_name, elements) if extracted: result.fields.append(extracted) else: result.fields.append( ExtractedField( field_name=field_name, field_value="", bbox=[], confidence=0.0, extraction_method="none", ) ) return result.to_dict() # ======================================================================== # 阶段1: 文档分析 # ======================================================================== def _analyze_document(self, file_path: str) -> tuple[list[OcrTextElement], tuple[int, int]]: """阶段1: OCR + 版面分析,产出带坐标的文本元素列表。""" from backend.layout_analyzer import _load_image, _ocr_elements img = _load_image(Path(file_path)) if img is None: return [], (0, 0) image_size = img.size raw_elements = self._ocr_elements_enhanced(img, file_path) elements = [] for e_data in raw_elements: if e_data.get("confidence", 1.0) < self.confidence_threshold: continue elements.append( OcrTextElement( text=e_data.get("text", ""), x_min=e_data.get("x", 0), y_min=e_data.get("y", 0), x_max=e_data.get("x", 0) + e_data.get("w", 0), y_max=e_data.get("y", 0) + e_data.get("h", 0), confidence=e_data.get("confidence", 1.0), ) ) elements.sort(key=lambda e: (e.y_min, e.x_min)) return elements, image_size def _ocr_elements_enhanced(self, img, file_path: str) -> list[dict]: """增强版 OCR,返回带置信度的元素列表。""" try: import numpy as np paddleocr_result = self._try_paddleocr(img, file_path) if paddleocr_result: return paddleocr_result easyocr_result = self._try_easyocr(np.array(img)) if easyocr_result: return easyocr_result except Exception: pass return [] def _try_easyocr(self, np_img) -> Optional[list[dict]]: try: import easyocr reader = easyocr.Reader( ["ch_sim", "en"], gpu=self.use_gpu, verbose=False, ) raw_result = reader.readtext(np_img) elements = [] for bbox, text, confidence in raw_result: if not text.strip(): continue xs = [p[0] for p in bbox] ys = [p[1] for p in bbox] x_min, x_max = min(xs), max(xs) y_min, y_max = min(ys), max(ys) elements.append({ "x": round(x_min, 1), "y": round(y_min, 1), "w": round(x_max - x_min, 1), "h": round(y_max - y_min, 1), "text": text.strip(), "confidence": round(confidence, 4), }) elements.sort(key=lambda e: (e["y"], e["x"])) return elements except ImportError: return None except Exception: return None def _try_paddleocr(self, img, file_path: str) -> Optional[list[dict]]: try: from paddleocr import PaddleOCR import numpy as np ocr = PaddleOCR(lang="ch") raw_result = ocr.ocr(np.array(img)) elements = [] if raw_result and raw_result[0]: for line in raw_result[0]: if len(line) < 2: continue box = line[0] text_info = line[1] if isinstance(text_info, (list, tuple)): text = text_info[0] confidence = text_info[1] if len(text_info) > 1 else 1.0 else: text = str(text_info) confidence = 1.0 if not text.strip(): continue xs = [p[0] for p in box] ys = [p[1] for p in box] x_min, x_max = min(xs), max(xs) y_min, y_max = min(ys), max(ys) elements.append({ "x": round(x_min, 1), "y": round(y_min, 1), "w": round(x_max - x_min, 1), "h": round(y_max - y_min, 1), "text": text.strip(), "confidence": round(float(confidence), 4), }) elements.sort(key=lambda e: (e["y"], e["x"])) return elements except ImportError: return None except Exception: return None def _check_ocr_availability(self) -> bool: try: import easyocr return True except ImportError: pass try: import paddleocr return True except ImportError: pass return False # ======================================================================== # 阶段2: 字段精确提取 # ======================================================================== def _extract_field( self, field_name: str, elements: list[OcrTextElement], ) -> Optional[ExtractedField]: """按优先级尝试四种策略提取单个字段。 策略优先级: 1. 精确键值对匹配 2. 模糊键值对匹配 3. 正则模式匹配 4. 表格结构匹配 """ strategies = [ ("exact_match", self._exact_kv_match), ("kv_pair", self._fuzzy_kv_match), ("regex", self._regex_match), ("table_match", self._table_match), ] for method_name, strategy_fn in strategies: result = strategy_fn(field_name, elements) if result and result.field_value: result.extraction_method = method_name return result return None # ----------------------------------------------------------------------- # 策略1: 精确键值对匹配 # ----------------------------------------------------------------------- def _exact_kv_match( self, field_name: str, elements: list[OcrTextElement], ) -> Optional[ExtractedField]: """精确键值对匹配: 识别"字段名: 值"或"字段名:值"模式。 在同一文本元素中查找 "字段名" 后紧跟分隔符 + "值" 的模式。 如 OCR 识别出 "发票代码: 12345678" 这一整个元素。 """ separators = [":", ":", "=", "-", "—", ":", "\t", "|"] field_name_clean = field_name.strip() for elem in elements: text = elem.text if field_name_clean not in text: continue for sep in separators: pattern = re.escape(field_name_clean) + r"\s*" + re.escape(sep) + r"\s*(.+)" m = re.search(pattern, text) if m: value = m.group(1).strip() if value: return ExtractedField( field_name=field_name, field_value=value, bbox=elem.bbox, confidence=0.95, extraction_method="", ) simple_pattern = re.escape(field_name_clean) + r"\s+(.+)" m = re.search(simple_pattern, text) if m: value = m.group(1).strip() if value and value != field_name_clean: return ExtractedField( field_name=field_name, field_value=value, bbox=elem.bbox, confidence=0.85, extraction_method="", ) return None # ----------------------------------------------------------------------- # 策略2: 模糊键值对匹配 # ----------------------------------------------------------------------- def _fuzzy_kv_match( self, field_name: str, elements: list[OcrTextElement], ) -> Optional[ExtractedField]: """模糊键值对匹配: 字段名和值分布在相邻的文本元素中。 找到含字段名的元素后,在同一行或相邻元素中查找值。 """ field_name_clean = field_name.strip() field_elem = None for elem in elements: if field_name_clean in elem.text: field_elem = elem break if field_elem is None: matching = [] for elem in elements: sim = self._text_similarity(field_name_clean, elem.text) if sim > 0.6: matching.append((sim, elem)) if matching: matching.sort(key=lambda x: x[0], reverse=True) field_elem = matching[0][1] if field_elem is None: return None candidates = [] for elem in elements: if elem is field_elem: continue candidates.append(elem) same_row = [] for elem in candidates: if abs(elem.center_y - field_elem.center_y) < field_elem.height * 1.5: same_row.append(elem) if same_row: same_row.sort(key=lambda e: e.x_min) for elem in same_row: if elem.x_min > field_elem.x_max: return ExtractedField( field_name=field_name, field_value=elem.text, bbox=elem.bbox, confidence=0.75, extraction_method="", ) nearest = None nearest_dist = float("inf") for elem in candidates: if elem.y_min > field_elem.y_max: dy = elem.y_min - field_elem.y_max dx = abs(elem.center_x - field_elem.center_x) dist = dy + dx * 0.3 if dist < nearest_dist and dy < field_elem.height * 3: nearest_dist = dist nearest = elem if nearest: return ExtractedField( field_name=field_name, field_value=nearest.text, bbox=nearest.bbox, confidence=0.6, extraction_method="", ) return None # ----------------------------------------------------------------------- # 策略3: 正则模式匹配 # ----------------------------------------------------------------------- PREDEFINED_PATTERNS: dict[str, str] = { "发票代码": r"[0-9A-Za-z]{10,12}", "发票号码": r"\d{8}", "合计金额": r"[\d,]+\.?\d*", "金额": r"[\d,]+\.?\d*", "开票日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?", "日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?", "校验码": r"[0-9A-Fa-f]{5,20}", "总价": r"[\d,]+\.?\d*", "总金额": r"[\d,]+\.?\d*", "价税合计": r"[\d,]+\.?\d*", "数量": r"\d+\.?\d*", "单价": r"[\d,]+\.?\d*", "税率": r"\d+\.?\d*%?", } def _regex_match( self, field_name: str, elements: list[OcrTextElement], ) -> Optional[ExtractedField]: """正则模式匹配: 根据字段名选择预定义的正则模式,在所有元素中搜索。""" pattern = self.PREDEFINED_PATTERNS.get(field_name) if not pattern: for key, pat in self.PREDEFINED_PATTERNS.items(): if key in field_name or field_name in key: pattern = pat break if not pattern: return None compiled = re.compile(r"^\s*" + pattern + r"\s*$") for elem in elements: if compiled.match(elem.text): return ExtractedField( field_name=field_name, field_value=elem.text.strip(), bbox=elem.bbox, confidence=0.7, extraction_method="", ) compiled_partial = re.compile(pattern) for elem in elements: m = compiled_partial.search(elem.text) if m: return ExtractedField( field_name=field_name, field_value=m.group(0), bbox=elem.bbox, confidence=0.6, extraction_method="", ) return None # ----------------------------------------------------------------------- # 策略4: 表格结构匹配 # ----------------------------------------------------------------------- def _table_match( self, field_name: str, elements: list[OcrTextElement], ) -> Optional[ExtractedField]: """表格结构匹配: 将元素按行列分组,查找表头-值对应关系。 识别逻辑: 1. 将元素按 Y 坐标分组为"行" 2. 查找包含 field_name 的表头行 3. 在表头列对应的数据行中取值 """ if len(elements) < 3: return None rows = self._group_elements_by_rows(elements) if len(rows) < 2: return None header_row_idx = -1 header_col_idx = -1 for ri, row in enumerate(rows): for ci, elem in enumerate(row): if field_name in elem.text: header_row_idx = ri header_col_idx = ci break if header_row_idx >= 0: break if header_row_idx < 0: for ri, row in enumerate(rows): for ci, elem in enumerate(row): sim = self._text_similarity(field_name, elem.text) if sim > 0.5: header_row_idx = ri header_col_idx = ci break if header_row_idx >= 0: break if header_row_idx < 0: return None data_rows = rows[header_row_idx + 1:] if not data_rows: data_rows = [rows[header_row_idx]] matched_elem = None for row in data_rows: if header_col_idx < len(row): matched_elem = row[header_col_idx] break closest = None min_dist = float("inf") header_x = float("inf") if header_col_idx < len(rows[header_row_idx]): header_x = rows[header_row_idx][header_col_idx].center_x for elem in row: dist = abs(elem.center_x - header_x) if dist < min_dist: min_dist = dist closest = elem if closest: matched_elem = closest break if matched_elem and matched_elem.text != field_name: return ExtractedField( field_name=field_name, field_value=matched_elem.text, bbox=matched_elem.bbox, confidence=0.55, extraction_method="", ) return None # ======================================================================== # 工具方法 # ======================================================================== @staticmethod def _group_elements_by_rows( elements: list[OcrTextElement], ) -> list[list[OcrTextElement]]: """将元素按 Y 坐标分组为行(容差为元素平均高度的一半)。""" if not elements: return [] avg_height = sum(e.height for e in elements) / len(elements) tolerance = max(avg_height * 0.5, 5.0) rows = [] current_row = [elements[0]] for elem in elements[1:]: prev_center_y = current_row[0].center_y if abs(elem.center_y - prev_center_y) < tolerance: current_row.append(elem) else: current_row.sort(key=lambda e: e.x_min) rows.append(current_row) current_row = [elem] if current_row: current_row.sort(key=lambda e: e.x_min) rows.append(current_row) return rows @staticmethod def _text_similarity(text1: str, text2: str) -> float: """计算两个文本的简单相似度(公共字符比例)。""" if not text1 or not text2: return 0.0 t1 = text1.lower().strip() t2 = text2.lower().strip() if t1 == t2: return 1.0 if t1 in t2 or t2 in t1: return 0.8 chars1 = set(t1) chars2 = set(t2) if not chars1: return 0.0 intersection = chars1 & chars2 return len(intersection) / len(chars1) def extract_ocr_fields( file_path: str, target_fields: list[str], use_gpu: bool = False, confidence_threshold: float = 0.5, ) -> dict: """便捷函数: 对指定图片执行 OCR 字段提取。 Args: file_path: 图片文件路径 target_fields: 目标字段名列表 use_gpu: 是否使用 GPU 加速 confidence_threshold: OCR 置信度阈值 Returns: 提取结果字典 """ extractor = OcrExtractor( use_gpu=use_gpu, confidence_threshold=confidence_threshold, ) return extractor.extract(file_path, target_fields) def extract_from_layout( layout_result: dict, target_fields: list[str], confidence_threshold: float = 0.5, ) -> dict: """便捷函数: 从已有的版面分析结果中提取字段。 Args: layout_result: analyze_layout() 的返回值 target_fields: 目标字段名列表 confidence_threshold: OCR 置信度阈值 Returns: 提取结果字典 """ extractor = OcrExtractor(confidence_threshold=confidence_threshold) return extractor.extract_from_layout_result(layout_result, target_fields)