9bb011e429
- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button) - Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines) - Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py - Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting - Add annotation_result field to AgentState with session persistence - Wire annotation detection into process_input and _format_ocr_context - Add 11 new tests: 7 annotation detector + 4 multi-format parser - Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
332 lines
10 KiB
Python
332 lines
10 KiB
Python
"""批注检测器:识别图片上的圈选(圆)和箭头,定位用户要修改的字段。
|
|
|
|
依赖 OpenCV (cv2),从 PaddleOCR 传递依赖已安装。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
|
|
@dataclass
|
|
class Annotation:
|
|
"""单个批注标记。"""
|
|
type: str # "circle" | "arrow"
|
|
bbox: dict # {"x": int, "y": int, "w": int, "h": int}
|
|
center: tuple[int, int] # (cx, cy)
|
|
nearby_texts: list[str] = field(default_factory=list)
|
|
from_text: str = "" # 箭头出发点的文本
|
|
to_text: str = "" # 箭头指向的文本
|
|
from_pt: Optional[tuple[int, int]] = None
|
|
to_pt: Optional[tuple[int, int]] = None
|
|
|
|
|
|
def detect_annotations(image_path: str, ocr_elements: list[dict]) -> dict:
|
|
"""检测图片上的手写批注(圈选 + 箭头),并与 OCR 文本关联。
|
|
|
|
Args:
|
|
image_path: 图片文件路径
|
|
ocr_elements: OCR 元素列表 [{"text": str, "bbox": {x,y,w,h}, "confidence": float}]
|
|
|
|
Returns:
|
|
{"circles": [...], "arrows": [...], "total": int}
|
|
"""
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
return {"circles": [], "arrows": [], "total": 0, "error": "无法读取图片"}
|
|
|
|
h, w = img.shape[:2]
|
|
|
|
circles = _detect_circles(img)
|
|
arrows = _detect_arrows(img)
|
|
|
|
all_annotations = circles + arrows
|
|
_correlate_with_ocr(all_annotations, ocr_elements, w, h)
|
|
|
|
result: dict = {
|
|
"circles": [_annotation_to_dict(a) for a in circles],
|
|
"arrows": [_annotation_to_dict(a) for a in arrows],
|
|
"total": len(all_annotations),
|
|
}
|
|
return result
|
|
|
|
|
|
def _annotation_to_dict(a: Annotation) -> dict:
|
|
d = {
|
|
"type": a.type,
|
|
"bbox": a.bbox,
|
|
"center": list(a.center),
|
|
"nearby_texts": a.nearby_texts,
|
|
}
|
|
if a.type == "arrow":
|
|
d["from_text"] = a.from_text
|
|
d["to_text"] = a.to_text
|
|
if a.from_pt:
|
|
d["from_pt"] = list(a.from_pt)
|
|
if a.to_pt:
|
|
d["to_pt"] = list(a.to_pt)
|
|
return d
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 圆圈检测
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _detect_circles(img: np.ndarray) -> list[Annotation]:
|
|
"""检测图片中可能是手绘批注的圆圈。"""
|
|
h, w = img.shape[:2]
|
|
b, g, r = cv2.split(img)
|
|
red_enhanced = cv2.addWeighted(r.astype(np.float32), 1.5,
|
|
g.astype(np.float32), -0.3, 0)
|
|
red_enhanced = cv2.addWeighted(red_enhanced, 1.2,
|
|
b.astype(np.float32), -0.3, 0)
|
|
red_enhanced = np.clip(red_enhanced, 0, 255).astype(np.uint8)
|
|
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
combined = cv2.addWeighted(gray, 0.5, red_enhanced, 0.5, 0)
|
|
blurred = cv2.GaussianBlur(combined, (9, 9), 2)
|
|
|
|
min_radius = max(15, min(w, h) // 40)
|
|
max_radius = min(200, max(w, h) // 8)
|
|
|
|
circles_raw = cv2.HoughCircles(
|
|
blurred, cv2.HOUGH_GRADIENT, dp=1.2, minDist=min_radius * 2,
|
|
param1=50, param2=30, minRadius=min_radius, maxRadius=max_radius,
|
|
)
|
|
|
|
annotations: list[Annotation] = []
|
|
|
|
if circles_raw is not None:
|
|
for cx, cy, r in circles_raw[0]:
|
|
bbox = {
|
|
"x": max(0, int(cx - r)),
|
|
"y": max(0, int(cy - r)),
|
|
"w": int(r * 2),
|
|
"h": int(r * 2),
|
|
}
|
|
annotations.append(Annotation(
|
|
type="circle",
|
|
bbox=bbox,
|
|
center=(int(cx), int(cy)),
|
|
))
|
|
|
|
return annotations
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 箭头检测
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _detect_arrows(img: np.ndarray) -> list[Annotation]:
|
|
"""检测图片中的手绘箭头(直线段 + 端点三角形)。"""
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
|
|
|
lines = cv2.HoughLinesP(
|
|
edges, rho=1, theta=np.pi / 180, threshold=40,
|
|
minLineLength=30, maxLineGap=15,
|
|
)
|
|
|
|
if lines is None:
|
|
return []
|
|
|
|
segments = [(x1, y1, x2, y2) for x1, y1, x2, y2 in lines[:, 0]]
|
|
clusters = _cluster_segments(segments)
|
|
|
|
annotations: list[Annotation] = []
|
|
for segs in clusters:
|
|
if len(segs) < 2:
|
|
continue
|
|
all_pts = []
|
|
for x1, y1, x2, y2 in segs:
|
|
all_pts.append((x1, y1))
|
|
all_pts.append((x2, y2))
|
|
all_pts_arr = np.array(all_pts)
|
|
max_dist = 0
|
|
p1 = p2 = all_pts[0]
|
|
for i in range(len(all_pts)):
|
|
for j in range(i + 1, len(all_pts)):
|
|
d = (all_pts[i][0] - all_pts[j][0]) ** 2 + (all_pts[i][1] - all_pts[j][1]) ** 2
|
|
if d > max_dist:
|
|
max_dist = d
|
|
p1, p2 = all_pts[i], all_pts[j]
|
|
|
|
from_pt, to_pt = _find_arrow_direction(edges, p1, p2)
|
|
|
|
x1, y1 = from_pt
|
|
x2, y2 = to_pt
|
|
bbox = {
|
|
"x": min(x1, x2),
|
|
"y": min(y1, y2),
|
|
"w": abs(x2 - x1),
|
|
"h": abs(y2 - y1),
|
|
}
|
|
cx = (x1 + x2) // 2
|
|
cy = (y1 + y2) // 2
|
|
|
|
annotations.append(Annotation(
|
|
type="arrow",
|
|
bbox=bbox,
|
|
center=(cx, cy),
|
|
from_pt=from_pt,
|
|
to_pt=to_pt,
|
|
))
|
|
|
|
return annotations
|
|
|
|
|
|
def _cluster_segments(segments: list[tuple]) -> list[list[tuple]]:
|
|
"""将线段按方向和空间距离聚类。"""
|
|
clusters: list[list[tuple]] = []
|
|
used = [False] * len(segments)
|
|
|
|
for i, (x1, y1, x2, y2) in enumerate(segments):
|
|
if used[i]:
|
|
continue
|
|
cluster = [(x1, y1, x2, y2)]
|
|
used[i] = True
|
|
angle_i = math.atan2(y2 - y1, x2 - x1)
|
|
|
|
for j in range(i + 1, len(segments)):
|
|
if used[j]:
|
|
continue
|
|
x3, y3, x4, y4 = segments[j]
|
|
angle_j = math.atan2(y4 - y3, x4 - x3)
|
|
angle_diff = abs(angle_i - angle_j)
|
|
if angle_diff > math.pi:
|
|
angle_diff = 2 * math.pi - angle_diff
|
|
|
|
if angle_diff < 0.35:
|
|
d1 = math.hypot(x3 - x2, y3 - y2)
|
|
d2 = math.hypot(x1 - x4, y1 - y4)
|
|
d3 = math.hypot(x3 - x1, y3 - y1)
|
|
d4 = math.hypot(x4 - x2, y4 - y2)
|
|
if min(d1, d2, d3, d4) < 80:
|
|
cluster.append((x3, y3, x4, y4))
|
|
used[j] = True
|
|
|
|
clusters.append(cluster)
|
|
|
|
return clusters
|
|
|
|
|
|
def _find_arrow_direction(edges: np.ndarray, p1: tuple, p2: tuple) -> tuple[tuple, tuple]:
|
|
"""判断箭头的方向(哪端是箭头/三角形汇聚点)。"""
|
|
r = 20
|
|
h, w = edges.shape[:2]
|
|
|
|
def edge_density(cx, cy):
|
|
x1 = max(0, int(cx - r))
|
|
y1 = max(0, int(cy - r))
|
|
x2 = min(w, int(cx + r))
|
|
y2 = min(h, int(cy + r))
|
|
roi = edges[y1:y2, x1:x2]
|
|
if roi.size == 0:
|
|
return 0
|
|
return float(np.count_nonzero(roi)) / roi.size
|
|
|
|
d1 = edge_density(p1[0], p1[1])
|
|
d2 = edge_density(p2[0], p2[1])
|
|
|
|
if d1 > d2 * 1.3:
|
|
return p2, p1
|
|
if d2 > d1 * 1.3:
|
|
return p1, p2
|
|
return p1, p2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OCR 关联
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _correlate_with_ocr(
|
|
annotations: list[Annotation],
|
|
ocr_elements: list[dict],
|
|
img_w: int,
|
|
img_h: int,
|
|
) -> None:
|
|
"""将批注与附近的 OCR 文本关联。"""
|
|
if not ocr_elements:
|
|
return
|
|
|
|
for ann in annotations:
|
|
ax = ann.center[0]
|
|
ay = ann.center[1]
|
|
|
|
near_texts: list[tuple[str, float]] = []
|
|
|
|
for elem in ocr_elements:
|
|
bbox = elem.get("bbox", {})
|
|
ex = bbox.get("x", 0) + bbox.get("w", 0) / 2
|
|
ey = bbox.get("y", 0) + bbox.get("h", 0) / 2
|
|
dist = math.hypot(ax - ex, ay - ey)
|
|
max_dist = max(img_w, img_h) * 0.15
|
|
if dist < max_dist:
|
|
near_texts.append((elem.get("text", ""), dist))
|
|
|
|
near_texts.sort(key=lambda x: x[1])
|
|
ann.nearby_texts = [t for t, _ in near_texts[:5]]
|
|
|
|
if ann.type == "arrow" and ann.from_pt and ann.to_pt:
|
|
ann.from_text = _closest_text(ann.from_pt, ocr_elements, img_w, img_h)
|
|
ann.to_text = _closest_text(ann.to_pt, ocr_elements, img_w, img_h)
|
|
|
|
|
|
def _closest_text(pt: tuple[int, int], ocr_elements: list[dict], img_w: int, img_h: int) -> str:
|
|
"""找到离 pt 最近的 OCR 文本。"""
|
|
best_text = ""
|
|
best_dist = max(img_w, img_h) * 0.12
|
|
for elem in ocr_elements:
|
|
bbox = elem.get("bbox", {})
|
|
ex = bbox.get("x", 0) + bbox.get("w", 0) / 2
|
|
ey = bbox.get("y", 0) + bbox.get("h", 0) / 2
|
|
dist = math.hypot(pt[0] - ex, pt[1] - ey)
|
|
if dist < best_dist:
|
|
best_dist = dist
|
|
best_text = elem.get("text", "")
|
|
return best_text
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLM 上下文格式化
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def format_annotation_context(annotation_result: dict) -> str:
|
|
"""将批注检测结果格式化为中文 LLM 提示文本。"""
|
|
if not annotation_result or not isinstance(annotation_result, dict):
|
|
return ""
|
|
|
|
circles = annotation_result.get("circles", [])
|
|
arrows = annotation_result.get("arrows", [])
|
|
total = annotation_result.get("total", len(circles) + len(arrows))
|
|
|
|
if total == 0:
|
|
return ""
|
|
|
|
parts = ["[图片批注检测结果]"]
|
|
|
|
if circles:
|
|
parts.append(f"\n检测到 {len(circles)} 个圈选标记:")
|
|
for i, c in enumerate(circles):
|
|
center = c.get("center", [0, 0])
|
|
near = c.get("nearby_texts", [])
|
|
parts.append(
|
|
f" 圈{i+1}. 位置 ({center[0]},{center[1]})"
|
|
f" — 圈选内容: {', '.join(near) if near else '(附近无文字)'}"
|
|
)
|
|
|
|
if arrows:
|
|
parts.append(f"\n检测到 {len(arrows)} 个箭头标记:")
|
|
for i, a in enumerate(arrows):
|
|
ft = a.get("from_text", "")
|
|
tt = a.get("to_text", "")
|
|
parts.append(f" 箭头{i+1}. 从「{ft}」→ 指向「{tt}」")
|
|
|
|
parts.append("\n请根据上述圈选/箭头定位用户要修改的报表字段。")
|
|
return "\n".join(parts)
|