feat: 新增 OCR 单据字段精确提取模块
- 新增 backend/ocr_extractor.py: 两阶段提取流水线 (文档分析 + 字段提取) - 四种提取策略: 精确KV匹配/模糊KV匹配/正则模式/表格结构匹配 - agent/state.py: 新增 ocr_extraction_result 和 uploaded_file_path 字段 - agent/nodes.py: process_input() 中自动触发 OCR 提取钩子 - app.py: 文件上传时保留图片路径, 总结卡片中展示提取结果 - .env.example: 新增 OCR_USE_GPU / OCR_CONFIDENCE_THRESHOLD 配置项 - tests/test_ocr_extraction.py: 48 个单元测试全部通过
This commit is contained in:
@@ -63,3 +63,9 @@ HISTORY_MAX_SNAPSHOTS=10
|
||||
|
||||
# 意图识别模型(默认使用主 LLM 模型)
|
||||
# INTENT_MODEL=gpt-4o-mini
|
||||
|
||||
# OCR 字段提取配置
|
||||
# 是否使用 GPU 加速 OCR(需要 CUDA 驱动和 GPU 版 EasyOCR/PaddleOCR)
|
||||
OCR_USE_GPU=false
|
||||
# OCR 文本置信度最低阈值(0-1),低于此值的元素将被忽略
|
||||
OCR_CONFIDENCE_THRESHOLD=0.5
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -114,6 +115,30 @@ def process_input(state: AgentState) -> Dict:
|
||||
conv_history.append({"role": "user", "content": user_input})
|
||||
state["conversation_history"] = conv_history
|
||||
|
||||
# OCR 单据字段精确提取(处理上传的图片文件)
|
||||
uploaded_path = state.get("uploaded_file_path", "")
|
||||
if uploaded_path and Path(uploaded_path).is_file():
|
||||
suffix = Path(uploaded_path).suffix.lower()
|
||||
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"):
|
||||
try:
|
||||
from backend.ocr_extractor import OcrExtractor
|
||||
extractor = OcrExtractor()
|
||||
ocr_result = extractor.extract(uploaded_path, [])
|
||||
if ocr_result.get("ocr_available"):
|
||||
state["ocr_extraction_result"] = ocr_result
|
||||
_node_log.info(
|
||||
"OCR 字段提取完成",
|
||||
extra={
|
||||
"file": uploaded_path,
|
||||
"elements": ocr_result.get("total_elements", 0),
|
||||
"fields": len(ocr_result.get("fields", [])),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
_node_log.warning(f"OCR 字段提取失败: {e}")
|
||||
state["ocr_extraction_result"] = {"error": str(e)}
|
||||
state["uploaded_file_path"] = ""
|
||||
|
||||
# 重置本轮请求字段
|
||||
state["retry_count"] = 0
|
||||
state["user_modification_request"] = user_input
|
||||
|
||||
@@ -40,3 +40,7 @@ class AgentState(TypedDict, total=False):
|
||||
|
||||
# 需求6:失败上下文传递 — 重试耗尽后暂存失败信息,下次用户输入时自动注入
|
||||
pending_failure_context: dict
|
||||
|
||||
# 需求7:OCR 单据字段精确提取结果
|
||||
ocr_extraction_result: dict
|
||||
uploaded_file_path: str
|
||||
|
||||
@@ -261,6 +261,14 @@ def run_agent(user_input: str):
|
||||
if stream_active:
|
||||
streaming_placeholder.empty()
|
||||
|
||||
# 清理已处理的临时文件
|
||||
for p in st.session_state.get("uploaded_temp_paths", []):
|
||||
try:
|
||||
Path(p).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
st.session_state.uploaded_temp_paths = []
|
||||
|
||||
# ---- 总结卡片 ----
|
||||
# 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态
|
||||
final_state = agent_state
|
||||
@@ -324,6 +332,30 @@ def run_agent(user_input: str):
|
||||
"content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。",
|
||||
"type": "error_explanation",
|
||||
})
|
||||
|
||||
# OCR 字段提取结果展示
|
||||
ocr_result = agent_state.get("ocr_extraction_result", {})
|
||||
if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"):
|
||||
with st.expander("🔍 OCR 单据字段提取结果", expanded=False):
|
||||
fields = ocr_result.get("fields", [])
|
||||
non_empty = [f for f in fields if f.get("field_value")]
|
||||
empty = [f for f in fields if not f.get("field_value")]
|
||||
if non_empty:
|
||||
st.markdown("**已提取字段:**")
|
||||
for f in non_empty:
|
||||
method = f.get("extraction_method", "")
|
||||
conf = f.get("confidence", 0)
|
||||
st.markdown(
|
||||
f"- **{f['field_name']}**: `{f['field_value']}` "
|
||||
f"(置信度: {conf:.0%}, 方法: {method})"
|
||||
)
|
||||
if empty:
|
||||
st.caption(
|
||||
f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}"
|
||||
)
|
||||
st.caption(
|
||||
f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素"
|
||||
)
|
||||
else:
|
||||
st.error("未产生结果,请重试。")
|
||||
|
||||
@@ -443,6 +475,9 @@ with st.sidebar:
|
||||
if "uploaded_files" not in st.session_state:
|
||||
st.session_state.uploaded_files = [] # [{name, text, type}]
|
||||
|
||||
if "uploaded_temp_paths" not in st.session_state:
|
||||
st.session_state.uploaded_temp_paths = [] # 待清理的临时文件路径
|
||||
|
||||
uploaded = st.file_uploader(
|
||||
"选择文件",
|
||||
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"],
|
||||
@@ -513,8 +548,6 @@ with st.sidebar:
|
||||
)
|
||||
parsed_type = "image_reference"
|
||||
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
if parsed_text:
|
||||
st.session_state.uploaded_files.append({
|
||||
"name": uf.name,
|
||||
@@ -522,6 +555,14 @@ with st.sidebar:
|
||||
"type": parsed_type,
|
||||
})
|
||||
|
||||
# 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段)
|
||||
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||||
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
|
||||
st.session_state.agent_state["uploaded_file_path"] = tmp_path
|
||||
st.session_state.uploaded_temp_paths.append(tmp_path)
|
||||
else:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
if st.session_state.uploaded_files:
|
||||
for i, f in enumerate(st.session_state.uploaded_files):
|
||||
cols = st.columns([5, 1])
|
||||
|
||||
@@ -0,0 +1,796 @@
|
||||
"""OCR 单据字段精确提取器。
|
||||
|
||||
两阶段提取流程:
|
||||
阶段1 - 文档分析: 复用 file_parser.parse_file() 和 layout_analyzer.analyze_layout()
|
||||
获取每个文本元素的精确坐标和内容
|
||||
阶段2 - 字段提取: 给定目标字段列表,通过四种策略(精确KV匹配、模糊KV匹配、
|
||||
正则模式匹配、表格结构匹配)提取字段值、位置和置信度
|
||||
|
||||
用法:
|
||||
from backend.ocr_extractor import OcrExtractor
|
||||
|
||||
extractor = OcrExtractor()
|
||||
result = extractor.extract("invoice.png", ["发票代码", "发票号码", "合计金额"])
|
||||
for field in result:
|
||||
print(f"{field['field_name']}: {field['field_value']} (置信度: {field['confidence']})")
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
OCR_USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() in ("true", "1", "yes")
|
||||
OCR_CONFIDENCE_THRESHOLD = float(os.getenv("OCR_CONFIDENCE_THRESHOLD", "0.5"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class OcrTextElement:
|
||||
"""OCR 文本元素,包含精确坐标和内容。"""
|
||||
|
||||
text: str
|
||||
x_min: float
|
||||
y_min: float
|
||||
x_max: float
|
||||
y_max: float
|
||||
confidence: float = 1.0
|
||||
|
||||
@property
|
||||
def center_x(self) -> float:
|
||||
return (self.x_min + self.x_max) / 2
|
||||
|
||||
@property
|
||||
def center_y(self) -> float:
|
||||
return (self.y_min + self.y_max) / 2
|
||||
|
||||
@property
|
||||
def width(self) -> float:
|
||||
return self.x_max - self.x_min
|
||||
|
||||
@property
|
||||
def height(self) -> float:
|
||||
return self.y_max - self.y_min
|
||||
|
||||
@property
|
||||
def bbox(self) -> list[float]:
|
||||
return [self.x_min, self.y_min, self.x_max, self.y_max]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedField:
|
||||
"""提取的字段结果。"""
|
||||
|
||||
field_name: str
|
||||
field_value: str
|
||||
bbox: list[float]
|
||||
confidence: float
|
||||
extraction_method: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""单次提取的完整结果。"""
|
||||
|
||||
file_path: str
|
||||
image_size: tuple[int, int]
|
||||
fields: list[ExtractedField] = field(default_factory=list)
|
||||
all_elements: list[OcrTextElement] = field(default_factory=list)
|
||||
errors: list[str] = field(default_factory=list)
|
||||
ocr_available: bool = False
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"image_size": self.image_size,
|
||||
"ocr_available": self.ocr_available,
|
||||
"fields": [
|
||||
{
|
||||
"field_name": f.field_name,
|
||||
"field_value": f.field_value,
|
||||
"bbox": f.bbox,
|
||||
"confidence": f.confidence,
|
||||
"extraction_method": f.extraction_method,
|
||||
}
|
||||
for f in self.fields
|
||||
],
|
||||
"total_elements": len(self.all_elements),
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class OcrExtractor:
|
||||
"""OCR 单据字段精确提取器。
|
||||
|
||||
两阶段流水线:
|
||||
阶段1: 对上传图片进行 OCR + 版面分析,产出带坐标的文本元素列表
|
||||
阶段2: 根据目标字段列表,按优先级逐一尝试四种提取策略
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_gpu: bool = False,
|
||||
confidence_threshold: float = 0.5,
|
||||
):
|
||||
"""初始化提取器。
|
||||
|
||||
Args:
|
||||
use_gpu: 是否使用 GPU 加速 OCR(需要相应驱动)
|
||||
confidence_threshold: OCR 文本置信度最低阈值,低于此值的元素被忽略
|
||||
"""
|
||||
self.use_gpu = use_gpu if use_gpu else OCR_USE_GPU
|
||||
self.confidence_threshold = (
|
||||
confidence_threshold
|
||||
if confidence_threshold != 0.5
|
||||
else OCR_CONFIDENCE_THRESHOLD
|
||||
)
|
||||
|
||||
# ========================================================================
|
||||
# 公共接口
|
||||
# ========================================================================
|
||||
|
||||
def extract(
|
||||
self,
|
||||
file_path: str,
|
||||
target_fields: list[str],
|
||||
) -> dict:
|
||||
"""执行两阶段 OCR 字段提取。
|
||||
|
||||
Args:
|
||||
file_path: 图片文件路径(支持 png/jpg/jpeg/bmp/webp)
|
||||
target_fields: 需要提取的字段名称列表,如 ["发票代码", "发票号码", "合计金额"]
|
||||
|
||||
Returns:
|
||||
提取结果字典,格式见 ExtractionResult.to_dict()
|
||||
"""
|
||||
result = ExtractionResult(file_path=file_path, image_size=(0, 0))
|
||||
|
||||
if not Path(file_path).exists():
|
||||
result.errors.append(f"文件不存在: {file_path}")
|
||||
return result.to_dict()
|
||||
|
||||
elements, image_size = self._analyze_document(file_path)
|
||||
result.image_size = image_size
|
||||
result.all_elements = elements
|
||||
|
||||
if not elements:
|
||||
result.ocr_available = self._check_ocr_availability()
|
||||
if not result.ocr_available:
|
||||
result.errors.append(
|
||||
"OCR 引擎不可用,请安装 easyocr (pip install easyocr) 或 paddleocr"
|
||||
)
|
||||
else:
|
||||
result.errors.append("图片未检测到文字元素")
|
||||
return result.to_dict()
|
||||
|
||||
result.ocr_available = True
|
||||
for field_name in target_fields:
|
||||
extracted = self._extract_field(field_name, elements)
|
||||
if extracted:
|
||||
result.fields.append(extracted)
|
||||
else:
|
||||
result.fields.append(
|
||||
ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value="",
|
||||
bbox=[],
|
||||
confidence=0.0,
|
||||
extraction_method="none",
|
||||
)
|
||||
)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
def extract_from_layout_result(
|
||||
self,
|
||||
layout_result: dict,
|
||||
target_fields: list[str],
|
||||
) -> dict:
|
||||
"""直接从 layout_analyzer.analyze_layout() 的结果中提取字段。
|
||||
|
||||
当已有版面分析结果时,跳过阶段1的重复 OCR,直接进入阶段2。
|
||||
|
||||
Args:
|
||||
layout_result: analyze_layout() 的返回值
|
||||
target_fields: 需要提取的字段名称列表
|
||||
|
||||
Returns:
|
||||
提取结果字典
|
||||
"""
|
||||
rows = layout_result.get("rows", [])
|
||||
if not rows:
|
||||
return ExtractionResult(
|
||||
file_path="(from layout)",
|
||||
image_size=layout_result.get("image_size", (0, 0)),
|
||||
errors=["版面分析结果中没有文本行"],
|
||||
).to_dict()
|
||||
|
||||
elements = []
|
||||
for row in rows:
|
||||
for elem_data in row.get("elements", []):
|
||||
elements.append(
|
||||
OcrTextElement(
|
||||
text=elem_data.get("text", ""),
|
||||
x_min=elem_data.get("x", 0),
|
||||
y_min=elem_data.get("y", 0),
|
||||
x_max=elem_data.get("x", 0) + elem_data.get("w", 0),
|
||||
y_max=elem_data.get("y", 0) + elem_data.get("h", 0),
|
||||
)
|
||||
)
|
||||
|
||||
result = ExtractionResult(
|
||||
file_path="(from layout)",
|
||||
image_size=layout_result.get("image_size", (0, 0)),
|
||||
all_elements=elements,
|
||||
ocr_available=True,
|
||||
)
|
||||
|
||||
for field_name in target_fields:
|
||||
extracted = self._extract_field(field_name, elements)
|
||||
if extracted:
|
||||
result.fields.append(extracted)
|
||||
else:
|
||||
result.fields.append(
|
||||
ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value="",
|
||||
bbox=[],
|
||||
confidence=0.0,
|
||||
extraction_method="none",
|
||||
)
|
||||
)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
# ========================================================================
|
||||
# 阶段1: 文档分析
|
||||
# ========================================================================
|
||||
|
||||
def _analyze_document(self, file_path: str) -> tuple[list[OcrTextElement], tuple[int, int]]:
|
||||
"""阶段1: OCR + 版面分析,产出带坐标的文本元素列表。"""
|
||||
from backend.layout_analyzer import _load_image, _ocr_elements
|
||||
|
||||
img = _load_image(Path(file_path))
|
||||
if img is None:
|
||||
return [], (0, 0)
|
||||
|
||||
image_size = img.size
|
||||
raw_elements = self._ocr_elements_enhanced(img, file_path)
|
||||
|
||||
elements = []
|
||||
for e_data in raw_elements:
|
||||
if e_data.get("confidence", 1.0) < self.confidence_threshold:
|
||||
continue
|
||||
elements.append(
|
||||
OcrTextElement(
|
||||
text=e_data.get("text", ""),
|
||||
x_min=e_data.get("x", 0),
|
||||
y_min=e_data.get("y", 0),
|
||||
x_max=e_data.get("x", 0) + e_data.get("w", 0),
|
||||
y_max=e_data.get("y", 0) + e_data.get("h", 0),
|
||||
confidence=e_data.get("confidence", 1.0),
|
||||
)
|
||||
)
|
||||
|
||||
elements.sort(key=lambda e: (e.y_min, e.x_min))
|
||||
return elements, image_size
|
||||
|
||||
def _ocr_elements_enhanced(self, img, file_path: str) -> list[dict]:
|
||||
"""增强版 OCR,返回带置信度的元素列表。"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
easyocr_result = self._try_easyocr(np.array(img))
|
||||
if easyocr_result:
|
||||
return easyocr_result
|
||||
|
||||
paddleocr_result = self._try_paddleocr(img, file_path)
|
||||
if paddleocr_result:
|
||||
return paddleocr_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
def _try_easyocr(self, np_img) -> Optional[list[dict]]:
|
||||
try:
|
||||
import easyocr
|
||||
|
||||
reader = easyocr.Reader(
|
||||
["ch_sim", "en"],
|
||||
gpu=self.use_gpu,
|
||||
verbose=False,
|
||||
)
|
||||
raw_result = reader.readtext(np_img)
|
||||
|
||||
elements = []
|
||||
for bbox, text, confidence in raw_result:
|
||||
if not text.strip():
|
||||
continue
|
||||
xs = [p[0] for p in bbox]
|
||||
ys = [p[1] for p in bbox]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
elements.append({
|
||||
"x": round(x_min, 1),
|
||||
"y": round(y_min, 1),
|
||||
"w": round(x_max - x_min, 1),
|
||||
"h": round(y_max - y_min, 1),
|
||||
"text": text.strip(),
|
||||
"confidence": round(confidence, 4),
|
||||
})
|
||||
|
||||
elements.sort(key=lambda e: (e["y"], e["x"]))
|
||||
return elements
|
||||
except ImportError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _try_paddleocr(self, img, file_path: str) -> Optional[list[dict]]:
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
import numpy as np
|
||||
|
||||
ocr = PaddleOCR(lang="ch")
|
||||
raw_result = ocr.ocr(np.array(img))
|
||||
|
||||
elements = []
|
||||
if raw_result and raw_result[0]:
|
||||
for line in raw_result[0]:
|
||||
if len(line) < 2:
|
||||
continue
|
||||
box = line[0]
|
||||
text_info = line[1]
|
||||
|
||||
if isinstance(text_info, (list, tuple)):
|
||||
text = text_info[0]
|
||||
confidence = text_info[1] if len(text_info) > 1 else 1.0
|
||||
else:
|
||||
text = str(text_info)
|
||||
confidence = 1.0
|
||||
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
xs = [p[0] for p in box]
|
||||
ys = [p[1] for p in box]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
elements.append({
|
||||
"x": round(x_min, 1),
|
||||
"y": round(y_min, 1),
|
||||
"w": round(x_max - x_min, 1),
|
||||
"h": round(y_max - y_min, 1),
|
||||
"text": text.strip(),
|
||||
"confidence": round(float(confidence), 4),
|
||||
})
|
||||
|
||||
elements.sort(key=lambda e: (e["y"], e["x"]))
|
||||
return elements
|
||||
except ImportError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _check_ocr_availability(self) -> bool:
|
||||
try:
|
||||
import easyocr
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import paddleocr
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
# ========================================================================
|
||||
# 阶段2: 字段精确提取
|
||||
# ========================================================================
|
||||
|
||||
def _extract_field(
|
||||
self,
|
||||
field_name: str,
|
||||
elements: list[OcrTextElement],
|
||||
) -> Optional[ExtractedField]:
|
||||
"""按优先级尝试四种策略提取单个字段。
|
||||
|
||||
策略优先级:
|
||||
1. 精确键值对匹配
|
||||
2. 模糊键值对匹配
|
||||
3. 正则模式匹配
|
||||
4. 表格结构匹配
|
||||
"""
|
||||
strategies = [
|
||||
("exact_match", self._exact_kv_match),
|
||||
("kv_pair", self._fuzzy_kv_match),
|
||||
("regex", self._regex_match),
|
||||
("table_match", self._table_match),
|
||||
]
|
||||
|
||||
for method_name, strategy_fn in strategies:
|
||||
result = strategy_fn(field_name, elements)
|
||||
if result and result.field_value:
|
||||
result.extraction_method = method_name
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 策略1: 精确键值对匹配
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _exact_kv_match(
|
||||
self,
|
||||
field_name: str,
|
||||
elements: list[OcrTextElement],
|
||||
) -> Optional[ExtractedField]:
|
||||
"""精确键值对匹配: 识别"字段名: 值"或"字段名:值"模式。
|
||||
|
||||
在同一文本元素中查找 "字段名" 后紧跟分隔符 + "值" 的模式。
|
||||
如 OCR 识别出 "发票代码: 12345678" 这一整个元素。
|
||||
"""
|
||||
separators = [":", ":", "=", "-", "—", ":", "\t", "|"]
|
||||
field_name_clean = field_name.strip()
|
||||
|
||||
for elem in elements:
|
||||
text = elem.text
|
||||
if field_name_clean not in text:
|
||||
continue
|
||||
|
||||
for sep in separators:
|
||||
pattern = re.escape(field_name_clean) + r"\s*" + re.escape(sep) + r"\s*(.+)"
|
||||
m = re.search(pattern, text)
|
||||
if m:
|
||||
value = m.group(1).strip()
|
||||
if value:
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value=value,
|
||||
bbox=elem.bbox,
|
||||
confidence=0.95,
|
||||
extraction_method="",
|
||||
)
|
||||
|
||||
simple_pattern = re.escape(field_name_clean) + r"\s+(.+)"
|
||||
m = re.search(simple_pattern, text)
|
||||
if m:
|
||||
value = m.group(1).strip()
|
||||
if value and value != field_name_clean:
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value=value,
|
||||
bbox=elem.bbox,
|
||||
confidence=0.85,
|
||||
extraction_method="",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 策略2: 模糊键值对匹配
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _fuzzy_kv_match(
|
||||
self,
|
||||
field_name: str,
|
||||
elements: list[OcrTextElement],
|
||||
) -> Optional[ExtractedField]:
|
||||
"""模糊键值对匹配: 字段名和值分布在相邻的文本元素中。
|
||||
|
||||
找到含字段名的元素后,在同一行或相邻元素中查找值。
|
||||
"""
|
||||
field_name_clean = field_name.strip()
|
||||
field_elem = None
|
||||
|
||||
for elem in elements:
|
||||
if field_name_clean in elem.text:
|
||||
field_elem = elem
|
||||
break
|
||||
|
||||
if field_elem is None:
|
||||
matching = []
|
||||
for elem in elements:
|
||||
sim = self._text_similarity(field_name_clean, elem.text)
|
||||
if sim > 0.6:
|
||||
matching.append((sim, elem))
|
||||
if matching:
|
||||
matching.sort(key=lambda x: x[0], reverse=True)
|
||||
field_elem = matching[0][1]
|
||||
|
||||
if field_elem is None:
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
for elem in elements:
|
||||
if elem is field_elem:
|
||||
continue
|
||||
candidates.append(elem)
|
||||
|
||||
same_row = []
|
||||
for elem in candidates:
|
||||
if abs(elem.center_y - field_elem.center_y) < field_elem.height * 1.5:
|
||||
same_row.append(elem)
|
||||
if same_row:
|
||||
same_row.sort(key=lambda e: e.x_min)
|
||||
for elem in same_row:
|
||||
if elem.x_min > field_elem.x_max:
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value=elem.text,
|
||||
bbox=elem.bbox,
|
||||
confidence=0.75,
|
||||
extraction_method="",
|
||||
)
|
||||
|
||||
nearest = None
|
||||
nearest_dist = float("inf")
|
||||
for elem in candidates:
|
||||
if elem.y_min > field_elem.y_max:
|
||||
dy = elem.y_min - field_elem.y_max
|
||||
dx = abs(elem.center_x - field_elem.center_x)
|
||||
dist = dy + dx * 0.3
|
||||
if dist < nearest_dist and dy < field_elem.height * 3:
|
||||
nearest_dist = dist
|
||||
nearest = elem
|
||||
|
||||
if nearest:
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value=nearest.text,
|
||||
bbox=nearest.bbox,
|
||||
confidence=0.6,
|
||||
extraction_method="",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 策略3: 正则模式匹配
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
PREDEFINED_PATTERNS: dict[str, str] = {
|
||||
"发票代码": r"[0-9A-Za-z]{10,12}",
|
||||
"发票号码": r"\d{8}",
|
||||
"合计金额": r"[\d,]+\.?\d*",
|
||||
"金额": r"[\d,]+\.?\d*",
|
||||
"开票日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
|
||||
"日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
|
||||
"校验码": r"[0-9A-Fa-f]{5,20}",
|
||||
"总价": r"[\d,]+\.?\d*",
|
||||
"总金额": r"[\d,]+\.?\d*",
|
||||
"价税合计": r"[\d,]+\.?\d*",
|
||||
"数量": r"\d+\.?\d*",
|
||||
"单价": r"[\d,]+\.?\d*",
|
||||
"税率": r"\d+\.?\d*%?",
|
||||
}
|
||||
|
||||
def _regex_match(
|
||||
self,
|
||||
field_name: str,
|
||||
elements: list[OcrTextElement],
|
||||
) -> Optional[ExtractedField]:
|
||||
"""正则模式匹配: 根据字段名选择预定义的正则模式,在所有元素中搜索。"""
|
||||
pattern = self.PREDEFINED_PATTERNS.get(field_name)
|
||||
if not pattern:
|
||||
for key, pat in self.PREDEFINED_PATTERNS.items():
|
||||
if key in field_name or field_name in key:
|
||||
pattern = pat
|
||||
break
|
||||
|
||||
if not pattern:
|
||||
return None
|
||||
|
||||
compiled = re.compile(r"^\s*" + pattern + r"\s*$")
|
||||
for elem in elements:
|
||||
if compiled.match(elem.text):
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value=elem.text.strip(),
|
||||
bbox=elem.bbox,
|
||||
confidence=0.7,
|
||||
extraction_method="",
|
||||
)
|
||||
|
||||
compiled_partial = re.compile(pattern)
|
||||
for elem in elements:
|
||||
m = compiled_partial.search(elem.text)
|
||||
if m:
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value=m.group(0),
|
||||
bbox=elem.bbox,
|
||||
confidence=0.6,
|
||||
extraction_method="",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 策略4: 表格结构匹配
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _table_match(
|
||||
self,
|
||||
field_name: str,
|
||||
elements: list[OcrTextElement],
|
||||
) -> Optional[ExtractedField]:
|
||||
"""表格结构匹配: 将元素按行列分组,查找表头-值对应关系。
|
||||
|
||||
识别逻辑:
|
||||
1. 将元素按 Y 坐标分组为"行"
|
||||
2. 查找包含 field_name 的表头行
|
||||
3. 在表头列对应的数据行中取值
|
||||
"""
|
||||
if len(elements) < 3:
|
||||
return None
|
||||
|
||||
rows = self._group_elements_by_rows(elements)
|
||||
if len(rows) < 2:
|
||||
return None
|
||||
|
||||
header_row_idx = -1
|
||||
header_col_idx = -1
|
||||
|
||||
for ri, row in enumerate(rows):
|
||||
for ci, elem in enumerate(row):
|
||||
if field_name in elem.text:
|
||||
header_row_idx = ri
|
||||
header_col_idx = ci
|
||||
break
|
||||
if header_row_idx >= 0:
|
||||
break
|
||||
|
||||
if header_row_idx < 0:
|
||||
for ri, row in enumerate(rows):
|
||||
for ci, elem in enumerate(row):
|
||||
sim = self._text_similarity(field_name, elem.text)
|
||||
if sim > 0.5:
|
||||
header_row_idx = ri
|
||||
header_col_idx = ci
|
||||
break
|
||||
if header_row_idx >= 0:
|
||||
break
|
||||
|
||||
if header_row_idx < 0:
|
||||
return None
|
||||
|
||||
data_rows = rows[header_row_idx + 1:]
|
||||
if not data_rows:
|
||||
data_rows = [rows[header_row_idx]]
|
||||
|
||||
matched_elem = None
|
||||
for row in data_rows:
|
||||
if header_col_idx < len(row):
|
||||
matched_elem = row[header_col_idx]
|
||||
break
|
||||
closest = None
|
||||
min_dist = float("inf")
|
||||
header_x = float("inf")
|
||||
if header_col_idx < len(rows[header_row_idx]):
|
||||
header_x = rows[header_row_idx][header_col_idx].center_x
|
||||
for elem in row:
|
||||
dist = abs(elem.center_x - header_x)
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
closest = elem
|
||||
if closest:
|
||||
matched_elem = closest
|
||||
break
|
||||
|
||||
if matched_elem and matched_elem.text != field_name:
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
field_value=matched_elem.text,
|
||||
bbox=matched_elem.bbox,
|
||||
confidence=0.55,
|
||||
extraction_method="",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# ========================================================================
|
||||
# 工具方法
|
||||
# ========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _group_elements_by_rows(
|
||||
elements: list[OcrTextElement],
|
||||
) -> list[list[OcrTextElement]]:
|
||||
"""将元素按 Y 坐标分组为行(容差为元素平均高度的一半)。"""
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
avg_height = sum(e.height for e in elements) / len(elements)
|
||||
tolerance = max(avg_height * 0.5, 5.0)
|
||||
|
||||
rows = []
|
||||
current_row = [elements[0]]
|
||||
|
||||
for elem in elements[1:]:
|
||||
prev_center_y = current_row[0].center_y
|
||||
if abs(elem.center_y - prev_center_y) < tolerance:
|
||||
current_row.append(elem)
|
||||
else:
|
||||
current_row.sort(key=lambda e: e.x_min)
|
||||
rows.append(current_row)
|
||||
current_row = [elem]
|
||||
|
||||
if current_row:
|
||||
current_row.sort(key=lambda e: e.x_min)
|
||||
rows.append(current_row)
|
||||
|
||||
return rows
|
||||
|
||||
@staticmethod
|
||||
def _text_similarity(text1: str, text2: str) -> float:
|
||||
"""计算两个文本的简单相似度(公共字符比例)。"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
t1 = text1.lower().strip()
|
||||
t2 = text2.lower().strip()
|
||||
|
||||
if t1 == t2:
|
||||
return 1.0
|
||||
if t1 in t2 or t2 in t1:
|
||||
return 0.8
|
||||
|
||||
chars1 = set(t1)
|
||||
chars2 = set(t2)
|
||||
if not chars1:
|
||||
return 0.0
|
||||
|
||||
intersection = chars1 & chars2
|
||||
return len(intersection) / len(chars1)
|
||||
|
||||
|
||||
def extract_ocr_fields(
|
||||
file_path: str,
|
||||
target_fields: list[str],
|
||||
use_gpu: bool = False,
|
||||
confidence_threshold: float = 0.5,
|
||||
) -> dict:
|
||||
"""便捷函数: 对指定图片执行 OCR 字段提取。
|
||||
|
||||
Args:
|
||||
file_path: 图片文件路径
|
||||
target_fields: 目标字段名列表
|
||||
use_gpu: 是否使用 GPU 加速
|
||||
confidence_threshold: OCR 置信度阈值
|
||||
|
||||
Returns:
|
||||
提取结果字典
|
||||
"""
|
||||
extractor = OcrExtractor(
|
||||
use_gpu=use_gpu,
|
||||
confidence_threshold=confidence_threshold,
|
||||
)
|
||||
return extractor.extract(file_path, target_fields)
|
||||
|
||||
|
||||
def extract_from_layout(
|
||||
layout_result: dict,
|
||||
target_fields: list[str],
|
||||
confidence_threshold: float = 0.5,
|
||||
) -> dict:
|
||||
"""便捷函数: 从已有的版面分析结果中提取字段。
|
||||
|
||||
Args:
|
||||
layout_result: analyze_layout() 的返回值
|
||||
target_fields: 目标字段名列表
|
||||
confidence_threshold: OCR 置信度阈值
|
||||
|
||||
Returns:
|
||||
提取结果字典
|
||||
"""
|
||||
extractor = OcrExtractor(confidence_threshold=confidence_threshold)
|
||||
return extractor.extract_from_layout_result(layout_result, target_fields)
|
||||
@@ -0,0 +1,543 @@
|
||||
"""OCR 单据字段提取器单元测试。
|
||||
|
||||
覆盖:
|
||||
- OcrTextElement / ExtractedField / ExtractionResult 数据结构
|
||||
- 四种提取策略的独立测试
|
||||
- 坐标计算正确性
|
||||
- 边界情况(空元素、无匹配、部分匹配)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.ocr_extractor import (
|
||||
OcrTextElement,
|
||||
ExtractedField,
|
||||
ExtractionResult,
|
||||
OcrExtractor,
|
||||
extract_ocr_fields,
|
||||
extract_from_layout,
|
||||
)
|
||||
|
||||
|
||||
class TestOcrTextElement:
|
||||
"""测试 OcrTextElement 数据类。"""
|
||||
|
||||
def test_bbox_property(self):
|
||||
elem = OcrTextElement(
|
||||
text="发票代码",
|
||||
x_min=100.0,
|
||||
y_min=50.0,
|
||||
x_max=250.0,
|
||||
y_max=80.0,
|
||||
)
|
||||
assert elem.bbox == [100.0, 50.0, 250.0, 80.0]
|
||||
|
||||
def test_center_coordinates(self):
|
||||
elem = OcrTextElement(
|
||||
text="测试",
|
||||
x_min=100.0,
|
||||
y_min=50.0,
|
||||
x_max=200.0,
|
||||
y_max=100.0,
|
||||
)
|
||||
assert elem.center_x == 150.0
|
||||
assert elem.center_y == 75.0
|
||||
|
||||
def test_width_height(self):
|
||||
elem = OcrTextElement(
|
||||
text="测试",
|
||||
x_min=10.0,
|
||||
y_min=20.0,
|
||||
x_max=100.0,
|
||||
y_max=70.0,
|
||||
)
|
||||
assert elem.width == 90.0
|
||||
assert elem.height == 50.0
|
||||
|
||||
def test_default_confidence(self):
|
||||
elem = OcrTextElement(
|
||||
text="测试",
|
||||
x_min=0,
|
||||
y_min=0,
|
||||
x_max=10,
|
||||
y_max=10,
|
||||
)
|
||||
assert elem.confidence == 1.0
|
||||
|
||||
|
||||
class TestExtractedField:
|
||||
"""测试 ExtractedField 数据类。"""
|
||||
|
||||
def test_field_creation(self):
|
||||
field = ExtractedField(
|
||||
field_name="发票代码",
|
||||
field_value="1234567890",
|
||||
bbox=[100, 50, 200, 80],
|
||||
confidence=0.95,
|
||||
extraction_method="exact_match",
|
||||
)
|
||||
assert field.field_name == "发票代码"
|
||||
assert field.field_value == "1234567890"
|
||||
assert field.bbox == [100, 50, 200, 80]
|
||||
assert field.confidence == 0.95
|
||||
assert field.extraction_method == "exact_match"
|
||||
|
||||
|
||||
class TestExtractionResult:
|
||||
"""测试 ExtractionResult 数据类。"""
|
||||
|
||||
def test_to_dict_basic(self):
|
||||
result = ExtractionResult(
|
||||
file_path="/test/invoice.png",
|
||||
image_size=(800, 600),
|
||||
ocr_available=True,
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert d["file_path"] == "/test/invoice.png"
|
||||
assert d["image_size"] == (800, 600)
|
||||
assert d["ocr_available"] is True
|
||||
assert d["fields"] == []
|
||||
assert d["total_elements"] == 0
|
||||
assert d["errors"] == []
|
||||
|
||||
def test_to_dict_with_fields(self):
|
||||
result = ExtractionResult(
|
||||
file_path="/test/invoice.png",
|
||||
image_size=(800, 600),
|
||||
ocr_available=True,
|
||||
)
|
||||
result.fields.append(
|
||||
ExtractedField(
|
||||
field_name="发票代码",
|
||||
field_value="1234567890",
|
||||
bbox=[100, 50, 200, 80],
|
||||
confidence=0.95,
|
||||
extraction_method="exact_match",
|
||||
)
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert len(d["fields"]) == 1
|
||||
assert d["fields"][0]["field_name"] == "发票代码"
|
||||
assert d["fields"][0]["field_value"] == "1234567890"
|
||||
assert d["fields"][0]["bbox"] == [100, 50, 200, 80]
|
||||
assert d["fields"][0]["confidence"] == 0.95
|
||||
assert d["fields"][0]["extraction_method"] == "exact_match"
|
||||
|
||||
def test_to_dict_with_error(self):
|
||||
result = ExtractionResult(
|
||||
file_path="/test/missing.png",
|
||||
image_size=(0, 0),
|
||||
errors=["文件不存在: /test/missing.png"],
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert d["errors"] == ["文件不存在: /test/missing.png"]
|
||||
|
||||
|
||||
class TestElementGrouping:
|
||||
"""测试元素行分组功能。"""
|
||||
|
||||
def test_group_single_row(self):
|
||||
elements = [
|
||||
OcrTextElement("A", 10, 50, 50, 70),
|
||||
OcrTextElement("B", 60, 50, 110, 70),
|
||||
OcrTextElement("C", 120, 50, 170, 70),
|
||||
]
|
||||
rows = OcrExtractor._group_elements_by_rows(elements)
|
||||
assert len(rows) == 1
|
||||
assert len(rows[0]) == 3
|
||||
|
||||
def test_group_multiple_rows(self):
|
||||
elements = [
|
||||
OcrTextElement("A1", 10, 50, 50, 70),
|
||||
OcrTextElement("B1", 60, 50, 110, 70),
|
||||
OcrTextElement("A2", 10, 120, 50, 140),
|
||||
OcrTextElement("B2", 60, 120, 110, 140),
|
||||
OcrTextElement("A3", 10, 200, 50, 220),
|
||||
]
|
||||
rows = OcrExtractor._group_elements_by_rows(elements)
|
||||
assert len(rows) == 3
|
||||
assert len(rows[0]) == 2
|
||||
assert len(rows[1]) == 2
|
||||
assert len(rows[2]) == 1
|
||||
|
||||
def test_group_empty(self):
|
||||
rows = OcrExtractor._group_elements_by_rows([])
|
||||
assert rows == []
|
||||
|
||||
def test_group_single_element(self):
|
||||
rows = OcrExtractor._group_elements_by_rows([
|
||||
OcrTextElement("X", 10, 50, 50, 70)
|
||||
])
|
||||
assert len(rows) == 1
|
||||
assert len(rows[0]) == 1
|
||||
|
||||
|
||||
class TestTextSimilarity:
|
||||
"""测试文本相似度计算。"""
|
||||
|
||||
def test_exact_match(self):
|
||||
sim = OcrExtractor._text_similarity("发票代码", "发票代码")
|
||||
assert sim == 1.0
|
||||
|
||||
def test_partial_match(self):
|
||||
sim = OcrExtractor._text_similarity("发票代码", "代码")
|
||||
assert sim > 0.5
|
||||
|
||||
def test_no_match(self):
|
||||
sim = OcrExtractor._text_similarity("发票代码", "xyz")
|
||||
assert sim == 0.0
|
||||
|
||||
def test_empty_strings(self):
|
||||
assert OcrExtractor._text_similarity("", "abc") == 0.0
|
||||
assert OcrExtractor._text_similarity("abc", "") == 0.0
|
||||
assert OcrExtractor._text_similarity("", "") == 0.0
|
||||
|
||||
def test_substring_match(self):
|
||||
sim = OcrExtractor._text_similarity("代码", "发票代码")
|
||||
assert sim > 0.7
|
||||
|
||||
|
||||
class TestExactKVMatch:
|
||||
"""测试策略1: 精确键值对匹配。"""
|
||||
|
||||
def test_colon_separator(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
|
||||
]
|
||||
result = extractor._exact_kv_match("发票代码", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "1234567890"
|
||||
assert result.confidence == 0.95
|
||||
|
||||
def test_chinese_colon(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("发票号码:87654321", 50, 50, 300, 80),
|
||||
]
|
||||
result = extractor._exact_kv_match("发票号码", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "87654321"
|
||||
|
||||
def test_space_separator(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("合计金额 999.00", 50, 50, 300, 80),
|
||||
]
|
||||
result = extractor._exact_kv_match("合计金额", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "999.00"
|
||||
|
||||
def test_equals_separator(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("数量=5", 50, 50, 200, 80),
|
||||
]
|
||||
result = extractor._exact_kv_match("数量", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "5"
|
||||
|
||||
def test_field_not_found(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("其他内容: 123", 50, 50, 200, 80),
|
||||
]
|
||||
result = extractor._exact_kv_match("发票代码", elements)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFuzzyKVMatch:
|
||||
"""测试策略2: 模糊键值对匹配。"""
|
||||
|
||||
def test_adjacent_same_row(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("发票代码", 50, 50, 150, 80),
|
||||
OcrTextElement("1234567890", 200, 50, 350, 80),
|
||||
]
|
||||
result = extractor._fuzzy_kv_match("发票代码", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "1234567890"
|
||||
assert result.confidence == 0.75
|
||||
|
||||
def test_adjacent_next_row(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("发票代码", 50, 50, 150, 80),
|
||||
OcrTextElement("1234567890", 50, 100, 200, 130),
|
||||
]
|
||||
result = extractor._fuzzy_kv_match("发票代码", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "1234567890"
|
||||
|
||||
def test_field_name_not_found(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("其他信息", 50, 50, 150, 80),
|
||||
OcrTextElement("1234567890", 200, 50, 350, 80),
|
||||
]
|
||||
result = extractor._fuzzy_kv_match("发票代码", elements)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestRegexMatch:
|
||||
"""测试策略3: 正则模式匹配。"""
|
||||
|
||||
def test_invoice_code_pattern(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("1234567890", 100, 50, 250, 80),
|
||||
]
|
||||
result = extractor._regex_match("发票代码", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "1234567890"
|
||||
|
||||
def test_invoice_number_pattern(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("87654321", 100, 50, 200, 80),
|
||||
]
|
||||
result = extractor._regex_match("发票号码", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "87654321"
|
||||
|
||||
def test_amount_pattern(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("1,234.56", 100, 50, 200, 80),
|
||||
]
|
||||
result = extractor._regex_match("合计金额", elements)
|
||||
assert result is not None
|
||||
assert "1,234.56" in result.field_value
|
||||
|
||||
def test_date_pattern(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("2024年1月15日", 100, 50, 250, 80),
|
||||
]
|
||||
result = extractor._regex_match("开票日期", elements)
|
||||
assert result is not None
|
||||
assert "2024" in result.field_value
|
||||
|
||||
def test_unknown_field_no_pattern(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("随便什么内容", 100, 50, 250, 80),
|
||||
]
|
||||
result = extractor._regex_match("未知字段", elements)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestTableMatch:
|
||||
"""测试策略4: 表格结构匹配。"""
|
||||
|
||||
def test_simple_table(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("名称", 50, 50, 150, 80),
|
||||
OcrTextElement("数量", 200, 50, 300, 80),
|
||||
OcrTextElement("单价", 350, 50, 450, 80),
|
||||
OcrTextElement("商品A", 50, 100, 150, 130),
|
||||
OcrTextElement("2", 200, 100, 250, 130),
|
||||
OcrTextElement("10.00", 350, 100, 420, 130),
|
||||
]
|
||||
result = extractor._table_match("数量", elements)
|
||||
assert result is not None
|
||||
assert result.field_value == "2"
|
||||
|
||||
def test_table_with_fuzzy_header(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("品名", 50, 50, 120, 80),
|
||||
OcrTextElement("金额", 200, 50, 300, 80),
|
||||
OcrTextElement("苹果", 50, 100, 120, 130),
|
||||
OcrTextElement("5.00", 200, 100, 260, 130),
|
||||
]
|
||||
result = extractor._table_match("合计金额", elements)
|
||||
# 金额列可能匹配到 "金额"
|
||||
if result:
|
||||
assert result.field_value == "5.00"
|
||||
|
||||
def test_table_header_not_found(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("A", 50, 50, 100, 80),
|
||||
OcrTextElement("B", 150, 50, 200, 80),
|
||||
]
|
||||
result = extractor._table_match("发票代码", elements)
|
||||
assert result is None
|
||||
|
||||
def test_table_too_few_elements(self):
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("名称", 50, 50, 150, 80),
|
||||
]
|
||||
result = extractor._table_match("名称", elements)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCoordinateCorrectness:
|
||||
"""测试坐标计算正确性。"""
|
||||
|
||||
def test_bbox_origin_top_left(self):
|
||||
"""验证坐标系统以左上角为原点。"""
|
||||
elem = OcrTextElement("A", 0, 0, 50, 20)
|
||||
assert elem.x_min == 0
|
||||
assert elem.y_min == 0
|
||||
assert elem.bbox[0] == 0
|
||||
assert elem.bbox[1] == 0
|
||||
|
||||
def test_bbox_conversion(self):
|
||||
"""验证从 (x, y, w, h) 到 [x_min, y_min, x_max, y_max] 的转换。"""
|
||||
x, y, w, h = 100, 200, 300, 50
|
||||
elem = OcrTextElement("test", x_min=x, y_min=y, x_max=x + w, y_max=y + h)
|
||||
assert elem.bbox == [x, y, x + w, y + h]
|
||||
assert elem.x_max - elem.x_min == w
|
||||
assert elem.y_max - elem.y_min == h
|
||||
|
||||
def test_multiple_elements_coordinate_independence(self):
|
||||
"""验证多个元素的坐标互不干扰。"""
|
||||
elements = [
|
||||
OcrTextElement("A", 10, 20, 60, 40),
|
||||
OcrTextElement("B", 100, 200, 180, 240),
|
||||
]
|
||||
assert elements[0].bbox == [10, 20, 60, 40]
|
||||
assert elements[1].bbox == [100, 200, 180, 240]
|
||||
assert elements[0].x_min != elements[1].x_min
|
||||
|
||||
def test_center_calculation(self):
|
||||
"""验证中心点计算。"""
|
||||
elem = OcrTextElement("test", 0, 0, 100, 100)
|
||||
assert elem.center_x == 50.0
|
||||
assert elem.center_y == 50.0
|
||||
|
||||
|
||||
class TestExtractionPipeline:
|
||||
"""测试完整的提取流水线。"""
|
||||
|
||||
def test_priority_order_exact_first(self):
|
||||
"""验证策略优先级: exact_match 优先于其他策略。"""
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
|
||||
]
|
||||
result = extractor._extract_field("发票代码", elements)
|
||||
assert result is not None
|
||||
assert result.extraction_method == "exact_match"
|
||||
|
||||
def test_fallback_to_fuzzy(self):
|
||||
"""验证精确匹配失败后回退到模糊匹配。"""
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("发票代码", 50, 50, 150, 80),
|
||||
OcrTextElement("1234567890", 200, 50, 350, 80),
|
||||
]
|
||||
result = extractor._extract_field("发票代码", elements)
|
||||
assert result is not None
|
||||
assert result.extraction_method == "kv_pair"
|
||||
|
||||
def test_all_fields_empty(self):
|
||||
"""验证所有字段提取失败时返回空结果。"""
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("一些不相关的文本", 50, 50, 300, 80),
|
||||
OcrTextElement("更多随机内容", 50, 100, 300, 130),
|
||||
]
|
||||
result = extractor._extract_field("发票代码", elements)
|
||||
assert result is None
|
||||
|
||||
def test_empty_elements_list(self):
|
||||
"""验证空元素列表时正常处理。"""
|
||||
extractor = OcrExtractor()
|
||||
result = extractor._extract_field("发票代码", [])
|
||||
assert result is None
|
||||
|
||||
def test_extraction_with_confidence(self):
|
||||
"""验证提取结果的置信度在合理范围内。"""
|
||||
extractor = OcrExtractor()
|
||||
elements = [
|
||||
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
|
||||
]
|
||||
result = extractor._extract_field("发票代码", elements)
|
||||
assert result is not None
|
||||
assert 0.0 < result.confidence <= 1.0
|
||||
|
||||
|
||||
class TestExtractFromLayout:
|
||||
"""测试从 layout 结果中提取字段。"""
|
||||
|
||||
def test_basic_layout_extraction(self):
|
||||
layout_result = {
|
||||
"image_size": (800, 600),
|
||||
"template_type": "full_a4",
|
||||
"rows": [
|
||||
{
|
||||
"y_center": 50,
|
||||
"elements": [
|
||||
{"x": 50, "y": 30, "w": 150, "h": 30, "text": "发票代码:"},
|
||||
{"x": 250, "y": 30, "w": 200, "h": 30, "text": "1234567890"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"y_center": 80,
|
||||
"elements": [
|
||||
{"x": 50, "y": 60, "w": 150, "h": 30, "text": "合计金额:"},
|
||||
{"x": 250, "y": 60, "w": 100, "h": 30, "text": "999.00"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
result = extract_from_layout(layout_result, ["发票代码"])
|
||||
assert result["ocr_available"] is True
|
||||
assert result["image_size"] == (800, 600)
|
||||
|
||||
def test_empty_layout(self):
|
||||
result = extract_from_layout({}, ["发票代码"])
|
||||
assert result["ocr_available"] is False
|
||||
assert len(result["errors"]) > 0
|
||||
|
||||
|
||||
class TestOcrExtractorFileNotFound:
|
||||
"""测试文件不存在的情况。"""
|
||||
|
||||
def test_missing_file(self):
|
||||
extractor = OcrExtractor()
|
||||
result = extractor.extract("/nonexistent/file.png", ["发票代码"])
|
||||
assert result["ocr_available"] is False
|
||||
assert len(result["errors"]) > 0
|
||||
assert "文件不存在" in result["errors"][0]
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""测试便捷函数。"""
|
||||
|
||||
def test_extract_ocr_fields_missing_file(self):
|
||||
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
|
||||
assert len(result["errors"]) > 0
|
||||
|
||||
def test_extract_from_layout_with_partial_rows(self):
|
||||
layout_result = {
|
||||
"image_size": (1200, 800),
|
||||
"template_type": "partial_rows",
|
||||
"rows": [
|
||||
{
|
||||
"y_center": 100,
|
||||
"elements": [
|
||||
{"x": 100, "y": 80, "w": 120, "h": 40, "text": "发票代码"},
|
||||
{"x": 300, "y": 80, "w": 180, "h": 40, "text": "NO123456"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
result = extract_from_layout(layout_result, ["发票代码"])
|
||||
assert result["ocr_available"] is True
|
||||
assert len(result["fields"]) == 1
|
||||
assert result["fields"][0]["field_name"] == "发票代码"
|
||||
assert result["fields"][0]["field_value"] == "NO123456"
|
||||
Reference in New Issue
Block a user