feat: 新增 OCR 单据字段精确提取模块

- 新增 backend/ocr_extractor.py: 两阶段提取流水线 (文档分析 + 字段提取)
- 四种提取策略: 精确KV匹配/模糊KV匹配/正则模式/表格结构匹配
- agent/state.py: 新增 ocr_extraction_result 和 uploaded_file_path 字段
- agent/nodes.py: process_input() 中自动触发 OCR 提取钩子
- app.py: 文件上传时保留图片路径, 总结卡片中展示提取结果
- .env.example: 新增 OCR_USE_GPU / OCR_CONFIDENCE_THRESHOLD 配置项
- tests/test_ocr_extraction.py: 48 个单元测试全部通过
This commit is contained in:
2026-05-20 08:06:55 +08:00
parent 067880bf2e
commit c9f003e1b7
6 changed files with 1417 additions and 2 deletions
+6
View File
@@ -63,3 +63,9 @@ HISTORY_MAX_SNAPSHOTS=10
# 意图识别模型(默认使用主 LLM 模型) # 意图识别模型(默认使用主 LLM 模型)
# INTENT_MODEL=gpt-4o-mini # INTENT_MODEL=gpt-4o-mini
# OCR 字段提取配置
# 是否使用 GPU 加速 OCR(需要 CUDA 驱动和 GPU 版 EasyOCR/PaddleOCR
OCR_USE_GPU=false
# OCR 文本置信度最低阈值(0-1),低于此值的元素将被忽略
OCR_CONFIDENCE_THRESHOLD=0.5
+25
View File
@@ -7,6 +7,7 @@ import os
import re import re
import time import time
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path
from typing import Dict from typing import Dict
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -114,6 +115,30 @@ def process_input(state: AgentState) -> Dict:
conv_history.append({"role": "user", "content": user_input}) conv_history.append({"role": "user", "content": user_input})
state["conversation_history"] = conv_history state["conversation_history"] = conv_history
# OCR 单据字段精确提取(处理上传的图片文件)
uploaded_path = state.get("uploaded_file_path", "")
if uploaded_path and Path(uploaded_path).is_file():
suffix = Path(uploaded_path).suffix.lower()
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"):
try:
from backend.ocr_extractor import OcrExtractor
extractor = OcrExtractor()
ocr_result = extractor.extract(uploaded_path, [])
if ocr_result.get("ocr_available"):
state["ocr_extraction_result"] = ocr_result
_node_log.info(
"OCR 字段提取完成",
extra={
"file": uploaded_path,
"elements": ocr_result.get("total_elements", 0),
"fields": len(ocr_result.get("fields", [])),
},
)
except Exception as e:
_node_log.warning(f"OCR 字段提取失败: {e}")
state["ocr_extraction_result"] = {"error": str(e)}
state["uploaded_file_path"] = ""
# 重置本轮请求字段 # 重置本轮请求字段
state["retry_count"] = 0 state["retry_count"] = 0
state["user_modification_request"] = user_input state["user_modification_request"] = user_input
+4
View File
@@ -40,3 +40,7 @@ class AgentState(TypedDict, total=False):
# 需求6:失败上下文传递 — 重试耗尽后暂存失败信息,下次用户输入时自动注入 # 需求6:失败上下文传递 — 重试耗尽后暂存失败信息,下次用户输入时自动注入
pending_failure_context: dict pending_failure_context: dict
# 需求7:OCR 单据字段精确提取结果
ocr_extraction_result: dict
uploaded_file_path: str
+43 -2
View File
@@ -261,6 +261,14 @@ def run_agent(user_input: str):
if stream_active: if stream_active:
streaming_placeholder.empty() streaming_placeholder.empty()
# 清理已处理的临时文件
for p in st.session_state.get("uploaded_temp_paths", []):
try:
Path(p).unlink(missing_ok=True)
except Exception:
pass
st.session_state.uploaded_temp_paths = []
# ---- 总结卡片 ---- # ---- 总结卡片 ----
# 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态 # 注:node_state 只含变更字段,用 agent_state(被所有节点就地修改)获取完整状态
final_state = agent_state final_state = agent_state
@@ -324,6 +332,30 @@ def run_agent(user_input: str):
"content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。", "content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n💡 请直接描述修改需求,系统会自动加载失败上下文。",
"type": "error_explanation", "type": "error_explanation",
}) })
# OCR 字段提取结果展示
ocr_result = agent_state.get("ocr_extraction_result", {})
if ocr_result and ocr_result.get("ocr_available") and ocr_result.get("fields"):
with st.expander("🔍 OCR 单据字段提取结果", expanded=False):
fields = ocr_result.get("fields", [])
non_empty = [f for f in fields if f.get("field_value")]
empty = [f for f in fields if not f.get("field_value")]
if non_empty:
st.markdown("**已提取字段:**")
for f in non_empty:
method = f.get("extraction_method", "")
conf = f.get("confidence", 0)
st.markdown(
f"- **{f['field_name']}**: `{f['field_value']}` "
f"(置信度: {conf:.0%}, 方法: {method}"
)
if empty:
st.caption(
f"未提取到值的字段: {', '.join(f['field_name'] for f in empty)}"
)
st.caption(
f"共检测到 {ocr_result.get('total_elements', 0)} 个文本元素"
)
else: else:
st.error("未产生结果,请重试。") st.error("未产生结果,请重试。")
@@ -443,6 +475,9 @@ with st.sidebar:
if "uploaded_files" not in st.session_state: if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = [] # [{name, text, type}] st.session_state.uploaded_files = [] # [{name, text, type}]
if "uploaded_temp_paths" not in st.session_state:
st.session_state.uploaded_temp_paths = [] # 待清理的临时文件路径
uploaded = st.file_uploader( uploaded = st.file_uploader(
"选择文件", "选择文件",
type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"], type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"],
@@ -513,8 +548,6 @@ with st.sidebar:
) )
parsed_type = "image_reference" parsed_type = "image_reference"
Path(tmp_path).unlink(missing_ok=True)
if parsed_text: if parsed_text:
st.session_state.uploaded_files.append({ st.session_state.uploaded_files.append({
"name": uf.name, "name": uf.name,
@@ -522,6 +555,14 @@ with st.sidebar:
"type": parsed_type, "type": parsed_type,
}) })
# 对图片类型,保存路径以便 OCR 字段提取(延迟到 process_input 阶段)
img_suffixes = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
if suffix in img_suffixes and result.get("method") not in ("metadata_only", None):
st.session_state.agent_state["uploaded_file_path"] = tmp_path
st.session_state.uploaded_temp_paths.append(tmp_path)
else:
Path(tmp_path).unlink(missing_ok=True)
if st.session_state.uploaded_files: if st.session_state.uploaded_files:
for i, f in enumerate(st.session_state.uploaded_files): for i, f in enumerate(st.session_state.uploaded_files):
cols = st.columns([5, 1]) cols = st.columns([5, 1])
+796
View File
@@ -0,0 +1,796 @@
"""OCR 单据字段精确提取器。
两阶段提取流程:
阶段1 - 文档分析: 复用 file_parser.parse_file() 和 layout_analyzer.analyze_layout()
获取每个文本元素的精确坐标和内容
阶段2 - 字段提取: 给定目标字段列表,通过四种策略(精确KV匹配、模糊KV匹配、
正则模式匹配、表格结构匹配)提取字段值、位置和置信度
用法:
from backend.ocr_extractor import OcrExtractor
extractor = OcrExtractor()
result = extractor.extract("invoice.png", ["发票代码", "发票号码", "合计金额"])
for field in result:
print(f"{field['field_name']}: {field['field_value']} (置信度: {field['confidence']})")
"""
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
from dotenv import load_dotenv
load_dotenv()
OCR_USE_GPU = os.getenv("OCR_USE_GPU", "false").lower() in ("true", "1", "yes")
OCR_CONFIDENCE_THRESHOLD = float(os.getenv("OCR_CONFIDENCE_THRESHOLD", "0.5"))
@dataclass
class OcrTextElement:
"""OCR 文本元素,包含精确坐标和内容。"""
text: str
x_min: float
y_min: float
x_max: float
y_max: float
confidence: float = 1.0
@property
def center_x(self) -> float:
return (self.x_min + self.x_max) / 2
@property
def center_y(self) -> float:
return (self.y_min + self.y_max) / 2
@property
def width(self) -> float:
return self.x_max - self.x_min
@property
def height(self) -> float:
return self.y_max - self.y_min
@property
def bbox(self) -> list[float]:
return [self.x_min, self.y_min, self.x_max, self.y_max]
@dataclass
class ExtractedField:
"""提取的字段结果。"""
field_name: str
field_value: str
bbox: list[float]
confidence: float
extraction_method: str
@dataclass
class ExtractionResult:
"""单次提取的完整结果。"""
file_path: str
image_size: tuple[int, int]
fields: list[ExtractedField] = field(default_factory=list)
all_elements: list[OcrTextElement] = field(default_factory=list)
errors: list[str] = field(default_factory=list)
ocr_available: bool = False
def to_dict(self) -> dict:
return {
"file_path": self.file_path,
"image_size": self.image_size,
"ocr_available": self.ocr_available,
"fields": [
{
"field_name": f.field_name,
"field_value": f.field_value,
"bbox": f.bbox,
"confidence": f.confidence,
"extraction_method": f.extraction_method,
}
for f in self.fields
],
"total_elements": len(self.all_elements),
"errors": self.errors,
}
class OcrExtractor:
"""OCR 单据字段精确提取器。
两阶段流水线:
阶段1: 对上传图片进行 OCR + 版面分析,产出带坐标的文本元素列表
阶段2: 根据目标字段列表,按优先级逐一尝试四种提取策略
"""
def __init__(
self,
use_gpu: bool = False,
confidence_threshold: float = 0.5,
):
"""初始化提取器。
Args:
use_gpu: 是否使用 GPU 加速 OCR(需要相应驱动)
confidence_threshold: OCR 文本置信度最低阈值,低于此值的元素被忽略
"""
self.use_gpu = use_gpu if use_gpu else OCR_USE_GPU
self.confidence_threshold = (
confidence_threshold
if confidence_threshold != 0.5
else OCR_CONFIDENCE_THRESHOLD
)
# ========================================================================
# 公共接口
# ========================================================================
def extract(
self,
file_path: str,
target_fields: list[str],
) -> dict:
"""执行两阶段 OCR 字段提取。
Args:
file_path: 图片文件路径(支持 png/jpg/jpeg/bmp/webp
target_fields: 需要提取的字段名称列表,如 ["发票代码", "发票号码", "合计金额"]
Returns:
提取结果字典,格式见 ExtractionResult.to_dict()
"""
result = ExtractionResult(file_path=file_path, image_size=(0, 0))
if not Path(file_path).exists():
result.errors.append(f"文件不存在: {file_path}")
return result.to_dict()
elements, image_size = self._analyze_document(file_path)
result.image_size = image_size
result.all_elements = elements
if not elements:
result.ocr_available = self._check_ocr_availability()
if not result.ocr_available:
result.errors.append(
"OCR 引擎不可用,请安装 easyocr (pip install easyocr) 或 paddleocr"
)
else:
result.errors.append("图片未检测到文字元素")
return result.to_dict()
result.ocr_available = True
for field_name in target_fields:
extracted = self._extract_field(field_name, elements)
if extracted:
result.fields.append(extracted)
else:
result.fields.append(
ExtractedField(
field_name=field_name,
field_value="",
bbox=[],
confidence=0.0,
extraction_method="none",
)
)
return result.to_dict()
def extract_from_layout_result(
self,
layout_result: dict,
target_fields: list[str],
) -> dict:
"""直接从 layout_analyzer.analyze_layout() 的结果中提取字段。
当已有版面分析结果时,跳过阶段1的重复 OCR,直接进入阶段2。
Args:
layout_result: analyze_layout() 的返回值
target_fields: 需要提取的字段名称列表
Returns:
提取结果字典
"""
rows = layout_result.get("rows", [])
if not rows:
return ExtractionResult(
file_path="(from layout)",
image_size=layout_result.get("image_size", (0, 0)),
errors=["版面分析结果中没有文本行"],
).to_dict()
elements = []
for row in rows:
for elem_data in row.get("elements", []):
elements.append(
OcrTextElement(
text=elem_data.get("text", ""),
x_min=elem_data.get("x", 0),
y_min=elem_data.get("y", 0),
x_max=elem_data.get("x", 0) + elem_data.get("w", 0),
y_max=elem_data.get("y", 0) + elem_data.get("h", 0),
)
)
result = ExtractionResult(
file_path="(from layout)",
image_size=layout_result.get("image_size", (0, 0)),
all_elements=elements,
ocr_available=True,
)
for field_name in target_fields:
extracted = self._extract_field(field_name, elements)
if extracted:
result.fields.append(extracted)
else:
result.fields.append(
ExtractedField(
field_name=field_name,
field_value="",
bbox=[],
confidence=0.0,
extraction_method="none",
)
)
return result.to_dict()
# ========================================================================
# 阶段1: 文档分析
# ========================================================================
def _analyze_document(self, file_path: str) -> tuple[list[OcrTextElement], tuple[int, int]]:
"""阶段1: OCR + 版面分析,产出带坐标的文本元素列表。"""
from backend.layout_analyzer import _load_image, _ocr_elements
img = _load_image(Path(file_path))
if img is None:
return [], (0, 0)
image_size = img.size
raw_elements = self._ocr_elements_enhanced(img, file_path)
elements = []
for e_data in raw_elements:
if e_data.get("confidence", 1.0) < self.confidence_threshold:
continue
elements.append(
OcrTextElement(
text=e_data.get("text", ""),
x_min=e_data.get("x", 0),
y_min=e_data.get("y", 0),
x_max=e_data.get("x", 0) + e_data.get("w", 0),
y_max=e_data.get("y", 0) + e_data.get("h", 0),
confidence=e_data.get("confidence", 1.0),
)
)
elements.sort(key=lambda e: (e.y_min, e.x_min))
return elements, image_size
def _ocr_elements_enhanced(self, img, file_path: str) -> list[dict]:
"""增强版 OCR,返回带置信度的元素列表。"""
try:
import numpy as np
easyocr_result = self._try_easyocr(np.array(img))
if easyocr_result:
return easyocr_result
paddleocr_result = self._try_paddleocr(img, file_path)
if paddleocr_result:
return paddleocr_result
except Exception:
pass
return []
def _try_easyocr(self, np_img) -> Optional[list[dict]]:
try:
import easyocr
reader = easyocr.Reader(
["ch_sim", "en"],
gpu=self.use_gpu,
verbose=False,
)
raw_result = reader.readtext(np_img)
elements = []
for bbox, text, confidence in raw_result:
if not text.strip():
continue
xs = [p[0] for p in bbox]
ys = [p[1] for p in bbox]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
elements.append({
"x": round(x_min, 1),
"y": round(y_min, 1),
"w": round(x_max - x_min, 1),
"h": round(y_max - y_min, 1),
"text": text.strip(),
"confidence": round(confidence, 4),
})
elements.sort(key=lambda e: (e["y"], e["x"]))
return elements
except ImportError:
return None
except Exception:
return None
def _try_paddleocr(self, img, file_path: str) -> Optional[list[dict]]:
try:
from paddleocr import PaddleOCR
import numpy as np
ocr = PaddleOCR(lang="ch")
raw_result = ocr.ocr(np.array(img))
elements = []
if raw_result and raw_result[0]:
for line in raw_result[0]:
if len(line) < 2:
continue
box = line[0]
text_info = line[1]
if isinstance(text_info, (list, tuple)):
text = text_info[0]
confidence = text_info[1] if len(text_info) > 1 else 1.0
else:
text = str(text_info)
confidence = 1.0
if not text.strip():
continue
xs = [p[0] for p in box]
ys = [p[1] for p in box]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
elements.append({
"x": round(x_min, 1),
"y": round(y_min, 1),
"w": round(x_max - x_min, 1),
"h": round(y_max - y_min, 1),
"text": text.strip(),
"confidence": round(float(confidence), 4),
})
elements.sort(key=lambda e: (e["y"], e["x"]))
return elements
except ImportError:
return None
except Exception:
return None
def _check_ocr_availability(self) -> bool:
try:
import easyocr
return True
except ImportError:
pass
try:
import paddleocr
return True
except ImportError:
pass
return False
# ========================================================================
# 阶段2: 字段精确提取
# ========================================================================
def _extract_field(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""按优先级尝试四种策略提取单个字段。
策略优先级:
1. 精确键值对匹配
2. 模糊键值对匹配
3. 正则模式匹配
4. 表格结构匹配
"""
strategies = [
("exact_match", self._exact_kv_match),
("kv_pair", self._fuzzy_kv_match),
("regex", self._regex_match),
("table_match", self._table_match),
]
for method_name, strategy_fn in strategies:
result = strategy_fn(field_name, elements)
if result and result.field_value:
result.extraction_method = method_name
return result
return None
# -----------------------------------------------------------------------
# 策略1: 精确键值对匹配
# -----------------------------------------------------------------------
def _exact_kv_match(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""精确键值对匹配: 识别"字段名: 值""字段名:值"模式。
在同一文本元素中查找 "字段名" 后紧跟分隔符 + "" 的模式。
如 OCR 识别出 "发票代码: 12345678" 这一整个元素。
"""
separators = [":", "", "=", "-", "", "", "\t", "|"]
field_name_clean = field_name.strip()
for elem in elements:
text = elem.text
if field_name_clean not in text:
continue
for sep in separators:
pattern = re.escape(field_name_clean) + r"\s*" + re.escape(sep) + r"\s*(.+)"
m = re.search(pattern, text)
if m:
value = m.group(1).strip()
if value:
return ExtractedField(
field_name=field_name,
field_value=value,
bbox=elem.bbox,
confidence=0.95,
extraction_method="",
)
simple_pattern = re.escape(field_name_clean) + r"\s+(.+)"
m = re.search(simple_pattern, text)
if m:
value = m.group(1).strip()
if value and value != field_name_clean:
return ExtractedField(
field_name=field_name,
field_value=value,
bbox=elem.bbox,
confidence=0.85,
extraction_method="",
)
return None
# -----------------------------------------------------------------------
# 策略2: 模糊键值对匹配
# -----------------------------------------------------------------------
def _fuzzy_kv_match(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""模糊键值对匹配: 字段名和值分布在相邻的文本元素中。
找到含字段名的元素后,在同一行或相邻元素中查找值。
"""
field_name_clean = field_name.strip()
field_elem = None
for elem in elements:
if field_name_clean in elem.text:
field_elem = elem
break
if field_elem is None:
matching = []
for elem in elements:
sim = self._text_similarity(field_name_clean, elem.text)
if sim > 0.6:
matching.append((sim, elem))
if matching:
matching.sort(key=lambda x: x[0], reverse=True)
field_elem = matching[0][1]
if field_elem is None:
return None
candidates = []
for elem in elements:
if elem is field_elem:
continue
candidates.append(elem)
same_row = []
for elem in candidates:
if abs(elem.center_y - field_elem.center_y) < field_elem.height * 1.5:
same_row.append(elem)
if same_row:
same_row.sort(key=lambda e: e.x_min)
for elem in same_row:
if elem.x_min > field_elem.x_max:
return ExtractedField(
field_name=field_name,
field_value=elem.text,
bbox=elem.bbox,
confidence=0.75,
extraction_method="",
)
nearest = None
nearest_dist = float("inf")
for elem in candidates:
if elem.y_min > field_elem.y_max:
dy = elem.y_min - field_elem.y_max
dx = abs(elem.center_x - field_elem.center_x)
dist = dy + dx * 0.3
if dist < nearest_dist and dy < field_elem.height * 3:
nearest_dist = dist
nearest = elem
if nearest:
return ExtractedField(
field_name=field_name,
field_value=nearest.text,
bbox=nearest.bbox,
confidence=0.6,
extraction_method="",
)
return None
# -----------------------------------------------------------------------
# 策略3: 正则模式匹配
# -----------------------------------------------------------------------
PREDEFINED_PATTERNS: dict[str, str] = {
"发票代码": r"[0-9A-Za-z]{10,12}",
"发票号码": r"\d{8}",
"合计金额": r"[\d,]+\.?\d*",
"金额": r"[\d,]+\.?\d*",
"开票日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
"日期": r"\d{4}[年/\-]\d{1,2}[月/\-]\d{1,2}日?",
"校验码": r"[0-9A-Fa-f]{5,20}",
"总价": r"[\d,]+\.?\d*",
"总金额": r"[\d,]+\.?\d*",
"价税合计": r"[\d,]+\.?\d*",
"数量": r"\d+\.?\d*",
"单价": r"[\d,]+\.?\d*",
"税率": r"\d+\.?\d*%?",
}
def _regex_match(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""正则模式匹配: 根据字段名选择预定义的正则模式,在所有元素中搜索。"""
pattern = self.PREDEFINED_PATTERNS.get(field_name)
if not pattern:
for key, pat in self.PREDEFINED_PATTERNS.items():
if key in field_name or field_name in key:
pattern = pat
break
if not pattern:
return None
compiled = re.compile(r"^\s*" + pattern + r"\s*$")
for elem in elements:
if compiled.match(elem.text):
return ExtractedField(
field_name=field_name,
field_value=elem.text.strip(),
bbox=elem.bbox,
confidence=0.7,
extraction_method="",
)
compiled_partial = re.compile(pattern)
for elem in elements:
m = compiled_partial.search(elem.text)
if m:
return ExtractedField(
field_name=field_name,
field_value=m.group(0),
bbox=elem.bbox,
confidence=0.6,
extraction_method="",
)
return None
# -----------------------------------------------------------------------
# 策略4: 表格结构匹配
# -----------------------------------------------------------------------
def _table_match(
self,
field_name: str,
elements: list[OcrTextElement],
) -> Optional[ExtractedField]:
"""表格结构匹配: 将元素按行列分组,查找表头-值对应关系。
识别逻辑:
1. 将元素按 Y 坐标分组为""
2. 查找包含 field_name 的表头行
3. 在表头列对应的数据行中取值
"""
if len(elements) < 3:
return None
rows = self._group_elements_by_rows(elements)
if len(rows) < 2:
return None
header_row_idx = -1
header_col_idx = -1
for ri, row in enumerate(rows):
for ci, elem in enumerate(row):
if field_name in elem.text:
header_row_idx = ri
header_col_idx = ci
break
if header_row_idx >= 0:
break
if header_row_idx < 0:
for ri, row in enumerate(rows):
for ci, elem in enumerate(row):
sim = self._text_similarity(field_name, elem.text)
if sim > 0.5:
header_row_idx = ri
header_col_idx = ci
break
if header_row_idx >= 0:
break
if header_row_idx < 0:
return None
data_rows = rows[header_row_idx + 1:]
if not data_rows:
data_rows = [rows[header_row_idx]]
matched_elem = None
for row in data_rows:
if header_col_idx < len(row):
matched_elem = row[header_col_idx]
break
closest = None
min_dist = float("inf")
header_x = float("inf")
if header_col_idx < len(rows[header_row_idx]):
header_x = rows[header_row_idx][header_col_idx].center_x
for elem in row:
dist = abs(elem.center_x - header_x)
if dist < min_dist:
min_dist = dist
closest = elem
if closest:
matched_elem = closest
break
if matched_elem and matched_elem.text != field_name:
return ExtractedField(
field_name=field_name,
field_value=matched_elem.text,
bbox=matched_elem.bbox,
confidence=0.55,
extraction_method="",
)
return None
# ========================================================================
# 工具方法
# ========================================================================
@staticmethod
def _group_elements_by_rows(
elements: list[OcrTextElement],
) -> list[list[OcrTextElement]]:
"""将元素按 Y 坐标分组为行(容差为元素平均高度的一半)。"""
if not elements:
return []
avg_height = sum(e.height for e in elements) / len(elements)
tolerance = max(avg_height * 0.5, 5.0)
rows = []
current_row = [elements[0]]
for elem in elements[1:]:
prev_center_y = current_row[0].center_y
if abs(elem.center_y - prev_center_y) < tolerance:
current_row.append(elem)
else:
current_row.sort(key=lambda e: e.x_min)
rows.append(current_row)
current_row = [elem]
if current_row:
current_row.sort(key=lambda e: e.x_min)
rows.append(current_row)
return rows
@staticmethod
def _text_similarity(text1: str, text2: str) -> float:
"""计算两个文本的简单相似度(公共字符比例)。"""
if not text1 or not text2:
return 0.0
t1 = text1.lower().strip()
t2 = text2.lower().strip()
if t1 == t2:
return 1.0
if t1 in t2 or t2 in t1:
return 0.8
chars1 = set(t1)
chars2 = set(t2)
if not chars1:
return 0.0
intersection = chars1 & chars2
return len(intersection) / len(chars1)
def extract_ocr_fields(
file_path: str,
target_fields: list[str],
use_gpu: bool = False,
confidence_threshold: float = 0.5,
) -> dict:
"""便捷函数: 对指定图片执行 OCR 字段提取。
Args:
file_path: 图片文件路径
target_fields: 目标字段名列表
use_gpu: 是否使用 GPU 加速
confidence_threshold: OCR 置信度阈值
Returns:
提取结果字典
"""
extractor = OcrExtractor(
use_gpu=use_gpu,
confidence_threshold=confidence_threshold,
)
return extractor.extract(file_path, target_fields)
def extract_from_layout(
layout_result: dict,
target_fields: list[str],
confidence_threshold: float = 0.5,
) -> dict:
"""便捷函数: 从已有的版面分析结果中提取字段。
Args:
layout_result: analyze_layout() 的返回值
target_fields: 目标字段名列表
confidence_threshold: OCR 置信度阈值
Returns:
提取结果字典
"""
extractor = OcrExtractor(confidence_threshold=confidence_threshold)
return extractor.extract_from_layout_result(layout_result, target_fields)
+543
View File
@@ -0,0 +1,543 @@
"""OCR 单据字段提取器单元测试。
覆盖:
- OcrTextElement / ExtractedField / ExtractionResult 数据结构
- 四种提取策略的独立测试
- 坐标计算正确性
- 边界情况(空元素、无匹配、部分匹配)
"""
import sys
import os
import math
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
from backend.ocr_extractor import (
OcrTextElement,
ExtractedField,
ExtractionResult,
OcrExtractor,
extract_ocr_fields,
extract_from_layout,
)
class TestOcrTextElement:
"""测试 OcrTextElement 数据类。"""
def test_bbox_property(self):
elem = OcrTextElement(
text="发票代码",
x_min=100.0,
y_min=50.0,
x_max=250.0,
y_max=80.0,
)
assert elem.bbox == [100.0, 50.0, 250.0, 80.0]
def test_center_coordinates(self):
elem = OcrTextElement(
text="测试",
x_min=100.0,
y_min=50.0,
x_max=200.0,
y_max=100.0,
)
assert elem.center_x == 150.0
assert elem.center_y == 75.0
def test_width_height(self):
elem = OcrTextElement(
text="测试",
x_min=10.0,
y_min=20.0,
x_max=100.0,
y_max=70.0,
)
assert elem.width == 90.0
assert elem.height == 50.0
def test_default_confidence(self):
elem = OcrTextElement(
text="测试",
x_min=0,
y_min=0,
x_max=10,
y_max=10,
)
assert elem.confidence == 1.0
class TestExtractedField:
"""测试 ExtractedField 数据类。"""
def test_field_creation(self):
field = ExtractedField(
field_name="发票代码",
field_value="1234567890",
bbox=[100, 50, 200, 80],
confidence=0.95,
extraction_method="exact_match",
)
assert field.field_name == "发票代码"
assert field.field_value == "1234567890"
assert field.bbox == [100, 50, 200, 80]
assert field.confidence == 0.95
assert field.extraction_method == "exact_match"
class TestExtractionResult:
"""测试 ExtractionResult 数据类。"""
def test_to_dict_basic(self):
result = ExtractionResult(
file_path="/test/invoice.png",
image_size=(800, 600),
ocr_available=True,
)
d = result.to_dict()
assert d["file_path"] == "/test/invoice.png"
assert d["image_size"] == (800, 600)
assert d["ocr_available"] is True
assert d["fields"] == []
assert d["total_elements"] == 0
assert d["errors"] == []
def test_to_dict_with_fields(self):
result = ExtractionResult(
file_path="/test/invoice.png",
image_size=(800, 600),
ocr_available=True,
)
result.fields.append(
ExtractedField(
field_name="发票代码",
field_value="1234567890",
bbox=[100, 50, 200, 80],
confidence=0.95,
extraction_method="exact_match",
)
)
d = result.to_dict()
assert len(d["fields"]) == 1
assert d["fields"][0]["field_name"] == "发票代码"
assert d["fields"][0]["field_value"] == "1234567890"
assert d["fields"][0]["bbox"] == [100, 50, 200, 80]
assert d["fields"][0]["confidence"] == 0.95
assert d["fields"][0]["extraction_method"] == "exact_match"
def test_to_dict_with_error(self):
result = ExtractionResult(
file_path="/test/missing.png",
image_size=(0, 0),
errors=["文件不存在: /test/missing.png"],
)
d = result.to_dict()
assert d["errors"] == ["文件不存在: /test/missing.png"]
class TestElementGrouping:
"""测试元素行分组功能。"""
def test_group_single_row(self):
elements = [
OcrTextElement("A", 10, 50, 50, 70),
OcrTextElement("B", 60, 50, 110, 70),
OcrTextElement("C", 120, 50, 170, 70),
]
rows = OcrExtractor._group_elements_by_rows(elements)
assert len(rows) == 1
assert len(rows[0]) == 3
def test_group_multiple_rows(self):
elements = [
OcrTextElement("A1", 10, 50, 50, 70),
OcrTextElement("B1", 60, 50, 110, 70),
OcrTextElement("A2", 10, 120, 50, 140),
OcrTextElement("B2", 60, 120, 110, 140),
OcrTextElement("A3", 10, 200, 50, 220),
]
rows = OcrExtractor._group_elements_by_rows(elements)
assert len(rows) == 3
assert len(rows[0]) == 2
assert len(rows[1]) == 2
assert len(rows[2]) == 1
def test_group_empty(self):
rows = OcrExtractor._group_elements_by_rows([])
assert rows == []
def test_group_single_element(self):
rows = OcrExtractor._group_elements_by_rows([
OcrTextElement("X", 10, 50, 50, 70)
])
assert len(rows) == 1
assert len(rows[0]) == 1
class TestTextSimilarity:
"""测试文本相似度计算。"""
def test_exact_match(self):
sim = OcrExtractor._text_similarity("发票代码", "发票代码")
assert sim == 1.0
def test_partial_match(self):
sim = OcrExtractor._text_similarity("发票代码", "代码")
assert sim > 0.5
def test_no_match(self):
sim = OcrExtractor._text_similarity("发票代码", "xyz")
assert sim == 0.0
def test_empty_strings(self):
assert OcrExtractor._text_similarity("", "abc") == 0.0
assert OcrExtractor._text_similarity("abc", "") == 0.0
assert OcrExtractor._text_similarity("", "") == 0.0
def test_substring_match(self):
sim = OcrExtractor._text_similarity("代码", "发票代码")
assert sim > 0.7
class TestExactKVMatch:
"""测试策略1: 精确键值对匹配。"""
def test_colon_separator(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
]
result = extractor._exact_kv_match("发票代码", elements)
assert result is not None
assert result.field_value == "1234567890"
assert result.confidence == 0.95
def test_chinese_colon(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("发票号码:87654321", 50, 50, 300, 80),
]
result = extractor._exact_kv_match("发票号码", elements)
assert result is not None
assert result.field_value == "87654321"
def test_space_separator(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("合计金额 999.00", 50, 50, 300, 80),
]
result = extractor._exact_kv_match("合计金额", elements)
assert result is not None
assert result.field_value == "999.00"
def test_equals_separator(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("数量=5", 50, 50, 200, 80),
]
result = extractor._exact_kv_match("数量", elements)
assert result is not None
assert result.field_value == "5"
def test_field_not_found(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("其他内容: 123", 50, 50, 200, 80),
]
result = extractor._exact_kv_match("发票代码", elements)
assert result is None
class TestFuzzyKVMatch:
"""测试策略2: 模糊键值对匹配。"""
def test_adjacent_same_row(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码", 50, 50, 150, 80),
OcrTextElement("1234567890", 200, 50, 350, 80),
]
result = extractor._fuzzy_kv_match("发票代码", elements)
assert result is not None
assert result.field_value == "1234567890"
assert result.confidence == 0.75
def test_adjacent_next_row(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码", 50, 50, 150, 80),
OcrTextElement("1234567890", 50, 100, 200, 130),
]
result = extractor._fuzzy_kv_match("发票代码", elements)
assert result is not None
assert result.field_value == "1234567890"
def test_field_name_not_found(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("其他信息", 50, 50, 150, 80),
OcrTextElement("1234567890", 200, 50, 350, 80),
]
result = extractor._fuzzy_kv_match("发票代码", elements)
assert result is None
class TestRegexMatch:
"""测试策略3: 正则模式匹配。"""
def test_invoice_code_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("1234567890", 100, 50, 250, 80),
]
result = extractor._regex_match("发票代码", elements)
assert result is not None
assert result.field_value == "1234567890"
def test_invoice_number_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("87654321", 100, 50, 200, 80),
]
result = extractor._regex_match("发票号码", elements)
assert result is not None
assert result.field_value == "87654321"
def test_amount_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("1,234.56", 100, 50, 200, 80),
]
result = extractor._regex_match("合计金额", elements)
assert result is not None
assert "1,234.56" in result.field_value
def test_date_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("2024年1月15日", 100, 50, 250, 80),
]
result = extractor._regex_match("开票日期", elements)
assert result is not None
assert "2024" in result.field_value
def test_unknown_field_no_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("随便什么内容", 100, 50, 250, 80),
]
result = extractor._regex_match("未知字段", elements)
assert result is None
class TestTableMatch:
"""测试策略4: 表格结构匹配。"""
def test_simple_table(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("名称", 50, 50, 150, 80),
OcrTextElement("数量", 200, 50, 300, 80),
OcrTextElement("单价", 350, 50, 450, 80),
OcrTextElement("商品A", 50, 100, 150, 130),
OcrTextElement("2", 200, 100, 250, 130),
OcrTextElement("10.00", 350, 100, 420, 130),
]
result = extractor._table_match("数量", elements)
assert result is not None
assert result.field_value == "2"
def test_table_with_fuzzy_header(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("品名", 50, 50, 120, 80),
OcrTextElement("金额", 200, 50, 300, 80),
OcrTextElement("苹果", 50, 100, 120, 130),
OcrTextElement("5.00", 200, 100, 260, 130),
]
result = extractor._table_match("合计金额", elements)
# 金额列可能匹配到 "金额"
if result:
assert result.field_value == "5.00"
def test_table_header_not_found(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("A", 50, 50, 100, 80),
OcrTextElement("B", 150, 50, 200, 80),
]
result = extractor._table_match("发票代码", elements)
assert result is None
def test_table_too_few_elements(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("名称", 50, 50, 150, 80),
]
result = extractor._table_match("名称", elements)
assert result is None
class TestCoordinateCorrectness:
"""测试坐标计算正确性。"""
def test_bbox_origin_top_left(self):
"""验证坐标系统以左上角为原点。"""
elem = OcrTextElement("A", 0, 0, 50, 20)
assert elem.x_min == 0
assert elem.y_min == 0
assert elem.bbox[0] == 0
assert elem.bbox[1] == 0
def test_bbox_conversion(self):
"""验证从 (x, y, w, h) 到 [x_min, y_min, x_max, y_max] 的转换。"""
x, y, w, h = 100, 200, 300, 50
elem = OcrTextElement("test", x_min=x, y_min=y, x_max=x + w, y_max=y + h)
assert elem.bbox == [x, y, x + w, y + h]
assert elem.x_max - elem.x_min == w
assert elem.y_max - elem.y_min == h
def test_multiple_elements_coordinate_independence(self):
"""验证多个元素的坐标互不干扰。"""
elements = [
OcrTextElement("A", 10, 20, 60, 40),
OcrTextElement("B", 100, 200, 180, 240),
]
assert elements[0].bbox == [10, 20, 60, 40]
assert elements[1].bbox == [100, 200, 180, 240]
assert elements[0].x_min != elements[1].x_min
def test_center_calculation(self):
"""验证中心点计算。"""
elem = OcrTextElement("test", 0, 0, 100, 100)
assert elem.center_x == 50.0
assert elem.center_y == 50.0
class TestExtractionPipeline:
"""测试完整的提取流水线。"""
def test_priority_order_exact_first(self):
"""验证策略优先级: exact_match 优先于其他策略。"""
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
]
result = extractor._extract_field("发票代码", elements)
assert result is not None
assert result.extraction_method == "exact_match"
def test_fallback_to_fuzzy(self):
"""验证精确匹配失败后回退到模糊匹配。"""
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码", 50, 50, 150, 80),
OcrTextElement("1234567890", 200, 50, 350, 80),
]
result = extractor._extract_field("发票代码", elements)
assert result is not None
assert result.extraction_method == "kv_pair"
def test_all_fields_empty(self):
"""验证所有字段提取失败时返回空结果。"""
extractor = OcrExtractor()
elements = [
OcrTextElement("一些不相关的文本", 50, 50, 300, 80),
OcrTextElement("更多随机内容", 50, 100, 300, 130),
]
result = extractor._extract_field("发票代码", elements)
assert result is None
def test_empty_elements_list(self):
"""验证空元素列表时正常处理。"""
extractor = OcrExtractor()
result = extractor._extract_field("发票代码", [])
assert result is None
def test_extraction_with_confidence(self):
"""验证提取结果的置信度在合理范围内。"""
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
]
result = extractor._extract_field("发票代码", elements)
assert result is not None
assert 0.0 < result.confidence <= 1.0
class TestExtractFromLayout:
"""测试从 layout 结果中提取字段。"""
def test_basic_layout_extraction(self):
layout_result = {
"image_size": (800, 600),
"template_type": "full_a4",
"rows": [
{
"y_center": 50,
"elements": [
{"x": 50, "y": 30, "w": 150, "h": 30, "text": "发票代码:"},
{"x": 250, "y": 30, "w": 200, "h": 30, "text": "1234567890"},
],
},
{
"y_center": 80,
"elements": [
{"x": 50, "y": 60, "w": 150, "h": 30, "text": "合计金额:"},
{"x": 250, "y": 60, "w": 100, "h": 30, "text": "999.00"},
],
},
],
}
result = extract_from_layout(layout_result, ["发票代码"])
assert result["ocr_available"] is True
assert result["image_size"] == (800, 600)
def test_empty_layout(self):
result = extract_from_layout({}, ["发票代码"])
assert result["ocr_available"] is False
assert len(result["errors"]) > 0
class TestOcrExtractorFileNotFound:
"""测试文件不存在的情况。"""
def test_missing_file(self):
extractor = OcrExtractor()
result = extractor.extract("/nonexistent/file.png", ["发票代码"])
assert result["ocr_available"] is False
assert len(result["errors"]) > 0
assert "文件不存在" in result["errors"][0]
class TestConvenienceFunctions:
"""测试便捷函数。"""
def test_extract_ocr_fields_missing_file(self):
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
assert len(result["errors"]) > 0
def test_extract_from_layout_with_partial_rows(self):
layout_result = {
"image_size": (1200, 800),
"template_type": "partial_rows",
"rows": [
{
"y_center": 100,
"elements": [
{"x": 100, "y": 80, "w": 120, "h": 40, "text": "发票代码"},
{"x": 300, "y": 80, "w": 180, "h": 40, "text": "NO123456"},
],
},
],
}
result = extract_from_layout(layout_result, ["发票代码"])
assert result["ocr_available"] is True
assert len(result["fields"]) == 1
assert result["fields"][0]["field_name"] == "发票代码"
assert result["fields"][0]["field_value"] == "NO123456"