9bb011e429
- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button) - Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines) - Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py - Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting - Add annotation_result field to AgentState with session persistence - Wire annotation detection into process_input and _format_ocr_context - Add 11 new tests: 7 annotation detector + 4 multi-format parser - Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
152 lines
5.7 KiB
Python
152 lines
5.7 KiB
Python
"""测试批注检测器:圆圈检测、箭头检测、OCR 关联、格式化。"""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
def _draw_circle_image(path: str, size: tuple = (400, 300)) -> None:
|
|
"""生成包含红色圆圈的合成测试图片。"""
|
|
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
|
cv2.circle(img, (200, 150), 50, (0, 0, 255), 2)
|
|
cv2.imwrite(path, img)
|
|
|
|
|
|
def _draw_arrow_image(path: str, size: tuple = (400, 300)) -> None:
|
|
"""生成包含手绘风格箭头的合成测试图片(多段线模拟手绘)。"""
|
|
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
|
# 多段略微偏移的线段模拟手绘箭杆(产生多个 HoughLinesP 段)
|
|
for offset_y in (-1, 0, 1):
|
|
cv2.line(img, (50, 150 + offset_y), (200, 150 + offset_y), (0, 0, 255), 2)
|
|
for offset_y in (-1, 0, 1):
|
|
cv2.line(img, (200, 150 + offset_y), (340, 150 + offset_y), (0, 0, 255), 2)
|
|
# 箭头三角形
|
|
pts = np.array([[350, 150], [330, 135], [330, 165]], np.int32)
|
|
cv2.fillPoly(img, [pts], (0, 0, 255))
|
|
# 额外三角形边缘线
|
|
cv2.line(img, (350, 150), (330, 135), (0, 0, 255), 2)
|
|
cv2.line(img, (350, 150), (330, 165), (0, 0, 255), 2)
|
|
cv2.imwrite(path, img)
|
|
|
|
|
|
def _draw_circle_and_text_image(path: str, size: tuple = (500, 400)) -> None:
|
|
"""生成包含红色圆圈和"文本"的合成图片(模拟圈选批注)。"""
|
|
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
|
|
cv2.circle(img, (250, 150), 60, (0, 0, 255), 3)
|
|
cv2.putText(img, "项目A", (20, 160), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
|
cv2.imwrite(path, img)
|
|
|
|
|
|
class TestAnnotationDetector:
|
|
"""测试 annotation_detector.py 各功能。"""
|
|
|
|
def test_detect_circles_finds_circle(self):
|
|
from backend.annotation_detector import detect_annotations
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
|
path = tmp.name
|
|
try:
|
|
_draw_circle_image(path)
|
|
ocr_elements = [
|
|
{"text": "测试字段", "bbox": {"x": 170, "y": 120, "w": 60, "h": 20}, "confidence": 0.95},
|
|
]
|
|
result = detect_annotations(path, ocr_elements)
|
|
assert result["total"] >= 1
|
|
circles = result["circles"]
|
|
assert len(circles) >= 1
|
|
c = circles[0]
|
|
assert c["type"] == "circle"
|
|
assert "center" in c
|
|
assert "bbox" in c
|
|
assert "nearby_texts" in c
|
|
finally:
|
|
Path(path).unlink(missing_ok=True)
|
|
|
|
def test_detect_arrows_finds_arrow(self):
|
|
from backend.annotation_detector import detect_annotations
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
|
path = tmp.name
|
|
try:
|
|
_draw_arrow_image(path)
|
|
ocr_elements = [
|
|
{"text": "起点", "bbox": {"x": 30, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
|
|
{"text": "终点", "bbox": {"x": 310, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
|
|
]
|
|
result = detect_annotations(path, ocr_elements)
|
|
assert result["total"] >= 1
|
|
arrows = result["arrows"]
|
|
assert len(arrows) >= 1
|
|
a = arrows[0]
|
|
assert a["type"] == "arrow"
|
|
assert "from_pt" in a
|
|
assert "to_pt" in a
|
|
finally:
|
|
Path(path).unlink(missing_ok=True)
|
|
|
|
def test_correlate_with_ocr_links_nearby_texts(self):
|
|
from backend.annotation_detector import detect_annotations
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
|
path = tmp.name
|
|
try:
|
|
_draw_circle_and_text_image(path)
|
|
ocr_elements = [
|
|
{"text": "项目A", "bbox": {"x": 20, "y": 140, "w": 80, "h": 30}, "confidence": 0.98},
|
|
{"text": "金额", "bbox": {"x": 350, "y": 200, "w": 50, "h": 20}, "confidence": 0.9},
|
|
]
|
|
result = detect_annotations(path, ocr_elements)
|
|
circles = result["circles"]
|
|
if circles:
|
|
near = circles[0].get("nearby_texts", [])
|
|
if near:
|
|
assert "项目A" in near
|
|
finally:
|
|
Path(path).unlink(missing_ok=True)
|
|
|
|
def test_invalid_image_path(self):
|
|
from backend.annotation_detector import detect_annotations
|
|
|
|
result = detect_annotations("/nonexistent/file.png", [])
|
|
assert result["total"] == 0
|
|
assert "error" in result
|
|
|
|
def test_format_annotation_context_empty(self):
|
|
from backend.annotation_detector import format_annotation_context
|
|
|
|
assert format_annotation_context({}) == ""
|
|
assert format_annotation_context(None) == ""
|
|
assert format_annotation_context({"circles": [], "arrows": [], "total": 0}) == ""
|
|
|
|
def test_format_annotation_context_with_circles(self):
|
|
from backend.annotation_detector import format_annotation_context
|
|
|
|
ann = {
|
|
"circles": [
|
|
{"center": [100, 200], "nearby_texts": ["项目A", "金额"]},
|
|
],
|
|
"arrows": [],
|
|
"total": 1,
|
|
}
|
|
text = format_annotation_context(ann)
|
|
assert "圈选标记" in text
|
|
assert "项目A" in text
|
|
|
|
def test_format_annotation_context_with_arrows(self):
|
|
from backend.annotation_detector import format_annotation_context
|
|
|
|
ann = {
|
|
"circles": [],
|
|
"arrows": [
|
|
{"from_text": "修理号", "to_text": "车架号"},
|
|
],
|
|
"total": 1,
|
|
}
|
|
text = format_annotation_context(ann)
|
|
assert "箭头标记" in text
|
|
assert "修理号" in text
|
|
assert "车架号" in text
|