Files
agent_jrxml/tests/test_e2e_ocr.py
T
panda 9bb011e429 feat: v4 multimodal chat input, multi-format support, and annotation detection
- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button)
- Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines)
- Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py
- Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting
- Add annotation_result field to AgentState with session persistence
- Wire annotation detection into process_input and _format_ocr_context
- Add 11 new tests: 7 annotation detector + 4 multi-format parser
- Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
2026-05-20 23:43:16 +08:00

144 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""端到端测试:OCR 字段精确提取完整流水线。
覆盖:
1. PaddleOCR 精确识别(优先)
2. EasyOCR 降级回退
3. 4种提取策略
4. 验证服务连通性
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from PIL import Image, ImageDraw
from backend.ocr_extractor import OcrExtractor, extract_ocr_fields
from backend.file_parser import parse_file
from backend.validation import validate_jrxml
def create_test_invoice(path: str):
"""创建一张模拟中文发票图片,包含已知字段。"""
img = Image.new("RGB", (800, 600), color="white")
draw = ImageDraw.Draw(img)
draw.text((300, 20), "增值税普通发票", fill="black")
draw.text((300, 60), "发票代码: 1234567890", fill="black")
draw.text((300, 100), "发票号码: 87654321", fill="black")
draw.text((50, 160), "开票日期: 2024年1月15日", fill="black")
draw.text((50, 200), "购买方名称: 测试公司", fill="black")
draw.text((50, 240), "合计金额: 1,234.56", fill="black")
draw.text((50, 280), "校验码: ABC12345678", fill="black")
draw.text((50, 350), "名称 数量 单价", fill="black")
draw.text((50, 390), "商品A 2 10.00", fill="black")
draw.text((50, 430), "商品B 5 20.00", fill="black")
img.save(path)
print(f"[OK] 测试图片已创建: {path}")
return path
def test_ocr_extraction_pipeline():
"""端到端测试:图片 -> OCR -> 字段提取。"""
print("\n=== 端到端OCR字段提取测试 ===\n")
img_path = create_test_invoice("test_invoice_e2e.png")
# 阶段1: 文件解析(含OCR
print("\n--- 阶段1: 文件解析(OCR) ---")
result = parse_file(img_path)
method = result.get("method", "N/A")
print(f" OCR方法: {method}")
print(f" 文件类型: {result.get('file_type', 'N/A')}")
text_preview = result.get("text", "")[:200]
print(f" 文本预览: {text_preview}")
# 阶段2: OCR精确提取
print("\n--- 阶段2: 字段精确提取 ---")
target_fields = ["发票代码", "发票号码", "开票日期", "合计金额", "校验码"]
extraction = extract_ocr_fields(img_path, target_fields)
print(f" OCR可用: {extraction.get('ocr_available')}")
print(f" 图片尺寸: {extraction.get('image_size')}")
print(f" 元素总数: {extraction.get('total_elements')}")
print(f" 错误: {extraction.get('errors')}")
print("\n 提取结果:")
all_ok = True
for field in extraction.get("fields", []):
status = "PASS" if field["field_value"] else "FAIL"
if not field["field_value"]:
all_ok = False
print(f" [{status}] {field['field_name']:10s} = {field['field_value']:20s} "
f"方法={field['extraction_method']:12s} 置信度={field['confidence']:.2f} "
f"bbox={field['bbox']}")
# 策略独立验证
print("\n--- 4种策略独立验证 ---")
extractor = OcrExtractor()
result_obj = extractor.extract(img_path, ["合计金额"])
fields = result_obj.get("fields", [])
if fields:
print(f" 策略: {fields[0].get('extraction_method', 'N/A')}")
print(f" 值: {fields[0].get('field_value', 'N/A')}")
print(f" 坐标: {fields[0].get('bbox', 'N/A')}")
return all_ok
def test_validation_service():
"""测试验证服务连通性。"""
print("\n=== 验证服务连通性测试 ===\n")
result = validate_jrxml("<jasperReport/>")
print(f" 状态: {'OK' if result else 'FAIL'}")
print(f" 响应: {result}")
return True
def test_ocr_fallback():
"""测试OCR回退:无图片时优雅降级。"""
print("\n=== OCR降级测试 ===\n")
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
print(f" OCR可用: {result.get('ocr_available')}")
print(f" 错误: {result.get('errors')}")
assert not result["ocr_available"]
assert len(result["errors"]) > 0
print(" [PASS] 降级行为正常(不阻断流程)")
if __name__ == "__main__":
errors = []
try:
test_ocr_fallback()
except Exception as e:
print(f"[FAIL] 降级测试: {e}")
errors.append(str(e))
try:
ok = test_ocr_extraction_pipeline()
if not ok:
print("\n 部分字段未提取到(可能因字体渲染差异)")
except Exception as e:
print(f"[FAIL] 流水线测试: {e}")
errors.append(str(e))
try:
test_validation_service()
except Exception as e:
print(f"[FAIL] 验证服务: {e}")
errors.append(str(e))
Path("test_invoice_e2e.png").unlink(missing_ok=True)
print("\n" + "=" * 50)
if errors:
print(f"测试完成,{len(errors)} 个错误:")
for e in errors:
print(f" - {e}")
sys.exit(1)
else:
print("所有端到端测试通过!")