4e14334030
- backend/llm.py: per-node max_tokens via get_llm(max_tokens=N), LLM_MAX_TOKENS env var (default 8192) - agent/nodes.py: 5 generation nodes use max_tokens=32768, generate_skeleton retries at 65536 - agent/nodes.py: fix ns:field regex (<field → <[\w:]*field) to handle namespace prefixes - agent/nodes.py: fix correct_jrxml never writing back to state["current_jrxml"] - agent/nodes.py: correct_jrxml rejects non-JRXML output (no <jasperReport tag) - agent/nodes.py: _strip_continuation_wrapper strips markdown/prefixes from continuation rounds - agent/nodes.py: _extract_jrxml iterates multiple markdown code blocks, skips fragments - agent/graph.py: route_after_validate skips correction loop when service_unavailable - agent/graph.py: route_after_save skips validation for empty JRXML - backend/validation.py: returns service_unavailable: True for ConnectError and HTTP 5xx - Docs: CLAUDE.md v14 changelog, README.md LLM_MAX_TOKENS, .env.example LLM_MAX_TOKENS
1634 lines
64 KiB
Python
1634 lines
64 KiB
Python
"""LangGraph JRXML 生成工作流的节点函数。"""
|
||
|
||
import copy
|
||
import functools
|
||
import json
|
||
import os
|
||
import re
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Dict
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
from agent.state import AgentState
|
||
from backend.llm import get_llm
|
||
from backend.logger import get_logger, set_trace_id
|
||
from backend.validation import validate_jrxml
|
||
from prompts.loader import load_prompt
|
||
|
||
load_dotenv(override=True)
|
||
|
||
_node_log = get_logger("agent")
|
||
|
||
MAX_RETRY = int(os.getenv("MAX_RETRY", "5"))
|
||
CONTEXT_MAX_TOKENS = int(os.getenv("CONTEXT_MAX_TOKENS", "6000"))
|
||
CONTEXT_KEEP_RECENT = int(os.getenv("CONTEXT_KEEP_RECENT", "4"))
|
||
HISTORY_MAX_SNAPSHOTS = int(os.getenv("HISTORY_MAX_SNAPSHOTS", "10"))
|
||
|
||
|
||
def _state_summary(state: AgentState) -> dict:
|
||
"""提取 state 中的关键字段用于日志摘要。"""
|
||
user_input = state.get("user_input", "")
|
||
return {
|
||
"session_id": state.get("session_id", ""),
|
||
"intent": state.get("intent", ""),
|
||
"status": state.get("status", ""),
|
||
"has_jrxml": bool(state.get("current_jrxml", "").strip()),
|
||
"jrxml_length": len(state.get("current_jrxml", "")),
|
||
"retry_count": state.get("retry_count", 0),
|
||
"user_input_preview": user_input[:100] if user_input else "",
|
||
"conversation_turns": len(state.get("conversation_history", [])),
|
||
"history_snapshots": len(state.get("history_states", [])),
|
||
"versions": len(state.get("jrxml_versions", [])),
|
||
}
|
||
|
||
|
||
def log_node(node_name: str):
|
||
"""装饰器:自动记录节点入口、出口和耗时。"""
|
||
def decorator(func):
|
||
@functools.wraps(func)
|
||
def wrapper(state: AgentState, *args, **kwargs):
|
||
t0 = time.time()
|
||
_node_log.info(
|
||
f"[节点入口] {node_name}",
|
||
extra={"node": node_name, "phase": "entry", "state": _state_summary(state)},
|
||
)
|
||
try:
|
||
result = func(state, *args, **kwargs)
|
||
elapsed_ms = round((time.time() - t0) * 1000)
|
||
_node_log.info(
|
||
f"[节点出口] {node_name}",
|
||
extra={
|
||
"node": node_name,
|
||
"phase": "exit",
|
||
"duration_ms": elapsed_ms,
|
||
"state": _state_summary(state),
|
||
},
|
||
)
|
||
return result
|
||
except Exception as e:
|
||
elapsed_ms = round((time.time() - t0) * 1000)
|
||
_node_log.error(
|
||
f"[节点异常] {node_name}: {e}",
|
||
extra={
|
||
"node": node_name,
|
||
"phase": "error",
|
||
"duration_ms": elapsed_ms,
|
||
"error": str(e),
|
||
"state": _state_summary(state),
|
||
},
|
||
)
|
||
raise
|
||
return wrapper
|
||
return decorator
|
||
|
||
|
||
# ============================================================
|
||
# 核心工作流节点
|
||
# ============================================================
|
||
|
||
@log_node("process_input")
|
||
def process_input(state: AgentState) -> Dict:
|
||
"""记录用户输入到对话历史,重置本轮请求状态。如有上次失败上下文则自动注入。"""
|
||
user_input = state.get("user_input", "")
|
||
|
||
# 维护全量对话历史(始终记录原始用户消息)
|
||
full_history = state.get("full_conversation_history", [])
|
||
full_history.append({"role": "user", "content": user_input, "ts": _now_iso()})
|
||
state["full_conversation_history"] = full_history
|
||
|
||
# 自动注入上次失败上下文
|
||
pending = state.get("pending_failure_context", {})
|
||
if pending and pending.get("error_msg"):
|
||
failure_note = (
|
||
f"[系统提示] 上次生成失败,以下是失败详情,请基于此修正:\n"
|
||
f"失败原因: {pending['error_msg']}\n"
|
||
f"上次失败的输出:\n{pending.get('bad_jrxml', '(无输出)')}"
|
||
)
|
||
user_input = f"{failure_note}\n\n---\n用户新输入:\n{user_input}"
|
||
state["pending_failure_context"] = {}
|
||
state["_failure_recovery"] = True # 标记本轮为失败恢复,分类器强制 modify_report
|
||
|
||
# 维护工作对话历史
|
||
conv_history = state.get("conversation_history", [])
|
||
conv_history.append({"role": "user", "content": user_input})
|
||
state["conversation_history"] = conv_history
|
||
|
||
# OCR 单据字段精确提取(处理上传的图片文件)
|
||
uploaded_path = state.get("uploaded_file_path", "")
|
||
if uploaded_path and Path(uploaded_path).is_file():
|
||
suffix = Path(uploaded_path).suffix.lower()
|
||
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"):
|
||
try:
|
||
from backend.ocr_extractor import OcrExtractor
|
||
extractor = OcrExtractor()
|
||
default_fields = [
|
||
"发票代码", "发票号码", "开票日期", "合计金额", "校验码",
|
||
"价税合计", "总金额", "日期", "金额", "数量", "单价", "税率",
|
||
"购买方名称", "销售方名称", "货物名称", "规格型号",
|
||
"不含税金额", "税额",
|
||
]
|
||
ocr_result = extractor.extract(uploaded_path, default_fields)
|
||
if ocr_result.get("ocr_available"):
|
||
state["ocr_extraction_result"] = ocr_result
|
||
_node_log.info(
|
||
"OCR 字段提取完成",
|
||
extra={
|
||
"file": uploaded_path,
|
||
"elements": ocr_result.get("total_elements", 0),
|
||
"fields": len(ocr_result.get("fields", [])),
|
||
},
|
||
)
|
||
# 将提取到的字段注入到对话上下文,供 LLM 使用
|
||
extracted_fields = ocr_result.get("fields", [])
|
||
non_empty = [f for f in extracted_fields if f.get("field_value")]
|
||
if non_empty:
|
||
lines = ["[OCR 单据字段提取结果]"]
|
||
for f in non_empty:
|
||
lines.append(
|
||
f"- {f['field_name']}: {f['field_value']}"
|
||
f"(置信度: {f['confidence']:.0%}, 方法: {f['extraction_method']})"
|
||
)
|
||
ocr_context = "\n".join(lines)
|
||
user_input = f"{ocr_context}\n\n{user_input}"
|
||
# 同时更新工作对话历史中的最后一条
|
||
conv_history[-1]["content"] = user_input
|
||
# 批注检测(圈选/箭头标记)
|
||
elements = ocr_result.get("elements", [])
|
||
if elements:
|
||
try:
|
||
from backend.annotation_detector import detect_annotations
|
||
ann_result = detect_annotations(uploaded_path, elements)
|
||
if ann_result.get("total", 0) > 0:
|
||
state["annotation_result"] = ann_result
|
||
_node_log.info(
|
||
"批注检测完成",
|
||
extra={
|
||
"circles": len(ann_result.get("circles", [])),
|
||
"arrows": len(ann_result.get("arrows", [])),
|
||
},
|
||
)
|
||
except Exception as e:
|
||
_node_log.warning(f"批注检测失败: {e}")
|
||
except Exception as e:
|
||
_node_log.warning(f"OCR 字段提取失败: {e}")
|
||
state["ocr_extraction_result"] = {"error": str(e)}
|
||
state["uploaded_file_path"] = ""
|
||
|
||
# ── OCR 两层日志:内容层 + 位置层 ──
|
||
_log_ocr_layers(state)
|
||
|
||
# 重置本轮请求字段
|
||
state["retry_count"] = 0
|
||
state["user_modification_request"] = user_input
|
||
|
||
return state
|
||
|
||
|
||
@log_node("save_state_snapshot")
|
||
def save_state_snapshot(state: AgentState) -> Dict:
|
||
"""保存当前状态快照到 history_states,用于撤销操作。最多保留 N 个版本。"""
|
||
snapshots = state.get("history_states", [])
|
||
if not isinstance(snapshots, list):
|
||
snapshots = []
|
||
|
||
snapshot = {
|
||
"current_jrxml": state.get("current_jrxml", ""),
|
||
"final_jrxml": state.get("final_jrxml", ""),
|
||
"status": state.get("status", ""),
|
||
"conversation_history": copy.deepcopy(state.get("conversation_history", [])),
|
||
"user_modification_request": state.get("user_modification_request", ""),
|
||
"intent": state.get("intent", ""),
|
||
}
|
||
snapshots.append(snapshot)
|
||
|
||
max_snap = HISTORY_MAX_SNAPSHOTS
|
||
if len(snapshots) > max_snap:
|
||
snapshots = snapshots[-max_snap:]
|
||
|
||
state["history_states"] = snapshots
|
||
return state
|
||
|
||
|
||
@log_node("classify_intent")
|
||
def classify_intent(state: AgentState) -> Dict:
|
||
"""使用 LLM 对用户输入进行意图分类(8 种意图)。"""
|
||
user_input = state.get("user_input", "")
|
||
has_report = "是" if state.get("current_jrxml", "").strip() else "否"
|
||
|
||
# 失败恢复模式:跳过 LLM 分类,直接走修正流程
|
||
if state.pop("_failure_recovery", False):
|
||
state["intent"] = "modify_report"
|
||
return state
|
||
|
||
intent = "initial_generation"
|
||
try:
|
||
llm = get_llm(caller="classify_intent")
|
||
# 智能截断:保留首部 200 + 尾部 300,避免用户真正输入被中间的长 JRXML 挤掉
|
||
if len(user_input) > 500:
|
||
ui_snippet = user_input[:200] + "\n...[已截断]...\n" + user_input[-300:]
|
||
else:
|
||
ui_snippet = user_input
|
||
prompt = load_prompt("intent_classify").format(
|
||
has_report=has_report,
|
||
user_input=ui_snippet,
|
||
)
|
||
resp = llm.invoke(prompt)
|
||
raw = resp.content.strip().lower()
|
||
|
||
valid_intents = [
|
||
"initial_generation", "modify_report", "preview_report",
|
||
"export_pdf", "export_jrxml", "undo_modification",
|
||
"consult_question", "reset_session",
|
||
]
|
||
for vi in valid_intents:
|
||
if vi in raw:
|
||
intent = vi
|
||
break
|
||
else:
|
||
# 兜底:有报表 → modify_report,无报表 → initial_generation
|
||
intent = "modify_report" if has_report == "是" else "initial_generation"
|
||
except Exception:
|
||
intent = "modify_report" if has_report == "是" else "initial_generation"
|
||
|
||
state["intent"] = intent
|
||
return state
|
||
|
||
|
||
@log_node("handle_consult")
|
||
def handle_consult(state: AgentState) -> Dict:
|
||
"""处理咨询类问题:调用 LLM 直接回答,不走报表生成流程。"""
|
||
user_input = state.get("user_input", "")
|
||
try:
|
||
llm = get_llm(caller="handle_consult")
|
||
prompt = load_prompt("consult").format(question=user_input)
|
||
resp = llm.invoke(prompt)
|
||
answer = resp.content.strip()
|
||
except Exception:
|
||
answer = "抱歉,暂时无法处理您的问题,请稍后再试。"
|
||
|
||
state["consult_answer"] = answer
|
||
state["conversation_history"].append({"role": "assistant", "content": answer})
|
||
state["full_conversation_history"].append(
|
||
{"role": "assistant", "content": answer, "ts": _now_iso()}
|
||
)
|
||
return state
|
||
|
||
|
||
@log_node("handle_undo")
|
||
def handle_undo(state: AgentState) -> Dict:
|
||
"""撤销上一步修改:从 history_states 恢复最近一个快照。"""
|
||
snapshots = state.get("history_states", [])
|
||
if not isinstance(snapshots, list) or not snapshots:
|
||
state["conversation_history"].append(
|
||
{"role": "assistant", "content": "没有可撤销的操作。"}
|
||
)
|
||
return state
|
||
|
||
prev = snapshots.pop()
|
||
state["history_states"] = snapshots
|
||
|
||
state["current_jrxml"] = prev.get("current_jrxml", "")
|
||
state["final_jrxml"] = prev.get("final_jrxml", "")
|
||
state["status"] = prev.get("status", "")
|
||
state["conversation_history"] = prev.get("conversation_history", [])
|
||
state["user_modification_request"] = prev.get("user_modification_request", "")
|
||
|
||
state["conversation_history"].append(
|
||
{"role": "assistant", "content": "已撤销上一步修改,恢复到之前的状态。"}
|
||
)
|
||
state["full_conversation_history"].append(
|
||
{"role": "assistant", "content": "已撤销上一步修改。", "ts": _now_iso()}
|
||
)
|
||
return state
|
||
|
||
|
||
@log_node("handle_reset")
|
||
def handle_reset(state: AgentState) -> Dict:
|
||
"""重置当前会话:清空报表相关状态,保留会话信息。"""
|
||
state["current_jrxml"] = ""
|
||
state["final_jrxml"] = ""
|
||
state["status"] = ""
|
||
state["error_msg"] = ""
|
||
state["natural_explanation"] = ""
|
||
state["user_modification_request"] = ""
|
||
state["retrieved_context"] = ""
|
||
state["retry_count"] = 0
|
||
state["compressed_history"] = ""
|
||
state["history_states"] = []
|
||
state["intent"] = "initial_generation"
|
||
|
||
state["conversation_history"] = []
|
||
state["conversation_history"].append(
|
||
{"role": "assistant", "content": "会话已重置,请描述您要创建的新报表。"}
|
||
)
|
||
state["full_conversation_history"].append(
|
||
{"role": "assistant", "content": "会话已重置。", "ts": _now_iso()}
|
||
)
|
||
return state
|
||
|
||
|
||
@log_node("count_tokens")
|
||
def count_tokens(state: AgentState) -> int:
|
||
"""使用 tiktoken(gpt-4o 编码器)计算当前上下文 token 数量。"""
|
||
try:
|
||
import tiktoken
|
||
enc = tiktoken.encoding_for_model("gpt-4o")
|
||
except Exception:
|
||
# 回退方案:中英文混合场景下,近似 1 token ≈ 2.5 个字符
|
||
text = json.dumps({
|
||
"history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:],
|
||
"jrxml": state.get("current_jrxml", ""),
|
||
"compressed": state.get("compressed_history", ""),
|
||
}, ensure_ascii=False)
|
||
return len(text) // 2.5
|
||
|
||
text = json.dumps({
|
||
"history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:],
|
||
"jrxml": state.get("current_jrxml", ""),
|
||
"compressed": state.get("compressed_history", ""),
|
||
}, ensure_ascii=False)
|
||
return len(enc.encode(text))
|
||
|
||
|
||
@log_node("manage_context")
|
||
def manage_context(state: AgentState) -> Dict:
|
||
"""当 token 数量超过阈值时,压缩较早的对话轮次。"""
|
||
token_count = count_tokens(state)
|
||
state["current_token_count"] = token_count
|
||
|
||
if token_count <= CONTEXT_MAX_TOKENS:
|
||
return state
|
||
|
||
full_history = state.get("full_conversation_history", [])
|
||
if len(full_history) <= CONTEXT_KEEP_RECENT:
|
||
return state
|
||
|
||
# 最近N轮保留完整,更早的轮次送去压缩
|
||
recent = full_history[-CONTEXT_KEEP_RECENT:]
|
||
older = full_history[:-CONTEXT_KEEP_RECENT]
|
||
|
||
if not older:
|
||
return state
|
||
|
||
conv_text = json.dumps(older, ensure_ascii=False, indent=2)
|
||
|
||
try:
|
||
llm = get_llm(caller="manage_context")
|
||
prompt = load_prompt("compression").format(conversation_text=conv_text)
|
||
resp = llm.invoke(prompt)
|
||
new_compressed = resp.content.strip()[:300]
|
||
except Exception:
|
||
new_compressed = _simple_compress(older)
|
||
|
||
# 合并已有压缩与新压缩
|
||
existing = state.get("compressed_history", "")
|
||
if existing:
|
||
state["compressed_history"] = f"{existing}\n---\n{new_compressed}"
|
||
else:
|
||
state["compressed_history"] = new_compressed
|
||
|
||
state["conversation_history"] = list(recent)
|
||
state["current_token_count"] = count_tokens(state)
|
||
return state
|
||
|
||
|
||
@log_node("load_session_node")
|
||
def load_session_node(state: AgentState) -> Dict:
|
||
"""在请求开始时从磁盘加载会话状态。"""
|
||
session_id = state.get("session_id", "")
|
||
if not session_id:
|
||
return state
|
||
|
||
try:
|
||
from backend.session import load_session
|
||
data = load_session(session_id)
|
||
if data and data.get("agent_state"):
|
||
saved = data["agent_state"]
|
||
# 恢复核心字段(不覆盖当前请求的 user_input / stage / session_id)
|
||
for key in ("conversation_history", "full_conversation_history",
|
||
"current_jrxml", "final_jrxml", "compressed_history",
|
||
"session_name", "created_at", "history_states",
|
||
"ocr_extraction_result", "uploaded_file_path",
|
||
"annotation_result", "layout_schema", "ocr_elements"):
|
||
if key in saved and key not in ("user_input", "stage", "session_id"):
|
||
state[key] = saved[key]
|
||
state["session_name"] = data.get("session_name", "")
|
||
state["created_at"] = data.get("created_at", "")
|
||
except Exception:
|
||
_node_log.warning("会话加载失败,使用空状态",
|
||
extra={"session_id": state.get("session_id", "")})
|
||
return state
|
||
|
||
|
||
@log_node("save_session_node")
|
||
def save_session_node(state: AgentState) -> Dict:
|
||
"""将当前代理状态持久化到磁盘。"""
|
||
session_id = state.get("session_id", "")
|
||
if not session_id:
|
||
return state
|
||
|
||
try:
|
||
from backend.session import save_session
|
||
persistable = {}
|
||
for key in ("session_id", "conversation_history", "full_conversation_history",
|
||
"current_jrxml", "final_jrxml", "compressed_history",
|
||
"status", "error_msg", "history_states",
|
||
"ocr_extraction_result", "uploaded_file_path",
|
||
"annotation_result", "layout_schema", "ocr_elements"):
|
||
if key in state:
|
||
persistable[key] = state[key]
|
||
persistable["updated_at"] = _now_iso()
|
||
|
||
session_name = state.get("session_name", "")
|
||
if not session_name and state.get("conversation_history"):
|
||
first_user = next(
|
||
(m["content"][:50] for m in state["conversation_history"]
|
||
if m.get("role") == "user"), "")
|
||
if first_user:
|
||
session_name = first_user
|
||
|
||
save_session(session_id, persistable, session_name)
|
||
if not state.get("session_name"):
|
||
state["session_name"] = session_name
|
||
state["updated_at"] = persistable["updated_at"]
|
||
except Exception:
|
||
_node_log.exception("会话保存失败",
|
||
extra={"session_id": state.get("session_id", "")})
|
||
return state
|
||
|
||
|
||
def _simple_compress(messages: list[dict]) -> str:
|
||
"""当 LLM 不可用时,基于简单规则的压缩回退方案。"""
|
||
points = []
|
||
for m in messages:
|
||
if m.get("role") == "user":
|
||
points.append(f"用户提问:{m['content'][:100]}")
|
||
return "; ".join(points[-10:])
|
||
|
||
|
||
def _now_iso() -> str:
|
||
return datetime.now(timezone.utc).isoformat()
|
||
|
||
|
||
def _format_row_coordinates(row: dict) -> dict:
|
||
"""将单行 OCR 元素格式化为紧凑的坐标描述,供阶段二 refine_layout 使用。"""
|
||
if not isinstance(row, dict):
|
||
return {}
|
||
elements = row.get("elements", [])
|
||
if not elements:
|
||
return {"y_center": row.get("y_center", 0), "columns": []}
|
||
sorted_elems = sorted(elements, key=lambda e: e.get("x", 0))
|
||
cols = []
|
||
for ci, e in enumerate(sorted_elems):
|
||
cols.append({
|
||
"col": ci,
|
||
"x": e.get("x", 0),
|
||
"y": e.get("y", 0),
|
||
"w": e.get("w", 0),
|
||
"h": e.get("h", 0),
|
||
"font_size": e.get("font_size", 12),
|
||
"text": e.get("text", ""),
|
||
})
|
||
return {"y_center": row.get("y_center", 0), "columns": cols}
|
||
|
||
|
||
def _build_sampled_text(ocr_rows: list) -> str:
|
||
"""从 OCR 行数据构建采样坐标 JSON 字符串。"""
|
||
sampled = {}
|
||
if isinstance(ocr_rows, list) and len(ocr_rows) >= 1:
|
||
sampled["header_row"] = _format_row_coordinates(ocr_rows[0])
|
||
if len(ocr_rows) > 1:
|
||
sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1])
|
||
if len(ocr_rows) > 2:
|
||
sampled["last_row"] = _format_row_coordinates(ocr_rows[-1])
|
||
return json.dumps(sampled, ensure_ascii=False, indent=2)
|
||
|
||
|
||
def _extract_band_height(band_xml: str) -> int:
|
||
"""从 <band height="N"> 中提取高度值。"""
|
||
m = re.search(r'<band\b[^>]*\sheight\s*=\s*"(\d+)"', band_xml)
|
||
return int(m.group(1)) if m else 0
|
||
|
||
|
||
def _extract_xml_fragment(text: str) -> str:
|
||
"""从 LLM 响应中提取 XML 片段(去除 markdown 代码块和解释文本)。"""
|
||
text = text.strip()
|
||
# 尝试提取 markdown 代码块内的内容
|
||
m = re.search(r"```(?:xml)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||
if m:
|
||
content = m.group(1).strip()
|
||
if content:
|
||
return content
|
||
# 尝试找到 <band ...>...</band> 片段
|
||
m = re.search(r"(<band\b[\s\S]*?</band>)", text, re.IGNORECASE)
|
||
if m:
|
||
return m.group(1).strip()
|
||
return text
|
||
|
||
|
||
def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
|
||
"""程序化字段映射:将 $F{{field_N}} 替换为 OCR 提取的真实字段名。
|
||
|
||
纯正则替换,不调 LLM。100% 确定性,零内容丢失。
|
||
"""
|
||
result = jrxml
|
||
for i, f in enumerate(ocr_fields):
|
||
placeholder = f"field_{i+1}"
|
||
raw_name = f.get("field_name", "")
|
||
if not raw_name:
|
||
continue
|
||
real_name = _sanitize_field_name(raw_name)
|
||
if real_name == placeholder:
|
||
continue
|
||
# 替换 field 声明: <ns0:field name="field_1" → <ns0:field name="customer_name"
|
||
result = re.sub(
|
||
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
|
||
rf'\g<1>{real_name}\g<2>', result,
|
||
)
|
||
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
|
||
result = result.replace(f'$F{{{placeholder}}}', f'$F{{{real_name}}}')
|
||
return result
|
||
|
||
|
||
def _sanitize_field_name(name: str) -> str:
|
||
"""将 OCR 字段名净化为合法的 JRXML field name(仅 ASCII 字母/数字/下划线)。
|
||
|
||
非 ASCII 字符会被替换为其 Unicode 码点,确保唯一且合法。
|
||
"""
|
||
result = []
|
||
for ch in name:
|
||
if ch.isascii() and (ch.isalnum() or ch == '_'):
|
||
result.append(ch)
|
||
elif ch.isascii():
|
||
result.append('_')
|
||
else:
|
||
# 非 ASCII 转 _uXXXX_ 格式,保留可追溯性
|
||
cp = ord(ch)
|
||
result.append(f'_u{cp:04X}_')
|
||
cleaned = ''.join(result)
|
||
cleaned = cleaned.strip('_')
|
||
if not cleaned:
|
||
return "unnamed_field"
|
||
if cleaned[0].isdigit():
|
||
cleaned = 'f_' + cleaned
|
||
# 压缩连续下划线
|
||
cleaned = re.sub(r'_{2,}', '_', cleaned)
|
||
return cleaned.lower()
|
||
|
||
|
||
def _format_ocr_context(state: AgentState) -> str:
|
||
"""将 OCR 提取结果格式化为 LLM 可用的上下文文本。"""
|
||
ocr_result = state.get("ocr_extraction_result")
|
||
if not ocr_result or not isinstance(ocr_result, dict):
|
||
return ""
|
||
if ocr_result.get("error"):
|
||
return ""
|
||
|
||
parts = []
|
||
parts.append("[图片OCR识别结果]")
|
||
|
||
total = ocr_result.get("total_elements", 0)
|
||
if total:
|
||
parts.append(f"检测到 {total} 个文字元素")
|
||
|
||
# 提取到的字段
|
||
fields = ocr_result.get("fields", [])
|
||
if fields:
|
||
parts.append("\n提取的结构化字段:")
|
||
for f in fields:
|
||
if f.get("field_value"):
|
||
parts.append(
|
||
f" - {f['field_name']}: {f['field_value']} "
|
||
f"(方法={f.get('extraction_method','?')}, "
|
||
f"置信度={f.get('confidence',0):.2f})"
|
||
)
|
||
|
||
# 所有原始文本(用于表格匹配等需要全文的场景)
|
||
elements = ocr_result.get("elements", [])
|
||
if elements:
|
||
parts.append("\n全部文本元素(含坐标):")
|
||
for e in elements:
|
||
bbox = e.get("bbox", {})
|
||
x, y, w, h = bbox.get("x", 0), bbox.get("y", 0), bbox.get("w", 0), bbox.get("h", 0)
|
||
parts.append(
|
||
f" [{x},{y} {w}×{h}] {e['text']} "
|
||
f"(置信度={e.get('confidence',0):.2f})"
|
||
)
|
||
|
||
# 批注检测结果
|
||
ann_result = state.get("annotation_result")
|
||
if ann_result and isinstance(ann_result, dict):
|
||
try:
|
||
from backend.annotation_detector import format_annotation_context
|
||
ann_text = format_annotation_context(ann_result)
|
||
if ann_text:
|
||
parts.append("\n" + ann_text)
|
||
except Exception:
|
||
pass
|
||
|
||
return "\n".join(parts)
|
||
|
||
|
||
def _log_ocr_layers(state: AgentState) -> None:
|
||
"""记录 OCR 两层分离日志:内容层(文本/字段)+ 位置层(布局/坐标)。"""
|
||
# ── 内容层:OCR 文本元素 + 提取的字段 ──
|
||
ocr_result = state.get("ocr_extraction_result")
|
||
ocr_elements = state.get("ocr_elements", [])
|
||
|
||
content_parts = []
|
||
if isinstance(ocr_result, dict) and not ocr_result.get("error"):
|
||
total = ocr_result.get("total_elements", 0)
|
||
fields = ocr_result.get("fields", [])
|
||
non_empty = [f for f in fields if f.get("field_value")]
|
||
if total or non_empty:
|
||
content_parts.append(
|
||
f"OCR 提取: {total} 个文本元素, {len(non_empty)} 个有效字段"
|
||
)
|
||
if isinstance(ocr_elements, list) and ocr_elements:
|
||
elem_count = sum(len(row.get("elements", [])) for row in ocr_elements)
|
||
content_parts.append(
|
||
f"API 注入 OCR 元素: {len(ocr_elements)} 行, {elem_count} 个文本"
|
||
)
|
||
|
||
if content_parts:
|
||
_node_log.info(
|
||
"[内容层] " + " | ".join(content_parts),
|
||
extra={"layer": "content", "phase": "ocr_extraction"},
|
||
)
|
||
|
||
# ── 位置层:布局 schema(行/列/区域)──
|
||
layout = state.get("layout_schema")
|
||
if isinstance(layout, dict) and layout.get("total_rows", 0) > 0:
|
||
region_list = layout.get("regions", [])
|
||
_rn = {"title": "标题", "header": "表头", "data": "数据", "footer": "表尾"}
|
||
region_names = [_rn.get(r["type"], r["type"]) for r in region_list] if isinstance(region_list, list) else []
|
||
cols = layout.get("total_columns", 0)
|
||
rows = layout.get("total_rows", 0)
|
||
regions_label = ", ".join(region_names) if region_names else "标题/表头/数据/表尾"
|
||
_node_log.info(
|
||
f"[位置层] 布局 schema: {cols} 列 × {rows} 行, 区域: {regions_label}",
|
||
extra={
|
||
"layer": "position",
|
||
"phase": "layout_analysis",
|
||
"columns": cols,
|
||
"rows": rows,
|
||
"regions": region_names,
|
||
"a4_confidence": layout.get("a4_confidence", ""),
|
||
},
|
||
)
|
||
|
||
# ── 合并:两阶段处理总结 ──
|
||
has_content = (isinstance(ocr_result, dict) and not ocr_result.get("error")) or \
|
||
(isinstance(ocr_elements, list) and ocr_elements)
|
||
has_layout = isinstance(layout, dict) and layout.get("total_rows", 0) > 0
|
||
|
||
if has_content and has_layout:
|
||
_node_log.info(
|
||
"[合并] 内容层 + 位置层均已就绪 — "
|
||
"注入 prompt: 骨架生成 → 精调布局 → 字段映射",
|
||
extra={"layer": "merge", "pipeline": "skeleton→refine→map_fields"},
|
||
)
|
||
elif has_content and not has_layout:
|
||
_node_log.info(
|
||
"[合并] 仅有内容层 — 使用单阶段 generate(无布局 schema)",
|
||
extra={"layer": "merge", "pipeline": "generate_only"},
|
||
)
|
||
elif has_layout and not has_content:
|
||
_node_log.info(
|
||
"[合并] 仅有位置层 — 使用布局 schema 指导生成",
|
||
extra={"layer": "merge", "pipeline": "layout_only"},
|
||
)
|
||
|
||
|
||
@log_node("retrieve")
|
||
def retrieve(state: AgentState) -> Dict:
|
||
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。
|
||
支持按 KB 隔离搜索 + 模板意图检测。
|
||
"""
|
||
try:
|
||
from backend.rag_adapter import search_chunks
|
||
from backend.error_kb import search_error_cases
|
||
|
||
user_input = state.get("user_input", "")
|
||
kb_id = state.get("kb_id", "")
|
||
context = search_chunks(user_input, k=5, kb_id=kb_id)
|
||
|
||
# 错误知识库
|
||
error_msg = state.get("error_msg", "")
|
||
if error_msg:
|
||
error_context = search_error_cases(error_msg, k=2)
|
||
if error_context:
|
||
context = f"{context}\n\n[历史错误修正案例]\n{error_context}"
|
||
|
||
# 模板意图检测:用户是否提到了模板名?
|
||
template_keywords = _detect_template_intent(user_input)
|
||
if template_keywords and kb_id:
|
||
try:
|
||
from backend.kb_searcher import search_templates_in_kb
|
||
templates = search_templates_in_kb(kb_id, template_keywords, k=1)
|
||
if templates:
|
||
tmpl = templates[0]
|
||
state["kb_template_jrxml"] = tmpl.get("content", "")
|
||
state["kb_template_name"] = (
|
||
tmpl.get("metadata", {}).get("report_name", "")
|
||
)
|
||
context += (
|
||
f"\n\n[匹配到模板: {state['kb_template_name']}]\n"
|
||
f"{state['kb_template_jrxml']}"
|
||
)
|
||
except Exception:
|
||
_node_log.warning("模板检索失败", extra={"kb_id": kb_id})
|
||
|
||
state["retrieved_context"] = context
|
||
except Exception:
|
||
_node_log.exception("RAG 检索失败", extra={"user_input": user_input[:80]})
|
||
state["retrieved_context"] = ""
|
||
return state
|
||
|
||
|
||
def _detect_template_intent(user_input: str) -> str:
|
||
"""检测用户输入中是否包含模板引用意图,返回提取的搜索关键词。"""
|
||
import re
|
||
patterns = [
|
||
r"根据(.+?)模板",
|
||
r"基于(.+?)模板",
|
||
r"参照(.+?)模板",
|
||
r"用(.+?)模板",
|
||
r"使用(.+?)模板",
|
||
r"(.+单)模板",
|
||
]
|
||
for pat in patterns:
|
||
m = re.search(pat, user_input)
|
||
if m:
|
||
kw = m.group(1).strip()
|
||
if len(kw) >= 2:
|
||
return kw
|
||
return ""
|
||
|
||
|
||
def _build_template_context(state: dict) -> str:
|
||
"""构建模板上下文,用于注入生成 prompt。
|
||
优先级:对话上传 > KB 检索 > KB 字段定义。
|
||
"""
|
||
parts = []
|
||
|
||
# 对话中上传的 JRXML 模板
|
||
uploaded = state.get("uploaded_template_jrxml", "")
|
||
if uploaded:
|
||
params = state.get("uploaded_template_params", [])
|
||
param_str = "\n".join(
|
||
f" - {p['name']} ({p.get('type', 'String')})" for p in params
|
||
) if params else "(参数列表未解析)"
|
||
parts.append(
|
||
"[对话上传的模板]\n"
|
||
f"以下为用户上传的 JRXML 模板,请基于此模板进行修改:\n"
|
||
f"模板参数:\n{param_str}\n"
|
||
f"```xml\n{uploaded}\n```"
|
||
)
|
||
|
||
# KB 检索到的模板
|
||
kb_tmpl = state.get("kb_template_jrxml", "")
|
||
kb_name = state.get("kb_template_name", "")
|
||
if kb_tmpl and not uploaded:
|
||
parts.append(
|
||
f"[知识库模板: {kb_name}]\n"
|
||
f"以下为从知识库检索到的 JRXML 模板,请作为结构参考:\n"
|
||
f"```xml\n{kb_tmpl}\n```"
|
||
)
|
||
|
||
# KB 字段定义
|
||
kb_fields = state.get("kb_fields", [])
|
||
if kb_fields:
|
||
field_table = "\n".join(
|
||
f"| {f['name']} | {f.get('description', '')} | {f.get('type', 'String')} |"
|
||
for f in kb_fields
|
||
)
|
||
parts.append(
|
||
"[可用数据字段]\n"
|
||
"生成 JRXML 时请使用以下字段作为 $P{{xxx}} 参数:\n"
|
||
f"| 字段名 | 含义 | 类型 |\n|---|---|---|\n{field_table}"
|
||
)
|
||
|
||
return "\n\n".join(parts)
|
||
|
||
|
||
@log_node("generate")
|
||
def generate(state: AgentState) -> Dict:
|
||
"""根据用户需求和检索到的上下文生成初始 JRXML。"""
|
||
from langgraph.config import get_stream_writer
|
||
|
||
writer = get_stream_writer()
|
||
llm = get_llm(caller="generate", max_tokens=32768)
|
||
|
||
user_request = state.get("user_input", "")
|
||
ocr_text = _format_ocr_context(state)
|
||
if ocr_text:
|
||
user_request = f"{ocr_text}\n\n---\n用户需求:\n{user_request}"
|
||
|
||
prompt = load_prompt("initial_generation").format(
|
||
context=state.get("retrieved_context", ""),
|
||
user_request=user_request,
|
||
template_context=_build_template_context(state),
|
||
)
|
||
full = []
|
||
for chunk in llm.stream(prompt):
|
||
full.append(chunk)
|
||
writer({"type": "stream", "node": "generate", "text": chunk})
|
||
jrxml = _extract_jrxml("".join(full))
|
||
state["current_jrxml"] = jrxml
|
||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||
return state
|
||
|
||
|
||
@log_node("generate_skeleton")
|
||
def generate_skeleton(state: AgentState) -> Dict:
|
||
"""阶段一:根据压缩的布局 schema 生成骨架 JRXML($F{field_N} 占位)。"""
|
||
from langgraph.config import get_stream_writer
|
||
|
||
writer = get_stream_writer()
|
||
|
||
schema = state.get("layout_schema", {})
|
||
schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else ""
|
||
user_request = state.get("user_input", "")
|
||
|
||
prompt = load_prompt("skeleton_generation").format(
|
||
layout_schema=schema_text,
|
||
context=state.get("retrieved_context", ""),
|
||
user_request=user_request,
|
||
template_context=_build_template_context(state),
|
||
)
|
||
llm = get_llm(caller="generate_skeleton", max_tokens=32768)
|
||
|
||
prev_jrxml = state.get("current_jrxml", "")
|
||
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
|
||
if not full_text.strip():
|
||
_node_log.warning("generate_skeleton 首次返回空响应,以更高 max_tokens 重试")
|
||
llm = get_llm(caller="generate_skeleton", max_tokens=65536)
|
||
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
|
||
if not full_text.strip():
|
||
_node_log.error("generate_skeleton LLM 返回空响应(含重试)")
|
||
return state
|
||
jrxml = _extract_jrxml(full_text)
|
||
if len(jrxml.strip()) < 200:
|
||
_node_log.warning(f"generate_skeleton 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||
jrxml = prev_jrxml
|
||
state["current_jrxml"] = jrxml
|
||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||
return state
|
||
|
||
|
||
@log_node("refine_layout")
|
||
def refine_layout(state: AgentState) -> Dict:
|
||
"""阶段二:Band 级窗口化精调 — 拆解骨架 JRXML 为独立 band,逐窗口 LLM 调整坐标。
|
||
|
||
流程:
|
||
1. decompose_jrxml() 拆解为 header + bands + footer
|
||
2. 每个 band 作为一个(或多个,若 >4000 字符)窗口发给 LLM
|
||
3. LLM 只修改该窗口内 reportElement 的 x/y/width/height
|
||
4. reassemble_jrxml() 重组 + validate_element_count() 校验
|
||
"""
|
||
from langgraph.config import get_stream_writer
|
||
from agent.jrxml_windower import (
|
||
decompose_jrxml, split_band_into_windows, reassemble_band_windows,
|
||
reassemble_jrxml, validate_element_count,
|
||
)
|
||
|
||
writer = get_stream_writer()
|
||
llm = get_llm(caller="refine_layout")
|
||
|
||
prev_jrxml = state.get("current_jrxml", "")
|
||
if not prev_jrxml.strip():
|
||
_node_log.warning("refine_layout 无输入 JRXML,跳过")
|
||
return state
|
||
|
||
# 拆解 JRXML
|
||
parts = decompose_jrxml(prev_jrxml)
|
||
if parts is None or parts.band_count == 0:
|
||
_node_log.warning("refine_layout 拆解失败或无 band,跳过")
|
||
return state
|
||
|
||
# 构建采样坐标
|
||
ocr_rows = state.get("ocr_elements", [])
|
||
sampled_text = _build_sampled_text(ocr_rows)
|
||
template_ctx = _build_template_context(state)
|
||
|
||
# 逐 band 窗口化精调
|
||
modified_bands: dict[str, str] = {}
|
||
total_windows = 0
|
||
for band in parts.bands:
|
||
if band.element_count == 0:
|
||
modified_bands[band.label] = band.band_xml
|
||
continue
|
||
|
||
windows = split_band_into_windows(band, max_chars=4000)
|
||
total_windows += len(windows)
|
||
band_results: list[str] = []
|
||
|
||
for wi, win_xml in enumerate(windows):
|
||
prompt = load_prompt("refine_layout").format(
|
||
band_name=band.section_name,
|
||
band_index=band.band_index,
|
||
band_height=_extract_band_height(win_xml),
|
||
window_index=wi + 1,
|
||
total_windows=len(windows),
|
||
xml_fragment=win_xml,
|
||
sampled_coordinates=sampled_text,
|
||
template_context=template_ctx,
|
||
)
|
||
try:
|
||
response = llm.invoke(prompt)
|
||
content = response.content if hasattr(response, "content") else str(response)
|
||
fragment = _extract_xml_fragment(content)
|
||
if fragment:
|
||
band_results.append(fragment)
|
||
writer({"type": "stream", "node": "refine_layout",
|
||
"text": f"[{band.label} 窗口 {wi+1}/{len(windows)} 完成] "})
|
||
else:
|
||
_node_log.warning("refine_layout 窗口 %s/%d 返回空,使用原文",
|
||
band.label, wi + 1)
|
||
band_results.append(win_xml)
|
||
except Exception as e:
|
||
_node_log.warning("refine_layout 窗口 %s/%d LLM 失败: %s,使用原文",
|
||
band.label, wi + 1, e)
|
||
band_results.append(win_xml)
|
||
|
||
if len(band_results) == 1:
|
||
modified_bands[band.label] = band_results[0]
|
||
else:
|
||
modified_bands[band.label] = reassemble_band_windows(band_results)
|
||
|
||
# 重组并校验
|
||
result = reassemble_jrxml(parts, modified_bands)
|
||
validation = validate_element_count(prev_jrxml, result, "refine_layout")
|
||
_node_log.info(
|
||
"refine_layout 窗口化完成: %d bands, %d 窗口, 元素 %d→%d (%.1f%%)",
|
||
parts.band_count, total_windows,
|
||
validation["original"], validation["modified"],
|
||
validation["change_pct"] * 100,
|
||
)
|
||
|
||
if not validation["ok"]:
|
||
_node_log.error("refine_layout 元素丢失过多,回退到骨架版本")
|
||
return state
|
||
|
||
state["current_jrxml"] = result
|
||
state["conversation_history"].append({"role": "assistant", "content": result})
|
||
return state
|
||
|
||
|
||
@log_node("map_fields")
|
||
def map_fields(state: AgentState) -> Dict:
|
||
"""阶段三:程序化字段映射 — 用正则将 $F{field_N} 替换为 OCR 字段名,不调 LLM。
|
||
|
||
仅当 OCR 字段名包含中文等需要语义解释时才回退到 LLM。
|
||
"""
|
||
from agent.jrxml_windower import validate_element_count
|
||
|
||
prev_jrxml = state.get("current_jrxml", "")
|
||
if not prev_jrxml.strip():
|
||
_node_log.warning("map_fields 无输入 JRXML,跳过")
|
||
return state
|
||
|
||
# 提取 OCR 字段列表
|
||
ocr_result = state.get("ocr_extraction_result", {})
|
||
ocr_fields: list[dict] = []
|
||
if isinstance(ocr_result, dict) and ocr_result.get("fields"):
|
||
ocr_fields = ocr_result["fields"]
|
||
|
||
if not ocr_fields:
|
||
_node_log.info("map_fields 无 OCR 字段,保留占位字段名")
|
||
state["conversation_history"].append({"role": "assistant", "content": prev_jrxml})
|
||
return state
|
||
|
||
# 程序化替换(主路径)
|
||
result = _programmatic_map_fields(prev_jrxml, ocr_fields)
|
||
|
||
# 校验
|
||
validation = validate_element_count(prev_jrxml, result, "map_fields")
|
||
_node_log.info(
|
||
"map_fields 程序化完成: %d 个字段, 元素 %d→%d (%.1f%%)",
|
||
len(ocr_fields),
|
||
validation["original"], validation["modified"],
|
||
validation["change_pct"] * 100,
|
||
)
|
||
|
||
if not validation["ok"]:
|
||
_node_log.error("map_fields 元素丢失过多,回退到前一版本")
|
||
return state
|
||
|
||
state["current_jrxml"] = result
|
||
state["conversation_history"].append({"role": "assistant", "content": result})
|
||
return state
|
||
|
||
|
||
@log_node("modify_jrxml")
|
||
def modify_jrxml(state: AgentState) -> Dict:
|
||
"""根据用户的修改请求修改现有 JRXML。"""
|
||
from langgraph.config import get_stream_writer
|
||
|
||
writer = get_stream_writer()
|
||
llm = get_llm(caller="modify_jrxml", max_tokens=32768)
|
||
# 构建对话上下文:压缩摘要 + 最近对话
|
||
compressed = state.get("compressed_history", "")
|
||
recent = state.get("conversation_history", [])[-6:]
|
||
conv_parts = []
|
||
if compressed:
|
||
conv_parts.append(f"[早期对话摘要]\n{compressed}")
|
||
conv_parts.append(json.dumps(recent, ensure_ascii=False, indent=2))
|
||
conv_text = "\n\n---\n\n".join(conv_parts)
|
||
|
||
prompt = load_prompt("modification").format(
|
||
current_jrxml=state.get("current_jrxml", ""),
|
||
conversation_history=conv_text,
|
||
modification_request=state.get("user_modification_request", ""),
|
||
ocr_context=_format_ocr_context(state),
|
||
template_context=_build_template_context(state),
|
||
)
|
||
prev_jrxml = state.get("current_jrxml", "")
|
||
full_text = _generate_with_continuation(llm, prompt, writer, "modify_jrxml")
|
||
if not full_text.strip():
|
||
_node_log.error("modify_jrxml LLM 返回空响应,保留原版本")
|
||
return state
|
||
jrxml = _extract_jrxml(full_text)
|
||
if len(jrxml.strip()) < 200:
|
||
_node_log.warning(f"modify_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||
jrxml = prev_jrxml
|
||
state["current_jrxml"] = jrxml
|
||
state["conversation_history"].append(
|
||
{
|
||
"role": "user",
|
||
"content": state.get("user_modification_request", ""),
|
||
}
|
||
)
|
||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||
state["full_conversation_history"] = (
|
||
list(state.get("full_conversation_history", [])) +
|
||
[
|
||
{"role": "user", "content": state.get("user_modification_request", ""), "ts": _now_iso()},
|
||
{"role": "assistant", "content": jrxml, "ts": _now_iso()},
|
||
]
|
||
)
|
||
state["retry_count"] = 0
|
||
return state
|
||
|
||
|
||
# ── Java renderer config ──────────────────────────────────────────────
|
||
_JAVA_BIN = os.path.join(
|
||
os.environ.get("JAVA_HOME", "C:/Program Files/Java/jdk-21.0.11"),
|
||
"bin", "java.exe"
|
||
)
|
||
_JAVA_JAR_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "lib", "java")
|
||
_JAVA_RENDERER_CP = ";".join([
|
||
os.path.join(_JAVA_JAR_DIR, j) for j in [
|
||
"jasperreports-6.21.0.jar",
|
||
"commons-logging-1.3.5.jar",
|
||
"commons-collections4-4.5.0.jar",
|
||
"commons-beanutils-1.10.1.jar",
|
||
"commons-lang3-3.17.0.jar",
|
||
"commons-digester-2.1.jar",
|
||
"itext-2.1.7.jar",
|
||
"jfreechart-1.5.5.jar",
|
||
"ecj-3.38.0.jar",
|
||
]
|
||
])
|
||
_JAVA_RENDERER_CLASS = "JrxmlRenderer"
|
||
_JAVA_RENDERER_CP = "." + os.pathsep + _JAVA_RENDERER_CP
|
||
|
||
|
||
def _render_jrxml_to_png(jrxml: str, output_path: str, scale: float = 2.0) -> bool:
|
||
"""调用 Java JrxmlRenderer 将 JRXML 渲染为 PNG。
|
||
|
||
返回 True 表示渲染成功,False 表示失败。
|
||
"""
|
||
import subprocess
|
||
import tempfile
|
||
|
||
tmpdir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "tmp")
|
||
os.makedirs(tmpdir, exist_ok=True)
|
||
|
||
jrxml_path = os.path.join(tmpdir, "_render_input.jrxml")
|
||
with open(jrxml_path, "w", encoding="utf-8") as f:
|
||
f.write(jrxml)
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
[_JAVA_BIN, "-cp", _JAVA_RENDERER_CP, _JAVA_RENDERER_CLASS,
|
||
jrxml_path, output_path, str(scale)],
|
||
capture_output=True, text=True, timeout=120,
|
||
cwd=_JAVA_JAR_DIR,
|
||
)
|
||
if result.returncode == 0:
|
||
_node_log.info(f"PNG rendered: {output_path} ({result.stdout.strip()})")
|
||
return True
|
||
else:
|
||
_node_log.warning(f"PNG render failed: {result.stdout.strip()} {result.stderr.strip()}")
|
||
return False
|
||
except Exception as e:
|
||
_node_log.warning(f"PNG render exception: {e}")
|
||
return False
|
||
|
||
|
||
def _compute_pixel_similarity(rendered_png: str, reference_image: str) -> dict:
|
||
"""计算渲染 PNG 与参考图片的像素级相似度。
|
||
|
||
使用 SSIM(结构相似性)作为主要指标,同时返回像素差异比例。
|
||
返回 {"ssim": float, "diff_pct": float, "error": str|None}
|
||
"""
|
||
try:
|
||
import cv2
|
||
import numpy as np
|
||
|
||
rendered = cv2.imread(rendered_png, cv2.IMREAD_GRAYSCALE)
|
||
reference = cv2.imread(reference_image, cv2.IMREAD_GRAYSCALE)
|
||
|
||
if rendered is None:
|
||
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"无法读取渲染图片: {rendered_png}"}
|
||
if reference is None:
|
||
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"无法读取参考图片: {reference_image}"}
|
||
|
||
# Resize rendered to match reference dimensions for comparison
|
||
if rendered.shape != reference.shape:
|
||
rendered = cv2.resize(rendered, (reference.shape[1], reference.shape[0]))
|
||
|
||
# SSIM
|
||
from skimage.metrics import structural_similarity as ssim
|
||
score = ssim(rendered, reference, data_range=255)
|
||
|
||
# Pixel difference percentage
|
||
diff = cv2.absdiff(rendered, reference)
|
||
diff_pct = float(np.count_nonzero(diff > 30)) / diff.size
|
||
|
||
return {"ssim": round(score, 4), "diff_pct": round(diff_pct, 4), "error": None}
|
||
except ImportError as e:
|
||
return {"ssim": 0.0, "diff_pct": 1.0, "error": f"缺少依赖: {e}"}
|
||
except Exception as e:
|
||
return {"ssim": 0.0, "diff_pct": 1.0, "error": str(e)}
|
||
|
||
|
||
def _check_ocr_fidelity(jrxml: str, state: dict) -> dict:
|
||
"""比对生成的 JRXML 与原始图片 OCR 提取内容的保真度。
|
||
|
||
检查维度:
|
||
1. 字段覆盖:OCR 字段名是否在 JRXML <field> 声明中出现
|
||
2. 元素数量:JRXML 中 textField+staticText 数量与 OCR 文本元素数量之比
|
||
3. 列结构:data band 中的列数与 OCR 检测到的列数比对
|
||
"""
|
||
ocr_elements = state.get("ocr_elements", [])
|
||
ocr_result = state.get("ocr_extraction_result", {})
|
||
layout_schema = state.get("layout_schema", {})
|
||
|
||
# 无 OCR 数据时跳过
|
||
if not ocr_elements and not ocr_result:
|
||
return {"score": 1.0, "field_coverage": 1.0, "element_coverage": 1.0, "issues": []}
|
||
|
||
issues = []
|
||
|
||
# 1. 元素数量对比
|
||
text_fields = len(re.findall(r"<textField", jrxml))
|
||
static_texts = len(re.findall(r"<staticText", jrxml))
|
||
total_jrxml_elements = text_fields + static_texts
|
||
|
||
ocr_text_count = 0
|
||
if isinstance(ocr_elements, list):
|
||
ocr_text_count = len([e for e in ocr_elements if isinstance(e, dict) and e.get("text", "").strip()])
|
||
if ocr_text_count == 0 and isinstance(ocr_result, dict):
|
||
ocr_text_count = ocr_result.get("total_elements", 0)
|
||
|
||
if ocr_text_count > 0:
|
||
element_coverage = min(total_jrxml_elements / max(ocr_text_count, 1), 1.0)
|
||
if element_coverage < 0.3:
|
||
issues.append(
|
||
f"元素覆盖不足:JRXML 仅有 {total_jrxml_elements} 个文本元素,"
|
||
f"OCR 源有 {ocr_text_count} 个文本元素(覆盖率 {element_coverage:.0%})"
|
||
)
|
||
else:
|
||
element_coverage = 1.0
|
||
|
||
# 2. 字段名覆盖
|
||
jrxml_fields = set(re.findall(r'<field name="(\w+)"', jrxml))
|
||
ocr_field_names = set()
|
||
ocr_fields = ocr_result.get("fields", []) if isinstance(ocr_result, dict) else []
|
||
for f in ocr_fields:
|
||
if isinstance(f, dict):
|
||
name = f.get("name", "") or f.get("field_name", "") or f.get("label", "")
|
||
if name and len(name) > 1:
|
||
ocr_field_names.add(name)
|
||
|
||
if ocr_field_names and jrxml_fields:
|
||
matched = jrxml_fields & ocr_field_names
|
||
field_coverage = len(matched) / max(len(ocr_field_names), 1)
|
||
unmatched = ocr_field_names - jrxml_fields
|
||
if unmatched:
|
||
sample = list(unmatched)[:8]
|
||
issues.append(f"OCR 字段未在 JRXML 中声明: {', '.join(sample)}")
|
||
elif ocr_field_names and not jrxml_fields:
|
||
field_coverage = 0.0
|
||
issues.append("JRXML 中未声明任何字段,但 OCR 提取了结构化字段数据")
|
||
else:
|
||
field_coverage = 1.0
|
||
|
||
# 3. 列数对比
|
||
if isinstance(layout_schema, dict):
|
||
ocr_columns = layout_schema.get("total_columns", 0) or layout_schema.get("columns", 0)
|
||
# 从 detail band 中的元素 x 坐标估算列数
|
||
detail_match = re.search(r"<band[^>]*height=\"(\d+)\"[^>]*>([\s\S]*?)</band>", jrxml)
|
||
if detail_match and ocr_columns > 0:
|
||
detail_content = detail_match.group(2)
|
||
x_positions = set()
|
||
for m in re.finditer(r'x="(\d+)"', detail_content):
|
||
x_positions.add(int(m.group(1)))
|
||
jrxml_columns = len(x_positions) if x_positions else 1
|
||
if jrxml_columns < ocr_columns * 0.5:
|
||
issues.append(
|
||
f"列数不足:JRXML detail band 检测到 {jrxml_columns} 列,"
|
||
f"OCR 布局分析有 {ocr_columns} 列"
|
||
)
|
||
|
||
# 综合评分
|
||
score = round(field_coverage * 0.5 + element_coverage * 0.5, 3)
|
||
return {
|
||
"score": score,
|
||
"field_coverage": round(field_coverage, 3),
|
||
"element_coverage": round(element_coverage, 3),
|
||
"issues": issues,
|
||
}
|
||
|
||
|
||
@log_node("validate")
|
||
def validate(state: AgentState) -> Dict:
|
||
"""根据 FastAPI 验证服务验证当前 JRXML。"""
|
||
jrxml = state.get("current_jrxml", "")
|
||
if not jrxml:
|
||
state["status"] = "fail"
|
||
state["error_msg"] = "没有 JRXML 内容可供验证。"
|
||
return state
|
||
|
||
# 过短的内容不可能是合法报表(最小骨架约 500+ 字符)
|
||
if len(jrxml.strip()) < 200:
|
||
state["status"] = "fail"
|
||
state["error_msg"] = f"JRXML 内容过短({len(jrxml.strip())} 字符),可能为不完整或空内容。"
|
||
return state
|
||
|
||
# 自动规范化 JRXML 元素顺序(符合 XSD sequence 要求)
|
||
try:
|
||
from backend.jrxml_reorder import normalize_jrxml
|
||
jrxml = normalize_jrxml(jrxml)
|
||
state["current_jrxml"] = jrxml
|
||
except Exception:
|
||
pass # 规范化失败不影响后续流程
|
||
|
||
result = validate_jrxml(jrxml)
|
||
state["status"] = "pass" if result.get("valid") else "fail"
|
||
state["error_msg"] = result.get("error", "")
|
||
state["service_unavailable"] = result.get("service_unavailable", False)
|
||
|
||
# OCR 保真度检查:比对生成结果与原始图片的 OCR 提取内容
|
||
fidelity = _check_ocr_fidelity(jrxml, state)
|
||
state["ocr_fidelity"] = fidelity
|
||
if fidelity["issues"]:
|
||
if state["status"] == "pass":
|
||
# XSD 通过但内容保真度不足 → 降级为 fail
|
||
if fidelity["score"] < 0.5:
|
||
state["status"] = "fail"
|
||
state["error_msg"] = (
|
||
f"[内容保真度不足] 得分 {fidelity['score']:.2f}/1.0。"
|
||
+ " ".join(fidelity["issues"][:3])
|
||
)
|
||
_node_log.warning(
|
||
f"OCR 保真度得分 {fidelity['score']:.2f},XSD 通过但内容差异过大: "
|
||
+ "; ".join(fidelity["issues"][:5])
|
||
)
|
||
else:
|
||
_node_log.info(
|
||
f"OCR 保真度得分 {fidelity['score']:.2f},XSD 通过,轻微差异: "
|
||
+ "; ".join(fidelity["issues"][:3])
|
||
)
|
||
else:
|
||
_node_log.info(
|
||
f"XSD 验证失败 + OCR 保真度得分 {fidelity['score']:.2f}: "
|
||
+ "; ".join(fidelity["issues"][:3])
|
||
)
|
||
|
||
# ── 像素级对比:将 JRXML 渲染为 PNG,与原始上传图片进行 SSIM 比较 ──
|
||
source_image = state.get("uploaded_file_path", "")
|
||
if source_image and os.path.isfile(source_image) and state["status"] == "pass":
|
||
tmpdir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "tmp")
|
||
rendered_png = os.path.join(tmpdir, "_pixel_test.png")
|
||
if _render_jrxml_to_png(jrxml, rendered_png):
|
||
pixel_result = _compute_pixel_similarity(rendered_png, source_image)
|
||
state["pixel_fidelity"] = pixel_result
|
||
if pixel_result["error"]:
|
||
_node_log.warning(f"像素对比失败: {pixel_result['error']}")
|
||
else:
|
||
_node_log.info(
|
||
f"像素对比: SSIM={pixel_result['ssim']:.4f}, "
|
||
f"Diff={pixel_result['diff_pct']:.2%}"
|
||
)
|
||
# SSIM < 0.4 或 diff > 60% → 质量不合格
|
||
if pixel_result["ssim"] < 0.4 and pixel_result["diff_pct"] > 0.6:
|
||
state["status"] = "fail"
|
||
state["error_msg"] = (
|
||
f"[像素保真度不足] SSIM={pixel_result['ssim']:.3f}, "
|
||
f"差异像素占比={pixel_result['diff_pct']:.2%}。"
|
||
f"渲染结果与原始图片差异过大,需调整布局。"
|
||
)
|
||
|
||
# 修正成功后记录到错误知识库
|
||
if result.get("valid") and state.get("retry_count", 0) > 0:
|
||
case = state.get("last_error_case", {})
|
||
if case and case.get("error_msg"):
|
||
try:
|
||
from backend.error_kb import record_error
|
||
|
||
recorded = record_error(
|
||
error_msg=case["error_msg"],
|
||
bad_jrxml=case.get("bad_jrxml", ""),
|
||
good_jrxml=jrxml,
|
||
correction_prompt=case.get("correction_prompt", ""),
|
||
retry_count=state.get("retry_count", 0),
|
||
)
|
||
if recorded:
|
||
state["conversation_history"].append({
|
||
"role": "system",
|
||
"content": f"[系统] 错误案例已记录到知识库(指纹: {case['error_msg'][:40]}...)",
|
||
})
|
||
except Exception:
|
||
pass # 知识库写入不影响主流程
|
||
|
||
return state
|
||
|
||
|
||
@log_node("explain_error")
|
||
def explain_error(state: AgentState) -> Dict:
|
||
"""生成验证错误的可读解释。"""
|
||
llm = get_llm(caller="explain_error")
|
||
jrxml = state.get("current_jrxml", "")
|
||
lines = jrxml.split("\n")[:80]
|
||
snippet = "\n".join(lines)
|
||
|
||
prompt = load_prompt("explain_error").format(
|
||
error_msg=state.get("error_msg", "未知错误"),
|
||
jrxml_snippet=snippet,
|
||
)
|
||
resp = llm.invoke(prompt)
|
||
state["natural_explanation"] = resp.content.strip()
|
||
return state
|
||
|
||
|
||
@log_node("correct_jrxml")
|
||
def correct_jrxml(state: AgentState) -> Dict:
|
||
"""尝试自动修正验证失败的 JRXML。"""
|
||
from langgraph.config import get_stream_writer
|
||
|
||
writer = get_stream_writer()
|
||
llm = get_llm(caller="correct_jrxml", max_tokens=32768)
|
||
ocr_context = _format_ocr_context(state)
|
||
layout_schema = state.get("layout_schema", {})
|
||
layout_text = ""
|
||
if isinstance(layout_schema, dict):
|
||
layout_text = layout_schema.get("schema_text", "")
|
||
|
||
# 构建保真度上下文(告诉 LLM 图片与模板的差异)
|
||
fidelity = state.get("ocr_fidelity", {})
|
||
fidelity_text = ""
|
||
if fidelity and fidelity.get("score", 1.0) < 0.9:
|
||
fidelity_text = (
|
||
f"[内容保真度警告] 得分 {fidelity.get('score', 0):.2f}/1.0\n"
|
||
+ "\n".join(f"- {issue}" for issue in fidelity.get("issues", []))
|
||
)
|
||
|
||
# 像素级对比上下文
|
||
pixel_fidelity = state.get("pixel_fidelity", {})
|
||
if pixel_fidelity and pixel_fidelity.get("ssim", 1.0) < 0.7:
|
||
fidelity_parts = [fidelity_text] if fidelity_text else []
|
||
fidelity_parts.append(
|
||
f"[像素保真度] SSIM={pixel_fidelity.get('ssim', 0):.4f}, "
|
||
f"像素差异={pixel_fidelity.get('diff_pct', 0):.2%}。"
|
||
f"渲染结果与原图差异过大,请调整元素位置、尺寸和布局。"
|
||
)
|
||
fidelity_text = "\n".join(fidelity_parts)
|
||
|
||
prompt = load_prompt("correction").format(
|
||
current_jrxml=state.get("current_jrxml", ""),
|
||
error_msg=state.get("error_msg", ""),
|
||
explanation=state.get("natural_explanation", ""),
|
||
ocr_context=ocr_context,
|
||
layout_schema_text=layout_text,
|
||
fidelity_context=fidelity_text,
|
||
template_context=_build_template_context(state),
|
||
)
|
||
# 保存修正前状态(供 validate 判断是否写入错误知识库)
|
||
state["last_error_case"] = {
|
||
"error_msg": state.get("error_msg", ""),
|
||
"bad_jrxml": state.get("current_jrxml", ""),
|
||
"correction_prompt": prompt,
|
||
}
|
||
|
||
prev_jrxml = state.get("current_jrxml", "")
|
||
full_text = _generate_with_continuation(llm, prompt, writer, "correct_jrxml")
|
||
if not full_text.strip():
|
||
_node_log.error("correct_jrxml LLM 返回空响应,保留原版本")
|
||
state["retry_count"] = state.get("retry_count", 0) + 1
|
||
return state
|
||
jrxml = _extract_jrxml(full_text)
|
||
if len(jrxml.strip()) < 200:
|
||
_node_log.warning(f"correct_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||
jrxml = prev_jrxml
|
||
|
||
# 如果提取结果不是合法 JRXML(不含 <jasperReport),说明 LLM 返回了 HTML 等垃圾输出
|
||
if jrxml and "<jasperReport" not in jrxml and "<?xml" not in jrxml:
|
||
_node_log.warning(
|
||
f"correct_jrxml 输出不是合法 JRXML({jrxml[:100]}),回退到前一版本"
|
||
)
|
||
jrxml = prev_jrxml
|
||
|
||
# 去重检测:如果输出与输入完全相同(忽略空白差异),说明修正无效
|
||
_prev_norm = re.sub(r"\s+", "", prev_jrxml) if prev_jrxml else ""
|
||
_new_norm = re.sub(r"\s+", "", jrxml) if jrxml else ""
|
||
if _prev_norm and _new_norm and _prev_norm == _new_norm:
|
||
_node_log.warning(
|
||
f"correct_jrxml 输出与输入完全相同({len(jrxml)} 字符),修正无效,加速消耗 retry"
|
||
)
|
||
state["retry_count"] = state.get("retry_count", 0) + 2
|
||
else:
|
||
state["retry_count"] = state.get("retry_count", 0) + 1
|
||
state["current_jrxml"] = jrxml
|
||
state["conversation_history"].append(
|
||
{"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"}
|
||
)
|
||
return state
|
||
|
||
|
||
@log_node("finalize")
|
||
def finalize(state: AgentState) -> Dict:
|
||
"""保存最终验证通过的 JRXML 并更新对话历史 + 版本记录。"""
|
||
jrxml = state.get("current_jrxml", "")
|
||
status = state.get("status", "")
|
||
|
||
if status == "pass":
|
||
state["final_jrxml"] = jrxml
|
||
if jrxml.strip():
|
||
versions = state.get("jrxml_versions", [])
|
||
if not isinstance(versions, list):
|
||
versions = []
|
||
intent = state.get("intent", "")
|
||
label_map = {
|
||
"initial_generation": "初始生成",
|
||
"modify_report": "修改",
|
||
"correct_jrxml": f"自动修正 (第{state.get('retry_count', 1)}次)",
|
||
}
|
||
versions.append({
|
||
"ts": _now_iso(),
|
||
"jrxml": jrxml,
|
||
"intent": intent,
|
||
"label": label_map.get(intent, intent),
|
||
"status": status,
|
||
})
|
||
state["jrxml_versions"] = versions
|
||
else:
|
||
# 验证未通过:不覆盖 final_jrxml,保留上一次成功的版本
|
||
retries = state.get("retry_count", 0)
|
||
error_msg = state.get("error_msg", "未知错误")
|
||
# 保存失败版本到 jrxml_versions(用户可以选择下载)
|
||
if jrxml.strip():
|
||
versions = state.get("jrxml_versions", [])
|
||
if not isinstance(versions, list):
|
||
versions = []
|
||
versions.append({
|
||
"ts": _now_iso(),
|
||
"jrxml": jrxml,
|
||
"intent": state.get("intent", ""),
|
||
"label": f"失败版本 (第{retries}次重试)",
|
||
"status": "fail",
|
||
"error_msg": error_msg,
|
||
})
|
||
state["jrxml_versions"] = versions
|
||
# 记录失败上下文,下次用户输入时自动注入
|
||
state["pending_failure_context"] = {
|
||
"error_msg": error_msg,
|
||
"bad_jrxml": state.get("current_jrxml", ""),
|
||
"retry_count": retries,
|
||
"ts": _now_iso(),
|
||
}
|
||
state["conversation_history"].append({
|
||
"role": "assistant",
|
||
"content": (
|
||
f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n"
|
||
f"错误: {error_msg}\n\n"
|
||
f"您可以:\n1. 继续描述修改要求,系统将自动重试修复\n2. 点击下载按钮获取当前版本(虽未通过 XSD 验证,但可能可在 Studio 中手动修复)"
|
||
),
|
||
})
|
||
return state
|
||
|
||
|
||
def _strip_continuation_wrapper(text: str) -> str:
|
||
"""去除续写响应中的 markdown 代码块标记和自然语言解释。
|
||
|
||
续写轮次的 LLM 可能会"忘记"原始 prompt 中的格式要求,
|
||
在响应开头加解释文字、用 ``` 包裹 XML 片段。
|
||
此函数提取其中的纯 XML 内容,去除包装。
|
||
"""
|
||
text = text.strip()
|
||
# 移除完整的 markdown 代码块包装: ```...```
|
||
m = re.search(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||
if m:
|
||
inner = m.group(1).strip()
|
||
if inner:
|
||
return inner
|
||
# 移除开头/结尾的独立 ``` 标记(不完整代码块)
|
||
text = re.sub(r"^```(?:xml|jrxml)?\s*", "", text)
|
||
text = re.sub(r"```\s*$", "", text)
|
||
# 移除续写响应常见的自然语言前缀
|
||
text = re.sub(
|
||
r"^.{0,40}(继续输出|剩余|续写|补全|接上).{0,30}[::]?\s*",
|
||
"", text, flags=re.IGNORECASE
|
||
)
|
||
return text.strip()
|
||
|
||
|
||
def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> str:
|
||
"""Stream LLM generation with automatic truncation recovery.
|
||
|
||
After each stream round, checks if the extracted JRXML ends with
|
||
</jasperReport>. If truncated, sends a continuation request with
|
||
the last 800 chars as anchor context.
|
||
|
||
Returns combined full text from all rounds.
|
||
"""
|
||
_jrxml_end = r"</(?:[\w:]+:)?(?:jasperReport|report)>\s*$"
|
||
full_text = ""
|
||
|
||
for round_num in range(max_rounds):
|
||
if round_num == 0:
|
||
current_prompt = prompt
|
||
else:
|
||
tail = full_text[-800:] if len(full_text) > 800 else full_text
|
||
current_prompt = (
|
||
f"[系统指令] 你正在生成的 JRXML 在上一次响应中被截断。\n"
|
||
f"已生成内容的最后部分(请从此处继续):\n...{tail}\n\n"
|
||
f"请从截断点继续输出剩余内容,不要重复已输出的部分。\n"
|
||
f"不要输出 markdown 代码块、解释或任何非 JRXML 的内容。"
|
||
)
|
||
|
||
new_chunks = []
|
||
for chunk in llm.stream(current_prompt):
|
||
new_chunks.append(chunk)
|
||
writer({"type": "stream", "node": node_name, "text": chunk})
|
||
|
||
new_text = "".join(new_chunks)
|
||
if round_num > 0:
|
||
new_text = _strip_continuation_wrapper(new_text)
|
||
full_text += new_text
|
||
|
||
jrxml = _extract_jrxml(full_text)
|
||
if re.search(_jrxml_end, jrxml, re.IGNORECASE):
|
||
break
|
||
|
||
if not new_text.strip():
|
||
_node_log.warning(f"{node_name} 第{round_num+1}轮续写无输出,停止")
|
||
break
|
||
else:
|
||
_node_log.warning(f"{node_name} 经{max_rounds}轮续写仍未完整")
|
||
|
||
return full_text
|
||
|
||
|
||
def _extract_jrxml(text: str) -> str:
|
||
"""从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。
|
||
|
||
处理多种情况:
|
||
1. 完整的 markdown 代码块包裹(单轮输出)
|
||
2. 混合文本(多轮续写:第一轮无 markdown,续写轮添加了 markdown)
|
||
3. 纯 JRXML 无包装
|
||
"""
|
||
text = text.strip()
|
||
# 检测并提取 markdown 代码块中的内容
|
||
# 如果第一个代码块的内容看起来是完整 JRXML(以 <?xml 或 <jasperReport 开头),
|
||
# 则返回它;否则跳过该块,回退到其他提取方式。
|
||
xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE)
|
||
for m in xml_pattern.finditer(text):
|
||
content = m.group(1).strip()
|
||
if content and (content.startswith("<?xml") or content.startswith("<jasperReport")):
|
||
return content
|
||
# 非完整 JRXML 片段 — 跳过,继续搜索后续代码块
|
||
|
||
# 直接匹配 <?xml ... </jasperReport> 或 ... </report>
|
||
_jrxml_close = r"</(?:[\w:]+:)?(?:jasperReport|report)>"
|
||
jasper_tag = re.search(rf"(<\?xml[\s\S]*?{_jrxml_close})", text, re.IGNORECASE)
|
||
if jasper_tag:
|
||
return jasper_tag.group(1).strip()
|
||
|
||
if text.startswith("<?xml") or text.startswith("<jasperReport"):
|
||
return text
|
||
|
||
# 最终回退:尝试在文本中定位 XML 起始和结束
|
||
xml_start = text.find("<?xml")
|
||
jr_close = re.search(_jrxml_close, text, re.IGNORECASE)
|
||
if xml_start >= 0 and jr_close:
|
||
jr_end = jr_close.end()
|
||
return text[xml_start:jr_end].strip()
|
||
|
||
return text |