Files
agent_jrxml/agent/nodes.py
T
panda 4b43c5d3e4 feat: LangGraph工作流核心 — Agent状态/节点/图 + 验证服务 + 知识库
agent/
  state.py: AgentState TypedDict(20字段含意图/压缩/会话/撤销)
  nodes.py: 17个节点函数(生成/修改/验证/纠错/意图分类/压缩/撤销/重置)
  graph.py: 17节点状态图,8意图路由分发

验证服务 validation_service/
  main.py: FastAPI服务,lxml XSD验证 + 结构化检查(字段引用/SQL/尺寸)

数据 data/
  sample_templates/: 4个JRXML示例模板
  corrections/: 3个错误修正案例

脚本 scripts/
  init_kb.py: Chroma知识库初始化
2026-05-14 23:21:10 +08:00

572 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""LangGraph JRXML 生成工作流的节点函数。"""
import copy
import json
import os
import re
from datetime import datetime, timezone
from typing import Dict
from dotenv import load_dotenv
from agent.state import AgentState
from backend.embeddings import get_embeddings
from backend.llm import get_llm
from backend.validation import validate_jrxml
load_dotenv()
MAX_RETRY = int(os.getenv("MAX_RETRY", "3"))
CONTEXT_MAX_TOKENS = int(os.getenv("CONTEXT_MAX_TOKENS", "6000"))
CONTEXT_KEEP_RECENT = int(os.getenv("CONTEXT_KEEP_RECENT", "4"))
HISTORY_MAX_SNAPSHOTS = int(os.getenv("HISTORY_MAX_SNAPSHOTS", "10"))
# ============================================================
# 意图分类提示词(约 180 tokens,控制在 200 token 以内)
# ============================================================
INTENT_CLASSIFY_PROMPT = """你是意图分类器。根据用户输入判断意图,只输出意图名称。
当前有报表:{has_report}
用户输入:{user_input}
可选意图:
- initial_generation(新建报表,或无报表时的任何需求)
- modify_report(修改当前已有报表)
- preview_report(预览/查看当前报表)
- export_pdf(导出PDF文件)
- export_jrxml(下载/导出/保存JRXML文件)
- undo_modification(撤销/回退上一步修改)
- consult_question(咨询JasperReports相关知识或使用问题)
- reset_session(清空/重置/重新开始)
意图名称:"""
# ============================================================
# 咨询回答提示词
# ============================================================
CONSULT_PROMPT = """你是 JasperReports 专家。用简洁清晰的中文回答用户关于 JasperReports 的问题。
用户问题:{question}
直接回答:"""
# ============================================================
# 原有提示词(不变)
# ============================================================
INITIAL_GENERATION_PROMPT = """你是一位资深 JasperReports 工程师。根据以下参考模板和用户需求,生成一个完整、可编译的 JRXML 文件。
JRXML 必须兼容 JasperReports 7.0.6 schema。
关键规则:
- 只输出 JRXML 代码,不要解释,不要 markdown 标记。
- 报表正文中使用的每个字段必须在 <field name="..."> 部分中声明。
- 根元素为 <jasperReport>,包含正确的 xmlns 属性。
- 包含 <queryString>,在 <![CDATA[...]]> 中包含 SQL 查询。
- 确保所有交叉引用(字段名称、band 元素)保持一致。
参考模板和组件:
{context}
用户需求:
{user_request}
"""
MODIFICATION_PROMPT = """你是一位资深 JasperReports 工程师。用户想要修改一个现有的、可编译的 JRXML 报表。精确应用请求的更改到当前 JRXML 并输出完整修改后的 JRXML。
关键规则:
- 只输出完整修改后的 JRXML 代码,不要解释,不要 markdown 标记。
- 保留所有未被更改的现有结构。
- 结果必须继续与 JasperReports 7.0.6 兼容。
- 报表正文中使用的每个字段必须在 <field> 部分中声明。
- 如果添加新字段,正确声明它们。
- 确保 <queryString> 是 <![CDATA[...]]> 中有效的 SQL。
当前 JRXML
{current_jrxml}
对话历史:
{conversation_history}
用户的修改请求:
{modification_request}
"""
CORRECTION_PROMPT = """你是一位资深 JasperReports 工程师。你生成的 JRXML 文件编译失败。分析错误并修复 JRXML。
关键规则:
- 只输出完整修复后的 JRXML 代码,不要解释,不要 markdown 标记。
- JRXML 必须与 JasperReports 7.0.6 兼容。
- 解决下面列出的特定错误。
当前 JRXML(带错误):
{current_jrxml}
编译错误:
{error_msg}
错误的自然语言解释:
{explanation}
立即生成修正后的 JRXML
"""
EXPLAIN_PROMPT = """你是一位 JasperReports 专家。用普通非技术语言解释以下 JRXML 编译错误,让业务用户能够理解。
错误消息:
{error_msg}
当前 JRXML 片段(前 80 行):
{jrxml_snippet}
用 2-3 句话解释哪里出了问题以及如何修复:
"""
COMPRESSION_PROMPT = """你是一个信息压缩助手。以下是用户与报表生成助手之间的历史对话记录,请将其压缩为一份简洁的摘要(不超过200字)。
摘要必须保留以下关键信息:
- 用户提出的所有报表需求点(字段、标题、分组、汇总等)
- 用户提出的所有修改要求及其顺序
- 当前报表的核心结构(字段列表、标题、分组方式)
- 任何特殊要求或约束条件
只输出摘要文本,不要添加任何解释或标记。
对话记录:
{conversation_text}
"""
# ============================================================
# 核心工作流节点
# ============================================================
def process_input(state: AgentState) -> Dict:
"""记录用户输入到对话历史,重置本轮请求状态。"""
user_input = state.get("user_input", "")
# 维护全量对话历史
full_history = state.get("full_conversation_history", [])
full_history.append({"role": "user", "content": user_input, "ts": _now_iso()})
state["full_conversation_history"] = full_history
# 维护工作对话历史
conv_history = state.get("conversation_history", [])
conv_history.append({"role": "user", "content": user_input})
state["conversation_history"] = conv_history
# 重置本轮请求字段
state["retry_count"] = 0
state["user_modification_request"] = user_input
return state
def save_state_snapshot(state: AgentState) -> Dict:
"""保存当前状态快照到 history_states,用于撤销操作。最多保留 N 个版本。"""
snapshots = state.get("history_states", [])
if not isinstance(snapshots, list):
snapshots = []
snapshot = {
"current_jrxml": state.get("current_jrxml", ""),
"final_jrxml": state.get("final_jrxml", ""),
"status": state.get("status", ""),
"conversation_history": copy.deepcopy(state.get("conversation_history", [])),
"user_modification_request": state.get("user_modification_request", ""),
"intent": state.get("intent", ""),
}
snapshots.append(snapshot)
max_snap = HISTORY_MAX_SNAPSHOTS
if len(snapshots) > max_snap:
snapshots = snapshots[-max_snap:]
state["history_states"] = snapshots
return state
def classify_intent(state: AgentState) -> Dict:
"""使用 LLM 对用户输入进行意图分类(8 种意图)。"""
user_input = state.get("user_input", "")
has_report = "" if state.get("current_jrxml", "").strip() else ""
intent = "initial_generation"
try:
llm = get_llm()
prompt = INTENT_CLASSIFY_PROMPT.format(
has_report=has_report,
user_input=user_input[:500],
)
resp = llm.invoke(prompt)
raw = resp.content.strip().lower()
valid_intents = [
"initial_generation", "modify_report", "preview_report",
"export_pdf", "export_jrxml", "undo_modification",
"consult_question", "reset_session",
]
for vi in valid_intents:
if vi in raw:
intent = vi
break
else:
# 兜底:有报表 → modify_report,无报表 → initial_generation
intent = "modify_report" if has_report == "" else "initial_generation"
except Exception:
intent = "modify_report" if has_report == "" else "initial_generation"
state["intent"] = intent
return state
def handle_consult(state: AgentState) -> Dict:
"""处理咨询类问题:调用 LLM 直接回答,不走报表生成流程。"""
user_input = state.get("user_input", "")
try:
llm = get_llm()
prompt = CONSULT_PROMPT.format(question=user_input)
resp = llm.invoke(prompt)
answer = resp.content.strip()
except Exception:
answer = "抱歉,暂时无法处理您的问题,请稍后再试。"
state["consult_answer"] = answer
state["conversation_history"].append({"role": "assistant", "content": answer})
state["full_conversation_history"].append(
{"role": "assistant", "content": answer, "ts": _now_iso()}
)
return state
def handle_undo(state: AgentState) -> Dict:
"""撤销上一步修改:从 history_states 恢复最近一个快照。"""
snapshots = state.get("history_states", [])
if not isinstance(snapshots, list) or not snapshots:
state["conversation_history"].append(
{"role": "assistant", "content": "没有可撤销的操作。"}
)
return state
prev = snapshots.pop()
state["history_states"] = snapshots
state["current_jrxml"] = prev.get("current_jrxml", "")
state["final_jrxml"] = prev.get("final_jrxml", "")
state["status"] = prev.get("status", "")
state["conversation_history"] = prev.get("conversation_history", [])
state["user_modification_request"] = prev.get("user_modification_request", "")
state["conversation_history"].append(
{"role": "assistant", "content": "已撤销上一步修改,恢复到之前的状态。"}
)
state["full_conversation_history"].append(
{"role": "assistant", "content": "已撤销上一步修改。", "ts": _now_iso()}
)
return state
def handle_reset(state: AgentState) -> Dict:
"""重置当前会话:清空报表相关状态,保留会话信息。"""
state["current_jrxml"] = ""
state["final_jrxml"] = ""
state["status"] = ""
state["error_msg"] = ""
state["natural_explanation"] = ""
state["user_modification_request"] = ""
state["retrieved_context"] = ""
state["retry_count"] = 0
state["compressed_history"] = ""
state["history_states"] = []
state["intent"] = "initial_generation"
state["conversation_history"] = []
state["conversation_history"].append(
{"role": "assistant", "content": "会话已重置,请描述您要创建的新报表。"}
)
state["full_conversation_history"].append(
{"role": "assistant", "content": "会话已重置。", "ts": _now_iso()}
)
return state
def count_tokens(state: AgentState) -> int:
"""使用 tiktoken(gpt-4o 编码器)计算当前上下文 token 数量。"""
try:
import tiktoken
enc = tiktoken.encoding_for_model("gpt-4o")
except Exception:
# 回退方案:中英文混合场景下,近似 1 token ≈ 2.5 个字符
text = json.dumps({
"history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:],
"jrxml": state.get("current_jrxml", ""),
"compressed": state.get("compressed_history", ""),
}, ensure_ascii=False)
return len(text) // 2.5
text = json.dumps({
"history": state.get("conversation_history", [])[-CONTEXT_KEEP_RECENT:],
"jrxml": state.get("current_jrxml", ""),
"compressed": state.get("compressed_history", ""),
}, ensure_ascii=False)
return len(enc.encode(text))
def manage_context(state: AgentState) -> Dict:
"""当 token 数量超过阈值时,压缩较早的对话轮次。"""
token_count = count_tokens(state)
state["current_token_count"] = token_count
if token_count <= CONTEXT_MAX_TOKENS:
return state
full_history = state.get("full_conversation_history", [])
if len(full_history) <= CONTEXT_KEEP_RECENT:
return state
# 最近N轮保留完整,更早的轮次送去压缩
recent = full_history[-CONTEXT_KEEP_RECENT:]
older = full_history[:-CONTEXT_KEEP_RECENT]
if not older:
return state
conv_text = json.dumps(older, ensure_ascii=False, indent=2)
try:
llm = get_llm()
prompt = COMPRESSION_PROMPT.format(conversation_text=conv_text)
resp = llm.invoke(prompt)
new_compressed = resp.content.strip()[:300]
except Exception:
new_compressed = _simple_compress(older)
# 合并已有压缩与新压缩
existing = state.get("compressed_history", "")
if existing:
state["compressed_history"] = f"{existing}\n---\n{new_compressed}"
else:
state["compressed_history"] = new_compressed
state["conversation_history"] = list(recent)
state["current_token_count"] = count_tokens(state)
return state
def load_session_node(state: AgentState) -> Dict:
"""在请求开始时从磁盘加载会话状态。"""
session_id = state.get("session_id", "")
if not session_id:
return state
try:
from backend.session import load_session
data = load_session(session_id)
if data and data.get("agent_state"):
saved = data["agent_state"]
# 恢复核心字段(不覆盖当前请求的 user_input / stage
for key in ("conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history",
"session_name", "created_at", "history_states"):
if key in saved and key not in ("user_input", "stage"):
state[key] = saved[key]
state["session_name"] = data.get("session_name", "")
state["created_at"] = data.get("created_at", "")
except Exception:
pass
return state
def save_session_node(state: AgentState) -> Dict:
"""将当前代理状态持久化到磁盘。"""
session_id = state.get("session_id", "")
if not session_id:
return state
try:
from backend.session import save_session
persistable = {}
for key in ("conversation_history", "full_conversation_history",
"current_jrxml", "final_jrxml", "compressed_history",
"status", "error_msg", "history_states"):
if key in state:
persistable[key] = state[key]
persistable["updated_at"] = _now_iso()
session_name = state.get("session_name", "")
if not session_name and state.get("conversation_history"):
first_user = next(
(m["content"][:50] for m in state["conversation_history"]
if m.get("role") == "user"), "")
if first_user:
session_name = first_user
save_session(session_id, persistable, session_name)
if not state.get("session_name"):
state["session_name"] = session_name
state["updated_at"] = persistable["updated_at"]
except Exception:
pass
return state
def _simple_compress(messages: list[dict]) -> str:
"""当 LLM 不可用时,基于简单规则的压缩回退方案。"""
points = []
for m in messages:
if m.get("role") == "user":
points.append(f"用户提问:{m['content'][:100]}")
return "; ".join(points[-10:])
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def retrieve(state: AgentState) -> Dict:
"""在 Chroma 中搜索相关的 JRXML 模板和组件。"""
try:
embeddings = get_embeddings()
from langchain_chroma import Chroma
persist_dir = os.getenv("CHROMA_PERSIST_DIR", "./db/chroma")
if not os.path.exists(persist_dir) or not os.listdir(persist_dir):
state["retrieved_context"] = ""
return state
vectorstore = Chroma(
embedding_function=embeddings,
persist_directory=persist_dir,
)
user_input = state.get("user_input", "")
docs = vectorstore.similarity_search(user_input, k=5)
context_parts = []
for d in docs:
context_parts.append(d.page_content)
state["retrieved_context"] = "\n\n---\n\n".join(context_parts)
except Exception:
state["retrieved_context"] = ""
return state
def generate(state: AgentState) -> Dict:
"""根据用户需求和检索到的上下文生成初始 JRXML。"""
llm = get_llm()
prompt = INITIAL_GENERATION_PROMPT.format(
context=state.get("retrieved_context", ""),
user_request=state.get("user_input", ""),
)
resp = llm.invoke(prompt)
jrxml = _extract_jrxml(resp.content)
state["current_jrxml"] = jrxml
state["conversation_history"].append({"role": "assistant", "content": jrxml})
return state
def modify_jrxml(state: AgentState) -> Dict:
"""根据用户的修改请求修改现有 JRXML。"""
llm = get_llm()
# 构建对话上下文:压缩摘要 + 最近对话
compressed = state.get("compressed_history", "")
recent = state.get("conversation_history", [])[-6:]
conv_parts = []
if compressed:
conv_parts.append(f"[早期对话摘要]\n{compressed}")
conv_parts.append(json.dumps(recent, ensure_ascii=False, indent=2))
conv_text = "\n\n---\n\n".join(conv_parts)
prompt = MODIFICATION_PROMPT.format(
current_jrxml=state.get("current_jrxml", ""),
conversation_history=conv_text,
modification_request=state.get("user_modification_request", ""),
)
resp = llm.invoke(prompt)
jrxml = _extract_jrxml(resp.content)
state["current_jrxml"] = jrxml
state["conversation_history"].append(
{
"role": "user",
"content": state.get("user_modification_request", ""),
}
)
state["conversation_history"].append({"role": "assistant", "content": jrxml})
state["full_conversation_history"] = (
list(state.get("full_conversation_history", [])) +
[
{"role": "user", "content": state.get("user_modification_request", ""), "ts": _now_iso()},
{"role": "assistant", "content": jrxml, "ts": _now_iso()},
]
)
state["retry_count"] = 0
return state
def validate(state: AgentState) -> Dict:
"""根据 FastAPI 验证服务验证当前 JRXML。"""
jrxml = state.get("current_jrxml", "")
if not jrxml:
state["status"] = "fail"
state["error_msg"] = "没有 JRXML 内容可供验证。"
return state
result = validate_jrxml(jrxml)
state["status"] = "pass" if result.get("valid") else "fail"
state["error_msg"] = result.get("error", "")
return state
def explain_error(state: AgentState) -> Dict:
"""生成验证错误的可读解释。"""
llm = get_llm()
jrxml = state.get("current_jrxml", "")
lines = jrxml.split("\n")[:80]
snippet = "\n".join(lines)
prompt = EXPLAIN_PROMPT.format(
error_msg=state.get("error_msg", "未知错误"),
jrxml_snippet=snippet,
)
resp = llm.invoke(prompt)
state["natural_explanation"] = resp.content.strip()
return state
def correct_jrxml(state: AgentState) -> Dict:
"""尝试自动修正验证失败的 JRXML。"""
llm = get_llm()
prompt = CORRECTION_PROMPT.format(
current_jrxml=state.get("current_jrxml", ""),
error_msg=state.get("error_msg", ""),
explanation=state.get("natural_explanation", ""),
)
resp = llm.invoke(prompt)
jrxml = _extract_jrxml(resp.content)
state["current_jrxml"] = jrxml
state["retry_count"] = state.get("retry_count", 0) + 1
state["conversation_history"].append(
{"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"}
)
return state
def finalize(state: AgentState) -> Dict:
"""保存最终验证通过的 JRXML 并更新对话历史。"""
state["final_jrxml"] = state.get("current_jrxml", "")
return state
def _extract_jrxml(text: str) -> str:
"""从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。"""
text = text.strip()
xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE)
m = xml_pattern.search(text)
if m:
return m.group(1).strip()
jasper_tag = re.search(r"(<\?xml[\s\S]*?</jasperReport>)", text, re.IGNORECASE)
if jasper_tag:
return jasper_tag.group(1).strip()
if text.startswith("<?xml") or text.startswith("<jasperReport"):
return text
return text