diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..96f404a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,157 @@ +# CLAUDE.md — JRXML 生成代理 + +## 项目概述 + +一个**本地桌面应用**,通过自然语言多轮对话帮助非技术用户创建 JasperReports 模板(JRXML 文件)。核心技术栈:Streamlit UI + LangGraph 状态机 + LLM 生成/修改 + 自动验证修正循环。 + +**一句话**:用户用中文描述报表需求 → LLM 生成 JRXML → 自动验证 → 失败则自动修正(最多3次) → 返回可编译的 JRXML 文件。 + +## 启动命令 + +```bash +# 终端 1 — 验证服务(必须先启动) +python -m uvicorn validation_service.main:app --port 8001 --host 0.0.0.0 + +# 终端 2 — Streamlit UI +STREAMLIT_SERVER_HEADLESS=true streamlit run app.py --server.port 8501 +``` + +浏览器打开 `http://localhost:8501`。 + +## 当前配置(.env) + +- **LLM**: `cloud` / `anthropic` → MiniMax Anthropic 兼容 API (`MiniMax-M2.7`) + - Base URL: `https://api.minimaxi.com/anthropic` + - 认证: 通过 `OPENAI_API_KEY` 传入 Anthropic SDK(注意不是 `ANTHROPIC_API_KEY`) + - 绕过代理: 代码中设 `NO_PROXY=*` +- **嵌入模型**: `local` / `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2` +- **向量库**: ChromaDB 持久化在 `./db/chroma` +- **验证服务**: FastAPI `localhost:8001` + +## 架构 + +``` +app.py (Streamlit UI) + │ run_agent(user_input) + │ 功能: 流式输出/节点平铺/文件上传/历史下载/预览/Ctrl+C修复 + ▼ +agent/graph.py (LangGraph 状态机) + │ 节点流程: + │ load_session → process_input → manage_context → save_state_snapshot + │ → classify_intent (8种意图路由) + │ ├─ retrieve → generate → save_session → validate → ... → finalize + │ ├─ modify_jrxml → save_session → validate → ... → finalize + │ ├─ handle_consult / handle_undo / handle_reset → finalize + │ └─ preview/export → save_session → finalize (跳过验证) + │ + │ 验证修正循环: validate ─fail─► explain_error ─► correct_jrxml ─► validate + │ ▲ │ + │ └──────── (retry < MAX_RETRY=3) ───────────────────┘ + │ + ├──► prompts/loader.py Prompt 外部化:7 个 .md 文件热重载 + ├──► backend/llm.py LLM 工厂: Anthropic SDK / OpenAI / Ollama (统一 stream/invoke) + ├──► backend/rag_adapter.py 语义搜索: ChromaDB + SentenceTransformer + ├──► backend/error_kb.py 错误知识库: 指纹去重 + ChromaDB 持久化 + ├──► backend/file_parser.py 文件解析: PDF/DOCX/图片/文本 + ├──► backend/layout_analyzer.py A4布局分析: OCR + 行分组 + JRXML行匹配 + ├──► backend/validation.py HTTP 客户端: POST /validate + ├──► backend/session.py 会话持久化: JSON 文件 CRUD + └──► validation_service/ 独立 FastAPI: 结构检查 + XSD 校验 +``` + +## 关键文件映射 + +| 文件 | 职责 | 修改频率 | +|------|------|---------| +| `app.py` | Streamlit UI 入口,聊天界面 + 侧边栏 + 下载 + 文件上传 | **高** | +| `agent/state.py` | AgentState 类型定义(~23 字段,含 jrxml_versions/last_error_case) | 低 | +| `agent/nodes.py` | 14 个工作流节点 + 流式生成 + 错误记录 | **高** | +| `agent/graph.py` | 状态图编译 + 路由函数(预览跳过验证) | 中 | +| `prompts/loader.py` | Prompt 加载器(从 .md 文件热重载) | 低 | +| `prompts/*.md` | 7 个独立 Prompt 模板 | **高** | +| `backend/llm.py` | LLM 工厂,统一 `_BaseLLM` 接口(invoke + stream) | 中 | +| `backend/rag_adapter.py` | RAGSearcher 单例,语义搜索接口 | 中 | +| `backend/error_kb.py` | ErrorKB — 错误指纹去重 + ChromaDB 持久化 + 语义检索 | 中 | +| `backend/file_parser.py` | 文件解析: PDF(pdfplumber)/DOCX(python-docx)/图片(PIL+PaddleOCR可选)/文本 | 中 | +| `backend/layout_analyzer.py` | A4模板分析: 比例检测/PaddleOCR元素提取/行分组/JRXML行匹配 | 中 | +| `backend/embeddings.py` | 嵌入模型工厂 (HuggingFace/OpenAI) | 低 | +| `backend/validation.py` | 验证服务 HTTP 客户端 | 低 | +| `backend/session.py` | 会话 JSON 文件 CRUD | 低 | +| `validation_service/main.py` | FastAPI 验证服务 | 低 | +| `scripts/init_kb.py` | 知识库初始化/模型下载 | 低 | + +## 关键约定 + +1. **LLM 调用接口**: 所有节点通过 `get_llm().invoke(prompt)` 同步调用,或用 `get_llm().stream(prompt)` 流式调用。三个后端(Anthropic/OpenAI/Ollama)通过 `_BaseLLM` 统一接口。 + +2. **流式生成**: generate/modify_jrxml/correct_jrxml 使用 `get_stream_writer()` 发射自定义事件,UI 通过 `stream_mode=["updates", "custom"]` 捕获逐字输出。 + +3. **JRXML 提取**: `_extract_jrxml()` 处理 LLM 响应 —— 去掉 markdown 代码块标记,提取 XML 内容。 + +4. **状态持久化**: 每个会话存为 `sessions/{session_id}.json`,LangGraph 节点间通过 AgentState dict 传递。 + +5. **Token 计数**: 使用 `tiktoken` (gpt-4o encoder) 估算,不管实际模型是什么。 + +6. **RAG 子模块**: `rag/` 是一个独立的 git submodule,其内部的生成产物 (`models/`, `embeddings/`, `chroma_db/`, `jrxml_source_chunks/`) 不在 git 中。 + +## Prompt 模板位置 + +所有 Prompt 在 `prompts/` 目录,`.md` 文件可直接编辑,无需重启应用: + +| 文件 | 用途 | +|------|------| +| `prompts/intent_classify.md` | 8 分类意图识别 | +| `prompts/initial_generation.md` | 首次生成 JRXML | +| `prompts/modification.md` | 修改现有 JRXML | +| `prompts/correction.md` | 自动修正错误 | +| `prompts/explain_error.md` | 错误转人话 | +| `prompts/compression.md` | 对话压缩摘要 | +| `prompts/consult.md` | 咨询解答 | + +## 新增功能 (v2) + +### 流式输出 + 节点平铺 +- LLM 生成时逐字展示 XML(不再是空白等待) +- 节点以"处理过程"折叠区展开,不相互覆盖 +- 完成后自动折叠,展示总结卡片 + +### 错误自增长知识库 +- `backend/error_kb.py` — ChromaDB 集合 `jrxml_error_cases` +- 错误指纹去重(标准化 + MD5):相同结构错误不重复录入 +- 记录内容:错误信息 + 修正前后 JRXML + 修正 prompt + 工具链 +- `retrieve` 节点自动注入历史修正案例 +- 流程:correct_jrxml 保存 last_error_case → validate 通过时自动入库 + +### 文件上传 +- 侧边栏多文件上传(可逐文件移除) +- 支持: PDF(pdfplumber+PIL) / DOCX(python-docx) / 图片(PIL+PaddleOCR可选) / 纯文本 +- 上传文本自动注入下一条消息前缀 +- 根据 `can_use_vision()` 判断是否走原生多模态(当前 MiniMax 不支持) + +### A4 模板识别 +- `backend/layout_analyzer.py` — 三种处理路径: + - **完整 A4**: 比例匹配 + OCR 元素 → 全量布局描述 + - **行片段 + 有现有报表**: 行匹配到 JRXML section → 定位修改 + - **行片段 + 无现有报表**: 按 A4 模板生成完整报表 +- PaddleOCR(可选安装)提供精确元素位置/字号 +- 行分组:Y 轴容差自动聚类;行匹配:文本相似度搜索 JRXML band + +### 会话历史下载 +- `AgentState.jrxml_versions` 追踪每次生成/修改的版本 +- 侧边栏"历史版本"折叠区,每版本独立下载按钮 + +### 预览修复 +- `route_after_save` 新增意图判断:预览/导出跳过验证直通 finalize + +### Ctrl+C 修复 +- JS 注入拦截 Streamlit 裸 `c` 键清缓存,保留 Ctrl+C 复制 + +## 已知注意点 + +- **Anthropic SDK**: 使用原始 `anthropic` 包(非 `langchain-anthropic`),因为需要直连 MiniMax 兼容端点。`backend/llm.py:31` 创建的 `Anthropic()` 必须传入 `api_key`,SDK 不会自动读 `OPENAI_API_KEY`。 +- **Windows 环境**: NO_PROXY 设为 `*` 避免代理干扰 MiniMax API。 +- **Streamlit headless**: Windows 下必须设 `STREAMLIT_SERVER_HEADLESS=true` 跳过邮箱采集提示。 +- **验证服务结构检查**: 字段引用一致性 (`$F{field}` vs `` 声明)、SQL SELECT 存在性、pageWidth/pageHeight/name 属性。 +- **XSD 校验可选**: 需要 `validation_service/schemas/jasperreport_7_0_6.xsd` 存在。 +- **rag 子模块**: 内部有独立的管线脚本(`batch_chunker.py` → `embed_chunks.py` → `import_to_chroma.py`),通常不需要在主项目中运行。 +- **PaddleOCR 可选**: A4 模板精确识别需要 `pip install paddleocr`,未安装时仅返回图片元信息。 diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..8c0579d --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,92 @@ +# 改进路线图 + +## 阶段一:代码质量(低风险,快速交付) + +### 1. Prompt 拆分 ✓ +- [x] 创建 `prompts/` 目录 +- [x] 7 个 prompt 各拆为独立 `.md` 文件 +- [x] `nodes.py` 改为从文件加载 +- [x] 支持热重载(文件变更无需重启) + +### 2. 修复无效代码 ✓ +- [x] `backend/llm.py` — `get_num_tokens()` 修复为正确 API +- [x] `backend/embeddings.py` — 修复 docstring 函数名不一致 +- [x] `backend/llm.py` — 统一 LLM 接口基类 `_BaseLLM` + +--- + +## 阶段二:用户体验(核心改造) + +### 3. 流式输出 + 节点平铺 ✓ +- [x] `backend/llm.py` — LLM 工厂支持 `stream()` 统一接口 +- [x] `agent/nodes.py` — generate/modify/correct 节点使用流式 + `get_stream_writer()` +- [x] `app.py` — 使用 `stream_mode=["updates", "custom"]` 捕获流式事件 +- [x] 节点状态平铺(处理过程 expander 逐节点展示) +- [x] 流式完成后节点自动折叠 +- [x] 完成后单独展示「总结卡片」 + +### 4. 错误自增长知识库 ✓ +- [x] `backend/error_kb.py` — ErrorKB 类(ChromaDB 持久化) +- [x] 错误指纹去重(标准化 + MD5) +- [x] `correct_jrxml` — 保存修正前状态到 `last_error_case` +- [x] `validate` — 修正成功时自动记录(仅新错误,自动去重) +- [x] `retrieve` — 搜索错误知识库,注入历史修正案例 +- [x] 记录内容:错误 + 修正前后 JRXML + prompt + 工具链 + 模型 + +### 5. 文件上传支持 ✓ +- [x] `backend/file_parser.py` — 统一解析接口 + - [x] 图片 → PIL 元信息 + PaddleOCR(可选安装后自动识别) + - [x] PDF → pdfplumber / PyMuPDF 文本提取 + - [x] DOCX → python-docx 文本提取 + - [x] 纯文本 (.txt/.csv/.json/.xml) → 直接读取 +- [x] `can_use_vision()` — 根据模型名判断是否支持原生多模态 +- [x] `app.py` — 侧边栏文件上传组件(多文件,可移除) +- [x] 上传文本自动注入下一条消息前缀 + +### 6. A4 图片模板识别 ✓ +- [x] `backend/layout_analyzer.py` — 完整布局分析模块 +- [x] A4 比例判定:exact(±3%) / close(±8%) / not_a4 三档 +- [x] PaddleOCR 布局分析:逐元素提取坐标(x,y,w,h)、字号、文本 +- [x] 行分组:Y 轴容差自动聚类 +- [x] 结构化输出:`图片模板共 X 行,第 1 行有 Y 个元素,其中元素 a 长...高...字体...内容是...` +- [x] 检测门槛:≥2 个 OCR 元素 + A4 比例 → 标记为模板 +- [x] `app.py` — 上传图片/PDF 时自动触发布局分析,替换为布局描述 + +### 7. 会话历史 JRXML 下载 ✓ +- [x] `agent/state.py` — 新增 `jrxml_versions` 字段 +- [x] `agent/nodes.py` — `finalize` 节点追加版本记录 +- [x] `app.py` — 侧边栏"历史版本"折叠区,每版本独立下载按钮 + +### 8. 预览功能修复 ✓ +- [x] 根因:`preview_report` 路由到 `save_session` → `validate` 触发不必要的验证修正循环 +- [x] 修复:`route_after_save` — 预览/导出意图跳过验证直接 `finalize` + +--- + +## 阶段三:细节修复 + +### 9. Ctrl+C 修复 ✓ +- [x] `app.py` — 注入 JS 拦截裸 `c` 键,保留 Ctrl+C 复制行为 + +--- + +## 执行顺序建议 + +``` +1. Prompt 拆分 ──► 2. 无效代码修复 + │ + ▼ + 3. 流式输出 + 节点平铺 + │ + ┌─────────────┼─────────────┐ + ▼ ▼ ▼ + 4. 错误自增长 5. 文件上传 7. 下载历史 + │ │ + ▼ ▼ + 6. A4 模板识别 8. 预览修复 + │ + ▼ + 9. Ctrl+C 修复 +``` + +阶段一立即可做,无外部依赖。阶段二是主要工作量。阶段三是收尾。 diff --git a/agent/graph.py b/agent/graph.py index 400e802..f7f8725 100644 --- a/agent/graph.py +++ b/agent/graph.py @@ -71,7 +71,11 @@ def route_after_undo(state: AgentState) -> Literal["save_session"]: return "save_session" -def route_after_save(state: AgentState) -> Literal["validate"]: +def route_after_save(state: AgentState) -> Literal["validate", "finalize"]: + # 预览/导出意图跳过验证,直接完成 + intent = state.get("intent", "") + if intent in ("preview_report", "export_pdf", "export_jrxml"): + return "finalize" return "validate" @@ -222,4 +226,6 @@ def create_initial_state() -> AgentState: updated_at="", intent="", history_states=[], + jrxml_versions=[], + last_error_case={}, ) diff --git a/agent/nodes.py b/agent/nodes.py index 2e70276..e267742 100644 --- a/agent/nodes.py +++ b/agent/nodes.py @@ -12,6 +12,7 @@ from dotenv import load_dotenv from agent.state import AgentState from backend.llm import get_llm from backend.validation import validate_jrxml +from prompts.loader import load_prompt load_dotenv() @@ -20,119 +21,6 @@ CONTEXT_MAX_TOKENS = int(os.getenv("CONTEXT_MAX_TOKENS", "6000")) CONTEXT_KEEP_RECENT = int(os.getenv("CONTEXT_KEEP_RECENT", "4")) HISTORY_MAX_SNAPSHOTS = int(os.getenv("HISTORY_MAX_SNAPSHOTS", "10")) -# ============================================================ -# 意图分类提示词(约 180 tokens,控制在 200 token 以内) -# ============================================================ -INTENT_CLASSIFY_PROMPT = """你是意图分类器。根据用户输入判断意图,只输出意图名称。 - -当前有报表:{has_report} -用户输入:{user_input} - -可选意图: -- initial_generation(新建报表,或无报表时的任何需求) -- modify_report(修改当前已有报表) -- preview_report(预览/查看当前报表) -- export_pdf(导出PDF文件) -- export_jrxml(下载/导出/保存JRXML文件) -- undo_modification(撤销/回退上一步修改) -- consult_question(咨询JasperReports相关知识或使用问题) -- reset_session(清空/重置/重新开始) - -意图名称:""" - -# ============================================================ -# 咨询回答提示词 -# ============================================================ -CONSULT_PROMPT = """你是 JasperReports 专家。用简洁清晰的中文回答用户关于 JasperReports 的问题。 - -用户问题:{question} - -直接回答:""" - -# ============================================================ -# 原有提示词(不变) -# ============================================================ -INITIAL_GENERATION_PROMPT = """你是一位资深 JasperReports 工程师。根据以下参考模板和用户需求,生成一个完整、可编译的 JRXML 文件。 -JRXML 必须兼容 JasperReports 7.0.6 schema。 - -关键规则: -- 只输出 JRXML 代码,不要解释,不要 markdown 标记。 -- 报表正文中使用的每个字段必须在 部分中声明。 -- 根元素为 ,包含正确的 xmlns 属性。 -- 包含 ,在 中包含 SQL 查询。 -- 确保所有交叉引用(字段名称、band 元素)保持一致。 - -参考模板和组件: -{context} - -用户需求: -{user_request} -""" - -MODIFICATION_PROMPT = """你是一位资深 JasperReports 工程师。用户想要修改一个现有的、可编译的 JRXML 报表。精确应用请求的更改到当前 JRXML 并输出完整修改后的 JRXML。 - -关键规则: -- 只输出完整修改后的 JRXML 代码,不要解释,不要 markdown 标记。 -- 保留所有未被更改的现有结构。 -- 结果必须继续与 JasperReports 7.0.6 兼容。 -- 报表正文中使用的每个字段必须在 部分中声明。 -- 如果添加新字段,正确声明它们。 -- 确保 中有效的 SQL。 - -当前 JRXML: -{current_jrxml} - -对话历史: -{conversation_history} - -用户的修改请求: -{modification_request} -""" - -CORRECTION_PROMPT = """你是一位资深 JasperReports 工程师。你生成的 JRXML 文件编译失败。分析错误并修复 JRXML。 - -关键规则: -- 只输出完整修复后的 JRXML 代码,不要解释,不要 markdown 标记。 -- JRXML 必须与 JasperReports 7.0.6 兼容。 -- 解决下面列出的特定错误。 - -当前 JRXML(带错误): -{current_jrxml} - -编译错误: -{error_msg} - -错误的自然语言解释: -{explanation} - -立即生成修正后的 JRXML: -""" - -EXPLAIN_PROMPT = """你是一位 JasperReports 专家。用普通非技术语言解释以下 JRXML 编译错误,让业务用户能够理解。 - -错误消息: -{error_msg} - -当前 JRXML 片段(前 80 行): -{jrxml_snippet} - -用 2-3 句话解释哪里出了问题以及如何修复: -""" - -COMPRESSION_PROMPT = """你是一个信息压缩助手。以下是用户与报表生成助手之间的历史对话记录,请将其压缩为一份简洁的摘要(不超过200字)。 - -摘要必须保留以下关键信息: -- 用户提出的所有报表需求点(字段、标题、分组、汇总等) -- 用户提出的所有修改要求及其顺序 -- 当前报表的核心结构(字段列表、标题、分组方式) -- 任何特殊要求或约束条件 - -只输出摘要文本,不要添加任何解释或标记。 - -对话记录: -{conversation_text} -""" - # ============================================================ # 核心工作流节点 @@ -191,7 +79,7 @@ def classify_intent(state: AgentState) -> Dict: intent = "initial_generation" try: llm = get_llm() - prompt = INTENT_CLASSIFY_PROMPT.format( + prompt = load_prompt("intent_classify").format( has_report=has_report, user_input=user_input[:500], ) @@ -222,7 +110,7 @@ def handle_consult(state: AgentState) -> Dict: user_input = state.get("user_input", "") try: llm = get_llm() - prompt = CONSULT_PROMPT.format(question=user_input) + prompt = load_prompt("consult").format(question=user_input) resp = llm.invoke(prompt) answer = resp.content.strip() except Exception: @@ -332,7 +220,7 @@ def manage_context(state: AgentState) -> Dict: try: llm = get_llm() - prompt = COMPRESSION_PROMPT.format(conversation_text=conv_text) + prompt = load_prompt("compression").format(conversation_text=conv_text) resp = llm.invoke(prompt) new_compressed = resp.content.strip()[:300] except Exception: @@ -421,12 +309,21 @@ def _now_iso() -> str: def retrieve(state: AgentState) -> Dict: - """在 Chroma 中搜索相关的 JRXML 模板和组件(使用 rag_jrxml 语义分块管线)。""" + """在 ChromaDB + 错误知识库中搜索相关的 JRXML 模板和组件。""" try: from backend.rag_adapter import search_chunks + from backend.error_kb import search_error_cases user_input = state.get("user_input", "") context = search_chunks(user_input, k=5) + + # 如果有最近错误,同时搜索错误知识库 + error_msg = state.get("error_msg", "") + if error_msg: + error_context = search_error_cases(error_msg, k=2) + if error_context: + context = f"{context}\n\n[历史错误修正案例]\n{error_context}" + state["retrieved_context"] = context except Exception: state["retrieved_context"] = "" @@ -435,13 +332,19 @@ def retrieve(state: AgentState) -> Dict: def generate(state: AgentState) -> Dict: """根据用户需求和检索到的上下文生成初始 JRXML。""" + from langgraph.config import get_stream_writer + + writer = get_stream_writer() llm = get_llm() - prompt = INITIAL_GENERATION_PROMPT.format( + prompt = load_prompt("initial_generation").format( context=state.get("retrieved_context", ""), user_request=state.get("user_input", ""), ) - resp = llm.invoke(prompt) - jrxml = _extract_jrxml(resp.content) + full = [] + for chunk in llm.stream(prompt): + full.append(chunk) + writer({"type": "stream", "node": "generate", "text": chunk}) + jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append({"role": "assistant", "content": jrxml}) return state @@ -449,6 +352,9 @@ def generate(state: AgentState) -> Dict: def modify_jrxml(state: AgentState) -> Dict: """根据用户的修改请求修改现有 JRXML。""" + from langgraph.config import get_stream_writer + + writer = get_stream_writer() llm = get_llm() # 构建对话上下文:压缩摘要 + 最近对话 compressed = state.get("compressed_history", "") @@ -459,13 +365,16 @@ def modify_jrxml(state: AgentState) -> Dict: conv_parts.append(json.dumps(recent, ensure_ascii=False, indent=2)) conv_text = "\n\n---\n\n".join(conv_parts) - prompt = MODIFICATION_PROMPT.format( + prompt = load_prompt("modification").format( current_jrxml=state.get("current_jrxml", ""), conversation_history=conv_text, modification_request=state.get("user_modification_request", ""), ) - resp = llm.invoke(prompt) - jrxml = _extract_jrxml(resp.content) + full = [] + for chunk in llm.stream(prompt): + full.append(chunk) + writer({"type": "stream", "node": "modify_jrxml", "text": chunk}) + jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["conversation_history"].append( { @@ -496,6 +405,29 @@ def validate(state: AgentState) -> Dict: result = validate_jrxml(jrxml) state["status"] = "pass" if result.get("valid") else "fail" state["error_msg"] = result.get("error", "") + + # 修正成功后记录到错误知识库 + if result.get("valid") and state.get("retry_count", 0) > 0: + case = state.get("last_error_case", {}) + if case and case.get("error_msg"): + try: + from backend.error_kb import record_error + + recorded = record_error( + error_msg=case["error_msg"], + bad_jrxml=case.get("bad_jrxml", ""), + good_jrxml=jrxml, + correction_prompt=case.get("correction_prompt", ""), + retry_count=state.get("retry_count", 0), + ) + if recorded: + state["conversation_history"].append({ + "role": "system", + "content": f"[系统] 错误案例已记录到知识库(指纹: {case['error_msg'][:40]}...)", + }) + except Exception: + pass # 知识库写入不影响主流程 + return state @@ -506,7 +438,7 @@ def explain_error(state: AgentState) -> Dict: lines = jrxml.split("\n")[:80] snippet = "\n".join(lines) - prompt = EXPLAIN_PROMPT.format( + prompt = load_prompt("explain_error").format( error_msg=state.get("error_msg", "未知错误"), jrxml_snippet=snippet, ) @@ -517,14 +449,27 @@ def explain_error(state: AgentState) -> Dict: def correct_jrxml(state: AgentState) -> Dict: """尝试自动修正验证失败的 JRXML。""" + from langgraph.config import get_stream_writer + + writer = get_stream_writer() llm = get_llm() - prompt = CORRECTION_PROMPT.format( + prompt = load_prompt("correction").format( current_jrxml=state.get("current_jrxml", ""), error_msg=state.get("error_msg", ""), explanation=state.get("natural_explanation", ""), ) - resp = llm.invoke(prompt) - jrxml = _extract_jrxml(resp.content) + # 保存修正前状态(供 validate 判断是否写入错误知识库) + state["last_error_case"] = { + "error_msg": state.get("error_msg", ""), + "bad_jrxml": state.get("current_jrxml", ""), + "correction_prompt": prompt, + } + + full = [] + for chunk in llm.stream(prompt): + full.append(chunk) + writer({"type": "stream", "node": "correct_jrxml", "text": chunk}) + jrxml = _extract_jrxml("".join(full)) state["current_jrxml"] = jrxml state["retry_count"] = state.get("retry_count", 0) + 1 state["conversation_history"].append( @@ -534,8 +479,28 @@ def correct_jrxml(state: AgentState) -> Dict: def finalize(state: AgentState) -> Dict: - """保存最终验证通过的 JRXML 并更新对话历史。""" - state["final_jrxml"] = state.get("current_jrxml", "") + """保存最终验证通过的 JRXML 并更新对话历史 + 版本记录。""" + jrxml = state.get("current_jrxml", "") + state["final_jrxml"] = jrxml + + if jrxml.strip(): + versions = state.get("jrxml_versions", []) + if not isinstance(versions, list): + versions = [] + intent = state.get("intent", "") + label_map = { + "initial_generation": "初始生成", + "modify_report": "修改", + "correct_jrxml": f"自动修正 (第{state.get('retry_count', 1)}次)", + } + versions.append({ + "ts": _now_iso(), + "jrxml": jrxml, + "intent": intent, + "label": label_map.get(intent, intent), + "status": state.get("status", ""), + }) + state["jrxml_versions"] = versions return state diff --git a/agent/state.py b/agent/state.py index 3940852..bb2f8e3 100644 --- a/agent/state.py +++ b/agent/state.py @@ -31,3 +31,9 @@ class AgentState(TypedDict, total=False): # 需求3:意图识别 intent: str history_states: List[dict] + + # 需求4:JRXML 版本历史(用于下载历史版本) + jrxml_versions: List[dict] + + # 需求5:错误自增长(记录修正前的状态,供 validate 节点判断是否入知识库) + last_error_case: dict diff --git a/app.py b/app.py index 8fda4a6..11a9951 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,15 @@ -"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。""" +"""Streamlit 多轮对话 UI,用于 JRXML 生成代理。 + +支持: +- 流式输出(LLM 逐字展示) +- 节点平铺展开(每个处理阶段独立展示) +- 完成后自动折叠节点区 +- 过程总结卡片 +""" import os import sys +from pathlib import Path import streamlit as st @@ -23,7 +31,70 @@ st.set_page_config( initial_sidebar_state="expanded", ) -# ---- URL 参数:session_id ---- +# 阻止 Streamlit 裸 'c' 键清除缓存,保留 Ctrl+C 复制行为 +st.components.v1.html(""" + +""", height=0) + +# ---- 节点名称 → 中文标签 ---- +NODE_LABELS = { + "load_session": "📂 加载会话", + "process_input": "📝 记录输入", + "manage_context": "🧠 管理上下文", + "save_state_snapshot": "💾 保存快照", + "classify_intent": "🔍 识别意图", + "retrieve": "📚 检索模板", + "generate": "⚙️ 生成 JRXML", + "modify_jrxml": "🔧 修改 JRXML", + "validate": "✅ 验证", + "explain_error": "🔎 分析错误", + "correct_jrxml": "🛠 自动修正", + "finalize": "📋 完成", + "handle_consult": "💬 咨询回答", + "handle_undo": "↩ 撤销操作", + "handle_reset": "🔄 重置会话", + "save_session": "💾 保存会话", +} + +INTENT_LABELS = { + "initial_generation": "新建报表", + "modify_report": "修改报表", + "preview_report": "预览报表", + "export_pdf": "导出 PDF", + "export_jrxml": "下载 JRXML", + "undo_modification": "撤销修改", + "consult_question": "咨询问题", + "reset_session": "重置会话", +} + +SKIP_NODES = {"load_session", "process_input", "manage_context", + "save_state_snapshot", "save_session"} + + +def _render_jrxml(jrxml: str, max_lines: int = 30): + """展示 JRXML 代码(折叠、限行)。""" + lines = jrxml.strip().split("\n") + preview = "\n".join(lines[:max_lines]) + if len(lines) > max_lines: + preview += f"\n... (共 {len(lines)} 行)" + st.code(preview, language="xml") + + +# ---- URL 参数 ---- query_params = st.query_params url_session_id = query_params.get("session_id", "") @@ -35,7 +106,6 @@ if "graph" not in st.session_state: if "pending_action" not in st.session_state: st.session_state.pending_action = None -# 确定活跃的 session_id if "agent_state" not in st.session_state: if url_session_id: data = load_session(url_session_id) @@ -58,8 +128,8 @@ if "agent_state" not in st.session_state: current_session_id = st.session_state.agent_state.get("session_id", "") -def run_agent(user_input: str) -> dict: - """运行代理图,返回最终状态。""" +def run_agent(user_input: str): + """运行代理图:流式渲染节点进度 + LLM 文本。""" agent_state = st.session_state.agent_state if agent_state.get("current_jrxml") and agent_state.get("status") == "pass": @@ -68,120 +138,155 @@ def run_agent(user_input: str) -> dict: agent_state["user_input"] = user_input agent_state["retry_count"] = 0 + # ---- UI 容器 ---- + streaming_placeholder = st.empty() # 流式文本 + nodes_container = st.container() # 节点进度区 + summary_placeholder = st.empty() # 总结卡片 + + # 节点追踪 + executed_nodes: list[dict] = [] # {name, label, status, detail} + stream_text = "" + stream_active = False + current_stream_node = "" final_state = None - with st.chat_message("assistant"): - status_placeholder = st.empty() - jrxml_placeholder = st.empty() - for event in st.session_state.graph.stream(agent_state): - for node_name, node_state in event.items(): - final_state = node_state - if node_name == "classify_intent": - intent = node_state.get("intent", "") - intent_labels = { - "initial_generation": "🆕 识别为新建报表请求", - "modify_report": "✏️ 识别为修改报表请求", - "preview_report": "👁 识别为预览请求", - "export_pdf": "📄 识别为导出PDF请求", - "export_jrxml": "📥 识别为导出JRXML请求", - "undo_modification": "↩ 识别为撤销请求", - "consult_question": "💬 识别为咨询问题", - "reset_session": "🔄 识别为重置会话请求", - } - label = intent_labels.get(intent, f"🔍 意图: {intent}") - status_placeholder.info(label) - elif node_name == "generate": - status_placeholder.info("🔧 正在生成 JRXML...") - elif node_name == "modify_jrxml": - status_placeholder.info("🔧 正在根据您的请求修改 JRXML...") - elif node_name == "validate": - if node_state.get("status") == "pass": - status_placeholder.success("✅ 验证通过!") - else: - status_placeholder.warning("⚠ 验证失败,正在分析错误...") - elif node_name == "explain_error": - explanation = node_state.get("natural_explanation", "") - status_placeholder.warning(f"🔍 {explanation}") - elif node_name == "correct_jrxml": - status_placeholder.info(f"🛠 正在自动修正(尝试 {node_state.get('retry_count', 1)})...") - elif node_name == "handle_consult": - pass - elif node_name == "handle_undo": - status_placeholder.info("↩ 已撤销上一步修改") - elif node_name == "handle_reset": - status_placeholder.info("🔄 会话已重置") - elif node_name == "manage_context": - pass - elif node_name == "save_state_snapshot": - pass - elif node_name == "save_session": - pass - elif node_name == "finalize": - pass + try: + for event in st.session_state.graph.stream( + agent_state, stream_mode=["updates", "custom"] + ): + mode, data = event - if final_state: - st.session_state.agent_state = final_state - intent = final_state.get("intent", "") + if mode == "updates": + for node_name, node_state in data.items(): + label = NODE_LABELS.get(node_name, node_name) + if node_name not in SKIP_NODES: + executed_nodes.append({ + "name": node_name, + "label": label, + }) + if node_name == "classify_intent": + intent = node_state.get("intent", "") + il = INTENT_LABELS.get(intent, intent) + executed_nodes[-1]["detail"] = f"意图: {il}" + + elif node_name == "retrieve": + ctx = node_state.get("retrieved_context", "") + executed_nodes[-1]["detail"] = ( + f"找到 {len(ctx)} 字符参考模板" if ctx else "未匹配到模板" + ) + + elif node_name in ("generate", "modify_jrxml", "correct_jrxml"): + # 流式文本已在上面的 custom 事件中展示 + jrxml = node_state.get("current_jrxml", "") + executed_nodes[-1]["detail"] = f"生成 {len(jrxml)} 字符 JRXML" + + elif node_name == "validate": + status = node_state.get("status", "") + if status == "pass": + executed_nodes[-1]["detail"] = "验证通过 ✓" + else: + err = node_state.get("error_msg", "") + executed_nodes[-1]["detail"] = f"验证失败: {err[:80]}" + + elif node_name == "explain_error": + expl = node_state.get("natural_explanation", "") + executed_nodes[-1]["detail"] = expl[:120] + + elif node_name == "handle_consult": + ans = node_state.get("consult_answer", "") + executed_nodes[-1]["detail"] = ans[:150] + + final_state = node_state + + elif mode == "custom": + cd = data + if cd.get("type") == "stream": + stream_text += cd.get("text", "") + stream_active = True + current_stream_node = cd.get("node", "") + streaming_placeholder.code(stream_text, language="xml") + + except Exception as e: + st.error(f"工作流异常: {e}") + return + + # ---- 渲染节点进度区 ---- + with nodes_container: + with st.expander("处理过程", expanded=False): + for i, node in enumerate(executed_nodes): + icon = "✓" if i < len(executed_nodes) - 1 else "●" + detail_str = f" — {node['detail']}" if node.get("detail") else "" + st.caption(f"{icon} {node['label']}{detail_str}") + + # ---- 清除流式占位 ---- + if stream_active: + streaming_placeholder.empty() + + # ---- 总结卡片 ---- + if final_state: + st.session_state.agent_state = final_state + intent = final_state.get("intent", "") + status = final_state.get("status", "") + + with summary_placeholder.container(border=True): if intent == "consult_question": answer = final_state.get("consult_answer", "") + st.info(answer) st.session_state.messages.append({ - "role": "assistant", - "content": answer, - "type": "consult", + "role": "assistant", "content": answer, "type": "consult", }) - status_placeholder.empty() - st.markdown(answer) + elif intent in ("undo_modification", "reset_session"): - # 消息已在节点中添加,不需要额外输出 - status_placeholder.empty() - jrxml_placeholder.empty() + st.success("操作已完成") + # 消息已在节点中添加 + elif intent in ("preview_report", "export_pdf", "export_jrxml"): jrxml = final_state.get("current_jrxml", "") if jrxml: + st.success("✅ 当前报表") + _render_jrxml(jrxml) st.session_state.messages.append({ - "role": "assistant", - "content": jrxml, - "type": "jrxml", + "role": "assistant", "content": jrxml, "type": "jrxml", }) - status_placeholder.success("✅ 当前报表") - jrxml_placeholder.code(jrxml, language="xml") else: - status_placeholder.warning("⚠ 当前没有报表可以预览或导出。") - jrxml_placeholder.empty() - elif final_state.get("status") == "pass": + st.warning("⚠ 当前没有报表可以展示。") + + elif status == "pass": + jrxml = final_state.get("current_jrxml", "") + st.success("✅ JRXML 生成成功") + st.markdown("**生成结果:**") + _render_jrxml(jrxml) + st.caption("您可以从侧边栏下载文件,或继续对话进行修改。") st.session_state.messages.append({ - "role": "assistant", - "content": final_state.get("current_jrxml", ""), - "type": "jrxml", + "role": "assistant", "content": jrxml, "type": "jrxml", }) st.session_state.messages.append({ "role": "assistant", "content": "✅ JRXML 生成成功!您可以从侧边栏下载文件,或继续修改。", "type": "success", }) - status_placeholder.success("✅ JRXML 验证通过!") - jrxml_placeholder.code(final_state.get("current_jrxml", ""), language="xml") + else: + jrxml = final_state.get("current_jrxml", "") error_msg = final_state.get("error_msg", "未知错误") explanation = final_state.get("natural_explanation", "") retries = final_state.get("retry_count", 0) + st.error(f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML") + st.markdown(f"**错误:** {error_msg}") + if explanation: + st.markdown(f"**原因:** {explanation}") + if jrxml: + with st.expander("查看当前 JRXML"): + _render_jrxml(jrxml, max_lines=80) + st.caption("请简化报表结构后重试。") st.session_state.messages.append({ "role": "assistant", - "content": final_state.get("current_jrxml", ""), - "type": "jrxml", - }) - st.session_state.messages.append({ - "role": "assistant", - "content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}\n\n**解释:** {explanation}\n\n请重新描述您的需求或简化报表结构。", + "content": f"❌ 经过 {retries} 次重试后仍无法生成有效的 JRXML。\n\n**错误:** {error_msg}", "type": "error_explanation", }) - status_placeholder.error(f"❌ 经过 {retries} 次重试后验证失败") - jrxml_placeholder.text("") - else: - st.error("未产生结果,请重试。") - - return final_state + else: + st.error("未产生结果,请重试。") # ---- 侧边栏 ---- @@ -192,7 +297,6 @@ with st.sidebar: # 会话管理 st.markdown("### 会话管理") - sessions = list_all_sessions() session_options = {} for s in sessions: @@ -270,6 +374,83 @@ with st.sidebar: run_agent("重新来,清空当前报表") st.rerun() + st.divider() + st.markdown("### 上传文件") + st.caption("支持图片 (OCR)、PDF、Word、文本文件。内容将附加到您的下一条消息中。") + + if "uploaded_files" not in st.session_state: + st.session_state.uploaded_files = [] # [{name, text, type}] + + uploaded = st.file_uploader( + "选择文件", + type=["png", "jpg", "jpeg", "bmp", "webp", "pdf", "docx", "txt", "csv", "json", "xml"], + accept_multiple_files=True, + key="file_uploader", + label_visibility="collapsed", + ) + + if uploaded: + for uf in uploaded: + # 去重 + if any(f["name"] == uf.name for f in st.session_state.uploaded_files): + continue + import tempfile + from backend.file_parser import parse_file + from backend.layout_analyzer import analyze_layout + + suffix = Path(uf.name).suffix.lower() + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + tmp.write(uf.getvalue()) + tmp_path = tmp.name + + result = parse_file(tmp_path, suffix) + + # 对图片/PDF 进行 A4 模板布局分析 + parsed_text = result["text"] + parsed_type = result["file_type"] + if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp", ".pdf"): + layout = analyze_layout(tmp_path) + tt = layout.get("template_type", "unknown") + current_jrxml = st.session_state.agent_state.get("current_jrxml", "") + + if tt == "full_a4": + parsed_text = layout["description"] + parsed_type = "a4_template" + elif tt == "partial_rows": + parsed_type = "a4_partial" + if current_jrxml.strip(): + # 修改模式:尝试行匹配 + from backend.layout_analyzer import match_rows_to_jrxml + match = match_rows_to_jrxml(layout, current_jrxml) + parsed_text = ( + f"[行片段修改] 上传图片包含 {layout['total_rows']} 行," + f"视为 A4 报表的一部分。\n\n" + f"{match['description']}\n\n" + f"--- 行结构 ---\n{layout['description']}" + ) + else: + # 新建模式:按 A4 模板处理 + parsed_text = layout["description"] + + Path(tmp_path).unlink(missing_ok=True) + + if parsed_text: + st.session_state.uploaded_files.append({ + "name": uf.name, + "text": parsed_text, + "type": parsed_type, + }) + + if st.session_state.uploaded_files: + for i, f in enumerate(st.session_state.uploaded_files): + cols = st.columns([5, 1]) + with cols[0]: + st.caption(f"📎 {f['name']} ({f['type']}, {len(f['text'])} 字符)") + with cols[1]: + if st.button("✕", key=f"rm_uf_{i}", help="移除"): + st.session_state.uploaded_files.pop(i) + st.rerun() + st.divider() st.markdown("### 配置") llm_backend = os.getenv("LLM_BACKEND", "cloud") @@ -280,16 +461,36 @@ with st.sidebar: st.divider() st.markdown("### 下载") + final = st.session_state.agent_state.get("final_jrxml", "") + versions = st.session_state.agent_state.get("jrxml_versions", []) + if final: st.download_button( - label="📥 下载 JRXML", + label="📥 下载最新 JRXML", data=final, file_name="report.jrxml", mime="application/xml", use_container_width=True, ) + if versions: + with st.expander("📋 历史版本", expanded=False): + for i, v in enumerate(reversed(versions)): + ts = v.get("ts", "")[:16] + label = v.get("label", "版本") + status = v.get("status", "") + icon = "✅" if status == "pass" else "❌" + dl_label = f"{icon} v{len(versions)-i} — {label} ({ts})" + st.download_button( + label=dl_label, + data=v.get("jrxml", ""), + file_name=f"report_v{len(versions)-i}.jrxml", + mime="application/xml", + use_container_width=True, + key=f"dl_v{i}", + ) + # ---- 标题 ---- st.title("📝 JRXML 报表生成器") st.caption("用自然语言描述您的报表需求,我将逐步生成可用的 JRXML 模板。") @@ -297,22 +498,33 @@ st.caption("用自然语言描述您的报表需求,我将逐步生成可用 # ---- 聊天历史 ---- for msg in st.session_state.messages: with st.chat_message(msg["role"]): - if msg["role"] == "assistant" and msg.get("type") == "jrxml": + if msg.get("type") == "jrxml": with st.expander("查看生成的 JRXML", expanded=False): st.code(msg["content"], language="xml") - elif msg["role"] == "assistant" and msg.get("type") == "error_explanation": + elif msg.get("type") == "error_explanation": st.warning(msg["content"]) - elif msg["role"] == "assistant" and msg.get("type") == "success": + elif msg.get("type") == "success": st.success(msg["content"]) - elif msg["role"] == "assistant" and msg.get("type") == "consult": + elif msg.get("type") == "consult": st.info(msg["content"]) else: st.markdown(msg["content"]) # ---- 聊天输入 ---- if prompt := st.chat_input("描述您的报表需求..."): + # 拼接上传文件的文本 + uploaded_texts = [] + if st.session_state.get("uploaded_files"): + for f in st.session_state.uploaded_files: + uploaded_texts.append(f"[上传文件: {f['name']}]\n{f['text']}") + if uploaded_texts: + full_prompt = "\n\n".join(uploaded_texts) + "\n\n---\n用户需求:\n" + prompt + st.session_state.uploaded_files = [] # 用后即清 + else: + full_prompt = prompt + st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) - run_agent(prompt) + run_agent(full_prompt) st.rerun() diff --git a/backend/embeddings.py b/backend/embeddings.py index 72b5226..69769a4 100644 --- a/backend/embeddings.py +++ b/backend/embeddings.py @@ -2,7 +2,7 @@ 调用方式: get_embeddings() → LangChain 兼容的 embeddings 对象 - get_st_embeddings() → 原始 SentenceTransformer 实例 + get_st_model() → 原始 SentenceTransformer 实例 """ import os diff --git a/backend/error_kb.py b/backend/error_kb.py new file mode 100644 index 0000000..245bda6 --- /dev/null +++ b/backend/error_kb.py @@ -0,0 +1,225 @@ +"""错误自增长知识库 — 记录修正成功的错误案例,用于未来参考。 + +原则: +- 仅记录"新错误"(指纹去重) +- 必须包含完整的修正方案(prompt、工具链、前后 JRXML) +- 存储于 ChromaDB,可被检索注入到生成 prompt 中 + +用法: + from backend.error_kb import ErrorKB + kb = ErrorKB() + kb.record(error_msg, bad_jrxml, good_jrxml, correction_prompt) + cases = kb.search("字段未声明", k=3) +""" + +import hashlib +import json +import os +import re +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +from dotenv import load_dotenv + +load_dotenv() + +CHROMA_DIR = Path(os.getenv("CHROMA_PERSIST_DIR", "./db/chroma")) +COLLECTION_NAME = "jrxml_error_cases" + + +def _make_fingerprint(error_msg: str) -> str: + """生成错误指纹 — 标准化后取 hash,用于去重。 + + 标准化规则: + - 去除字段名、变量名等具体标识符(替换为占位符) + - 小写化 + - 只保留错误的结构性特征 + """ + text = error_msg.lower() + # 替换变量名 / 字段名($F{xxx}, "name", 'value' 等) + text = re.sub(r'\$f\{[^}]+\}', '$f{}', text) + text = re.sub(r"'[^']*'", "''", text) + text = re.sub(r'"[^"]*"', '""', text) + # 替换数字 + text = re.sub(r'\b\d+\b', '', text) + # 压缩空白 + text = re.sub(r'\s+', ' ', text).strip() + return hashlib.md5(text.encode()).hexdigest()[:16] + + +class ErrorKB: + """错误案例知识库 — 包装 ChromaDB 持久化。""" + + def __init__(self): + self._client = None + self._collection = None + + @property + def client(self): + if self._client is None: + import chromadb + self._client = chromadb.PersistentClient(path=str(CHROMA_DIR)) + return self._client + + @property + def collection(self): + if self._collection is None: + try: + self._collection = self.client.get_collection(COLLECTION_NAME) + except Exception: + self._collection = self.client.create_collection(COLLECTION_NAME) + return self._collection + + def exists(self, error_msg: str) -> bool: + """检查错误是否已存在于知识库中(按指纹去重)。""" + fp = _make_fingerprint(error_msg) + try: + results = self.collection.get(ids=[fp]) + return bool(results and results["ids"]) + except Exception: + return False + + def record( + self, + error_msg: str, + bad_jrxml: str, + good_jrxml: str, + correction_prompt: str, + model: str = "", + retry_count: int = 0, + ) -> bool: + """记录一个成功修正的错误案例。 + + 仅当指纹不重复时写入。返回 True 表示已记录,False 表示重复。 + """ + if self.exists(error_msg): + return False + + fp = _make_fingerprint(error_msg) + now = datetime.now(timezone.utc).isoformat() + + # 内容:结构化记录 + doc = json.dumps({ + "error": error_msg, + "bad_jrxml_snippet": bad_jrxml[:2000], + "good_jrxml_snippet": good_jrxml[:2000], + "correction_prompt": correction_prompt[:1500], + "model": model, + "retry_count": retry_count, + "recorded_at": now, + "tools": ["validation_service", "llm_correction"], + }, ensure_ascii=False) + + # 元数据:用于检索过滤 + error_keywords = _extract_keywords(error_msg) + metadata = { + "fingerprint": fp, + "error_keywords": ", ".join(error_keywords[:5]), + "recorded_at": now, + "retry_success": retry_count + 1, # 第几次修正成功的 + } + + self.collection.add( + ids=[fp], + documents=[doc], + metadatas=[metadata], + ) + return True + + def search(self, error_msg: str, k: int = 3) -> list[dict]: + """根据错误消息搜索相似的修正案例(ChromaDB 语义搜索)。 + + 返回 [{error, fix_snippet, prompt, ...}, ...] + """ + keywords = _extract_keywords(error_msg) + if not keywords: + return [] + + query_text = " ".join(keywords) + try: + results = self.collection.query( + query_texts=[query_text], + n_results=k, + include=["documents", "metadatas", "distances"], + ) + except Exception: + return [] + + output = [] + if not results["ids"] or not results["ids"][0]: + return output + + for i, doc_id in enumerate(results["ids"][0]): + dist = results["distances"][0][i] + try: + data = json.loads(results["documents"][0][i]) + output.append({ + "id": doc_id, + "error": data.get("error", ""), + "fix_snippet": data.get("good_jrxml_snippet", ""), + "prompt": data.get("correction_prompt", ""), + "recorded_at": data.get("recorded_at", ""), + "distance": dist, + }) + except json.JSONDecodeError: + continue + + return output + + def search_as_context(self, error_msg: str, k: int = 3) -> str: + """搜索并返回拼接好的错误案例上下文,可直接注入 LLM prompt。""" + results = self.search(error_msg, k=k) + if not results: + return "" + + parts = [] + for r in results: + parts.append( + f"[历史错误案例]\n" + f"错误: {r['error'][:200]}\n" + f"修正后 JRXML 片段:\n{r['fix_snippet'][:800]}\n" + ) + return "\n---\n".join(parts) + + def stats(self) -> dict: + """返回知识库统计信息。""" + try: + count = self.collection.count() + return {"total_cases": count, "collection": COLLECTION_NAME} + except Exception: + return {"total_cases": 0, "collection": COLLECTION_NAME} + + +def _extract_keywords(error_msg: str) -> list[str]: + """从错误消息中提取关键词(中文 + 英文 token)。""" + # 中文字符作为独立关键词 + chinese = re.findall(r'[一-鿿]{2,}', error_msg) + # 英文 camelCase / snake_case token + english = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]{2,}', error_msg) + # JRXML 特有模式 + jrxml_patterns = re.findall(r'\$F\{[^}]*\}', error_msg) + return chinese + english + jrxml_patterns + + +# 全局单例 +_kb: Optional[ErrorKB] = None + + +def get_error_kb() -> ErrorKB: + global _kb + if _kb is None: + _kb = ErrorKB() + return _kb + + +def record_error(error_msg: str, bad_jrxml: str, good_jrxml: str, + correction_prompt: str, model: str = "", retry_count: int = 0) -> bool: + """便捷函数:记录成功修正的错误案例。""" + return get_error_kb().record(error_msg, bad_jrxml, good_jrxml, + correction_prompt, model, retry_count) + + +def search_error_cases(error_msg: str, k: int = 3) -> str: + """便捷函数:搜索历史错误案例并返回上下文字符串。""" + return get_error_kb().search_as_context(error_msg, k=k) diff --git a/backend/file_parser.py b/backend/file_parser.py new file mode 100644 index 0000000..4920df7 --- /dev/null +++ b/backend/file_parser.py @@ -0,0 +1,193 @@ +"""文件解析器:将上传文件转为文本,供 LLM 处理。 + +支持: +- 图片 (.png/.jpg/.jpeg/.bmp) → OCR 提取文本 +- PDF (.pdf) → 文本提取 +- Word (.docx) → 文本提取 +- 纯文本 (.txt/.csv/.json/.xml) → 直接读取 + +策略选择: +- 原生多模态: 模型支持图片时直接传文件(当前 MiniMax 不支持,自动退回文本转换) +- 文本转换: 所有文件转为 UTF-8 文本后注入 prompt +""" + +import os +import io +from pathlib import Path +from typing import Optional + +import PIL.Image + +MODELS_WITH_VISION = { + "gpt-4o", "gpt-4-turbo", "gpt-4-vision-preview", + "claude-3", "claude-3.5", "claude-4", + "gemini-1.5", "gemini-2", +} + + +def can_use_vision(model: str = "") -> bool: + """检查当前模型是否支持原生多模态(图片直接上传)。""" + if not model: + model = os.getenv("LLM_MODEL", "") + return any(v in model.lower() for v in MODELS_WITH_VISION) + + +def parse_file(file_path: str, file_type: str = "") -> dict: + """解析任意文件为文本。 + + 返回: {"text": str, "file_type": str, "method": str, "error": Optional[str]} + """ + path = Path(file_path) + if not path.exists(): + return {"text": "", "file_type": file_type, "method": "none", "error": "文件不存在"} + + suffix = file_type or path.suffix.lower() + + parsers = { + ".png": _parse_image, + ".jpg": _parse_image, + ".jpeg": _parse_image, + ".bmp": _parse_image, + ".webp": _parse_image, + ".pdf": _parse_pdf, + ".docx": _parse_docx, + } + + parser = parsers.get(suffix) + if parser: + return parser(path) + else: + return _parse_text(path) + + +# --------------------------------------------------------------------------- +# 各类型解析器 +# --------------------------------------------------------------------------- + +def _parse_image(path: Path) -> dict: + """OCR 提取图片中的文字。""" + try: + img = PIL.Image.open(path) + info = f"[图片: {img.size[0]}x{img.size[1]}, {img.mode}]" + except Exception: + info = "[图片: 无法读取元数据]" + + # 尝试 PaddleOCR + try: + from paddleocr import PaddleOCR + ocr = PaddleOCR(lang="ch", use_angle_cls=False, show_log=False) + result = ocr.ocr(str(path)) + lines = [] + if result and result[0]: + for line in result[0]: + text = line[1][0] if len(line) > 1 else "" + if text.strip(): + lines.append(text.strip()) + if lines: + return { + "text": f"{info}\n识别文本:\n" + "\n".join(lines), + "file_type": "image", + "method": "paddleocr", + "error": None, + } + except ImportError: + pass + except Exception: + pass + + # OCR 不可用 → 返回图片元信息 + 安装提示 + return { + "text": f"{info}\n(如需 OCR 文字识别,请安装: pip install paddleocr)", + "file_type": "image", + "method": "metadata_only", + "error": "OCR 引擎未安装,已返回图片元信息", + } + + +def _parse_pdf(path: Path) -> dict: + """提取 PDF 中的文本。""" + try: + import pdfplumber + with pdfplumber.open(path) as pdf: + pages = [] + for page in pdf.pages: + text = page.extract_text() + if text: + pages.append(text) + full = "\n\n".join(pages) + return { + "text": full, + "file_type": "pdf", + "method": "pdfplumber", + "error": None, + } + except ImportError: + pass + except Exception as e: + pass + + # Fallback: 尝试 PyMuPDF + try: + import fitz + doc = fitz.open(path) + pages = [] + for page in doc: + pages.append(page.get_text()) + doc.close() + return { + "text": "\n\n".join(pages), + "file_type": "pdf", + "method": "pymupdf", + "error": None, + } + except ImportError: + pass + except Exception: + pass + + return {"text": "", "file_type": "pdf", "method": "none", + "error": "PDF 解析需要安装 pdfplumber 或 PyMuPDF"} + + +def _parse_docx(path: Path) -> dict: + """提取 Word 文档中的文本。""" + try: + from docx import Document + doc = Document(path) + paragraphs = [p.text for p in doc.paragraphs if p.text.strip()] + # 同时提取表格内容 + for table in doc.tables: + for row in table.rows: + cells = [cell.text for cell in row.cells if cell.text.strip()] + if cells: + paragraphs.append(" | ".join(cells)) + return { + "text": "\n\n".join(paragraphs), + "file_type": "docx", + "method": "python-docx", + "error": None, + } + except ImportError: + pass + except Exception as e: + pass + + return {"text": "", "file_type": "docx", "method": "none", + "error": "DOCX 解析需要安装 python-docx"} + + +def _parse_text(path: Path) -> dict: + """读取纯文本文件。""" + try: + text = path.read_text(encoding="utf-8") + return {"text": text, "file_type": path.suffix, "method": "direct", "error": None} + except UnicodeDecodeError: + try: + text = path.read_text(encoding="gbk") + return {"text": text, "file_type": path.suffix, "method": "direct_gbk", "error": None} + except Exception: + return {"text": "", "file_type": path.suffix, "method": "none", + "error": "无法解码文件"} + except Exception: + return {"text": "", "file_type": path.suffix, "method": "none", + "error": "读取失败"} diff --git a/backend/layout_analyzer.py b/backend/layout_analyzer.py new file mode 100644 index 0000000..631aff5 --- /dev/null +++ b/backend/layout_analyzer.py @@ -0,0 +1,494 @@ +"""A4 图片模板布局分析器。 + +检测上传图片并逐行识别每个元素的: +- 位置 (x, y, w, h) +- 字体大小(基于 OCR 边界框高度估算) +- 文本内容 + +支持三种模式: +- 完整 A4 模板:比例匹配 + OCR 元素 ≥2 → 全量布局描述 +- 行片段(非 A4 但有元素):视为 A4 中的某几行 → 部分布局描述 +- 修改匹配:将图片中的行与现有 JRXML 做匹配,定位修改位置 + +用法: + from backend.layout_analyzer import analyze_layout, match_rows_to_jrxml + result = analyze_layout("row_snippet.png") + # result["template_type"] = "partial_rows" + match = match_rows_to_jrxml(result, current_jrxml) + # match["matched_rows"] = [{"row_index": 0, "jrxml_section": "detail_band", ...}] +""" + +import re +import xml.etree.ElementTree as ET +from pathlib import Path +from typing import Optional + +import PIL.Image + +# A4 标准尺寸 (mm): 210 × 297, 比例 ≈ 0.707 +A4_RATIO = 210 / 297 +A4_RATIO_EXACT_MIN, A4_RATIO_EXACT_MAX = 0.686, 0.728 +A4_RATIO_CLOSE_MIN, A4_RATIO_CLOSE_MAX = 0.650, 0.764 + + +def analyze_layout( + file_path: str, + row_tolerance_ratio: float = 0.02, +) -> dict: + """分析图片/PDF 的报表模板布局。 + + 返回: + { + "is_a4_template": bool, # 完整 A4 模板 + "is_partial": bool, # 行片段(非 A4 但有文字元素) + "template_type": str, # "full_a4" | "partial_rows" | "unknown" + "image_size": (w, h), + "aspect_ratio": float, + "a4_confidence": str, + "rows": [{y_center, elements: [{x, y, w, h, font_size, text}, ...]}, ...], + "description": str, + "total_rows": int, + "total_elements": int, + } + """ + path = Path(file_path) + if not path.exists(): + return _empty_result("文件不存在") + + img = _load_image(path) + if img is None: + return _empty_result("无法加载图片") + + w, h = img.size + ratio = min(w, h) / max(w, h) + + # A4 比例判定 + if A4_RATIO_EXACT_MIN <= ratio <= A4_RATIO_EXACT_MAX: + a4_confidence = "exact" + elif A4_RATIO_CLOSE_MIN <= ratio <= A4_RATIO_CLOSE_MAX: + a4_confidence = "close" + else: + a4_confidence = "not_a4" + + # OCR 提取 + elements = _ocr_elements(img, file_path) + + if not elements: + return { + "is_a4_template": False, + "is_partial": False, + "template_type": "unknown", + "image_size": (w, h), + "aspect_ratio": round(ratio, 3), + "a4_confidence": a4_confidence, + "rows": [], + "description": _build_description([], w, h, a4_confidence, "unknown"), + "total_rows": 0, + "total_elements": 0, + } + + # 行分组 + rows = _group_into_rows(elements, h, row_tolerance_ratio) + + total = sum(len(r["elements"]) for r in rows) + + # 模板类型判定 + is_full_a4 = a4_confidence != "not_a4" and total >= 2 + is_partial = not is_full_a4 and total >= 1 # 非 A4 但有文字 → 行片段 + + if is_full_a4: + template_type = "full_a4" + elif is_partial: + template_type = "partial_rows" + else: + template_type = "unknown" + + description = _build_description(rows, w, h, a4_confidence, template_type) + + return { + "is_a4_template": is_full_a4, + "is_partial": is_partial, + "template_type": template_type, + "image_size": (w, h), + "aspect_ratio": round(ratio, 3), + "a4_confidence": a4_confidence, + "rows": rows, + "description": description, + "total_rows": len(rows), + "total_elements": total, + } + + +def match_rows_to_jrxml( + layout_result: dict, + current_jrxml: str, +) -> dict: + """将图片中的行与现有 JRXML 中的 section/band 做匹配。 + + 匹配策略: + 1. 从图片 OCR 文本中提取关键词 + 2. 在 JRXML 中搜索这些关键词出现在哪个 band + 3. 返回匹配结果,可用于定位修改位置 + + 返回: + { + "matched": bool, + "matched_rows": [{row_index, row_y_center, jrxml_section, confidence}], + "unmatched_rows": [...], + "description": str, # 人类可读的匹配结果 + } + """ + rows = layout_result.get("rows", []) + if not rows or not current_jrxml.strip(): + return {"matched": False, "matched_rows": [], "unmatched_rows": rows, + "description": "无行数据或 JRXML 为空"} + + # 解析 JRXML 结构 + jrxml_sections = _parse_jrxml_sections(current_jrxml) + + matched_rows = [] + unmatched_rows = [] + + for ri, row in enumerate(rows): + ocr_texts = [e["text"] for e in row["elements"]] + best_section = None + best_score = 0 + + for section in jrxml_sections: + score = _text_similarity(ocr_texts, section["text_content"]) + if score > best_score: + best_score = score + best_section = section + + if best_score > 0.3 and best_section: # 最低匹配阈值 + matched_rows.append({ + "row_index": ri, + "row_y_center": row["y_center"], + "jrxml_section": best_section["name"], + "jrxml_section_type": best_section["type"], + "confidence": round(best_score, 2), + "matched_text": best_section["text_content"][:200], + }) + else: + unmatched_rows.append({ + "row_index": ri, + "row_y_center": row["y_center"], + "ocr_texts": ocr_texts, + }) + + # 生成描述 + desc_parts = [] + if matched_rows: + desc_parts.append(f"图片中 {len(matched_rows)} 行匹配到当前 JRXML:") + for m in matched_rows: + desc_parts.append( + f" - 图片第 {m['row_index']+1} 行 → JRXML「{m['jrxml_section']}」" + f"({m['jrxml_section_type']},置信度 {m['confidence']})" + ) + if unmatched_rows: + desc_parts.append(f"图片中 {len(unmatched_rows)} 行未匹配到 JRXML 现有区域:") + for u in unmatched_rows: + texts = ", ".join(u["ocr_texts"][:3]) + desc_parts.append(f" - 图片第 {u['row_index']+1} 行:{texts}") + + return { + "matched": len(matched_rows) > 0, + "matched_rows": matched_rows, + "unmatched_rows": unmatched_rows, + "description": "\n".join(desc_parts), + } + + +def analyze_and_inject(file_path: str, base_prompt: str, + current_jrxml: str = "") -> str: + """分析布局并增强 prompt。 + + - 完整 A4 模板 → 全量布局描述 + - 行片段 + 有 JRXML → 行匹配 + 修改指引 + - 行片段 + 无 JRXML → 行片段描述(视为 A4 模板的一部分) + """ + result = analyze_layout(file_path) + tt = result.get("template_type", "unknown") + + if tt == "unknown": + return base_prompt + + if tt == "full_a4": + return f"[图片模板分析 — 完整 A4 报表]\n{result['description']}\n\n---\n原始需求:\n{base_prompt}" + + if tt == "partial_rows": + if current_jrxml.strip(): + match = match_rows_to_jrxml(result, current_jrxml) + if match["matched"]: + return ( + f"[图片模板分析 — 行片段修改]\n" + f"图片包含 {result['total_rows']} 行,视为 A4 模板的一部分。\n" + f"{match['description']}\n\n" + f"{result['description']}\n\n" + f"---\n请根据以上匹配结果,修改 JRXML 中对应区域的布局:\n{base_prompt}" + ) + else: + return ( + f"[图片模板分析 — 行片段(未匹配到现有区域)]\n" + f"图片包含 {result['total_rows']} 行。\n" + f"{result['description']}\n\n" + f"---\n请根据以上行结构,在 JRXML 中找到合适位置进行修改:\n{base_prompt}" + ) + else: + return ( + f"[图片模板分析 — 行片段(无现有报表,按 A4 模板处理)]\n" + f"图片包含 {result['total_rows']} 行,请按 A4 报表模板的需求输出整张报表。\n" + f"{result['description']}\n\n" + f"---\n原始需求:\n{base_prompt}" + ) + + return base_prompt + + +# --------------------------------------------------------------------------- +# JRXML 结构解析 +# --------------------------------------------------------------------------- + +def _parse_jrxml_sections(jrxml: str) -> list[dict]: + """解析 JRXML 中的 section/band 结构。 + + 直接搜索所有 band 元素,通过上下文字符串推断其所属 section。 + """ + sections = [] + try: + root = ET.fromstring(jrxml) + section_tags = { + "title", "pageHeader", "columnHeader", "detail", + "columnFooter", "pageFooter", "summary", "background", + "noData", "groupHeader", "groupFooter", + } + + for section_elem in root.iter(): + stag = _tag(section_elem) + if stag not in section_tags: + continue + + for child in section_elem: + if _tag(child) == "band": + name = child.get("name", "") + section_name = f"{stag}[{name}]" if name else stag + text_content = ET.tostring(child, encoding="unicode") + sections.append({ + "name": section_name, + "type": stag, + "text_content": text_content, + }) + except Exception: + pass + + # Fallback: 如果 structured parsing 失败,直接把整个 JRXML 按 band 分割 + if not sections: + sections = _parse_jrxml_regex(jrxml) + + return sections + + +def _tag(elem) -> str: + """去除命名空间前缀的标签名。""" + return elem.tag.split("}")[-1] if "}" in elem.tag else elem.tag + + +def _parse_jrxml_regex(jrxml: str) -> list[dict]: + """正则回退方案:直接在文本中搜索 band 块。""" + sections = [] + band_pattern = re.compile( + r'<(title|pageHeader|columnHeader|detail|columnFooter|pageFooter|summary|background|noData|groupHeader|groupFooter)>\s*' + r'(]*>.*?)\s*' + r'', + re.DOTALL, + ) + for m in band_pattern.finditer(jrxml): + stag = m.group(1) + band_xml = m.group(0) + sections.append({ + "name": stag, + "type": stag, + "text_content": band_xml, + }) + return sections + + +def _text_similarity(ocr_texts: list[str], jrxml_text: str) -> float: + """计算 OCR 文本与 JRXML 文本的相似度(简单的词匹配)。""" + if not ocr_texts or not jrxml_text: + return 0.0 + + jrxml_lower = jrxml_text.lower() + score = 0.0 + for text in ocr_texts: + # 精确匹配 + if text.lower() in jrxml_lower: + score += 1.0 + else: + # 部分词匹配 + words = re.findall(r"\w+", text) + matched = sum(1 for w in words if w.lower() in jrxml_lower) + if words: + score += matched / len(words) * 0.5 + + return min(score / len(ocr_texts), 1.0) + + +# --------------------------------------------------------------------------- +# 内部实现(不变) +# --------------------------------------------------------------------------- + +def _load_image(path: Path) -> Optional[PIL.Image.Image]: + suffix = path.suffix.lower() + + if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"): + try: + return PIL.Image.open(path).convert("RGB") + except Exception: + return None + + if suffix == ".pdf": + try: + import pdfplumber + with pdfplumber.open(path) as pdf: + if pdf.pages: + pil_img = pdf.pages[0].to_image(resolution=150) + return pil_img.original.convert("RGB") + except Exception: + pass + + try: + import fitz + doc = fitz.open(path) + pix = doc[0].get_pixmap(dpi=150) + img = PIL.Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + doc.close() + return img + except Exception: + pass + + return None + + +def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]: + try: + from paddleocr import PaddleOCR + import numpy as np + + ocr = PaddleOCR(lang="ch", use_angle_cls=True, show_log=False) + result = ocr.ocr(np.array(img)) + + elements = [] + if result and result[0]: + for line in result[0]: + if len(line) < 2: + continue + box = line[0] + text_info = line[1] + text = text_info[0] if isinstance(text_info, (list, tuple)) else str(text_info) + if not text.strip(): + continue + + xs = [p[0] for p in box] + ys = [p[1] for p in box] + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + + elements.append({ + "x": round(x_min, 1), + "y": round(y_min, 1), + "w": round(x_max - x_min, 1), + "h": round(y_max - y_min, 1), + "font_size": round(y_max - y_min, 1), + "text": text.strip(), + }) + + elements.sort(key=lambda e: (e["y"], e["x"])) + return elements + except Exception: + pass + + return [] + + +def _group_into_rows(elements: list[dict], img_height: int, + tolerance_ratio: float = 0.02) -> list[dict]: + if not elements: + return [] + + tolerance = img_height * tolerance_ratio + rows = [] + current_row = [elements[0]] + + for elem in elements[1:]: + prev_cy = current_row[0]["y"] + current_row[0]["h"] / 2 + curr_cy = elem["y"] + elem["h"] / 2 + + if abs(curr_cy - prev_cy) < tolerance: + current_row.append(elem) + else: + rows.append(_build_row(current_row)) + current_row = [elem] + + if current_row: + rows.append(_build_row(current_row)) + + return rows + + +def _build_row(elements: list[dict]) -> dict: + elements.sort(key=lambda e: e["x"]) + ys = [e["y"] for e in elements] + return {"y_center": round(sum(ys) / len(ys), 1), "elements": elements} + + +def _build_description(rows: list[dict], img_w: int, img_h: int, + a4_confidence: str, template_type: str) -> str: + if not rows: + if template_type == "partial_rows": + return f"图片 {img_w}x{img_h}(非 A4 比例),未检测到文字元素。" + return f"图片共 {img_w}x{img_h} 像素,未检测到文字元素。" + + lines = [] + if template_type == "full_a4": + lines.append(f"图片为完整 A4 报表模板,共 {len(rows)} 行,像素区域 {img_w}x{img_h}:") + elif template_type == "partial_rows": + lines.append(f"图片为报表模板行片段(非完整 A4),包含 {len(rows)} 行," + f"像素区域 {img_w}x{img_h},请按 A4 模板处理:") + else: + lines.append(f"图片共 {img_w}x{img_h} 像素,包含 {len(rows)} 行文字:") + + for i, row in enumerate(rows): + elems = row["elements"] + lines.append(f"\n第 {i+1} 行有 {len(elems)} 个元素:") + for j, e in enumerate(elems): + letter = chr(ord("a") + j) + lines.append( + f" 元素 {letter}:位置(x={e['x']}, y={e['y']})," + f"长 {e['w']}px,高 {e['h']}px," + f"字体 {e['font_size']}px," + f"内容「{e['text']}」" + ) + + if template_type == "full_a4": + lines.append(f"\n请根据以上布局生成对应的 JRXML 报表模板。") + elif template_type == "partial_rows": + lines.append(f"\n请将以上 {len(rows)} 行作为 A4 模板的一部分," + f"生成或修改对应的 JRXML 报表区域。") + + return "\n".join(lines) + + +def _empty_result(error: str = "") -> dict: + return { + "is_a4_template": False, + "is_partial": False, + "template_type": "unknown", + "image_size": (0, 0), + "aspect_ratio": 0, + "a4_confidence": "not_a4", + "rows": [], + "description": error, + "total_rows": 0, + "total_elements": 0, + } diff --git a/backend/llm.py b/backend/llm.py index 75f38e4..41ba648 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -8,13 +8,33 @@ from dotenv import load_dotenv load_dotenv() +class _BaseLLM: + """LLM 统一接口基类 — 所有后端都提供 invoke() 和 stream()。""" + + def invoke(self, prompt: str) -> Any: + raise NotImplementedError + + def stream(self, prompt: str): + raise NotImplementedError + + def get_llm(): backend = os.getenv("LLM_BACKEND", "cloud") if backend == "local": from langchain_ollama import ChatOllama model = os.getenv("LOCAL_LLM_MODEL", "qwen2.5-coder:7b") - return ChatOllama(model=model, temperature=0.1) + raw = ChatOllama(model=model, temperature=0.1) + + class OllamaWrapper(_BaseLLM): + def invoke(self, prompt): + return raw.invoke(prompt) + + def stream(self, prompt): + for chunk in raw.stream(prompt): + yield chunk.content + + return OllamaWrapper() provider = os.getenv("LLM_PROVIDER", "openai") if provider == "anthropic": @@ -30,7 +50,7 @@ def get_llm(): client = Anthropic(api_key=api_key, base_url=base_url, timeout=120) - class MiniMaxLLM: + class MiniMaxLLM(_BaseLLM): def invoke(self, prompt: str) -> Any: resp = client.messages.create( model=model, @@ -43,20 +63,44 @@ def get_llm(): return type("Response", (), {"content": block.text})() return type("Response", (), {"content": ""})() + def stream(self, prompt: str): + with client.messages.stream( + model=model, + max_tokens=max_tokens, + temperature=temperature, + messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}], + ) as s: + for text in s.text_stream: + yield text + def get_num_tokens(self, text: str) -> int: - return client.count_tokens(text) + resp = client.messages.count_tokens( + model=model, + messages=[{"role": "user", "content": [{"type": "text", "text": text}]}], + ) + return resp.input_tokens return MiniMaxLLM() else: from langchain_openai import ChatOpenAI - return ChatOpenAI( + raw = ChatOpenAI( model=os.getenv("LLM_MODEL", "gpt-4o"), api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"), temperature=0.1, ) + class OpenAIWrapper(_BaseLLM): + def invoke(self, prompt): + return raw.invoke(prompt) + + def stream(self, prompt): + for chunk in raw.stream(prompt): + yield chunk.content + + return OpenAIWrapper() + def get_llm_for_correction(): return get_llm() \ No newline at end of file diff --git a/prompts/compression.md b/prompts/compression.md new file mode 100644 index 0000000..1968e80 --- /dev/null +++ b/prompts/compression.md @@ -0,0 +1,12 @@ +你是一个信息压缩助手。以下是用户与报表生成助手之间的历史对话记录,请将其压缩为一份简洁的摘要(不超过200字)。 + +摘要必须保留以下关键信息: +- 用户提出的所有报表需求点(字段、标题、分组、汇总等) +- 用户提出的所有修改要求及其顺序 +- 当前报表的核心结构(字段列表、标题、分组方式) +- 任何特殊要求或约束条件 + +只输出摘要文本,不要添加任何解释或标记。 + +对话记录: +{conversation_text} diff --git a/prompts/consult.md b/prompts/consult.md new file mode 100644 index 0000000..4e282f7 --- /dev/null +++ b/prompts/consult.md @@ -0,0 +1,5 @@ +你是 JasperReports 专家。用简洁清晰的中文回答用户关于 JasperReports 的问题。 + +用户问题:{question} + +直接回答: diff --git a/prompts/correction.md b/prompts/correction.md new file mode 100644 index 0000000..4409a53 --- /dev/null +++ b/prompts/correction.md @@ -0,0 +1,17 @@ +你是一位资深 JasperReports 工程师。你生成的 JRXML 文件编译失败。分析错误并修复 JRXML。 + +关键规则: +- 只输出完整修复后的 JRXML 代码,不要解释,不要 markdown 标记。 +- JRXML 必须与 JasperReports 7.0.6 兼容。 +- 解决下面列出的特定错误。 + +当前 JRXML(带错误): +{current_jrxml} + +编译错误: +{error_msg} + +错误的自然语言解释: +{explanation} + +立即生成修正后的 JRXML: diff --git a/prompts/explain_error.md b/prompts/explain_error.md new file mode 100644 index 0000000..0c94a07 --- /dev/null +++ b/prompts/explain_error.md @@ -0,0 +1,9 @@ +你是一位 JasperReports 专家。用普通非技术语言解释以下 JRXML 编译错误,让业务用户能够理解。 + +错误消息: +{error_msg} + +当前 JRXML 片段(前 80 行): +{jrxml_snippet} + +用 2-3 句话解释哪里出了问题以及如何修复: diff --git a/prompts/initial_generation.md b/prompts/initial_generation.md new file mode 100644 index 0000000..48b43e1 --- /dev/null +++ b/prompts/initial_generation.md @@ -0,0 +1,15 @@ +你是一位资深 JasperReports 工程师。根据以下参考模板和用户需求,生成一个完整、可编译的 JRXML 文件。 +JRXML 必须兼容 JasperReports 7.0.6 schema。 + +关键规则: +- 只输出 JRXML 代码,不要解释,不要 markdown 标记。 +- 报表正文中使用的每个字段必须在 部分中声明。 +- 根元素为 ,包含正确的 xmlns 属性。 +- 包含 ,在 中包含 SQL 查询。 +- 确保所有交叉引用(字段名称、band 元素)保持一致。 + +参考模板和组件: +{context} + +用户需求: +{user_request} diff --git a/prompts/intent_classify.md b/prompts/intent_classify.md new file mode 100644 index 0000000..6b72ef1 --- /dev/null +++ b/prompts/intent_classify.md @@ -0,0 +1,16 @@ +你是意图分类器。根据用户输入判断意图,只输出意图名称。 + +当前有报表:{has_report} +用户输入:{user_input} + +可选意图: +- initial_generation(新建报表,或无报表时的任何需求) +- modify_report(修改当前已有报表) +- preview_report(预览/查看当前报表) +- export_pdf(导出PDF文件) +- export_jrxml(下载/导出/保存JRXML文件) +- undo_modification(撤销/回退上一步修改) +- consult_question(咨询JasperReports相关知识或使用问题) +- reset_session(清空/重置/重新开始) + +意图名称: diff --git a/prompts/loader.py b/prompts/loader.py new file mode 100644 index 0000000..2a324d9 --- /dev/null +++ b/prompts/loader.py @@ -0,0 +1,53 @@ +"""Prompt 加载器:从 prompts/ 目录加载 .md 文件。 + +支持热重载 — 每次调用都从磁盘读取,修改 prompt 文件无需重启应用。 + +用法: + from prompts.loader import load_prompt + prompt = load_prompt("intent_classify").format(has_report="是", user_input="...") +""" + +import re +from pathlib import Path + +_PROMPTS_DIR = Path(__file__).resolve().parent + +# 文件名 → 变量名 映射 +_NAME_MAP = { + "intent_classify": "intent_classify.md", + "consult": "consult.md", + "initial_generation": "initial_generation.md", + "modification": "modification.md", + "correction": "correction.md", + "explain_error": "explain_error.md", + "compression": "compression.md", +} + + +def load_prompt(name: str) -> str: + """从 prompts/{name}.md 加载 prompt 模板(每次从磁盘读取,支持热重载)。 + + 返回的字符串包含 Python .format() 占位符,调用方负责填充。 + """ + filename = _NAME_MAP.get(name) + if not filename: + raise ValueError(f"未知 prompt: {name},可选值: {list(_NAME_MAP.keys())}") + + filepath = _PROMPTS_DIR / filename + if not filepath.exists(): + raise FileNotFoundError(f"Prompt 文件不存在: {filepath}") + + text = filepath.read_text(encoding="utf-8").strip() + + # 去掉可能存在的 markdown frontmatter(--- 包裹的元数据) + if text.startswith("---"): + end = text.find("---", 3) + if end != -1: + text = text[end + 3:].strip() + + return text + + +def list_prompts() -> list[str]: + """列出所有可用的 prompt 名称。""" + return sorted(_NAME_MAP.keys()) diff --git a/prompts/modification.md b/prompts/modification.md new file mode 100644 index 0000000..be8e6d1 --- /dev/null +++ b/prompts/modification.md @@ -0,0 +1,18 @@ +你是一位资深 JasperReports 工程师。用户想要修改一个现有的、可编译的 JRXML 报表。精确应用请求的更改到当前 JRXML 并输出完整修改后的 JRXML。 + +关键规则: +- 只输出完整修改后的 JRXML 代码,不要解释,不要 markdown 标记。 +- 保留所有未被更改的现有结构。 +- 结果必须继续与 JasperReports 7.0.6 兼容。 +- 报表正文中使用的每个字段必须在 部分中声明。 +- 如果添加新字段,正确声明它们。 +- 确保 中有效的 SQL。 + +当前 JRXML: +{current_jrxml} + +对话历史: +{conversation_history} + +用户的修改请求: +{modification_request}