"""批注检测器:识别图片上的圈选(圆)和箭头,定位用户要修改的字段。 依赖 OpenCV (cv2),从 PaddleOCR 传递依赖已安装。 """ from __future__ import annotations import math from dataclasses import dataclass, field from typing import Optional import cv2 import numpy as np @dataclass class Annotation: """单个批注标记。""" type: str # "circle" | "arrow" bbox: dict # {"x": int, "y": int, "w": int, "h": int} center: tuple[int, int] # (cx, cy) nearby_texts: list[str] = field(default_factory=list) from_text: str = "" # 箭头出发点的文本 to_text: str = "" # 箭头指向的文本 from_pt: Optional[tuple[int, int]] = None to_pt: Optional[tuple[int, int]] = None def detect_annotations(image_path: str, ocr_elements: list[dict]) -> dict: """检测图片上的手写批注(圈选 + 箭头),并与 OCR 文本关联。 Args: image_path: 图片文件路径 ocr_elements: OCR 元素列表 [{"text": str, "bbox": {x,y,w,h}, "confidence": float}] Returns: {"circles": [...], "arrows": [...], "total": int} """ img = cv2.imread(image_path) if img is None: return {"circles": [], "arrows": [], "total": 0, "error": "无法读取图片"} h, w = img.shape[:2] circles = _detect_circles(img) arrows = _detect_arrows(img) all_annotations = circles + arrows _correlate_with_ocr(all_annotations, ocr_elements, w, h) result: dict = { "circles": [_annotation_to_dict(a) for a in circles], "arrows": [_annotation_to_dict(a) for a in arrows], "total": len(all_annotations), } return result def _annotation_to_dict(a: Annotation) -> dict: d = { "type": a.type, "bbox": a.bbox, "center": list(a.center), "nearby_texts": a.nearby_texts, } if a.type == "arrow": d["from_text"] = a.from_text d["to_text"] = a.to_text if a.from_pt: d["from_pt"] = list(a.from_pt) if a.to_pt: d["to_pt"] = list(a.to_pt) return d # --------------------------------------------------------------------------- # 圆圈检测 # --------------------------------------------------------------------------- def _detect_circles(img: np.ndarray) -> list[Annotation]: """检测图片中可能是手绘批注的圆圈。""" h, w = img.shape[:2] b, g, r = cv2.split(img) red_enhanced = cv2.addWeighted(r.astype(np.float32), 1.5, g.astype(np.float32), -0.3, 0) red_enhanced = cv2.addWeighted(red_enhanced, 1.2, b.astype(np.float32), -0.3, 0) red_enhanced = np.clip(red_enhanced, 0, 255).astype(np.uint8) gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) combined = cv2.addWeighted(gray, 0.5, red_enhanced, 0.5, 0) blurred = cv2.GaussianBlur(combined, (9, 9), 2) min_radius = max(15, min(w, h) // 40) max_radius = min(200, max(w, h) // 8) circles_raw = cv2.HoughCircles( blurred, cv2.HOUGH_GRADIENT, dp=1.2, minDist=min_radius * 2, param1=50, param2=30, minRadius=min_radius, maxRadius=max_radius, ) annotations: list[Annotation] = [] if circles_raw is not None: for cx, cy, r in circles_raw[0]: bbox = { "x": max(0, int(cx - r)), "y": max(0, int(cy - r)), "w": int(r * 2), "h": int(r * 2), } annotations.append(Annotation( type="circle", bbox=bbox, center=(int(cx), int(cy)), )) return annotations # --------------------------------------------------------------------------- # 箭头检测 # --------------------------------------------------------------------------- def _detect_arrows(img: np.ndarray) -> list[Annotation]: """检测图片中的手绘箭头(直线段 + 端点三角形)。""" gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) edges = cv2.Canny(gray, 50, 150, apertureSize=3) lines = cv2.HoughLinesP( edges, rho=1, theta=np.pi / 180, threshold=40, minLineLength=30, maxLineGap=15, ) if lines is None: return [] segments = [(x1, y1, x2, y2) for x1, y1, x2, y2 in lines[:, 0]] clusters = _cluster_segments(segments) annotations: list[Annotation] = [] for segs in clusters: if len(segs) < 2: continue all_pts = [] for x1, y1, x2, y2 in segs: all_pts.append((x1, y1)) all_pts.append((x2, y2)) all_pts_arr = np.array(all_pts) max_dist = 0 p1 = p2 = all_pts[0] for i in range(len(all_pts)): for j in range(i + 1, len(all_pts)): d = (all_pts[i][0] - all_pts[j][0]) ** 2 + (all_pts[i][1] - all_pts[j][1]) ** 2 if d > max_dist: max_dist = d p1, p2 = all_pts[i], all_pts[j] from_pt, to_pt = _find_arrow_direction(edges, p1, p2) x1, y1 = from_pt x2, y2 = to_pt bbox = { "x": min(x1, x2), "y": min(y1, y2), "w": abs(x2 - x1), "h": abs(y2 - y1), } cx = (x1 + x2) // 2 cy = (y1 + y2) // 2 annotations.append(Annotation( type="arrow", bbox=bbox, center=(cx, cy), from_pt=from_pt, to_pt=to_pt, )) return annotations def _cluster_segments(segments: list[tuple]) -> list[list[tuple]]: """将线段按方向和空间距离聚类。""" clusters: list[list[tuple]] = [] used = [False] * len(segments) for i, (x1, y1, x2, y2) in enumerate(segments): if used[i]: continue cluster = [(x1, y1, x2, y2)] used[i] = True angle_i = math.atan2(y2 - y1, x2 - x1) for j in range(i + 1, len(segments)): if used[j]: continue x3, y3, x4, y4 = segments[j] angle_j = math.atan2(y4 - y3, x4 - x3) angle_diff = abs(angle_i - angle_j) if angle_diff > math.pi: angle_diff = 2 * math.pi - angle_diff if angle_diff < 0.35: d1 = math.hypot(x3 - x2, y3 - y2) d2 = math.hypot(x1 - x4, y1 - y4) d3 = math.hypot(x3 - x1, y3 - y1) d4 = math.hypot(x4 - x2, y4 - y2) if min(d1, d2, d3, d4) < 80: cluster.append((x3, y3, x4, y4)) used[j] = True clusters.append(cluster) return clusters def _find_arrow_direction(edges: np.ndarray, p1: tuple, p2: tuple) -> tuple[tuple, tuple]: """判断箭头的方向(哪端是箭头/三角形汇聚点)。""" r = 20 h, w = edges.shape[:2] def edge_density(cx, cy): x1 = max(0, int(cx - r)) y1 = max(0, int(cy - r)) x2 = min(w, int(cx + r)) y2 = min(h, int(cy + r)) roi = edges[y1:y2, x1:x2] if roi.size == 0: return 0 return float(np.count_nonzero(roi)) / roi.size d1 = edge_density(p1[0], p1[1]) d2 = edge_density(p2[0], p2[1]) if d1 > d2 * 1.3: return p2, p1 if d2 > d1 * 1.3: return p1, p2 return p1, p2 # --------------------------------------------------------------------------- # OCR 关联 # --------------------------------------------------------------------------- def _correlate_with_ocr( annotations: list[Annotation], ocr_elements: list[dict], img_w: int, img_h: int, ) -> None: """将批注与附近的 OCR 文本关联。""" if not ocr_elements: return for ann in annotations: ax = ann.center[0] ay = ann.center[1] near_texts: list[tuple[str, float]] = [] for elem in ocr_elements: bbox = elem.get("bbox", {}) ex = bbox.get("x", 0) + bbox.get("w", 0) / 2 ey = bbox.get("y", 0) + bbox.get("h", 0) / 2 dist = math.hypot(ax - ex, ay - ey) max_dist = max(img_w, img_h) * 0.15 if dist < max_dist: near_texts.append((elem.get("text", ""), dist)) near_texts.sort(key=lambda x: x[1]) ann.nearby_texts = [t for t, _ in near_texts[:5]] if ann.type == "arrow" and ann.from_pt and ann.to_pt: ann.from_text = _closest_text(ann.from_pt, ocr_elements, img_w, img_h) ann.to_text = _closest_text(ann.to_pt, ocr_elements, img_w, img_h) def _closest_text(pt: tuple[int, int], ocr_elements: list[dict], img_w: int, img_h: int) -> str: """找到离 pt 最近的 OCR 文本。""" best_text = "" best_dist = max(img_w, img_h) * 0.12 for elem in ocr_elements: bbox = elem.get("bbox", {}) ex = bbox.get("x", 0) + bbox.get("w", 0) / 2 ey = bbox.get("y", 0) + bbox.get("h", 0) / 2 dist = math.hypot(pt[0] - ex, pt[1] - ey) if dist < best_dist: best_dist = dist best_text = elem.get("text", "") return best_text # --------------------------------------------------------------------------- # LLM 上下文格式化 # --------------------------------------------------------------------------- def format_annotation_context(annotation_result: dict) -> str: """将批注检测结果格式化为中文 LLM 提示文本。""" if not annotation_result or not isinstance(annotation_result, dict): return "" circles = annotation_result.get("circles", []) arrows = annotation_result.get("arrows", []) total = annotation_result.get("total", len(circles) + len(arrows)) if total == 0: return "" parts = ["[图片批注检测结果]"] if circles: parts.append(f"\n检测到 {len(circles)} 个圈选标记:") for i, c in enumerate(circles): center = c.get("center", [0, 0]) near = c.get("nearby_texts", []) parts.append( f" 圈{i+1}. 位置 ({center[0]},{center[1]})" f" — 圈选内容: {', '.join(near) if near else '(附近无文字)'}" ) if arrows: parts.append(f"\n检测到 {len(arrows)} 个箭头标记:") for i, a in enumerate(arrows): ft = a.get("from_text", "") tt = a.get("to_text", "") parts.append(f" 箭头{i+1}. 从「{ft}」→ 指向「{tt}」") parts.append("\n请根据上述圈选/箭头定位用户要修改的报表字段。") return "\n".join(parts)