feat: layered precise generation for A4 report images
3-phase pipeline to solve LLM prompt overflow from too many OCR elements:
Phase 1 (generate_skeleton): compressed layout schema → skeleton JRXML
Phase 2 (refine_layout): sampled coordinates → pixel-level position tuning
Phase 3 (map_fields): OCR field names → replace $F{field_N} placeholders
Only triggered when layout_schema.total_rows > 0 on initial_generation intent.
Text requests and all other intents are unaffected (zero behavior change).
This commit is contained in:
@@ -119,6 +119,146 @@ def analyze_layout(
|
||||
}
|
||||
|
||||
|
||||
def extract_layout_schema(layout_result: dict) -> dict:
|
||||
"""将 analyze_layout() 的完整 OCR 行数据压缩为高层布局 schema。
|
||||
|
||||
列检测:跨所有行对元素 X 坐标进行聚类。
|
||||
区域分类:启发式识别标题/表头/数据/表尾行。
|
||||
输出紧凑的 schema_text,供 LLM 阶段一骨架生成使用。
|
||||
"""
|
||||
rows = layout_result.get("rows", [])
|
||||
if not rows:
|
||||
return _empty_schema()
|
||||
|
||||
img_w, img_h = layout_result.get("image_size", (595, 842))
|
||||
if img_w <= 0:
|
||||
img_w = 595
|
||||
|
||||
all_elements = []
|
||||
for row in rows:
|
||||
all_elements.extend(row.get("elements", []))
|
||||
if not all_elements:
|
||||
return _empty_schema()
|
||||
|
||||
x_centers = sorted((e["x"] + e["w"] / 2) for e in all_elements)
|
||||
avg_width = sum(e["w"] for e in all_elements) / len(all_elements)
|
||||
cluster_threshold = avg_width * 0.5
|
||||
|
||||
clusters = []
|
||||
current_cluster = [x_centers[0]]
|
||||
for xc in x_centers[1:]:
|
||||
if xc - current_cluster[-1] < cluster_threshold:
|
||||
current_cluster.append(xc)
|
||||
else:
|
||||
clusters.append(current_cluster)
|
||||
current_cluster = [xc]
|
||||
if current_cluster:
|
||||
clusters.append(current_cluster)
|
||||
|
||||
columns = []
|
||||
for ci, cluster in enumerate(clusters):
|
||||
cx_min = min(cluster)
|
||||
cx_max = max(cluster)
|
||||
col_elements = [
|
||||
e for e in all_elements
|
||||
if cx_min - cluster_threshold <= (e["x"] + e["w"] / 2) <= cx_max + cluster_threshold
|
||||
]
|
||||
avg_w = sum(e["w"] for e in col_elements) / len(col_elements) if col_elements else 0
|
||||
x_start = min(e["x"] for e in col_elements)
|
||||
|
||||
col_elements_by_y = sorted(col_elements, key=lambda e: e["y"])
|
||||
header_text = col_elements_by_y[0]["text"] if col_elements_by_y else f"列{ci+1}"
|
||||
|
||||
columns.append({
|
||||
"index": ci,
|
||||
"header_text": header_text,
|
||||
"avg_width": round(avg_w, 1),
|
||||
"x_start": round(x_start, 1),
|
||||
})
|
||||
|
||||
columns.sort(key=lambda c: c["x_start"])
|
||||
|
||||
row_element_counts = [len(r.get("elements", [])) for r in rows]
|
||||
median_count = sorted(row_element_counts)[len(row_element_counts) // 2] if row_element_counts else 0
|
||||
total_rows = len(rows)
|
||||
|
||||
regions = []
|
||||
current_region = None
|
||||
|
||||
for ri in range(total_rows):
|
||||
count = row_element_counts[ri]
|
||||
if ri == 0 and count < median_count * 0.6 and total_rows > 2:
|
||||
rtype = "title"
|
||||
elif ri == 0 and total_rows <= 2:
|
||||
rtype = "header"
|
||||
elif ri == 1 and total_rows > 2:
|
||||
rtype = "header" if median_count > 0 else "data"
|
||||
elif ri >= total_rows - 2 and count < median_count * 0.7 and total_rows > 3:
|
||||
rtype = "footer"
|
||||
else:
|
||||
rtype = "data"
|
||||
|
||||
if current_region and current_region["type"] == rtype:
|
||||
current_region["row_indices"].append(ri)
|
||||
current_region["element_count"] += count
|
||||
else:
|
||||
if current_region:
|
||||
regions.append(current_region)
|
||||
current_region = {"type": rtype, "row_indices": [ri], "element_count": count}
|
||||
|
||||
if current_region:
|
||||
regions.append(current_region)
|
||||
|
||||
# schema_text
|
||||
width_ratios = [c["avg_width"] / img_w for c in columns]
|
||||
width_labels = []
|
||||
for r in width_ratios:
|
||||
if r < 0.08:
|
||||
width_labels.append("窄")
|
||||
elif r > 0.20:
|
||||
width_labels.append("宽")
|
||||
else:
|
||||
width_labels.append("中")
|
||||
|
||||
col_descs = []
|
||||
for ci, col in enumerate(columns):
|
||||
wl = width_labels[ci] if ci < len(width_labels) else "中"
|
||||
col_descs.append(f"{col['header_text']}({wl})")
|
||||
|
||||
_rn = {"title": "标题", "header": "表头", "data": "数据", "footer": "表尾"}
|
||||
region_parts = []
|
||||
for r in regions:
|
||||
label = _rn.get(r["type"], r["type"])
|
||||
region_parts.append(f"{label}({len(r['row_indices'])}行)")
|
||||
region_summary = " → ".join(region_parts)
|
||||
|
||||
schema_text = (
|
||||
f"报表布局: {len(columns)}列 x {total_rows}行, A4纵向\n"
|
||||
f"列定义: {', '.join(col_descs)}\n"
|
||||
f"区域: {region_summary}"
|
||||
)
|
||||
|
||||
return {
|
||||
"columns": columns,
|
||||
"regions": regions,
|
||||
"total_rows": total_rows,
|
||||
"total_columns": len(columns),
|
||||
"a4_dimensions": {"width": 595, "height": 842},
|
||||
"schema_text": schema_text,
|
||||
}
|
||||
|
||||
|
||||
def _empty_schema() -> dict:
|
||||
return {
|
||||
"columns": [],
|
||||
"regions": [],
|
||||
"total_rows": 0,
|
||||
"total_columns": 0,
|
||||
"a4_dimensions": {"width": 595, "height": 842},
|
||||
"schema_text": "无法解析报表布局",
|
||||
}
|
||||
|
||||
|
||||
def match_rows_to_jrxml(
|
||||
layout_result: dict,
|
||||
current_jrxml: str,
|
||||
|
||||
Reference in New Issue
Block a user