diff --git a/.env.example b/.env.example index aba9d01..385d923 100644 --- a/.env.example +++ b/.env.example @@ -14,6 +14,9 @@ OPENAI_BASE_URL=https://api.openai.com/v1 LLM_MODEL=MiniMax-M2.7 +# 默认 max_tokens(各生成节点可覆盖为更高值) +LLM_MAX_TOKENS=8192 + # 本地大语言模型(Ollama) LOCAL_LLM_MODEL=qwen2.5-coder:7b diff --git a/CLAUDE.md b/CLAUDE.md index 453831a..fa52ea4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -547,3 +547,47 @@ GET /api/sessions/{session_id}/kb # 获取会话绑定的 KB | `tests/test_programmatic_map_fields.py` | 20 | 字段声明替换/引用替换/中文转义/坐标保留/部分映射/空字段跳过 | 完整测试套件(385 项)无回归。 + +## 更新 (v14 — 2026-05-24) + +### max_tokens per-node + 修正循环死锁修复 + +**问题 A — max_tokens 自限**: `backend/llm.py` 硬编码 `max_tokens=8192`。MiniMax M2.7 的 reasoning token 吃光 8192 输出预算后骨架生成为空(0 个可见字符)。其他节点(correct_jrxml/modify_jrxml)输入 68K+ 字符时输出也被截断。 + +**问题 B — ns:field 命名空间前缀正则失配**: `_programmatic_map_fields()` 正则 ``,导致字段声明保持占位符但引用被替换为 OCR 字段名,校验报"used in expressions but not declared"。 + +**问题 C — 验证服务 502 修正死循环**: 验证服务(port 8001)未启动时,`validate_jrxml()` 返回 502。错误消息被当作 JRXML 校验错误送入 `explain_error → correct_jrxml`,LLM 尝试"修复"网络错误产出 HTML/markdown 等垃圾,循环 5 轮直到 retry_count 耗尽。 + +**问题 D — correct_jrxml 从未写回 current_jrxml**: 修正后的 JRXML 只写入 `conversation_history`,从不更新 `state["current_jrxml"]`,导致每轮 validate 看到同一份原始 JRXML,修正完全无效。这是 5 轮 jrxml_length 始终 4441 不变的根本原因。 + +**修复方案**: + +#### 1. per-node max_tokens(`backend/llm.py` + `agent/nodes.py`) +- `get_llm(caller, max_tokens=None)` — 新增可选 `max_tokens` 参数,透传到 `_build_raw_llm` +- `MiniMaxLLM.__init__()` — 存储 `self._max_tokens` +- `LLM_MAX_TOKENS` 环境变量覆盖默认 8192 +- 5 个生成节点 max_tokens 提升到 32768:`generate`, `generate_skeleton`, `refine_layout`, `modify_jrxml`, `correct_jrxml` +- `generate_skeleton` 空响应自动重试(max_tokens=65536) + +#### 2. ns:field 正则修复(`agent/nodes.py:548`) +- ``, `` 等所有命名空间前缀 + +#### 3. 验证服务不可用防护 +- `backend/validation.py` — 区分 ConnectError/HTTPStatusError(5xx):返回 `service_unavailable: True` +- `agent/nodes.py:validate` — 透传 `state["service_unavailable"]` +- `agent/graph.py:route_after_validate` — `service_unavailable` 时直接 `finalize`,不进入修正循环 + +#### 4. correct_jrxml 输出合法性守卫 +- 新增 JRXML 有效性检查:输出不含 ` Literal["validate", "finalize"]: intent = state.get("intent", "") if intent in ("preview_report", "export_pdf", "export_jrxml"): return "finalize" + # JRXML 为空时跳过验证/修正循环(生成失败等场景) + if not state.get("current_jrxml", "").strip(): + return "finalize" return "validate" @@ -127,6 +130,12 @@ def route_after_save(state: AgentState) -> Literal["validate", "finalize"]: def route_after_validate(state: AgentState) -> Literal["finalize", "explain_error"]: if state.get("status") == "pass": return "finalize" + # JRXML 为空时跳过 explain→correct 修正循环 + if not state.get("current_jrxml", "").strip(): + return "finalize" + # 验证服务不可用时跳过修正循环,避免对网络错误进行无效修正 + if state.get("service_unavailable"): + return "finalize" return "explain_error" @@ -256,7 +265,7 @@ def build_graph(on_node_start=None) -> StateGraph: workflow.add_conditional_edges( "save_session", route_after_save, - {"validate": "validate"}, + {"validate": "validate", "finalize": "finalize"}, ) # ---- 验证 → 修正循环 ---- diff --git a/agent/nodes.py b/agent/nodes.py index 1097708..4e9874b 100644 --- a/agent/nodes.py +++ b/agent/nodes.py @@ -543,9 +543,9 @@ def _programmatic_map_fields(jrxml: str, ocr_fields: list[dict]) -> str: real_name = _sanitize_field_name(raw_name) if real_name == placeholder: continue - # 替换 field 声明: ]*\bname\s*=\s*"){re.escape(placeholder)}(")', + rf'(<[\w:]*field\b[^>]*\bname\s*=\s*"){re.escape(placeholder)}(")', rf'\g<1>{real_name}\g<2>', result, ) # 替换所有引用: $F{{field_1}} → $F{{customer_name}} @@ -821,7 +821,7 @@ def generate(state: AgentState) -> Dict: from langgraph.config import get_stream_writer writer = get_stream_writer() - llm = get_llm(caller="generate") + llm = get_llm(caller="generate", max_tokens=32768) user_request = state.get("user_input", "") ocr_text = _format_ocr_context(state) @@ -849,7 +849,6 @@ def generate_skeleton(state: AgentState) -> Dict: from langgraph.config import get_stream_writer writer = get_stream_writer() - llm = get_llm(caller="generate_skeleton") schema = state.get("layout_schema", {}) schema_text = schema.get("schema_text", "") if isinstance(schema, dict) else "" @@ -861,10 +860,16 @@ def generate_skeleton(state: AgentState) -> Dict: user_request=user_request, template_context=_build_template_context(state), ) + llm = get_llm(caller="generate_skeleton", max_tokens=32768) + prev_jrxml = state.get("current_jrxml", "") full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton") if not full_text.strip(): - _node_log.error("generate_skeleton LLM 返回空响应") + _node_log.warning("generate_skeleton 首次返回空响应,以更高 max_tokens 重试") + llm = get_llm(caller="generate_skeleton", max_tokens=65536) + full_text = _generate_with_continuation(llm, prompt, writer, "generate_skeleton") + if not full_text.strip(): + _node_log.error("generate_skeleton LLM 返回空响应(含重试)") return state jrxml = _extract_jrxml(full_text) if len(jrxml.strip()) < 200: @@ -1025,7 +1030,7 @@ def modify_jrxml(state: AgentState) -> Dict: from langgraph.config import get_stream_writer writer = get_stream_writer() - llm = get_llm(caller="modify_jrxml") + llm = get_llm(caller="modify_jrxml", max_tokens=32768) # 构建对话上下文:压缩摘要 + 最近对话 compressed = state.get("compressed_history", "") recent = state.get("conversation_history", [])[-6:] @@ -1278,6 +1283,7 @@ def validate(state: AgentState) -> Dict: result = validate_jrxml(jrxml) state["status"] = "pass" if result.get("valid") else "fail" state["error_msg"] = result.get("error", "") + state["service_unavailable"] = result.get("service_unavailable", False) # OCR 保真度检查:比对生成结果与原始图片的 OCR 提取内容 fidelity = _check_ocr_fidelity(jrxml, state) @@ -1378,7 +1384,7 @@ def correct_jrxml(state: AgentState) -> Dict: from langgraph.config import get_stream_writer writer = get_stream_writer() - llm = get_llm(caller="correct_jrxml") + llm = get_llm(caller="correct_jrxml", max_tokens=32768) ocr_context = _format_ocr_context(state) layout_schema = state.get("layout_schema", {}) layout_text = "" @@ -1432,6 +1438,13 @@ def correct_jrxml(state: AgentState) -> Dict: _node_log.warning(f"correct_jrxml 输出过短({len(jrxml)} 字符),回退到前一版本") jrxml = prev_jrxml + # 如果提取结果不是合法 JRXML(不含 Dict: state["retry_count"] = state.get("retry_count", 0) + 2 else: state["retry_count"] = state.get("retry_count", 0) + 1 + state["current_jrxml"] = jrxml state["conversation_history"].append( {"role": "assistant", "content": f"[自动修正,第 {state['retry_count']} 次尝试]\n{jrxml}"} ) @@ -1510,6 +1524,31 @@ def finalize(state: AgentState) -> Dict: return state +def _strip_continuation_wrapper(text: str) -> str: + """去除续写响应中的 markdown 代码块标记和自然语言解释。 + + 续写轮次的 LLM 可能会"忘记"原始 prompt 中的格式要求, + 在响应开头加解释文字、用 ``` 包裹 XML 片段。 + 此函数提取其中的纯 XML 内容,去除包装。 + """ + text = text.strip() + # 移除完整的 markdown 代码块包装: ```...``` + m = re.search(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", text, re.IGNORECASE) + if m: + inner = m.group(1).strip() + if inner: + return inner + # 移除开头/结尾的独立 ``` 标记(不完整代码块) + text = re.sub(r"^```(?:xml|jrxml)?\s*", "", text) + text = re.sub(r"```\s*$", "", text) + # 移除续写响应常见的自然语言前缀 + text = re.sub( + r"^.{0,40}(继续输出|剩余|续写|补全|接上).{0,30}[::]?\s*", + "", text, flags=re.IGNORECASE + ) + return text.strip() + + def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> str: """Stream LLM generation with automatic truncation recovery. @@ -1519,6 +1558,7 @@ def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> Returns combined full text from all rounds. """ + _jrxml_end = r"\s*$" full_text = "" for round_num in range(max_rounds): @@ -1529,7 +1569,8 @@ def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> current_prompt = ( f"[系统指令] 你正在生成的 JRXML 在上一次响应中被截断。\n" f"已生成内容的最后部分(请从此处继续):\n...{tail}\n\n" - f"请从截断点继续输出剩余内容,不要重复已输出的部分。" + f"请从截断点继续输出剩余内容,不要重复已输出的部分。\n" + f"不要输出 markdown 代码块、解释或任何非 JRXML 的内容。" ) new_chunks = [] @@ -1538,10 +1579,12 @@ def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> writer({"type": "stream", "node": node_name, "text": chunk}) new_text = "".join(new_chunks) + if round_num > 0: + new_text = _strip_continuation_wrapper(new_text) full_text += new_text jrxml = _extract_jrxml(full_text) - if re.search(r"\s*$", jrxml, re.IGNORECASE): + if re.search(_jrxml_end, jrxml, re.IGNORECASE): break if not new_text.strip(): @@ -1554,17 +1597,26 @@ def _generate_with_continuation(llm, prompt, writer, node_name, max_rounds=3) -> def _extract_jrxml(text: str) -> str: - """从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。""" - text = text.strip() - xml_pattern = re.compile(r"```(?:xml|jrxml)?\s*([\s\S]*?)```", re.IGNORECASE) - m = xml_pattern.search(text) - if m: - content = m.group(1).strip() - if content: - return content - # markdown 代码块存在但内容为空 — 回退到直接匹配 + """从 LLM 响应中提取 JRXML 内容,如有 markdown 标记则去除。 - _jrxml_close = r"" + 处理多种情况: + 1. 完整的 markdown 代码块包裹(单轮输出) + 2. 混合文本(多轮续写:第一轮无 markdown,续写轮添加了 markdown) + 3. 纯 JRXML 无包装 + """ + text = text.strip() + # 检测并提取 markdown 代码块中的内容 + # 如果第一个代码块的内容看起来是完整 JRXML(以 或 ... + _jrxml_close = r"" jasper_tag = re.search(rf"(<\?xml[\s\S]*?{_jrxml_close})", text, re.IGNORECASE) if jasper_tag: return jasper_tag.group(1).strip() @@ -1572,8 +1624,7 @@ def _extract_jrxml(text: str) -> str: if text.startswith("= 0 and jr_close: diff --git a/backend/llm.py b/backend/llm.py index cb63932..ac1e684 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -156,8 +156,14 @@ class _LLMLoggingWrapper(_BaseLLM): raise -def _build_raw_llm(caller: str = "") -> tuple[_BaseLLM, str, str]: - """构造原始 LLM 实例,返回 (实例, model名, backend名)。""" +DEFAULT_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "8192")) + + +def _build_raw_llm(caller: str = "", max_tokens: int | None = None) -> tuple[_BaseLLM, str, str]: + """构造原始 LLM 实例,返回 (实例, model名, backend名)。 + + max_tokens: 覆盖默认输出 token 数。None 使用 LLM_MAX_TOKENS 环境变量或 8192。 + """ backend = os.getenv("LLM_BACKEND", "cloud") if backend == "local": from langchain_ollama import ChatOllama @@ -183,18 +189,19 @@ def _build_raw_llm(caller: str = "") -> tuple[_BaseLLM, str, str]: base_url = os.getenv("ANTHROPIC_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://api.minimaxi.com/anthropic") model = os.getenv("LLM_MODEL", "MiniMax-M2.7") temperature = 0.1 - max_tokens = 8192 + _default_max_tokens = max_tokens if max_tokens is not None else DEFAULT_MAX_TOKENS client = Anthropic(api_key=api_key, base_url=base_url, timeout=120) class MiniMaxLLM(_BaseLLM): def __init__(self): self._last_stop_reason = None + self._max_tokens = _default_max_tokens def invoke(self, prompt: str) -> Any: resp = client.messages.create( model=model, - max_tokens=max_tokens, + max_tokens=self._max_tokens, temperature=temperature, messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}], ) @@ -208,7 +215,7 @@ def _build_raw_llm(caller: str = "") -> tuple[_BaseLLM, str, str]: self._last_stop_reason = None with client.messages.stream( model=model, - max_tokens=max_tokens, + max_tokens=self._max_tokens, temperature=temperature, messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}], ) as s: @@ -250,9 +257,12 @@ def _build_raw_llm(caller: str = "") -> tuple[_BaseLLM, str, str]: return OpenAIWrapper(), model, f"cloud/openai/{model}" -def get_llm(caller: str = "") -> _BaseLLM: - """返回带日志的 LLM 实例。caller 用于标识调用来源(如 generate、classify_intent)。""" - inner, model, backend = _build_raw_llm(caller) +def get_llm(caller: str = "", max_tokens: int | None = None) -> _BaseLLM: + """返回带日志的 LLM 实例。caller 用于标识调用来源(如 generate、classify_intent)。 + + max_tokens: 覆盖默认输出 token 数。用于骨架生成等需要大量输出的节点。 + """ + inner, model, backend = _build_raw_llm(caller, max_tokens=max_tokens) return _LLMLoggingWrapper(inner, model=model, backend=backend, caller=caller) diff --git a/backend/validation.py b/backend/validation.py index f906525..3a37f9f 100644 --- a/backend/validation.py +++ b/backend/validation.py @@ -4,6 +4,7 @@ import os import httpx from dotenv import load_dotenv +from httpx import ConnectError, HTTPStatusError from backend.logger import get_logger @@ -31,10 +32,19 @@ def validate_jrxml(jrxml_text: str) -> dict: }, ) return result - except httpx.ConnectError: + except ConnectError: error_msg = f"无法连接到验证服务 ({VALIDATION_URL})。是否正在运行?" _val_log.error("验证服务连接失败", extra={"error": error_msg, "url": VALIDATION_URL}) - return {"valid": False, "error": error_msg} + return {"valid": False, "error": error_msg, "service_unavailable": True} + except HTTPStatusError as e: + status_code = e.response.status_code + error_msg = f"验证服务返回错误 ({status_code}): {str(e)}" + _val_log.error("验证请求异常", extra={"error": str(e), "url": VALIDATION_URL, "status_code": status_code}) + return { + "valid": False, + "error": error_msg, + "service_unavailable": status_code >= 500, + } except Exception as e: error_msg = f"验证请求失败: {str(e)}" _val_log.error("验证请求异常", extra={"error": str(e), "url": VALIDATION_URL}) diff --git a/tests/test_continuation_extraction.py b/tests/test_continuation_extraction.py new file mode 100644 index 0000000..6289567 --- /dev/null +++ b/tests/test_continuation_extraction.py @@ -0,0 +1,228 @@ +"""续写 + JRXML 提取单元测试。 + +测试 _strip_continuation_wrapper、_extract_jrxml 在 +多轮续写场景下的鲁棒性,以及 _generate_with_continuation 的完成检测。 +""" + +from __future__ import annotations + +import pytest +from agent.nodes import _strip_continuation_wrapper, _extract_jrxml + +# ── 完整 JRXML ───────────────────────────────────────────────────── + +COMPLETE_JRXML = """ + + + + + <band height="50"> + <staticText> + <reportElement x="0" y="0" width="100" height="20"/> + <text><![CDATA[$F{field_1}]]></text> + </staticText> + </band> + +""" + +# 第一轮输出:完整开头但缺少 (模拟截断) +ROUND1_TRUNCATED = """ + + + + + + <band height="50"> + <staticText> + <reportElement x="0" y="0" width="100" height="20"/> + <text><![CDATA[$F{field_1}]]></text> + </staticText> + </band> + + + + + + + + + + +```""" + +# 第二轮续写变体:用 关闭(另一种常见 LLM 错误) +ROUND2_REPORT_CLOSE = """继续输出: + +``` + + + + + +```""" + +# 第二轮续写变体:只用 ``` 开头,无结尾(不完整代码块) +ROUND2_PARTIAL_MARKDOWN = """ +```xml + + + + + +```""" + + +# ── _strip_continuation_wrapper 测试 ─────────────────────────────── + +class TestStripContinuationWrapper: + def test_removes_complete_markdown_block(self): + text = '继续输出:\n\n```\ntest\n```' + result = _strip_continuation_wrapper(text) + assert result == 'test' + + def test_removes_xml_fenced_block(self): + text = '```xml\ntest\n```' + result = _strip_continuation_wrapper(text) + assert result == 'test' + + def test_removes_opening_fence_only(self): + text = '```xml\ntest' + result = _strip_continuation_wrapper(text) + assert 'test' in result + assert '```' not in result + + def test_removes_closing_fence_only(self): + text = 'test\n```' + result = _strip_continuation_wrapper(text) + assert 'test' in result + assert '```' not in result + + def test_removes_continuation_prefix_chinese(self): + text = '继续输出剩余的 JRXML 内容:\ntest' + result = _strip_continuation_wrapper(text) + assert result == 'test' + + def test_pure_xml_passes_through(self): + text = 'test' + result = _strip_continuation_wrapper(text) + assert result == 'test' + + def test_empty_becomes_empty(self): + assert _strip_continuation_wrapper('') == '' + assert _strip_continuation_wrapper(' ') == '' + + def test_empty_markdown_block_returns_empty(self): + text = '```xml\n```' + result = _strip_continuation_wrapper(text) + assert result == '' + + def test_multiple_backtick_pairs_extracts_first_valid(self): + text = '```\nfragment\n```\n```xml\ncomplete" in result + assert '$F{field_1}' in result + assert '$F{field_2}' in result + + def test_extracts_with_report_close_tag(self): + """第二轮用 而非 关闭。""" + combined = ROUND1_TRUNCATED + ROUND2_REPORT_CLOSE + result = _extract_jrxml(combined) + assert result.startswith("" in result + assert '$F{field_2}' in result + + def test_extracts_with_partial_markdown(self): + """第二轮用 ```xml 开头,``` 结尾。""" + combined = ROUND1_TRUNCATED + ROUND2_PARTIAL_MARKDOWN + result = _extract_jrxml(combined) + assert result.startswith("" in result + + def test_single_round_complete_jrxml_in_markdown(self): + """单轮输出:完整的 JRXML 在 markdown 代码块中。""" + text = '```xml\n' + COMPLETE_JRXML + '\n```' + result = _extract_jrxml(text) + assert result == COMPLETE_JRXML + + def test_single_round_pure_jrxml(self): + """单轮输出:纯 JRXML 无 markdown。""" + result = _extract_jrxml(COMPLETE_JRXML) + assert result == COMPLETE_JRXML + + def test_jrxml_with_leading_explanation(self): + """JRXML 前有自然语言解释。""" + text = '这是生成的报表模板:\n' + COMPLETE_JRXML + result = _extract_jrxml(text) + assert result == COMPLETE_JRXML + + def test_two_markdown_blocks_skips_fragment(self): + """文本中有两个 markdown 块,第一个是片段,第二个是完整 JRXML。""" + text = ( + '```\nsome fragment\n```\n' + '```xml\n' + COMPLETE_JRXML + '\n```' + ) + result = _extract_jrxml(text) + assert result == COMPLETE_JRXML + + def test_two_markdown_blocks_first_is_complete(self): + """文本中有两个 markdown 块,第一个是完整 JRXML。""" + text = ( + '```xml\n' + COMPLETE_JRXML + '\n```\n' + '```\nsome other stuff\n```' + ) + result = _extract_jrxml(text) + assert result == COMPLETE_JRXML + + def test_no_xml_passes_through(self): + """无 XML 内容的文本原样返回。""" + text = 'Hello, this has no XML at all.' + result = _extract_jrxml(text) + assert result == text + + +# ── 完成检测测试 ─────────────────────────────────────────────────── + +class TestCompletionDetection: + def test_jasperreport_close_detected(self): + """以 结尾的 JRXML 应被识别为完成。""" + import re + jrxml = COMPLETE_JRXML.strip() + _jrxml_end = r"\s*$" + assert re.search(_jrxml_end, jrxml, re.IGNORECASE) + + def test_report_close_detected(self): + """以 结尾的 JRXML 也应被识别为完成。""" + import re + jrxml = COMPLETE_JRXML.replace('', '').strip() + _jrxml_end = r"\s*$" + assert re.search(_jrxml_end, jrxml, re.IGNORECASE) + + def test_namespaced_jasperreport_close_detected(self): + """以 结尾的 JRXML 也应被识别。""" + import re + jrxml = COMPLETE_JRXML.replace('', '').strip() + _jrxml_end = r"\s*$" + assert re.search(_jrxml_end, jrxml, re.IGNORECASE) + + def test_truncated_jrxml_not_detected(self): + """截断的 JRXML(无关闭标签)不应被识别为完成。""" + import re + _jrxml_end = r"\s*$" + assert not re.search(_jrxml_end, ROUND1_TRUNCATED.strip(), re.IGNORECASE)