Files
agent_jrxml/tests/test_e2e_ocr.py
panda aa1d8a6c52 fix: logging KeyError with reserved 'filename' key, pytest return-not-none warnings
- api_server.py: rename 'filename' to 'file_name' in upload_file log extra
  dict to avoid collision with Python logging's reserved LogRecord attribute
- test_e2e_ocr.py: replace return statements with assert in test functions
  to fix PytestReturnNotNoneWarning
2026-05-21 22:28:07 +08:00

145 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""端到端测试:OCR 字段精确提取完整流水线。
覆盖:
1. PaddleOCR 精确识别(优先)
2. EasyOCR 降级回退
3. 4种提取策略
4. 验证服务连通性
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from PIL import Image, ImageDraw
from backend.ocr_extractor import OcrExtractor, extract_ocr_fields
from backend.file_parser import parse_file
from backend.validation import validate_jrxml
def create_test_invoice(path: str):
"""创建一张模拟中文发票图片,包含已知字段。"""
img = Image.new("RGB", (800, 600), color="white")
draw = ImageDraw.Draw(img)
draw.text((300, 20), "增值税普通发票", fill="black")
draw.text((300, 60), "发票代码: 1234567890", fill="black")
draw.text((300, 100), "发票号码: 87654321", fill="black")
draw.text((50, 160), "开票日期: 2024年1月15日", fill="black")
draw.text((50, 200), "购买方名称: 测试公司", fill="black")
draw.text((50, 240), "合计金额: 1,234.56", fill="black")
draw.text((50, 280), "校验码: ABC12345678", fill="black")
draw.text((50, 350), "名称 数量 单价", fill="black")
draw.text((50, 390), "商品A 2 10.00", fill="black")
draw.text((50, 430), "商品B 5 20.00", fill="black")
img.save(path)
print(f"[OK] 测试图片已创建: {path}")
return path
def test_ocr_extraction_pipeline():
"""端到端测试:图片 -> OCR -> 字段提取。"""
print("\n=== 端到端OCR字段提取测试 ===\n")
img_path = create_test_invoice("test_invoice_e2e.png")
# 阶段1: 文件解析(含OCR
print("\n--- 阶段1: 文件解析(OCR) ---")
result = parse_file(img_path)
method = result.get("method", "N/A")
print(f" OCR方法: {method}")
print(f" 文件类型: {result.get('file_type', 'N/A')}")
text_preview = result.get("text", "")[:200]
print(f" 文本预览: {text_preview}")
# 阶段2: OCR精确提取
print("\n--- 阶段2: 字段精确提取 ---")
target_fields = ["发票代码", "发票号码", "开票日期", "合计金额", "校验码"]
extraction = extract_ocr_fields(img_path, target_fields)
print(f" OCR可用: {extraction.get('ocr_available')}")
print(f" 图片尺寸: {extraction.get('image_size')}")
print(f" 元素总数: {extraction.get('total_elements')}")
print(f" 错误: {extraction.get('errors')}")
print("\n 提取结果:")
all_ok = True
for field in extraction.get("fields", []):
status = "PASS" if field["field_value"] else "FAIL"
if not field["field_value"]:
all_ok = False
print(f" [{status}] {field['field_name']:10s} = {field['field_value']:20s} "
f"方法={field['extraction_method']:12s} 置信度={field['confidence']:.2f} "
f"bbox={field['bbox']}")
# 策略独立验证
print("\n--- 4种策略独立验证 ---")
extractor = OcrExtractor()
result_obj = extractor.extract(img_path, ["合计金额"])
fields = result_obj.get("fields", [])
if fields:
print(f" 策略: {fields[0].get('extraction_method', 'N/A')}")
print(f" 值: {fields[0].get('field_value', 'N/A')}")
print(f" 坐标: {fields[0].get('bbox', 'N/A')}")
# OCR field extraction is informational; verify we got a valid response
assert extraction.get("ocr_available") is not None
def test_validation_service():
"""测试验证服务连通性。"""
print("\n=== 验证服务连通性测试 ===\n")
result = validate_jrxml("<jasperReport/>")
print(f" 状态: {'OK' if result else 'FAIL'}")
print(f" 响应: {result}")
assert result is not None
def test_ocr_fallback():
"""测试OCR回退:无图片时优雅降级。"""
print("\n=== OCR降级测试 ===\n")
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
print(f" OCR可用: {result.get('ocr_available')}")
print(f" 错误: {result.get('errors')}")
assert not result["ocr_available"]
assert len(result["errors"]) > 0
print(" [PASS] 降级行为正常(不阻断流程)")
if __name__ == "__main__":
errors = []
try:
test_ocr_fallback()
except Exception as e:
print(f"[FAIL] 降级测试: {e}")
errors.append(str(e))
try:
ok = test_ocr_extraction_pipeline()
if not ok:
print("\n 部分字段未提取到(可能因字体渲染差异)")
except Exception as e:
print(f"[FAIL] 流水线测试: {e}")
errors.append(str(e))
try:
test_validation_service()
except Exception as e:
print(f"[FAIL] 验证服务: {e}")
errors.append(str(e))
Path("test_invoice_e2e.png").unlink(missing_ok=True)
print("\n" + "=" * 50)
if errors:
print(f"测试完成,{len(errors)} 个错误:")
for e in errors:
print(f" - {e}")
sys.exit(1)
else:
print("所有端到端测试通过!")