"""OCR 单据字段提取器单元测试。 覆盖: - OcrTextElement / ExtractedField / ExtractionResult 数据结构 - 四种提取策略的独立测试 - 坐标计算正确性 - 边界情况(空元素、无匹配、部分匹配) """ import sys import os import math from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import pytest from backend.ocr_extractor import ( OcrTextElement, ExtractedField, ExtractionResult, OcrExtractor, extract_ocr_fields, extract_from_layout, ) class TestOcrTextElement: """测试 OcrTextElement 数据类。""" def test_bbox_property(self): elem = OcrTextElement( text="发票代码", x_min=100.0, y_min=50.0, x_max=250.0, y_max=80.0, ) assert elem.bbox == [100.0, 50.0, 250.0, 80.0] def test_center_coordinates(self): elem = OcrTextElement( text="测试", x_min=100.0, y_min=50.0, x_max=200.0, y_max=100.0, ) assert elem.center_x == 150.0 assert elem.center_y == 75.0 def test_width_height(self): elem = OcrTextElement( text="测试", x_min=10.0, y_min=20.0, x_max=100.0, y_max=70.0, ) assert elem.width == 90.0 assert elem.height == 50.0 def test_default_confidence(self): elem = OcrTextElement( text="测试", x_min=0, y_min=0, x_max=10, y_max=10, ) assert elem.confidence == 1.0 class TestExtractedField: """测试 ExtractedField 数据类。""" def test_field_creation(self): field = ExtractedField( field_name="发票代码", field_value="1234567890", bbox=[100, 50, 200, 80], confidence=0.95, extraction_method="exact_match", ) assert field.field_name == "发票代码" assert field.field_value == "1234567890" assert field.bbox == [100, 50, 200, 80] assert field.confidence == 0.95 assert field.extraction_method == "exact_match" class TestExtractionResult: """测试 ExtractionResult 数据类。""" def test_to_dict_basic(self): result = ExtractionResult( file_path="/test/invoice.png", image_size=(800, 600), ocr_available=True, ) d = result.to_dict() assert d["file_path"] == "/test/invoice.png" assert d["image_size"] == (800, 600) assert d["ocr_available"] is True assert d["fields"] == [] assert d["total_elements"] == 0 assert d["errors"] == [] def test_to_dict_with_fields(self): result = ExtractionResult( file_path="/test/invoice.png", image_size=(800, 600), ocr_available=True, ) result.fields.append( ExtractedField( field_name="发票代码", field_value="1234567890", bbox=[100, 50, 200, 80], confidence=0.95, extraction_method="exact_match", ) ) d = result.to_dict() assert len(d["fields"]) == 1 assert d["fields"][0]["field_name"] == "发票代码" assert d["fields"][0]["field_value"] == "1234567890" assert d["fields"][0]["bbox"] == [100, 50, 200, 80] assert d["fields"][0]["confidence"] == 0.95 assert d["fields"][0]["extraction_method"] == "exact_match" def test_to_dict_with_error(self): result = ExtractionResult( file_path="/test/missing.png", image_size=(0, 0), errors=["文件不存在: /test/missing.png"], ) d = result.to_dict() assert d["errors"] == ["文件不存在: /test/missing.png"] class TestElementGrouping: """测试元素行分组功能。""" def test_group_single_row(self): elements = [ OcrTextElement("A", 10, 50, 50, 70), OcrTextElement("B", 60, 50, 110, 70), OcrTextElement("C", 120, 50, 170, 70), ] rows = OcrExtractor._group_elements_by_rows(elements) assert len(rows) == 1 assert len(rows[0]) == 3 def test_group_multiple_rows(self): elements = [ OcrTextElement("A1", 10, 50, 50, 70), OcrTextElement("B1", 60, 50, 110, 70), OcrTextElement("A2", 10, 120, 50, 140), OcrTextElement("B2", 60, 120, 110, 140), OcrTextElement("A3", 10, 200, 50, 220), ] rows = OcrExtractor._group_elements_by_rows(elements) assert len(rows) == 3 assert len(rows[0]) == 2 assert len(rows[1]) == 2 assert len(rows[2]) == 1 def test_group_empty(self): rows = OcrExtractor._group_elements_by_rows([]) assert rows == [] def test_group_single_element(self): rows = OcrExtractor._group_elements_by_rows([ OcrTextElement("X", 10, 50, 50, 70) ]) assert len(rows) == 1 assert len(rows[0]) == 1 class TestTextSimilarity: """测试文本相似度计算。""" def test_exact_match(self): sim = OcrExtractor._text_similarity("发票代码", "发票代码") assert sim == 1.0 def test_partial_match(self): sim = OcrExtractor._text_similarity("发票代码", "代码") assert sim > 0.5 def test_no_match(self): sim = OcrExtractor._text_similarity("发票代码", "xyz") assert sim == 0.0 def test_empty_strings(self): assert OcrExtractor._text_similarity("", "abc") == 0.0 assert OcrExtractor._text_similarity("abc", "") == 0.0 assert OcrExtractor._text_similarity("", "") == 0.0 def test_substring_match(self): sim = OcrExtractor._text_similarity("代码", "发票代码") assert sim > 0.7 class TestExactKVMatch: """测试策略1: 精确键值对匹配。""" def test_colon_separator(self): extractor = OcrExtractor() elements = [ OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80), ] result = extractor._exact_kv_match("发票代码", elements) assert result is not None assert result.field_value == "1234567890" assert result.confidence == 0.95 def test_chinese_colon(self): extractor = OcrExtractor() elements = [ OcrTextElement("发票号码:87654321", 50, 50, 300, 80), ] result = extractor._exact_kv_match("发票号码", elements) assert result is not None assert result.field_value == "87654321" def test_space_separator(self): extractor = OcrExtractor() elements = [ OcrTextElement("合计金额 999.00", 50, 50, 300, 80), ] result = extractor._exact_kv_match("合计金额", elements) assert result is not None assert result.field_value == "999.00" def test_equals_separator(self): extractor = OcrExtractor() elements = [ OcrTextElement("数量=5", 50, 50, 200, 80), ] result = extractor._exact_kv_match("数量", elements) assert result is not None assert result.field_value == "5" def test_field_not_found(self): extractor = OcrExtractor() elements = [ OcrTextElement("其他内容: 123", 50, 50, 200, 80), ] result = extractor._exact_kv_match("发票代码", elements) assert result is None class TestFuzzyKVMatch: """测试策略2: 模糊键值对匹配。""" def test_adjacent_same_row(self): extractor = OcrExtractor() elements = [ OcrTextElement("发票代码", 50, 50, 150, 80), OcrTextElement("1234567890", 200, 50, 350, 80), ] result = extractor._fuzzy_kv_match("发票代码", elements) assert result is not None assert result.field_value == "1234567890" assert result.confidence == 0.75 def test_adjacent_next_row(self): extractor = OcrExtractor() elements = [ OcrTextElement("发票代码", 50, 50, 150, 80), OcrTextElement("1234567890", 50, 100, 200, 130), ] result = extractor._fuzzy_kv_match("发票代码", elements) assert result is not None assert result.field_value == "1234567890" def test_field_name_not_found(self): extractor = OcrExtractor() elements = [ OcrTextElement("其他信息", 50, 50, 150, 80), OcrTextElement("1234567890", 200, 50, 350, 80), ] result = extractor._fuzzy_kv_match("发票代码", elements) assert result is None class TestRegexMatch: """测试策略3: 正则模式匹配。""" def test_invoice_code_pattern(self): extractor = OcrExtractor() elements = [ OcrTextElement("1234567890", 100, 50, 250, 80), ] result = extractor._regex_match("发票代码", elements) assert result is not None assert result.field_value == "1234567890" def test_invoice_number_pattern(self): extractor = OcrExtractor() elements = [ OcrTextElement("87654321", 100, 50, 200, 80), ] result = extractor._regex_match("发票号码", elements) assert result is not None assert result.field_value == "87654321" def test_amount_pattern(self): extractor = OcrExtractor() elements = [ OcrTextElement("1,234.56", 100, 50, 200, 80), ] result = extractor._regex_match("合计金额", elements) assert result is not None assert "1,234.56" in result.field_value def test_date_pattern(self): extractor = OcrExtractor() elements = [ OcrTextElement("2024年1月15日", 100, 50, 250, 80), ] result = extractor._regex_match("开票日期", elements) assert result is not None assert "2024" in result.field_value def test_unknown_field_no_pattern(self): extractor = OcrExtractor() elements = [ OcrTextElement("随便什么内容", 100, 50, 250, 80), ] result = extractor._regex_match("未知字段", elements) assert result is None class TestTableMatch: """测试策略4: 表格结构匹配。""" def test_simple_table(self): extractor = OcrExtractor() elements = [ OcrTextElement("名称", 50, 50, 150, 80), OcrTextElement("数量", 200, 50, 300, 80), OcrTextElement("单价", 350, 50, 450, 80), OcrTextElement("商品A", 50, 100, 150, 130), OcrTextElement("2", 200, 100, 250, 130), OcrTextElement("10.00", 350, 100, 420, 130), ] result = extractor._table_match("数量", elements) assert result is not None assert result.field_value == "2" def test_table_with_fuzzy_header(self): extractor = OcrExtractor() elements = [ OcrTextElement("品名", 50, 50, 120, 80), OcrTextElement("金额", 200, 50, 300, 80), OcrTextElement("苹果", 50, 100, 120, 130), OcrTextElement("5.00", 200, 100, 260, 130), ] result = extractor._table_match("合计金额", elements) # 金额列可能匹配到 "金额" if result: assert result.field_value == "5.00" def test_table_header_not_found(self): extractor = OcrExtractor() elements = [ OcrTextElement("A", 50, 50, 100, 80), OcrTextElement("B", 150, 50, 200, 80), ] result = extractor._table_match("发票代码", elements) assert result is None def test_table_too_few_elements(self): extractor = OcrExtractor() elements = [ OcrTextElement("名称", 50, 50, 150, 80), ] result = extractor._table_match("名称", elements) assert result is None class TestCoordinateCorrectness: """测试坐标计算正确性。""" def test_bbox_origin_top_left(self): """验证坐标系统以左上角为原点。""" elem = OcrTextElement("A", 0, 0, 50, 20) assert elem.x_min == 0 assert elem.y_min == 0 assert elem.bbox[0] == 0 assert elem.bbox[1] == 0 def test_bbox_conversion(self): """验证从 (x, y, w, h) 到 [x_min, y_min, x_max, y_max] 的转换。""" x, y, w, h = 100, 200, 300, 50 elem = OcrTextElement("test", x_min=x, y_min=y, x_max=x + w, y_max=y + h) assert elem.bbox == [x, y, x + w, y + h] assert elem.x_max - elem.x_min == w assert elem.y_max - elem.y_min == h def test_multiple_elements_coordinate_independence(self): """验证多个元素的坐标互不干扰。""" elements = [ OcrTextElement("A", 10, 20, 60, 40), OcrTextElement("B", 100, 200, 180, 240), ] assert elements[0].bbox == [10, 20, 60, 40] assert elements[1].bbox == [100, 200, 180, 240] assert elements[0].x_min != elements[1].x_min def test_center_calculation(self): """验证中心点计算。""" elem = OcrTextElement("test", 0, 0, 100, 100) assert elem.center_x == 50.0 assert elem.center_y == 50.0 class TestExtractionPipeline: """测试完整的提取流水线。""" def test_priority_order_exact_first(self): """验证策略优先级: exact_match 优先于其他策略。""" extractor = OcrExtractor() elements = [ OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80), ] result = extractor._extract_field("发票代码", elements) assert result is not None assert result.extraction_method == "exact_match" def test_fallback_to_fuzzy(self): """验证精确匹配失败后回退到模糊匹配。""" extractor = OcrExtractor() elements = [ OcrTextElement("发票代码", 50, 50, 150, 80), OcrTextElement("1234567890", 200, 50, 350, 80), ] result = extractor._extract_field("发票代码", elements) assert result is not None assert result.extraction_method == "kv_pair" def test_all_fields_empty(self): """验证所有字段提取失败时返回空结果。""" extractor = OcrExtractor() elements = [ OcrTextElement("一些不相关的文本", 50, 50, 300, 80), OcrTextElement("更多随机内容", 50, 100, 300, 130), ] result = extractor._extract_field("发票代码", elements) assert result is None def test_empty_elements_list(self): """验证空元素列表时正常处理。""" extractor = OcrExtractor() result = extractor._extract_field("发票代码", []) assert result is None def test_extraction_with_confidence(self): """验证提取结果的置信度在合理范围内。""" extractor = OcrExtractor() elements = [ OcrTextElement("发票代码: 1234567890", 50, 50, 300, 80), ] result = extractor._extract_field("发票代码", elements) assert result is not None assert 0.0 < result.confidence <= 1.0 class TestExtractFromLayout: """测试从 layout 结果中提取字段。""" def test_basic_layout_extraction(self): layout_result = { "image_size": (800, 600), "template_type": "full_a4", "rows": [ { "y_center": 50, "elements": [ {"x": 50, "y": 30, "w": 150, "h": 30, "text": "发票代码:"}, {"x": 250, "y": 30, "w": 200, "h": 30, "text": "1234567890"}, ], }, { "y_center": 80, "elements": [ {"x": 50, "y": 60, "w": 150, "h": 30, "text": "合计金额:"}, {"x": 250, "y": 60, "w": 100, "h": 30, "text": "999.00"}, ], }, ], } result = extract_from_layout(layout_result, ["发票代码"]) assert result["ocr_available"] is True assert result["image_size"] == (800, 600) def test_empty_layout(self): result = extract_from_layout({}, ["发票代码"]) assert result["ocr_available"] is False assert len(result["errors"]) > 0 class TestOcrExtractorFileNotFound: """测试文件不存在的情况。""" def test_missing_file(self): extractor = OcrExtractor() result = extractor.extract("/nonexistent/file.png", ["发票代码"]) assert result["ocr_available"] is False assert len(result["errors"]) > 0 assert "文件不存在" in result["errors"][0] class TestConvenienceFunctions: """测试便捷函数。""" def test_extract_ocr_fields_missing_file(self): result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"]) assert len(result["errors"]) > 0 def test_extract_from_layout_with_partial_rows(self): layout_result = { "image_size": (1200, 800), "template_type": "partial_rows", "rows": [ { "y_center": 100, "elements": [ {"x": 100, "y": 80, "w": 120, "h": 40, "text": "发票代码"}, {"x": 300, "y": 80, "w": 180, "h": 40, "text": "NO123456"}, ], }, ], } result = extract_from_layout(layout_result, ["发票代码"]) assert result["ocr_available"] is True assert len(result["fields"]) == 1 assert result["fields"][0]["field_name"] == "发票代码" assert result["fields"][0]["field_value"] == "NO123456"