"""LangGraph JRXML 生成工作流的节点函数。""" import copy import functools import json import os import re import time from datetime import datetime, timezone from typing import Dict from dotenv import load_dotenv from agent.state import AgentState from backend.llm import get_llm from backend.logger import get_logger, set_trace_id from backend.validation import validate_jrxml from prompts.loader import load_prompt load_dotenv() _node_log = get_logger("agent") MAX_RETRY = int(os.getenv("MAX_RETRY", "3")) CONTEXT_MAX_TOKENS = int(os.getenv("CONTEXT_MAX_TOKENS", "6000")) CONTEXT_KEEP_RECENT = int(os.getenv("CONTEXT_KEEP_RECENT", "4")) HISTORY_MAX_SNAPSHOTS = int(os.getenv("HISTORY_MAX_SNAPSHOTS", "10")) def _state_summary(state: AgentState) -> dict: """提取 state 中的关键字段用于日志摘要。""" user_input = state.get("user_input", "") return { "session_id": state.get("session_id", ""), "intent": state.get("intent", ""), "status": state.get("status", ""), "has_jrxml": bool(state.get("current_jrxml", "").strip()), "jrxml_length": len(state.get("current_jrxml", "")), "retry_count": state.get("retry_count", 0), "user_input_preview": user_input[:100] if user_input else "", "conversation_turns": len(state.get("conversation_history", [])), "history_snapshots": len(state.get("history_states", [])), "versions": len(state.get("jrxml_versions", [])), } def log_node(node_name: str): """装饰器:自动记录节点入口、出口和耗时。""" def decorator(func): @functools.wraps(func) def wrapper(state: AgentState, *args, **kwargs): t0 = time.time() _node_log.info( f"[节点入口] {node_name}", extra={"node": node_name, "phase": "entry", "state": _state_summary(state)}, ) try: result = func(state, *args, **kwargs) elapsed_ms = round((time.time() - t0) * 1000) _node_log.info( f"[节点出口] {node_name}", extra={ "node": node_name, "phase": "exit", "duration_ms": elapsed_ms, "state": _state_summary(state), }, ) return result except Exception as e: elapsed_ms = round((time.time() - t0) * 1000) _node_log.error( f"[节点异常] {node_name}: {e}", extra={ "node": node_name, "phase": "error", "duration_ms": elapsed_ms, "error": str(e), "state": _state_summary(state), }, ) raise return wrapper return decorator # ============================================================ # 核心工作流节点 # ============================================================ @log_node("process_input") def process_input(state: AgentState) -> Dict: """记录用户输入到对话历史,重置本轮请求状态。如有上次失败上下文则自动注入。""" user_input = state.get("user_input", "") # 维护全量对话历史(始终记录原始用户消息) full_history = state.get("full_conversation_history", []) full_history.append({"role": "user", "content": user_input, "ts": _now_iso()}) state["full_conversation_history"] = full_history # 自动注入上次失败上下文 pending = state.get("pending_failure_context", {}) if pending and pending.get("error_msg"): failure_note = ( f"[系统提示] 上次生成失败,以下是失败详情,请基于此修正:\n" f"失败原因: {pending['error_msg']}\n" f"上次失败的输出:\n{pending.get('bad_jrxml', '(无输出)')}" ) user_input = f"{failure_note}\n\n---\n用户新输入:\n{user_input}" state["pending_failure_context"] = {} # 维护工作对话历史 conv_history = state.get("conversation_history", []) conv_history.append({"role": "user", "content": user_input}) state["conversation_history"] = conv_history # 重置本轮请求字段 state["retry_count"] = 0 state["user_modification_request"] = user_input return state @log_node("save_state_snapshot") def save_state_snapshot(state: AgentState) -> Dict: """保存当前状态快照到 history_states,用于撤销操作。最多保留 N 个版本。""" snapshots = state.get("history_states", []) if not isinstance(snapshots, list): snapshots = [] snapshot = { "current_jrxml": state.get("current_jrxml", ""), "final_jrxml": state.get("final_jrxml", ""), "status": state.get("status", ""), "conversation_history": copy.deepcopy(state.get("conversation_history", [])), "user_modification_request": state.get("user_modification_request", ""), "intent": state.get("intent", ""), } snapshots.append(snapshot) max_snap = HISTORY_MAX_SNAPSHOTS if len(snapshots) > max_snap: snapshots = snapshots[-max_snap:] state["history_states"] = snapshots return state @log_node("classify_intent") def classify_intent(state: AgentState) -> Dict: """使用 LLM 对用户输入进行意图分类(8 种意图)。""" user_input = state.get("user_input", "") has_report = "是" if state.get("current_jrxml", "").strip() else "否" intent = "initial_generation" try: llm = get_llm(caller="classify_intent") prompt = load_prompt("intent_classify").format( has_report=has_report, user_input=user_input[:500], ) resp = llm.invoke(prompt) raw = resp.content.strip().lower() valid_intents = [ "initial_generation", "modify_report", "preview_report", "export_pdf", "export_jrxml", "undo_modification", "consult_question", "reset_session", ] for vi in valid_intents: if vi in raw: intent = vi break else: # 兜底:有报表 → modify_report,无报表 → initial_generation intent = "modify_report" if has_report == "是" else "initial_generation" except Exception: intent = "modify_report" if has_report == "是" else "initial_generation" state["intent"] = intent return state @log_node("handle_consult") def handle_consult(state: AgentState) -> Dict: """处理咨询类问题:调用 LLM 直接回答,不走报表生成流程。""" user_input = state.get("user_input", "") try: llm = get_llm(caller="handle_consult") prompt = load_prompt("consult").format(question=user_input) resp = llm.invoke(prompt) answer = resp.content.strip() except Exception: answer = "抱歉,暂时无法处理您的问题,请稍后再试。" state["consult_answer"] = answer state["conversation_history"].append({"role": "assistant", "content": answer}) state["full_conversation_history"].append( {"role": "assistant", "content": answer, "ts": _now_iso()} ) return state @log_node("handle_undo") def handle_undo(state: AgentState) -> Dict: """撤销上一步修改:从 history_states 恢复最近一个快照。""" snapshots = state.get("history_states", []) if not isinstance(snapshots, list) or not snapshots: state["conversation_history"].append( {"role": "assistant", "content": "没有可撤销的操作。"} ) return state prev = snapshots.pop() state["history_states"] = snapshots state["current_jrxml"] = prev.get("current_jrxml", "") state["final_jrxml"] = prev.get("final_jrxml", "") state["status"] = prev.get("status", "") state["conversation_history"] = prev.get("conversation_history", []) state["user_modification_request"] = prev.get("user_modification_request", "") state["conversation_history"].append( {"role": "assistant", "content": "已撤销上一步修改,恢复到之前的状态。"} ) state["full_conversation_history"].append( {"role": "assistant", "content": "已撤销上一步修改。", "ts": _now_iso()} ) return state @log_node("handle_reset") def handle_reset(state: AgentState) -> Dict: """重置当前会话:清空报表相关状态,保留会话信息。""" state["current_jrxml"] = "" state["final_jrxml"] = "" state["status"] = "" state["error_msg"] = "" state["natural_explanation"] = "" state["user_modification_request"] = "" state["retrieved_context"] = "" state["retry_count"] = 0 state["compressed_history"] = "" state["history_states"] = [] state["intent"] = "initial_generation" state["conversation_history"] = [] state["conversation_history"].append( {"role": "assistant", "content": "会话已重置,请描述您要创建的新报表。"} ) state["full_conversation_history"].append( {"role": "assistant", "content": "会话已重置。", "ts": _now_iso()} ) return state @log_node("count_tokens") def count_tokens(state: AgentState) -> int: """使用 tiktoken(gpt-4o 编码器)计算当前上下文 token 数量。""" try: import tiktoken enc = tiktoken.encoding_for_model("gpt-4o") except Exception: # 回退方案:中英文混合场景下,近似 1 token ≈ 2.5 个字符 text = json.dumps({ "history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:], "jrxml": state.get("current_jrxml", ""), "compressed": state.get("compressed_history", ""), }, ensure_ascii=False) return len(text) // 2.5 text = json.dumps({ "history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:], "jrxml": state.get("current_jrxml", ""), "compressed": state.get("compressed_history", ""), }, ensure_ascii=False) return len(enc.encode(text)) @log_node("manage_context") def manage_context(state: AgentState) -> Dict: """当 token 数量超过阈值时,压缩较早的对话轮次。""" token_count = count_tokens(state) state["current_token_count"] = token_count if token_count <= CONTEXT_MAX_TOKENS: return state full_history = state.get("full_conversation_history", []) if len(full_history) <= CONTEXT_KEEP_RECENT: return state # 最近N轮保留完整,更早的轮次送去压缩 recent = full_history[-CONTEXT_KEEP_RECENT:] older = full_history[:-CONTEXT_KEEP_RECENT] if not older: return state conv_text = json.dumps(older, ensure_ascii=False, indent=2) try: llm = get_llm(caller="manage_context") prompt = load_prompt("compression").format(conversation_text=conv_text) resp = llm.invoke(prompt) new_compressed = resp.content.strip()[:300] except Exception: new_compressed = _simple_compress(older) # 合并已有压缩与新压缩 existing = state.get("compressed_history", "") if existing: state["compressed_history"] = f"{existing}\n---\n{new_compressed}" else: state["compressed_history"] = new_compressed state["conversation_history"] = list(recent) state["current_token_count"] = count_tokens(state) return state @log_node("load_session_node") def load_session_node(state: AgentState) -> Dict: """在请求开始时从磁盘加载会话状态。""" session_id = state.get("session_id", "") if not session_id: return state try: from backend.session import load_session data = load_session(session_id) if data and data.get("agent_state"): saved = data["agent_state"] # 恢复核心字段(不覆盖当前请求的 user_input / stage) for key in ("conversation_history", "full_conversation_history", "current_jrxml", "final_jrxml", "compressed_history", "session_name", "created_at", "history_states"): if key in saved and key not in ("user_input", "stage"): state[key] = saved[key] state["session_name"] = data.get("session_name", "") state["created_at"] = data.get("created_at", "") except Exception: pass return state @log_node("save_session_node") def save_session_node(state: AgentState) -> Dict: """将当前代理状态持久化到磁盘。""" session_id = state.get("session_id", "") if not session_id: return state try: from backend.session import save_session persistable = {} for key in ("conversation_history", "full_conversation_history", "current_jrxml", "final_jrxml", "compressed_history", "status", "error_msg", "history_states"): if key in state: persistable[key] = state[key] persistable["updated_at"] = _now_iso() session_name = state.get("session_name", "") if not session_name and state.get("conversation_history"): first_user = next( (m["content"][:50] for m in state["conversation_history"] if m.get("role") == "user"), "") if first_user: session_name = first_user save_session(session_id, persistable, session_name) if not state.get("session_name"): state["session_name"] = session_name state["updated_at"] = persistable["updated_at"] except Exception: pass return state def _simple_compress(messages: list[dict]) -> str: """当 LLM 不可用时,基于简单规则的压缩回退方案。""" points = [] for m in messages: if m.get("role") == "user": points.append(f"用户提问:{m['content'][:100]}") return "; ".join(points[-10:]) def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() @log_node("retrieve") def retrieve(state: AgentState) -> Dict: """在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。""" try: from backend.rag_adapter import search_chunks from backend.error_kb import search_error_cases user_input = state.get("user_input", "") context = search_chunks(user_input, k=5) # 如果有最近错误,同时搜索错误知识库 error_msg = state.get("error_msg", "") if error_msg: error_context = search_error_cases(error_msg, k=2) if error_context: context = f"{context}\n\n[历史错误修正案例]\n{error_context}" state["retrieved_context"] = context except Exception: state["retrieved_context"] = "" return state @log_node("generate") def generate(state: AgentState) -> Dict: """根据用户需求和检索到的上下文生成初始 JRXML。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="generate") prompt = load_prompt("initial_generation").format( context=state.get("retrieved_context", ""), user_request=state.get("user_input", ""), ) full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "generate", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append({"role": "assistant", "content": jrxml}) return state @log_node("modify_jrxml") def modify_jrxml(state: AgentState) -> Dict: """根据用户的修改请求修改现有 JRXML。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="modify_jrxml") # 构建对话上下文:压缩摘要 + 最近对话 compressed = state.get("compressed_history", "") recent = state.get("conversation_history", [])[-6:] conv_parts = [] if compressed: conv_parts.append(f"[早期对话摘要]\n{compressed}") conv_parts.append(json.dumps(recent, ensure_ascii=False, indent=2)) conv_text = "\n\n---\n\n".join(conv_parts) prompt = load_prompt("modification").format( current_jrxml=state.get("current_jrxml", ""), conversation_history=conv_text, modification_request=state.get("user_modification_request", ""), ) full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "modify_jrxml", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append( { "role": "user", "content": state.get("user_modification_request", ""), } ) state["conversation_history"].append({"role": "assistant", "content": jrxml}) state["full_conversation_history"] = ( list(state.get("full_conversation_history", [])) + [ {"role": "user", "content": state.get("user_modification_request", ""), "ts": _now_iso()}, {"role": "assistant", "content": jrxml, "ts": _now_iso()}, ] ) state["retry_count"] = 0 return state @log_node("validate") def validate(state: AgentState) -> Dict: """根据 FastAPI 验证服务验证当前 JRXML。""" jrxml = state.get("current_jrxml", "") if not jrxml: state["status"] = "fail" state["error_msg"] = "没有 JRXML 内容可供验证。" return state # 过短的内容不可能是合法报表(最小骨架约 500+ 字符) if len(jrxml.strip()) < 200: state["status"] = "fail" state["error_msg"] = f"JRXML 内容过短({len(jrxml.strip())} 字符),可能为不完整或空内容。" return state result = validate_jrxml(jrxml) state["status"] = "pass" if result.get("valid") else "fail" state["error_msg"] = result.get("error", "") # 修正成功后记录到错误知识库 if result.get("valid") and state.get("retry_count", 0) > 0: case = state.get("last_error_case", {}) if case and case.get("error_msg"): try: from backend.error_kb import record_error recorded = record_error( error_msg=case["error_msg"], bad_jrxml=case.get("bad_jrxml", ""), good_jrxml=jrxml, correction_prompt=case.get("correction_prompt", ""), retry_count=state.get("retry_count", 0), ) if recorded: state["conversation_history"].append({ "role": "system", "content": f"[系统] 错误案例已记录到知识库(指纹: {case['error_msg'][:40]}...)", }) except Exception: pass # 知识库写入不影响主流程 return state @log_node("explain_error") def explain_error(state: AgentState) -> Dict: """生成验证错误的可读解释。""" llm = get_llm(caller="explain_error") jrxml = state.get("current_jrxml", "") lines = jrxml.split("\n")[:80] snippet = "\n".join(lines) prompt = load_prompt("explain_error").format( error_msg=state.get("error_msg", "未知错误"), jrxml_snippet=snippet, ) resp = llm.invoke(prompt) state["natural_explanation"] = resp.content.strip() return state @log_node("correct_jrxml") def correct_jrxml(state: AgentState) -> Dict: """尝试自动修正验证失败的 JRXML。""" from langgraph.config import get_stream_writer writer = get_stream_writer() llm = get_llm(caller="correct_jrxml") prompt = load_prompt("correction").format( current_jrxml=state.get("current_jrxml", ""), error_msg=state.get("error_msg", ""), explanation=state.get("natural_explanation", ""), ) # 保存修正前状态(供 validate 判断是否写入错误知识库) state["last_error_case"] = { "error_msg": state.get("error_msg", ""), "bad_jrxml": state.get("current_jrxml", ""), "correction_prompt": prompt, } full = [] for chunk in llm.stream(prompt): full.append(chunk) writer({"type": "stream", "node": "correct_jrxml", "text": chunk}) jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["retry_count"] = state.get("retry_count", 0) + 1 state["conversation_history"].append( {"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"} ) return state @log_node("finalize") def finalize(state: AgentState) -> Dict: """保存最终验证通过的 JRXML 并更新对话历史 + 版本记录。""" jrxml = state.get("current_jrxml", "") status = state.get("status", "") if status == "pass": state["final_jrxml"] = jrxml if jrxml.strip(): versions = state.get("jrxml_versions", []) if not isinstance(versions, list): versions = [] intent = state.get("intent", "") label_map = { "initial_generation": "初始生成", "modify_report": "修改", "correct_jrxml": f"自动修正 (第{state.get('retry_count', 1)}次)", } versions.append({ "ts": _now_iso(), "jrxml": jrxml, "intent": intent, "label": label_map.get(intent, intent), "status": status, }) state["jrxml_versions"] = versions else: # 验证未通过:不覆盖 final_jrxml,保留上一次成功的版本 retries = state.get("retry_count", 0) error_msg = state.get("error_msg", "未知错误") # 记录失败上下文,下次用户输入时自动注入 state["pending_failure_context"] = { "error_msg": error_msg, "bad_jrxml": state.get("current_jrxml", ""), "retry_count": retries, "ts": _now_iso(), } state["conversation_history"].append({ "role": "assistant", "content": ( f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n" f"错误: {error_msg}\n" f"请描述您想要的修改,系统会自动加载失败上下文继续修复。" ), }) return state def _extract_jrxml(text: str) -> str: """从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。""" text = text.strip() xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE) m = xml_pattern.search(text) if m: content = m.group(1).strip() if content: return content # markdown 代码块存在但内容为空 — 回退到直接匹配 jasper_tag = re.search(r"(<\?xml[\s\S]*?)", text, re.IGNORECASE) if jasper_tag: return jasper_tag.group(1).strip() if text.startswith("") if xml_start >= 0 and jr_end > xml_start: return text[xml_start:jr_end + len("")].strip() return text