diff --git a/.env.example b/.env.example index 83f5941..84e8ed7 100644 --- a/.env.example +++ b/.env.example @@ -63,3 +63,9 @@ HISTORY_MAX_SNAPSHOTS=10 # 意图识别模型(默认使用主 LLM 模型) # INTENT_MODEL=gpt-4o-mini + +# OCR 字段提取配置 +# 是否使用 GPU 加速 OCR(需要 CUDA 驱动和 GPU 版 EasyOCR/PaddleOCR) +OCR_USE_GPU=false +# OCR 文本置信度最低阈值(0-1),低于此值的元素将被忽略 +OCR_CONFIDENCE_THRESHOLD=0.5 diff --git a/agent/nodes.py b/agent/nodes.py index 0c3163c..0afc2a4 100644 --- a/agent/nodes.py +++ b/agent/nodes.py @@ -7,6 +7,7 @@ import os import re import time from datetime import datetime, timezone +from pathlib import Path from typing import Dict from dotenv import load_dotenv @@ -114,6 +115,30 @@ def process_input(state: AgentState) -> Dict: conv_history.append({"role": "user", "content": user_input}) state["conversation_history"] = conv_history + # OCR 单据字段精确提取(处理上传的图片文件) + uploaded_path = state.get("uploaded_file_path", "") + if uploaded_path and Path(uploaded_path).is_file(): + suffix = Path(uploaded_path).suffix.lower() + if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"): + try: + from backend.ocr_extractor import OcrExtractor + extractor = OcrExtractor() + ocr_result = extractor.extract(uploaded_path, []) + if ocr_result.get("ocr_available"): + state["ocr_extraction_result"] = ocr_result + _node_log.info( + "OCR 字段提取完成", + extra={ + "file": uploaded_path, + "elements": ocr_result.get("total_elements", 0), + "fields": len(ocr_result.get("fields", [])), + }, + ) + except Exception as e: + _node_log.warning(f"OCR 字段提取失败: {e}") + state["ocr_extraction_result"] = {"error": str(e)} + state["uploaded_file_path"] = "" + # 重置本轮请求字段 state["retry_count"] = 0 state["user_modification_request"] = user_input diff --git a/agent/state.py b/agent/state.py index 849ea3a..b787ebb 100644 --- a/agent/state.py +++ b/agent/state.py @@ -40,3 +40,7 @@ class AgentState(TypedDict, total=False): # 需求6:失败上下文传递 — 重试耗尽后暂存失败信息,下次用户输入时自动注入 pending_failure_context: dict + + # 需求7:OCR 单据字段精确提取结果 + ocr_extraction_result: dict + uploaded_file_path: str diff --git a/app.py b/app.py index f37321a..f02b576 100644 --- a/app.py +++ b/app.py @@ -261,6 +261,14 @@ def run_agent(user_input: str): if stream_active: streaming_placeholder.empty() + # 清理已处理的临时文件 + for p in st.session_state.get("uploaded_temp_paths", []): + try: + Path(p).unlink(missing_ok=True) + except Exception: + pass + st.session_state.uploaded_temp_paths = [] + # ---- 总结卡片 ---- # 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态 final_state = agent_state @@ -324,6 +332,30 @@ def run_agent(user_input: str): "content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。", "type": "error_explanation", }) + + # OCR 字段提取结果展示 + ocr_result = agent_state.get("ocr_extraction_result", {}) + if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"): + with st.expander("🔍 OCR 单据字段提取结果", expanded=False): + fields = ocr_result.get("fields", []) + non_empty = [f for f in fields if f.get("field_value")] + empty = [f for f in fields if not f.get("field_value")] + if non_empty: + st.markdown("**已提取字段:**") + for f in non_empty: + method = f.get("extraction_method", "") + conf = f.get("confidence", 0) + st.markdown( + f"- **{f['field_name']}**: `{f['field_value']}` " + f"(置信度: {conf:.0%}, 方法: {method})" + ) + if empty: + st.caption( + f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}" + ) + st.caption( + f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素" + ) else: st.error("未产生结果,请重试。") @@ -443,6 +475,9 @@ with st.sidebar: if "uploaded_files" not in st.session_state: st.session_state.uploaded_files = [] # [{name, text, type}] + if "uploaded_temp_paths" not in st.session_state: + st.session_state.uploaded_temp_paths = [] # 待清理的临时文件路径 + uploaded = st.file_uploader( "选择文件", type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"], @@ -513,8 +548,6 @@ with st.sidebar: ) parsed_type = "image_reference" - Path(tmp_path).unlink(missing_ok=True) - if parsed_text: st.session_state.uploaded_files.append({ "name": uf.name, @@ -522,6 +555,14 @@ with st.sidebar: "type": parsed_type, }) + # 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段) + img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp") + if suffix in img_suffixes and result.get("method") not in ("metadata_only", None): + st.session_state.agent_state["uploaded_file_path"] = tmp_path + st.session_state.uploaded_temp_paths.append(tmp_path) + else: + Path(tmp_path).unlink(missing_ok=True) + if st.session_state.uploaded_files: for i, f in enumerate(st.session_state.uploaded_files): cols = st.columns([5, 1]) diff --git a/backend/ocr_extractor.py b/backend/ocr_extractor.py new file mode 100644 index 0000000..7cd9843 --- /dev/null +++ b/backend/ocr_extractor.py @@ -0,0 +1,796 @@ +"""OCR 单据字段精确提取器。 + +两阶段提取流程: + 阶段1 - 文档分析: 复用 file_parser.parse_file() 和 layout_analyzer.analyze_layout() + 获取每个文本元素的精确坐标和内容 + 阶段2 - 字段提取: 给定目标字段列表,通过四种策略(精确KV匹配、模糊KV匹配、 + 正则模式匹配、表格结构匹配)提取字段值、位置和置信度 + +用法: + from backend.ocr_extractor import OcrExtractor + + extractor = OcrExtractor() + result = extractor.extract("invoice.png", ["发票代码", "发票号码", "合计金额"]) + for field in result: + print(f"{field['field_name']}: {field['field_value']} (置信度: {field['confidence']})") +""" + +import os +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +from dotenv import load_dotenv + +load_dotenv() + +OCR_USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() in ("true", "1", "yes") +OCR_CONFIDENCE_THRESHOLD = float(os.getenv("OCR_CONFIDENCE_THRESHOLD", "0.5")) + + +@dataclass +class OcrTextElement: + """OCR 文本元素,包含精确坐标和内容。""" + + text: str + x_min: float + y_min: float + x_max: float + y_max: float + confidence: float = 1.0 + + @property + def center_x(self) -> float: + return (self.x_min + self.x_max) / 2 + + @property + def center_y(self) -> float: + return (self.y_min + self.y_max) / 2 + + @property + def width(self) -> float: + return self.x_max - self.x_min + + @property + def height(self) -> float: + return self.y_max - self.y_min + + @property + def bbox(self) -> list[float]: + return [self.x_min, self.y_min, self.x_max, self.y_max] + + +@dataclass +class ExtractedField: + """提取的字段结果。""" + + field_name: str + field_value: str + bbox: list[float] + confidence: float + extraction_method: str + + +@dataclass +class ExtractionResult: + """单次提取的完整结果。""" + + file_path: str + image_size: tuple[int, int] + fields: list[ExtractedField] = field(default_factory=list) + all_elements: list[OcrTextElement] = field(default_factory=list) + errors: list[str] = field(default_factory=list) + ocr_available: bool = False + + def to_dict(self) -> dict: + return { + "file_path": self.file_path, + "image_size": self.image_size, + "ocr_available": self.ocr_available, + "fields": [ + { + "field_name": f.field_name, + "field_value": f.field_value, + "bbox": f.bbox, + "confidence": f.confidence, + "extraction_method": f.extraction_method, + } + for f in self.fields + ], + "total_elements": len(self.all_elements), + "errors": self.errors, + } + + +class OcrExtractor: + """OCR 单据字段精确提取器。 + + 两阶段流水线: + 阶段1: 对上传图片进行 OCR + 版面分析,产出带坐标的文本元素列表 + 阶段2: 根据目标字段列表,按优先级逐一尝试四种提取策略 + """ + + def __init__( + self, + use_gpu: bool = False, + confidence_threshold: float = 0.5, + ): + """初始化提取器。 + + Args: + use_gpu: 是否使用 GPU 加速 OCR(需要相应驱动) + confidence_threshold: OCR 文本置信度最低阈值,低于此值的元素被忽略 + """ + self.use_gpu = use_gpu if use_gpu else OCR_USE_GPU + self.confidence_threshold = ( + confidence_threshold + if confidence_threshold != 0.5 + else OCR_CONFIDENCE_THRESHOLD + ) + + # ======================================================================== + # 公共接口 + # ======================================================================== + + def extract( + self, + file_path: str, + target_fields: list[str], + ) -> dict: + """执行两阶段 OCR 字段提取。 + + Args: + file_path: 图片文件路径(支持 png/jpg/jpeg/bmp/webp) + target_fields: 需要提取的字段名称列表,如 ["发票代码", "发票号码", "合计金额"] + + Returns: + 提取结果字典,格式见 ExtractionResult.to_dict() + """ + result = ExtractionResult(file_path=file_path, image_size=(0, 0)) + + if not Path(file_path).exists(): + result.errors.append(f"文件不存在: {file_path}") + return result.to_dict() + + elements, image_size = self._analyze_document(file_path) + result.image_size = image_size + result.all_elements = elements + + if not elements: + result.ocr_available = self._check_ocr_availability() + if not result.ocr_available: + result.errors.append( + "OCR 引擎不可用,请安装 easyocr (pip install easyocr) 或 paddleocr" + ) + else: + result.errors.append("图片未检测到文字元素") + return result.to_dict() + + result.ocr_available = True + for field_name in target_fields: + extracted = self._extract_field(field_name, elements) + if extracted: + result.fields.append(extracted) + else: + result.fields.append( + ExtractedField( + field_name=field_name, + field_value="", + bbox=[], + confidence=0.0, + extraction_method="none", + ) + ) + + return result.to_dict() + + def extract_from_layout_result( + self, + layout_result: dict, + target_fields: list[str], + ) -> dict: + """直接从 layout_analyzer.analyze_layout() 的结果中提取字段。 + + 当已有版面分析结果时,跳过阶段1的重复 OCR,直接进入阶段2。 + + Args: + layout_result: analyze_layout() 的返回值 + target_fields: 需要提取的字段名称列表 + + Returns: + 提取结果字典 + """ + rows = layout_result.get("rows", []) + if not rows: + return ExtractionResult( + file_path="(from layout)", + image_size=layout_result.get("image_size", (0, 0)), + errors=["版面分析结果中没有文本行"], + ).to_dict() + + elements = [] + for row in rows: + for elem_data in row.get("elements", []): + elements.append( + OcrTextElement( + text=elem_data.get("text", ""), + x_min=elem_data.get("x", 0), + y_min=elem_data.get("y", 0), + x_max=elem_data.get("x", 0) + elem_data.get("w", 0), + y_max=elem_data.get("y", 0) + elem_data.get("h", 0), + ) + ) + + result = ExtractionResult( + file_path="(from layout)", + image_size=layout_result.get("image_size", (0, 0)), + all_elements=elements, + ocr_available=True, + ) + + for field_name in target_fields: + extracted = self._extract_field(field_name, elements) + if extracted: + result.fields.append(extracted) + else: + result.fields.append( + ExtractedField( + field_name=field_name, + field_value="", + bbox=[], + confidence=0.0, + extraction_method="none", + ) + ) + + return result.to_dict() + + # ======================================================================== + # 阶段1: 文档分析 + # ======================================================================== + + def _analyze_document(self, file_path: str) -> tuple[list[OcrTextElement], tuple[int, int]]: + """阶段1: OCR + 版面分析,产出带坐标的文本元素列表。""" + from backend.layout_analyzer import _load_image, _ocr_elements + + img = _load_image(Path(file_path)) + if img is None: + return [], (0, 0) + + image_size = img.size + raw_elements = self._ocr_elements_enhanced(img, file_path) + + elements = [] + for e_data in raw_elements: + if e_data.get("confidence", 1.0) < self.confidence_threshold: + continue + elements.append( + OcrTextElement( + text=e_data.get("text", ""), + x_min=e_data.get("x", 0), + y_min=e_data.get("y", 0), + x_max=e_data.get("x", 0) + e_data.get("w", 0), + y_max=e_data.get("y", 0) + e_data.get("h", 0), + confidence=e_data.get("confidence", 1.0), + ) + ) + + elements.sort(key=lambda e: (e.y_min, e.x_min)) + return elements, image_size + + def _ocr_elements_enhanced(self, img, file_path: str) -> list[dict]: + """增强版 OCR,返回带置信度的元素列表。""" + try: + import numpy as np + + easyocr_result = self._try_easyocr(np.array(img)) + if easyocr_result: + return easyocr_result + + paddleocr_result = self._try_paddleocr(img, file_path) + if paddleocr_result: + return paddleocr_result + except Exception: + pass + + return [] + + def _try_easyocr(self, np_img) -> Optional[list[dict]]: + try: + import easyocr + + reader = easyocr.Reader( + ["ch_sim", "en"], + gpu=self.use_gpu, + verbose=False, + ) + raw_result = reader.readtext(np_img) + + elements = [] + for bbox, text, confidence in raw_result: + if not text.strip(): + continue + xs = [p[0] for p in bbox] + ys = [p[1] for p in bbox] + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + + elements.append({ + "x": round(x_min, 1), + "y": round(y_min, 1), + "w": round(x_max - x_min, 1), + "h": round(y_max - y_min, 1), + "text": text.strip(), + "confidence": round(confidence, 4), + }) + + elements.sort(key=lambda e: (e["y"], e["x"])) + return elements + except ImportError: + return None + except Exception: + return None + + def _try_paddleocr(self, img, file_path: str) -> Optional[list[dict]]: + try: + from paddleocr import PaddleOCR + import numpy as np + + ocr = PaddleOCR(lang="ch") + raw_result = ocr.ocr(np.array(img)) + + elements = [] + if raw_result and raw_result[0]: + for line in raw_result[0]: + if len(line) < 2: + continue + box = line[0] + text_info = line[1] + + if isinstance(text_info, (list, tuple)): + text = text_info[0] + confidence = text_info[1] if len(text_info) > 1 else 1.0 + else: + text = str(text_info) + confidence = 1.0 + + if not text.strip(): + continue + + xs = [p[0] for p in box] + ys = [p[1] for p in box] + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + + elements.append({ + "x": round(x_min, 1), + "y": round(y_min, 1), + "w": round(x_max - x_min, 1), + "h": round(y_max - y_min, 1), + "text": text.strip(), + "confidence": round(float(confidence), 4), + }) + + elements.sort(key=lambda e: (e["y"], e["x"])) + return elements + except ImportError: + return None + except Exception: + return None + + def _check_ocr_availability(self) -> bool: + try: + import easyocr + return True + except ImportError: + pass + try: + import paddleocr + return True + except ImportError: + pass + return False + + # ======================================================================== + # 阶段2: 字段精确提取 + # ======================================================================== + + def _extract_field( + self, + field_name: str, + elements: list[OcrTextElement], + ) -> Optional[ExtractedField]: + """按优先级尝试四种策略提取单个字段。 + + 策略优先级: + 1. 精确键值对匹配 + 2. 模糊键值对匹配 + 3. 正则模式匹配 + 4. 表格结构匹配 + """ + strategies = [ + ("exact_match", self._exact_kv_match), + ("kv_pair", self._fuzzy_kv_match), + ("regex", self._regex_match), + ("table_match", self._table_match), + ] + + for method_name, strategy_fn in strategies: + result = strategy_fn(field_name, elements) + if result and result.field_value: + result.extraction_method = method_name + return result + + return None + + # ----------------------------------------------------------------------- + # 策略1: 精确键值对匹配 + # ----------------------------------------------------------------------- + + def _exact_kv_match( + self, + field_name: str, + elements: list[OcrTextElement], + ) -> Optional[ExtractedField]: + """精确键值对匹配: 识别"字段名: 值"或"字段名:值"模式。 + + 在同一文本元素中查找 "字段名" 后紧跟分隔符 + "值" 的模式。 + 如 OCR 识别出 "发票代码: 12345678" 这一整个元素。 + """ + separators = [":", ":", "=", "-", "—", ":", "\t", "|"] + field_name_clean = field_name.strip() + + for elem in elements: + text = elem.text + if field_name_clean not in text: + continue + + for sep in separators: + pattern = re.escape(field_name_clean) + r"\s*" + re.escape(sep) + r"\s*(.+)" + m = re.search(pattern, text) + if m: + value = m.group(1).strip() + if value: + return ExtractedField( + field_name=field_name, + field_value=value, + bbox=elem.bbox, + confidence=0.95, + extraction_method="", + ) + + simple_pattern = re.escape(field_name_clean) + r"\s+(.+)" + m = re.search(simple_pattern, text) + if m: + value = m.group(1).strip() + if value and value != field_name_clean: + return ExtractedField( + field_name=field_name, + field_value=value, + bbox=elem.bbox, + confidence=0.85, + extraction_method="", + ) + + return None + + # ----------------------------------------------------------------------- + # 策略2: 模糊键值对匹配 + # ----------------------------------------------------------------------- + + def _fuzzy_kv_match( + self, + field_name: str, + elements: list[OcrTextElement], + ) -> Optional[ExtractedField]: + """模糊键值对匹配: 字段名和值分布在相邻的文本元素中。 + + 找到含字段名的元素后,在同一行或相邻元素中查找值。 + """ + field_name_clean = field_name.strip() + field_elem = None + + for elem in elements: + if field_name_clean in elem.text: + field_elem = elem + break + + if field_elem is None: + matching = [] + for elem in elements: + sim = self._text_similarity(field_name_clean, elem.text) + if sim > 0.6: + matching.append((sim, elem)) + if matching: + matching.sort(key=lambda x: x[0], reverse=True) + field_elem = matching[0][1] + + if field_elem is None: + return None + + candidates = [] + for elem in elements: + if elem is field_elem: + continue + candidates.append(elem) + + same_row = [] + for elem in candidates: + if abs(elem.center_y - field_elem.center_y) < field_elem.height * 1.5: + same_row.append(elem) + if same_row: + same_row.sort(key=lambda e: e.x_min) + for elem in same_row: + if elem.x_min > field_elem.x_max: + return ExtractedField( + field_name=field_name, + field_value=elem.text, + bbox=elem.bbox, + confidence=0.75, + extraction_method="", + ) + + nearest = None + nearest_dist = float("inf") + for elem in candidates: + if elem.y_min > field_elem.y_max: + dy = elem.y_min - field_elem.y_max + dx = abs(elem.center_x - field_elem.center_x) + dist = dy + dx * 0.3 + if dist < nearest_dist and dy < field_elem.height * 3: + nearest_dist = dist + nearest = elem + + if nearest: + return ExtractedField( + field_name=field_name, + field_value=nearest.text, + bbox=nearest.bbox, + confidence=0.6, + extraction_method="", + ) + + return None + + # ----------------------------------------------------------------------- + # 策略3: 正则模式匹配 + # ----------------------------------------------------------------------- + + PREDEFINED_PATTERNS: dict[str, str] = { + "发票代码": r"[0-9A-Za-z]{10,12}", + "发票号码": r"\d{8}", + "合计金额": r"[\d,]+\.?\d*", + "金额": r"[\d,]+\.?\d*", + "开票日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?", + "日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?", + "校验码": r"[0-9A-Fa-f]{5,20}", + "总价": r"[\d,]+\.?\d*", + "总金额": r"[\d,]+\.?\d*", + "价税合计": r"[\d,]+\.?\d*", + "数量": r"\d+\.?\d*", + "单价": r"[\d,]+\.?\d*", + "税率": r"\d+\.?\d*%?", + } + + def _regex_match( + self, + field_name: str, + elements: list[OcrTextElement], + ) -> Optional[ExtractedField]: + """正则模式匹配: 根据字段名选择预定义的正则模式,在所有元素中搜索。""" + pattern = self.PREDEFINED_PATTERNS.get(field_name) + if not pattern: + for key, pat in self.PREDEFINED_PATTERNS.items(): + if key in field_name or field_name in key: + pattern = pat + break + + if not pattern: + return None + + compiled = re.compile(r"^\s*" + pattern + r"\s*$") + for elem in elements: + if compiled.match(elem.text): + return ExtractedField( + field_name=field_name, + field_value=elem.text.strip(), + bbox=elem.bbox, + confidence=0.7, + extraction_method="", + ) + + compiled_partial = re.compile(pattern) + for elem in elements: + m = compiled_partial.search(elem.text) + if m: + return ExtractedField( + field_name=field_name, + field_value=m.group(0), + bbox=elem.bbox, + confidence=0.6, + extraction_method="", + ) + + return None + + # ----------------------------------------------------------------------- + # 策略4: 表格结构匹配 + # ----------------------------------------------------------------------- + + def _table_match( + self, + field_name: str, + elements: list[OcrTextElement], + ) -> Optional[ExtractedField]: + """表格结构匹配: 将元素按行列分组,查找表头-值对应关系。 + + 识别逻辑: + 1. 将元素按 Y 坐标分组为"行" + 2. 查找包含 field_name 的表头行 + 3. 在表头列对应的数据行中取值 + """ + if len(elements) < 3: + return None + + rows = self._group_elements_by_rows(elements) + if len(rows) < 2: + return None + + header_row_idx = -1 + header_col_idx = -1 + + for ri, row in enumerate(rows): + for ci, elem in enumerate(row): + if field_name in elem.text: + header_row_idx = ri + header_col_idx = ci + break + if header_row_idx >= 0: + break + + if header_row_idx < 0: + for ri, row in enumerate(rows): + for ci, elem in enumerate(row): + sim = self._text_similarity(field_name, elem.text) + if sim > 0.5: + header_row_idx = ri + header_col_idx = ci + break + if header_row_idx >= 0: + break + + if header_row_idx < 0: + return None + + data_rows = rows[header_row_idx + 1:] + if not data_rows: + data_rows = [rows[header_row_idx]] + + matched_elem = None + for row in data_rows: + if header_col_idx < len(row): + matched_elem = row[header_col_idx] + break + closest = None + min_dist = float("inf") + header_x = float("inf") + if header_col_idx < len(rows[header_row_idx]): + header_x = rows[header_row_idx][header_col_idx].center_x + for elem in row: + dist = abs(elem.center_x - header_x) + if dist < min_dist: + min_dist = dist + closest = elem + if closest: + matched_elem = closest + break + + if matched_elem and matched_elem.text != field_name: + return ExtractedField( + field_name=field_name, + field_value=matched_elem.text, + bbox=matched_elem.bbox, + confidence=0.55, + extraction_method="", + ) + + return None + + # ======================================================================== + # 工具方法 + # ======================================================================== + + @staticmethod + def _group_elements_by_rows( + elements: list[OcrTextElement], + ) -> list[list[OcrTextElement]]: + """将元素按 Y 坐标分组为行(容差为元素平均高度的一半)。""" + if not elements: + return [] + + avg_height = sum(e.height for e in elements) / len(elements) + tolerance = max(avg_height * 0.5, 5.0) + + rows = [] + current_row = [elements[0]] + + for elem in elements[1:]: + prev_center_y = current_row[0].center_y + if abs(elem.center_y - prev_center_y) < tolerance: + current_row.append(elem) + else: + current_row.sort(key=lambda e: e.x_min) + rows.append(current_row) + current_row = [elem] + + if current_row: + current_row.sort(key=lambda e: e.x_min) + rows.append(current_row) + + return rows + + @staticmethod + def _text_similarity(text1: str, text2: str) -> float: + """计算两个文本的简单相似度(公共字符比例)。""" + if not text1 or not text2: + return 0.0 + + t1 = text1.lower().strip() + t2 = text2.lower().strip() + + if t1 == t2: + return 1.0 + if t1 in t2 or t2 in t1: + return 0.8 + + chars1 = set(t1) + chars2 = set(t2) + if not chars1: + return 0.0 + + intersection = chars1 & chars2 + return len(intersection) / len(chars1) + + +def extract_ocr_fields( + file_path: str, + target_fields: list[str], + use_gpu: bool = False, + confidence_threshold: float = 0.5, +) -> dict: + """便捷函数: 对指定图片执行 OCR 字段提取。 + + Args: + file_path: 图片文件路径 + target_fields: 目标字段名列表 + use_gpu: 是否使用 GPU 加速 + confidence_threshold: OCR 置信度阈值 + + Returns: + 提取结果字典 + """ + extractor = OcrExtractor( + use_gpu=use_gpu, + confidence_threshold=confidence_threshold, + ) + return extractor.extract(file_path, target_fields) + + +def extract_from_layout( + layout_result: dict, + target_fields: list[str], + confidence_threshold: float = 0.5, +) -> dict: + """便捷函数: 从已有的版面分析结果中提取字段。 + + Args: + layout_result: analyze_layout() 的返回值 + target_fields: 目标字段名列表 + confidence_threshold: OCR 置信度阈值 + + Returns: + 提取结果字典 + """ + extractor = OcrExtractor(confidence_threshold=confidence_threshold) + return extractor.extract_from_layout_result(layout_result, target_fields) diff --git a/tests/test_ocr_extraction.py b/tests/test_ocr_extraction.py new file mode 100644 index 0000000..439873c --- /dev/null +++ b/tests/test_ocr_extraction.py @@ -0,0 +1,543 @@ +"""OCR 单据字段提取器单元测试。 + +覆盖: + - OcrTextElement / ExtractedField / ExtractionResult 数据结构 + - 四种提取策略的独立测试 + - 坐标计算正确性 + - 边界情况(空元素、无匹配、部分匹配) +""" + +import sys +import os +import math +from pathlib import Path + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import pytest + +from backend.ocr_extractor import ( + OcrTextElement, + ExtractedField, + ExtractionResult, + OcrExtractor, + extract_ocr_fields, + extract_from_layout, +) + + +class TestOcrTextElement: + """测试 OcrTextElement 数据类。""" + + def test_bbox_property(self): + elem = OcrTextElement( + text="发票代码", + x_min=100.0, + y_min=50.0, + x_max=250.0, + y_max=80.0, + ) + assert elem.bbox == [100.0, 50.0, 250.0, 80.0] + + def test_center_coordinates(self): + elem = OcrTextElement( + text="测试", + x_min=100.0, + y_min=50.0, + x_max=200.0, + y_max=100.0, + ) + assert elem.center_x == 150.0 + assert elem.center_y == 75.0 + + def test_width_height(self): + elem = OcrTextElement( + text="测试", + x_min=10.0, + y_min=20.0, + x_max=100.0, + y_max=70.0, + ) + assert elem.width == 90.0 + assert elem.height == 50.0 + + def test_default_confidence(self): + elem = OcrTextElement( + text="测试", + x_min=0, + y_min=0, + x_max=10, + y_max=10, + ) + assert elem.confidence == 1.0 + + +class TestExtractedField: + """测试 ExtractedField 数据类。""" + + def test_field_creation(self): + field = ExtractedField( + field_name="发票代码", + field_value="1234567890", + bbox=[100, 50, 200, 80], + confidence=0.95, + extraction_method="exact_match", + ) + assert field.field_name == "发票代码" + assert field.field_value == "1234567890" + assert field.bbox == [100, 50, 200, 80] + assert field.confidence == 0.95 + assert field.extraction_method == "exact_match" + + +class TestExtractionResult: + """测试 ExtractionResult 数据类。""" + + def test_to_dict_basic(self): + result = ExtractionResult( + file_path="/test/invoice.png", + image_size=(800, 600), + ocr_available=True, + ) + d = result.to_dict() + assert d["file_path"] == "/test/invoice.png" + assert d["image_size"] == (800, 600) + assert d["ocr_available"] is True + assert d["fields"] == [] + assert d["total_elements"] == 0 + assert d["errors"] == [] + + def test_to_dict_with_fields(self): + result = ExtractionResult( + file_path="/test/invoice.png", + image_size=(800, 600), + ocr_available=True, + ) + result.fields.append( + ExtractedField( + field_name="发票代码", + field_value="1234567890", + bbox=[100, 50, 200, 80], + confidence=0.95, + extraction_method="exact_match", + ) + ) + d = result.to_dict() + assert len(d["fields"]) == 1 + assert d["fields"][0]["field_name"] == "发票代码" + assert d["fields"][0]["field_value"] == "1234567890" + assert d["fields"][0]["bbox"] == [100, 50, 200, 80] + assert d["fields"][0]["confidence"] == 0.95 + assert d["fields"][0]["extraction_method"] == "exact_match" + + def test_to_dict_with_error(self): + result = ExtractionResult( + file_path="/test/missing.png", + image_size=(0, 0), + errors=["文件不存在: /test/missing.png"], + ) + d = result.to_dict() + assert d["errors"] == ["文件不存在: /test/missing.png"] + + +class TestElementGrouping: + """测试元素行分组功能。""" + + def test_group_single_row(self): + elements = [ + OcrTextElement("A", 10, 50, 50, 70), + OcrTextElement("B", 60, 50, 110, 70), + OcrTextElement("C", 120, 50, 170, 70), + ] + rows = OcrExtractor._group_elements_by_rows(elements) + assert len(rows) == 1 + assert len(rows[0]) == 3 + + def test_group_multiple_rows(self): + elements = [ + OcrTextElement("A1", 10, 50, 50, 70), + OcrTextElement("B1", 60, 50, 110, 70), + OcrTextElement("A2", 10, 120, 50, 140), + OcrTextElement("B2", 60, 120, 110, 140), + OcrTextElement("A3", 10, 200, 50, 220), + ] + rows = OcrExtractor._group_elements_by_rows(elements) + assert len(rows) == 3 + assert len(rows[0]) == 2 + assert len(rows[1]) == 2 + assert len(rows[2]) == 1 + + def test_group_empty(self): + rows = OcrExtractor._group_elements_by_rows([]) + assert rows == [] + + def test_group_single_element(self): + rows = OcrExtractor._group_elements_by_rows([ + OcrTextElement("X", 10, 50, 50, 70) + ]) + assert len(rows) == 1 + assert len(rows[0]) == 1 + + +class TestTextSimilarity: + """测试文本相似度计算。""" + + def test_exact_match(self): + sim = OcrExtractor._text_similarity("发票代码", "发票代码") + assert sim == 1.0 + + def test_partial_match(self): + sim = OcrExtractor._text_similarity("发票代码", "代码") + assert sim > 0.5 + + def test_no_match(self): + sim = OcrExtractor._text_similarity("发票代码", "xyz") + assert sim == 0.0 + + def test_empty_strings(self): + assert OcrExtractor._text_similarity("", "abc") == 0.0 + assert OcrExtractor._text_similarity("abc", "") == 0.0 + assert OcrExtractor._text_similarity("", "") == 0.0 + + def test_substring_match(self): + sim = OcrExtractor._text_similarity("代码", "发票代码") + assert sim > 0.7 + + +class TestExactKVMatch: + """测试策略1: 精确键值对匹配。""" + + def test_colon_separator(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80), + ] + result = extractor._exact_kv_match("发票代码", elements) + assert result is not None + assert result.field_value == "1234567890" + assert result.confidence == 0.95 + + def test_chinese_colon(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("发票号码:87654321", 50, 50, 300, 80), + ] + result = extractor._exact_kv_match("发票号码", elements) + assert result is not None + assert result.field_value == "87654321" + + def test_space_separator(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("合计金额 999.00", 50, 50, 300, 80), + ] + result = extractor._exact_kv_match("合计金额", elements) + assert result is not None + assert result.field_value == "999.00" + + def test_equals_separator(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("数量=5", 50, 50, 200, 80), + ] + result = extractor._exact_kv_match("数量", elements) + assert result is not None + assert result.field_value == "5" + + def test_field_not_found(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("其他内容: 123", 50, 50, 200, 80), + ] + result = extractor._exact_kv_match("发票代码", elements) + assert result is None + + +class TestFuzzyKVMatch: + """测试策略2: 模糊键值对匹配。""" + + def test_adjacent_same_row(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("发票代码", 50, 50, 150, 80), + OcrTextElement("1234567890", 200, 50, 350, 80), + ] + result = extractor._fuzzy_kv_match("发票代码", elements) + assert result is not None + assert result.field_value == "1234567890" + assert result.confidence == 0.75 + + def test_adjacent_next_row(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("发票代码", 50, 50, 150, 80), + OcrTextElement("1234567890", 50, 100, 200, 130), + ] + result = extractor._fuzzy_kv_match("发票代码", elements) + assert result is not None + assert result.field_value == "1234567890" + + def test_field_name_not_found(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("其他信息", 50, 50, 150, 80), + OcrTextElement("1234567890", 200, 50, 350, 80), + ] + result = extractor._fuzzy_kv_match("发票代码", elements) + assert result is None + + +class TestRegexMatch: + """测试策略3: 正则模式匹配。""" + + def test_invoice_code_pattern(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("1234567890", 100, 50, 250, 80), + ] + result = extractor._regex_match("发票代码", elements) + assert result is not None + assert result.field_value == "1234567890" + + def test_invoice_number_pattern(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("87654321", 100, 50, 200, 80), + ] + result = extractor._regex_match("发票号码", elements) + assert result is not None + assert result.field_value == "87654321" + + def test_amount_pattern(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("1,234.56", 100, 50, 200, 80), + ] + result = extractor._regex_match("合计金额", elements) + assert result is not None + assert "1,234.56" in result.field_value + + def test_date_pattern(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("2024年1月15日", 100, 50, 250, 80), + ] + result = extractor._regex_match("开票日期", elements) + assert result is not None + assert "2024" in result.field_value + + def test_unknown_field_no_pattern(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("随便什么内容", 100, 50, 250, 80), + ] + result = extractor._regex_match("未知字段", elements) + assert result is None + + +class TestTableMatch: + """测试策略4: 表格结构匹配。""" + + def test_simple_table(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("名称", 50, 50, 150, 80), + OcrTextElement("数量", 200, 50, 300, 80), + OcrTextElement("单价", 350, 50, 450, 80), + OcrTextElement("商品A", 50, 100, 150, 130), + OcrTextElement("2", 200, 100, 250, 130), + OcrTextElement("10.00", 350, 100, 420, 130), + ] + result = extractor._table_match("数量", elements) + assert result is not None + assert result.field_value == "2" + + def test_table_with_fuzzy_header(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("品名", 50, 50, 120, 80), + OcrTextElement("金额", 200, 50, 300, 80), + OcrTextElement("苹果", 50, 100, 120, 130), + OcrTextElement("5.00", 200, 100, 260, 130), + ] + result = extractor._table_match("合计金额", elements) + # 金额列可能匹配到 "金额" + if result: + assert result.field_value == "5.00" + + def test_table_header_not_found(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("A", 50, 50, 100, 80), + OcrTextElement("B", 150, 50, 200, 80), + ] + result = extractor._table_match("发票代码", elements) + assert result is None + + def test_table_too_few_elements(self): + extractor = OcrExtractor() + elements = [ + OcrTextElement("名称", 50, 50, 150, 80), + ] + result = extractor._table_match("名称", elements) + assert result is None + + +class TestCoordinateCorrectness: + """测试坐标计算正确性。""" + + def test_bbox_origin_top_left(self): + """验证坐标系统以左上角为原点。""" + elem = OcrTextElement("A", 0, 0, 50, 20) + assert elem.x_min == 0 + assert elem.y_min == 0 + assert elem.bbox[0] == 0 + assert elem.bbox[1] == 0 + + def test_bbox_conversion(self): + """验证从 (x, y, w, h) 到 [x_min, y_min, x_max, y_max] 的转换。""" + x, y, w, h = 100, 200, 300, 50 + elem = OcrTextElement("test", x_min=x, y_min=y, x_max=x + w, y_max=y + h) + assert elem.bbox == [x, y, x + w, y + h] + assert elem.x_max - elem.x_min == w + assert elem.y_max - elem.y_min == h + + def test_multiple_elements_coordinate_independence(self): + """验证多个元素的坐标互不干扰。""" + elements = [ + OcrTextElement("A", 10, 20, 60, 40), + OcrTextElement("B", 100, 200, 180, 240), + ] + assert elements[0].bbox == [10, 20, 60, 40] + assert elements[1].bbox == [100, 200, 180, 240] + assert elements[0].x_min != elements[1].x_min + + def test_center_calculation(self): + """验证中心点计算。""" + elem = OcrTextElement("test", 0, 0, 100, 100) + assert elem.center_x == 50.0 + assert elem.center_y == 50.0 + + +class TestExtractionPipeline: + """测试完整的提取流水线。""" + + def test_priority_order_exact_first(self): + """验证策略优先级: exact_match 优先于其他策略。""" + extractor = OcrExtractor() + elements = [ + OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80), + ] + result = extractor._extract_field("发票代码", elements) + assert result is not None + assert result.extraction_method == "exact_match" + + def test_fallback_to_fuzzy(self): + """验证精确匹配失败后回退到模糊匹配。""" + extractor = OcrExtractor() + elements = [ + OcrTextElement("发票代码", 50, 50, 150, 80), + OcrTextElement("1234567890", 200, 50, 350, 80), + ] + result = extractor._extract_field("发票代码", elements) + assert result is not None + assert result.extraction_method == "kv_pair" + + def test_all_fields_empty(self): + """验证所有字段提取失败时返回空结果。""" + extractor = OcrExtractor() + elements = [ + OcrTextElement("一些不相关的文本", 50, 50, 300, 80), + OcrTextElement("更多随机内容", 50, 100, 300, 130), + ] + result = extractor._extract_field("发票代码", elements) + assert result is None + + def test_empty_elements_list(self): + """验证空元素列表时正常处理。""" + extractor = OcrExtractor() + result = extractor._extract_field("发票代码", []) + assert result is None + + def test_extraction_with_confidence(self): + """验证提取结果的置信度在合理范围内。""" + extractor = OcrExtractor() + elements = [ + OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80), + ] + result = extractor._extract_field("发票代码", elements) + assert result is not None + assert 0.0 < result.confidence <= 1.0 + + +class TestExtractFromLayout: + """测试从 layout 结果中提取字段。""" + + def test_basic_layout_extraction(self): + layout_result = { + "image_size": (800, 600), + "template_type": "full_a4", + "rows": [ + { + "y_center": 50, + "elements": [ + {"x": 50, "y": 30, "w": 150, "h": 30, "text": "发票代码:"}, + {"x": 250, "y": 30, "w": 200, "h": 30, "text": "1234567890"}, + ], + }, + { + "y_center": 80, + "elements": [ + {"x": 50, "y": 60, "w": 150, "h": 30, "text": "合计金额:"}, + {"x": 250, "y": 60, "w": 100, "h": 30, "text": "999.00"}, + ], + }, + ], + } + result = extract_from_layout(layout_result, ["发票代码"]) + assert result["ocr_available"] is True + assert result["image_size"] == (800, 600) + + def test_empty_layout(self): + result = extract_from_layout({}, ["发票代码"]) + assert result["ocr_available"] is False + assert len(result["errors"]) > 0 + + +class TestOcrExtractorFileNotFound: + """测试文件不存在的情况。""" + + def test_missing_file(self): + extractor = OcrExtractor() + result = extractor.extract("/nonexistent/file.png", ["发票代码"]) + assert result["ocr_available"] is False + assert len(result["errors"]) > 0 + assert "文件不存在" in result["errors"][0] + + +class TestConvenienceFunctions: + """测试便捷函数。""" + + def test_extract_ocr_fields_missing_file(self): + result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"]) + assert len(result["errors"]) > 0 + + def test_extract_from_layout_with_partial_rows(self): + layout_result = { + "image_size": (1200, 800), + "template_type": "partial_rows", + "rows": [ + { + "y_center": 100, + "elements": [ + {"x": 100, "y": 80, "w": 120, "h": 40, "text": "发票代码"}, + {"x": 300, "y": 80, "w": 180, "h": 40, "text": "NO123456"}, + ], + }, + ], + } + result = extract_from_layout(layout_result, ["发票代码"]) + assert result["ocr_available"] is True + assert len(result["fields"]) == 1 + assert result["fields"][0]["field_name"] == "发票代码" + assert result["fields"][0]["field_value"] == "NO123456"