9bb011e429
- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button) - Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines) - Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py - Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting - Add annotation_result field to AgentState with session persistence - Wire annotation detection into process_input and _format_ocr_context - Add 11 new tests: 7 annotation detector + 4 multi-format parser - Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
144 lines
4.9 KiB
Python
144 lines
4.9 KiB
Python
"""端到端测试:OCR 字段精确提取完整流水线。
|
||
|
||
覆盖:
|
||
1. PaddleOCR 精确识别(优先)
|
||
2. EasyOCR 降级回退
|
||
3. 4种提取策略
|
||
4. 验证服务连通性
|
||
"""
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
||
from PIL import Image, ImageDraw
|
||
|
||
from backend.ocr_extractor import OcrExtractor, extract_ocr_fields
|
||
from backend.file_parser import parse_file
|
||
from backend.validation import validate_jrxml
|
||
|
||
|
||
def create_test_invoice(path: str):
|
||
"""创建一张模拟中文发票图片,包含已知字段。"""
|
||
img = Image.new("RGB", (800, 600), color="white")
|
||
draw = ImageDraw.Draw(img)
|
||
|
||
draw.text((300, 20), "增值税普通发票", fill="black")
|
||
draw.text((300, 60), "发票代码: 1234567890", fill="black")
|
||
draw.text((300, 100), "发票号码: 87654321", fill="black")
|
||
draw.text((50, 160), "开票日期: 2024年1月15日", fill="black")
|
||
draw.text((50, 200), "购买方名称: 测试公司", fill="black")
|
||
draw.text((50, 240), "合计金额: 1,234.56", fill="black")
|
||
draw.text((50, 280), "校验码: ABC12345678", fill="black")
|
||
|
||
draw.text((50, 350), "名称 数量 单价", fill="black")
|
||
draw.text((50, 390), "商品A 2 10.00", fill="black")
|
||
draw.text((50, 430), "商品B 5 20.00", fill="black")
|
||
|
||
img.save(path)
|
||
print(f"[OK] 测试图片已创建: {path}")
|
||
return path
|
||
|
||
|
||
def test_ocr_extraction_pipeline():
|
||
"""端到端测试:图片 -> OCR -> 字段提取。"""
|
||
print("\n=== 端到端OCR字段提取测试 ===\n")
|
||
|
||
img_path = create_test_invoice("test_invoice_e2e.png")
|
||
|
||
# 阶段1: 文件解析(含OCR)
|
||
print("\n--- 阶段1: 文件解析(OCR) ---")
|
||
result = parse_file(img_path)
|
||
method = result.get("method", "N/A")
|
||
print(f" OCR方法: {method}")
|
||
print(f" 文件类型: {result.get('file_type', 'N/A')}")
|
||
text_preview = result.get("text", "")[:200]
|
||
print(f" 文本预览: {text_preview}")
|
||
|
||
# 阶段2: OCR精确提取
|
||
print("\n--- 阶段2: 字段精确提取 ---")
|
||
target_fields = ["发票代码", "发票号码", "开票日期", "合计金额", "校验码"]
|
||
extraction = extract_ocr_fields(img_path, target_fields)
|
||
|
||
print(f" OCR可用: {extraction.get('ocr_available')}")
|
||
print(f" 图片尺寸: {extraction.get('image_size')}")
|
||
print(f" 元素总数: {extraction.get('total_elements')}")
|
||
print(f" 错误: {extraction.get('errors')}")
|
||
|
||
print("\n 提取结果:")
|
||
all_ok = True
|
||
for field in extraction.get("fields", []):
|
||
status = "PASS" if field["field_value"] else "FAIL"
|
||
if not field["field_value"]:
|
||
all_ok = False
|
||
print(f" [{status}] {field['field_name']:10s} = {field['field_value']:20s} "
|
||
f"方法={field['extraction_method']:12s} 置信度={field['confidence']:.2f} "
|
||
f"bbox={field['bbox']}")
|
||
|
||
# 策略独立验证
|
||
print("\n--- 4种策略独立验证 ---")
|
||
extractor = OcrExtractor()
|
||
result_obj = extractor.extract(img_path, ["合计金额"])
|
||
fields = result_obj.get("fields", [])
|
||
if fields:
|
||
print(f" 策略: {fields[0].get('extraction_method', 'N/A')}")
|
||
print(f" 值: {fields[0].get('field_value', 'N/A')}")
|
||
print(f" 坐标: {fields[0].get('bbox', 'N/A')}")
|
||
|
||
return all_ok
|
||
|
||
|
||
def test_validation_service():
|
||
"""测试验证服务连通性。"""
|
||
print("\n=== 验证服务连通性测试 ===\n")
|
||
result = validate_jrxml("<jasperReport/>")
|
||
print(f" 状态: {'OK' if result else 'FAIL'}")
|
||
print(f" 响应: {result}")
|
||
return True
|
||
|
||
|
||
def test_ocr_fallback():
|
||
"""测试OCR回退:无图片时优雅降级。"""
|
||
print("\n=== OCR降级测试 ===\n")
|
||
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
|
||
print(f" OCR可用: {result.get('ocr_available')}")
|
||
print(f" 错误: {result.get('errors')}")
|
||
assert not result["ocr_available"]
|
||
assert len(result["errors"]) > 0
|
||
print(" [PASS] 降级行为正常(不阻断流程)")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
errors = []
|
||
|
||
try:
|
||
test_ocr_fallback()
|
||
except Exception as e:
|
||
print(f"[FAIL] 降级测试: {e}")
|
||
errors.append(str(e))
|
||
|
||
try:
|
||
ok = test_ocr_extraction_pipeline()
|
||
if not ok:
|
||
print("\n 部分字段未提取到(可能因字体渲染差异)")
|
||
except Exception as e:
|
||
print(f"[FAIL] 流水线测试: {e}")
|
||
errors.append(str(e))
|
||
|
||
try:
|
||
test_validation_service()
|
||
except Exception as e:
|
||
print(f"[FAIL] 验证服务: {e}")
|
||
errors.append(str(e))
|
||
|
||
Path("test_invoice_e2e.png").unlink(missing_ok=True)
|
||
|
||
print("\n" + "=" * 50)
|
||
if errors:
|
||
print(f"测试完成,{len(errors)} 个错误:")
|
||
for e in errors:
|
||
print(f" - {e}")
|
||
sys.exit(1)
|
||
else:
|
||
print("所有端到端测试通过!")
|