Files
agent_jrxml/tests/test_annotation_detector.py
T
panda 9bb011e429 feat: v4 multimodal chat input, multi-format support, and annotation detection
- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button)
- Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines)
- Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py
- Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting
- Add annotation_result field to AgentState with session persistence
- Wire annotation detection into process_input and _format_ocr_context
- Add 11 new tests: 7 annotation detector + 4 multi-format parser
- Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
2026-05-20 23:43:16 +08:00

152 lines
5.7 KiB
Python

"""测试批注检测器:圆圈检测、箭头检测、OCR 关联、格式化。"""
import tempfile
from pathlib import Path
import cv2
import numpy as np
import pytest
def _draw_circle_image(path: str, size: tuple = (400, 300)) -> None:
"""生成包含红色圆圈的合成测试图片。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
cv2.circle(img, (200, 150), 50, (0, 0, 255), 2)
cv2.imwrite(path, img)
def _draw_arrow_image(path: str, size: tuple = (400, 300)) -> None:
"""生成包含手绘风格箭头的合成测试图片(多段线模拟手绘)。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
# 多段略微偏移的线段模拟手绘箭杆(产生多个 HoughLinesP 段)
for offset_y in (-1, 0, 1):
cv2.line(img, (50, 150 + offset_y), (200, 150 + offset_y), (0, 0, 255), 2)
for offset_y in (-1, 0, 1):
cv2.line(img, (200, 150 + offset_y), (340, 150 + offset_y), (0, 0, 255), 2)
# 箭头三角形
pts = np.array([[350, 150], [330, 135], [330, 165]], np.int32)
cv2.fillPoly(img, [pts], (0, 0, 255))
# 额外三角形边缘线
cv2.line(img, (350, 150), (330, 135), (0, 0, 255), 2)
cv2.line(img, (350, 150), (330, 165), (0, 0, 255), 2)
cv2.imwrite(path, img)
def _draw_circle_and_text_image(path: str, size: tuple = (500, 400)) -> None:
"""生成包含红色圆圈和"文本"的合成图片(模拟圈选批注)。"""
img = np.ones((size[1], size[0], 3), dtype=np.uint8) * 255
cv2.circle(img, (250, 150), 60, (0, 0, 255), 3)
cv2.putText(img, "项目A", (20, 160), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
cv2.imwrite(path, img)
class TestAnnotationDetector:
"""测试 annotation_detector.py 各功能。"""
def test_detect_circles_finds_circle(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_circle_image(path)
ocr_elements = [
{"text": "测试字段", "bbox": {"x": 170, "y": 120, "w": 60, "h": 20}, "confidence": 0.95},
]
result = detect_annotations(path, ocr_elements)
assert result["total"] >= 1
circles = result["circles"]
assert len(circles) >= 1
c = circles[0]
assert c["type"] == "circle"
assert "center" in c
assert "bbox" in c
assert "nearby_texts" in c
finally:
Path(path).unlink(missing_ok=True)
def test_detect_arrows_finds_arrow(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_arrow_image(path)
ocr_elements = [
{"text": "起点", "bbox": {"x": 30, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
{"text": "终点", "bbox": {"x": 310, "y": 130, "w": 40, "h": 20}, "confidence": 0.9},
]
result = detect_annotations(path, ocr_elements)
assert result["total"] >= 1
arrows = result["arrows"]
assert len(arrows) >= 1
a = arrows[0]
assert a["type"] == "arrow"
assert "from_pt" in a
assert "to_pt" in a
finally:
Path(path).unlink(missing_ok=True)
def test_correlate_with_ocr_links_nearby_texts(self):
from backend.annotation_detector import detect_annotations
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
path = tmp.name
try:
_draw_circle_and_text_image(path)
ocr_elements = [
{"text": "项目A", "bbox": {"x": 20, "y": 140, "w": 80, "h": 30}, "confidence": 0.98},
{"text": "金额", "bbox": {"x": 350, "y": 200, "w": 50, "h": 20}, "confidence": 0.9},
]
result = detect_annotations(path, ocr_elements)
circles = result["circles"]
if circles:
near = circles[0].get("nearby_texts", [])
if near:
assert "项目A" in near
finally:
Path(path).unlink(missing_ok=True)
def test_invalid_image_path(self):
from backend.annotation_detector import detect_annotations
result = detect_annotations("/nonexistent/file.png", [])
assert result["total"] == 0
assert "error" in result
def test_format_annotation_context_empty(self):
from backend.annotation_detector import format_annotation_context
assert format_annotation_context({}) == ""
assert format_annotation_context(None) == ""
assert format_annotation_context({"circles": [], "arrows": [], "total": 0}) == ""
def test_format_annotation_context_with_circles(self):
from backend.annotation_detector import format_annotation_context
ann = {
"circles": [
{"center": [100, 200], "nearby_texts": ["项目A", "金额"]},
],
"arrows": [],
"total": 1,
}
text = format_annotation_context(ann)
assert "圈选标记" in text
assert "项目A" in text
def test_format_annotation_context_with_arrows(self):
from backend.annotation_detector import format_annotation_context
ann = {
"circles": [],
"arrows": [
{"from_text": "修理号", "to_text": "车架号"},
],
"total": 1,
}
text = format_annotation_context(ann)
assert "箭头标记" in text
assert "修理号" in text
assert "车架号" in text