Files
agent_jrxml/backend/layout_analyzer.py
T
panda 9bb011e429 feat: v4 multimodal chat input, multi-format support, and annotation detection
- Replace st.chat_input with st-multimodal-chatinput (Ctrl+V paste, drag-drop, file button)
- Extract _process_uploaded_file() shared handler (eliminates ~70 duplicated lines)
- Add XLSX (openpyxl), XLS (xlrd), DOC (olefile) parsers to file_parser.py
- Add backend/annotation_detector.py: circle detection (HoughCircles) + arrow detection (HoughLinesP clustering) + OCR correlation + LLM context formatting
- Add annotation_result field to AgentState with session persistence
- Wire annotation detection into process_input and _format_ocr_context
- Add 11 new tests: 7 annotation detector + 4 multi-format parser
- Update all docs: CLAUDE.md, README.md, CODE_GUIDE.md, ROADMAP.md
2026-05-20 23:43:16 +08:00

533 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""A4 图片模板布局分析器。
检测上传图片并逐行识别每个元素的:
- 位置 (x, y, w, h)
- 字体大小(基于 OCR 边界框高度估算)
- 文本内容
支持三种模式:
- 完整 A4 模板:比例匹配 + OCR 元素 ≥2 → 全量布局描述
- 行片段(非 A4 但有元素):视为 A4 中的某几行 → 部分布局描述
- 修改匹配:将图片中的行与现有 JRXML 做匹配,定位修改位置
用法:
from backend.layout_analyzer import analyze_layout, match_rows_to_jrxml
result = analyze_layout("row_snippet.png")
# result["template_type"] = "partial_rows"
match = match_rows_to_jrxml(result, current_jrxml)
# match["matched_rows"] = [{"row_index": 0, "jrxml_section": "detail_band", ...}]
"""
import re
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Optional
import PIL.Image
# A4 标准尺寸 (mm): 210 × 297, 比例 ≈ 0.707
A4_RATIO = 210 / 297
A4_RATIO_EXACT_MIN, A4_RATIO_EXACT_MAX = 0.686, 0.728
A4_RATIO_CLOSE_MIN, A4_RATIO_CLOSE_MAX = 0.650, 0.764
def analyze_layout(
file_path: str,
row_tolerance_ratio: float = 0.02,
) -> dict:
"""分析图片/PDF 的报表模板布局。
返回:
{
"is_a4_template": bool, # 完整 A4 模板
"is_partial": bool, # 行片段(非 A4 但有文字元素)
"template_type": str, # "full_a4" | "partial_rows" | "unknown"
"image_size": (w, h),
"aspect_ratio": float,
"a4_confidence": str,
"rows": [{y_center, elements: [{x, y, w, h, font_size, text}, ...]}, ...],
"description": str,
"total_rows": int,
"total_elements": int,
}
"""
path = Path(file_path)
if not path.exists():
return _empty_result("文件不存在")
img = _load_image(path)
if img is None:
return _empty_result("无法加载图片")
w, h = img.size
ratio = min(w, h) / max(w, h)
# A4 比例判定
if A4_RATIO_EXACT_MIN <= ratio <= A4_RATIO_EXACT_MAX:
a4_confidence = "exact"
elif A4_RATIO_CLOSE_MIN <= ratio <= A4_RATIO_CLOSE_MAX:
a4_confidence = "close"
else:
a4_confidence = "not_a4"
# OCR 提取
elements = _ocr_elements(img, file_path)
if not elements:
return {
"is_a4_template": False,
"is_partial": False,
"template_type": "unknown",
"image_size": (w, h),
"aspect_ratio": round(ratio, 3),
"a4_confidence": a4_confidence,
"rows": [],
"description": _build_description([], w, h, a4_confidence, "unknown"),
"total_rows": 0,
"total_elements": 0,
}
# 行分组
rows = _group_into_rows(elements, h, row_tolerance_ratio)
total = sum(len(r["elements"]) for r in rows)
# 模板类型判定
is_full_a4 = a4_confidence != "not_a4" and total >= 2
is_partial = not is_full_a4 and total >= 1 # 非 A4 但有文字 → 行片段
if is_full_a4:
template_type = "full_a4"
elif is_partial:
template_type = "partial_rows"
else:
template_type = "unknown"
description = _build_description(rows, w, h, a4_confidence, template_type)
return {
"is_a4_template": is_full_a4,
"is_partial": is_partial,
"template_type": template_type,
"image_size": (w, h),
"aspect_ratio": round(ratio, 3),
"a4_confidence": a4_confidence,
"rows": rows,
"description": description,
"total_rows": len(rows),
"total_elements": total,
}
def match_rows_to_jrxml(
layout_result: dict,
current_jrxml: str,
) -> dict:
"""将图片中的行与现有 JRXML 中的 section/band 做匹配。
匹配策略:
1. 从图片 OCR 文本中提取关键词
2. 在 JRXML 中搜索这些关键词出现在哪个 band
3. 返回匹配结果,可用于定位修改位置
返回:
{
"matched": bool,
"matched_rows": [{row_index, row_y_center, jrxml_section, confidence}],
"unmatched_rows": [...],
"description": str, # 人类可读的匹配结果
}
"""
rows = layout_result.get("rows", [])
if not rows or not current_jrxml.strip():
return {"matched": False, "matched_rows": [], "unmatched_rows": rows,
"description": "无行数据或 JRXML 为空"}
# 解析 JRXML 结构
jrxml_sections = _parse_jrxml_sections(current_jrxml)
matched_rows = []
unmatched_rows = []
for ri, row in enumerate(rows):
ocr_texts = [e["text"] for e in row["elements"]]
best_section = None
best_score = 0
for section in jrxml_sections:
score = _text_similarity(ocr_texts, section["text_content"])
if score > best_score:
best_score = score
best_section = section
if best_score > 0.3 and best_section: # 最低匹配阈值
matched_rows.append({
"row_index": ri,
"row_y_center": row["y_center"],
"jrxml_section": best_section["name"],
"jrxml_section_type": best_section["type"],
"confidence": round(best_score, 2),
"matched_text": best_section["text_content"][:200],
})
else:
unmatched_rows.append({
"row_index": ri,
"row_y_center": row["y_center"],
"ocr_texts": ocr_texts,
})
# 生成描述
desc_parts = []
if matched_rows:
desc_parts.append(f"图片中 {len(matched_rows)} 行匹配到当前 JRXML")
for m in matched_rows:
desc_parts.append(
f" - 图片第 {m['row_index']+1} 行 → JRXML「{m['jrxml_section']}"
f"{m['jrxml_section_type']},置信度 {m['confidence']}"
)
if unmatched_rows:
desc_parts.append(f"图片中 {len(unmatched_rows)} 行未匹配到 JRXML 现有区域:")
for u in unmatched_rows:
texts = ", ".join(u["ocr_texts"][:3])
desc_parts.append(f" - 图片第 {u['row_index']+1} 行:{texts}")
return {
"matched": len(matched_rows) > 0,
"matched_rows": matched_rows,
"unmatched_rows": unmatched_rows,
"description": "\n".join(desc_parts),
}
def analyze_and_inject(file_path: str, base_prompt: str,
current_jrxml: str = "") -> str:
"""分析布局并增强 prompt。
- 完整 A4 模板 → 全量布局描述
- 行片段 + 有 JRXML → 行匹配 + 修改指引
- 行片段 + 无 JRXML → 行片段描述(视为 A4 模板的一部分)
"""
result = analyze_layout(file_path)
tt = result.get("template_type", "unknown")
if tt == "unknown":
return base_prompt
if tt == "full_a4":
return f"[图片模板分析 — 完整 A4 报表]\n{result['description']}\n\n---\n原始需求:\n{base_prompt}"
if tt == "partial_rows":
if current_jrxml.strip():
match = match_rows_to_jrxml(result, current_jrxml)
if match["matched"]:
return (
f"[图片模板分析 — 行片段修改]\n"
f"图片包含 {result['total_rows']} 行,视为 A4 模板的一部分。\n"
f"{match['description']}\n\n"
f"{result['description']}\n\n"
f"---\n请根据以上匹配结果,修改 JRXML 中对应区域的布局:\n{base_prompt}"
)
else:
return (
f"[图片模板分析 — 行片段(未匹配到现有区域)]\n"
f"图片包含 {result['total_rows']} 行。\n"
f"{result['description']}\n\n"
f"---\n请根据以上行结构,在 JRXML 中找到合适位置进行修改:\n{base_prompt}"
)
else:
return (
f"[图片模板分析 — 行片段(无现有报表,按 A4 模板处理)]\n"
f"图片包含 {result['total_rows']} 行,请按 A4 报表模板的需求输出整张报表。\n"
f"{result['description']}\n\n"
f"---\n原始需求:\n{base_prompt}"
)
return base_prompt
# ---------------------------------------------------------------------------
# JRXML 结构解析
# ---------------------------------------------------------------------------
def _parse_jrxml_sections(jrxml: str) -> list[dict]:
"""解析 JRXML 中的 section/band 结构。
直接搜索所有 band 元素,通过上下文字符串推断其所属 section。
"""
sections = []
try:
root = ET.fromstring(jrxml)
section_tags = {
"title", "pageHeader", "columnHeader", "detail",
"columnFooter", "pageFooter", "summary", "background",
"noData", "groupHeader", "groupFooter",
}
for section_elem in root.iter():
stag = _tag(section_elem)
if stag not in section_tags:
continue
for child in section_elem:
if _tag(child) == "band":
name = child.get("name", "")
section_name = f"{stag}[{name}]" if name else stag
text_content = ET.tostring(child, encoding="unicode")
sections.append({
"name": section_name,
"type": stag,
"text_content": text_content,
})
except Exception:
pass
# Fallback: 如果 structured parsing 失败,直接把整个 JRXML 按 band 分割
if not sections:
sections = _parse_jrxml_regex(jrxml)
return sections
def _tag(elem) -> str:
"""去除命名空间前缀的标签名。"""
return elem.tag.split("}")[-1] if "}" in elem.tag else elem.tag
def _parse_jrxml_regex(jrxml: str) -> list[dict]:
"""正则回退方案:直接在文本中搜索 band 块。"""
sections = []
band_pattern = re.compile(
r'<(title|pageHeader|columnHeader|detail|columnFooter|pageFooter|summary|background|noData|groupHeader|groupFooter)>\s*'
r'(<band[^>]*>.*?</band>)\s*'
r'</\1>',
re.DOTALL,
)
for m in band_pattern.finditer(jrxml):
stag = m.group(1)
band_xml = m.group(0)
sections.append({
"name": stag,
"type": stag,
"text_content": band_xml,
})
return sections
def _text_similarity(ocr_texts: list[str], jrxml_text: str) -> float:
"""计算 OCR 文本与 JRXML 文本的相似度(简单的词匹配)。"""
if not ocr_texts or not jrxml_text:
return 0.0
jrxml_lower = jrxml_text.lower()
score = 0.0
for text in ocr_texts:
# 精确匹配
if text.lower() in jrxml_lower:
score += 1.0
else:
# 部分词匹配
words = re.findall(r"\w+", text)
matched = sum(1 for w in words if w.lower() in jrxml_lower)
if words:
score += matched / len(words) * 0.5
return min(score / len(ocr_texts), 1.0)
# ---------------------------------------------------------------------------
# 内部实现(不变)
# ---------------------------------------------------------------------------
def _load_image(path: Path) -> Optional[PIL.Image.Image]:
suffix = path.suffix.lower()
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"):
try:
return PIL.Image.open(path).convert("RGB")
except Exception:
return None
if suffix == ".pdf":
try:
import pdfplumber
with pdfplumber.open(path) as pdf:
if pdf.pages:
pil_img = pdf.pages[0].to_image(resolution=150)
return pil_img.original.convert("RGB")
except Exception:
pass
try:
import fitz
doc = fitz.open(path)
pix = doc[0].get_pixmap(dpi=150)
img = PIL.Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
doc.close()
return img
except Exception:
pass
return None
def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]:
"""OCR 提取图片中的文字元素(位置+内容)。优先 EasyOCR,回退 PaddleOCR。"""
# 优先 PaddleOCR(精确识别)
try:
from paddleocr import PaddleOCR
import numpy as np
ocr = PaddleOCR(lang="ch")
result = ocr.ocr(np.array(img))
elements = []
if result and result[0]:
for line in result[0]:
if len(line) < 2:
continue
box = line[0]
text_info = line[1]
text = text_info[0] if isinstance(text_info, (list, tuple)) else str(text_info)
if not text.strip():
continue
xs = [p[0] for p in box]
ys = [p[1] for p in box]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
elements.append({
"x": round(x_min, 1),
"y": round(y_min, 1),
"w": round(x_max - x_min, 1),
"h": round(y_max - y_min, 1),
"font_size": round(y_max - y_min, 1),
"text": text.strip(),
})
elements.sort(key=lambda e: (e["y"], e["x"]))
return elements
except ImportError:
pass
except Exception:
pass
# 回退 EasyOCR
try:
import easyocr
import numpy as np
reader = easyocr.Reader(["ch_sim", "en"], gpu=False, verbose=False)
result = reader.readtext(np.array(img))
elements = []
for (bbox, text, confidence) in result:
if not text.strip():
continue
xs = [p[0] for p in bbox]
ys = [p[1] for p in bbox]
x_min, x_max = min(xs), max(xs)
y_min, y_max = min(ys), max(ys)
elements.append({
"x": round(x_min, 1),
"y": round(y_min, 1),
"w": round(x_max - x_min, 1),
"h": round(y_max - y_min, 1),
"font_size": round(y_max - y_min, 1),
"text": text.strip(),
})
elements.sort(key=lambda e: (e["y"], e["x"]))
return elements
except ImportError:
pass
except Exception:
pass
return []
def _group_into_rows(elements: list[dict], img_height: int,
tolerance_ratio: float = 0.02) -> list[dict]:
if not elements:
return []
tolerance = img_height * tolerance_ratio
rows = []
current_row = [elements[0]]
for elem in elements[1:]:
prev_cy = current_row[0]["y"] + current_row[0]["h"] / 2
curr_cy = elem["y"] + elem["h"] / 2
if abs(curr_cy - prev_cy) < tolerance:
current_row.append(elem)
else:
rows.append(_build_row(current_row))
current_row = [elem]
if current_row:
rows.append(_build_row(current_row))
return rows
def _build_row(elements: list[dict]) -> dict:
elements.sort(key=lambda e: e["x"])
ys = [e["y"] for e in elements]
return {"y_center": round(sum(ys) / len(ys), 1), "elements": elements}
def _build_description(rows: list[dict], img_w: int, img_h: int,
a4_confidence: str, template_type: str) -> str:
if not rows:
if template_type == "partial_rows":
return f"图片 {img_w}x{img_h}(非 A4 比例),未检测到文字元素。"
return f"图片共 {img_w}x{img_h} 像素,未检测到文字元素。"
lines = []
if template_type == "full_a4":
lines.append(f"图片为完整 A4 报表模板,共 {len(rows)} 行,像素区域 {img_w}x{img_h}")
elif template_type == "partial_rows":
lines.append(f"图片为报表模板行片段(非完整 A4),包含 {len(rows)} 行,"
f"像素区域 {img_w}x{img_h},请按 A4 模板处理:")
else:
lines.append(f"图片共 {img_w}x{img_h} 像素,包含 {len(rows)} 行文字:")
for i, row in enumerate(rows):
elems = row["elements"]
lines.append(f"\n{i+1} 行有 {len(elems)} 个元素:")
for j, e in enumerate(elems):
letter = chr(ord("a") + j)
lines.append(
f" 元素 {letter}:位置(x={e['x']}, y={e['y']})"
f"{e['w']}px,高 {e['h']}px"
f"字体 {e['font_size']}px"
f"内容「{e['text']}"
)
if template_type == "full_a4":
lines.append(f"\n请根据以上布局生成对应的 JRXML 报表模板。")
elif template_type == "partial_rows":
lines.append(f"\n请将以上 {len(rows)} 行作为 A4 模板的一部分,"
f"生成或修改对应的 JRXML 报表区域。")
return "\n".join(lines)
def _empty_result(error: str = "") -> dict:
return {
"is_a4_template": False,
"is_partial": False,
"template_type": "unknown",
"image_size": (0, 0),
"aspect_ratio": 0,
"a4_confidence": "not_a4",
"rows": [],
"description": error,
"total_rows": 0,
"total_elements": 0,
}