From 1952d75f1323c07c34401dad6b2b8d4e2c4fafab Mon Sep 17 00:00:00 2001
From: panda <1415243231@qq.com>
Date: Sat, 23 May 2026 08:38:29 +0800
Subject: [PATCH] test: add unit/integration/E2E test suites, fix
create_session bug, update docs
- Unit tests: test_session.py (27), test_error_kb.py (24), test_agent.py hardened
- Integration tests: test_api_integration.py (25) with FastAPI TestClient
- E2E tests: main-flows.spec.ts (8) with Playwright + API mocking
- Bug fix: backend/session.py create_session() missing session_id parameter
- Config: frontend/playwright.config.ts, npm run test:e2e
- Docs: update CLAUDE.md v9, .gitignore for test artifacts/eval reports
---
.gitignore | 13 ++
CLAUDE.md | 33 ++++
backend/session.py | 7 +-
frontend/package-lock.json | 64 +++++++
frontend/package.json | 4 +-
frontend/playwright.config.ts | 20 ++
frontend/tests/e2e/main-flows.spec.ts | 168 ++++++++++++++++
tests/test_agent.py | 17 +-
tests/test_api_integration.py | 263 ++++++++++++++++++++++++++
tests/test_error_kb.py | 242 ++++++++++++++++++++++++
tests/test_session.py | 210 ++++++++++++++++++++
11 files changed, 1029 insertions(+), 12 deletions(-)
create mode 100644 frontend/playwright.config.ts
create mode 100644 frontend/tests/e2e/main-flows.spec.ts
create mode 100644 tests/test_api_integration.py
create mode 100644 tests/test_error_kb.py
create mode 100644 tests/test_session.py
diff --git a/.gitignore b/.gitignore
index 4bddf00..ad54e14 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,6 +13,19 @@ db/chroma/
sessions/
logs/
db/
+# 自动评测 (Mavis AI)
+.mavis/
+EVALUATION_REPORT.md
+
+# 上传文件
+uploads/
+
+# OCR 临时输出
+ocr_raw_positions.json
+
+# Playwright E2E 测试产物
+frontend/test-results/
+
# RAG 管线中间产物 (rag 子模块内)
rag/jrxml_chunker_output/
rag/embeddings/
diff --git a/CLAUDE.md b/CLAUDE.md
index 86b4b9f..7a0ccd5 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -327,3 +327,36 @@ validation_service/ (FastAPI, 端口 8001) — 不变
- `prompts/refine_layout.md` — 1 处
Python 将 `{{` 输出为字面量 `{`,LLM 看到的内容不变。
+
+## 更新 (v9 — 2026-05-22)
+
+### 测试基础设施全面补齐
+
+**单元测试** (76 测试):
+- `tests/test_session.py` — 27 测试:会话 CRUD、原子写入、唯一 ID、损坏 JSON 跳过
+- `tests/test_error_kb.py` — 24 测试:指纹去重、关键词提取(中/英/JRXML)、ErrorKB CRUD、搜索、统计
+- `tests/test_agent.py` — 5 个软断言强化为严格断言(`status`/`current_jrxml` 存在性检查)
+- 已有测试:`test_ocr_extraction.py`(49)、`test_layered_generation.py`(19)、`test_validation.py`(6)、`test_file_parser_formats.py`(4)、`test_annotation_detector.py`(7)、`test_e2e_ocr.py`(3)
+
+**集成测试** (25 测试, `tests/test_api_integration.py`):
+- FastAPI TestClient 全覆盖:健康检查、配置、会话 CRUD、文件上传、下载、Chat SSE、安全边界(路径穿越/非法 JSON/大 payload)
+- Mock LangGraph graph 避免真实 LLM 调用
+
+**E2E 测试** (8 测试, `frontend/tests/e2e/main-flows.spec.ts`):
+- Playwright 浏览器自动化:页面加载、侧边栏、会话管理、聊天流程、输入 UX
+- 全量 API Mock(`page.route`)无需后端运行
+- 配置: `frontend/playwright.config.ts`, `npm run test:e2e`
+
+**运行测试**:
+```bash
+# 全部单元+集成测试
+cd D:\Idea Project\jaspersoft && python -m pytest tests/ -v
+
+# 仅 E2E(需要前端 dev server)
+cd frontend && npx playwright test
+```
+
+### Bug 修复: create_session 参数缺失
+
+`backend/session.py` — `create_session()` 新增可选参数 `session_id: Optional[str] = None`。
+`api_server.py:507` 调用 `create_session(session_id=session_id)` 时之前会抛出 `TypeError`。
diff --git a/backend/session.py b/backend/session.py
index 1333592..22472c0 100644
--- a/backend/session.py
+++ b/backend/session.py
@@ -34,10 +34,11 @@ def generate_session_id() -> str:
return uuid.uuid4().hex[:12]
-def create_session(name: str = "", agent_state: Optional[dict] = None) -> dict:
- """创建新会话,返回会话元数据。"""
+def create_session(name: str = "", agent_state: Optional[dict] = None,
+ session_id: Optional[str] = None) -> dict:
+ """创建新会话,返回会话元数据。session_id 可选——传入时使用指定 ID。"""
_ensure_dir()
- sid = generate_session_id()
+ sid = session_id or generate_session_id()
now = _now_iso()
agent_state = agent_state or {}
agent_state["session_id"] = sid
diff --git a/frontend/package-lock.json b/frontend/package-lock.json
index b3f19b2..a22acc1 100644
--- a/frontend/package-lock.json
+++ b/frontend/package-lock.json
@@ -12,6 +12,7 @@
"vue": "^3.5.34"
},
"devDependencies": {
+ "@playwright/test": "^1.60.0",
"@types/node": "^24.12.3",
"@vitejs/plugin-vue": "^6.0.6",
"@vue/tsconfig": "^0.9.1",
@@ -135,6 +136,22 @@
"url": "https://github.com/sponsors/Boshen"
}
},
+ "node_modules/@playwright/test": {
+ "version": "1.60.0",
+ "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.60.0.tgz",
+ "integrity": "sha512-O71yZIbAh/PxDMNGns37GHBIfrVkEVyn+AXyIa5dOTfb4/xNvRWV+Vv/NMbNCtODB/pO7vLlF2OTmMVLhmr7Ag==",
+ "dev": true,
+ "license": "Apache-2.0",
+ "dependencies": {
+ "playwright": "1.60.0"
+ },
+ "bin": {
+ "playwright": "cli.js"
+ },
+ "engines": {
+ "node": ">=18"
+ }
+ },
"node_modules/@rolldown/binding-android-arm64": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/@rolldown/binding-android-arm64/-/binding-android-arm64-1.0.2.tgz",
@@ -1104,6 +1121,53 @@
}
}
},
+ "node_modules/playwright": {
+ "version": "1.60.0",
+ "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.60.0.tgz",
+ "integrity": "sha512-hheHdokM8cdqCb0lcE3s+zT4t4W+vvjpGxsZlDnikarzx8tSzMebh3UiFtgqwFwnTnjYQcsyMF8ei2mCO/tpeA==",
+ "dev": true,
+ "license": "Apache-2.0",
+ "dependencies": {
+ "playwright-core": "1.60.0"
+ },
+ "bin": {
+ "playwright": "cli.js"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "optionalDependencies": {
+ "fsevents": "2.3.2"
+ }
+ },
+ "node_modules/playwright-core": {
+ "version": "1.60.0",
+ "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.60.0.tgz",
+ "integrity": "sha512-9bW6zvX/m0lEbgTKJ6YppOKx8H3VOPBMOCFh2irXFOT4BbHgrx5hPjwJYLT40Lu+4qtD36qKc/Hn56StUW57IA==",
+ "dev": true,
+ "license": "Apache-2.0",
+ "bin": {
+ "playwright-core": "cli.js"
+ },
+ "engines": {
+ "node": ">=18"
+ }
+ },
+ "node_modules/playwright/node_modules/fsevents": {
+ "version": "2.3.2",
+ "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
+ "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==",
+ "dev": true,
+ "hasInstallScript": true,
+ "license": "MIT",
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "node": "^8.16.0 || ^10.6.0 || >=11.0.0"
+ }
+ },
"node_modules/postcss": {
"version": "8.5.15",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.15.tgz",
diff --git a/frontend/package.json b/frontend/package.json
index 981e7f3..7fdaf49 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -6,13 +6,15 @@
"scripts": {
"dev": "vite",
"build": "vue-tsc -b && vite build",
- "preview": "vite preview"
+ "preview": "vite preview",
+ "test:e2e": "playwright test"
},
"dependencies": {
"pinia": "^3.0.4",
"vue": "^3.5.34"
},
"devDependencies": {
+ "@playwright/test": "^1.60.0",
"@types/node": "^24.12.3",
"@vitejs/plugin-vue": "^6.0.6",
"@vue/tsconfig": "^0.9.1",
diff --git a/frontend/playwright.config.ts b/frontend/playwright.config.ts
new file mode 100644
index 0000000..f02df11
--- /dev/null
+++ b/frontend/playwright.config.ts
@@ -0,0 +1,20 @@
+import { defineConfig } from "@playwright/test";
+
+export default defineConfig({
+ testDir: "./tests/e2e",
+ timeout: 60000,
+ expect: { timeout: 10000 },
+ retries: 0,
+ use: {
+ baseURL: "http://localhost:5173",
+ headless: true,
+ screenshot: "only-on-failure",
+ trace: "retain-on-failure",
+ },
+ webServer: {
+ command: "npm run dev",
+ url: "http://localhost:5173",
+ reuseExistingServer: true,
+ timeout: 30000,
+ },
+});
diff --git a/frontend/tests/e2e/main-flows.spec.ts b/frontend/tests/e2e/main-flows.spec.ts
new file mode 100644
index 0000000..b394013
--- /dev/null
+++ b/frontend/tests/e2e/main-flows.spec.ts
@@ -0,0 +1,168 @@
+/**
+ * E2E tests: key user flows for the JRXML Agent frontend.
+ *
+ * Pre-requisites: npm run dev (reuseExistingServer in playwright.config).
+ * API calls are intercepted by page.route() — no real backend needed.
+ */
+
+import { test, expect } from "@playwright/test";
+
+// ── helpers ────────────────────────────────────────────────────
+
+function mockApi(page: any) {
+ page.route("**/api/health", (route: any) =>
+ route.fulfill({ json: { status: "ok", version: "5.0" } })
+ );
+
+ page.route("**/api/sessions", (route: any) => {
+ if (route.request().method() === "POST") {
+ return route.fulfill({
+ json: {
+ session_id: "test12345678",
+ session_name: "新建报表 2026-05-22",
+ created_at: "2026-05-22T10:00:00.000Z",
+ updated_at: "2026-05-22T10:00:00.000Z",
+ },
+ });
+ }
+ return route.fulfill({ json: { sessions: [] } });
+ });
+
+ page.route("**/api/sessions/*/chat", (route: any) => {
+ const sseBody = [
+ "event: node_start",
+ 'data: {"node":"classify_intent","label":"识别意图","step_index":1}',
+ "",
+ "event: node_complete",
+ 'data: {"node":"classify_intent","label":"识别意图","detail":"意图: 新建报表"}',
+ "",
+ "event: agent_complete",
+ 'data: {"reason":"done","intent":"initial_generation","status":"pass","jrxml_length":42,"versions":1,"total_duration_ms":1200}',
+ "",
+ "",
+ ].join("\n");
+ return route.fulfill({
+ status: 200,
+ headers: { "content-type": "text/event-stream" },
+ body: sseBody,
+ });
+ });
+
+ page.route("**/api/upload", (route: any) =>
+ route.fulfill({
+ json: { file_id: "f001122334455", filename: "test.png", size: 1024 },
+ })
+ );
+
+ // Catch-all for GET/DELETE /api/sessions/:id (must fallback for POST to let chat route match)
+ page.route("**/api/sessions/**", (route: any) => {
+ if (route.request().method() === "DELETE") {
+ return route.fulfill({ json: { status: "deleted" } });
+ }
+ if (route.request().method() === "GET") {
+ return route.fulfill({
+ json: {
+ session_id: "test12345678",
+ session_name: "测试会话",
+ agent_state: { current_jrxml: "" },
+ },
+ });
+ }
+ return route.fallback();
+ });
+}
+
+// ── tests ──────────────────────────────────────────────────────
+
+test.describe("Page load", () => {
+ test("renders sidebar and input area", async ({ page }) => {
+ await mockApi(page);
+ await page.goto("/");
+
+ await expect(page.locator("aside.sidebar")).toBeVisible();
+ await expect(page.locator("h2")).toContainText("JRXML");
+ await expect(page.locator(".unified-input")).toBeVisible();
+ });
+
+ test("sidebar shows session list header and new button", async ({ page }) => {
+ await mockApi(page);
+ await page.goto("/");
+
+ await expect(page.getByText("会话列表")).toBeVisible();
+ await expect(page.locator('button[title="新建会话"]')).toBeVisible();
+ });
+});
+
+test.describe("Session management", () => {
+ test("creates new session on button click", async ({ page }) => {
+ await mockApi(page);
+ await page.goto("/");
+
+ await page.locator('button[title="新建会话"]').click();
+ await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 });
+ await expect(page.locator(".session-item.active")).toBeVisible();
+ });
+
+ test("can delete current session", async ({ page }) => {
+ await mockApi(page);
+ await page.goto("/");
+
+ await page.locator('button[title="新建会话"]').click();
+ await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 });
+
+ page.on("dialog", (dialog) => dialog.accept());
+ await page.locator(".btn-delete").click();
+
+ await expect(page.locator(".session-item")).toHaveCount(0, { timeout: 5000 });
+ });
+});
+
+test.describe("Chat flow", () => {
+ test("sends text and displays user message + process section", async ({ page }) => {
+ await mockApi(page);
+ await page.goto("/");
+
+ await page.locator('button[title="新建会话"]').click();
+ await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 });
+
+ const textarea = page.locator(".unified-input textarea");
+ await textarea.fill("生成一个员工名册报表");
+ await page.locator(".send-btn").click();
+
+ await expect(
+ page.locator(".chat-messages .message.msg-user").filter({ hasText: "员工名册" })
+ ).toBeVisible({ timeout: 10000 });
+
+ await expect(page.locator(".process-section")).toBeVisible({ timeout: 10000 });
+ });
+
+ test("summary card appears after stream complete", async ({ page }) => {
+ await mockApi(page);
+ await page.goto("/");
+
+ await page.locator('button[title="新建会话"]').click();
+ await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 });
+
+ await page.locator(".unified-input textarea").fill("生成报表");
+ await page.locator(".send-btn").click();
+
+ await expect(page.locator(".summary-card")).toBeVisible({ timeout: 15000 });
+ });
+});
+
+test.describe("Input UX", () => {
+ test("send button disabled when input empty", async ({ page }) => {
+ await mockApi(page);
+ await page.goto("/");
+
+ await expect(page.locator(".send-btn")).toBeDisabled();
+ });
+
+ test("send button enabled when text entered", async ({ page }) => {
+ await mockApi(page);
+ await page.goto("/");
+
+ await page.locator(".unified-input textarea").fill("Hi");
+ await expect(page.locator(".send-btn")).toBeEnabled();
+ });
+});
diff --git a/tests/test_agent.py b/tests/test_agent.py
index 787848d..34153a6 100644
--- a/tests/test_agent.py
+++ b/tests/test_agent.py
@@ -44,8 +44,8 @@ class TestAcceptanceScenarios:
final = run_graph(graph, state)
assert final.get("current_jrxml"), "应该已生成 JRXML"
- # 注意:通过/失败取决于 LLM 输出质量;我们检查是否得到了结果
- print(f"场景 1 状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}")
+ assert final.get("status") in ("pass", "fail"), f"意外状态: {final.get('status')}"
+ assert ""})
+
+ resp = client.get(f"/api/sessions/{sid}/download/latest")
+ assert resp.status_code == 200
+ assert "", "status": "pass"}}),
+ ("updates", {"validate": {"status": "pass"}}),
+ ("updates", {"finalize": {}}),
+ ("done", {"reason": "graph_completed"}),
+ ]
+ # 注意:_graph 是模块级变量,在导入时就编译了。需要直接替换。
+ monkeypatch.setattr("api_server._graph", mock_graph)
+ # 同时替换 agent.graph.build_graph 以防后续重新编译
+ monkeypatch.setattr("agent.graph.build_graph", lambda on_node_start=None: mock_graph)
+ return mock_graph
+
+ def test_empty_payload_rejected(self, client, temp_sessions):
+ sid = client.post("/api/sessions").json()["session_id"]
+ resp = client.post(
+ f"/api/sessions/{sid}/chat",
+ json={"text": "", "file_ids": []},
+ )
+ assert resp.status_code == 400
+
+ def test_sse_stream_returns_valid_events(self, client, temp_sessions):
+ sid = client.post("/api/sessions").json()["session_id"]
+ with client.stream(
+ "POST",
+ f"/api/sessions/{sid}/chat",
+ json={"text": "生成一个简单的员工名册报表", "file_ids": []},
+ ) as resp:
+ assert resp.status_code == 200
+ assert resp.headers["content-type"].startswith("text/event-stream")
+ body = resp.read().decode("utf-8", errors="replace")
+ assert "event: node_complete" in body
+ assert "event: agent_complete" in body
+
+ def test_auto_creates_session_on_chat(self, client, temp_sessions):
+ with client.stream(
+ "POST",
+ "/api/sessions/auto_new_session/chat",
+ json={"text": "生成报表", "file_ids": []},
+ ) as resp:
+ assert resp.status_code == 200
+ assert b"event:" in resp.read()
+
+ def test_unknown_file_ids_not_crash(self, client, temp_sessions):
+ sid = client.post("/api/sessions").json()["session_id"]
+ with client.stream(
+ "POST",
+ f"/api/sessions/{sid}/chat",
+ json={"text": "测试", "file_ids": ["fake_id_xyz"]},
+ ) as resp:
+ assert resp.status_code == 200
+
+
+# ── 边界 & 安全测试 ────────────────────────────────────────────
+
+class TestBoundaries:
+ def test_session_id_path_traversal_returns_404(self, client, temp_sessions):
+ assert client.get("/api/sessions/../etc/passwd").status_code == 404
+
+ def test_upload_with_path_traversal_session_id(self, client, temp_sessions):
+ """路径穿越 session_id 仍正常处理(目录隔离在 UPLOADS_DIR 内)。"""
+ resp = client.post(
+ "/api/upload?session_id=../malicious",
+ files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")},
+ )
+ assert resp.status_code == 200
+
+ def test_invalid_json_body_rejected(self, client, temp_sessions):
+ sid = client.post("/api/sessions").json()["session_id"]
+ resp = client.post(
+ f"/api/sessions/{sid}/chat",
+ content=b"{not valid json",
+ headers={"Content-Type": "application/json"},
+ )
+ assert resp.status_code == 422
+
+ def test_large_payload_survives(self, client, temp_sessions):
+ """大文本(100KB)不应崩溃。"""
+ sid = client.post("/api/sessions").json()["session_id"]
+ large_text = "生成报表包含字段: " + ", ".join(f"field_{i}" for i in range(5000))
+ with client.stream(
+ "POST",
+ f"/api/sessions/{sid}/chat",
+ json={"text": large_text, "file_ids": []},
+ ) as resp:
+ assert resp.status_code == 200
diff --git a/tests/test_error_kb.py b/tests/test_error_kb.py
new file mode 100644
index 0000000..0df795f
--- /dev/null
+++ b/tests/test_error_kb.py
@@ -0,0 +1,242 @@
+"""backend/error_kb.py 单元测试 — 指纹去重 + 关键词提取 + CRUD。
+
+覆盖:
+ - _make_fingerprint 标准化与去重
+ - _extract_keywords 中英文混合提取
+ - ErrorKB.record / exists / search / search_as_context(mock ChromaDB)
+ - 全局便捷函数 record_error / search_error_cases
+"""
+
+import os
+import sys
+import json
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from backend.error_kb import (
+ _make_fingerprint,
+ _extract_keywords,
+ ErrorKB,
+ get_error_kb,
+ record_error,
+ search_error_cases,
+)
+
+
+# ── _make_fingerprint ───────────────────────────────────────────
+
+class TestMakeFingerprint:
+ def test_same_structure_same_fingerprint(self):
+ e1 = "Field $F{customer_name} is not declared in the report"
+ e2 = "Field $F{order_total} is not declared in the report"
+ assert _make_fingerprint(e1) == _make_fingerprint(e2)
+
+ def test_different_errors_different_fingerprint(self):
+ e1 = "Missing required attribute pageWidth"
+ e2 = "Query returned 0 results"
+ assert _make_fingerprint(e1) != _make_fingerprint(e2)
+
+ def test_normalizes_variable_names(self):
+ fp1 = _make_fingerprint("Field $F{amount} not found")
+ fp2 = _make_fingerprint("Field $F{total_price} not found")
+ assert fp1 == fp2
+
+ def test_normalizes_string_literals_single_quote(self):
+ fp1 = _make_fingerprint("Value 'abc123' is invalid")
+ fp2 = _make_fingerprint("Value 'xyz789' is invalid")
+ assert fp1 == fp2
+
+ def test_normalizes_string_literals_double_quote(self):
+ fp1 = _make_fingerprint('Name "test_table" not found')
+ fp2 = _make_fingerprint('Name "prod_table" not found')
+ assert fp1 == fp2
+
+ def test_normalizes_numbers(self):
+ fp1 = _make_fingerprint("Line 42 has 100 errors")
+ fp2 = _make_fingerprint("Line 7 has 3 errors")
+ assert fp1 == fp2
+
+ def test_case_insensitive(self):
+ assert _make_fingerprint("ERROR: Missing Field") == _make_fingerprint("error: missing field")
+
+ def test_whitespace_insensitive(self):
+ e1 = "missing field\n\ndeclaration"
+ e2 = "missing field declaration"
+ assert _make_fingerprint(e1) == _make_fingerprint(e2)
+
+ def test_output_is_16_char_hex(self):
+ fp = _make_fingerprint("some error message")
+ assert len(fp) == 16
+ assert all(c in "0123456789abcdef" for c in fp)
+
+
+# ── _extract_keywords ───────────────────────────────────────────
+
+class TestExtractKeywords:
+ def test_extracts_chinese_words(self):
+ kw = _extract_keywords("未声明的字段引用和语法错误")
+ has_cn = any(len(k) >= 2 and "一" <= k[0] <= "鿿" for k in kw)
+ assert has_cn
+
+ def test_extracts_english_tokens(self):
+ kw = _extract_keywords("missing field declaration in report")
+ assert "missing" in kw
+ assert "field" in kw
+ assert "report" in kw
+
+ def test_extracts_jrxml_patterns(self):
+ kw = _extract_keywords("Field $F{customer_name} not declared")
+ assert "$F{customer_name}" in kw
+
+ def test_short_tokens_ignored(self):
+ kw = _extract_keywords("a b c ab cd")
+ assert "ab" not in kw
+ assert "cd" not in kw
+
+ def test_empty_input_returns_empty_list(self):
+ assert _extract_keywords("") == []
+
+ def test_mixed_cn_en_jrxml(self):
+ kw = _extract_keywords("字段 $F{amount} 在 report 中未声明")
+ assert "$F{amount}" in kw
+ assert "report" in kw
+
+
+# ── ErrorKB class (mock ChromaDB) ───────────────────────────────
+
+def _make_patched_kb(client_override=None, collection_override=None):
+ """创建一个 ErrorKB 实例,其 ChromaDB 依赖已被 mock。
+
+ 因为 chromadb 是懒加载的(在 client/collection property 中导入),
+ 直接设置 _client/_collection 实例属性即可绕过真实 ChromaDB。
+ """
+ kb = ErrorKB()
+ kb._client = client_override or MagicMock()
+ kb._collection = collection_override or MagicMock()
+ if not client_override and not collection_override:
+ # 默认:client.get_collection 返回 mock collection
+ kb._client.get_collection.return_value = kb._collection
+ return kb
+
+
+class TestErrorKBRecord:
+ def test_exists_returns_true_when_found(self):
+ col = MagicMock()
+ col.get.return_value = {"ids": ["abc123"]}
+ kb = _make_patched_kb(collection_override=col)
+ assert kb.exists("some error") is True
+
+ def test_exists_returns_false_when_not_found(self):
+ col = MagicMock()
+ col.get.return_value = {"ids": []}
+ kb = _make_patched_kb(collection_override=col)
+ assert kb.exists("some error") is False
+
+ def test_exists_survives_exception(self):
+ col = MagicMock()
+ col.get.side_effect = RuntimeError("db down")
+ kb = _make_patched_kb(collection_override=col)
+ assert kb.exists("some error") is False
+
+ def test_record_skips_duplicate(self):
+ col = MagicMock()
+ col.get.return_value = {"ids": ["existing_fp"]}
+ kb = _make_patched_kb(collection_override=col)
+ assert kb.record("error", "", "", "fix prompt") is False
+ col.add.assert_not_called()
+
+ def test_record_adds_new_case(self):
+ col = MagicMock()
+ col.get.return_value = {"ids": []}
+ kb = _make_patched_kb(collection_override=col)
+ assert kb.record(
+ "Field $F{x} not declared",
+ "", "",
+ "prompt content", model="test-model", retry_count=2,
+ ) is True
+ col.add.assert_called_once()
+ meta = col.add.call_args[1]["metadatas"][0]
+ assert meta["retry_success"] == 3
+
+
+class TestErrorKBSearch:
+ @pytest.fixture
+ def col(self):
+ return MagicMock()
+
+ @pytest.fixture
+ def kb(self, col):
+ return _make_patched_kb(collection_override=col)
+
+ def test_search_returns_formatted_results(self, kb, col):
+ col.get.return_value = {"ids": []}
+ col.query.return_value = {
+ "ids": [["fp1"]],
+ "documents": [[json.dumps({
+ "error": "test error",
+ "good_jrxml_snippet": "",
+ "correction_prompt": "fix it",
+ "recorded_at": "2026-01-01T00:00:00",
+ })]],
+ "metadatas": [[{}]],
+ "distances": [[0.05]],
+ }
+ results = kb.search("some error", k=3)
+ assert len(results) == 1
+ assert results[0]["error"] == "test error"
+ assert results[0]["distance"] == 0.05
+
+ def test_search_returns_empty_on_exception(self, kb, col):
+ col.query.side_effect = RuntimeError("fail")
+ assert kb.search("error") == []
+
+ def test_search_as_context_formats_output(self, kb, col):
+ col.get.return_value = {"ids": []}
+ col.query.return_value = {
+ "ids": [["fp1", "fp2"]],
+ "documents": [[
+ json.dumps({"error": "e1", "good_jrxml_snippet": "", "correction_prompt": "p1", "recorded_at": ""}),
+ json.dumps({"error": "e2", "good_jrxml_snippet": "", "correction_prompt": "p2", "recorded_at": ""}),
+ ]],
+ "metadatas": [[{}, {}]],
+ "distances": [[0.1, 0.2]],
+ }
+ ctx = kb.search_as_context("error", k=2)
+ assert "[历史错误案例]" in ctx
+ assert "---" in ctx
+
+ def test_search_as_context_empty_for_no_results(self, kb, col):
+ col.get.return_value = {"ids": []}
+ col.query.return_value = {"ids": [[]], "documents": [[]], "distances": [[]]}
+ assert kb.search_as_context("error") == ""
+
+ def test_stats_returns_count(self, kb, col):
+ col.count.return_value = 42
+ assert kb.stats()["total_cases"] == 42
+
+ def test_stats_zero_on_exception(self, kb, col):
+ col.count.side_effect = RuntimeError("down")
+ assert kb.stats()["total_cases"] == 0
+
+
+# ── 全局便捷函数 ───────────────────────────────────────────────
+
+class TestConvenienceFunctions:
+ def test_get_error_kb_is_singleton(self, monkeypatch):
+ import backend.error_kb as mod
+ monkeypatch.setattr(mod, "_kb", None)
+ assert get_error_kb() is get_error_kb()
+
+ def test_record_error_delegates(self):
+ with patch.object(ErrorKB, "record", return_value=True) as mock_r:
+ assert record_error("e", "", "", "p") is True
+ mock_r.assert_called_once()
+
+ def test_search_error_cases_delegates(self):
+ with patch.object(ErrorKB, "search_as_context", return_value="ctx") as mock_s:
+ assert search_error_cases("err", k=5) == "ctx"
+ mock_s.assert_called_once_with("err", k=5)
diff --git a/tests/test_session.py b/tests/test_session.py
new file mode 100644
index 0000000..6a0414b
--- /dev/null
+++ b/tests/test_session.py
@@ -0,0 +1,210 @@
+"""backend/session.py 单元测试 — 会话 CRUD + 原子写入。
+
+覆盖:
+ - 创建/加载/保存/删除/列出
+ - 原子写入(tempfile + os.replace)
+ - 边界情况(不存在会话、损坏 JSON、空名称自动填充)
+"""
+
+import json
+import os
+import sys
+import tempfile
+import time
+from pathlib import Path
+
+import pytest
+
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from backend.session import (
+ create_session,
+ load_session,
+ save_session,
+ get_session_state,
+ list_all_sessions,
+ delete_session,
+ generate_session_id,
+ SESSIONS_DIR,
+)
+
+
+@pytest.fixture
+def temp_sessions_dir(monkeypatch):
+ with tempfile.TemporaryDirectory(prefix="test_sessions_") as tmpdir:
+ monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir))
+ yield Path(tmpdir)
+
+
+# ── create_session ──────────────────────────────────────────────
+
+class TestCreateSession:
+ def test_creates_with_defaults(self, temp_sessions_dir):
+ s = create_session()
+ assert len(s["session_id"]) == 12
+ assert "新建报表" in s["session_name"]
+ assert s["created_at"]
+ assert s["updated_at"]
+
+ def test_custom_name(self, temp_sessions_dir):
+ s = create_session(name="测试报表")
+ assert s["session_name"] == "测试报表"
+
+ def test_agent_state_preserved(self, temp_sessions_dir):
+ s = create_session(agent_state={"current_jrxml": ""})
+ assert s["agent_state"]["current_jrxml"] == ""
+
+ def test_session_id_injected_into_agent_state(self, temp_sessions_dir):
+ s = create_session()
+ assert s["agent_state"]["session_id"] == s["session_id"]
+
+ def test_persists_json_to_disk(self, temp_sessions_dir):
+ s = create_session(name="磁盘测试")
+ fp = temp_sessions_dir / f"{s['session_id']}.json"
+ assert fp.exists()
+ loaded = json.loads(fp.read_text("utf-8"))
+ assert loaded["session_name"] == "磁盘测试"
+
+ def test_unique_ids_no_collision(self, temp_sessions_dir):
+ ids = {generate_session_id() for _ in range(100)}
+ assert len(ids) == 100
+
+ def test_creates_sessions_dir_if_missing(self, temp_sessions_dir):
+ nested = temp_sessions_dir / "nested" / "sub"
+ import backend.session as mod
+
+ monkeypatch = pytest.MonkeyPatch()
+ monkeypatch.setattr(mod, "SESSIONS_DIR", nested)
+ s = create_session()
+ assert nested.exists()
+ assert (nested / f"{s['session_id']}.json").exists()
+
+
+# ── load_session ────────────────────────────────────────────────
+
+class TestLoadSession:
+ def test_returns_none_for_missing(self, temp_sessions_dir):
+ assert load_session("nonexistent_id") is None
+
+ def test_loads_existing(self, temp_sessions_dir):
+ created = create_session(name="加载测试")
+ loaded = load_session(created["session_id"])
+ assert loaded["session_name"] == "加载测试"
+ assert loaded["session_id"] == created["session_id"]
+
+ def test_load_includes_agent_state(self, temp_sessions_dir):
+ created = create_session(agent_state={"field_count": 5})
+ loaded = load_session(created["session_id"])
+ assert loaded["agent_state"]["field_count"] == 5
+
+
+# ── save_session ────────────────────────────────────────────────
+
+class TestSaveSession:
+ def test_updates_name_and_state(self, temp_sessions_dir):
+ created = create_session(name="原始")
+ save_session(created["session_id"], {"new_key": True}, session_name="更新")
+ loaded = load_session(created["session_id"])
+ assert loaded["session_name"] == "更新"
+ assert loaded["agent_state"]["new_key"] is True
+
+ def test_preserves_created_at(self, temp_sessions_dir):
+ created = create_session()
+ original = created["created_at"]
+ save_session(created["session_id"], {"x": 1})
+ assert load_session(created["session_id"])["created_at"] == original
+
+ def test_updates_updated_at(self, temp_sessions_dir):
+ created = create_session()
+ time.sleep(0.01)
+ save_session(created["session_id"], {"x": 1})
+ loaded = load_session(created["session_id"])
+ assert loaded["updated_at"] != created["updated_at"]
+
+ def test_atomic_write_produces_valid_json(self, temp_sessions_dir):
+ created = create_session()
+ save_session(created["session_id"], {"data": "x" * 1000})
+ fp = temp_sessions_dir / f"{created['session_id']}.json"
+ data = json.loads(fp.read_text("utf-8"))
+ assert data["agent_state"]["data"] == "x" * 1000
+
+ def test_auto_generates_name_when_empty(self, temp_sessions_dir):
+ created = create_session(name="")
+ save_session(created["session_id"], {"x": 1})
+ assert load_session(created["session_id"])["session_name"]
+
+ def test_keeps_existing_name_when_not_provided(self, temp_sessions_dir):
+ created = create_session(name="原名")
+ save_session(created["session_id"], {"x": 1})
+ assert load_session(created["session_id"])["session_name"] == "原名"
+
+ def test_fills_missing_created_at(self, temp_sessions_dir):
+ sid = "test_no_created"
+ fp = temp_sessions_dir / f"{sid}.json"
+ fp.write_text(
+ json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"
+ )
+ save_session(sid, {"x": 1})
+ assert load_session(sid)["created_at"]
+
+
+# ── get_session_state ───────────────────────────────────────────
+
+class TestGetSessionState:
+ def test_none_for_missing(self, temp_sessions_dir):
+ assert get_session_state("missing") is None
+
+ def test_returns_all_keys(self, temp_sessions_dir):
+ created = create_session(name="状态测试")
+ state = get_session_state(created["session_id"])
+ for key in ("session_id", "session_name", "agent_state", "created_at", "updated_at"):
+ assert key in state
+
+
+# ── list_all_sessions ───────────────────────────────────────────
+
+class TestListAllSessions:
+ def test_empty_when_no_sessions(self, temp_sessions_dir):
+ assert list_all_sessions() == []
+
+ def test_lists_all_created(self, temp_sessions_dir):
+ s1 = create_session(name="A")
+ s2 = create_session(name="B")
+ ids = {s["session_id"] for s in list_all_sessions()}
+ assert s1["session_id"] in ids
+ assert s2["session_id"] in ids
+
+ def test_summary_excludes_agent_state(self, temp_sessions_dir):
+ create_session(agent_state={"secret": True})
+ result = list_all_sessions()
+ assert "agent_state" not in result[0]
+
+ def test_sorted_by_mtime_desc(self, temp_sessions_dir):
+ s1 = create_session(name="先")
+ time.sleep(0.02)
+ s2 = create_session(name="后")
+ assert list_all_sessions()[0]["session_id"] == s2["session_id"]
+
+ def test_skips_corrupt_json(self, temp_sessions_dir):
+ (temp_sessions_dir / "bad.json").write_text("{not json}", "utf-8")
+ create_session(name="正常")
+ assert len(list_all_sessions()) == 1
+
+
+# ── delete_session ──────────────────────────────────────────────
+
+class TestDeleteSession:
+ def test_returns_false_for_missing(self, temp_sessions_dir):
+ assert delete_session("ghost_id") is False
+
+ def test_returns_true_and_removes(self, temp_sessions_dir):
+ created = create_session()
+ assert delete_session(created["session_id"]) is True
+ assert load_session(created["session_id"]) is None
+
+ def test_file_is_removed_from_disk(self, temp_sessions_dir):
+ created = create_session()
+ fp = temp_sessions_dir / f"{created['session_id']}.json"
+ assert fp.exists()
+ delete_session(created["session_id"])
+ assert not fp.exists()