fix: per-node max_tokens + validation 502 guard + correct_jrxml output validity
- backend/llm.py: per-node max_tokens via get_llm(max_tokens=N), LLM_MAX_TOKENS env var (default 8192) - agent/nodes.py: 5 generation nodes use max_tokens=32768, generate_skeleton retries at 65536 - agent/nodes.py: fix ns:field regex (<field → <[\w:]*field) to handle namespace prefixes - agent/nodes.py: fix correct_jrxml never writing back to state["current_jrxml"] - agent/nodes.py: correct_jrxml rejects non-JRXML output (no <jasperReport tag) - agent/nodes.py: _strip_continuation_wrapper strips markdown/prefixes from continuation rounds - agent/nodes.py: _extract_jrxml iterates multiple markdown code blocks, skips fragments - agent/graph.py: route_after_validate skips correction loop when service_unavailable - agent/graph.py: route_after_save skips validation for empty JRXML - backend/validation.py: returns service_unavailable: True for ConnectError and HTTP 5xx - Docs: CLAUDE.md v14 changelog, README.md LLM_MAX_TOKENS, .env.example LLM_MAX_TOKENS
This commit is contained in:
@@ -14,6 +14,9 @@ OPENAI_BASE_URL=https://api.openai.com/v1
|
|||||||
|
|
||||||
LLM_MODEL=MiniMax-M2.7
|
LLM_MODEL=MiniMax-M2.7
|
||||||
|
|
||||||
|
# 默认 max_tokens(各生成节点可覆盖为更高值)
|
||||||
|
LLM_MAX_TOKENS=8192
|
||||||
|
|
||||||
# 本地大语言模型(Ollama)
|
# 本地大语言模型(Ollama)
|
||||||
LOCAL_LLM_MODEL=qwen2.5-coder:7b
|
LOCAL_LLM_MODEL=qwen2.5-coder:7b
|
||||||
|
|
||||||
|
|||||||
@@ -547,3 +547,47 @@ GET /api/sessions/{session_id}/kb # 获取会话绑定的 KB
|
|||||||
| `tests/test_programmatic_map_fields.py` | 20 | 字段声明替换/引用替换/中文转义/坐标保留/部分映射/空字段跳过 |
|
| `tests/test_programmatic_map_fields.py` | 20 | 字段声明替换/引用替换/中文转义/坐标保留/部分映射/空字段跳过 |
|
||||||
|
|
||||||
完整测试套件(385 项)无回归。
|
完整测试套件(385 项)无回归。
|
||||||
|
|
||||||
|
## 更新 (v14 — 2026-05-24)
|
||||||
|
|
||||||
|
### max_tokens per-node + 修正循环死锁修复
|
||||||
|
|
||||||
|
**问题 A — max_tokens 自限**: `backend/llm.py` 硬编码 `max_tokens=8192`。MiniMax M2.7 的 reasoning token 吃光 8192 输出预算后骨架生成为空(0 个可见字符)。其他节点(correct_jrxml/modify_jrxml)输入 68K+ 字符时输出也被截断。
|
||||||
|
|
||||||
|
**问题 B — ns:field 命名空间前缀正则失配**: `_programmatic_map_fields()` 正则 `<field\b` 匹配不到 `<ns0:field name="field_1">`,导致字段声明保持占位符但引用被替换为 OCR 字段名,校验报"used in expressions but not declared"。
|
||||||
|
|
||||||
|
**问题 C — 验证服务 502 修正死循环**: 验证服务(port 8001)未启动时,`validate_jrxml()` 返回 502。错误消息被当作 JRXML 校验错误送入 `explain_error → correct_jrxml`,LLM 尝试"修复"网络错误产出 HTML/markdown 等垃圾,循环 5 轮直到 retry_count 耗尽。
|
||||||
|
|
||||||
|
**问题 D — correct_jrxml 从未写回 current_jrxml**: 修正后的 JRXML 只写入 `conversation_history`,从不更新 `state["current_jrxml"]`,导致每轮 validate 看到同一份原始 JRXML,修正完全无效。这是 5 轮 jrxml_length 始终 4441 不变的根本原因。
|
||||||
|
|
||||||
|
**修复方案**:
|
||||||
|
|
||||||
|
#### 1. per-node max_tokens(`backend/llm.py` + `agent/nodes.py`)
|
||||||
|
- `get_llm(caller, max_tokens=None)` — 新增可选 `max_tokens` 参数,透传到 `_build_raw_llm`
|
||||||
|
- `MiniMaxLLM.__init__()` — 存储 `self._max_tokens`
|
||||||
|
- `LLM_MAX_TOKENS` 环境变量覆盖默认 8192
|
||||||
|
- 5 个生成节点 max_tokens 提升到 32768:`generate`, `generate_skeleton`, `refine_layout`, `modify_jrxml`, `correct_jrxml`
|
||||||
|
- `generate_skeleton` 空响应自动重试(max_tokens=65536)
|
||||||
|
|
||||||
|
#### 2. ns:field 正则修复(`agent/nodes.py:548`)
|
||||||
|
- `<field\b` → `<[\w:]*field\b` 兼容 `<ns0:field>`, `<field>` 等所有命名空间前缀
|
||||||
|
|
||||||
|
#### 3. 验证服务不可用防护
|
||||||
|
- `backend/validation.py` — 区分 ConnectError/HTTPStatusError(5xx):返回 `service_unavailable: True`
|
||||||
|
- `agent/nodes.py:validate` — 透传 `state["service_unavailable"]`
|
||||||
|
- `agent/graph.py:route_after_validate` — `service_unavailable` 时直接 `finalize`,不进入修正循环
|
||||||
|
|
||||||
|
#### 4. correct_jrxml 输出合法性守卫
|
||||||
|
- 新增 JRXML 有效性检查:输出不含 `<jasperReport` 且不含 `<?xml` 时,回退到前一版本
|
||||||
|
- **Bug 修复**: `state["current_jrxml"] = jrxml` 写回修正结果
|
||||||
|
|
||||||
|
#### 5. 连续输出提取增强
|
||||||
|
- `_strip_continuation_wrapper()` — 剥离续写响应中 LLM 重新添加的 markdown 代码块和自然语言前缀
|
||||||
|
- `_extract_jrxml()` — 逐一检查多个 markdown 代码块,跳过非 JRXML 片段
|
||||||
|
- `_generate_with_continuation()` — 续写轮次自动应用 `_strip_continuation_wrapper`
|
||||||
|
|
||||||
|
#### 新增环境变量
|
||||||
|
|
||||||
|
| 变量 | 描述 | 默认值 |
|
||||||
|
|------|------|--------|
|
||||||
|
| `LLM_MAX_TOKENS` | 默认 max_tokens(各节点可覆盖) | 8192 |
|
||||||
|
|||||||
@@ -176,6 +176,7 @@ jrxml-agent/
|
|||||||
| ANTHROPIC_API_KEY | Anthropic 兼容 API 密钥(优先) | - |
|
| ANTHROPIC_API_KEY | Anthropic 兼容 API 密钥(优先) | - |
|
||||||
| ANTHROPIC_BASE_URL | Anthropic 兼容 Base URL | https://api.minimaxi.com/anthropic |
|
| ANTHROPIC_BASE_URL | Anthropic 兼容 Base URL | https://api.minimaxi.com/anthropic |
|
||||||
| LLM_MODEL | 模型名称 | MiniMax-M2.7 |
|
| LLM_MODEL | 模型名称 | MiniMax-M2.7 |
|
||||||
|
| LLM_MAX_TOKENS | 默认 max_tokens(各节点可覆盖) | 8192 |
|
||||||
| LOCAL_LLM_MODEL | Ollama 模型 | qwen2.5-coder:7b |
|
| LOCAL_LLM_MODEL | Ollama 模型 | qwen2.5-coder:7b |
|
||||||
| EMBED_BACKEND | local 或 cloud | local |
|
| EMBED_BACKEND | local 或 cloud | local |
|
||||||
| LOCAL_EMBED_MODEL | 嵌入模型 | Qwen/Qwen3-Embedding-0.6B |
|
| LOCAL_EMBED_MODEL | 嵌入模型 | Qwen/Qwen3-Embedding-0.6B |
|
||||||
|
|||||||
+10
-1
@@ -120,6 +120,9 @@ def route_after_save(state: AgentState) -> Literal["validate", "finalize"]:
|
|||||||
intent = state.get("intent", "")
|
intent = state.get("intent", "")
|
||||||
if intent in ("preview_report", "export_pdf", "export_jrxml"):
|
if intent in ("preview_report", "export_pdf", "export_jrxml"):
|
||||||
return "finalize"
|
return "finalize"
|
||||||
|
# JRXML 为空时跳过验证/修正循环(生成失败等场景)
|
||||||
|
if not state.get("current_jrxml", "").strip():
|
||||||
|
return "finalize"
|
||||||
return "validate"
|
return "validate"
|
||||||
|
|
||||||
|
|
||||||
@@ -127,6 +130,12 @@ def route_after_save(state: AgentState) -> Literal["validate", "finalize"]:
|
|||||||
def route_after_validate(state: AgentState) -> Literal["finalize", "explain_error"]:
|
def route_after_validate(state: AgentState) -> Literal["finalize", "explain_error"]:
|
||||||
if state.get("status") == "pass":
|
if state.get("status") == "pass":
|
||||||
return "finalize"
|
return "finalize"
|
||||||
|
# JRXML 为空时跳过 explain→correct 修正循环
|
||||||
|
if not state.get("current_jrxml", "").strip():
|
||||||
|
return "finalize"
|
||||||
|
# 验证服务不可用时跳过修正循环,避免对网络错误进行无效修正
|
||||||
|
if state.get("service_unavailable"):
|
||||||
|
return "finalize"
|
||||||
return "explain_error"
|
return "explain_error"
|
||||||
|
|
||||||
|
|
||||||
@@ -256,7 +265,7 @@ def build_graph(on_node_start=None) -> StateGraph:
|
|||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"save_session",
|
"save_session",
|
||||||
route_after_save,
|
route_after_save,
|
||||||
{"validate": "validate"},
|
{"validate": "validate", "finalize": "finalize"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---- 验证 → 修正循环 ----
|
# ---- 验证 → 修正循环 ----
|
||||||
|
|||||||
+72
-21
@@ -543,9 +543,9 @@ def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str:
|
|||||||
real_name = _sanitize_field_name(raw_name)
|
real_name = _sanitize_field_name(raw_name)
|
||||||
if real_name == placeholder:
|
if real_name == placeholder:
|
||||||
continue
|
continue
|
||||||
# 替换 field 声明: <field name="field_1" → <field name="customer_name"
|
# 替换 field 声明: <ns0:field name="field_1" → <ns0:field name="customer_name"
|
||||||
result = re.sub(
|
result = re.sub(
|
||||||
rf'(<field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
|
rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")',
|
||||||
rf'\g<1>{real_name}\g<2>', result,
|
rf'\g<1>{real_name}\g<2>', result,
|
||||||
)
|
)
|
||||||
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
|
# 替换所有引用: $F{{field_1}} → $F{{customer_name}}
|
||||||
@@ -821,7 +821,7 @@ def generate(state: AgentState) -> Dict:
|
|||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
llm = get_llm(caller="generate")
|
llm = get_llm(caller="generate", max_tokens=32768)
|
||||||
|
|
||||||
user_request = state.get("user_input", "")
|
user_request = state.get("user_input", "")
|
||||||
ocr_text = _format_ocr_context(state)
|
ocr_text = _format_ocr_context(state)
|
||||||
@@ -849,7 +849,6 @@ def generate_skeleton(state: AgentState) -> Dict:
|
|||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
llm = get_llm(caller="generate_skeleton")
|
|
||||||
|
|
||||||
schema = state.get("layout_schema", {})
|
schema = state.get("layout_schema", {})
|
||||||
schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else ""
|
schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else ""
|
||||||
@@ -861,10 +860,16 @@ def generate_skeleton(state: AgentState) -> Dict:
|
|||||||
user_request=user_request,
|
user_request=user_request,
|
||||||
template_context=_build_template_context(state),
|
template_context=_build_template_context(state),
|
||||||
)
|
)
|
||||||
|
llm = get_llm(caller="generate_skeleton", max_tokens=32768)
|
||||||
|
|
||||||
prev_jrxml = state.get("current_jrxml", "")
|
prev_jrxml = state.get("current_jrxml", "")
|
||||||
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
|
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
|
||||||
if not full_text.strip():
|
if not full_text.strip():
|
||||||
_node_log.error("generate_skeleton LLM 返回空响应")
|
_node_log.warning("generate_skeleton 首次返回空响应,以更高 max_tokens 重试")
|
||||||
|
llm = get_llm(caller="generate_skeleton", max_tokens=65536)
|
||||||
|
full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton")
|
||||||
|
if not full_text.strip():
|
||||||
|
_node_log.error("generate_skeleton LLM 返回空响应(含重试)")
|
||||||
return state
|
return state
|
||||||
jrxml = _extract_jrxml(full_text)
|
jrxml = _extract_jrxml(full_text)
|
||||||
if len(jrxml.strip()) < 200:
|
if len(jrxml.strip()) < 200:
|
||||||
@@ -1025,7 +1030,7 @@ def modify_jrxml(state: AgentState) -> Dict:
|
|||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
llm = get_llm(caller="modify_jrxml")
|
llm = get_llm(caller="modify_jrxml", max_tokens=32768)
|
||||||
# 构建对话上下文:压缩摘要 + 最近对话
|
# 构建对话上下文:压缩摘要 + 最近对话
|
||||||
compressed = state.get("compressed_history", "")
|
compressed = state.get("compressed_history", "")
|
||||||
recent = state.get("conversation_history", [])[-6:]
|
recent = state.get("conversation_history", [])[-6:]
|
||||||
@@ -1278,6 +1283,7 @@ def validate(state: AgentState) -> Dict:
|
|||||||
result = validate_jrxml(jrxml)
|
result = validate_jrxml(jrxml)
|
||||||
state["status"] = "pass" if result.get("valid") else "fail"
|
state["status"] = "pass" if result.get("valid") else "fail"
|
||||||
state["error_msg"] = result.get("error", "")
|
state["error_msg"] = result.get("error", "")
|
||||||
|
state["service_unavailable"] = result.get("service_unavailable", False)
|
||||||
|
|
||||||
# OCR 保真度检查:比对生成结果与原始图片的 OCR 提取内容
|
# OCR 保真度检查:比对生成结果与原始图片的 OCR 提取内容
|
||||||
fidelity = _check_ocr_fidelity(jrxml, state)
|
fidelity = _check_ocr_fidelity(jrxml, state)
|
||||||
@@ -1378,7 +1384,7 @@ def correct_jrxml(state: AgentState) -> Dict:
|
|||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
llm = get_llm(caller="correct_jrxml")
|
llm = get_llm(caller="correct_jrxml", max_tokens=32768)
|
||||||
ocr_context = _format_ocr_context(state)
|
ocr_context = _format_ocr_context(state)
|
||||||
layout_schema = state.get("layout_schema", {})
|
layout_schema = state.get("layout_schema", {})
|
||||||
layout_text = ""
|
layout_text = ""
|
||||||
@@ -1432,6 +1438,13 @@ def correct_jrxml(state: AgentState) -> Dict:
|
|||||||
_node_log.warning(f"correct_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
|
_node_log.warning(f"correct_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本")
|
||||||
jrxml = prev_jrxml
|
jrxml = prev_jrxml
|
||||||
|
|
||||||
|
# 如果提取结果不是合法 JRXML(不含 <jasperReport),说明 LLM 返回了 HTML 等垃圾输出
|
||||||
|
if jrxml and "<jasperReport" not in jrxml and "<?xml" not in jrxml:
|
||||||
|
_node_log.warning(
|
||||||
|
f"correct_jrxml 输出不是合法 JRXML({jrxml[:100]}),回退到前一版本"
|
||||||
|
)
|
||||||
|
jrxml = prev_jrxml
|
||||||
|
|
||||||
# 去重检测:如果输出与输入完全相同(忽略空白差异),说明修正无效
|
# 去重检测:如果输出与输入完全相同(忽略空白差异),说明修正无效
|
||||||
_prev_norm = re.sub(r"\s+", "", prev_jrxml) if prev_jrxml else ""
|
_prev_norm = re.sub(r"\s+", "", prev_jrxml) if prev_jrxml else ""
|
||||||
_new_norm = re.sub(r"\s+", "", jrxml) if jrxml else ""
|
_new_norm = re.sub(r"\s+", "", jrxml) if jrxml else ""
|
||||||
@@ -1442,6 +1455,7 @@ def correct_jrxml(state: AgentState) -> Dict:
|
|||||||
state["retry_count"] = state.get("retry_count", 0) + 2
|
state["retry_count"] = state.get("retry_count", 0) + 2
|
||||||
else:
|
else:
|
||||||
state["retry_count"] = state.get("retry_count", 0) + 1
|
state["retry_count"] = state.get("retry_count", 0) + 1
|
||||||
|
state["current_jrxml"] = jrxml
|
||||||
state["conversation_history"].append(
|
state["conversation_history"].append(
|
||||||
{"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"}
|
{"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"}
|
||||||
)
|
)
|
||||||
@@ -1510,6 +1524,31 @@ def finalize(state: AgentState) -> Dict:
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_continuation_wrapper(text: str) -> str:
|
||||||
|
"""去除续写响应中的 markdown 代码块标记和自然语言解释。
|
||||||
|
|
||||||
|
续写轮次的 LLM 可能会"忘记"原始 prompt 中的格式要求,
|
||||||
|
在响应开头加解释文字、用 ``` 包裹 XML 片段。
|
||||||
|
此函数提取其中的纯 XML 内容,去除包装。
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
# 移除完整的 markdown 代码块包装: ```...```
|
||||||
|
m = re.search(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", text, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
inner = m.group(1).strip()
|
||||||
|
if inner:
|
||||||
|
return inner
|
||||||
|
# 移除开头/结尾的独立 ``` 标记(不完整代码块)
|
||||||
|
text = re.sub(r"^```(?:xml|jrxml)?\s*", "", text)
|
||||||
|
text = re.sub(r"```\s*$", "", text)
|
||||||
|
# 移除续写响应常见的自然语言前缀
|
||||||
|
text = re.sub(
|
||||||
|
r"^.{0,40}(继续输出|剩余|续写|补全|接上).{0,30}[::]?\s*",
|
||||||
|
"", text, flags=re.IGNORECASE
|
||||||
|
)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> str:
|
def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> str:
|
||||||
"""Stream LLM generation with automatic truncation recovery.
|
"""Stream LLM generation with automatic truncation recovery.
|
||||||
|
|
||||||
@@ -1519,6 +1558,7 @@ def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) ->
|
|||||||
|
|
||||||
Returns combined full text from all rounds.
|
Returns combined full text from all rounds.
|
||||||
"""
|
"""
|
||||||
|
_jrxml_end = r"</(?:[\w:]+:)?(?:jasperReport|report)>\s*$"
|
||||||
full_text = ""
|
full_text = ""
|
||||||
|
|
||||||
for round_num in range(max_rounds):
|
for round_num in range(max_rounds):
|
||||||
@@ -1529,7 +1569,8 @@ def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) ->
|
|||||||
current_prompt = (
|
current_prompt = (
|
||||||
f"[系统指令] 你正在生成的 JRXML 在上一次响应中被截断。\n"
|
f"[系统指令] 你正在生成的 JRXML 在上一次响应中被截断。\n"
|
||||||
f"已生成内容的最后部分(请从此处继续):\n...{tail}\n\n"
|
f"已生成内容的最后部分(请从此处继续):\n...{tail}\n\n"
|
||||||
f"请从截断点继续输出剩余内容,不要重复已输出的部分。"
|
f"请从截断点继续输出剩余内容,不要重复已输出的部分。\n"
|
||||||
|
f"不要输出 markdown 代码块、解释或任何非 JRXML 的内容。"
|
||||||
)
|
)
|
||||||
|
|
||||||
new_chunks = []
|
new_chunks = []
|
||||||
@@ -1538,10 +1579,12 @@ def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) ->
|
|||||||
writer({"type": "stream", "node": node_name, "text": chunk})
|
writer({"type": "stream", "node": node_name, "text": chunk})
|
||||||
|
|
||||||
new_text = "".join(new_chunks)
|
new_text = "".join(new_chunks)
|
||||||
|
if round_num > 0:
|
||||||
|
new_text = _strip_continuation_wrapper(new_text)
|
||||||
full_text += new_text
|
full_text += new_text
|
||||||
|
|
||||||
jrxml = _extract_jrxml(full_text)
|
jrxml = _extract_jrxml(full_text)
|
||||||
if re.search(r"</(?:[\w:]+:)?jasperReport>\s*$", jrxml, re.IGNORECASE):
|
if re.search(_jrxml_end, jrxml, re.IGNORECASE):
|
||||||
break
|
break
|
||||||
|
|
||||||
if not new_text.strip():
|
if not new_text.strip():
|
||||||
@@ -1554,17 +1597,26 @@ def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) ->
|
|||||||
|
|
||||||
|
|
||||||
def _extract_jrxml(text: str) -> str:
|
def _extract_jrxml(text: str) -> str:
|
||||||
"""从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。"""
|
"""从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。
|
||||||
text = text.strip()
|
|
||||||
xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE)
|
|
||||||
m = xml_pattern.search(text)
|
|
||||||
if m:
|
|
||||||
content = m.group(1).strip()
|
|
||||||
if content:
|
|
||||||
return content
|
|
||||||
# markdown 代码块存在但内容为空 — 回退到直接匹配
|
|
||||||
|
|
||||||
_jrxml_close = r"</(?:[\w:]+:)?jasperReport>"
|
处理多种情况:
|
||||||
|
1. 完整的 markdown 代码块包裹(单轮输出)
|
||||||
|
2. 混合文本(多轮续写:第一轮无 markdown,续写轮添加了 markdown)
|
||||||
|
3. 纯 JRXML 无包装
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
# 检测并提取 markdown 代码块中的内容
|
||||||
|
# 如果第一个代码块的内容看起来是完整 JRXML(以 <?xml 或 <jasperReport 开头),
|
||||||
|
# 则返回它;否则跳过该块,回退到其他提取方式。
|
||||||
|
xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE)
|
||||||
|
for m in xml_pattern.finditer(text):
|
||||||
|
content = m.group(1).strip()
|
||||||
|
if content and (content.startswith("<?xml") or content.startswith("<jasperReport")):
|
||||||
|
return content
|
||||||
|
# 非完整 JRXML 片段 — 跳过,继续搜索后续代码块
|
||||||
|
|
||||||
|
# 直接匹配 <?xml ... </jasperReport> 或 ... </report>
|
||||||
|
_jrxml_close = r"</(?:[\w:]+:)?(?:jasperReport|report)>"
|
||||||
jasper_tag = re.search(rf"(<\?xml[\s\S]*?{_jrxml_close})", text, re.IGNORECASE)
|
jasper_tag = re.search(rf"(<\?xml[\s\S]*?{_jrxml_close})", text, re.IGNORECASE)
|
||||||
if jasper_tag:
|
if jasper_tag:
|
||||||
return jasper_tag.group(1).strip()
|
return jasper_tag.group(1).strip()
|
||||||
@@ -1572,8 +1624,7 @@ def _extract_jrxml(text: str) -> str:
|
|||||||
if text.startswith("<?xml") or text.startswith("<jasperReport"):
|
if text.startswith("<?xml") or text.startswith("<jasperReport"):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# 最终回退:如果文本中包含 XML 片段但没有被捕获到,尝试直接提取
|
# 最终回退:尝试在文本中定位 XML 起始和结束
|
||||||
# 这处理 LLM 在代码块外用自然语言"包裹"JRXML 的情况
|
|
||||||
xml_start = text.find("<?xml")
|
xml_start = text.find("<?xml")
|
||||||
jr_close = re.search(_jrxml_close, text, re.IGNORECASE)
|
jr_close = re.search(_jrxml_close, text, re.IGNORECASE)
|
||||||
if xml_start >= 0 and jr_close:
|
if xml_start >= 0 and jr_close:
|
||||||
|
|||||||
+18
-8
@@ -156,8 +156,14 @@ class _LLMLoggingWrapper(_BaseLLM):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _build_raw_llm(caller: str = "") -> tuple[_BaseLLM, str, str]:
|
DEFAULT_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "8192"))
|
||||||
"""构造原始 LLM 实例,返回 (实例, model名, backend名)。"""
|
|
||||||
|
|
||||||
|
def _build_raw_llm(caller: str = "", max_tokens: int | None = None) -> tuple[_BaseLLM, str, str]:
|
||||||
|
"""构造原始 LLM 实例,返回 (实例, model名, backend名)。
|
||||||
|
|
||||||
|
max_tokens: 覆盖默认输出 token 数。None 使用 LLM_MAX_TOKENS 环境变量或 8192。
|
||||||
|
"""
|
||||||
backend = os.getenv("LLM_BACKEND", "cloud")
|
backend = os.getenv("LLM_BACKEND", "cloud")
|
||||||
if backend == "local":
|
if backend == "local":
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
@@ -183,18 +189,19 @@ def _build_raw_llm(caller: str = "") -> tuple[_BaseLLM, str, str]:
|
|||||||
base_url = os.getenv("ANTHROPIC_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://api.minimaxi.com/anthropic")
|
base_url = os.getenv("ANTHROPIC_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://api.minimaxi.com/anthropic")
|
||||||
model = os.getenv("LLM_MODEL", "MiniMax-M2.7")
|
model = os.getenv("LLM_MODEL", "MiniMax-M2.7")
|
||||||
temperature = 0.1
|
temperature = 0.1
|
||||||
max_tokens = 8192
|
_default_max_tokens = max_tokens if max_tokens is not None else DEFAULT_MAX_TOKENS
|
||||||
|
|
||||||
client = Anthropic(api_key=api_key, base_url=base_url, timeout=120)
|
client = Anthropic(api_key=api_key, base_url=base_url, timeout=120)
|
||||||
|
|
||||||
class MiniMaxLLM(_BaseLLM):
|
class MiniMaxLLM(_BaseLLM):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._last_stop_reason = None
|
self._last_stop_reason = None
|
||||||
|
self._max_tokens = _default_max_tokens
|
||||||
|
|
||||||
def invoke(self, prompt: str) -> Any:
|
def invoke(self, prompt: str) -> Any:
|
||||||
resp = client.messages.create(
|
resp = client.messages.create(
|
||||||
model=model,
|
model=model,
|
||||||
max_tokens=max_tokens,
|
max_tokens=self._max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
||||||
)
|
)
|
||||||
@@ -208,7 +215,7 @@ def _build_raw_llm(caller: str = "") -> tuple[_BaseLLM, str, str]:
|
|||||||
self._last_stop_reason = None
|
self._last_stop_reason = None
|
||||||
with client.messages.stream(
|
with client.messages.stream(
|
||||||
model=model,
|
model=model,
|
||||||
max_tokens=max_tokens,
|
max_tokens=self._max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
||||||
) as s:
|
) as s:
|
||||||
@@ -250,9 +257,12 @@ def _build_raw_llm(caller: str = "") -> tuple[_BaseLLM, str, str]:
|
|||||||
return OpenAIWrapper(), model, f"cloud/openai/{model}"
|
return OpenAIWrapper(), model, f"cloud/openai/{model}"
|
||||||
|
|
||||||
|
|
||||||
def get_llm(caller: str = "") -> _BaseLLM:
|
def get_llm(caller: str = "", max_tokens: int | None = None) -> _BaseLLM:
|
||||||
"""返回带日志的 LLM 实例。caller 用于标识调用来源(如 generate、classify_intent)。"""
|
"""返回带日志的 LLM 实例。caller 用于标识调用来源(如 generate、classify_intent)。
|
||||||
inner, model, backend = _build_raw_llm(caller)
|
|
||||||
|
max_tokens: 覆盖默认输出 token 数。用于骨架生成等需要大量输出的节点。
|
||||||
|
"""
|
||||||
|
inner, model, backend = _build_raw_llm(caller, max_tokens=max_tokens)
|
||||||
return _LLMLoggingWrapper(inner, model=model, backend=backend, caller=caller)
|
return _LLMLoggingWrapper(inner, model=model, backend=backend, caller=caller)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+12
-2
@@ -4,6 +4,7 @@ import os
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from httpx import ConnectError, HTTPStatusError
|
||||||
|
|
||||||
from backend.logger import get_logger
|
from backend.logger import get_logger
|
||||||
|
|
||||||
@@ -31,10 +32,19 @@ def validate_jrxml(jrxml_text: str) -> dict:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except httpx.ConnectError:
|
except ConnectError:
|
||||||
error_msg = f"无法连接到验证服务 ({VALIDATION_URL})。是否正在运行?"
|
error_msg = f"无法连接到验证服务 ({VALIDATION_URL})。是否正在运行?"
|
||||||
_val_log.error("验证服务连接失败", extra={"error": error_msg, "url": VALIDATION_URL})
|
_val_log.error("验证服务连接失败", extra={"error": error_msg, "url": VALIDATION_URL})
|
||||||
return {"valid": False, "error": error_msg}
|
return {"valid": False, "error": error_msg, "service_unavailable": True}
|
||||||
|
except HTTPStatusError as e:
|
||||||
|
status_code = e.response.status_code
|
||||||
|
error_msg = f"验证服务返回错误 ({status_code}): {str(e)}"
|
||||||
|
_val_log.error("验证请求异常", extra={"error": str(e), "url": VALIDATION_URL, "status_code": status_code})
|
||||||
|
return {
|
||||||
|
"valid": False,
|
||||||
|
"error": error_msg,
|
||||||
|
"service_unavailable": status_code >= 500,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"验证请求失败: {str(e)}"
|
error_msg = f"验证请求失败: {str(e)}"
|
||||||
_val_log.error("验证请求异常", extra={"error": str(e), "url": VALIDATION_URL})
|
_val_log.error("验证请求异常", extra={"error": str(e), "url": VALIDATION_URL})
|
||||||
|
|||||||
@@ -0,0 +1,228 @@
|
|||||||
|
"""续写 + JRXML 提取单元测试。
|
||||||
|
|
||||||
|
测试 _strip_continuation_wrapper、_extract_jrxml 在
|
||||||
|
多轮续写场景下的鲁棒性,以及 _generate_with_continuation 的完成检测。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from agent.nodes import _strip_continuation_wrapper, _extract_jrxml
|
||||||
|
|
||||||
|
# ── 完整 JRXML ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
COMPLETE_JRXML = """<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<jasperReport name="test" pageWidth="595" pageHeight="842">
|
||||||
|
<field name="field_1" class="java.lang.String"/>
|
||||||
|
<queryString><![CDATA[SELECT * FROM t]]></queryString>
|
||||||
|
<title>
|
||||||
|
<band height="50">
|
||||||
|
<staticText>
|
||||||
|
<reportElement x="0" y="0" width="100" height="20"/>
|
||||||
|
<text><![CDATA[$F{field_1}]]></text>
|
||||||
|
</staticText>
|
||||||
|
</band>
|
||||||
|
</title>
|
||||||
|
</jasperReport>"""
|
||||||
|
|
||||||
|
# 第一轮输出:完整开头但缺少 </jasperReport>(模拟截断)
|
||||||
|
ROUND1_TRUNCATED = """<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<jasperReport name="test" pageWidth="595" pageHeight="842">
|
||||||
|
<field name="field_1" class="java.lang.String"/>
|
||||||
|
<field name="field_2" class="java.lang.String"/>
|
||||||
|
<queryString><![CDATA[SELECT * FROM t]]></queryString>
|
||||||
|
<title>
|
||||||
|
<band height="50">
|
||||||
|
<staticText>
|
||||||
|
<reportElement x="0" y="0" width="100" height="20"/>
|
||||||
|
<text><![CDATA[$F{field_1}]]></text>
|
||||||
|
</staticText>
|
||||||
|
</band>
|
||||||
|
</title>
|
||||||
|
<detail>
|
||||||
|
<band height="30">
|
||||||
|
<textField>
|
||||||
|
<reportElement x="0" y="0" width="100" height="20"/>
|
||||||
|
<textFieldExpression><![CDATA[$F{field_1}]]></"""
|
||||||
|
|
||||||
|
# 第二轮续写:用 markdown 包裹 + 错误关闭标签(真实 LLM 行为)
|
||||||
|
ROUND2_MARKDOWN_CONTINUATION = """继续输出剩余的 JRXML 内容:
|
||||||
|
|
||||||
|
```
|
||||||
|
<textFieldExpression><![CDATA[$F{field_2}]]></textFieldExpression>
|
||||||
|
</textField>
|
||||||
|
</band>
|
||||||
|
</detail>
|
||||||
|
</jasperReport>
|
||||||
|
```"""
|
||||||
|
|
||||||
|
# 第二轮续写变体:用 </report> 关闭(另一种常见 LLM 错误)
|
||||||
|
ROUND2_REPORT_CLOSE = """继续输出:
|
||||||
|
|
||||||
|
```
|
||||||
|
<textFieldExpression><![CDATA[$F{field_2}]]></textFieldExpression>
|
||||||
|
</textField>
|
||||||
|
</band>
|
||||||
|
</detail>
|
||||||
|
</report>
|
||||||
|
```"""
|
||||||
|
|
||||||
|
# 第二轮续写变体:只用 ``` 开头,无结尾(不完整代码块)
|
||||||
|
ROUND2_PARTIAL_MARKDOWN = """
|
||||||
|
```xml
|
||||||
|
<textFieldExpression><![CDATA[$F{field_2}]]></textFieldExpression>
|
||||||
|
</textField>
|
||||||
|
</band>
|
||||||
|
</detail>
|
||||||
|
</jasperReport>
|
||||||
|
```"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── _strip_continuation_wrapper 测试 ───────────────────────────────
|
||||||
|
|
||||||
|
class TestStripContinuationWrapper:
|
||||||
|
def test_removes_complete_markdown_block(self):
|
||||||
|
text = '继续输出:\n\n```\n<band>test</band>\n```'
|
||||||
|
result = _strip_continuation_wrapper(text)
|
||||||
|
assert result == '<band>test</band>'
|
||||||
|
|
||||||
|
def test_removes_xml_fenced_block(self):
|
||||||
|
text = '```xml\n<band>test</band>\n```'
|
||||||
|
result = _strip_continuation_wrapper(text)
|
||||||
|
assert result == '<band>test</band>'
|
||||||
|
|
||||||
|
def test_removes_opening_fence_only(self):
|
||||||
|
text = '```xml\n<band>test</band>'
|
||||||
|
result = _strip_continuation_wrapper(text)
|
||||||
|
assert '<band>test</band>' in result
|
||||||
|
assert '```' not in result
|
||||||
|
|
||||||
|
def test_removes_closing_fence_only(self):
|
||||||
|
text = '<band>test</band>\n```'
|
||||||
|
result = _strip_continuation_wrapper(text)
|
||||||
|
assert '<band>test</band>' in result
|
||||||
|
assert '```' not in result
|
||||||
|
|
||||||
|
def test_removes_continuation_prefix_chinese(self):
|
||||||
|
text = '继续输出剩余的 JRXML 内容:\n<band>test</band>'
|
||||||
|
result = _strip_continuation_wrapper(text)
|
||||||
|
assert result == '<band>test</band>'
|
||||||
|
|
||||||
|
def test_pure_xml_passes_through(self):
|
||||||
|
text = '<band>test</band>'
|
||||||
|
result = _strip_continuation_wrapper(text)
|
||||||
|
assert result == '<band>test</band>'
|
||||||
|
|
||||||
|
def test_empty_becomes_empty(self):
|
||||||
|
assert _strip_continuation_wrapper('') == ''
|
||||||
|
assert _strip_continuation_wrapper(' ') == ''
|
||||||
|
|
||||||
|
def test_empty_markdown_block_returns_empty(self):
|
||||||
|
text = '```xml\n```'
|
||||||
|
result = _strip_continuation_wrapper(text)
|
||||||
|
assert result == ''
|
||||||
|
|
||||||
|
def test_multiple_backtick_pairs_extracts_first_valid(self):
|
||||||
|
text = '```\nfragment\n```\n```xml\ncomplete<?xml ...\n```'
|
||||||
|
result = _strip_continuation_wrapper(text)
|
||||||
|
assert result == 'fragment'
|
||||||
|
|
||||||
|
|
||||||
|
# ── _extract_jrxml 多轮续写场景测试 ─────────────────────────────────
|
||||||
|
|
||||||
|
class TestExtractJrxmlMultiRound:
|
||||||
|
def test_extracts_from_mixed_multi_round_output(self):
|
||||||
|
"""第一轮无 markdown + 第二轮有 markdown 的混合文本。"""
|
||||||
|
combined = ROUND1_TRUNCATED + ROUND2_MARKDOWN_CONTINUATION
|
||||||
|
result = _extract_jrxml(combined)
|
||||||
|
assert result.startswith("<?xml")
|
||||||
|
assert "</jasperReport>" in result
|
||||||
|
assert '$F{field_1}' in result
|
||||||
|
assert '$F{field_2}' in result
|
||||||
|
|
||||||
|
def test_extracts_with_report_close_tag(self):
|
||||||
|
"""第二轮用 </report> 而非 </jasperReport> 关闭。"""
|
||||||
|
combined = ROUND1_TRUNCATED + ROUND2_REPORT_CLOSE
|
||||||
|
result = _extract_jrxml(combined)
|
||||||
|
assert result.startswith("<?xml")
|
||||||
|
assert "</report>" in result
|
||||||
|
assert '$F{field_2}' in result
|
||||||
|
|
||||||
|
def test_extracts_with_partial_markdown(self):
|
||||||
|
"""第二轮用 ```xml 开头,``` 结尾。"""
|
||||||
|
combined = ROUND1_TRUNCATED + ROUND2_PARTIAL_MARKDOWN
|
||||||
|
result = _extract_jrxml(combined)
|
||||||
|
assert result.startswith("<?xml")
|
||||||
|
assert "</jasperReport>" in result
|
||||||
|
|
||||||
|
def test_single_round_complete_jrxml_in_markdown(self):
|
||||||
|
"""单轮输出:完整的 JRXML 在 markdown 代码块中。"""
|
||||||
|
text = '```xml\n' + COMPLETE_JRXML + '\n```'
|
||||||
|
result = _extract_jrxml(text)
|
||||||
|
assert result == COMPLETE_JRXML
|
||||||
|
|
||||||
|
def test_single_round_pure_jrxml(self):
|
||||||
|
"""单轮输出:纯 JRXML 无 markdown。"""
|
||||||
|
result = _extract_jrxml(COMPLETE_JRXML)
|
||||||
|
assert result == COMPLETE_JRXML
|
||||||
|
|
||||||
|
def test_jrxml_with_leading_explanation(self):
|
||||||
|
"""JRXML 前有自然语言解释。"""
|
||||||
|
text = '这是生成的报表模板:\n' + COMPLETE_JRXML
|
||||||
|
result = _extract_jrxml(text)
|
||||||
|
assert result == COMPLETE_JRXML
|
||||||
|
|
||||||
|
def test_two_markdown_blocks_skips_fragment(self):
|
||||||
|
"""文本中有两个 markdown 块,第一个是片段,第二个是完整 JRXML。"""
|
||||||
|
text = (
|
||||||
|
'```\nsome fragment\n```\n'
|
||||||
|
'```xml\n' + COMPLETE_JRXML + '\n```'
|
||||||
|
)
|
||||||
|
result = _extract_jrxml(text)
|
||||||
|
assert result == COMPLETE_JRXML
|
||||||
|
|
||||||
|
def test_two_markdown_blocks_first_is_complete(self):
|
||||||
|
"""文本中有两个 markdown 块,第一个是完整 JRXML。"""
|
||||||
|
text = (
|
||||||
|
'```xml\n' + COMPLETE_JRXML + '\n```\n'
|
||||||
|
'```\nsome other stuff\n```'
|
||||||
|
)
|
||||||
|
result = _extract_jrxml(text)
|
||||||
|
assert result == COMPLETE_JRXML
|
||||||
|
|
||||||
|
def test_no_xml_passes_through(self):
|
||||||
|
"""无 XML 内容的文本原样返回。"""
|
||||||
|
text = 'Hello, this has no XML at all.'
|
||||||
|
result = _extract_jrxml(text)
|
||||||
|
assert result == text
|
||||||
|
|
||||||
|
|
||||||
|
# ── 完成检测测试 ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestCompletionDetection:
|
||||||
|
def test_jasperreport_close_detected(self):
|
||||||
|
"""以 </jasperReport> 结尾的 JRXML 应被识别为完成。"""
|
||||||
|
import re
|
||||||
|
jrxml = COMPLETE_JRXML.strip()
|
||||||
|
_jrxml_end = r"</(?:[\w:]+:)?(?:jasperReport|report)>\s*$"
|
||||||
|
assert re.search(_jrxml_end, jrxml, re.IGNORECASE)
|
||||||
|
|
||||||
|
def test_report_close_detected(self):
|
||||||
|
"""以 </report> 结尾的 JRXML 也应被识别为完成。"""
|
||||||
|
import re
|
||||||
|
jrxml = COMPLETE_JRXML.replace('</jasperReport>', '</report>').strip()
|
||||||
|
_jrxml_end = r"</(?:[\w:]+:)?(?:jasperReport|report)>\s*$"
|
||||||
|
assert re.search(_jrxml_end, jrxml, re.IGNORECASE)
|
||||||
|
|
||||||
|
def test_namespaced_jasperreport_close_detected(self):
|
||||||
|
"""以 </ns0:jasperReport> 结尾的 JRXML 也应被识别。"""
|
||||||
|
import re
|
||||||
|
jrxml = COMPLETE_JRXML.replace('</jasperReport>', '</ns0:jasperReport>').strip()
|
||||||
|
_jrxml_end = r"</(?:[\w:]+:)?(?:jasperReport|report)>\s*$"
|
||||||
|
assert re.search(_jrxml_end, jrxml, re.IGNORECASE)
|
||||||
|
|
||||||
|
def test_truncated_jrxml_not_detected(self):
|
||||||
|
"""截断的 JRXML(无关闭标签)不应被识别为完成。"""
|
||||||
|
import re
|
||||||
|
_jrxml_end = r"</(?:[\w:]+:)?(?:jasperReport|report)>\s*$"
|
||||||
|
assert not re.search(_jrxml_end, ROUND1_TRUNCATED.strip(), re.IGNORECASE)
|
||||||
Reference in New Issue
Block a user