Files
agent_jrxml/agent/nodes.py
T
panda 963c5e41c8 fix: nodes.py 调用 detect_annotations 前将 bbox 从 [x_min,y_min,x_max,y_max] 转为 {x,y,w,h}
annotation_detector._correlate_with_ocr 期望 bbox 格式为 {x,y,w,h},
但 OcrTextElement.to_dict() 返回 [x_min,y_min,x_max,y_max]。
Bug3 的根因在 nodes.py 而非 layout_analyzer。
2026-05-25 22:24:29 +08:00

1724 lines
69 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""LangGraph JRXML 生成工作流的节点函数。"""
import copy
import functools
import json
import os
import re
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict
from dotenv import load_dotenv
from agent.state import AgentState
from backend.llm import get_llm
from backend.logger import get_logger, set_trace_id
from backend.validation import validate_jrxml
from prompts.loader import load_prompt
load_dotenv(override=True)
_node_log = get_logger("agent")
MAX_RETRY = int(os.getenv("MAX_RETRY", "5"))
CONTEXT_MAX_TOKENS = int(os.getenv("CONTEXT_MAX_TOKENS", "6000"))
CONTEXT_KEEP_RECENT = int(os.getenv("CONTEXT_KEEP_RECENT", "4"))
HISTORY_MAX_SNAPSHOTS = int(os.getenv("HISTORY_MAX_SNAPSHOTS", "10"))
def _state_summary(state: AgentState) -> dict:
"""提取 state 中的关键字段用于日志摘要。"""
user_input = state.get("user_input", "")
return {
"session_id": state.get("session_id", ""),
"intent": state.get("intent", ""),
"status": state.get("status", ""),
"has_jrxml": bool(state.get("current_jrxml", "").strip()),
"jrxml_length": len(state.get("current_jrxml", "")),
"retry_count": state.get("retry_count", 0),
"user_input_preview": user_input[:100] if user_input else "",
"conversation_turns": len(state.get("conversation_history", [])),
"history_snapshots": len(state.get("history_states", [])),
"versions": len(state.get("jrxml_versions", [])),
}
def log_node(node_name: str):
"""装饰器:自动记录节点入口、出口和耗时。"""
def decorator(func):
@functools.wraps(func)
def wrapper(state: AgentState, *args, **kwargs):
t0 = time.time()
_node_log.info(
f"[节点入口] {node_name}",
extra={"node": node_name, "phase": "entry", "state": _state_summary(state)},
)
try:
result = func(state, *args, **kwargs)
elapsed_ms = round((time.time() - t0) * 1000)
_node_log.info(
f"[节点出口] {node_name}",
extra={
"node": node_name,
"phase": "exit",
"duration_ms": elapsed_ms,
"state": _state_summary(state),
},
)
return result
except Exception as e:
elapsed_ms = round((time.time() - t0) * 1000)
_node_log.error(
f"[节点异常] {node_name}: {e}",
extra={
"node": node_name,
"phase": "error",
"duration_ms": elapsed_ms,
"error": str(e),
"state": _state_summary(state),
},
)
raise
return wrapper
return decorator
# ============================================================
# 核心工作流节点
# ============================================================
@log_node("process_input")
def process_input(state: AgentState) -> Dict:
"""记录用户输入到对话历史,重置本轮请求状态。如有上次失败上下文则自动注入。"""
user_input = state.get("user_input", "")
# 维护全量对话历史(始终记录原始用户消息)
full_history = state.get("full_conversation_history", [])
full_history.append({"role": "user", "content": user_input, "ts": _now_iso()})
state["full_conversation_history"] = full_history
# 自动注入上次失败上下文
pending = state.get("pending_failure_context", {})
if pending and pending.get("error_msg"):
failure_note = (
f"[系统提示] 上次生成失败,以下是失败详情,请基于此修正:\n"
f"失败原因: {pending['error_msg']}\n"
f"上次失败的输出:\n{pending.get('bad_jrxml', '(无输出)')}"
)
user_input = f"{failure_note}\n\n---\n用户新输入:\n{user_input}"
state["pending_failure_context"] = {}
state["_failure_recovery"] = True # 标记本轮为失败恢复,分类器强制 modify_report
# 维护工作对话历史
conv_history = state.get("conversation_history", [])
conv_history.append({"role": "user", "content": user_input})
state["conversation_history"] = conv_history
# OCR 单据字段精确提取(处理上传的图片文件)
uploaded_path = state.get("uploaded_file_path", "")
if uploaded_path and Path(uploaded_path).is_file():
suffix = Path(uploaded_path).suffix.lower()
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"):
try:
from backend.ocr_extractor import OcrExtractor
extractor = OcrExtractor()
# 不传预设字段名,让 OCR 自动发现文档中的所有键值对
ocr_result = extractor.extract(uploaded_path)
if ocr_result.get("ocr_available"):
state["ocr_extraction_result"] = ocr_result
_node_log.info(
"OCR 字段提取完成",
extra={
"file": uploaded_path,
"elements": ocr_result.get("total_elements", 0),
"fields": len(ocr_result.get("fields", [])),
},
)
# 将提取到的字段注入到对话上下文,供 LLM 使用
extracted_fields = ocr_result.get("fields", [])
non_empty = [f for f in extracted_fields if f.get("field_value")]
if non_empty:
lines = ["[OCR 单据字段提取结果]"]
for f in non_empty:
lines.append(
f"- {f['field_name']}: {f['field_value']}"
f"(置信度: {f['confidence']:.0%}, 方法: {f['extraction_method']}"
)
ocr_context = "\n".join(lines)
user_input = f"{ocr_context}\n\n{user_input}"
# 同时更新工作对话历史中的最后一条
conv_history[-1]["content"] = user_input
# 批注检测(圈选/箭头标记)
elements = ocr_result.get("all_elements", [])
if elements:
try:
from backend.annotation_detector import detect_annotations
elem_dicts = []
for e in elements:
d = e.to_dict() if hasattr(e, "to_dict") else (e if isinstance(e, dict) else {"text": str(e), "bbox": [], "confidence": 0})
# annotation_detector 期望 bbox 为 {x,y,w,h},但 OcrTextElement.to_dict() 返回 [x_min,y_min,x_max,y_max]
b = d.get("bbox", [])
if isinstance(b, (list, tuple)) and len(b) == 4:
d["bbox"] = {"x": b[0], "y": b[1], "w": b[2] - b[0], "h": b[3] - b[1]}
elif isinstance(b, dict) and "x" not in b:
# 已经是 [x,y,w,h] 形式的 list 但被当成 dict 的情况
d["bbox"] = {"x": b.get(0, 0), "y": b.get(1, 0), "w": b.get(2, 0) - b.get(0, 0), "h": b.get(3, 0) - b.get(1, 0)}
elem_dicts.append(d)
ann_result = detect_annotations(uploaded_path, elem_dicts)
if ann_result.get("total", 0) > 0:
state["annotation_result"] = ann_result
_node_log.info(
"批注检测完成",
extra={
"circles": len(ann_result.get("circles", [])),
"arrows": len(ann_result.get("arrows", [])),
},
)
except Exception as e:
_node_log.warning(f"批注检测失败: {e}")
except Exception as e:
_node_log.warning(f"OCR 字段提取失败: {e}")
state["ocr_extraction_result"] = {"error": str(e)}
state["uploaded_file_path"] = ""
# ── OCR 两层日志:内容层 + 位置层 ──
_log_ocr_layers(state)
# 重置本轮请求字段
state["retry_count"] = 0
state["user_modification_request"] = user_input
return state
@log_node("save_state_snapshot")
def save_state_snapshot(state: AgentState) -> Dict:
"""保存当前状态快照到 history_states,用于撤销操作。最多保留 N 个版本。"""
snapshots = state.get("history_states", [])
if not isinstance(snapshots, list):
snapshots = []
snapshot = {
"current_jrxml": state.get("current_jrxml", ""),
"final_jrxml": state.get("final_jrxml", ""),
"status": state.get("status", ""),
"conversation_history": copy.deepcopy(state.get("conversation_history", [])),
"user_modification_request": state.get("user_modification_request", ""),
"intent": state.get("intent", ""),
}
snapshots.append(snapshot)
max_snap = HISTORY_MAX_SNAPSHOTS
if len(snapshots) > max_snap:
snapshots = snapshots[-max_snap:]
state["history_states"] = snapshots
return state
@log_node("classify_intent")
def classify_intent(state: AgentState) -> Dict:
"""使用 LLM 对用户输入进行意图分类(8 种意图)。"""
user_input = state.get("user_input", "")
has_report = "" if state.get("current_jrxml", "").strip() else ""
# 失败恢复模式:跳过 LLM 分类,直接走修正流程
if state.pop("_failure_recovery", False):
state["intent"] = "modify_report"
return state
intent = "initial_generation"
try:
llm = get_llm(caller="classify_intent")
# 智能截断:保留首部 200 + 尾部 300,避免用户真正输入被中间的长 JRXML 挤掉
if len(user_input) > 500:
ui_snippet = user_input[:200] + "\n...[已截断]...\n" + user_input[-300:]
else:
ui_snippet = user_input
prompt = load_prompt("intent_classify").format(
has_report=has_report,
user_input=ui_snippet,
)
resp = llm.invoke(prompt)
raw = resp.content.strip().lower()
valid_intents = [
"initial_generation", "modify_report", "preview_report",
"export_pdf", "export_jrxml", "undo_modification",
"consult_question", "reset_session",
]
for vi in valid_intents:
if vi in raw:
intent = vi
break
else:
# 兜底:有报表 → modify_report,无报表 → initial_generation
intent = "modify_report" if has_report == "" else "initial_generation"
except Exception:
intent = "modify_report" if has_report == "" else "initial_generation"
state["intent"] = intent
return state
@log_node("handle_consult")
def handle_consult(state: AgentState) -> Dict:
"""处理咨询类问题:调用 LLM 直接回答,不走报表生成流程。"""
user_input = state.get("user_input", "")
try:
llm = get_llm(caller="handle_consult")
prompt = load_prompt("consult").format(question=user_input)
resp = llm.invoke(prompt)
answer = resp.content.strip()
except Exception:
answer = "抱歉,暂时无法处理您的问题,请稍后再试。"
state["consult_answer"] = answer
state["conversation_history"].append({"role": "assistant", "content": answer})
state["full_conversation_history"].append(
{"role": "assistant", "content": answer, "ts": _now_iso()}
)
return state
@log_node("handle_undo")
def handle_undo(state: AgentState) -> Dict:
"""撤销上一步修改:从 history_states 恢复最近一个快照。"""
snapshots = state.get("history_states", [])
if not isinstance(snapshots, list) or not snapshots:
state["conversation_history"].append(
{"role": "assistant", "content": "没有可撤销的操作。"}
)
return state
prev = snapshots.pop()
state["history_states"] = snapshots
state["current_jrxml"] = prev.get("current_jrxml", "")
state["final_jrxml"] = prev.get("final_jrxml", "")
state["status"] = prev.get("status", "")
state["conversation_history"] = prev.get("conversation_history", [])
state["user_modification_request"] = prev.get("user_modification_request", "")
state["conversation_history"].append(
{"role": "assistant", "content": "已撤销上一步修改,恢复到之前的状态。"}
)
state["full_conversation_history"].append(
{"role": "assistant", "content": "已撤销上一步修改。", "ts": _now_iso()}
)
return state
@log_node("handle_reset")
def handle_reset(state: AgentState) -> Dict:
"""重置当前会话:清空报表相关状态,保留会话信息。"""
state["current_jrxml"] = ""
state["final_jrxml"] = ""
state["status"] = ""
state["error_msg"] = ""
state["natural_explanation"] = ""
state["user_modification_request"] = ""
state["retrieved_context"] = ""
state["retry_count"] = 0
state["compressed_history"] = ""
state["history_states"] = []
state["intent"] = "initial_generation"
state["conversation_history"] = []
state["conversation_history"].append(
{"role": "assistant", "content": "会话已重置,请描述您要创建的新报表。"}
)
state["full_conversation_history"].append(
{"role": "assistant", "content": "会话已重置。", "ts": _now_iso()}
)
return state
@log_node("count_tokens")
def count_tokens(state: AgentState) -> int:
"""使用 tiktoken(gpt-4o 编码器)计算当前上下文 token 数量。"""
try:
import tiktoken
enc = tiktoken.encoding_for_model("gpt-4o")
except Exception:
# 回退方案:中英文混合场景下,近似 1 token ≈ 2.5 个字符
text = json.dumps({
"history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:],
"jrxml": state.get("current_jrxml", ""),
"compressed": state.get("compressed_history", ""),
}, ensure_ascii=False)
return len(text) // 2.5
text = json.dumps({
"history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:],
"jrxml": state.get("current_jrxml", ""),
"compressed": state.get("compressed_history", ""),
}, ensure_ascii=False)
return len(enc.encode(text))
@log_node("manage_context")
def manage_context(state: AgentState) -> Dict:
"""当 token 数量超过阈值时,压缩较早的对话轮次。"""
token_count = count_tokens(state)
state["current_token_count"] = token_count
if token_count <= CONTEXT_MAX_TOKENS:
return state
full_history = state.get("full_conversation_history", [])
if len(full_history) <= CONTEXT_KEEP_RECENT:
return state
# 最近N轮保留完整,更早的轮次送去压缩
recent = full_history[-CONTEXT_KEEP_RECENT:]
older = full_history[:-CONTEXT_KEEP_RECENT]
if not older:
return state
conv_text = json.dumps(older, ensure_ascii=False, indent=2)
try:
llm = get_llm(caller="manage_context")
prompt = load_prompt("compression").format(conversation_text=conv_text)
resp = llm.invoke(prompt)
new_compressed = resp.content.strip()[:300]
except Exception:
new_compressed = _simple_compress(older)
# 合并已有压缩与新压缩
existing = state.get("compressed_history", "")
if existing:
state["compressed_history"] = f"{existing}\n---\n{new_compressed}"
else:
state["compressed_history"] = new_compressed
state["conversation_history"] = list(recent)
state["current_token_count"] = count_tokens(state)
return state
@log_node("load_session_node")
def load_session_node(state: AgentState) -> Dict:
"""在请求开始时从磁盘加载会话状态。"""
session_id = state.get("session_id", "")
if not session_id:
return state
try:
from backend.session import load_session
data = load_session(session_id)
if data and data.get("agent_state"):
saved = data["agent_state"]
# 恢复核心字段(不覆盖当前请求的 user_input / stage / session_id
for key in ("conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history",
"session_name", "created_at", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result", "layout_schema", "ocr_elements"):
if key in saved and key not in ("user_input", "stage", "session_id"):
state[key] = saved[key]
state["session_name"] = data.get("session_name", "")
state["created_at"] = data.get("created_at", "")
except Exception:
_node_log.warning("会话加载失败,使用空状态",
extra={"session_id": state.get("session_id", "")})
return state
@log_node("save_session_node")
def save_session_node(state: AgentState) -> Dict:
"""将当前代理状态持久化到磁盘。"""
session_id = state.get("session_id", "")
if not session_id:
return state
try:
from backend.session import save_session
persistable = {}
for key in ("session_id", "conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history",
"status", "error_msg", "history_states",
"ocr_extraction_result", "uploaded_file_path",
"annotation_result", "layout_schema", "ocr_elements"):
if key in state:
persistable[key] = state[key]
persistable["updated_at"] = _now_iso()
session_name = state.get("session_name", "")
if not session_name and state.get("conversation_history"):
first_user = next(
(m["content"][:50] for m in state["conversation_history"]
if m.get("role") == "user"), "")
if first_user:
session_name = first_user
save_session(session_id, persistable, session_name)
if not state.get("session_name"):
state["session_name"] = session_name
state["updated_at"] = persistable["updated_at"]
except Exception:
_node_log.exception("会话保存失败",
extra={"session_id": state.get("session_id", "")})
return state
def _simple_compress(messages: list[dict]) -> str:
"""当 LLM 不可用时,基于简单规则的压缩回退方案。"""
points = []
for m in messages:
if m.get("role") == "user":
points.append(f"用户提问:{m['content'][:100]}")
return "; ".join(points[-10:])
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _format_row_coordinates(row: dict) -> dict:
"""将单行 OCR 元素格式化为紧凑的坐标描述,供阶段二 refine_layout 使用。"""
if not isinstance(row, dict):
return {}
elements = row.get("elements", [])
if not elements:
return {"y_center": row.get("y_center", 0), "columns": []}
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
cols = []
for ci, e in enumerate(sorted_elems):
x = e.get("x", 0)
y = e.get("y", 0)
w = e.get("w", 0)
h = e.get("h", 0)
if not (x > 0 and y > 0 and w > 0 and h > 0):
continue
cols.append({
"col": ci,
"x": x,
"y": y,
"w": w,
"h": h,
"font_size": e.get("font_size", 12),
"text": e.get("text", ""),
})
return {"y_center": row.get("y_center", 0), "columns": cols}
def _build_sampled_text(ocr_rows: list) -> str:
"""从 OCR 行数据构建采样坐标 JSON 字符串。"""
sampled = {}
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
if len(ocr_rows) > 1:
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
if len(ocr_rows) > 2:
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
return json.dumps(sampled, ensure_ascii=False, indent=2)
def _extract_band_height(band_xml: str) -> int:
"""从 <band height="N"> 中提取高度值。"""
m = re.search(r'<band\b[^>]*\sheight\s*=\s*"(\d+)"', band_xml)
return int(m.group(1)) if m else 0
def _extract_xml_fragment(text: str) -> str:
"""从 LLM 响应中提取 XML 片段(去除 markdown 代码块和解释文本)。"""
text = text.strip()
# 尝试提取 markdown 代码块内的内容
m = re.search(r"```(?:xml)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if m:
content = m.group(1).strip()
if content:
return content
# 尝试找到 <band ...>...</band> 片段
m = re.search(r"(<band\b[\s\S]*?</band>)", text, re.IGNORECASE)
if m:
return m.group(1).strip()
return text
def _count_zero_coordinate_elements(xml: str) -> tuple[int, int]:
"""统计坐标无效(x=0 或 y=0 或 width=0 或 height=0)的 reportElement 数量。
返回 (zero_count, total_count)。
"""
total = 0
zero = 0
for m in re.finditer(r'<reportElement\b([^>]*)/>', xml):
total += 1
attrs = m.group(1)
xm = re.search(r'\sx\s*=\s*"(\d+)"', attrs)
ym = re.search(r'\sy\s*=\s*"(\d+)"', attrs)
wm = re.search(r'\swidth\s*=\s*"(\d+)"', attrs)
hm = re.search(r'\sheight\s*=\s*"(\d+)"', attrs)
x = int(xm.group(1)) if xm else 0
y = int(ym.group(1)) if ym else 0
w = int(wm.group(1)) if wm else 0
h = int(hm.group(1)) if hm else 0
if x == 0 or y == 0 or w == 0 or h == 0:
zero += 1
return zero, total
def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
"""程序化字段映射:将 $F{{field_N}} 替换为 OCR 提取的真实字段名。
纯正则替换,不调 LLM。100% 确定性,零内容丢失。
未映射的 field_N 会被重命名为基于波段上下文的描述性名称。
"""
result = jrxml
for i, f in enumerate(ocr_fields):
placeholder = f"field_{i+1}"
raw_name = f.get("field_name", "")
if not raw_name:
continue
real_name = _sanitize_field_name(raw_name)
if real_name == placeholder:
continue
result = re.sub(
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
rf'\g<1>{real_name}\g<2>', result,
)
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{real_name}}}')
# 第二遍:为剩余未映射的 field_N 赋予基于波段位置的描述性名称
remaining = set()
for m in re.finditer(r'\$F\{(field_\d+)\}', result):
remaining.add(m.group(1))
if remaining:
_SECTION_TAGS = (
"title", "pageHeader", "columnHeader", "detail", "columnFooter",
"pageFooter", "summary", "background", "noData",
)
for placeholder in sorted(remaining, key=lambda x: int(re.search(r'\d+', x).group())):
n = int(re.search(r'\d+', placeholder).group())
# 查找第一个引用此字段的位置,确定波段上下文
pattern = rf'\$F\{{{re.escape(placeholder)}\}}'
m = re.search(pattern, result)
section = "data"
if m:
before = result[:m.start()]
# 从后往前找最近的 section 标签
for tag in _SECTION_TAGS:
# 找最近的未闭合 section 标签
opens = [o.start() for o in re.finditer(rf'<{tag}>', before)]
closes = [o.start() for o in re.finditer(rf'</{tag}>', before)]
last_open = opens[-1] if opens else -1
last_close = closes[-1] if closes else -1
if last_open > last_close:
section = tag
break
new_name = f"{section}_f{n}"
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{new_name}}}')
result = re.sub(
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
rf'\g<1>{new_name}\g<2>', result,
)
return result
def _sanitize_field_name(name: str) -> str:
"""将 OCR 字段名净化为合法的 JRXML field name(仅 ASCII 字母/数字/下划线)。
非 ASCII 字符会被替换为其 Unicode 码点,确保唯一且合法。
"""
result = []
for ch in name:
if ch.isascii() and (ch.isalnum() or ch == '_'):
result.append(ch)
elif ch.isascii():
result.append('_')
else:
# 非 ASCII 转 _uXXXX_ 格式,保留可追溯性
cp = ord(ch)
result.append(f'_u{cp:04X}_')
cleaned = ''.join(result)
cleaned = cleaned.strip('_')
if not cleaned:
return "unnamed_field"
if cleaned[0].isdigit():
cleaned = 'f_' + cleaned
# 压缩连续下划线
cleaned = re.sub(r'_{2,}', '_', cleaned)
return cleaned.lower()
def _format_ocr_context(state: AgentState) -> str:
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
ocr_result = state.get("ocr_extraction_result")
if not ocr_result or not isinstance(ocr_result, dict):
return ""
if ocr_result.get("error"):
return ""
parts = []
parts.append("[图片OCR识别结果]")
total = ocr_result.get("total_elements", 0)
if total:
parts.append(f"检测到 {total} 个文字元素")
# 提取到的字段
fields = ocr_result.get("fields", [])
if fields:
parts.append("\n提取的结构化字段:")
for f in fields:
if f.get("field_value"):
parts.append(
f" - {f['field_name']}: {f['field_value']} "
f"(方法={f.get('extraction_method','?')}, "
f"置信度={f.get('confidence',0):.2f})"
)
# 所有原始文本(用于表格匹配等需要全文的场景)
elements = ocr_result.get("all_elements", [])
if elements:
parts.append("\n全部文本元素(含坐标):")
for e in elements:
bbox = e.get("bbox", [])
if isinstance(bbox, list) and len(bbox) >= 4:
x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
x, y, w, h = x_min, y_min, x_max - x_min, y_max - y_min
else:
x, y, w, h = 0, 0, 0, 0
parts.append(
f" [{x},{y} {w}×{h}] {e.get('text','')} "
f"(置信度={e.get('confidence',0):.2f})"
)
# 批注检测结果
ann_result = state.get("annotation_result")
if ann_result and isinstance(ann_result, dict):
try:
from backend.annotation_detector import format_annotation_context
ann_text = format_annotation_context(ann_result)
if ann_text:
parts.append("\n" + ann_text)
except Exception:
pass
return "\n".join(parts)
def _log_ocr_layers(state: AgentState) -> None:
"""记录 OCR 两层分离日志:内容层(文本/字段)+ 位置层(布局/坐标)。"""
# ── 内容层:OCR 文本元素 + 提取的字段 ──
ocr_result = state.get("ocr_extraction_result")
ocr_elements = state.get("ocr_elements", [])
content_parts = []
if isinstance(ocr_result, dict) and not ocr_result.get("error"):
total = ocr_result.get("total_elements", 0)
fields = ocr_result.get("fields", [])
non_empty = [f for f in fields if f.get("field_value")]
if total or non_empty:
content_parts.append(
f"OCR 提取: {total} 个文本元素, {len(non_empty)} 个有效字段"
)
if isinstance(ocr_elements, list) and ocr_elements:
elem_count = sum(len(row.get("elements", [])) for row in ocr_elements)
content_parts.append(
f"API 注入 OCR 元素: {len(ocr_elements)} 行, {elem_count} 个文本"
)
if content_parts:
_node_log.info(
"[内容层] " + " | ".join(content_parts),
extra={"layer": "content", "phase": "ocr_extraction"},
)
# ── 位置层:布局 schema(行/列/区域)──
layout = state.get("layout_schema")
if isinstance(layout, dict) and layout.get("total_rows", 0) > 0:
region_list = layout.get("regions", [])
_rn = {"title": "标题", "header": "表头", "data": "数据", "footer": "表尾"}
region_names = [_rn.get(r["type"], r["type"]) for r in region_list] if isinstance(region_list, list) else []
cols = layout.get("total_columns", 0)
rows = layout.get("total_rows", 0)
regions_label = ", ".join(region_names) if region_names else "标题/表头/数据/表尾"
_node_log.info(
f"[位置层] 布局 schema: {cols}× {rows} 行, 区域: {regions_label}",
extra={
"layer": "position",
"phase": "layout_analysis",
"columns": cols,
"rows": rows,
"regions": region_names,
"a4_confidence": layout.get("a4_confidence", ""),
},
)
# ── 合并:两阶段处理总结 ──
has_content = (isinstance(ocr_result, dict) and not ocr_result.get("error")) or \
(isinstance(ocr_elements, list) and ocr_elements)
has_layout = isinstance(layout, dict) and layout.get("total_rows", 0) > 0
if has_content and has_layout:
_node_log.info(
"[合并] 内容层 + 位置层均已就绪 — "
"注入 prompt: 骨架生成 → 精调布局 → 字段映射",
extra={"layer": "merge", "pipeline": "skeleton→refine→map_fields"},
)
elif has_content and not has_layout:
_node_log.info(
"[合并] 仅有内容层 — 使用单阶段 generate(无布局 schema",
extra={"layer": "merge", "pipeline": "generate_only"},
)
elif has_layout and not has_content:
_node_log.info(
"[合并] 仅有位置层 — 使用布局 schema 指导生成",
extra={"layer": "merge", "pipeline": "layout_only"},
)
@log_node("retrieve")
def retrieve(state: AgentState) -> Dict:
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。
支持按 KB 隔离搜索 + 模板意图检测。
"""
try:
from backend.rag_adapter import search_chunks
from backend.error_kb import search_error_cases
user_input = state.get("user_input", "")
kb_id = state.get("kb_id", "")
context = search_chunks(user_input, k=5, kb_id=kb_id)
# 错误知识库
error_msg = state.get("error_msg", "")
if error_msg:
error_context = search_error_cases(error_msg, k=2)
if error_context:
context = f"{context}\n\n[历史错误修正案例]\n{error_context}"
# 模板意图检测:用户是否提到了模板名?
template_keywords = _detect_template_intent(user_input)
if template_keywords and kb_id:
try:
from backend.kb_searcher import search_templates_in_kb
templates = search_templates_in_kb(kb_id, template_keywords, k=1)
if templates:
tmpl = templates[0]
state["kb_template_jrxml"] = tmpl.get("content", "")
state["kb_template_name"] = (
tmpl.get("metadata", {}).get("report_name", "")
)
context += (
f"\n\n[匹配到模板: {state['kb_template_name']}]\n"
f"{state['kb_template_jrxml']}"
)
except Exception:
_node_log.warning("模板检索失败", extra={"kb_id": kb_id})
state["retrieved_context"] = context
except Exception:
_node_log.exception("RAG 检索失败", extra={"user_input": user_input[:80]})
state["retrieved_context"] = ""
return state
def _detect_template_intent(user_input: str) -> str:
"""检测用户输入中是否包含模板引用意图,返回提取的搜索关键词。"""
import re
patterns = [
r"根据(.+?)模板",
r"基于(.+?)模板",
r"参照(.+?)模板",
r"用(.+?)模板",
r"使用(.+?)模板",
r"(.+单)模板",
]
for pat in patterns:
m = re.search(pat, user_input)
if m:
kw = m.group(1).strip()
if len(kw) >= 2:
return kw
return ""
def _build_template_context(state: dict) -> str:
"""构建模板上下文,用于注入生成 prompt。
优先级:对话上传 > KB 检索 > KB 字段定义。
"""
parts = []
# 对话中上传的 JRXML 模板
uploaded = state.get("uploaded_template_jrxml", "")
if uploaded:
params = state.get("uploaded_template_params", [])
param_str = "\n".join(
f" - {p['name']} ({p.get('type', 'String')})" for p in params
) if params else "(参数列表未解析)"
parts.append(
"[对话上传的模板]\n"
f"以下为用户上传的 JRXML 模板,请基于此模板进行修改:\n"
f"模板参数:\n{param_str}\n"
f"```xml\n{uploaded}\n```"
)
# KB 检索到的模板
kb_tmpl = state.get("kb_template_jrxml", "")
kb_name = state.get("kb_template_name", "")
if kb_tmpl and not uploaded:
parts.append(
f"[知识库模板: {kb_name}]\n"
f"以下为从知识库检索到的 JRXML 模板,请作为结构参考:\n"
f"```xml\n{kb_tmpl}\n```"
)
# KB 字段定义
kb_fields = state.get("kb_fields", [])
if kb_fields:
field_table = "\n".join(
f"| {f['name']} | {f.get('description', '')} | {f.get('type', 'String')} |"
for f in kb_fields
)
parts.append(
"[可用数据字段]\n"
"生成 JRXML 时请使用以下字段作为 $P{{xxx}} 参数:\n"
f"| 字段名 | 含义 | 类型 |\n|---|---|---|\n{field_table}"
)
return "\n\n".join(parts)
@log_node("generate")
def generate(state: AgentState) -> Dict:
"""根据用户需求和检索到的上下文生成初始 JRXML。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="generate", max_tokens=32768)
user_request = state.get("user_input", "")
ocr_text = _format_ocr_context(state)
if ocr_text:
user_request = f"{ocr_text}\n\n---\n用户需求:\n{user_request}"
prompt = load_prompt("initial_generation").format(
context=state.get("retrieved_context", ""),
user_request=user_request,
template_context=_build_template_context(state),
)
full = []
for chunk in llm.stream(prompt):
full.append(chunk)
writer({"type": "stream", "node": "generate", "text": chunk})
jrxml = _extract_jrxml("".join(full))
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@log_node("generate_skeleton")
def generate_skeleton(state: AgentState) -> Dict:
"""阶段一:根据压缩的布局 schema 生成骨架 JRXML$F{field_N} 占位)。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
schema = state.get("layout_schema", {})
schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else ""
user_request = state.get("user_input", "")
prompt = load_prompt("skeleton_generation").format(
layout_schema=schema_text,
context=state.get("retrieved_context", ""),
user_request=user_request,
template_context=_build_template_context(state),
)
llm = get_llm(caller="generate_skeleton", max_tokens=32768)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
if not full_text.strip():
_node_log.warning("generate_skeleton 首次返回空响应,以更高 max_tokens 重试")
llm = get_llm(caller="generate_skeleton", max_tokens=65536)
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
if not full_text.strip():
_node_log.error("generate_skeleton LLM 返回空响应(含重试)")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"generate_skeleton 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
@log_node("refine_layout")
def refine_layout(state: AgentState) -> Dict:
"""阶段二:Band 级窗口化精调 — 拆解骨架 JRXML 为独立 band,逐窗口 LLM 调整坐标。
流程:
1. decompose_jrxml() 拆解为 header + bands + footer
2. 每个 band 作为一个(或多个,若 >4000 字符)窗口发给 LLM
3. LLM 只修改该窗口内 reportElement 的 x/y/width/height
4. reassemble_jrxml() 重组 + validate_element_count() 校验
"""
from langgraph.config import get_stream_writer
from agent.jrxml_windower import (
decompose_jrxml, split_band_into_windows, reassemble_band_windows,
reassemble_jrxml, validate_element_count,
)
writer = get_stream_writer()
llm = get_llm(caller="refine_layout")
prev_jrxml = state.get("current_jrxml", "")
if not prev_jrxml.strip():
_node_log.warning("refine_layout 无输入 JRXML,跳过")
return state
# 拆解 JRXML
parts = decompose_jrxml(prev_jrxml)
if parts is None or parts.band_count == 0:
_node_log.warning("refine_layout 拆解失败或无 band,跳过")
return state
# 构建采样坐标
ocr_rows = state.get("ocr_elements", [])
sampled_text = _build_sampled_text(ocr_rows)
template_ctx = _build_template_context(state)
# 逐 band 窗口化精调
modified_bands: dict[str, str] = {}
total_windows = 0
for band in parts.bands:
if band.element_count == 0:
modified_bands[band.label] = band.band_xml
continue
windows = split_band_into_windows(band, max_chars=4000)
total_windows += len(windows)
band_results: list[str] = []
for wi, win_xml in enumerate(windows):
prompt = load_prompt("refine_layout").format(
band_name=band.section_name,
band_index=band.band_index,
band_height=_extract_band_height(win_xml),
window_index=wi + 1,
total_windows=len(windows),
xml_fragment=win_xml,
sampled_coordinates=sampled_text,
template_context=template_ctx,
)
try:
response = llm.invoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
fragment = _extract_xml_fragment(content)
if fragment:
zero_count, total = _count_zero_coordinate_elements(fragment)
if total > 0 and zero_count / total > 0.3:
_node_log.warning(
"refine_layout 窗口 %s/%d 零坐标元素 %d/%d (%.0f%%),使用原文",
band.label, wi + 1, zero_count, total,
zero_count / total * 100,
)
band_results.append(win_xml)
else:
band_results.append(fragment)
writer({"type": "stream", "node": "refine_layout",
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
else:
_node_log.warning("refine_layout 窗口 %s/%d 返回空,使用原文",
band.label, wi + 1)
band_results.append(win_xml)
except Exception as e:
_node_log.warning("refine_layout 窗口 %s/%d LLM 失败: %s,使用原文",
band.label, wi + 1, e)
band_results.append(win_xml)
if len(band_results) == 1:
modified_bands[band.label] = band_results[0]
else:
modified_bands[band.label] = reassemble_band_windows(band_results)
# 重组并校验
result = reassemble_jrxml(parts, modified_bands)
validation = validate_element_count(prev_jrxml, result, "refine_layout")
_node_log.info(
"refine_layout 窗口化完成: %d bands, %d 窗口, 元素 %d%d (%.1f%%)",
parts.band_count, total_windows,
validation["original"], validation["modified"],
validation["change_pct"] * 100,
)
if not validation["ok"]:
_node_log.error("refine_layout 元素丢失过多,回退到骨架版本")
return state
state["current_jrxml"] = result
state["conversation_history"].append({"role": "assistant", "content": result})
return state
@log_node("map_fields")
def map_fields(state: AgentState) -> Dict:
"""阶段三:程序化字段映射 — 用正则将 $F{field_N} 替换为 OCR 字段名,不调 LLM。
仅当 OCR 字段名包含中文等需要语义解释时才回退到 LLM。
"""
from agent.jrxml_windower import validate_element_count
prev_jrxml = state.get("current_jrxml", "")
if not prev_jrxml.strip():
_node_log.warning("map_fields 无输入 JRXML,跳过")
return state
# 提取 OCR 字段列表
ocr_result = state.get("ocr_extraction_result", {})
ocr_fields: list[dict] = []
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
ocr_fields = ocr_result["fields"]
if not ocr_fields:
_node_log.info("map_fields 无 OCR 字段,保留占位字段名")
state["conversation_history"].append({"role": "assistant", "content": prev_jrxml})
return state
# 程序化替换(主路径)
result = _programmatic_map_fields(prev_jrxml, ocr_fields)
# 校验
validation = validate_element_count(prev_jrxml, result, "map_fields")
_node_log.info(
"map_fields 程序化完成: %d 个字段, 元素 %d%d (%.1f%%)",
len(ocr_fields),
validation["original"], validation["modified"],
validation["change_pct"] * 100,
)
if not validation["ok"]:
_node_log.error("map_fields 元素丢失过多,回退到前一版本")
return state
state["current_jrxml"] = result
state["conversation_history"].append({"role": "assistant", "content": result})
return state
@log_node("modify_jrxml")
def modify_jrxml(state: AgentState) -> Dict:
"""根据用户的修改请求修改现有 JRXML。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="modify_jrxml", max_tokens=32768)
# 构建对话上下文:压缩摘要 + 最近对话
compressed = state.get("compressed_history", "")
recent = state.get("conversation_history", [])[-6:]
conv_parts = []
if compressed:
conv_parts.append(f"[早期对话摘要]\n{compressed}")
conv_parts.append(json.dumps(recent, ensure_ascii=False, indent=2))
conv_text = "\n\n---\n\n".join(conv_parts)
prompt = load_prompt("modification").format(
current_jrxml=state.get("current_jrxml", ""),
conversation_history=conv_text,
modification_request=state.get("user_modification_request", ""),
ocr_context=_format_ocr_context(state),
template_context=_build_template_context(state),
)
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "modify_jrxml")
if not full_text.strip():
_node_log.error("modify_jrxml LLM 返回空响应,保留原版本")
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"modify_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
state["current_jrxml"] = jrxml
state["conversation_history"].append(
{
"role": "user",
"content": state.get("user_modification_request", ""),
}
)
state["conversation_history"].append({"role": "assistant", "content": jrxml})
state["full_conversation_history"] = (
list(state.get("full_conversation_history", [])) +
[
{"role": "user", "content": state.get("user_modification_request", ""), "ts": _now_iso()},
{"role": "assistant", "content": jrxml, "ts": _now_iso()},
]
)
state["retry_count"] = 0
return state
# ── Java renderer config ──────────────────────────────────────────────
_JAVA_BIN = os.path.join(
os.environ.get("JAVA_HOME", "C:/Program Files/Java/jdk-21.0.11"),
"bin", "java.exe"
)
_JAVA_JAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib", "java")
_JAVA_RENDERER_CP = ";".join([
os.path.join(_JAVA_JAR_DIR, j) for j in [
"jasperreports-6.21.0.jar",
"commons-logging-1.3.5.jar",
"commons-collections4-4.5.0.jar",
"commons-beanutils-1.10.1.jar",
"commons-lang3-3.17.0.jar",
"commons-digester-2.1.jar",
"itext-2.1.7.jar",
"jfreechart-1.5.5.jar",
"ecj-3.38.0.jar",
]
])
_JAVA_RENDERER_CLASS = "JrxmlRenderer"
_JAVA_RENDERER_CP = "." + os.pathsep + _JAVA_RENDERER_CP
def _render_jrxml_to_png(jrxml: str, output_path: str, scale: float = 2.0) -> bool:
"""调用 Java JrxmlRenderer 将 JRXML 渲染为 PNG。
返回 True 表示渲染成功,False 表示失败。
"""
import subprocess
import tempfile
tmpdir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "tmp")
os.makedirs(tmpdir, exist_ok=True)
jrxml_path = os.path.join(tmpdir, "_render_input.jrxml")
with open(jrxml_path, "w", encoding="utf-8") as f:
f.write(jrxml)
try:
result = subprocess.run(
[_JAVA_BIN, "-cp", _JAVA_RENDERER_CP, _JAVA_RENDERER_CLASS,
jrxml_path, output_path, str(scale)],
capture_output=True, text=True, timeout=120,
cwd=_JAVA_JAR_DIR,
)
if result.returncode == 0:
_node_log.info(f"PNG rendered: {output_path} ({result.stdout.strip()})")
return True
else:
_node_log.warning(f"PNG render failed: {result.stdout.strip()} {result.stderr.strip()}")
return False
except Exception as e:
_node_log.warning(f"PNG render exception: {e}")
return False
def _compute_pixel_similarity(rendered_png: str, reference_image: str) -> dict:
"""计算渲染 PNG 与参考图片的像素级相似度。
使用 SSIM(结构相似性)作为主要指标,同时返回像素差异比例。
返回 {"ssim": float, "diff_pct": float, "error": str|None}
"""
try:
import cv2
import numpy as np
rendered = cv2.imread(rendered_png, cv2.IMREAD_GRAYSCALE)
reference = cv2.imread(reference_image, cv2.IMREAD_GRAYSCALE)
if rendered is None:
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"无法读取渲染图片: {rendered_png}"}
if reference is None:
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"无法读取参考图片: {reference_image}"}
# Resize rendered to match reference dimensions for comparison
if rendered.shape != reference.shape:
rendered = cv2.resize(rendered, (reference.shape[1], reference.shape[0]))
# SSIM
from skimage.metrics import structural_similarity as ssim
score = ssim(rendered, reference, data_range=255)
# Pixel difference percentage
diff = cv2.absdiff(rendered, reference)
diff_pct = float(np.count_nonzero(diff > 30)) / diff.size
return {"ssim": round(score, 4), "diff_pct": round(diff_pct, 4), "error": None}
except ImportError as e:
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"缺少依赖: {e}"}
except Exception as e:
return {"ssim": 0.0, "diff_pct": 1.0, "error": str(e)}
def _check_ocr_fidelity(jrxml: str, state: dict) -> dict:
"""比对生成的 JRXML 与原始图片 OCR 提取内容的保真度。
检查维度:
1. 字段覆盖:OCR 字段名是否在 JRXML <field> 声明中出现
2. 元素数量:JRXML 中 textField+staticText 数量与 OCR 文本元素数量之比
3. 列结构:data band 中的列数与 OCR 检测到的列数比对
"""
ocr_elements = state.get("ocr_elements", [])
ocr_result = state.get("ocr_extraction_result", {})
layout_schema = state.get("layout_schema", {})
# 无 OCR 数据时跳过
if not ocr_elements and not ocr_result:
return {"score": 1.0, "field_coverage": 1.0, "element_coverage": 1.0, "issues": []}
issues = []
# 1. 元素数量对比
text_fields = len(re.findall(r"<textField", jrxml))
static_texts = len(re.findall(r"<staticText", jrxml))
total_jrxml_elements = text_fields + static_texts
ocr_text_count = 0
if isinstance(ocr_elements, list):
ocr_text_count = len([e for e in ocr_elements if isinstance(e, dict) and e.get("text", "").strip()])
if ocr_text_count == 0 and isinstance(ocr_result, dict):
ocr_text_count = ocr_result.get("total_elements", 0)
if ocr_text_count > 0:
element_coverage = min(total_jrxml_elements / max(ocr_text_count, 1), 1.0)
if element_coverage < 0.3:
issues.append(
f"元素覆盖不足:JRXML 仅有 {total_jrxml_elements} 个文本元素,"
f"OCR 源有 {ocr_text_count} 个文本元素(覆盖率 {element_coverage:.0%}"
)
else:
element_coverage = 1.0
# 2. 字段名覆盖(英文字段名 vs OCR 中文字段名天然不匹配,权重降低)
jrxml_fields = set(re.findall(r'<field name="([^"]+)"', jrxml))
ocr_field_names = set()
ocr_fields = ocr_result.get("fields", []) if isinstance(ocr_result, dict) else []
for f in ocr_fields:
if isinstance(f, dict):
name = f.get("name", "") or f.get("field_name", "") or f.get("label", "")
if name and len(name) > 1:
ocr_field_names.add(name)
if ocr_field_names and jrxml_fields:
matched = jrxml_fields & ocr_field_names
# 尝试通过 _sanitize_field_name 转义后匹配(中文→_uXXXX_)
sanitized_ocr = {_sanitize_field_name(n) for n in ocr_field_names}
matched_via_sanitize = jrxml_fields & sanitized_ocr
all_matched = matched | matched_via_sanitize
field_coverage = len(all_matched) / max(len(ocr_field_names), 1)
still_unmatched = {n for n in ocr_field_names
if n not in jrxml_fields and _sanitize_field_name(n) not in jrxml_fields}
if still_unmatched:
sample = list(still_unmatched)[:8]
issues.append(f"OCR 字段未在 JRXML 中声明: {', '.join(sample)}")
elif ocr_field_names and not jrxml_fields:
field_coverage = 0.0
issues.append("JRXML 中未声明任何字段,但 OCR 提取了结构化字段数据")
else:
field_coverage = 1.0
# 3. 列数对比
if isinstance(layout_schema, dict):
ocr_columns = layout_schema.get("total_columns", 0) or layout_schema.get("columns", 0)
# 从 detail band 中的元素 x 坐标估算列数
detail_match = re.search(r"<band[^>]*height=\"(\d+)\"[^>]*>([\s\S]*?)</band>", jrxml)
if detail_match and ocr_columns > 0:
detail_content = detail_match.group(2)
x_positions = set()
for m in re.finditer(r'x="(\d+)"', detail_content):
x_positions.add(int(m.group(1)))
jrxml_columns = len(x_positions) if x_positions else 1
if jrxml_columns < ocr_columns * 0.5:
issues.append(
f"列数不足:JRXML detail band 检测到 {jrxml_columns} 列,"
f"OCR 布局分析有 {ocr_columns}"
)
# 综合评分(只依赖元素覆盖,字段名语言不匹配是预期行为)
# field_coverage 仅作信息提示,不参与评分计算
score = round(element_coverage, 3)
return {
"score": score,
"field_coverage": round(field_coverage, 3),
"element_coverage": round(element_coverage, 3),
"issues": issues,
}
@log_node("validate")
def validate(state: AgentState) -> Dict:
"""根据 FastAPI 验证服务验证当前 JRXML。"""
jrxml = state.get("current_jrxml", "")
if not jrxml:
state["status"] = "fail"
state["error_msg"] = "没有 JRXML 内容可供验证。"
return state
# 过短的内容不可能是合法报表(最小骨架约 500+ 字符)
if len(jrxml.strip()) < 200:
state["status"] = "fail"
state["error_msg"] = f"JRXML 内容过短({len(jrxml.strip())} 字符),可能为不完整或空内容。"
return state
# 自动规范化 JRXML 元素顺序(符合 XSD sequence 要求)
try:
from backend.jrxml_reorder import normalize_jrxml
jrxml = normalize_jrxml(jrxml)
state["current_jrxml"] = jrxml
except Exception:
pass # 规范化失败不影响后续流程
result = validate_jrxml(jrxml)
state["status"] = "pass" if result.get("valid") else "fail"
state["error_msg"] = result.get("error", "")
state["service_unavailable"] = result.get("service_unavailable", False)
# OCR 保真度检查:比对生成结果与原始图片的 OCR 提取内容
fidelity = _check_ocr_fidelity(jrxml, state)
state["ocr_fidelity"] = fidelity
if fidelity["issues"]:
if state["status"] == "pass":
# XSD 通过但内容保真度不足 → 降级为 fail
# 需要 score < 0.5 且 element_coverage < 0.4(字段名语言不匹配不应单独导致 fail)
if fidelity["score"] < 0.5 and fidelity.get("element_coverage", 0) < 0.4:
state["status"] = "fail"
state["error_msg"] = (
f"[内容保真度不足] 得分 {fidelity['score']:.2f}/1.0。"
+ " ".join(fidelity["issues"][:3])
)
_node_log.warning(
f"OCR 保真度得分 {fidelity['score']:.2f}XSD 通过但内容差异过大: "
+ "; ".join(fidelity["issues"][:5])
)
else:
_node_log.info(
f"OCR 保真度得分 {fidelity['score']:.2f}XSD 通过,轻微差异: "
+ "; ".join(fidelity["issues"][:3])
)
else:
_node_log.info(
f"XSD 验证失败 + OCR 保真度得分 {fidelity['score']:.2f}: "
+ "; ".join(fidelity["issues"][:3])
)
# ── 像素级对比:将 JRXML 渲染为 PNG,与原始上传图片进行 SSIM 比较 ──
source_image = state.get("uploaded_file_path", "")
if source_image and os.path.isfile(source_image) and state["status"] == "pass":
tmpdir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "tmp")
rendered_png = os.path.join(tmpdir, "_pixel_test.png")
if _render_jrxml_to_png(jrxml, rendered_png):
pixel_result = _compute_pixel_similarity(rendered_png, source_image)
state["pixel_fidelity"] = pixel_result
if pixel_result["error"]:
_node_log.warning(f"像素对比失败: {pixel_result['error']}")
else:
_node_log.info(
f"像素对比: SSIM={pixel_result['ssim']:.4f}, "
f"Diff={pixel_result['diff_pct']:.2%}"
)
# SSIM < 0.4 或 diff > 60% → 质量不合格
if pixel_result["ssim"] < 0.4 and pixel_result["diff_pct"] > 0.6:
state["status"] = "fail"
state["error_msg"] = (
f"[像素保真度不足] SSIM={pixel_result['ssim']:.3f}, "
f"差异像素占比={pixel_result['diff_pct']:.2%}"
f"渲染结果与原始图片差异过大,需调整布局。"
)
# 修正成功后记录到错误知识库
if result.get("valid") and state.get("retry_count", 0) > 0:
case = state.get("last_error_case", {})
if case and case.get("error_msg"):
try:
from backend.error_kb import record_error
recorded = record_error(
error_msg=case["error_msg"],
bad_jrxml=case.get("bad_jrxml", ""),
good_jrxml=jrxml,
correction_prompt=case.get("correction_prompt", ""),
retry_count=state.get("retry_count", 0),
)
if recorded:
state["conversation_history"].append({
"role": "system",
"content": f"[系统] 错误案例已记录到知识库(指纹: {case['error_msg'][:40]}...",
})
except Exception:
pass # 知识库写入不影响主流程
return state
@log_node("explain_error")
def explain_error(state: AgentState) -> Dict:
"""生成验证错误的可读解释。"""
llm = get_llm(caller="explain_error")
jrxml = state.get("current_jrxml", "")
lines = jrxml.split("\n")[:80]
snippet = "\n".join(lines)
prompt = load_prompt("explain_error").format(
error_msg=state.get("error_msg", "未知错误"),
jrxml_snippet=snippet,
)
resp = llm.invoke(prompt)
state["natural_explanation"] = resp.content.strip()
return state
@log_node("correct_jrxml")
def correct_jrxml(state: AgentState) -> Dict:
"""尝试自动修正验证失败的 JRXML。"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
llm = get_llm(caller="correct_jrxml", max_tokens=32768)
ocr_context = _format_ocr_context(state)
layout_schema = state.get("layout_schema", {})
layout_text = ""
if isinstance(layout_schema, dict):
layout_text = layout_schema.get("schema_text", "")
# 构建保真度上下文(告诉 LLM 图片与模板的差异)
fidelity = state.get("ocr_fidelity", {})
fidelity_text = ""
if fidelity and fidelity.get("score", 1.0) < 0.9:
fidelity_text = (
f"[内容保真度警告] 得分 {fidelity.get('score', 0):.2f}/1.0\n"
+ "\n".join(f"- {issue}" for issue in fidelity.get("issues", []))
)
# 像素级对比上下文
pixel_fidelity = state.get("pixel_fidelity", {})
if pixel_fidelity and pixel_fidelity.get("ssim", 1.0) < 0.7:
fidelity_parts = [fidelity_text] if fidelity_text else []
fidelity_parts.append(
f"[像素保真度] SSIM={pixel_fidelity.get('ssim', 0):.4f}, "
f"像素差异={pixel_fidelity.get('diff_pct', 0):.2%}"
f"渲染结果与原图差异过大,请调整元素位置、尺寸和布局。"
)
fidelity_text = "\n".join(fidelity_parts)
prompt = load_prompt("correction").format(
current_jrxml=state.get("current_jrxml", ""),
error_msg=state.get("error_msg", ""),
explanation=state.get("natural_explanation", ""),
ocr_context=ocr_context,
layout_schema_text=layout_text,
fidelity_context=fidelity_text,
template_context=_build_template_context(state),
)
# 保存修正前状态(供 validate 判断是否写入错误知识库)
state["last_error_case"] = {
"error_msg": state.get("error_msg", ""),
"bad_jrxml": state.get("current_jrxml", ""),
"correction_prompt": prompt,
}
prev_jrxml = state.get("current_jrxml", "")
full_text = _generate_with_continuation(llm, prompt, writer, "correct_jrxml")
if not full_text.strip():
_node_log.error("correct_jrxml LLM 返回空响应,保留原版本")
state["retry_count"] = state.get("retry_count", 0) + 1
return state
jrxml = _extract_jrxml(full_text)
if len(jrxml.strip()) < 200:
_node_log.warning(f"correct_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
jrxml = prev_jrxml
# 如果提取结果不是合法 JRXML(不含 <jasperReport),说明 LLM 返回了 HTML 等垃圾输出
if jrxml and "<jasperReport" not in jrxml and "<?xml" not in jrxml:
_node_log.warning(
f"correct_jrxml 输出不是合法 JRXML{jrxml[:100]}),回退到前一版本"
)
jrxml = prev_jrxml
# 去重检测:如果输出与输入完全相同(忽略空白差异),说明修正无效
_prev_norm = re.sub(r"\s+", "", prev_jrxml) if prev_jrxml else ""
_new_norm = re.sub(r"\s+", "", jrxml) if jrxml else ""
if _prev_norm and _new_norm and _prev_norm == _new_norm:
_node_log.warning(
f"correct_jrxml 输出与输入完全相同({len(jrxml)} 字符),修正无效,加速消耗 retry"
)
state["retry_count"] = state.get("retry_count", 0) + 2
else:
state["retry_count"] = state.get("retry_count", 0) + 1
state["current_jrxml"] = jrxml
state["conversation_history"].append(
{"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"}
)
return state
@log_node("finalize")
def finalize(state: AgentState) -> Dict:
"""保存最终验证通过的 JRXML 并更新对话历史 + 版本记录。"""
jrxml = state.get("current_jrxml", "")
status = state.get("status", "")
if status == "pass":
state["final_jrxml"] = jrxml
if jrxml.strip():
versions = state.get("jrxml_versions", [])
if not isinstance(versions, list):
versions = []
intent = state.get("intent", "")
label_map = {
"initial_generation": "初始生成",
"modify_report": "修改",
"correct_jrxml": f"自动修正 (第{state.get('retry_count', 1)}次)",
}
versions.append({
"ts": _now_iso(),
"jrxml": jrxml,
"intent": intent,
"label": label_map.get(intent, intent),
"status": status,
})
state["jrxml_versions"] = versions
else:
# 验证未通过:不覆盖 final_jrxml,保留上一次成功的版本
retries = state.get("retry_count", 0)
error_msg = state.get("error_msg", "未知错误")
# 保存失败版本到 jrxml_versions(用户可以选择下载)
if jrxml.strip():
versions = state.get("jrxml_versions", [])
if not isinstance(versions, list):
versions = []
versions.append({
"ts": _now_iso(),
"jrxml": jrxml,
"intent": state.get("intent", ""),
"label": f"失败版本 (第{retries}次重试)",
"status": "fail",
"error_msg": error_msg,
})
state["jrxml_versions"] = versions
# 记录失败上下文,下次用户输入时自动注入
state["pending_failure_context"] = {
"error_msg": error_msg,
"bad_jrxml": state.get("current_jrxml", ""),
"retry_count": retries,
"ts": _now_iso(),
}
state["conversation_history"].append({
"role": "assistant",
"content": (
f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n"
f"错误: {error_msg}\n\n"
f"您可以:\n1. 继续描述修改要求,系统将自动重试修复\n2. 点击下载按钮获取当前版本(虽未通过 XSD 验证,但可能可在 Studio 中手动修复)"
),
})
return state
def _strip_continuation_wrapper(text: str) -> str:
"""去除续写响应中的 markdown 代码块标记和自然语言解释。
续写轮次的 LLM 可能会"忘记"原始 prompt 中的格式要求,
在响应开头加解释文字、用 ``` 包裹 XML 片段。
此函数提取其中的纯 XML 内容,去除包装。
"""
text = text.strip()
# 移除完整的 markdown 代码块包装: ```...```
m = re.search(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", text, re.IGNORECASE)
if m:
inner = m.group(1).strip()
if inner:
return inner
# 移除开头/结尾的独立 ``` 标记(不完整代码块)
text = re.sub(r"^```(?:xml|jrxml)?\s*", "", text)
text = re.sub(r"```\s*$", "", text)
# 移除续写响应常见的自然语言前缀
text = re.sub(
r"^.{0,40}(继续输出|剩余|续写|补全|接上).{0,30}[:]?\s*",
"", text, flags=re.IGNORECASE
)
return text.strip()
def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> str:
"""Stream LLM generation with automatic truncation recovery.
After each stream round, checks if the extracted JRXML ends with
</jasperReport>. If truncated, sends a continuation request with
the last 800 chars as anchor context.
Returns combined full text from all rounds.
"""
_jrxml_end = r"</(?:[\w:]+:)?(?:jasperReport|report)>\s*$"
full_text = ""
for round_num in range(max_rounds):
if round_num == 0:
current_prompt = prompt
else:
tail = full_text[-800:] if len(full_text) > 800 else full_text
current_prompt = (
f"[系统指令] 你正在生成的 JRXML 在上一次响应中被截断。\n"
f"已生成内容的最后部分(请从此处继续):\n...{tail}\n\n"
f"请从截断点继续输出剩余内容,不要重复已输出的部分。\n"
f"不要输出 markdown 代码块、解释或任何非 JRXML 的内容。"
)
new_chunks = []
for chunk in llm.stream(current_prompt):
new_chunks.append(chunk)
writer({"type": "stream", "node": node_name, "text": chunk})
new_text = "".join(new_chunks)
if round_num > 0:
new_text = _strip_continuation_wrapper(new_text)
full_text += new_text
jrxml = _extract_jrxml(full_text)
if re.search(_jrxml_end, jrxml, re.IGNORECASE):
break
if not new_text.strip():
_node_log.warning(f"{node_name}{round_num+1}轮续写无输出,停止")
break
else:
_node_log.warning(f"{node_name}{max_rounds}轮续写仍未完整")
return full_text
def _extract_jrxml(text: str) -> str:
"""从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。
处理多种情况:
1. 完整的 markdown 代码块包裹(单轮输出)
2. 混合文本(多轮续写:第一轮无 markdown,续写轮添加了 markdown
3. 纯 JRXML 无包装
"""
text = text.strip()
# 清理 LLM 输出的 ns0: 命名空间前缀和声明
text = text.replace("ns0:", "")
text = re.sub(r'\s+xmlns:ns0="[^"]*"', "", text)
# 检测并提取 markdown 代码块中的内容
# 如果第一个代码块的内容看起来是完整 JRXML(以 <?xml 或 <jasperReport 开头),
# 则返回它;否则跳过该块,回退到其他提取方式。
xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE)
for m in xml_pattern.finditer(text):
content = m.group(1).strip()
if content and (content.startswith("<?xml") or content.startswith("<jasperReport")):
return content
# 非完整 JRXML 片段 — 跳过,继续搜索后续代码块
# 直接匹配 <?xml ... </jasperReport> 或 ... </report>
_jrxml_close = r"</(?:[\w:]+:)?(?:jasperReport|report)>"
jasper_tag = re.search(rf"(<\?xml[\s\S]*?{_jrxml_close})", text, re.IGNORECASE)
if jasper_tag:
return jasper_tag.group(1).strip()
if text.startswith("<?xml") or text.startswith("<jasperReport"):
return text
# 最终回退:尝试在文本中定位 XML 起始和结束
xml_start = text.find("<?xml")
jr_close = re.search(_jrxml_close, text, re.IGNORECASE)
if xml_start >= 0 and jr_close:
jr_end = jr_close.end()
return text[xml_start:jr_end].strip()
return text