6467fd4ae5
- OCR: EasyOCR (primary, ch_sim+en) with PaddleOCR fallback for Windows compatibility - Validation: _check_minimum_content() rejects empty-shell JRXML (no band/textField) - Retry: MAX_RETRY 3→5, exhaustion records pending_failure_context for next-turn auto-injection - Finalize: only saves jrxml_versions on pass, preserves last good final_jrxml on fail - Extract JRXML: improved empty markdown block handling and XML fragment fallback - UI: real-time node progress via placeholder updates, initial "analyzing" feedback - UI: use agent_state (full) instead of node_state (partial) for summary card routing - UI: unknown template_type now gives LLM meaningful image context instead of metadata - Docs: updated CLAUDE.md and CODE_GUIDE.md to reflect all v3 changes Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
571 lines
20 KiB
Python
571 lines
20 KiB
Python
"""LangGraph JRXML 生成工作流的节点函数。"""
|
||
|
||
import copy
|
||
import json
|
||
import os
|
||
import re
|
||
from datetime import datetime, timezone
|
||
from typing import Dict
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
from agent.state import AgentState
|
||
from backend.llm import get_llm
|
||
from backend.validation import validate_jrxml
|
||
from prompts.loader import load_prompt
|
||
|
||
load_dotenv()
|
||
|
||
MAX_RETRY = int(os.getenv("MAX_RETRY", "3"))
|
||
CONTEXT_MAX_TOKENS = int(os.getenv("CONTEXT_MAX_TOKENS", "6000"))
|
||
CONTEXT_KEEP_RECENT = int(os.getenv("CONTEXT_KEEP_RECENT", "4"))
|
||
HISTORY_MAX_SNAPSHOTS = int(os.getenv("HISTORY_MAX_SNAPSHOTS", "10"))
|
||
|
||
|
||
# ============================================================
|
||
# 核心工作流节点
|
||
# ============================================================
|
||
|
||
def process_input(state: AgentState) -> Dict:
|
||
"""记录用户输入到对话历史,重置本轮请求状态。如有上次失败上下文则自动注入。"""
|
||
user_input = state.get("user_input", "")
|
||
|
||
# 维护全量对话历史(始终记录原始用户消息)
|
||
full_history = state.get("full_conversation_history", [])
|
||
full_history.append({"role": "user", "content": user_input, "ts": _now_iso()})
|
||
state["full_conversation_history"] = full_history
|
||
|
||
# 自动注入上次失败上下文
|
||
pending = state.get("pending_failure_context", {})
|
||
if pending and pending.get("error_msg"):
|
||
failure_note = (
|
||
f"[系统提示] 上次生成失败,以下是失败详情,请基于此修正:\n"
|
||
f"失败原因: {pending['error_msg']}\n"
|
||
f"上次失败的输出:\n{pending.get('bad_jrxml', '(无输出)')}"
|
||
)
|
||
user_input = f"{failure_note}\n\n---\n用户新输入:\n{user_input}"
|
||
state["pending_failure_context"] = {}
|
||
|
||
# 维护工作对话历史
|
||
conv_history = state.get("conversation_history", [])
|
||
conv_history.append({"role": "user", "content": user_input})
|
||
state["conversation_history"] = conv_history
|
||
|
||
# 重置本轮请求字段
|
||
state["retry_count"] = 0
|
||
state["user_modification_request"] = user_input
|
||
|
||
return state
|
||
|
||
|
||
def save_state_snapshot(state: AgentState) -> Dict:
|
||
"""保存当前状态快照到 history_states,用于撤销操作。最多保留 N 个版本。"""
|
||
snapshots = state.get("history_states", [])
|
||
if not isinstance(snapshots, list):
|
||
snapshots = []
|
||
|
||
snapshot = {
|
||
"current_jrxml": state.get("current_jrxml", ""),
|
||
"final_jrxml": state.get("final_jrxml", ""),
|
||
"status": state.get("status", ""),
|
||
"conversation_history": copy.deepcopy(state.get("conversation_history", [])),
|
||
"user_modification_request": state.get("user_modification_request", ""),
|
||
"intent": state.get("intent", ""),
|
||
}
|
||
snapshots.append(snapshot)
|
||
|
||
max_snap = HISTORY_MAX_SNAPSHOTS
|
||
if len(snapshots) > max_snap:
|
||
snapshots = snapshots[-max_snap:]
|
||
|
||
state["history_states"] = snapshots
|
||
return state
|
||
|
||
|
||
def classify_intent(state: AgentState) -> Dict:
|
||
"""使用 LLM 对用户输入进行意图分类(8 种意图)。"""
|
||
user_input = state.get("user_input", "")
|
||
has_report = "是" if state.get("current_jrxml", "").strip() else "否"
|
||
|
||
intent = "initial_generation"
|
||
try:
|
||
llm = get_llm()
|
||
prompt = load_prompt("intent_classify").format(
|
||
has_report=has_report,
|
||
user_input=user_input[:500],
|
||
)
|
||
resp = llm.invoke(prompt)
|
||
raw = resp.content.strip().lower()
|
||
|
||
valid_intents = [
|
||
"initial_generation", "modify_report", "preview_report",
|
||
"export_pdf", "export_jrxml", "undo_modification",
|
||
"consult_question", "reset_session",
|
||
]
|
||
for vi in valid_intents:
|
||
if vi in raw:
|
||
intent = vi
|
||
break
|
||
else:
|
||
# 兜底:有报表 → modify_report,无报表 → initial_generation
|
||
intent = "modify_report" if has_report == "是" else "initial_generation"
|
||
except Exception:
|
||
intent = "modify_report" if has_report == "是" else "initial_generation"
|
||
|
||
state["intent"] = intent
|
||
return state
|
||
|
||
|
||
def handle_consult(state: AgentState) -> Dict:
|
||
"""处理咨询类问题:调用 LLM 直接回答,不走报表生成流程。"""
|
||
user_input = state.get("user_input", "")
|
||
try:
|
||
llm = get_llm()
|
||
prompt = load_prompt("consult").format(question=user_input)
|
||
resp = llm.invoke(prompt)
|
||
answer = resp.content.strip()
|
||
except Exception:
|
||
answer = "抱歉,暂时无法处理您的问题,请稍后再试。"
|
||
|
||
state["consult_answer"] = answer
|
||
state["conversation_history"].append({"role": "assistant", "content": answer})
|
||
state["full_conversation_history"].append(
|
||
{"role": "assistant", "content": answer, "ts": _now_iso()}
|
||
)
|
||
return state
|
||
|
||
|
||
def handle_undo(state: AgentState) -> Dict:
|
||
"""撤销上一步修改:从 history_states 恢复最近一个快照。"""
|
||
snapshots = state.get("history_states", [])
|
||
if not isinstance(snapshots, list) or not snapshots:
|
||
state["conversation_history"].append(
|
||
{"role": "assistant", "content": "没有可撤销的操作。"}
|
||
)
|
||
return state
|
||
|
||
prev = snapshots.pop()
|
||
state["history_states"] = snapshots
|
||
|
||
state["current_jrxml"] = prev.get("current_jrxml", "")
|
||
state["final_jrxml"] = prev.get("final_jrxml", "")
|
||
state["status"] = prev.get("status", "")
|
||
state["conversation_history"] = prev.get("conversation_history", [])
|
||
state["user_modification_request"] = prev.get("user_modification_request", "")
|
||
|
||
state["conversation_history"].append(
|
||
{"role": "assistant", "content": "已撤销上一步修改,恢复到之前的状态。"}
|
||
)
|
||
state["full_conversation_history"].append(
|
||
{"role": "assistant", "content": "已撤销上一步修改。", "ts": _now_iso()}
|
||
)
|
||
return state
|
||
|
||
|
||
def handle_reset(state: AgentState) -> Dict:
|
||
"""重置当前会话:清空报表相关状态,保留会话信息。"""
|
||
state["current_jrxml"] = ""
|
||
state["final_jrxml"] = ""
|
||
state["status"] = ""
|
||
state["error_msg"] = ""
|
||
state["natural_explanation"] = ""
|
||
state["user_modification_request"] = ""
|
||
state["retrieved_context"] = ""
|
||
state["retry_count"] = 0
|
||
state["compressed_history"] = ""
|
||
state["history_states"] = []
|
||
state["intent"] = "initial_generation"
|
||
|
||
state["conversation_history"] = []
|
||
state["conversation_history"].append(
|
||
{"role": "assistant", "content": "会话已重置,请描述您要创建的新报表。"}
|
||
)
|
||
state["full_conversation_history"].append(
|
||
{"role": "assistant", "content": "会话已重置。", "ts": _now_iso()}
|
||
)
|
||
return state
|
||
|
||
|
||
def count_tokens(state: AgentState) -> int:
|
||
"""使用 tiktoken(gpt-4o 编码器)计算当前上下文 token 数量。"""
|
||
try:
|
||
import tiktoken
|
||
enc = tiktoken.encoding_for_model("gpt-4o")
|
||
except Exception:
|
||
# 回退方案:中英文混合场景下,近似 1 token ≈ 2.5 个字符
|
||
text = json.dumps({
|
||
"history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:],
|
||
"jrxml": state.get("current_jrxml", ""),
|
||
"compressed": state.get("compressed_history", ""),
|
||
}, ensure_ascii=False)
|
||
return len(text) // 2.5
|
||
|
||
text = json.dumps({
|
||
"history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:],
|
||
"jrxml": state.get("current_jrxml", ""),
|
||
"compressed": state.get("compressed_history", ""),
|
||
}, ensure_ascii=False)
|
||
return len(enc.encode(text))
|
||
|
||
|
||
def manage_context(state: AgentState) -> Dict:
|
||
"""当 token 数量超过阈值时,压缩较早的对话轮次。"""
|
||
token_count = count_tokens(state)
|
||
state["current_token_count"] = token_count
|
||
|
||
if token_count <= CONTEXT_MAX_TOKENS:
|
||
return state
|
||
|
||
full_history = state.get("full_conversation_history", [])
|
||
if len(full_history) <= CONTEXT_KEEP_RECENT:
|
||
return state
|
||
|
||
# 最近N轮保留完整,更早的轮次送去压缩
|
||
recent = full_history[-CONTEXT_KEEP_RECENT:]
|
||
older = full_history[:-CONTEXT_KEEP_RECENT]
|
||
|
||
if not older:
|
||
return state
|
||
|
||
conv_text = json.dumps(older, ensure_ascii=False, indent=2)
|
||
|
||
try:
|
||
llm = get_llm()
|
||
prompt = load_prompt("compression").format(conversation_text=conv_text)
|
||
resp = llm.invoke(prompt)
|
||
new_compressed = resp.content.strip()[:300]
|
||
except Exception:
|
||
new_compressed = _simple_compress(older)
|
||
|
||
# 合并已有压缩与新压缩
|
||
existing = state.get("compressed_history", "")
|
||
if existing:
|
||
state["compressed_history"] = f"{existing}\n---\n{new_compressed}"
|
||
else:
|
||
state["compressed_history"] = new_compressed
|
||
|
||
state["conversation_history"] = list(recent)
|
||
state["current_token_count"] = count_tokens(state)
|
||
return state
|
||
|
||
|
||
def load_session_node(state: AgentState) -> Dict:
|
||
"""在请求开始时从磁盘加载会话状态。"""
|
||
session_id = state.get("session_id", "")
|
||
if not session_id:
|
||
return state
|
||
|
||
try:
|
||
from backend.session import load_session
|
||
data = load_session(session_id)
|
||
if data and data.get("agent_state"):
|
||
saved = data["agent_state"]
|
||
# 恢复核心字段(不覆盖当前请求的 user_input / stage)
|
||
for key in ("conversation_history", "full_conversation_history",
|
||
"current_jrxml", "final_jrxml", "compressed_history",
|
||
"session_name", "created_at", "history_states"):
|
||
if key in saved and key not in ("user_input", "stage"):
|
||
state[key] = saved[key]
|
||
state["session_name"] = data.get("session_name", "")
|
||
state["created_at"] = data.get("created_at", "")
|
||
except Exception:
|
||
pass
|
||
return state
|
||
|
||
|
||
def save_session_node(state: AgentState) -> Dict:
|
||
"""将当前代理状态持久化到磁盘。"""
|
||
session_id = state.get("session_id", "")
|
||
if not session_id:
|
||
return state
|
||
|
||
try:
|
||
from backend.session import save_session
|
||
persistable = {}
|
||
for key in ("conversation_history", "full_conversation_history",
|
||
"current_jrxml", "final_jrxml", "compressed_history",
|
||
"status", "error_msg", "history_states"):
|
||
if key in state:
|
||
persistable[key] = state[key]
|
||
persistable["updated_at"] = _now_iso()
|
||
|
||
session_name = state.get("session_name", "")
|
||
if not session_name and state.get("conversation_history"):
|
||
first_user = next(
|
||
(m["content"][:50] for m in state["conversation_history"]
|
||
if m.get("role") == "user"), "")
|
||
if first_user:
|
||
session_name = first_user
|
||
|
||
save_session(session_id, persistable, session_name)
|
||
if not state.get("session_name"):
|
||
state["session_name"] = session_name
|
||
state["updated_at"] = persistable["updated_at"]
|
||
except Exception:
|
||
pass
|
||
return state
|
||
|
||
|
||
def _simple_compress(messages: list[dict]) -> str:
|
||
"""当 LLM 不可用时,基于简单规则的压缩回退方案。"""
|
||
points = []
|
||
for m in messages:
|
||
if m.get("role") == "user":
|
||
points.append(f"用户提问:{m['content'][:100]}")
|
||
return "; ".join(points[-10:])
|
||
|
||
|
||
def _now_iso() -> str:
|
||
return datetime.now(timezone.utc).isoformat()
|
||
|
||
|
||
def retrieve(state: AgentState) -> Dict:
|
||
"""在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。"""
|
||
try:
|
||
from backend.rag_adapter import search_chunks
|
||
from backend.error_kb import search_error_cases
|
||
|
||
user_input = state.get("user_input", "")
|
||
context = search_chunks(user_input, k=5)
|
||
|
||
# 如果有最近错误,同时搜索错误知识库
|
||
error_msg = state.get("error_msg", "")
|
||
if error_msg:
|
||
error_context = search_error_cases(error_msg, k=2)
|
||
if error_context:
|
||
context = f"{context}\n\n[历史错误修正案例]\n{error_context}"
|
||
|
||
state["retrieved_context"] = context
|
||
except Exception:
|
||
state["retrieved_context"] = ""
|
||
return state
|
||
|
||
|
||
def generate(state: AgentState) -> Dict:
|
||
"""根据用户需求和检索到的上下文生成初始 JRXML。"""
|
||
from langgraph.config import get_stream_writer
|
||
|
||
writer = get_stream_writer()
|
||
llm = get_llm()
|
||
prompt = load_prompt("initial_generation").format(
|
||
context=state.get("retrieved_context", ""),
|
||
user_request=state.get("user_input", ""),
|
||
)
|
||
full = []
|
||
for chunk in llm.stream(prompt):
|
||
full.append(chunk)
|
||
writer({"type": "stream", "node": "generate", "text": chunk})
|
||
jrxml = _extract_jrxml("".join(full))
|
||
state["current_jrxml"] = jrxml
|
||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||
return state
|
||
|
||
|
||
def modify_jrxml(state: AgentState) -> Dict:
|
||
"""根据用户的修改请求修改现有 JRXML。"""
|
||
from langgraph.config import get_stream_writer
|
||
|
||
writer = get_stream_writer()
|
||
llm = get_llm()
|
||
# 构建对话上下文:压缩摘要 + 最近对话
|
||
compressed = state.get("compressed_history", "")
|
||
recent = state.get("conversation_history", [])[-6:]
|
||
conv_parts = []
|
||
if compressed:
|
||
conv_parts.append(f"[早期对话摘要]\n{compressed}")
|
||
conv_parts.append(json.dumps(recent, ensure_ascii=False, indent=2))
|
||
conv_text = "\n\n---\n\n".join(conv_parts)
|
||
|
||
prompt = load_prompt("modification").format(
|
||
current_jrxml=state.get("current_jrxml", ""),
|
||
conversation_history=conv_text,
|
||
modification_request=state.get("user_modification_request", ""),
|
||
)
|
||
full = []
|
||
for chunk in llm.stream(prompt):
|
||
full.append(chunk)
|
||
writer({"type": "stream", "node": "modify_jrxml", "text": chunk})
|
||
jrxml = _extract_jrxml("".join(full))
|
||
state["current_jrxml"] = jrxml
|
||
state["conversation_history"].append(
|
||
{
|
||
"role": "user",
|
||
"content": state.get("user_modification_request", ""),
|
||
}
|
||
)
|
||
state["conversation_history"].append({"role": "assistant", "content": jrxml})
|
||
state["full_conversation_history"] = (
|
||
list(state.get("full_conversation_history", [])) +
|
||
[
|
||
{"role": "user", "content": state.get("user_modification_request", ""), "ts": _now_iso()},
|
||
{"role": "assistant", "content": jrxml, "ts": _now_iso()},
|
||
]
|
||
)
|
||
state["retry_count"] = 0
|
||
return state
|
||
|
||
|
||
def validate(state: AgentState) -> Dict:
|
||
"""根据 FastAPI 验证服务验证当前 JRXML。"""
|
||
jrxml = state.get("current_jrxml", "")
|
||
if not jrxml:
|
||
state["status"] = "fail"
|
||
state["error_msg"] = "没有 JRXML 内容可供验证。"
|
||
return state
|
||
|
||
# 过短的内容不可能是合法报表(最小骨架约 500+ 字符)
|
||
if len(jrxml.strip()) < 200:
|
||
state["status"] = "fail"
|
||
state["error_msg"] = f"JRXML 内容过短({len(jrxml.strip())} 字符),可能为不完整或空内容。"
|
||
return state
|
||
|
||
result = validate_jrxml(jrxml)
|
||
state["status"] = "pass" if result.get("valid") else "fail"
|
||
state["error_msg"] = result.get("error", "")
|
||
|
||
# 修正成功后记录到错误知识库
|
||
if result.get("valid") and state.get("retry_count", 0) > 0:
|
||
case = state.get("last_error_case", {})
|
||
if case and case.get("error_msg"):
|
||
try:
|
||
from backend.error_kb import record_error
|
||
|
||
recorded = record_error(
|
||
error_msg=case["error_msg"],
|
||
bad_jrxml=case.get("bad_jrxml", ""),
|
||
good_jrxml=jrxml,
|
||
correction_prompt=case.get("correction_prompt", ""),
|
||
retry_count=state.get("retry_count", 0),
|
||
)
|
||
if recorded:
|
||
state["conversation_history"].append({
|
||
"role": "system",
|
||
"content": f"[系统] 错误案例已记录到知识库(指纹: {case['error_msg'][:40]}...)",
|
||
})
|
||
except Exception:
|
||
pass # 知识库写入不影响主流程
|
||
|
||
return state
|
||
|
||
|
||
def explain_error(state: AgentState) -> Dict:
|
||
"""生成验证错误的可读解释。"""
|
||
llm = get_llm()
|
||
jrxml = state.get("current_jrxml", "")
|
||
lines = jrxml.split("\n")[:80]
|
||
snippet = "\n".join(lines)
|
||
|
||
prompt = load_prompt("explain_error").format(
|
||
error_msg=state.get("error_msg", "未知错误"),
|
||
jrxml_snippet=snippet,
|
||
)
|
||
resp = llm.invoke(prompt)
|
||
state["natural_explanation"] = resp.content.strip()
|
||
return state
|
||
|
||
|
||
def correct_jrxml(state: AgentState) -> Dict:
|
||
"""尝试自动修正验证失败的 JRXML。"""
|
||
from langgraph.config import get_stream_writer
|
||
|
||
writer = get_stream_writer()
|
||
llm = get_llm()
|
||
prompt = load_prompt("correction").format(
|
||
current_jrxml=state.get("current_jrxml", ""),
|
||
error_msg=state.get("error_msg", ""),
|
||
explanation=state.get("natural_explanation", ""),
|
||
)
|
||
# 保存修正前状态(供 validate 判断是否写入错误知识库)
|
||
state["last_error_case"] = {
|
||
"error_msg": state.get("error_msg", ""),
|
||
"bad_jrxml": state.get("current_jrxml", ""),
|
||
"correction_prompt": prompt,
|
||
}
|
||
|
||
full = []
|
||
for chunk in llm.stream(prompt):
|
||
full.append(chunk)
|
||
writer({"type": "stream", "node": "correct_jrxml", "text": chunk})
|
||
jrxml = _extract_jrxml("".join(full))
|
||
state["current_jrxml"] = jrxml
|
||
state["retry_count"] = state.get("retry_count", 0) + 1
|
||
state["conversation_history"].append(
|
||
{"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"}
|
||
)
|
||
return state
|
||
|
||
|
||
def finalize(state: AgentState) -> Dict:
|
||
"""保存最终验证通过的 JRXML 并更新对话历史 + 版本记录。"""
|
||
jrxml = state.get("current_jrxml", "")
|
||
status = state.get("status", "")
|
||
|
||
if status == "pass":
|
||
state["final_jrxml"] = jrxml
|
||
if jrxml.strip():
|
||
versions = state.get("jrxml_versions", [])
|
||
if not isinstance(versions, list):
|
||
versions = []
|
||
intent = state.get("intent", "")
|
||
label_map = {
|
||
"initial_generation": "初始生成",
|
||
"modify_report": "修改",
|
||
"correct_jrxml": f"自动修正 (第{state.get('retry_count', 1)}次)",
|
||
}
|
||
versions.append({
|
||
"ts": _now_iso(),
|
||
"jrxml": jrxml,
|
||
"intent": intent,
|
||
"label": label_map.get(intent, intent),
|
||
"status": status,
|
||
})
|
||
state["jrxml_versions"] = versions
|
||
else:
|
||
# 验证未通过:不覆盖 final_jrxml,保留上一次成功的版本
|
||
retries = state.get("retry_count", 0)
|
||
error_msg = state.get("error_msg", "未知错误")
|
||
# 记录失败上下文,下次用户输入时自动注入
|
||
state["pending_failure_context"] = {
|
||
"error_msg": error_msg,
|
||
"bad_jrxml": state.get("current_jrxml", ""),
|
||
"retry_count": retries,
|
||
"ts": _now_iso(),
|
||
}
|
||
state["conversation_history"].append({
|
||
"role": "assistant",
|
||
"content": (
|
||
f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n"
|
||
f"错误: {error_msg}\n"
|
||
f"请描述您想要的修改,系统会自动加载失败上下文继续修复。"
|
||
),
|
||
})
|
||
return state
|
||
|
||
|
||
def _extract_jrxml(text: str) -> str:
|
||
"""从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。"""
|
||
text = text.strip()
|
||
xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE)
|
||
m = xml_pattern.search(text)
|
||
if m:
|
||
content = m.group(1).strip()
|
||
if content:
|
||
return content
|
||
# markdown 代码块存在但内容为空 — 回退到直接匹配
|
||
|
||
jasper_tag = re.search(r"(<\?xml[\s\S]*?</jasperReport>)", text, re.IGNORECASE)
|
||
if jasper_tag:
|
||
return jasper_tag.group(1).strip()
|
||
|
||
if text.startswith("<?xml") or text.startswith("<jasperReport"):
|
||
return text
|
||
|
||
# 最终回退:如果文本中包含 XML 片段但没有被捕获到,尝试直接提取
|
||
# 这处理 LLM 在代码块外用自然语言"包裹"JRXML 的情况
|
||
xml_start = text.find("<?xml")
|
||
jr_end = text.lower().rfind("</jasperreport>")
|
||
if xml_start >= 0 and jr_end > xml_start:
|
||
return text[xml_start:jr_end + len("</jasperreport>")].strip()
|
||
|
||
return text
|