feat: v4 multimodal chat input, multi-format support, and annotation detection
- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button) - Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines) - Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py - Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting - Add annotation_result field to AgentState with session persistence - Wire annotation detection into process_input and _format_ocr_context - Add 11 new tests: 7 annotation detector + 4 multi-format parser - Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
This commit is contained in:
@@ -0,0 +1,151 @@
|
||||
"""测试批注检测器:圆圈检测、箭头检测、OCR 关联、格式化。"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def _draw_circle_image(path: str, size: tuple = (400, 300)) -> None:
|
||||
"""生成包含红色圆圈的合成测试图片。"""
|
||||
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
||||
cv2.circle(img, (200, 150), 50, (0, 0, 255), 2)
|
||||
cv2.imwrite(path, img)
|
||||
|
||||
|
||||
def _draw_arrow_image(path: str, size: tuple = (400, 300)) -> None:
|
||||
"""生成包含手绘风格箭头的合成测试图片(多段线模拟手绘)。"""
|
||||
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
||||
# 多段略微偏移的线段模拟手绘箭杆(产生多个 HoughLinesP 段)
|
||||
for offset_y in (-1, 0, 1):
|
||||
cv2.line(img, (50, 150 + offset_y), (200, 150 + offset_y), (0, 0, 255), 2)
|
||||
for offset_y in (-1, 0, 1):
|
||||
cv2.line(img, (200, 150 + offset_y), (340, 150 + offset_y), (0, 0, 255), 2)
|
||||
# 箭头三角形
|
||||
pts = np.array([[350, 150], [330, 135], [330, 165]], np.int32)
|
||||
cv2.fillPoly(img, [pts], (0, 0, 255))
|
||||
# 额外三角形边缘线
|
||||
cv2.line(img, (350, 150), (330, 135), (0, 0, 255), 2)
|
||||
cv2.line(img, (350, 150), (330, 165), (0, 0, 255), 2)
|
||||
cv2.imwrite(path, img)
|
||||
|
||||
|
||||
def _draw_circle_and_text_image(path: str, size: tuple = (500, 400)) -> None:
|
||||
"""生成包含红色圆圈和"文本"的合成图片(模拟圈选批注)。"""
|
||||
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
||||
cv2.circle(img, (250, 150), 60, (0, 0, 255), 3)
|
||||
cv2.putText(img, "项目A", (20, 160), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
||||
cv2.imwrite(path, img)
|
||||
|
||||
|
||||
class TestAnnotationDetector:
|
||||
"""测试 annotation_detector.py 各功能。"""
|
||||
|
||||
def test_detect_circles_finds_circle(self):
|
||||
from backend.annotation_detector import detect_annotations
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_draw_circle_image(path)
|
||||
ocr_elements = [
|
||||
{"text": "测试字段", "bbox": {"x": 170, "y": 120, "w": 60, "h": 20}, "confidence": 0.95},
|
||||
]
|
||||
result = detect_annotations(path, ocr_elements)
|
||||
assert result["total"] >= 1
|
||||
circles = result["circles"]
|
||||
assert len(circles) >= 1
|
||||
c = circles[0]
|
||||
assert c["type"] == "circle"
|
||||
assert "center" in c
|
||||
assert "bbox" in c
|
||||
assert "nearby_texts" in c
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_detect_arrows_finds_arrow(self):
|
||||
from backend.annotation_detector import detect_annotations
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_draw_arrow_image(path)
|
||||
ocr_elements = [
|
||||
{"text": "起点", "bbox": {"x": 30, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
|
||||
{"text": "终点", "bbox": {"x": 310, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
|
||||
]
|
||||
result = detect_annotations(path, ocr_elements)
|
||||
assert result["total"] >= 1
|
||||
arrows = result["arrows"]
|
||||
assert len(arrows) >= 1
|
||||
a = arrows[0]
|
||||
assert a["type"] == "arrow"
|
||||
assert "from_pt" in a
|
||||
assert "to_pt" in a
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_correlate_with_ocr_links_nearby_texts(self):
|
||||
from backend.annotation_detector import detect_annotations
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_draw_circle_and_text_image(path)
|
||||
ocr_elements = [
|
||||
{"text": "项目A", "bbox": {"x": 20, "y": 140, "w": 80, "h": 30}, "confidence": 0.98},
|
||||
{"text": "金额", "bbox": {"x": 350, "y": 200, "w": 50, "h": 20}, "confidence": 0.9},
|
||||
]
|
||||
result = detect_annotations(path, ocr_elements)
|
||||
circles = result["circles"]
|
||||
if circles:
|
||||
near = circles[0].get("nearby_texts", [])
|
||||
if near:
|
||||
assert "项目A" in near
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_invalid_image_path(self):
|
||||
from backend.annotation_detector import detect_annotations
|
||||
|
||||
result = detect_annotations("/nonexistent/file.png", [])
|
||||
assert result["total"] == 0
|
||||
assert "error" in result
|
||||
|
||||
def test_format_annotation_context_empty(self):
|
||||
from backend.annotation_detector import format_annotation_context
|
||||
|
||||
assert format_annotation_context({}) == ""
|
||||
assert format_annotation_context(None) == ""
|
||||
assert format_annotation_context({"circles": [], "arrows": [], "total": 0}) == ""
|
||||
|
||||
def test_format_annotation_context_with_circles(self):
|
||||
from backend.annotation_detector import format_annotation_context
|
||||
|
||||
ann = {
|
||||
"circles": [
|
||||
{"center": [100, 200], "nearby_texts": ["项目A", "金额"]},
|
||||
],
|
||||
"arrows": [],
|
||||
"total": 1,
|
||||
}
|
||||
text = format_annotation_context(ann)
|
||||
assert "圈选标记" in text
|
||||
assert "项目A" in text
|
||||
|
||||
def test_format_annotation_context_with_arrows(self):
|
||||
from backend.annotation_detector import format_annotation_context
|
||||
|
||||
ann = {
|
||||
"circles": [],
|
||||
"arrows": [
|
||||
{"from_text": "修理号", "to_text": "车架号"},
|
||||
],
|
||||
"total": 1,
|
||||
}
|
||||
text = format_annotation_context(ann)
|
||||
assert "箭头标记" in text
|
||||
assert "修理号" in text
|
||||
assert "车架号" in text
|
||||
@@ -0,0 +1,143 @@
|
||||
"""端到端测试:OCR 字段精确提取完整流水线。
|
||||
|
||||
覆盖:
|
||||
1. PaddleOCR 精确识别(优先)
|
||||
2. EasyOCR 降级回退
|
||||
3. 4种提取策略
|
||||
4. 验证服务连通性
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from backend.ocr_extractor import OcrExtractor, extract_ocr_fields
|
||||
from backend.file_parser import parse_file
|
||||
from backend.validation import validate_jrxml
|
||||
|
||||
|
||||
def create_test_invoice(path: str):
|
||||
"""创建一张模拟中文发票图片,包含已知字段。"""
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
draw.text((300, 20), "增值税普通发票", fill="black")
|
||||
draw.text((300, 60), "发票代码: 1234567890", fill="black")
|
||||
draw.text((300, 100), "发票号码: 87654321", fill="black")
|
||||
draw.text((50, 160), "开票日期: 2024年1月15日", fill="black")
|
||||
draw.text((50, 200), "购买方名称: 测试公司", fill="black")
|
||||
draw.text((50, 240), "合计金额: 1,234.56", fill="black")
|
||||
draw.text((50, 280), "校验码: ABC12345678", fill="black")
|
||||
|
||||
draw.text((50, 350), "名称 数量 单价", fill="black")
|
||||
draw.text((50, 390), "商品A 2 10.00", fill="black")
|
||||
draw.text((50, 430), "商品B 5 20.00", fill="black")
|
||||
|
||||
img.save(path)
|
||||
print(f"[OK] 测试图片已创建: {path}")
|
||||
return path
|
||||
|
||||
|
||||
def test_ocr_extraction_pipeline():
|
||||
"""端到端测试:图片 -> OCR -> 字段提取。"""
|
||||
print("\n=== 端到端OCR字段提取测试 ===\n")
|
||||
|
||||
img_path = create_test_invoice("test_invoice_e2e.png")
|
||||
|
||||
# 阶段1: 文件解析(含OCR)
|
||||
print("\n--- 阶段1: 文件解析(OCR) ---")
|
||||
result = parse_file(img_path)
|
||||
method = result.get("method", "N/A")
|
||||
print(f" OCR方法: {method}")
|
||||
print(f" 文件类型: {result.get('file_type', 'N/A')}")
|
||||
text_preview = result.get("text", "")[:200]
|
||||
print(f" 文本预览: {text_preview}")
|
||||
|
||||
# 阶段2: OCR精确提取
|
||||
print("\n--- 阶段2: 字段精确提取 ---")
|
||||
target_fields = ["发票代码", "发票号码", "开票日期", "合计金额", "校验码"]
|
||||
extraction = extract_ocr_fields(img_path, target_fields)
|
||||
|
||||
print(f" OCR可用: {extraction.get('ocr_available')}")
|
||||
print(f" 图片尺寸: {extraction.get('image_size')}")
|
||||
print(f" 元素总数: {extraction.get('total_elements')}")
|
||||
print(f" 错误: {extraction.get('errors')}")
|
||||
|
||||
print("\n 提取结果:")
|
||||
all_ok = True
|
||||
for field in extraction.get("fields", []):
|
||||
status = "PASS" if field["field_value"] else "FAIL"
|
||||
if not field["field_value"]:
|
||||
all_ok = False
|
||||
print(f" [{status}] {field['field_name']:10s} = {field['field_value']:20s} "
|
||||
f"方法={field['extraction_method']:12s} 置信度={field['confidence']:.2f} "
|
||||
f"bbox={field['bbox']}")
|
||||
|
||||
# 策略独立验证
|
||||
print("\n--- 4种策略独立验证 ---")
|
||||
extractor = OcrExtractor()
|
||||
result_obj = extractor.extract(img_path, ["合计金额"])
|
||||
fields = result_obj.get("fields", [])
|
||||
if fields:
|
||||
print(f" 策略: {fields[0].get('extraction_method', 'N/A')}")
|
||||
print(f" 值: {fields[0].get('field_value', 'N/A')}")
|
||||
print(f" 坐标: {fields[0].get('bbox', 'N/A')}")
|
||||
|
||||
return all_ok
|
||||
|
||||
|
||||
def test_validation_service():
|
||||
"""测试验证服务连通性。"""
|
||||
print("\n=== 验证服务连通性测试 ===\n")
|
||||
result = validate_jrxml("<jasperReport/>")
|
||||
print(f" 状态: {'OK' if result else 'FAIL'}")
|
||||
print(f" 响应: {result}")
|
||||
return True
|
||||
|
||||
|
||||
def test_ocr_fallback():
|
||||
"""测试OCR回退:无图片时优雅降级。"""
|
||||
print("\n=== OCR降级测试 ===\n")
|
||||
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
|
||||
print(f" OCR可用: {result.get('ocr_available')}")
|
||||
print(f" 错误: {result.get('errors')}")
|
||||
assert not result["ocr_available"]
|
||||
assert len(result["errors"]) > 0
|
||||
print(" [PASS] 降级行为正常(不阻断流程)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
errors = []
|
||||
|
||||
try:
|
||||
test_ocr_fallback()
|
||||
except Exception as e:
|
||||
print(f"[FAIL] 降级测试: {e}")
|
||||
errors.append(str(e))
|
||||
|
||||
try:
|
||||
ok = test_ocr_extraction_pipeline()
|
||||
if not ok:
|
||||
print("\n 部分字段未提取到(可能因字体渲染差异)")
|
||||
except Exception as e:
|
||||
print(f"[FAIL] 流水线测试: {e}")
|
||||
errors.append(str(e))
|
||||
|
||||
try:
|
||||
test_validation_service()
|
||||
except Exception as e:
|
||||
print(f"[FAIL] 验证服务: {e}")
|
||||
errors.append(str(e))
|
||||
|
||||
Path("test_invoice_e2e.png").unlink(missing_ok=True)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
if errors:
|
||||
print(f"测试完成,{len(errors)} 个错误:")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("所有端到端测试通过!")
|
||||
@@ -0,0 +1,90 @@
|
||||
"""测试多格式文件解析器:XLSX, XLS, DOC。"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_xlsx(path: str) -> None:
|
||||
"""生成最小 .xlsx 测试文件。"""
|
||||
from openpyxl import Workbook
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "Sheet1"
|
||||
ws["A1"] = "名称"
|
||||
ws["B1"] = "金额"
|
||||
ws["A2"] = "项目A"
|
||||
ws["B2"] = 100
|
||||
ws["A3"] = "项目B"
|
||||
ws["B3"] = 200
|
||||
wb.save(path)
|
||||
|
||||
|
||||
def _make_xls(path: str) -> None:
|
||||
"""生成最小 .xls 测试文件。"""
|
||||
from xlwt import Workbook
|
||||
wb = Workbook()
|
||||
ws = wb.add_sheet("Sheet1")
|
||||
ws.write(0, 0, "名称")
|
||||
ws.write(0, 1, "金额")
|
||||
ws.write(1, 0, "项目A")
|
||||
ws.write(1, 1, 100)
|
||||
ws.write(2, 0, "项目B")
|
||||
ws.write(2, 1, 200)
|
||||
wb.save(path)
|
||||
|
||||
|
||||
class TestMultiFormatParsers:
|
||||
"""测试 file_parser.py 的多格式解析器。"""
|
||||
|
||||
def test_parse_xlsx(self):
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_make_xlsx(path)
|
||||
result = parse_file(path, ".xlsx")
|
||||
assert result["file_type"] == "xlsx"
|
||||
assert result["method"] == "openpyxl"
|
||||
assert result["error"] is None
|
||||
assert "Sheet1" in result["text"]
|
||||
assert "项目A" in result["text"]
|
||||
assert "100" in result["text"]
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_parse_xls(self):
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".xls", delete=False) as tmp:
|
||||
path = tmp.name
|
||||
try:
|
||||
_make_xls(path)
|
||||
result = parse_file(path, ".xls")
|
||||
assert result["file_type"] == "xls"
|
||||
assert result["method"] == "xlrd"
|
||||
assert result["error"] is None
|
||||
assert "Sheet1" in result["text"]
|
||||
assert "项目A" in result["text"]
|
||||
assert "100.0" in result["text"]
|
||||
finally:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
def test_parse_doc_nonexistent(self):
|
||||
"""测试 .doc 文件不存在时的错误处理。"""
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
result = parse_file("/nonexistent/file.doc", ".doc")
|
||||
assert result["file_type"] == ".doc"
|
||||
assert result["method"] == "none"
|
||||
assert result.get("error") is not None
|
||||
|
||||
def test_dispatch_adds_new_formats(self):
|
||||
"""验证新格式已在 parse_file 调度表中注册。"""
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
for ext in [".xlsx", ".xls", ".doc"]:
|
||||
result = parse_file("/tmp/test" + ext, ext)
|
||||
assert result["file_type"] in (ext, "xlsx", "xls", "doc")
|
||||
Reference in New Issue
Block a user