From 1952d75f1323c07c34401dad6b2b8d4e2c4fafab Mon Sep 17 00:00:00 2001 From: panda <1415243231@qq.com> Date: Sat, 23 May 2026 08:38:29 +0800 Subject: [PATCH] test: add unit/integration/E2E test suites, fix create_session bug, update docs - Unit tests: test_session.py (27), test_error_kb.py (24), test_agent.py hardened - Integration tests: test_api_integration.py (25) with FastAPI TestClient - E2E tests: main-flows.spec.ts (8) with Playwright + API mocking - Bug fix: backend/session.py create_session() missing session_id parameter - Config: frontend/playwright.config.ts, npm run test:e2e - Docs: update CLAUDE.md v9, .gitignore for test artifacts/eval reports --- .gitignore | 13 ++ CLAUDE.md | 33 ++++ backend/session.py | 7 +- frontend/package-lock.json | 64 +++++++ frontend/package.json | 4 +- frontend/playwright.config.ts | 20 ++ frontend/tests/e2e/main-flows.spec.ts | 168 ++++++++++++++++ tests/test_agent.py | 17 +- tests/test_api_integration.py | 263 ++++++++++++++++++++++++++ tests/test_error_kb.py | 242 ++++++++++++++++++++++++ tests/test_session.py | 210 ++++++++++++++++++++ 11 files changed, 1029 insertions(+), 12 deletions(-) create mode 100644 frontend/playwright.config.ts create mode 100644 frontend/tests/e2e/main-flows.spec.ts create mode 100644 tests/test_api_integration.py create mode 100644 tests/test_error_kb.py create mode 100644 tests/test_session.py diff --git a/.gitignore b/.gitignore index 4bddf00..ad54e14 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,19 @@ db/chroma/ sessions/ logs/ db/ +# 自动评测 (Mavis AI) +.mavis/ +EVALUATION_REPORT.md + +# 上传文件 +uploads/ + +# OCR 临时输出 +ocr_raw_positions.json + +# Playwright E2E 测试产物 +frontend/test-results/ + # RAG 管线中间产物 (rag 子模块内) rag/jrxml_chunker_output/ rag/embeddings/ diff --git a/CLAUDE.md b/CLAUDE.md index 86b4b9f..7a0ccd5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -327,3 +327,36 @@ validation_service/ (FastAPI, 端口 8001) — 不变 - `prompts/refine_layout.md` — 1 处 Python 将 `{{` 输出为字面量 `{`,LLM 看到的内容不变。 + +## 更新 (v9 — 2026-05-22) + +### 测试基础设施全面补齐 + +**单元测试** (76 测试): +- `tests/test_session.py` — 27 测试:会话 CRUD、原子写入、唯一 ID、损坏 JSON 跳过 +- `tests/test_error_kb.py` — 24 测试:指纹去重、关键词提取(中/英/JRXML)、ErrorKB CRUD、搜索、统计 +- `tests/test_agent.py` — 5 个软断言强化为严格断言(`status`/`current_jrxml` 存在性检查) +- 已有测试:`test_ocr_extraction.py`(49)、`test_layered_generation.py`(19)、`test_validation.py`(6)、`test_file_parser_formats.py`(4)、`test_annotation_detector.py`(7)、`test_e2e_ocr.py`(3) + +**集成测试** (25 测试, `tests/test_api_integration.py`): +- FastAPI TestClient 全覆盖:健康检查、配置、会话 CRUD、文件上传、下载、Chat SSE、安全边界(路径穿越/非法 JSON/大 payload) +- Mock LangGraph graph 避免真实 LLM 调用 + +**E2E 测试** (8 测试, `frontend/tests/e2e/main-flows.spec.ts`): +- Playwright 浏览器自动化:页面加载、侧边栏、会话管理、聊天流程、输入 UX +- 全量 API Mock(`page.route`)无需后端运行 +- 配置: `frontend/playwright.config.ts`, `npm run test:e2e` + +**运行测试**: +```bash +# 全部单元+集成测试 +cd D:\Idea Project\jaspersoft && python -m pytest tests/ -v + +# 仅 E2E(需要前端 dev server) +cd frontend && npx playwright test +``` + +### Bug 修复: create_session 参数缺失 + +`backend/session.py` — `create_session()` 新增可选参数 `session_id: Optional[str] = None`。 +`api_server.py:507` 调用 `create_session(session_id=session_id)` 时之前会抛出 `TypeError`。 diff --git a/backend/session.py b/backend/session.py index 1333592..22472c0 100644 --- a/backend/session.py +++ b/backend/session.py @@ -34,10 +34,11 @@ def generate_session_id() -> str: return uuid.uuid4().hex[:12] -def create_session(name: str = "", agent_state: Optional[dict] = None) -> dict: - """创建新会话,返回会话元数据。""" +def create_session(name: str = "", agent_state: Optional[dict] = None, + session_id: Optional[str] = None) -> dict: + """创建新会话,返回会话元数据。session_id 可选——传入时使用指定 ID。""" _ensure_dir() - sid = generate_session_id() + sid = session_id or generate_session_id() now = _now_iso() agent_state = agent_state or {} agent_state["session_id"] = sid diff --git a/frontend/package-lock.json b/frontend/package-lock.json index b3f19b2..a22acc1 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -12,6 +12,7 @@ "vue": "^3.5.34" }, "devDependencies": { + "@playwright/test": "^1.60.0", "@types/node": "^24.12.3", "@vitejs/plugin-vue": "^6.0.6", "@vue/tsconfig": "^0.9.1", @@ -135,6 +136,22 @@ "url": "https://github.com/sponsors/Boshen" } }, + "node_modules/@playwright/test": { + "version": "1.60.0", + "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.60.0.tgz", + "integrity": "sha512-O71yZIbAh/PxDMNGns37GHBIfrVkEVyn+AXyIa5dOTfb4/xNvRWV+Vv/NMbNCtODB/pO7vLlF2OTmMVLhmr7Ag==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "playwright": "1.60.0" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/@rolldown/binding-android-arm64": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/@rolldown/binding-android-arm64/-/binding-android-arm64-1.0.2.tgz", @@ -1104,6 +1121,53 @@ } } }, + "node_modules/playwright": { + "version": "1.60.0", + "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.60.0.tgz", + "integrity": "sha512-hheHdokM8cdqCb0lcE3s+zT4t4W+vvjpGxsZlDnikarzx8tSzMebh3UiFtgqwFwnTnjYQcsyMF8ei2mCO/tpeA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "playwright-core": "1.60.0" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "fsevents": "2.3.2" + } + }, + "node_modules/playwright-core": { + "version": "1.60.0", + "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.60.0.tgz", + "integrity": "sha512-9bW6zvX/m0lEbgTKJ6YppOKx8H3VOPBMOCFh2irXFOT4BbHgrx5hPjwJYLT40Lu+4qtD36qKc/Hn56StUW57IA==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "playwright-core": "cli.js" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/playwright/node_modules/fsevents": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", + "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, "node_modules/postcss": { "version": "8.5.15", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.15.tgz", diff --git a/frontend/package.json b/frontend/package.json index 981e7f3..7fdaf49 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -6,13 +6,15 @@ "scripts": { "dev": "vite", "build": "vue-tsc -b && vite build", - "preview": "vite preview" + "preview": "vite preview", + "test:e2e": "playwright test" }, "dependencies": { "pinia": "^3.0.4", "vue": "^3.5.34" }, "devDependencies": { + "@playwright/test": "^1.60.0", "@types/node": "^24.12.3", "@vitejs/plugin-vue": "^6.0.6", "@vue/tsconfig": "^0.9.1", diff --git a/frontend/playwright.config.ts b/frontend/playwright.config.ts new file mode 100644 index 0000000..f02df11 --- /dev/null +++ b/frontend/playwright.config.ts @@ -0,0 +1,20 @@ +import { defineConfig } from "@playwright/test"; + +export default defineConfig({ + testDir: "./tests/e2e", + timeout: 60000, + expect: { timeout: 10000 }, + retries: 0, + use: { + baseURL: "http://localhost:5173", + headless: true, + screenshot: "only-on-failure", + trace: "retain-on-failure", + }, + webServer: { + command: "npm run dev", + url: "http://localhost:5173", + reuseExistingServer: true, + timeout: 30000, + }, +}); diff --git a/frontend/tests/e2e/main-flows.spec.ts b/frontend/tests/e2e/main-flows.spec.ts new file mode 100644 index 0000000..b394013 --- /dev/null +++ b/frontend/tests/e2e/main-flows.spec.ts @@ -0,0 +1,168 @@ +/** + * E2E tests: key user flows for the JRXML Agent frontend. + * + * Pre-requisites: npm run dev (reuseExistingServer in playwright.config). + * API calls are intercepted by page.route() — no real backend needed. + */ + +import { test, expect } from "@playwright/test"; + +// ── helpers ──────────────────────────────────────────────────── + +function mockApi(page: any) { + page.route("**/api/health", (route: any) => + route.fulfill({ json: { status: "ok", version: "5.0" } }) + ); + + page.route("**/api/sessions", (route: any) => { + if (route.request().method() === "POST") { + return route.fulfill({ + json: { + session_id: "test12345678", + session_name: "新建报表 2026-05-22", + created_at: "2026-05-22T10:00:00.000Z", + updated_at: "2026-05-22T10:00:00.000Z", + }, + }); + } + return route.fulfill({ json: { sessions: [] } }); + }); + + page.route("**/api/sessions/*/chat", (route: any) => { + const sseBody = [ + "event: node_start", + 'data: {"node":"classify_intent","label":"识别意图","step_index":1}', + "", + "event: node_complete", + 'data: {"node":"classify_intent","label":"识别意图","detail":"意图: 新建报表"}', + "", + "event: agent_complete", + 'data: {"reason":"done","intent":"initial_generation","status":"pass","jrxml_length":42,"versions":1,"total_duration_ms":1200}', + "", + "", + ].join("\n"); + return route.fulfill({ + status: 200, + headers: { "content-type": "text/event-stream" }, + body: sseBody, + }); + }); + + page.route("**/api/upload", (route: any) => + route.fulfill({ + json: { file_id: "f001122334455", filename: "test.png", size: 1024 }, + }) + ); + + // Catch-all for GET/DELETE /api/sessions/:id (must fallback for POST to let chat route match) + page.route("**/api/sessions/**", (route: any) => { + if (route.request().method() === "DELETE") { + return route.fulfill({ json: { status: "deleted" } }); + } + if (route.request().method() === "GET") { + return route.fulfill({ + json: { + session_id: "test12345678", + session_name: "测试会话", + agent_state: { current_jrxml: "" }, + }, + }); + } + return route.fallback(); + }); +} + +// ── tests ────────────────────────────────────────────────────── + +test.describe("Page load", () => { + test("renders sidebar and input area", async ({ page }) => { + await mockApi(page); + await page.goto("/"); + + await expect(page.locator("aside.sidebar")).toBeVisible(); + await expect(page.locator("h2")).toContainText("JRXML"); + await expect(page.locator(".unified-input")).toBeVisible(); + }); + + test("sidebar shows session list header and new button", async ({ page }) => { + await mockApi(page); + await page.goto("/"); + + await expect(page.getByText("会话列表")).toBeVisible(); + await expect(page.locator('button[title="新建会话"]')).toBeVisible(); + }); +}); + +test.describe("Session management", () => { + test("creates new session on button click", async ({ page }) => { + await mockApi(page); + await page.goto("/"); + + await page.locator('button[title="新建会话"]').click(); + await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 }); + await expect(page.locator(".session-item.active")).toBeVisible(); + }); + + test("can delete current session", async ({ page }) => { + await mockApi(page); + await page.goto("/"); + + await page.locator('button[title="新建会话"]').click(); + await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 }); + + page.on("dialog", (dialog) => dialog.accept()); + await page.locator(".btn-delete").click(); + + await expect(page.locator(".session-item")).toHaveCount(0, { timeout: 5000 }); + }); +}); + +test.describe("Chat flow", () => { + test("sends text and displays user message + process section", async ({ page }) => { + await mockApi(page); + await page.goto("/"); + + await page.locator('button[title="新建会话"]').click(); + await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 }); + + const textarea = page.locator(".unified-input textarea"); + await textarea.fill("生成一个员工名册报表"); + await page.locator(".send-btn").click(); + + await expect( + page.locator(".chat-messages .message.msg-user").filter({ hasText: "员工名册" }) + ).toBeVisible({ timeout: 10000 }); + + await expect(page.locator(".process-section")).toBeVisible({ timeout: 10000 }); + }); + + test("summary card appears after stream complete", async ({ page }) => { + await mockApi(page); + await page.goto("/"); + + await page.locator('button[title="新建会话"]').click(); + await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 }); + + await page.locator(".unified-input textarea").fill("生成报表"); + await page.locator(".send-btn").click(); + + await expect(page.locator(".summary-card")).toBeVisible({ timeout: 15000 }); + }); +}); + +test.describe("Input UX", () => { + test("send button disabled when input empty", async ({ page }) => { + await mockApi(page); + await page.goto("/"); + + await expect(page.locator(".send-btn")).toBeDisabled(); + }); + + test("send button enabled when text entered", async ({ page }) => { + await mockApi(page); + await page.goto("/"); + + await page.locator(".unified-input textarea").fill("Hi"); + await expect(page.locator(".send-btn")).toBeEnabled(); + }); +}); diff --git a/tests/test_agent.py b/tests/test_agent.py index 787848d..34153a6 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -44,8 +44,8 @@ class TestAcceptanceScenarios: final = run_graph(graph, state) assert final.get("current_jrxml"), "应该已生成 JRXML" - # 注意:通过/失败取决于 LLM 输出质量;我们检查是否得到了结果 - print(f"场景 1 状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}") + assert final.get("status") in ("pass", "fail"), f"意外状态: {final.get('status')}" + assert ""}) + + resp = client.get(f"/api/sessions/{sid}/download/latest") + assert resp.status_code == 200 + assert "", "status": "pass"}}), + ("updates", {"validate": {"status": "pass"}}), + ("updates", {"finalize": {}}), + ("done", {"reason": "graph_completed"}), + ] + # 注意:_graph 是模块级变量,在导入时就编译了。需要直接替换。 + monkeypatch.setattr("api_server._graph", mock_graph) + # 同时替换 agent.graph.build_graph 以防后续重新编译 + monkeypatch.setattr("agent.graph.build_graph", lambda on_node_start=None: mock_graph) + return mock_graph + + def test_empty_payload_rejected(self, client, temp_sessions): + sid = client.post("/api/sessions").json()["session_id"] + resp = client.post( + f"/api/sessions/{sid}/chat", + json={"text": "", "file_ids": []}, + ) + assert resp.status_code == 400 + + def test_sse_stream_returns_valid_events(self, client, temp_sessions): + sid = client.post("/api/sessions").json()["session_id"] + with client.stream( + "POST", + f"/api/sessions/{sid}/chat", + json={"text": "生成一个简单的员工名册报表", "file_ids": []}, + ) as resp: + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + body = resp.read().decode("utf-8", errors="replace") + assert "event: node_complete" in body + assert "event: agent_complete" in body + + def test_auto_creates_session_on_chat(self, client, temp_sessions): + with client.stream( + "POST", + "/api/sessions/auto_new_session/chat", + json={"text": "生成报表", "file_ids": []}, + ) as resp: + assert resp.status_code == 200 + assert b"event:" in resp.read() + + def test_unknown_file_ids_not_crash(self, client, temp_sessions): + sid = client.post("/api/sessions").json()["session_id"] + with client.stream( + "POST", + f"/api/sessions/{sid}/chat", + json={"text": "测试", "file_ids": ["fake_id_xyz"]}, + ) as resp: + assert resp.status_code == 200 + + +# ── 边界 & 安全测试 ──────────────────────────────────────────── + +class TestBoundaries: + def test_session_id_path_traversal_returns_404(self, client, temp_sessions): + assert client.get("/api/sessions/../etc/passwd").status_code == 404 + + def test_upload_with_path_traversal_session_id(self, client, temp_sessions): + """路径穿越 session_id 仍正常处理(目录隔离在 UPLOADS_DIR 内)。""" + resp = client.post( + "/api/upload?session_id=../malicious", + files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")}, + ) + assert resp.status_code == 200 + + def test_invalid_json_body_rejected(self, client, temp_sessions): + sid = client.post("/api/sessions").json()["session_id"] + resp = client.post( + f"/api/sessions/{sid}/chat", + content=b"{not valid json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + def test_large_payload_survives(self, client, temp_sessions): + """大文本(100KB)不应崩溃。""" + sid = client.post("/api/sessions").json()["session_id"] + large_text = "生成报表包含字段: " + ", ".join(f"field_{i}" for i in range(5000)) + with client.stream( + "POST", + f"/api/sessions/{sid}/chat", + json={"text": large_text, "file_ids": []}, + ) as resp: + assert resp.status_code == 200 diff --git a/tests/test_error_kb.py b/tests/test_error_kb.py new file mode 100644 index 0000000..0df795f --- /dev/null +++ b/tests/test_error_kb.py @@ -0,0 +1,242 @@ +"""backend/error_kb.py 单元测试 — 指纹去重 + 关键词提取 + CRUD。 + +覆盖: + - _make_fingerprint 标准化与去重 + - _extract_keywords 中英文混合提取 + - ErrorKB.record / exists / search / search_as_context(mock ChromaDB) + - 全局便捷函数 record_error / search_error_cases +""" + +import os +import sys +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from backend.error_kb import ( + _make_fingerprint, + _extract_keywords, + ErrorKB, + get_error_kb, + record_error, + search_error_cases, +) + + +# ── _make_fingerprint ─────────────────────────────────────────── + +class TestMakeFingerprint: + def test_same_structure_same_fingerprint(self): + e1 = "Field $F{customer_name} is not declared in the report" + e2 = "Field $F{order_total} is not declared in the report" + assert _make_fingerprint(e1) == _make_fingerprint(e2) + + def test_different_errors_different_fingerprint(self): + e1 = "Missing required attribute pageWidth" + e2 = "Query returned 0 results" + assert _make_fingerprint(e1) != _make_fingerprint(e2) + + def test_normalizes_variable_names(self): + fp1 = _make_fingerprint("Field $F{amount} not found") + fp2 = _make_fingerprint("Field $F{total_price} not found") + assert fp1 == fp2 + + def test_normalizes_string_literals_single_quote(self): + fp1 = _make_fingerprint("Value 'abc123' is invalid") + fp2 = _make_fingerprint("Value 'xyz789' is invalid") + assert fp1 == fp2 + + def test_normalizes_string_literals_double_quote(self): + fp1 = _make_fingerprint('Name "test_table" not found') + fp2 = _make_fingerprint('Name "prod_table" not found') + assert fp1 == fp2 + + def test_normalizes_numbers(self): + fp1 = _make_fingerprint("Line 42 has 100 errors") + fp2 = _make_fingerprint("Line 7 has 3 errors") + assert fp1 == fp2 + + def test_case_insensitive(self): + assert _make_fingerprint("ERROR: Missing Field") == _make_fingerprint("error: missing field") + + def test_whitespace_insensitive(self): + e1 = "missing field\n\ndeclaration" + e2 = "missing field declaration" + assert _make_fingerprint(e1) == _make_fingerprint(e2) + + def test_output_is_16_char_hex(self): + fp = _make_fingerprint("some error message") + assert len(fp) == 16 + assert all(c in "0123456789abcdef" for c in fp) + + +# ── _extract_keywords ─────────────────────────────────────────── + +class TestExtractKeywords: + def test_extracts_chinese_words(self): + kw = _extract_keywords("未声明的字段引用和语法错误") + has_cn = any(len(k) >= 2 and "一" <= k[0] <= "鿿" for k in kw) + assert has_cn + + def test_extracts_english_tokens(self): + kw = _extract_keywords("missing field declaration in report") + assert "missing" in kw + assert "field" in kw + assert "report" in kw + + def test_extracts_jrxml_patterns(self): + kw = _extract_keywords("Field $F{customer_name} not declared") + assert "$F{customer_name}" in kw + + def test_short_tokens_ignored(self): + kw = _extract_keywords("a b c ab cd") + assert "ab" not in kw + assert "cd" not in kw + + def test_empty_input_returns_empty_list(self): + assert _extract_keywords("") == [] + + def test_mixed_cn_en_jrxml(self): + kw = _extract_keywords("字段 $F{amount} 在 report 中未声明") + assert "$F{amount}" in kw + assert "report" in kw + + +# ── ErrorKB class (mock ChromaDB) ─────────────────────────────── + +def _make_patched_kb(client_override=None, collection_override=None): + """创建一个 ErrorKB 实例,其 ChromaDB 依赖已被 mock。 + + 因为 chromadb 是懒加载的(在 client/collection property 中导入), + 直接设置 _client/_collection 实例属性即可绕过真实 ChromaDB。 + """ + kb = ErrorKB() + kb._client = client_override or MagicMock() + kb._collection = collection_override or MagicMock() + if not client_override and not collection_override: + # 默认:client.get_collection 返回 mock collection + kb._client.get_collection.return_value = kb._collection + return kb + + +class TestErrorKBRecord: + def test_exists_returns_true_when_found(self): + col = MagicMock() + col.get.return_value = {"ids": ["abc123"]} + kb = _make_patched_kb(collection_override=col) + assert kb.exists("some error") is True + + def test_exists_returns_false_when_not_found(self): + col = MagicMock() + col.get.return_value = {"ids": []} + kb = _make_patched_kb(collection_override=col) + assert kb.exists("some error") is False + + def test_exists_survives_exception(self): + col = MagicMock() + col.get.side_effect = RuntimeError("db down") + kb = _make_patched_kb(collection_override=col) + assert kb.exists("some error") is False + + def test_record_skips_duplicate(self): + col = MagicMock() + col.get.return_value = {"ids": ["existing_fp"]} + kb = _make_patched_kb(collection_override=col) + assert kb.record("error", "", "", "fix prompt") is False + col.add.assert_not_called() + + def test_record_adds_new_case(self): + col = MagicMock() + col.get.return_value = {"ids": []} + kb = _make_patched_kb(collection_override=col) + assert kb.record( + "Field $F{x} not declared", + "", "", + "prompt content", model="test-model", retry_count=2, + ) is True + col.add.assert_called_once() + meta = col.add.call_args[1]["metadatas"][0] + assert meta["retry_success"] == 3 + + +class TestErrorKBSearch: + @pytest.fixture + def col(self): + return MagicMock() + + @pytest.fixture + def kb(self, col): + return _make_patched_kb(collection_override=col) + + def test_search_returns_formatted_results(self, kb, col): + col.get.return_value = {"ids": []} + col.query.return_value = { + "ids": [["fp1"]], + "documents": [[json.dumps({ + "error": "test error", + "good_jrxml_snippet": "", + "correction_prompt": "fix it", + "recorded_at": "2026-01-01T00:00:00", + })]], + "metadatas": [[{}]], + "distances": [[0.05]], + } + results = kb.search("some error", k=3) + assert len(results) == 1 + assert results[0]["error"] == "test error" + assert results[0]["distance"] == 0.05 + + def test_search_returns_empty_on_exception(self, kb, col): + col.query.side_effect = RuntimeError("fail") + assert kb.search("error") == [] + + def test_search_as_context_formats_output(self, kb, col): + col.get.return_value = {"ids": []} + col.query.return_value = { + "ids": [["fp1", "fp2"]], + "documents": [[ + json.dumps({"error": "e1", "good_jrxml_snippet": "", "correction_prompt": "p1", "recorded_at": ""}), + json.dumps({"error": "e2", "good_jrxml_snippet": "", "correction_prompt": "p2", "recorded_at": ""}), + ]], + "metadatas": [[{}, {}]], + "distances": [[0.1, 0.2]], + } + ctx = kb.search_as_context("error", k=2) + assert "[历史错误案例]" in ctx + assert "---" in ctx + + def test_search_as_context_empty_for_no_results(self, kb, col): + col.get.return_value = {"ids": []} + col.query.return_value = {"ids": [[]], "documents": [[]], "distances": [[]]} + assert kb.search_as_context("error") == "" + + def test_stats_returns_count(self, kb, col): + col.count.return_value = 42 + assert kb.stats()["total_cases"] == 42 + + def test_stats_zero_on_exception(self, kb, col): + col.count.side_effect = RuntimeError("down") + assert kb.stats()["total_cases"] == 0 + + +# ── 全局便捷函数 ─────────────────────────────────────────────── + +class TestConvenienceFunctions: + def test_get_error_kb_is_singleton(self, monkeypatch): + import backend.error_kb as mod + monkeypatch.setattr(mod, "_kb", None) + assert get_error_kb() is get_error_kb() + + def test_record_error_delegates(self): + with patch.object(ErrorKB, "record", return_value=True) as mock_r: + assert record_error("e", "", "", "p") is True + mock_r.assert_called_once() + + def test_search_error_cases_delegates(self): + with patch.object(ErrorKB, "search_as_context", return_value="ctx") as mock_s: + assert search_error_cases("err", k=5) == "ctx" + mock_s.assert_called_once_with("err", k=5) diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..6a0414b --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,210 @@ +"""backend/session.py 单元测试 — 会话 CRUD + 原子写入。 + +覆盖: + - 创建/加载/保存/删除/列出 + - 原子写入(tempfile + os.replace) + - 边界情况(不存在会话、损坏 JSON、空名称自动填充) +""" + +import json +import os +import sys +import tempfile +import time +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from backend.session import ( + create_session, + load_session, + save_session, + get_session_state, + list_all_sessions, + delete_session, + generate_session_id, + SESSIONS_DIR, +) + + +@pytest.fixture +def temp_sessions_dir(monkeypatch): + with tempfile.TemporaryDirectory(prefix="test_sessions_") as tmpdir: + monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir)) + yield Path(tmpdir) + + +# ── create_session ────────────────────────────────────────────── + +class TestCreateSession: + def test_creates_with_defaults(self, temp_sessions_dir): + s = create_session() + assert len(s["session_id"]) == 12 + assert "新建报表" in s["session_name"] + assert s["created_at"] + assert s["updated_at"] + + def test_custom_name(self, temp_sessions_dir): + s = create_session(name="测试报表") + assert s["session_name"] == "测试报表" + + def test_agent_state_preserved(self, temp_sessions_dir): + s = create_session(agent_state={"current_jrxml": ""}) + assert s["agent_state"]["current_jrxml"] == "" + + def test_session_id_injected_into_agent_state(self, temp_sessions_dir): + s = create_session() + assert s["agent_state"]["session_id"] == s["session_id"] + + def test_persists_json_to_disk(self, temp_sessions_dir): + s = create_session(name="磁盘测试") + fp = temp_sessions_dir / f"{s['session_id']}.json" + assert fp.exists() + loaded = json.loads(fp.read_text("utf-8")) + assert loaded["session_name"] == "磁盘测试" + + def test_unique_ids_no_collision(self, temp_sessions_dir): + ids = {generate_session_id() for _ in range(100)} + assert len(ids) == 100 + + def test_creates_sessions_dir_if_missing(self, temp_sessions_dir): + nested = temp_sessions_dir / "nested" / "sub" + import backend.session as mod + + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(mod, "SESSIONS_DIR", nested) + s = create_session() + assert nested.exists() + assert (nested / f"{s['session_id']}.json").exists() + + +# ── load_session ──────────────────────────────────────────────── + +class TestLoadSession: + def test_returns_none_for_missing(self, temp_sessions_dir): + assert load_session("nonexistent_id") is None + + def test_loads_existing(self, temp_sessions_dir): + created = create_session(name="加载测试") + loaded = load_session(created["session_id"]) + assert loaded["session_name"] == "加载测试" + assert loaded["session_id"] == created["session_id"] + + def test_load_includes_agent_state(self, temp_sessions_dir): + created = create_session(agent_state={"field_count": 5}) + loaded = load_session(created["session_id"]) + assert loaded["agent_state"]["field_count"] == 5 + + +# ── save_session ──────────────────────────────────────────────── + +class TestSaveSession: + def test_updates_name_and_state(self, temp_sessions_dir): + created = create_session(name="原始") + save_session(created["session_id"], {"new_key": True}, session_name="更新") + loaded = load_session(created["session_id"]) + assert loaded["session_name"] == "更新" + assert loaded["agent_state"]["new_key"] is True + + def test_preserves_created_at(self, temp_sessions_dir): + created = create_session() + original = created["created_at"] + save_session(created["session_id"], {"x": 1}) + assert load_session(created["session_id"])["created_at"] == original + + def test_updates_updated_at(self, temp_sessions_dir): + created = create_session() + time.sleep(0.01) + save_session(created["session_id"], {"x": 1}) + loaded = load_session(created["session_id"]) + assert loaded["updated_at"] != created["updated_at"] + + def test_atomic_write_produces_valid_json(self, temp_sessions_dir): + created = create_session() + save_session(created["session_id"], {"data": "x" * 1000}) + fp = temp_sessions_dir / f"{created['session_id']}.json" + data = json.loads(fp.read_text("utf-8")) + assert data["agent_state"]["data"] == "x" * 1000 + + def test_auto_generates_name_when_empty(self, temp_sessions_dir): + created = create_session(name="") + save_session(created["session_id"], {"x": 1}) + assert load_session(created["session_id"])["session_name"] + + def test_keeps_existing_name_when_not_provided(self, temp_sessions_dir): + created = create_session(name="原名") + save_session(created["session_id"], {"x": 1}) + assert load_session(created["session_id"])["session_name"] == "原名" + + def test_fills_missing_created_at(self, temp_sessions_dir): + sid = "test_no_created" + fp = temp_sessions_dir / f"{sid}.json" + fp.write_text( + json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8" + ) + save_session(sid, {"x": 1}) + assert load_session(sid)["created_at"] + + +# ── get_session_state ─────────────────────────────────────────── + +class TestGetSessionState: + def test_none_for_missing(self, temp_sessions_dir): + assert get_session_state("missing") is None + + def test_returns_all_keys(self, temp_sessions_dir): + created = create_session(name="状态测试") + state = get_session_state(created["session_id"]) + for key in ("session_id", "session_name", "agent_state", "created_at", "updated_at"): + assert key in state + + +# ── list_all_sessions ─────────────────────────────────────────── + +class TestListAllSessions: + def test_empty_when_no_sessions(self, temp_sessions_dir): + assert list_all_sessions() == [] + + def test_lists_all_created(self, temp_sessions_dir): + s1 = create_session(name="A") + s2 = create_session(name="B") + ids = {s["session_id"] for s in list_all_sessions()} + assert s1["session_id"] in ids + assert s2["session_id"] in ids + + def test_summary_excludes_agent_state(self, temp_sessions_dir): + create_session(agent_state={"secret": True}) + result = list_all_sessions() + assert "agent_state" not in result[0] + + def test_sorted_by_mtime_desc(self, temp_sessions_dir): + s1 = create_session(name="先") + time.sleep(0.02) + s2 = create_session(name="后") + assert list_all_sessions()[0]["session_id"] == s2["session_id"] + + def test_skips_corrupt_json(self, temp_sessions_dir): + (temp_sessions_dir / "bad.json").write_text("{not json}", "utf-8") + create_session(name="正常") + assert len(list_all_sessions()) == 1 + + +# ── delete_session ────────────────────────────────────────────── + +class TestDeleteSession: + def test_returns_false_for_missing(self, temp_sessions_dir): + assert delete_session("ghost_id") is False + + def test_returns_true_and_removes(self, temp_sessions_dir): + created = create_session() + assert delete_session(created["session_id"]) is True + assert load_session(created["session_id"]) is None + + def test_file_is_removed_from_disk(self, temp_sessions_dir): + created = create_session() + fp = temp_sessions_dir / f"{created['session_id']}.json" + assert fp.exists() + delete_session(created["session_id"]) + assert not fp.exists()