"""LangGraph JRXML 生成工作流的节点函数。""" import copy import functools import json import os import re import time from datetime import datetime, timezone from pathlib import Path from typing import Dict from dotenv import load_dotenv from agent.state import AgentState from backend.llm import get_llm from backend.logger import get_logger, set_trace_id from backend.validation import validate_jrxml from prompts.loader import load_prompt load_dotenv() _node_log = get_logger("agent") MAX_RETRY = int(os.getenv("MAX_RETRY", "3")) CONTEXT_MAX_TOKENS = int(os.getenv("CONTEXT_MAX_TOKENS", "6000")) CONTEXT_KEEP_RECENT = int(os.getenv("CONTEXT_KEEP_RECENT", "4")) HISTORY_MAX_SNAPSHOTS = int(os.getenv("HISTORY_MAX_SNAPSHOTS", "10")) def _state_summary(state: AgentState) -> dict: """提取 state 中的关键字段用于日志摘要。""" user_input = state.get("user_input", "") return { "session_id": state.get("session_id", ""), "intent": state.get("intent", ""), "status": state.get("status", ""), "has_jrxml": bool(state.get("current_jrxml", "").strip()), "jrxml_length": len(state.get("current_jrxml", "")), "retry_count": state.get("retry_count", 0), "user_input_preview": user_input[:100] if user_input else "", "conversation_turns": len(state.get("conversation_history", [])), "history_snapshots": len(state.get("history_states", [])), "versions": len(state.get("jrxml_versions", [])), } def log_node(node_name: str): """装饰器:自动记录节点入口、出口和耗时。""" def decorator(func): @functools.wraps(func) def wrapper(state: AgentState, *args, **kwargs): t0 = time.time() _node_log.info( f"[节点入口] {node_name}", extra={"node": node_name, "phase": "entry", "state": _state_summary(state)}, ) try: result = func(state, *args, **kwargs) elapsed_ms = round((time.time() - t0) * 1000) _node_log.info( f"[节点出口] {node_name}", extra={ "node": node_name, "phase": "exit", "duration_ms": elapsed_ms, "state": _state_summary(state), }, ) return result except Exception as e: elapsed_ms = round((time.time() - t0) * 1000) _node_log.error( f"[节点异常] {node_name}: {e}", extra={ "node": node_name, "phase": "error", "duration_ms": elapsed_ms, "error": str(e), "state": _state_summary(state), }, ) raise return wrapper return decorator # ============================================================ # 核心工作流节点 # ============================================================ @log_node("process_input") def process_input(state: AgentState) -> Dict: """记录用户输入到对话历史,重置本轮请求状态。如有上次失败上下文则自动注入。""" user_input = state.get("user_input", "") # 维护全量对话历史(始终记录原始用户消息) full_history = state.get("full_conversation_history", []) full_history.append({"role": "user", "content": user_input, "ts": _now_iso()}) state["full_conversation_history"] = full_history # 自动注入上次失败上下文 pending = state.get("pending_failure_context", {}) if pending and pending.get("error_msg"): failure_note = ( f"[系统提示] 上次生成失败,以下是失败详情,请基于此修正:\n" f"失败原因: {pending['error_msg']}\n" f"上次失败的输出:\n{pending.get('bad_jrxml', '(无输出)')}" ) user_input = f"{failure_note}\n\n---\n用户新输入:\n{user_input}" state["pending_failure_context"] = {} # 维护工作对话历史 conv_history = state.get("conversation_history", []) conv_history.append({"role": "user", "content": user_input}) state["conversation_history"] = conv_history # OCR 单据字段精确提取(处理上传的图片文件) uploaded_path = state.get("uploaded_file_path", "") if uploaded_path and Path(uploaded_path).is_file(): suffix = Path(uploaded_path).suffix.lower() if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"): try: from backend.ocr_extractor import OcrExtractor extractor = OcrExtractor() default_fields = [ "发票代码", "发票号码", "开票日期", "合计金额", "校验码", "价税合计", "总金额", "日期", "金额", "数量", "单价", "税率", "购买方名称", "销售方名称", "货物名称", "规格型号", "不含税金额", "税额", ] ocr_result = extractor.extract(uploaded_path, default_fields) if ocr_result.get("ocr_available"): state["ocr_extraction_result"] = ocr_result _node_log.info( "OCR 字段提取完成", extra={ "file": uploaded_path, "elements": ocr_result.get("total_elements", 0), "fields": len(ocr_result.get("fields", [])), }, ) # 将提取到的字段注入到对话上下文,供 LLM 使用 extracted_fields = ocr_result.get("fields", []) non_empty = [f for f in extracted_fields if f.get("field_value")] if non_empty: lines = ["[OCR 单据字段提取结果]"] for f in non_empty: lines.append( f"- {f['field_name']}: {f['field_value']}" f"(置信度: {f['confidence']:.0%}, 方法: {f['extraction_method']})" ) ocr_context = "\n".join(lines) user_input = f"{ocr_context}\n\n{user_input}" # 同时更新工作对话历史中的最后一条 conv_history[-1]["content"] = user_input # 批注检测(圈选/箭头标记) elements = ocr_result.get("elements", []) if elements: try: from backend.annotation_detector import detect_annotations ann_result = detect_annotations(uploaded_path, elements) if ann_result.get("total", 0) > 0: state["annotation_result"] = ann_result _node_log.info( "批注检测完成", extra={ "circles": len(ann_result.get("circles", [])), "arrows": len(ann_result.get("arrows", [])), }, ) except Exception as e: _node_log.warning(f"批注检测失败: {e}") except Exception as e: _node_log.warning(f"OCR 字段提取失败: {e}") state["ocr_extraction_result"] = {"error": str(e)} state["uploaded_file_path"] = "" # 重置本轮请求字段 state["retry_count"] = 0 state["user_modification_request"] = user_input return state @log_node("save_state_snapshot") def save_state_snapshot(state: AgentState) -> Dict: """保存当前状态快照到 history_states,用于撤销操作。最多保留 N 个版本。""" snapshots = state.get("history_states", []) if not isinstance(snapshots, list): snapshots = [] snapshot = { "current_jrxml": state.get("current_jrxml", ""), "final_jrxml": state.get("final_jrxml", ""), "status": state.get("status", ""), "conversation_history": copy.deepcopy(state.get("conversation_history", [])), "user_modification_request": state.get("user_modification_request", ""), "intent": state.get("intent", ""), } snapshots.append(snapshot) max_snap = HISTORY_MAX_SNAPSHOTS if len(snapshots) > max_snap: snapshots = snapshots[-max_snap:] state["history_states"] = snapshots return state @log_node("classify_intent") def classify_intent(state: AgentState) -> Dict: """使用 LLM 对用户输入进行意图分类(8 种意图)。""" user_input = state.get("user_input", "") has_report = "是" if state.get("current_jrxml", "").strip() else "否" intent = "initial_generation" try: llm = get_llm(caller="classify_intent") prompt = load_prompt("intent_classify").format( has_report=has_report, user_input=user_input[:500], ) resp = llm.invoke(prompt) raw = resp.content.strip().lower() valid_intents = [ "initial_generation", "modify_report", "preview_report", "export_pdf", "export_jrxml", "undo_modification", "consult_question", "reset_session", ] for vi in valid_intents: if vi in raw: intent = vi break else: # 兜底:有报表 → modify_report,无报表 → initial_generation intent = "modify_report" if has_report == "是" else "initial_generation" except Exception: intent = "modify_report" if has_report == "是" else "initial_generation" state["intent"] = intent return state @log_node("handle_consult") def handle_consult(state: AgentState) -> Dict: """处理咨询类问题:调用 LLM 直接回答,不走报表生成流程。""" user_input = state.get("user_input", "") try: llm = get_llm(caller="handle_consult") prompt = load_prompt("consult").format(question=user_input) resp = llm.invoke(prompt) answer = resp.content.strip() except Exception: answer = "抱歉,暂时无法处理您的问题,请稍后再试。" state["consult_answer"] = answer state["conversation_history"].append({"role": "assistant", "content": answer}) state["full_conversation_history"].append( {"role": "assistant", "content": answer, "ts": _now_iso()} ) return state @log_node("handle_undo") def handle_undo(state: AgentState) -> Dict: """撤销上一步修改:从 history_states 恢复最近一个快照。""" snapshots = state.get("history_states", []) if not isinstance(snapshots, list) or not snapshots: state["conversation_history"].append( {"role": "assistant", "content": "没有可撤销的操作。"} ) return state prev = snapshots.pop() state["history_states"] = snapshots state["current_jrxml"] = prev.get("current_jrxml", "") state["final_jrxml"] = prev.get("final_jrxml", "") state["status"] = prev.get("status", "") state["conversation_history"] = prev.get("conversation_history", []) state["user_modification_request"] = prev.get("user_modification_request", "") state["conversation_history"].append( {"role": "assistant", "content": "已撤销上一步修改,恢复到之前的状态。"} ) state["full_conversation_history"].append( {"role": "assistant", "content": "已撤销上一步修改。", "ts": _now_iso()} ) return state @log_node("handle_reset") def handle_reset(state: AgentState) -> Dict: """重置当前会话:清空报表相关状态,保留会话信息。""" state["current_jrxml"] = "" state["final_jrxml"] = "" state["status"] = "" state["error_msg"] = "" state["natural_explanation"] = "" state["user_modification_request"] = "" state["retrieved_context"] = "" state["retry_count"] = 0 state["compressed_history"] = "" state["history_states"] = [] state["intent"] = "initial_generation" state["conversation_history"] = [] state["conversation_history"].append( {"role": "assistant", "content": "会话已重置,请描述您要创建的新报表。"} ) state["full_conversation_history"].append( {"role": "assistant", "content": "会话已重置。", "ts": _now_iso()} ) return state @log_node("count_tokens") def count_tokens(state: AgentState) -> int: """使用 tiktoken(gpt-4o 编码器)计算当前上下文 token 数量。""" try: import tiktoken enc = tiktoken.encoding_for_model("gpt-4o") except Exception: # 回退方案:中英文混合场景下,近似 1 token ≈ 2.5 个字符 text = json.dumps({ "history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:], "jrxml": state.get("current_jrxml", ""), "compressed": state.get("compressed_history", ""), }, ensure_ascii=False) return len(text) // 2.5 text = json.dumps({ "history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:], "jrxml": state.get("current_jrxml", ""), "compressed": state.get("compressed_history", ""), }, ensure_ascii=False) return len(enc.encode(text)) @log_node("manage_context") def manage_context(state: AgentState) -> Dict: """当 token 数量超过阈值时,压缩较早的对话轮次。""" token_count = count_tokens(state) state["current_token_count"] = token_count if token_count <= CONTEXT_MAX_TOKENS: return state full_history = state.get("full_conversation_history", []) if len(full_history) <= CONTEXT_KEEP_RECENT: return state # 最近N轮保留完整,更早的轮次送去压缩 recent = full_history[-CONTEXT_KEEP_RECENT:] older = full_history[:-CONTEXT_KEEP_RECENT] if not older: return state conv_text = json.dumps(older, ensure_ascii=False, indent=2) try: llm = get_llm(caller="manage_context") prompt = load_prompt("compression").format(conversation_text=conv_text) resp = llm.invoke(prompt) new_compressed = resp.content.strip()[:300] except Exception: new_compressed = _simple_compress(older) # 合并已有压缩与新压缩 existing = state.get("compressed_history", "") if existing: state["compressed_history"] = f"{existing}\n---\n{new_compressed}" else: state["compressed_history"] = new_compressed state["conversation_history"] = list(recent) state["current_token_count"] = count_tokens(state) return state @log_node("load_session_node") def load_session_node(state: AgentState) -> Dict: """在请求开始时从磁盘加载会话状态。""" session_id = state.get("session_id", "") if not session_id: return state try: from backend.session import load_session data = load_session(session_id) if data and data.get("agent_state"): saved = data["agent_state"] # 恢复核心字段(不覆盖当前请求的 user_input / stage / session_id) for key in ("conversation_history", "full_conversation_history", "current_jrxml", "final_jrxml", "compressed_history", "session_name", "created_at", "history_states", "ocr_extraction_result", "uploaded_file_path", "annotation_result", "layout_schema", "ocr_elements"): if key in saved and key not in ("user_input", "stage", "session_id"): state[key] = saved[key] state["session_name"] = data.get("session_name", "") state["created_at"] = data.get("created_at", "") except Exception: pass return state @log_node("save_session_node") def save_session_node(state: AgentState) -> Dict: """将当前代理状态持久化到磁盘。""" session_id = state.get("session_id", "") if not session_id: return state try: from backend.session import save_session persistable = {} for key in ("session_id", "conversation_history", "full_conversation_history", "current_jrxml", "final_jrxml", "compressed_history", "status", "error_msg", "history_states", "ocr_extraction_result", "uploaded_file_path", "annotation_result", "layout_schema", "ocr_elements"): if key in state: persistable[key] = state[key] persistable["updated_at"] = _now_iso() session_name = state.get("session_name", "") if not session_name and state.get("conversation_history"): first_user = next( (m["content"][:50] for m in state["conversation_history"] if m.get("role") == "user"), "") if first_user: session_name = first_user save_session(session_id, persistable, session_name) if not state.get("session_name"): state["session_name"] = session_name state["updated_at"] = persistable["updated_at"] except Exception: pass return state def _simple_compress(messages: list[dict]) -> str: """当 LLM 不可用时,基于简单规则的压缩回退方案。""" points = [] for m in messages: if m.get("role") == "user": points.append(f"用户提问:{m['content'][:100]}") return "; ".join(points[-10:]) def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() def _format_row_coordinates(row: dict) -> dict: """将单行 OCR 元素格式化为紧凑的坐标描述,供阶段二 refine_layout 使用。""" if not isinstance(row, dict): return {} elements = row.get("elements", []) if not elements: return {"y_center": row.get("y_center", 0), "columns": []} sorted_elems = sorted(elements, key=lambda e: e.get("x", 0)) cols = [] for ci, e in enumerate(sorted_elems): cols.append({ "col": ci, "x": e.get("x", 0), "y": e.get("y", 0), "w": e.get("w", 0), "h": e.get("h", 0), "font_size": e.get("font_size", 12), "text": e.get("text", ""), }) return {"y_center": row.get("y_center", 0), "columns": cols} def _format_ocr_context(state: AgentState) -> str: """将 OCR 提取结果格式化为 LLM 可用的上下文文本。""" ocr_result = state.get("ocr_extraction_result") if not ocr_result or not isinstance(ocr_result, dict): return "" if ocr_result.get("error"): return "" parts = [] parts.append("[图片OCR识别结果]") total = ocr_result.get("total_elements", 0) if total: parts.append(f"检测到 {total} 个文字元素") # 提取到的字段 fields = ocr_result.get("fields", []) if fields: parts.append("\n提取的结构化字段:") for f in fields: if f.get("field_value"): parts.append( f" - {f['field_name']}: {f['field_value']} " f"(方法={f.get('extraction_method','?')}, " f"置信度={f.get('confidence',0):.2f})" ) # 所有原始文本(用于表格匹配等需要全文的场景) elements = ocr_result.get("elements", []) if elements: parts.append("\n全部文本元素(含坐标):") for e in elements: bbox = e.get("bbox", {}) x, y, w, h = bbox.get("x", 0), bbox.get("y", 0), bbox.get("w", 0), bbox.get("h", 0) parts.append( f" [{x},{y} {w}×{h}] {e['text']} " f"(置信度={e.get('confidence',0):.2f})" ) # 批注检测结果 ann_result = state.get("annotation_result") if ann_result and isinstance(ann_result, dict): try: from backend.annotation_detector import format_annotation_context ann_text = format_annotation_context(ann_result) if ann_text: parts.append("\n" + ann_text) except Exception: pass return "\n".join(parts) @log_node("retrieve") def retrieve(state: AgentState) -> Dict: """在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。""" try: from backend.rag_adapter import search_chunks from backend.error_kb import search_error_cases user_input = state.get("user_input", "") context = search_chunks(user_input, k=5) # 如果有最近错误,同时搜索错误知识库 error_msg = state.get("error_msg", "") if error_msg: error_context = search_error_cases(error_msg, k=2) if error_context: context = f"{context}\n\n[历史错误修正案例]\n{error_context}" state["retrieved_context"] = context except Exception: state["retrieved_context"] = "" return state @log_node("generate") def generate(state: AgentState) -> Dict: """根据用户需求和检索到的上下文生成初始 JRXML。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="generate") user_request = state.get("user_input", "") ocr_text = _format_ocr_context(state) if ocr_text: user_request = f"{ocr_text}\n\n---\n用户需求:\n{user_request}" prompt = load_prompt("initial_generation").format( context=state.get("retrieved_context", ""), user_request=user_request, ) full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "generate", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append({"role": "assistant", "content": jrxml}) return state @log_node("generate_skeleton") def generate_skeleton(state: AgentState) -> Dict: """阶段一:根据压缩的布局 schema 生成骨架 JRXML($F{field_N} 占位)。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="generate_skeleton") schema = state.get("layout_schema", {}) schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else "" user_request = state.get("user_input", "") prompt = load_prompt("skeleton_generation").format( layout_schema=schema_text, context=state.get("retrieved_context", ""), user_request=user_request, ) full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "generate_skeleton", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append({"role": "assistant", "content": jrxml}) return state @log_node("refine_layout") def refine_layout(state: AgentState) -> Dict: """阶段二:使用采样坐标(表头 + 首行数据 + 最后一行)精确调整元素位置。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="refine_layout") ocr_rows = state.get("ocr_elements", []) sampled = {} if isinstance(ocr_rows, list) and len(ocr_rows) >= 1: sampled["header_row"] = _format_row_coordinates(ocr_rows[0]) if len(ocr_rows) > 1: sampled["first_data_row"] = _format_row_coordinates(ocr_rows[1]) if len(ocr_rows) > 2: sampled["last_row"] = _format_row_coordinates(ocr_rows[-1]) sampled_text = json.dumps(sampled, ensure_ascii=False, indent=2) prompt = load_prompt("refine_layout").format( current_jrxml=state.get("current_jrxml", ""), sampled_coordinates=sampled_text, ) full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "refine_layout", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append({"role": "assistant", "content": jrxml}) return state @log_node("map_fields") def map_fields(state: AgentState) -> Dict: """阶段三:将占位字段名替换为 OCR 提取的真实字段名。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="map_fields") ocr_result = state.get("ocr_extraction_result", {}) fields_text = "" if isinstance(ocr_result, dict) and ocr_result.get("fields"): field_descs = [] for f in ocr_result["fields"]: fname = f.get("field_name", "") fval = f.get("field_value", "") if fname: field_descs.append(f" - {fname}: {fval}") if field_descs: fields_text = "提取的字段:\n" + "\n".join(field_descs) if not fields_text: elements = ocr_result.get("elements", []) if isinstance(ocr_result, dict) else [] if elements: texts = [e.get("text", "") for e in elements if e.get("text")] fields_text = "OCR 文本内容:\n" + "\n".join(f" - {t}" for t in texts[:50]) prompt = load_prompt("field_mapping").format( current_jrxml=state.get("current_jrxml", ""), ocr_fields=fields_text, ) full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "map_fields", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append({"role": "assistant", "content": jrxml}) return state @log_node("modify_jrxml") def modify_jrxml(state: AgentState) -> Dict: """根据用户的修改请求修改现有 JRXML。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="modify_jrxml") # 构建对话上下文:压缩摘要 + 最近对话 compressed = state.get("compressed_history", "") recent = state.get("conversation_history", [])[-6:] conv_parts = [] if compressed: conv_parts.append(f"[早期对话摘要]\n{compressed}") conv_parts.append(json.dumps(recent, ensure_ascii=False, indent=2)) conv_text = "\n\n---\n\n".join(conv_parts) prompt = load_prompt("modification").format( current_jrxml=state.get("current_jrxml", ""), conversation_history=conv_text, modification_request=state.get("user_modification_request", ""), ocr_context=_format_ocr_context(state), ) full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "modify_jrxml", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append( { "role": "user", "content": state.get("user_modification_request", ""), } ) state["conversation_history"].append({"role": "assistant", "content": jrxml}) state["full_conversation_history"] = ( list(state.get("full_conversation_history", [])) + [ {"role": "user", "content": state.get("user_modification_request", ""), "ts": _now_iso()}, {"role": "assistant", "content": jrxml, "ts": _now_iso()}, ] ) state["retry_count"] = 0 return state @log_node("validate") def validate(state: AgentState) -> Dict: """根据 FastAPI 验证服务验证当前 JRXML。""" jrxml = state.get("current_jrxml", "") if not jrxml: state["status"] = "fail" state["error_msg"] = "没有 JRXML 内容可供验证。" return state # 过短的内容不可能是合法报表(最小骨架约 500+ 字符) if len(jrxml.strip()) < 200: state["status"] = "fail" state["error_msg"] = f"JRXML 内容过短({len(jrxml.strip())} 字符),可能为不完整或空内容。" return state result = validate_jrxml(jrxml) state["status"] = "pass" if result.get("valid") else "fail" state["error_msg"] = result.get("error", "") # 修正成功后记录到错误知识库 if result.get("valid") and state.get("retry_count", 0) > 0: case = state.get("last_error_case", {}) if case and case.get("error_msg"): try: from backend.error_kb import record_error recorded = record_error( error_msg=case["error_msg"], bad_jrxml=case.get("bad_jrxml", ""), good_jrxml=jrxml, correction_prompt=case.get("correction_prompt", ""), retry_count=state.get("retry_count", 0), ) if recorded: state["conversation_history"].append({ "role": "system", "content": f"[系统] 错误案例已记录到知识库(指纹: {case['error_msg'][:40]}...)", }) except Exception: pass # 知识库写入不影响主流程 return state @log_node("explain_error") def explain_error(state: AgentState) -> Dict: """生成验证错误的可读解释。""" llm = get_llm(caller="explain_error") jrxml = state.get("current_jrxml", "") lines = jrxml.split("\n")[:80] snippet = "\n".join(lines) prompt = load_prompt("explain_error").format( error_msg=state.get("error_msg", "未知错误"), jrxml_snippet=snippet, ) resp = llm.invoke(prompt) state["natural_explanation"] = resp.content.strip() return state @log_node("correct_jrxml") def correct_jrxml(state: AgentState) -> Dict: """尝试自动修正验证失败的 JRXML。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="correct_jrxml") prompt = load_prompt("correction").format( current_jrxml=state.get("current_jrxml", ""), error_msg=state.get("error_msg", ""), explanation=state.get("natural_explanation", ""), ) # 保存修正前状态(供 validate 判断是否写入错误知识库) state["last_error_case"] = { "error_msg": state.get("error_msg", ""), "bad_jrxml": state.get("current_jrxml", ""), "correction_prompt": prompt, } full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "correct_jrxml", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["retry_count"] = state.get("retry_count", 0) + 1 state["conversation_history"].append( {"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"} ) return state @log_node("finalize") def finalize(state: AgentState) -> Dict: """保存最终验证通过的 JRXML 并更新对话历史 + 版本记录。""" jrxml = state.get("current_jrxml", "") status = state.get("status", "") if status == "pass": state["final_jrxml"] = jrxml if jrxml.strip(): versions = state.get("jrxml_versions", []) if not isinstance(versions, list): versions = [] intent = state.get("intent", "") label_map = { "initial_generation": "初始生成", "modify_report": "修改", "correct_jrxml": f"自动修正 (第{state.get('retry_count', 1)}次)", } versions.append({ "ts": _now_iso(), "jrxml": jrxml, "intent": intent, "label": label_map.get(intent, intent), "status": status, }) state["jrxml_versions"] = versions else: # 验证未通过:不覆盖 final_jrxml,保留上一次成功的版本 retries = state.get("retry_count", 0) error_msg = state.get("error_msg", "未知错误") # 记录失败上下文,下次用户输入时自动注入 state["pending_failure_context"] = { "error_msg": error_msg, "bad_jrxml": state.get("current_jrxml", ""), "retry_count": retries, "ts": _now_iso(), } state["conversation_history"].append({ "role": "assistant", "content": ( f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n" f"错误: {error_msg}\n" f"请描述您想要的修改,系统会自动加载失败上下文继续修复。" ), }) return state def _extract_jrxml(text: str) -> str: """从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。""" text = text.strip() xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE) m = xml_pattern.search(text) if m: content = m.group(1).strip() if content: return content # markdown 代码块存在但内容为空 — 回退到直接匹配 jasper_tag = re.search(r"(<\?xml[\s\S]*?)", text, re.IGNORECASE) if jasper_tag: return jasper_tag.group(1).strip() if text.startswith("") if xml_start >= 0 and jr_end > xml_start: return text[xml_start:jr_end + len("")].strip() return text