Files
agent_jrxml/tests/test_ocr_extraction.py
panda c9f003e1b7 feat: 新增 OCR 单据字段精确提取模块
- 新增 backend/ocr_extractor.py: 两阶段提取流水线 (文档分析 + 字段提取)
- 四种提取策略: 精确KV匹配/模糊KV匹配/正则模式/表格结构匹配
- agent/state.py: 新增 ocr_extraction_result 和 uploaded_file_path 字段
- agent/nodes.py: process_input() 中自动触发 OCR 提取钩子
- app.py: 文件上传时保留图片路径, 总结卡片中展示提取结果
- .env.example: 新增 OCR_USE_GPU / OCR_CONFIDENCE_THRESHOLD 配置项
- tests/test_ocr_extraction.py: 48 个单元测试全部通过
2026-05-20 08:06:55 +08:00

544 lines
18 KiB
Python

"""OCR 单据字段提取器单元测试。
覆盖:
- OcrTextElement / ExtractedField / ExtractionResult 数据结构
- 四种提取策略的独立测试
- 坐标计算正确性
- 边界情况(空元素、无匹配、部分匹配)
"""
import sys
import os
import math
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pytest
from backend.ocr_extractor import (
OcrTextElement,
ExtractedField,
ExtractionResult,
OcrExtractor,
extract_ocr_fields,
extract_from_layout,
)
class TestOcrTextElement:
"""测试 OcrTextElement 数据类。"""
def test_bbox_property(self):
elem = OcrTextElement(
text="发票代码",
x_min=100.0,
y_min=50.0,
x_max=250.0,
y_max=80.0,
)
assert elem.bbox == [100.0, 50.0, 250.0, 80.0]
def test_center_coordinates(self):
elem = OcrTextElement(
text="测试",
x_min=100.0,
y_min=50.0,
x_max=200.0,
y_max=100.0,
)
assert elem.center_x == 150.0
assert elem.center_y == 75.0
def test_width_height(self):
elem = OcrTextElement(
text="测试",
x_min=10.0,
y_min=20.0,
x_max=100.0,
y_max=70.0,
)
assert elem.width == 90.0
assert elem.height == 50.0
def test_default_confidence(self):
elem = OcrTextElement(
text="测试",
x_min=0,
y_min=0,
x_max=10,
y_max=10,
)
assert elem.confidence == 1.0
class TestExtractedField:
"""测试 ExtractedField 数据类。"""
def test_field_creation(self):
field = ExtractedField(
field_name="发票代码",
field_value="1234567890",
bbox=[100, 50, 200, 80],
confidence=0.95,
extraction_method="exact_match",
)
assert field.field_name == "发票代码"
assert field.field_value == "1234567890"
assert field.bbox == [100, 50, 200, 80]
assert field.confidence == 0.95
assert field.extraction_method == "exact_match"
class TestExtractionResult:
"""测试 ExtractionResult 数据类。"""
def test_to_dict_basic(self):
result = ExtractionResult(
file_path="/test/invoice.png",
image_size=(800, 600),
ocr_available=True,
)
d = result.to_dict()
assert d["file_path"] == "/test/invoice.png"
assert d["image_size"] == (800, 600)
assert d["ocr_available"] is True
assert d["fields"] == []
assert d["total_elements"] == 0
assert d["errors"] == []
def test_to_dict_with_fields(self):
result = ExtractionResult(
file_path="/test/invoice.png",
image_size=(800, 600),
ocr_available=True,
)
result.fields.append(
ExtractedField(
field_name="发票代码",
field_value="1234567890",
bbox=[100, 50, 200, 80],
confidence=0.95,
extraction_method="exact_match",
)
)
d = result.to_dict()
assert len(d["fields"]) == 1
assert d["fields"][0]["field_name"] == "发票代码"
assert d["fields"][0]["field_value"] == "1234567890"
assert d["fields"][0]["bbox"] == [100, 50, 200, 80]
assert d["fields"][0]["confidence"] == 0.95
assert d["fields"][0]["extraction_method"] == "exact_match"
def test_to_dict_with_error(self):
result = ExtractionResult(
file_path="/test/missing.png",
image_size=(0, 0),
errors=["文件不存在: /test/missing.png"],
)
d = result.to_dict()
assert d["errors"] == ["文件不存在: /test/missing.png"]
class TestElementGrouping:
"""测试元素行分组功能。"""
def test_group_single_row(self):
elements = [
OcrTextElement("A", 10, 50, 50, 70),
OcrTextElement("B", 60, 50, 110, 70),
OcrTextElement("C", 120, 50, 170, 70),
]
rows = OcrExtractor._group_elements_by_rows(elements)
assert len(rows) == 1
assert len(rows[0]) == 3
def test_group_multiple_rows(self):
elements = [
OcrTextElement("A1", 10, 50, 50, 70),
OcrTextElement("B1", 60, 50, 110, 70),
OcrTextElement("A2", 10, 120, 50, 140),
OcrTextElement("B2", 60, 120, 110, 140),
OcrTextElement("A3", 10, 200, 50, 220),
]
rows = OcrExtractor._group_elements_by_rows(elements)
assert len(rows) == 3
assert len(rows[0]) == 2
assert len(rows[1]) == 2
assert len(rows[2]) == 1
def test_group_empty(self):
rows = OcrExtractor._group_elements_by_rows([])
assert rows == []
def test_group_single_element(self):
rows = OcrExtractor._group_elements_by_rows([
OcrTextElement("X", 10, 50, 50, 70)
])
assert len(rows) == 1
assert len(rows[0]) == 1
class TestTextSimilarity:
"""测试文本相似度计算。"""
def test_exact_match(self):
sim = OcrExtractor._text_similarity("发票代码", "发票代码")
assert sim == 1.0
def test_partial_match(self):
sim = OcrExtractor._text_similarity("发票代码", "代码")
assert sim > 0.5
def test_no_match(self):
sim = OcrExtractor._text_similarity("发票代码", "xyz")
assert sim == 0.0
def test_empty_strings(self):
assert OcrExtractor._text_similarity("", "abc") == 0.0
assert OcrExtractor._text_similarity("abc", "") == 0.0
assert OcrExtractor._text_similarity("", "") == 0.0
def test_substring_match(self):
sim = OcrExtractor._text_similarity("代码", "发票代码")
assert sim > 0.7
class TestExactKVMatch:
"""测试策略1: 精确键值对匹配。"""
def test_colon_separator(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
]
result = extractor._exact_kv_match("发票代码", elements)
assert result is not None
assert result.field_value == "1234567890"
assert result.confidence == 0.95
def test_chinese_colon(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("发票号码:87654321", 50, 50, 300, 80),
]
result = extractor._exact_kv_match("发票号码", elements)
assert result is not None
assert result.field_value == "87654321"
def test_space_separator(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("合计金额 999.00", 50, 50, 300, 80),
]
result = extractor._exact_kv_match("合计金额", elements)
assert result is not None
assert result.field_value == "999.00"
def test_equals_separator(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("数量=5", 50, 50, 200, 80),
]
result = extractor._exact_kv_match("数量", elements)
assert result is not None
assert result.field_value == "5"
def test_field_not_found(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("其他内容: 123", 50, 50, 200, 80),
]
result = extractor._exact_kv_match("发票代码", elements)
assert result is None
class TestFuzzyKVMatch:
"""测试策略2: 模糊键值对匹配。"""
def test_adjacent_same_row(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码", 50, 50, 150, 80),
OcrTextElement("1234567890", 200, 50, 350, 80),
]
result = extractor._fuzzy_kv_match("发票代码", elements)
assert result is not None
assert result.field_value == "1234567890"
assert result.confidence == 0.75
def test_adjacent_next_row(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码", 50, 50, 150, 80),
OcrTextElement("1234567890", 50, 100, 200, 130),
]
result = extractor._fuzzy_kv_match("发票代码", elements)
assert result is not None
assert result.field_value == "1234567890"
def test_field_name_not_found(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("其他信息", 50, 50, 150, 80),
OcrTextElement("1234567890", 200, 50, 350, 80),
]
result = extractor._fuzzy_kv_match("发票代码", elements)
assert result is None
class TestRegexMatch:
"""测试策略3: 正则模式匹配。"""
def test_invoice_code_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("1234567890", 100, 50, 250, 80),
]
result = extractor._regex_match("发票代码", elements)
assert result is not None
assert result.field_value == "1234567890"
def test_invoice_number_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("87654321", 100, 50, 200, 80),
]
result = extractor._regex_match("发票号码", elements)
assert result is not None
assert result.field_value == "87654321"
def test_amount_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("1,234.56", 100, 50, 200, 80),
]
result = extractor._regex_match("合计金额", elements)
assert result is not None
assert "1,234.56" in result.field_value
def test_date_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("2024年1月15日", 100, 50, 250, 80),
]
result = extractor._regex_match("开票日期", elements)
assert result is not None
assert "2024" in result.field_value
def test_unknown_field_no_pattern(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("随便什么内容", 100, 50, 250, 80),
]
result = extractor._regex_match("未知字段", elements)
assert result is None
class TestTableMatch:
"""测试策略4: 表格结构匹配。"""
def test_simple_table(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("名称", 50, 50, 150, 80),
OcrTextElement("数量", 200, 50, 300, 80),
OcrTextElement("单价", 350, 50, 450, 80),
OcrTextElement("商品A", 50, 100, 150, 130),
OcrTextElement("2", 200, 100, 250, 130),
OcrTextElement("10.00", 350, 100, 420, 130),
]
result = extractor._table_match("数量", elements)
assert result is not None
assert result.field_value == "2"
def test_table_with_fuzzy_header(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("品名", 50, 50, 120, 80),
OcrTextElement("金额", 200, 50, 300, 80),
OcrTextElement("苹果", 50, 100, 120, 130),
OcrTextElement("5.00", 200, 100, 260, 130),
]
result = extractor._table_match("合计金额", elements)
# 金额列可能匹配到 "金额"
if result:
assert result.field_value == "5.00"
def test_table_header_not_found(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("A", 50, 50, 100, 80),
OcrTextElement("B", 150, 50, 200, 80),
]
result = extractor._table_match("发票代码", elements)
assert result is None
def test_table_too_few_elements(self):
extractor = OcrExtractor()
elements = [
OcrTextElement("名称", 50, 50, 150, 80),
]
result = extractor._table_match("名称", elements)
assert result is None
class TestCoordinateCorrectness:
"""测试坐标计算正确性。"""
def test_bbox_origin_top_left(self):
"""验证坐标系统以左上角为原点。"""
elem = OcrTextElement("A", 0, 0, 50, 20)
assert elem.x_min == 0
assert elem.y_min == 0
assert elem.bbox[0] == 0
assert elem.bbox[1] == 0
def test_bbox_conversion(self):
"""验证从 (x, y, w, h) 到 [x_min, y_min, x_max, y_max] 的转换。"""
x, y, w, h = 100, 200, 300, 50
elem = OcrTextElement("test", x_min=x, y_min=y, x_max=x + w, y_max=y + h)
assert elem.bbox == [x, y, x + w, y + h]
assert elem.x_max - elem.x_min == w
assert elem.y_max - elem.y_min == h
def test_multiple_elements_coordinate_independence(self):
"""验证多个元素的坐标互不干扰。"""
elements = [
OcrTextElement("A", 10, 20, 60, 40),
OcrTextElement("B", 100, 200, 180, 240),
]
assert elements[0].bbox == [10, 20, 60, 40]
assert elements[1].bbox == [100, 200, 180, 240]
assert elements[0].x_min != elements[1].x_min
def test_center_calculation(self):
"""验证中心点计算。"""
elem = OcrTextElement("test", 0, 0, 100, 100)
assert elem.center_x == 50.0
assert elem.center_y == 50.0
class TestExtractionPipeline:
"""测试完整的提取流水线。"""
def test_priority_order_exact_first(self):
"""验证策略优先级: exact_match 优先于其他策略。"""
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
]
result = extractor._extract_field("发票代码", elements)
assert result is not None
assert result.extraction_method == "exact_match"
def test_fallback_to_fuzzy(self):
"""验证精确匹配失败后回退到模糊匹配。"""
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码", 50, 50, 150, 80),
OcrTextElement("1234567890", 200, 50, 350, 80),
]
result = extractor._extract_field("发票代码", elements)
assert result is not None
assert result.extraction_method == "kv_pair"
def test_all_fields_empty(self):
"""验证所有字段提取失败时返回空结果。"""
extractor = OcrExtractor()
elements = [
OcrTextElement("一些不相关的文本", 50, 50, 300, 80),
OcrTextElement("更多随机内容", 50, 100, 300, 130),
]
result = extractor._extract_field("发票代码", elements)
assert result is None
def test_empty_elements_list(self):
"""验证空元素列表时正常处理。"""
extractor = OcrExtractor()
result = extractor._extract_field("发票代码", [])
assert result is None
def test_extraction_with_confidence(self):
"""验证提取结果的置信度在合理范围内。"""
extractor = OcrExtractor()
elements = [
OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80),
]
result = extractor._extract_field("发票代码", elements)
assert result is not None
assert 0.0 < result.confidence <= 1.0
class TestExtractFromLayout:
"""测试从 layout 结果中提取字段。"""
def test_basic_layout_extraction(self):
layout_result = {
"image_size": (800, 600),
"template_type": "full_a4",
"rows": [
{
"y_center": 50,
"elements": [
{"x": 50, "y": 30, "w": 150, "h": 30, "text": "发票代码:"},
{"x": 250, "y": 30, "w": 200, "h": 30, "text": "1234567890"},
],
},
{
"y_center": 80,
"elements": [
{"x": 50, "y": 60, "w": 150, "h": 30, "text": "合计金额:"},
{"x": 250, "y": 60, "w": 100, "h": 30, "text": "999.00"},
],
},
],
}
result = extract_from_layout(layout_result, ["发票代码"])
assert result["ocr_available"] is True
assert result["image_size"] == (800, 600)
def test_empty_layout(self):
result = extract_from_layout({}, ["发票代码"])
assert result["ocr_available"] is False
assert len(result["errors"]) > 0
class TestOcrExtractorFileNotFound:
"""测试文件不存在的情况。"""
def test_missing_file(self):
extractor = OcrExtractor()
result = extractor.extract("/nonexistent/file.png", ["发票代码"])
assert result["ocr_available"] is False
assert len(result["errors"]) > 0
assert "文件不存在" in result["errors"][0]
class TestConvenienceFunctions:
"""测试便捷函数。"""
def test_extract_ocr_fields_missing_file(self):
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
assert len(result["errors"]) > 0
def test_extract_from_layout_with_partial_rows(self):
layout_result = {
"image_size": (1200, 800),
"template_type": "partial_rows",
"rows": [
{
"y_center": 100,
"elements": [
{"x": 100, "y": 80, "w": 120, "h": 40, "text": "发票代码"},
{"x": 300, "y": 80, "w": 180, "h": 40, "text": "NO123456"},
],
},
],
}
result = extract_from_layout(layout_result, ["发票代码"])
assert result["ocr_available"] is True
assert len(result["fields"]) == 1
assert result["fields"][0]["field_name"] == "发票代码"
assert result["fields"][0]["field_value"] == "NO123456"