feat: add Java JRXML-to-PNG rendering pipeline with pixel-level SSIM comparison

- lib/java/: Java renderer (JrxmlRenderer) using JasperReports 6.21.0
  - JrxmlDebug for diagnostics, JrxmlGen for format reference
  - download_jars.sh for one-time dependency setup
- agent/nodes.py: _render_jrxml_to_png() and _compute_pixel_similarity()
  - Pixel comparison integrates into validate node (SSIM < 0.4 fails)
  - Pixel fidelity context injected into correct_jrxml for targeted fixes
- tests/test_pixel_comparison.py: 15 unit tests (render, SSIM, integration)
- .gitignore: exclude lib/java/*.jar, lib/java/*.class, tmp/
- CLAUDE.md: v11 changelog documenting the rendering pipeline
- All non-LLM tests pass (97/97)
This commit is contained in:
2026-05-23 15:09:55 +08:00
parent 9de75d2f25
commit bb6cc6e241
16 changed files with 837 additions and 8 deletions
+266 -2
View File
@@ -839,6 +839,188 @@ def modify_jrxml(state: AgentState) -> Dict:
return state
# ── Java renderer config ──────────────────────────────────────────────
_JAVA_BIN = os.path.join(
os.environ.get("JAVA_HOME", "C:/Program Files/Java/jdk-21.0.11"),
"bin", "java.exe"
)
_JAVA_JAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib", "java")
_JAVA_RENDERER_CP = ";".join([
os.path.join(_JAVA_JAR_DIR, j) for j in [
"jasperreports-6.21.0.jar",
"commons-logging-1.3.5.jar",
"commons-collections4-4.5.0.jar",
"commons-beanutils-1.10.1.jar",
"commons-lang3-3.17.0.jar",
"commons-digester-2.1.jar",
"itext-2.1.7.jar",
"jfreechart-1.5.5.jar",
"ecj-3.38.0.jar",
]
])
_JAVA_RENDERER_CLASS = "JrxmlRenderer"
_JAVA_RENDERER_CP = "." + os.pathsep + _JAVA_RENDERER_CP
def _render_jrxml_to_png(jrxml: str, output_path: str, scale: float = 2.0) -> bool:
"""调用 Java JrxmlRenderer 将 JRXML 渲染为 PNG。
返回 True 表示渲染成功,False 表示失败。
"""
import subprocess
import tempfile
tmpdir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "tmp")
os.makedirs(tmpdir, exist_ok=True)
jrxml_path = os.path.join(tmpdir, "_render_input.jrxml")
with open(jrxml_path, "w", encoding="utf-8") as f:
f.write(jrxml)
try:
result = subprocess.run(
[_JAVA_BIN, "-cp", _JAVA_RENDERER_CP, _JAVA_RENDERER_CLASS,
jrxml_path, output_path, str(scale)],
capture_output=True, text=True, timeout=120,
cwd=_JAVA_JAR_DIR,
)
if result.returncode == 0:
_node_log.info(f"PNG rendered: {output_path} ({result.stdout.strip()})")
return True
else:
_node_log.warning(f"PNG render failed: {result.stdout.strip()} {result.stderr.strip()}")
return False
except Exception as e:
_node_log.warning(f"PNG render exception: {e}")
return False
def _compute_pixel_similarity(rendered_png: str, reference_image: str) -> dict:
"""计算渲染 PNG 与参考图片的像素级相似度。
使用 SSIM(结构相似性)作为主要指标,同时返回像素差异比例。
返回 {"ssim": float, "diff_pct": float, "error": str|None}
"""
try:
import cv2
import numpy as np
rendered = cv2.imread(rendered_png, cv2.IMREAD_GRAYSCALE)
reference = cv2.imread(reference_image, cv2.IMREAD_GRAYSCALE)
if rendered is None:
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"无法读取渲染图片: {rendered_png}"}
if reference is None:
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"无法读取参考图片: {reference_image}"}
# Resize rendered to match reference dimensions for comparison
if rendered.shape != reference.shape:
rendered = cv2.resize(rendered, (reference.shape[1], reference.shape[0]))
# SSIM
from skimage.metrics import structural_similarity as ssim
score = ssim(rendered, reference, data_range=255)
# Pixel difference percentage
diff = cv2.absdiff(rendered, reference)
diff_pct = float(np.count_nonzero(diff > 30)) / diff.size
return {"ssim": round(score, 4), "diff_pct": round(diff_pct, 4), "error": None}
except ImportError as e:
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"缺少依赖: {e}"}
except Exception as e:
return {"ssim": 0.0, "diff_pct": 1.0, "error": str(e)}
def _check_ocr_fidelity(jrxml: str, state: dict) -> dict:
"""比对生成的 JRXML 与原始图片 OCR 提取内容的保真度。
检查维度:
1. 字段覆盖:OCR 字段名是否在 JRXML <field> 声明中出现
2. 元素数量:JRXML 中 textField+staticText 数量与 OCR 文本元素数量之比
3. 列结构:data band 中的列数与 OCR 检测到的列数比对
"""
ocr_elements = state.get("ocr_elements", [])
ocr_result = state.get("ocr_extraction_result", {})
layout_schema = state.get("layout_schema", {})
# 无 OCR 数据时跳过
if not ocr_elements and not ocr_result:
return {"score": 1.0, "field_coverage": 1.0, "element_coverage": 1.0, "issues": []}
issues = []
# 1. 元素数量对比
text_fields = len(re.findall(r"<textField", jrxml))
static_texts = len(re.findall(r"<staticText", jrxml))
total_jrxml_elements = text_fields + static_texts
ocr_text_count = 0
if isinstance(ocr_elements, list):
ocr_text_count = len([e for e in ocr_elements if isinstance(e, dict) and e.get("text", "").strip()])
if ocr_text_count == 0 and isinstance(ocr_result, dict):
ocr_text_count = ocr_result.get("total_elements", 0)
if ocr_text_count > 0:
element_coverage = min(total_jrxml_elements / max(ocr_text_count, 1), 1.0)
if element_coverage < 0.3:
issues.append(
f"元素覆盖不足:JRXML 仅有 {total_jrxml_elements} 个文本元素,"
f"OCR 源有 {ocr_text_count} 个文本元素(覆盖率 {element_coverage:.0%}"
)
else:
element_coverage = 1.0
# 2. 字段名覆盖
jrxml_fields = set(re.findall(r'<field name="(\w+)"', jrxml))
ocr_field_names = set()
ocr_fields = ocr_result.get("fields", []) if isinstance(ocr_result, dict) else []
for f in ocr_fields:
if isinstance(f, dict):
name = f.get("name", "") or f.get("field_name", "") or f.get("label", "")
if name and len(name) > 1:
ocr_field_names.add(name)
if ocr_field_names and jrxml_fields:
matched = jrxml_fields & ocr_field_names
field_coverage = len(matched) / max(len(ocr_field_names), 1)
unmatched = ocr_field_names - jrxml_fields
if unmatched:
sample = list(unmatched)[:8]
issues.append(f"OCR 字段未在 JRXML 中声明: {', '.join(sample)}")
elif ocr_field_names and not jrxml_fields:
field_coverage = 0.0
issues.append("JRXML 中未声明任何字段,但 OCR 提取了结构化字段数据")
else:
field_coverage = 1.0
# 3. 列数对比
if isinstance(layout_schema, dict):
ocr_columns = layout_schema.get("total_columns", 0) or layout_schema.get("columns", 0)
# 从 detail band 中的元素 x 坐标估算列数
detail_match = re.search(r"<band[^>]*height=\"(\d+)\"[^>]*>([\s\S]*?)</band>", jrxml)
if detail_match and ocr_columns > 0:
detail_content = detail_match.group(2)
x_positions = set()
for m in re.finditer(r'x="(\d+)"', detail_content):
x_positions.add(int(m.group(1)))
jrxml_columns = len(x_positions) if x_positions else 1
if jrxml_columns < ocr_columns * 0.5:
issues.append(
f"列数不足:JRXML detail band 检测到 {jrxml_columns} 列,"
f"OCR 布局分析有 {ocr_columns}"
)
# 综合评分
score = round(field_coverage * 0.5 + element_coverage * 0.5, 3)
return {
"score": score,
"field_coverage": round(field_coverage, 3),
"element_coverage": round(element_coverage, 3),
"issues": issues,
}
@log_node("validate")
def validate(state: AgentState) -> Dict:
"""根据 FastAPI 验证服务验证当前 JRXML。"""
@@ -866,6 +1048,57 @@ def validate(state: AgentState) -> Dict:
state["status"] = "pass" if result.get("valid") else "fail"
state["error_msg"] = result.get("error", "")
# OCR 保真度检查:比对生成结果与原始图片的 OCR 提取内容
fidelity = _check_ocr_fidelity(jrxml, state)
state["ocr_fidelity"] = fidelity
if fidelity["issues"]:
if state["status"] == "pass":
# XSD 通过但内容保真度不足 → 降级为 fail
if fidelity["score"] < 0.5:
state["status"] = "fail"
state["error_msg"] = (
f"[内容保真度不足] 得分 {fidelity['score']:.2f}/1.0。"
+ " ".join(fidelity["issues"][:3])
)
_node_log.warning(
f"OCR 保真度得分 {fidelity['score']:.2f}XSD 通过但内容差异过大: "
+ "; ".join(fidelity["issues"][:5])
)
else:
_node_log.info(
f"OCR 保真度得分 {fidelity['score']:.2f}XSD 通过,轻微差异: "
+ "; ".join(fidelity["issues"][:3])
)
else:
_node_log.info(
f"XSD 验证失败 + OCR 保真度得分 {fidelity['score']:.2f}: "
+ "; ".join(fidelity["issues"][:3])
)
# ── 像素级对比:将 JRXML 渲染为 PNG,与原始上传图片进行 SSIM 比较 ──
source_image = state.get("uploaded_file_path", "")
if source_image and os.path.isfile(source_image) and state["status"] == "pass":
tmpdir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "tmp")
rendered_png = os.path.join(tmpdir, "_pixel_test.png")
if _render_jrxml_to_png(jrxml, rendered_png):
pixel_result = _compute_pixel_similarity(rendered_png, source_image)
state["pixel_fidelity"] = pixel_result
if pixel_result["error"]:
_node_log.warning(f"像素对比失败: {pixel_result['error']}")
else:
_node_log.info(
f"像素对比: SSIM={pixel_result['ssim']:.4f}, "
f"Diff={pixel_result['diff_pct']:.2%}"
)
# SSIM < 0.4 或 diff > 60% → 质量不合格
if pixel_result["ssim"] < 0.4 and pixel_result["diff_pct"] > 0.6:
state["status"] = "fail"
state["error_msg"] = (
f"[像素保真度不足] SSIM={pixel_result['ssim']:.3f}, "
f"差异像素占比={pixel_result['diff_pct']:.2%}"
f"渲染结果与原始图片差异过大,需调整布局。"
)
# 修正成功后记录到错误知识库
if result.get("valid") and state.get("retry_count", 0) > 0:
case = state.get("last_error_case", {})
@@ -920,12 +1153,34 @@ def correct_jrxml(state: AgentState) -> Dict:
layout_text = ""
if isinstance(layout_schema, dict):
layout_text = layout_schema.get("schema_text", "")
# 构建保真度上下文(告诉 LLM 图片与模板的差异)
fidelity = state.get("ocr_fidelity", {})
fidelity_text = ""
if fidelity and fidelity.get("score", 1.0) < 0.9:
fidelity_text = (
f"[内容保真度警告] 得分 {fidelity.get('score', 0):.2f}/1.0\n"
+ "\n".join(f"- {issue}" for issue in fidelity.get("issues", []))
)
# 像素级对比上下文
pixel_fidelity = state.get("pixel_fidelity", {})
if pixel_fidelity and pixel_fidelity.get("ssim", 1.0) < 0.7:
fidelity_parts = [fidelity_text] if fidelity_text else []
fidelity_parts.append(
f"[像素保真度] SSIM={pixel_fidelity.get('ssim', 0):.4f}, "
f"像素差异={pixel_fidelity.get('diff_pct', 0):.2%}"
f"渲染结果与原图差异过大,请调整元素位置、尺寸和布局。"
)
fidelity_text = "\n".join(fidelity_parts)
prompt = load_prompt("correction").format(
current_jrxml=state.get("current_jrxml", ""),
error_msg=state.get("error_msg", ""),
explanation=state.get("natural_explanation", ""),
ocr_context=ocr_context,
layout_schema_text=layout_text,
fidelity_context=fidelity_text,
)
# 保存修正前状态(供 validate 判断是否写入错误知识库)
state["last_error_case"] = {
@@ -944,8 +1199,17 @@ def correct_jrxml(state: AgentState) -> Dict:
if len(jrxml.strip()) < 200:
_node_log.warning(f"correct_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["retry_count"] = state.get("retry_count", 0) + 1
# 去重检测:如果输出与输入完全相同(忽略空白差异),说明修正无效
_prev_norm = re.sub(r"\s+", "", prev_jrxml) if prev_jrxml else ""
_new_norm = re.sub(r"\s+", "", jrxml) if jrxml else ""
if _prev_norm and _new_norm and _prev_norm == _new_norm:
_node_log.warning(
f"correct_jrxml 输出与输入完全相同({len(jrxml)} 字符),修正无效,加速消耗 retry"
)
state["retry_count"] = state.get("retry_count", 0) + 2
else:
state["retry_count"] = state.get("retry_count", 0) + 1
state["conversation_history"].append(
{"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"}
)