Merge remote v4/v5 features (multimodal chat input, layered generation, annotation detection) with local v3 features (dialog file upload, XLSX support, session fix)
Key resolutions: - agent/nodes.py: Merged session_id exclusion fix with new persistable fields (ocr_extraction_result, annotation_result, layout_schema, ocr_elements) - app.py: Adopted st-multimodal-chatinput for unified paste/drop/upload, removed custom JS paste bridge - backend/file_parser.py: Kept local XLSX parser, added remote XLS/DOC parsers - CLAUDE.md + CODE_GUIDE.md: Merged documentation from both branches Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,331 @@
|
||||
"""批注检测器:识别图片上的圈选(圆)和箭头,定位用户要修改的字段。
|
||||
|
||||
依赖 OpenCV (cv2),从 PaddleOCR 传递依赖已安装。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Annotation:
|
||||
"""单个批注标记。"""
|
||||
type: str # "circle" | "arrow"
|
||||
bbox: dict # {"x": int, "y": int, "w": int, "h": int}
|
||||
center: tuple[int, int] # (cx, cy)
|
||||
nearby_texts: list[str] = field(default_factory=list)
|
||||
from_text: str = "" # 箭头出发点的文本
|
||||
to_text: str = "" # 箭头指向的文本
|
||||
from_pt: Optional[tuple[int, int]] = None
|
||||
to_pt: Optional[tuple[int, int]] = None
|
||||
|
||||
|
||||
def detect_annotations(image_path: str, ocr_elements: list[dict]) -> dict:
|
||||
"""检测图片上的手写批注(圈选 + 箭头),并与 OCR 文本关联。
|
||||
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
ocr_elements: OCR 元素列表 [{"text": str, "bbox": {x,y,w,h}, "confidence": float}]
|
||||
|
||||
Returns:
|
||||
{"circles": [...], "arrows": [...], "total": int}
|
||||
"""
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
return {"circles": [], "arrows": [], "total": 0, "error": "无法读取图片"}
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
circles = _detect_circles(img)
|
||||
arrows = _detect_arrows(img)
|
||||
|
||||
all_annotations = circles + arrows
|
||||
_correlate_with_ocr(all_annotations, ocr_elements, w, h)
|
||||
|
||||
result: dict = {
|
||||
"circles": [_annotation_to_dict(a) for a in circles],
|
||||
"arrows": [_annotation_to_dict(a) for a in arrows],
|
||||
"total": len(all_annotations),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _annotation_to_dict(a: Annotation) -> dict:
|
||||
d = {
|
||||
"type": a.type,
|
||||
"bbox": a.bbox,
|
||||
"center": list(a.center),
|
||||
"nearby_texts": a.nearby_texts,
|
||||
}
|
||||
if a.type == "arrow":
|
||||
d["from_text"] = a.from_text
|
||||
d["to_text"] = a.to_text
|
||||
if a.from_pt:
|
||||
d["from_pt"] = list(a.from_pt)
|
||||
if a.to_pt:
|
||||
d["to_pt"] = list(a.to_pt)
|
||||
return d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 圆圈检测
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_circles(img: np.ndarray) -> list[Annotation]:
|
||||
"""检测图片中可能是手绘批注的圆圈。"""
|
||||
h, w = img.shape[:2]
|
||||
b, g, r = cv2.split(img)
|
||||
red_enhanced = cv2.addWeighted(r.astype(np.float32), 1.5,
|
||||
g.astype(np.float32), -0.3, 0)
|
||||
red_enhanced = cv2.addWeighted(red_enhanced, 1.2,
|
||||
b.astype(np.float32), -0.3, 0)
|
||||
red_enhanced = np.clip(red_enhanced, 0, 255).astype(np.uint8)
|
||||
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
combined = cv2.addWeighted(gray, 0.5, red_enhanced, 0.5, 0)
|
||||
blurred = cv2.GaussianBlur(combined, (9, 9), 2)
|
||||
|
||||
min_radius = max(15, min(w, h) // 40)
|
||||
max_radius = min(200, max(w, h) // 8)
|
||||
|
||||
circles_raw = cv2.HoughCircles(
|
||||
blurred, cv2.HOUGH_GRADIENT, dp=1.2, minDist=min_radius * 2,
|
||||
param1=50, param2=30, minRadius=min_radius, maxRadius=max_radius,
|
||||
)
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
|
||||
if circles_raw is not None:
|
||||
for cx, cy, r in circles_raw[0]:
|
||||
bbox = {
|
||||
"x": max(0, int(cx - r)),
|
||||
"y": max(0, int(cy - r)),
|
||||
"w": int(r * 2),
|
||||
"h": int(r * 2),
|
||||
}
|
||||
annotations.append(Annotation(
|
||||
type="circle",
|
||||
bbox=bbox,
|
||||
center=(int(cx), int(cy)),
|
||||
))
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 箭头检测
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _detect_arrows(img: np.ndarray) -> list[Annotation]:
|
||||
"""检测图片中的手绘箭头(直线段 + 端点三角形)。"""
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
||||
|
||||
lines = cv2.HoughLinesP(
|
||||
edges, rho=1, theta=np.pi / 180, threshold=40,
|
||||
minLineLength=30, maxLineGap=15,
|
||||
)
|
||||
|
||||
if lines is None:
|
||||
return []
|
||||
|
||||
segments = [(x1, y1, x2, y2) for x1, y1, x2, y2 in lines[:, 0]]
|
||||
clusters = _cluster_segments(segments)
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for segs in clusters:
|
||||
if len(segs) < 2:
|
||||
continue
|
||||
all_pts = []
|
||||
for x1, y1, x2, y2 in segs:
|
||||
all_pts.append((x1, y1))
|
||||
all_pts.append((x2, y2))
|
||||
all_pts_arr = np.array(all_pts)
|
||||
max_dist = 0
|
||||
p1 = p2 = all_pts[0]
|
||||
for i in range(len(all_pts)):
|
||||
for j in range(i + 1, len(all_pts)):
|
||||
d = (all_pts[i][0] - all_pts[j][0]) ** 2 + (all_pts[i][1] - all_pts[j][1]) ** 2
|
||||
if d > max_dist:
|
||||
max_dist = d
|
||||
p1, p2 = all_pts[i], all_pts[j]
|
||||
|
||||
from_pt, to_pt = _find_arrow_direction(edges, p1, p2)
|
||||
|
||||
x1, y1 = from_pt
|
||||
x2, y2 = to_pt
|
||||
bbox = {
|
||||
"x": min(x1, x2),
|
||||
"y": min(y1, y2),
|
||||
"w": abs(x2 - x1),
|
||||
"h": abs(y2 - y1),
|
||||
}
|
||||
cx = (x1 + x2) // 2
|
||||
cy = (y1 + y2) // 2
|
||||
|
||||
annotations.append(Annotation(
|
||||
type="arrow",
|
||||
bbox=bbox,
|
||||
center=(cx, cy),
|
||||
from_pt=from_pt,
|
||||
to_pt=to_pt,
|
||||
))
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
def _cluster_segments(segments: list[tuple]) -> list[list[tuple]]:
|
||||
"""将线段按方向和空间距离聚类。"""
|
||||
clusters: list[list[tuple]] = []
|
||||
used = [False] * len(segments)
|
||||
|
||||
for i, (x1, y1, x2, y2) in enumerate(segments):
|
||||
if used[i]:
|
||||
continue
|
||||
cluster = [(x1, y1, x2, y2)]
|
||||
used[i] = True
|
||||
angle_i = math.atan2(y2 - y1, x2 - x1)
|
||||
|
||||
for j in range(i + 1, len(segments)):
|
||||
if used[j]:
|
||||
continue
|
||||
x3, y3, x4, y4 = segments[j]
|
||||
angle_j = math.atan2(y4 - y3, x4 - x3)
|
||||
angle_diff = abs(angle_i - angle_j)
|
||||
if angle_diff > math.pi:
|
||||
angle_diff = 2 * math.pi - angle_diff
|
||||
|
||||
if angle_diff < 0.35:
|
||||
d1 = math.hypot(x3 - x2, y3 - y2)
|
||||
d2 = math.hypot(x1 - x4, y1 - y4)
|
||||
d3 = math.hypot(x3 - x1, y3 - y1)
|
||||
d4 = math.hypot(x4 - x2, y4 - y2)
|
||||
if min(d1, d2, d3, d4) < 80:
|
||||
cluster.append((x3, y3, x4, y4))
|
||||
used[j] = True
|
||||
|
||||
clusters.append(cluster)
|
||||
|
||||
return clusters
|
||||
|
||||
|
||||
def _find_arrow_direction(edges: np.ndarray, p1: tuple, p2: tuple) -> tuple[tuple, tuple]:
|
||||
"""判断箭头的方向(哪端是箭头/三角形汇聚点)。"""
|
||||
r = 20
|
||||
h, w = edges.shape[:2]
|
||||
|
||||
def edge_density(cx, cy):
|
||||
x1 = max(0, int(cx - r))
|
||||
y1 = max(0, int(cy - r))
|
||||
x2 = min(w, int(cx + r))
|
||||
y2 = min(h, int(cy + r))
|
||||
roi = edges[y1:y2, x1:x2]
|
||||
if roi.size == 0:
|
||||
return 0
|
||||
return float(np.count_nonzero(roi)) / roi.size
|
||||
|
||||
d1 = edge_density(p1[0], p1[1])
|
||||
d2 = edge_density(p2[0], p2[1])
|
||||
|
||||
if d1 > d2 * 1.3:
|
||||
return p2, p1
|
||||
if d2 > d1 * 1.3:
|
||||
return p1, p2
|
||||
return p1, p2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OCR 关联
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _correlate_with_ocr(
|
||||
annotations: list[Annotation],
|
||||
ocr_elements: list[dict],
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
) -> None:
|
||||
"""将批注与附近的 OCR 文本关联。"""
|
||||
if not ocr_elements:
|
||||
return
|
||||
|
||||
for ann in annotations:
|
||||
ax = ann.center[0]
|
||||
ay = ann.center[1]
|
||||
|
||||
near_texts: list[tuple[str, float]] = []
|
||||
|
||||
for elem in ocr_elements:
|
||||
bbox = elem.get("bbox", {})
|
||||
ex = bbox.get("x", 0) + bbox.get("w", 0) / 2
|
||||
ey = bbox.get("y", 0) + bbox.get("h", 0) / 2
|
||||
dist = math.hypot(ax - ex, ay - ey)
|
||||
max_dist = max(img_w, img_h) * 0.15
|
||||
if dist < max_dist:
|
||||
near_texts.append((elem.get("text", ""), dist))
|
||||
|
||||
near_texts.sort(key=lambda x: x[1])
|
||||
ann.nearby_texts = [t for t, _ in near_texts[:5]]
|
||||
|
||||
if ann.type == "arrow" and ann.from_pt and ann.to_pt:
|
||||
ann.from_text = _closest_text(ann.from_pt, ocr_elements, img_w, img_h)
|
||||
ann.to_text = _closest_text(ann.to_pt, ocr_elements, img_w, img_h)
|
||||
|
||||
|
||||
def _closest_text(pt: tuple[int, int], ocr_elements: list[dict], img_w: int, img_h: int) -> str:
|
||||
"""找到离 pt 最近的 OCR 文本。"""
|
||||
best_text = ""
|
||||
best_dist = max(img_w, img_h) * 0.12
|
||||
for elem in ocr_elements:
|
||||
bbox = elem.get("bbox", {})
|
||||
ex = bbox.get("x", 0) + bbox.get("w", 0) / 2
|
||||
ey = bbox.get("y", 0) + bbox.get("h", 0) / 2
|
||||
dist = math.hypot(pt[0] - ex, pt[1] - ey)
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
best_text = elem.get("text", "")
|
||||
return best_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM 上下文格式化
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def format_annotation_context(annotation_result: dict) -> str:
|
||||
"""将批注检测结果格式化为中文 LLM 提示文本。"""
|
||||
if not annotation_result or not isinstance(annotation_result, dict):
|
||||
return ""
|
||||
|
||||
circles = annotation_result.get("circles", [])
|
||||
arrows = annotation_result.get("arrows", [])
|
||||
total = annotation_result.get("total", len(circles) + len(arrows))
|
||||
|
||||
if total == 0:
|
||||
return ""
|
||||
|
||||
parts = ["[图片批注检测结果]"]
|
||||
|
||||
if circles:
|
||||
parts.append(f"\n检测到 {len(circles)} 个圈选标记:")
|
||||
for i, c in enumerate(circles):
|
||||
center = c.get("center", [0, 0])
|
||||
near = c.get("nearby_texts", [])
|
||||
parts.append(
|
||||
f" 圈{i+1}. 位置 ({center[0]},{center[1]})"
|
||||
f" — 圈选内容: {', '.join(near) if near else '(附近无文字)'}"
|
||||
)
|
||||
|
||||
if arrows:
|
||||
parts.append(f"\n检测到 {len(arrows)} 个箭头标记:")
|
||||
for i, a in enumerate(arrows):
|
||||
ft = a.get("from_text", "")
|
||||
tt = a.get("to_text", "")
|
||||
parts.append(f" 箭头{i+1}. 从「{ft}」→ 指向「{tt}」")
|
||||
|
||||
parts.append("\n请根据上述圈选/箭头定位用户要修改的报表字段。")
|
||||
return "\n".join(parts)
|
||||
+95
-38
@@ -52,6 +52,8 @@ def parse_file(file_path: str, file_type: str = "") -> dict:
|
||||
".pdf": _parse_pdf,
|
||||
".docx": _parse_docx,
|
||||
".xlsx": _parse_xlsx,
|
||||
".xls": _parse_xls,
|
||||
".doc": _parse_doc,
|
||||
}
|
||||
|
||||
parser = parsers.get(suffix)
|
||||
@@ -73,26 +75,7 @@ def _parse_image(path: Path) -> dict:
|
||||
except Exception:
|
||||
info = "[图片: 无法读取元数据]"
|
||||
|
||||
# 优先 EasyOCR(Windows 兼容性更好)
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
|
||||
result = reader.readtext(np.array(img))
|
||||
lines = [text.strip() for (_, text, _) in result if text.strip()]
|
||||
if lines:
|
||||
return {
|
||||
"text": f"{info}\n识别文本:\n" + "\n".join(lines),
|
||||
"file_type": "image",
|
||||
"method": "easyocr",
|
||||
"error": None,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退 PaddleOCR
|
||||
# 优先 PaddleOCR(精确识别)
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
ocr = PaddleOCR(lang="ch")
|
||||
@@ -115,6 +98,25 @@ def _parse_image(path: Path) -> dict:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退 EasyOCR
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
|
||||
result = reader.readtext(np.array(img))
|
||||
lines = [text.strip() for (_, text, _) in result if text.strip()]
|
||||
if lines:
|
||||
return {
|
||||
"text": f"{info}\n识别文本:\n" + "\n".join(lines),
|
||||
"file_type": "image",
|
||||
"method": "easyocr",
|
||||
"error": None,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# OCR 不可用 → 返回图片元信息 + 安装提示
|
||||
return {
|
||||
"text": f"{info}\n(如需 OCR 文字识别,请安装: pip install easyocr)",
|
||||
@@ -197,36 +199,91 @@ def _parse_docx(path: Path) -> dict:
|
||||
|
||||
|
||||
def _parse_xlsx(path: Path) -> dict:
|
||||
"""提取 Excel (.xlsx) 表格内容为文本。"""
|
||||
"""提取 Excel .xlsx 文件中的文本。"""
|
||||
try:
|
||||
import openpyxl
|
||||
wb = openpyxl.load_workbook(path, read_only=True, data_only=True)
|
||||
sheets_text = []
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
from openpyxl import load_workbook
|
||||
wb = load_workbook(path, read_only=True, data_only=True)
|
||||
parts = []
|
||||
for name in wb.sheetnames:
|
||||
ws = wb[name]
|
||||
rows = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) if c is not None else "" for c in row]
|
||||
if any(c.strip() for c in cells):
|
||||
rows.append(" | ".join(cells))
|
||||
if any(c for c in cells):
|
||||
rows.append("\t".join(cells))
|
||||
if rows:
|
||||
sheets_text.append(f"--- 工作表: {sheet_name} ---\n" + "\n".join(rows))
|
||||
parts.append(f"[Sheet: {name}]\n" + "\n".join(rows))
|
||||
wb.close()
|
||||
if sheets_text:
|
||||
return {
|
||||
"text": "\n\n".join(sheets_text),
|
||||
"file_type": "xlsx",
|
||||
"method": "openpyxl",
|
||||
"error": None,
|
||||
}
|
||||
text = "\n\n".join(parts)
|
||||
return {"text": text, "file_type": "xlsx", "method": "openpyxl", "error": None}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
return {"text": "", "file_type": "xlsx", "method": "none",
|
||||
"error": f"XLSX 解析失败: {e}"}
|
||||
return {"text": "", "file_type": "xlsx", "method": "none",
|
||||
"error": "XLSX 解析需要安装 openpyxl"}
|
||||
|
||||
|
||||
def _parse_xls(path: Path) -> dict:
|
||||
"""提取旧版 Excel .xls 文件中的文本。"""
|
||||
try:
|
||||
import xlrd
|
||||
wb = xlrd.open_workbook(path)
|
||||
parts = []
|
||||
for name in wb.sheet_names():
|
||||
ws = wb.sheet_by_name(name)
|
||||
rows = []
|
||||
for rx in range(ws.nrows):
|
||||
cells = [str(ws.cell_value(rx, cx)) if ws.cell_value(rx, cx) != "" else ""
|
||||
for cx in range(ws.ncols)]
|
||||
if any(c for c in cells):
|
||||
rows.append("\t".join(cells))
|
||||
if rows:
|
||||
parts.append(f"[Sheet: {name}]\n" + "\n".join(rows))
|
||||
text = "\n\n".join(parts)
|
||||
return {"text": text, "file_type": "xls", "method": "xlrd", "error": None}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
return {"text": "", "file_type": "xls", "method": "none",
|
||||
"error": f"XLS 解析失败: {e}"}
|
||||
return {"text": "", "file_type": "xls", "method": "none",
|
||||
"error": "XLS 解析需要安装 xlrd"}
|
||||
|
||||
|
||||
def _parse_doc(path: Path) -> dict:
|
||||
"""提取旧版 Word .doc 文件中的文本(尽力而为,二进制格式)。"""
|
||||
try:
|
||||
import olefile
|
||||
ole = olefile.OleFileIO(path)
|
||||
if not ole.exists("WordDocument"):
|
||||
ole.close()
|
||||
return {"text": "", "file_type": "doc", "method": "none",
|
||||
"error": "不是有效的 .doc 文件"}
|
||||
raw = ole.openstream("WordDocument").read()
|
||||
ole.close()
|
||||
# 提取可打印 UTF-16LE 字符段
|
||||
text = ""
|
||||
try:
|
||||
decoded = raw.decode("utf-16-le", errors="ignore")
|
||||
text = "".join(c for c in decoded if c.isprintable() or c in "\n\r\t")
|
||||
except Exception:
|
||||
pass
|
||||
if not text.strip():
|
||||
return {"text": "", "file_type": "doc", "method": "olefile",
|
||||
"error": "无法提取文本(.doc 为二进制格式,建议转换为 .docx)"}
|
||||
return {"text": text.strip(), "file_type": "doc", "method": "olefile", "error": None}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
return {"text": "", "file_type": "doc", "method": "none",
|
||||
"error": f"DOC 解析失败: {e}"}
|
||||
return {"text": "", "file_type": "doc", "method": "none",
|
||||
"error": "DOC 解析需要安装 olefile"}
|
||||
|
||||
|
||||
|
||||
def _parse_text(path: Path) -> dict:
|
||||
"""读取纯文本文件。"""
|
||||
try:
|
||||
|
||||
+174
-34
@@ -119,6 +119,146 @@ def analyze_layout(
|
||||
}
|
||||
|
||||
|
||||
def extract_layout_schema(layout_result: dict) -> dict:
|
||||
"""将 analyze_layout() 的完整 OCR 行数据压缩为高层布局 schema。
|
||||
|
||||
列检测:跨所有行对元素 X 坐标进行聚类。
|
||||
区域分类:启发式识别标题/表头/数据/表尾行。
|
||||
输出紧凑的 schema_text,供 LLM 阶段一骨架生成使用。
|
||||
"""
|
||||
rows = layout_result.get("rows", [])
|
||||
if not rows:
|
||||
return _empty_schema()
|
||||
|
||||
img_w, img_h = layout_result.get("image_size", (595, 842))
|
||||
if img_w <= 0:
|
||||
img_w = 595
|
||||
|
||||
all_elements = []
|
||||
for row in rows:
|
||||
all_elements.extend(row.get("elements", []))
|
||||
if not all_elements:
|
||||
return _empty_schema()
|
||||
|
||||
x_centers = sorted((e["x"] + e["w"] / 2) for e in all_elements)
|
||||
avg_width = sum(e["w"] for e in all_elements) / len(all_elements)
|
||||
cluster_threshold = avg_width * 0.5
|
||||
|
||||
clusters = []
|
||||
current_cluster = [x_centers[0]]
|
||||
for xc in x_centers[1:]:
|
||||
if xc - current_cluster[-1] < cluster_threshold:
|
||||
current_cluster.append(xc)
|
||||
else:
|
||||
clusters.append(current_cluster)
|
||||
current_cluster = [xc]
|
||||
if current_cluster:
|
||||
clusters.append(current_cluster)
|
||||
|
||||
columns = []
|
||||
for ci, cluster in enumerate(clusters):
|
||||
cx_min = min(cluster)
|
||||
cx_max = max(cluster)
|
||||
col_elements = [
|
||||
e for e in all_elements
|
||||
if cx_min - cluster_threshold <= (e["x"] + e["w"] / 2) <= cx_max + cluster_threshold
|
||||
]
|
||||
avg_w = sum(e["w"] for e in col_elements) / len(col_elements) if col_elements else 0
|
||||
x_start = min(e["x"] for e in col_elements)
|
||||
|
||||
col_elements_by_y = sorted(col_elements, key=lambda e: e["y"])
|
||||
header_text = col_elements_by_y[0]["text"] if col_elements_by_y else f"列{ci+1}"
|
||||
|
||||
columns.append({
|
||||
"index": ci,
|
||||
"header_text": header_text,
|
||||
"avg_width": round(avg_w, 1),
|
||||
"x_start": round(x_start, 1),
|
||||
})
|
||||
|
||||
columns.sort(key=lambda c: c["x_start"])
|
||||
|
||||
row_element_counts = [len(r.get("elements", [])) for r in rows]
|
||||
median_count = sorted(row_element_counts)[len(row_element_counts) // 2] if row_element_counts else 0
|
||||
total_rows = len(rows)
|
||||
|
||||
regions = []
|
||||
current_region = None
|
||||
|
||||
for ri in range(total_rows):
|
||||
count = row_element_counts[ri]
|
||||
if ri == 0 and count < median_count * 0.6 and total_rows > 2:
|
||||
rtype = "title"
|
||||
elif ri == 0 and total_rows <= 2:
|
||||
rtype = "header"
|
||||
elif ri == 1 and total_rows > 2:
|
||||
rtype = "header" if median_count > 0 else "data"
|
||||
elif ri >= total_rows - 2 and count < median_count * 0.7 and total_rows > 3:
|
||||
rtype = "footer"
|
||||
else:
|
||||
rtype = "data"
|
||||
|
||||
if current_region and current_region["type"] == rtype:
|
||||
current_region["row_indices"].append(ri)
|
||||
current_region["element_count"] += count
|
||||
else:
|
||||
if current_region:
|
||||
regions.append(current_region)
|
||||
current_region = {"type": rtype, "row_indices": [ri], "element_count": count}
|
||||
|
||||
if current_region:
|
||||
regions.append(current_region)
|
||||
|
||||
# schema_text
|
||||
width_ratios = [c["avg_width"] / img_w for c in columns]
|
||||
width_labels = []
|
||||
for r in width_ratios:
|
||||
if r < 0.08:
|
||||
width_labels.append("窄")
|
||||
elif r > 0.20:
|
||||
width_labels.append("宽")
|
||||
else:
|
||||
width_labels.append("中")
|
||||
|
||||
col_descs = []
|
||||
for ci, col in enumerate(columns):
|
||||
wl = width_labels[ci] if ci < len(width_labels) else "中"
|
||||
col_descs.append(f"{col['header_text']}({wl})")
|
||||
|
||||
_rn = {"title": "标题", "header": "表头", "data": "数据", "footer": "表尾"}
|
||||
region_parts = []
|
||||
for r in regions:
|
||||
label = _rn.get(r["type"], r["type"])
|
||||
region_parts.append(f"{label}({len(r['row_indices'])}行)")
|
||||
region_summary = " → ".join(region_parts)
|
||||
|
||||
schema_text = (
|
||||
f"报表布局: {len(columns)}列 x {total_rows}行, A4纵向\n"
|
||||
f"列定义: {', '.join(col_descs)}\n"
|
||||
f"区域: {region_summary}"
|
||||
)
|
||||
|
||||
return {
|
||||
"columns": columns,
|
||||
"regions": regions,
|
||||
"total_rows": total_rows,
|
||||
"total_columns": len(columns),
|
||||
"a4_dimensions": {"width": 595, "height": 842},
|
||||
"schema_text": schema_text,
|
||||
}
|
||||
|
||||
|
||||
def _empty_schema() -> dict:
|
||||
return {
|
||||
"columns": [],
|
||||
"regions": [],
|
||||
"total_rows": 0,
|
||||
"total_columns": 0,
|
||||
"a4_dimensions": {"width": 595, "height": 842},
|
||||
"schema_text": "无法解析报表布局",
|
||||
}
|
||||
|
||||
|
||||
def match_rows_to_jrxml(
|
||||
layout_result: dict,
|
||||
current_jrxml: str,
|
||||
@@ -373,40 +513,7 @@ def _load_image(path: Path) -> Optional[PIL.Image.Image]:
|
||||
def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]:
|
||||
"""OCR 提取图片中的文字元素(位置+内容)。优先 EasyOCR,回退 PaddleOCR。"""
|
||||
|
||||
# 优先 EasyOCR
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
|
||||
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
|
||||
result = reader.readtext(np.array(img))
|
||||
|
||||
elements = []
|
||||
for (bbox, text, confidence) in result:
|
||||
if not text.strip():
|
||||
continue
|
||||
xs = [p[0] for p in bbox]
|
||||
ys = [p[1] for p in bbox]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
elements.append({
|
||||
"x": round(x_min, 1),
|
||||
"y": round(y_min, 1),
|
||||
"w": round(x_max - x_min, 1),
|
||||
"h": round(y_max - y_min, 1),
|
||||
"font_size": round(y_max - y_min, 1),
|
||||
"text": text.strip(),
|
||||
})
|
||||
|
||||
elements.sort(key=lambda e: (e["y"], e["x"]))
|
||||
return elements
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退 PaddleOCR
|
||||
# 优先 PaddleOCR(精确识别)
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
import numpy as np
|
||||
@@ -446,6 +553,39 @@ def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退 EasyOCR
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
|
||||
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
|
||||
result = reader.readtext(np.array(img))
|
||||
|
||||
elements = []
|
||||
for (bbox, text, confidence) in result:
|
||||
if not text.strip():
|
||||
continue
|
||||
xs = [p[0] for p in bbox]
|
||||
ys = [p[1] for p in bbox]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
elements.append({
|
||||
"x": round(x_min, 1),
|
||||
"y": round(y_min, 1),
|
||||
"w": round(x_max - x_min, 1),
|
||||
"h": round(y_max - y_min, 1),
|
||||
"font_size": round(y_max - y_min, 1),
|
||||
"text": text.strip(),
|
||||
})
|
||||
|
||||
elements.sort(key=lambda e: (e["y"], e["x"]))
|
||||
return elements
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -284,13 +284,13 @@ class OcrExtractor:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
easyocr_result = self._try_easyocr(np.array(img))
|
||||
if easyocr_result:
|
||||
return easyocr_result
|
||||
|
||||
paddleocr_result = self._try_paddleocr(img, file_path)
|
||||
if paddleocr_result:
|
||||
return paddleocr_result
|
||||
|
||||
easyocr_result = self._try_easyocr(np.array(img))
|
||||
if easyocr_result:
|
||||
return easyocr_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user