feat: v4 multimodal chat input, multi-format support, and annotation detection

- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button)
- Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines)
- Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py
- Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting
- Add annotation_result field to AgentState with session persistence
- Wire annotation detection into process_input and _format_ocr_context
- Add 11 new tests: 7 annotation detector + 4 multi-format parser
- Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
This commit is contained in:
2026-05-20 23:43:16 +08:00
parent c9f003e1b7
commit 9bb011e429
16 changed files with 1257 additions and 164 deletions
+151
View File
@@ -0,0 +1,151 @@
"""测试批注检测器:圆圈检测、箭头检测、OCR 关联、格式化。"""
import tempfile
from pathlib import Path
import cv2
import numpy as np
import pytest
def _draw_circle_image(path: str, size: tuple = (400, 300)) -> None:
"""生成包含红色圆圈的合成测试图片。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
cv2.circle(img, (200, 150), 50, (0, 0, 255), 2)
cv2.imwrite(path, img)
def _draw_arrow_image(path: str, size: tuple = (400, 300)) -> None:
"""生成包含手绘风格箭头的合成测试图片(多段线模拟手绘)。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
# 多段略微偏移的线段模拟手绘箭杆(产生多个 HoughLinesP 段)
for offset_y in (-1, 0, 1):
cv2.line(img, (50, 150 + offset_y), (200, 150 + offset_y), (0, 0, 255), 2)
for offset_y in (-1, 0, 1):
cv2.line(img, (200, 150 + offset_y), (340, 150 + offset_y), (0, 0, 255), 2)
# 箭头三角形
pts = np.array([[350, 150], [330, 135], [330, 165]], np.int32)
cv2.fillPoly(img, [pts], (0, 0, 255))
# 额外三角形边缘线
cv2.line(img, (350, 150), (330, 135), (0, 0, 255), 2)
cv2.line(img, (350, 150), (330, 165), (0, 0, 255), 2)
cv2.imwrite(path, img)
def _draw_circle_and_text_image(path: str, size: tuple = (500, 400)) -> None:
"""生成包含红色圆圈和"文本"的合成图片(模拟圈选批注)。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
cv2.circle(img, (250, 150), 60, (0, 0, 255), 3)
cv2.putText(img, "项目A", (20, 160), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
cv2.imwrite(path, img)
class TestAnnotationDetector:
"""测试 annotation_detector.py 各功能。"""
def test_detect_circles_finds_circle(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_circle_image(path)
ocr_elements = [
{"text": "测试字段", "bbox": {"x": 170, "y": 120, "w": 60, "h": 20}, "confidence": 0.95},
]
result = detect_annotations(path, ocr_elements)
assert result["total"] >= 1
circles = result["circles"]
assert len(circles) >= 1
c = circles[0]
assert c["type"] == "circle"
assert "center" in c
assert "bbox" in c
assert "nearby_texts" in c
finally:
Path(path).unlink(missing_ok=True)
def test_detect_arrows_finds_arrow(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_arrow_image(path)
ocr_elements = [
{"text": "起点", "bbox": {"x": 30, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
{"text": "终点", "bbox": {"x": 310, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
]
result = detect_annotations(path, ocr_elements)
assert result["total"] >= 1
arrows = result["arrows"]
assert len(arrows) >= 1
a = arrows[0]
assert a["type"] == "arrow"
assert "from_pt" in a
assert "to_pt" in a
finally:
Path(path).unlink(missing_ok=True)
def test_correlate_with_ocr_links_nearby_texts(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_circle_and_text_image(path)
ocr_elements = [
{"text": "项目A", "bbox": {"x": 20, "y": 140, "w": 80, "h": 30}, "confidence": 0.98},
{"text": "金额", "bbox": {"x": 350, "y": 200, "w": 50, "h": 20}, "confidence": 0.9},
]
result = detect_annotations(path, ocr_elements)
circles = result["circles"]
if circles:
near = circles[0].get("nearby_texts", [])
if near:
assert "项目A" in near
finally:
Path(path).unlink(missing_ok=True)
def test_invalid_image_path(self):
from backend.annotation_detector import detect_annotations
result = detect_annotations("/nonexistent/file.png", [])
assert result["total"] == 0
assert "error" in result
def test_format_annotation_context_empty(self):
from backend.annotation_detector import format_annotation_context
assert format_annotation_context({}) == ""
assert format_annotation_context(None) == ""
assert format_annotation_context({"circles": [], "arrows": [], "total": 0}) == ""
def test_format_annotation_context_with_circles(self):
from backend.annotation_detector import format_annotation_context
ann = {
"circles": [
{"center": [100, 200], "nearby_texts": ["项目A", "金额"]},
],
"arrows": [],
"total": 1,
}
text = format_annotation_context(ann)
assert "圈选标记" in text
assert "项目A" in text
def test_format_annotation_context_with_arrows(self):
from backend.annotation_detector import format_annotation_context
ann = {
"circles": [],
"arrows": [
{"from_text": "修理号", "to_text": "车架号"},
],
"total": 1,
}
text = format_annotation_context(ann)
assert "箭头标记" in text
assert "修理号" in text
assert "车架号" in text
+143
View File
@@ -0,0 +1,143 @@
"""端到端测试:OCR 字段精确提取完整流水线。
覆盖:
1. PaddleOCR 精确识别(优先)
2. EasyOCR 降级回退
3. 4种提取策略
4. 验证服务连通性
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from PIL import Image, ImageDraw
from backend.ocr_extractor import OcrExtractor, extract_ocr_fields
from backend.file_parser import parse_file
from backend.validation import validate_jrxml
def create_test_invoice(path: str):
"""创建一张模拟中文发票图片,包含已知字段。"""
img = Image.new("RGB", (800, 600), color="white")
draw = ImageDraw.Draw(img)
draw.text((300, 20), "增值税普通发票", fill="black")
draw.text((300, 60), "发票代码: 1234567890", fill="black")
draw.text((300, 100), "发票号码: 87654321", fill="black")
draw.text((50, 160), "开票日期: 2024年1月15日", fill="black")
draw.text((50, 200), "购买方名称: 测试公司", fill="black")
draw.text((50, 240), "合计金额: 1,234.56", fill="black")
draw.text((50, 280), "校验码: ABC12345678", fill="black")
draw.text((50, 350), "名称 数量 单价", fill="black")
draw.text((50, 390), "商品A 2 10.00", fill="black")
draw.text((50, 430), "商品B 5 20.00", fill="black")
img.save(path)
print(f"[OK] 测试图片已创建: {path}")
return path
def test_ocr_extraction_pipeline():
"""端到端测试:图片 -> OCR -> 字段提取。"""
print("\n=== 端到端OCR字段提取测试 ===\n")
img_path = create_test_invoice("test_invoice_e2e.png")
# 阶段1: 文件解析(含OCR
print("\n--- 阶段1: 文件解析(OCR) ---")
result = parse_file(img_path)
method = result.get("method", "N/A")
print(f" OCR方法: {method}")
print(f" 文件类型: {result.get('file_type', 'N/A')}")
text_preview = result.get("text", "")[:200]
print(f" 文本预览: {text_preview}")
# 阶段2: OCR精确提取
print("\n--- 阶段2: 字段精确提取 ---")
target_fields = ["发票代码", "发票号码", "开票日期", "合计金额", "校验码"]
extraction = extract_ocr_fields(img_path, target_fields)
print(f" OCR可用: {extraction.get('ocr_available')}")
print(f" 图片尺寸: {extraction.get('image_size')}")
print(f" 元素总数: {extraction.get('total_elements')}")
print(f" 错误: {extraction.get('errors')}")
print("\n 提取结果:")
all_ok = True
for field in extraction.get("fields", []):
status = "PASS" if field["field_value"] else "FAIL"
if not field["field_value"]:
all_ok = False
print(f" [{status}] {field['field_name']:10s} = {field['field_value']:20s} "
f"方法={field['extraction_method']:12s} 置信度={field['confidence']:.2f} "
f"bbox={field['bbox']}")
# 策略独立验证
print("\n--- 4种策略独立验证 ---")
extractor = OcrExtractor()
result_obj = extractor.extract(img_path, ["合计金额"])
fields = result_obj.get("fields", [])
if fields:
print(f" 策略: {fields[0].get('extraction_method', 'N/A')}")
print(f" 值: {fields[0].get('field_value', 'N/A')}")
print(f" 坐标: {fields[0].get('bbox', 'N/A')}")
return all_ok
def test_validation_service():
"""测试验证服务连通性。"""
print("\n=== 验证服务连通性测试 ===\n")
result = validate_jrxml("<jasperReport/>")
print(f" 状态: {'OK' if result else 'FAIL'}")
print(f" 响应: {result}")
return True
def test_ocr_fallback():
"""测试OCR回退:无图片时优雅降级。"""
print("\n=== OCR降级测试 ===\n")
result = extract_ocr_fields("/nonexistent/file.png", ["发票代码"])
print(f" OCR可用: {result.get('ocr_available')}")
print(f" 错误: {result.get('errors')}")
assert not result["ocr_available"]
assert len(result["errors"]) > 0
print(" [PASS] 降级行为正常(不阻断流程)")
if __name__ == "__main__":
errors = []
try:
test_ocr_fallback()
except Exception as e:
print(f"[FAIL] 降级测试: {e}")
errors.append(str(e))
try:
ok = test_ocr_extraction_pipeline()
if not ok:
print("\n 部分字段未提取到(可能因字体渲染差异)")
except Exception as e:
print(f"[FAIL] 流水线测试: {e}")
errors.append(str(e))
try:
test_validation_service()
except Exception as e:
print(f"[FAIL] 验证服务: {e}")
errors.append(str(e))
Path("test_invoice_e2e.png").unlink(missing_ok=True)
print("\n" + "=" * 50)
if errors:
print(f"测试完成,{len(errors)} 个错误:")
for e in errors:
print(f" - {e}")
sys.exit(1)
else:
print("所有端到端测试通过!")
+90
View File
@@ -0,0 +1,90 @@
"""测试多格式文件解析器:XLSX, XLS, DOC。"""
import tempfile
from pathlib import Path
import pytest
def _make_xlsx(path: str) -> None:
"""生成最小 .xlsx 测试文件。"""
from openpyxl import Workbook
wb = Workbook()
ws = wb.active
ws.title = "Sheet1"
ws["A1"] = "名称"
ws["B1"] = "金额"
ws["A2"] = "项目A"
ws["B2"] = 100
ws["A3"] = "项目B"
ws["B3"] = 200
wb.save(path)
def _make_xls(path: str) -> None:
"""生成最小 .xls 测试文件。"""
from xlwt import Workbook
wb = Workbook()
ws = wb.add_sheet("Sheet1")
ws.write(0, 0, "名称")
ws.write(0, 1, "金额")
ws.write(1, 0, "项目A")
ws.write(1, 1, 100)
ws.write(2, 0, "项目B")
ws.write(2, 1, 200)
wb.save(path)
class TestMultiFormatParsers:
"""测试 file_parser.py 的多格式解析器。"""
def test_parse_xlsx(self):
from backend.file_parser import parse_file
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as tmp:
path = tmp.name
try:
_make_xlsx(path)
result = parse_file(path, ".xlsx")
assert result["file_type"] == "xlsx"
assert result["method"] == "openpyxl"
assert result["error"] is None
assert "Sheet1" in result["text"]
assert "项目A" in result["text"]
assert "100" in result["text"]
finally:
Path(path).unlink(missing_ok=True)
def test_parse_xls(self):
from backend.file_parser import parse_file
with tempfile.NamedTemporaryFile(suffix=".xls", delete=False) as tmp:
path = tmp.name
try:
_make_xls(path)
result = parse_file(path, ".xls")
assert result["file_type"] == "xls"
assert result["method"] == "xlrd"
assert result["error"] is None
assert "Sheet1" in result["text"]
assert "项目A" in result["text"]
assert "100.0" in result["text"]
finally:
Path(path).unlink(missing_ok=True)
def test_parse_doc_nonexistent(self):
"""测试 .doc 文件不存在时的错误处理。"""
from backend.file_parser import parse_file
result = parse_file("/nonexistent/file.doc", ".doc")
assert result["file_type"] == ".doc"
assert result["method"] == "none"
assert result.get("error") is not None
def test_dispatch_adds_new_formats(self):
"""验证新格式已在 parse_file 调度表中注册。"""
from backend.file_parser import parse_file
for ext in [".xlsx", ".xls", ".doc"]:
result = parse_file("/tmp/test" + ext, ext)
assert result["file_type"] in (ext, "xlsx", "xls", "doc")