test: add unit/integration/E2E test suites, fix create_session bug, update docs

- Unit tests: test_session.py (27), test_error_kb.py (24), test_agent.py hardened
- Integration tests: test_api_integration.py (25) with FastAPI TestClient
- E2E tests: main-flows.spec.ts (8) with Playwright + API mocking
- Bug fix: backend/session.py create_session() missing session_id parameter
- Config: frontend/playwright.config.ts, npm run test:e2e
- Docs: update CLAUDE.md v9, .gitignore for test artifacts/eval reports
This commit is contained in:
2026-05-23 08:38:29 +08:00
parent b444303055
commit 1952d75f13
11 changed files with 1029 additions and 12 deletions
+13
View File
@@ -13,6 +13,19 @@ db/chroma/
sessions/
logs/
db/
# 自动评测 (Mavis AI)
.mavis/
EVALUATION_REPORT.md
# 上传文件
uploads/
# OCR 临时输出
ocr_raw_positions.json
# Playwright E2E 测试产物
frontend/test-results/
# RAG 管线中间产物 (rag 子模块内)
rag/jrxml_chunker_output/
rag/embeddings/
+33
View File
@@ -327,3 +327,36 @@ validation_service/ (FastAPI, 端口 8001) — 不变
- `prompts/refine_layout.md` — 1 处
Python 将 `{{` 输出为字面量 `{`LLM 看到的内容不变。
## 更新 (v9 — 2026-05-22)
### 测试基础设施全面补齐
**单元测试** (76 测试):
- `tests/test_session.py` — 27 测试:会话 CRUD、原子写入、唯一 ID、损坏 JSON 跳过
- `tests/test_error_kb.py` — 24 测试:指纹去重、关键词提取(中/英/JRXML)、ErrorKB CRUD、搜索、统计
- `tests/test_agent.py` — 5 个软断言强化为严格断言(`status`/`current_jrxml` 存在性检查)
- 已有测试:`test_ocr_extraction.py`49)、`test_layered_generation.py`19)、`test_validation.py`6)、`test_file_parser_formats.py`4)、`test_annotation_detector.py`7)、`test_e2e_ocr.py`3
**集成测试** (25 测试, `tests/test_api_integration.py`):
- FastAPI TestClient 全覆盖:健康检查、配置、会话 CRUD、文件上传、下载、Chat SSE、安全边界(路径穿越/非法 JSON/大 payload
- Mock LangGraph graph 避免真实 LLM 调用
**E2E 测试** (8 测试, `frontend/tests/e2e/main-flows.spec.ts`):
- Playwright 浏览器自动化:页面加载、侧边栏、会话管理、聊天流程、输入 UX
- 全量 API Mock`page.route`)无需后端运行
- 配置: `frontend/playwright.config.ts`, `npm run test:e2e`
**运行测试**:
```bash
# 全部单元+集成测试
cd D:\Idea Project\jaspersoft && python -m pytest tests/ -v
# 仅 E2E(需要前端 dev server
cd frontend && npx playwright test
```
### Bug 修复: create_session 参数缺失
`backend/session.py``create_session()` 新增可选参数 `session_id: Optional[str] = None`
`api_server.py:507` 调用 `create_session(session_id=session_id)` 时之前会抛出 `TypeError`
+4 -3
View File
@@ -34,10 +34,11 @@ def generate_session_id() -> str:
return uuid.uuid4().hex[:12]
def create_session(name: str = "", agent_state: Optional[dict] = None) -> dict:
"""创建新会话,返回会话元数据。"""
def create_session(name: str = "", agent_state: Optional[dict] = None,
session_id: Optional[str] = None) -> dict:
"""创建新会话,返回会话元数据。session_id 可选——传入时使用指定 ID。"""
_ensure_dir()
sid = generate_session_id()
sid = session_id or generate_session_id()
now = _now_iso()
agent_state = agent_state or {}
agent_state["session_id"] = sid
+64
View File
@@ -12,6 +12,7 @@
"vue": "^3.5.34"
},
"devDependencies": {
"@playwright/test": "^1.60.0",
"@types/node": "^24.12.3",
"@vitejs/plugin-vue": "^6.0.6",
"@vue/tsconfig": "^0.9.1",
@@ -135,6 +136,22 @@
"url": "https://github.com/sponsors/Boshen"
}
},
"node_modules/@playwright/test": {
"version": "1.60.0",
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.60.0.tgz",
"integrity": "sha512-O71yZIbAh/PxDMNGns37GHBIfrVkEVyn+AXyIa5dOTfb4/xNvRWV+Vv/NMbNCtODB/pO7vLlF2OTmMVLhmr7Ag==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"playwright": "1.60.0"
},
"bin": {
"playwright": "cli.js"
},
"engines": {
"node": ">=18"
}
},
"node_modules/@rolldown/binding-android-arm64": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/@rolldown/binding-android-arm64/-/binding-android-arm64-1.0.2.tgz",
@@ -1104,6 +1121,53 @@
}
}
},
"node_modules/playwright": {
"version": "1.60.0",
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.60.0.tgz",
"integrity": "sha512-hheHdokM8cdqCb0lcE3s+zT4t4W+vvjpGxsZlDnikarzx8tSzMebh3UiFtgqwFwnTnjYQcsyMF8ei2mCO/tpeA==",
"dev": true,
"license": "Apache-2.0",
"dependencies": {
"playwright-core": "1.60.0"
},
"bin": {
"playwright": "cli.js"
},
"engines": {
"node": ">=18"
},
"optionalDependencies": {
"fsevents": "2.3.2"
}
},
"node_modules/playwright-core": {
"version": "1.60.0",
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.60.0.tgz",
"integrity": "sha512-9bW6zvX/m0lEbgTKJ6YppOKx8H3VOPBMOCFh2irXFOT4BbHgrx5hPjwJYLT40Lu+4qtD36qKc/Hn56StUW57IA==",
"dev": true,
"license": "Apache-2.0",
"bin": {
"playwright-core": "cli.js"
},
"engines": {
"node": ">=18"
}
},
"node_modules/playwright/node_modules/fsevents": {
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz",
"integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==",
"dev": true,
"hasInstallScript": true,
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
}
},
"node_modules/postcss": {
"version": "8.5.15",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.15.tgz",
+3 -1
View File
@@ -6,13 +6,15 @@
"scripts": {
"dev": "vite",
"build": "vue-tsc -b && vite build",
"preview": "vite preview"
"preview": "vite preview",
"test:e2e": "playwright test"
},
"dependencies": {
"pinia": "^3.0.4",
"vue": "^3.5.34"
},
"devDependencies": {
"@playwright/test": "^1.60.0",
"@types/node": "^24.12.3",
"@vitejs/plugin-vue": "^6.0.6",
"@vue/tsconfig": "^0.9.1",
+20
View File
@@ -0,0 +1,20 @@
import { defineConfig } from "@playwright/test";
export default defineConfig({
testDir: "./tests/e2e",
timeout: 60000,
expect: { timeout: 10000 },
retries: 0,
use: {
baseURL: "http://localhost:5173",
headless: true,
screenshot: "only-on-failure",
trace: "retain-on-failure",
},
webServer: {
command: "npm run dev",
url: "http://localhost:5173",
reuseExistingServer: true,
timeout: 30000,
},
});
+168
View File
@@ -0,0 +1,168 @@
/**
* E2E tests: key user flows for the JRXML Agent frontend.
*
* Pre-requisites: npm run dev (reuseExistingServer in playwright.config).
* API calls are intercepted by page.route() — no real backend needed.
*/
import { test, expect } from "@playwright/test";
// ── helpers ────────────────────────────────────────────────────
function mockApi(page: any) {
page.route("**/api/health", (route: any) =>
route.fulfill({ json: { status: "ok", version: "5.0" } })
);
page.route("**/api/sessions", (route: any) => {
if (route.request().method() === "POST") {
return route.fulfill({
json: {
session_id: "test12345678",
session_name: "新建报表 2026-05-22",
created_at: "2026-05-22T10:00:00.000Z",
updated_at: "2026-05-22T10:00:00.000Z",
},
});
}
return route.fulfill({ json: { sessions: [] } });
});
page.route("**/api/sessions/*/chat", (route: any) => {
const sseBody = [
"event: node_start",
'data: {"node":"classify_intent","label":"识别意图","step_index":1}',
"",
"event: node_complete",
'data: {"node":"classify_intent","label":"识别意图","detail":"意图: 新建报表"}',
"",
"event: agent_complete",
'data: {"reason":"done","intent":"initial_generation","status":"pass","jrxml_length":42,"versions":1,"total_duration_ms":1200}',
"",
"",
].join("\n");
return route.fulfill({
status: 200,
headers: { "content-type": "text/event-stream" },
body: sseBody,
});
});
page.route("**/api/upload", (route: any) =>
route.fulfill({
json: { file_id: "f001122334455", filename: "test.png", size: 1024 },
})
);
// Catch-all for GET/DELETE /api/sessions/:id (must fallback for POST to let chat route match)
page.route("**/api/sessions/**", (route: any) => {
if (route.request().method() === "DELETE") {
return route.fulfill({ json: { status: "deleted" } });
}
if (route.request().method() === "GET") {
return route.fulfill({
json: {
session_id: "test12345678",
session_name: "测试会话",
agent_state: { current_jrxml: "<jasperReport/>" },
},
});
}
return route.fallback();
});
}
// ── tests ──────────────────────────────────────────────────────
test.describe("Page load", () => {
test("renders sidebar and input area", async ({ page }) => {
await mockApi(page);
await page.goto("/");
await expect(page.locator("aside.sidebar")).toBeVisible();
await expect(page.locator("h2")).toContainText("JRXML");
await expect(page.locator(".unified-input")).toBeVisible();
});
test("sidebar shows session list header and new button", async ({ page }) => {
await mockApi(page);
await page.goto("/");
await expect(page.getByText("会话列表")).toBeVisible();
await expect(page.locator('button[title="新建会话"]')).toBeVisible();
});
});
test.describe("Session management", () => {
test("creates new session on button click", async ({ page }) => {
await mockApi(page);
await page.goto("/");
await page.locator('button[title="新建会话"]').click();
await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 });
await expect(page.locator(".session-item.active")).toBeVisible();
});
test("can delete current session", async ({ page }) => {
await mockApi(page);
await page.goto("/");
await page.locator('button[title="新建会话"]').click();
await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 });
page.on("dialog", (dialog) => dialog.accept());
await page.locator(".btn-delete").click();
await expect(page.locator(".session-item")).toHaveCount(0, { timeout: 5000 });
});
});
test.describe("Chat flow", () => {
test("sends text and displays user message + process section", async ({ page }) => {
await mockApi(page);
await page.goto("/");
await page.locator('button[title="新建会话"]').click();
await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 });
const textarea = page.locator(".unified-input textarea");
await textarea.fill("生成一个员工名册报表");
await page.locator(".send-btn").click();
await expect(
page.locator(".chat-messages .message.msg-user").filter({ hasText: "员工名册" })
).toBeVisible({ timeout: 10000 });
await expect(page.locator(".process-section")).toBeVisible({ timeout: 10000 });
});
test("summary card appears after stream complete", async ({ page }) => {
await mockApi(page);
await page.goto("/");
await page.locator('button[title="新建会话"]').click();
await expect(page.locator(".session-item")).toBeVisible({ timeout: 5000 });
await page.locator(".unified-input textarea").fill("生成报表");
await page.locator(".send-btn").click();
await expect(page.locator(".summary-card")).toBeVisible({ timeout: 15000 });
});
});
test.describe("Input UX", () => {
test("send button disabled when input empty", async ({ page }) => {
await mockApi(page);
await page.goto("/");
await expect(page.locator(".send-btn")).toBeDisabled();
});
test("send button enabled when text entered", async ({ page }) => {
await mockApi(page);
await page.goto("/");
await page.locator(".unified-input textarea").fill("Hi");
await expect(page.locator(".send-btn")).toBeEnabled();
});
});
+9 -8
View File
@@ -44,8 +44,8 @@ class TestAcceptanceScenarios:
final = run_graph(graph, state)
assert final.get("current_jrxml"), "应该已生成 JRXML"
# 注意:通过/失败取决于 LLM 输出质量;我们检查是否得到了结果
print(f"场景 1 状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}")
assert final.get("status") in ("pass", "fail"), f"意外状态: {final.get('status')}"
assert "<jasperReport" in final["current_jrxml"], "输出应包含合法 JRXML 根元素"
def test_scenario2_auto_correction(self, graph):
"""场景 2:故意提出一个可能初次失败的需求。"""
@@ -58,7 +58,8 @@ class TestAcceptanceScenarios:
final = run_graph(graph, state)
assert final.get("retry_count", 0) <= 5, "不应超过最大重试次数"
print(f"场景 2 状态: {final.get('status')}, 重试次数: {final.get('retry_count', 0)}")
assert "status" in final, "最终状态应包含 status 字段"
assert final.get("current_jrxml") or final.get("error_msg"), "应有输出或错误消息"
def test_scenario3_multi_turn_modification(self, graph):
"""场景 3:多轮对话 - 先生成,再修改两次。"""
@@ -71,8 +72,8 @@ class TestAcceptanceScenarios:
state["stage"] = "initial_generation"
final = run_graph(graph, state)
print(f"第 1 轮状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}")
assert final.get("current_jrxml"), "第 1 轮应该已生成 JRXML"
assert final.get("status") in ("pass", "fail")
# 第 2 轮:添加月度销售汇总
state2 = final.copy()
@@ -82,8 +83,8 @@ class TestAcceptanceScenarios:
state2["retry_count"] = 0
final2 = run_graph(graph, state2)
print(f"第 2 轮状态: {final2.get('status')}")
assert final2.get("current_jrxml"), "第 2 轮应该已修改 JRXML"
assert final2.get("status") in ("pass", "fail")
# 第 3 轮:修改标题
state3 = final2.copy()
@@ -93,9 +94,9 @@ class TestAcceptanceScenarios:
state3["retry_count"] = 0
final3 = run_graph(graph, state3)
print(f"第 3 轮状态: {final3.get('status')}")
jrxml = final3.get("current_jrxml", "")
assert "2024" in jrxml or "Annual" in jrxml, "标题修改应该体现在 JRXML 中"
assert final3.get("status") in ("pass", "fail")
def test_scenario4_context_aware_modification(self, graph):
"""场景 4:基于对话上下文的修改。"""
@@ -109,7 +110,7 @@ class TestAcceptanceScenarios:
state["stage"] = "initial_generation"
final = run_graph(graph, state)
print(f"第 1 轮状态: {final.get('status')}")
assert final.get("current_jrxml"), "第 1 轮应该已生成 JRXML"
# 第 2 轮:上下文感知修改
state2 = final.copy()
@@ -119,9 +120,9 @@ class TestAcceptanceScenarios:
state2["retry_count"] = 0
final2 = run_graph(graph, state2)
print(f"第 2 轮状态: {final2.get('status')}")
jrxml = final2.get("current_jrxml", "")
assert "isBold" in jrxml or "size=" in jrxml, "字体修改应该体现在结果中"
assert final2.get("status") in ("pass", "fail")
def test_max_retry_handling(self, graph):
"""测试在 MAX_RETRY 次失败后,图能否正常终止。"""
+263
View File
@@ -0,0 +1,263 @@
"""api_server.py 集成测试 — REST 端点 + SSE 流 + 文件上传/下载。
使用 FastAPI TestClient,不需要启动真实服务器。
"""
import io
import json
import os
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
from fastapi.testclient import TestClient
sys.path.insert(0, str(Path(__file__).parent.parent))
from api_server import app
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def temp_sessions(monkeypatch):
"""重定向上传目录到临时目录,隔离测试数据。"""
with tempfile.TemporaryDirectory(prefix="test_api_") as tmpdir:
monkeypatch.setattr("api_server.UPLOADS_DIR", Path(tmpdir) / "uploads")
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir) / "sessions")
yield Path(tmpdir)
# ── 健康检查 & 配置 ────────────────────────────────────────────
class TestHealthAndConfig:
def test_health_returns_ok(self, client):
resp = client.get("/api/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "ok"
assert data["version"] == "5.0"
assert "timestamp" in data
def test_config_returns_env_keys(self, client):
resp = client.get("/api/config")
assert resp.status_code == 200
cfg = resp.json()["config"]
for key in ("LLM_PROVIDER", "OCR_ENGINE", "MAX_RETRY"):
assert key in cfg
# ── 会话 CRUD ──────────────────────────────────────────────────
class TestSessionCRUD:
def test_create_session(self, client, temp_sessions):
resp = client.post("/api/sessions")
assert resp.status_code == 200
data = resp.json()
assert len(data["session_id"]) == 12
assert "session_name" in data
assert "created_at" in data
def test_list_sessions_empty(self, client, temp_sessions):
assert client.get("/api/sessions").json()["sessions"] == []
def test_list_sessions_populated(self, client, temp_sessions):
client.post("/api/sessions")
client.post("/api/sessions")
assert len(client.get("/api/sessions").json()["sessions"]) == 2
def test_get_session_found(self, client, temp_sessions):
created = client.post("/api/sessions").json()
resp = client.get(f"/api/sessions/{created['session_id']}")
assert resp.status_code == 200
assert resp.json()["session_id"] == created["session_id"]
assert "agent_state" in resp.json()
def test_get_session_not_found(self, client, temp_sessions):
assert client.get("/api/sessions/nonexistent").status_code == 404
def test_delete_session(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
resp = client.delete(f"/api/sessions/{sid}")
assert resp.status_code == 200
assert resp.json()["status"] == "deleted"
assert client.get(f"/api/sessions/{sid}").status_code == 404
def test_delete_nonexistent(self, client, temp_sessions):
assert client.delete("/api/sessions/ghost_id").status_code == 404
def test_full_crud_lifecycle(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
assert client.get(f"/api/sessions/{sid}").status_code == 200
assert len(client.get("/api/sessions").json()["sessions"]) == 1
client.delete(f"/api/sessions/{sid}")
assert client.get("/api/sessions").json()["sessions"] == []
# ── 文件上传 ───────────────────────────────────────────────────
class TestFileUpload:
def test_upload_text_file(self, client, temp_sessions):
content = b"Hello, JRXML!"
resp = client.post(
"/api/upload",
files={"file": ("test.txt", io.BytesIO(content), "text/plain")},
)
assert resp.status_code == 200
data = resp.json()
assert data["filename"] == "test.txt"
assert data["size"] == len(content)
assert len(data["file_id"]) == 12
def test_upload_with_session_id_in_query(self, client, temp_sessions):
resp = client.post(
"/api/upload?session_id=abc123",
files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")},
)
assert resp.status_code == 200
def test_upload_png_gets_correct_content_type(self, client, temp_sessions):
png_minimal = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
resp = client.post(
"/api/upload",
files={"file": ("chart.png", io.BytesIO(png_minimal), "image/png")},
)
assert resp.status_code == 200
assert resp.json()["content_type"] == "image/png"
def test_upload_writes_file_to_disk(self, client, temp_sessions):
data = b"persisted content"
file_id = client.post(
"/api/upload",
files={"file": ("note.txt", io.BytesIO(data), "text/plain")},
).json()["file_id"]
matches = list(temp_sessions.rglob(f"{file_id}_*"))
assert len(matches) == 1
assert matches[0].read_bytes() == data
# ── 下载 ───────────────────────────────────────────────────────
class TestDownload:
def test_download_missing_session_returns_404(self, client, temp_sessions):
assert client.get("/api/sessions/missing/download/latest").status_code == 404
def test_download_no_jrxml_returns_404(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
resp = client.get(f"/api/sessions/{sid}/download/latest")
assert resp.status_code == 404
def test_download_with_jrxml_returns_file(self, client, temp_sessions):
import backend.session as sess
sess.create_session(name="测试下载")
# 需要手动写入 JRXML 到会话
sessions = sess.list_all_sessions()
sid = sessions[0]["session_id"]
sess.save_session(sid, {"current_jrxml": "<jasperReport name='rpt'/>"})
resp = client.get(f"/api/sessions/{sid}/download/latest")
assert resp.status_code == 200
assert "<jasperReport" in resp.text
assert "attachment" in resp.headers.get("content-disposition", "")
# ── 聊天 SSE ───────────────────────────────────────────────────
class TestChatSSE:
@pytest.fixture(autouse=True)
def mock_graph(self, monkeypatch):
"""Mock LangGraph 的 build_graph 和 stream,避免真实 LLM 调用。"""
mock_graph = MagicMock()
mock_graph.stream.return_value = [
("updates", {"classify_intent": {"intent": "initial_generation"}}),
("updates", {"generate": {"current_jrxml": "<jasperReport name='test'/>", "status": "pass"}}),
("updates", {"validate": {"status": "pass"}}),
("updates", {"finalize": {}}),
("done", {"reason": "graph_completed"}),
]
# 注意:_graph 是模块级变量,在导入时就编译了。需要直接替换。
monkeypatch.setattr("api_server._graph", mock_graph)
# 同时替换 agent.graph.build_graph 以防后续重新编译
monkeypatch.setattr("agent.graph.build_graph", lambda on_node_start=None: mock_graph)
return mock_graph
def test_empty_payload_rejected(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
resp = client.post(
f"/api/sessions/{sid}/chat",
json={"text": "", "file_ids": []},
)
assert resp.status_code == 400
def test_sse_stream_returns_valid_events(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
with client.stream(
"POST",
f"/api/sessions/{sid}/chat",
json={"text": "生成一个简单的员工名册报表", "file_ids": []},
) as resp:
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
body = resp.read().decode("utf-8", errors="replace")
assert "event: node_complete" in body
assert "event: agent_complete" in body
def test_auto_creates_session_on_chat(self, client, temp_sessions):
with client.stream(
"POST",
"/api/sessions/auto_new_session/chat",
json={"text": "生成报表", "file_ids": []},
) as resp:
assert resp.status_code == 200
assert b"event:" in resp.read()
def test_unknown_file_ids_not_crash(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
with client.stream(
"POST",
f"/api/sessions/{sid}/chat",
json={"text": "测试", "file_ids": ["fake_id_xyz"]},
) as resp:
assert resp.status_code == 200
# ── 边界 & 安全测试 ────────────────────────────────────────────
class TestBoundaries:
def test_session_id_path_traversal_returns_404(self, client, temp_sessions):
assert client.get("/api/sessions/../etc/passwd").status_code == 404
def test_upload_with_path_traversal_session_id(self, client, temp_sessions):
"""路径穿越 session_id 仍正常处理(目录隔离在 UPLOADS_DIR 内)。"""
resp = client.post(
"/api/upload?session_id=../malicious",
files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")},
)
assert resp.status_code == 200
def test_invalid_json_body_rejected(self, client, temp_sessions):
sid = client.post("/api/sessions").json()["session_id"]
resp = client.post(
f"/api/sessions/{sid}/chat",
content=b"{not valid json",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 422
def test_large_payload_survives(self, client, temp_sessions):
"""大文本(100KB)不应崩溃。"""
sid = client.post("/api/sessions").json()["session_id"]
large_text = "生成报表包含字段: " + ", ".join(f"field_{i}" for i in range(5000))
with client.stream(
"POST",
f"/api/sessions/{sid}/chat",
json={"text": large_text, "file_ids": []},
) as resp:
assert resp.status_code == 200
+242
View File
@@ -0,0 +1,242 @@
"""backend/error_kb.py 单元测试 — 指纹去重 + 关键词提取 + CRUD。
覆盖:
- _make_fingerprint 标准化与去重
- _extract_keywords 中英文混合提取
- ErrorKB.record / exists / search / search_as_contextmock ChromaDB
- 全局便捷函数 record_error / search_error_cases
"""
import os
import sys
import json
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from backend.error_kb import (
_make_fingerprint,
_extract_keywords,
ErrorKB,
get_error_kb,
record_error,
search_error_cases,
)
# ── _make_fingerprint ───────────────────────────────────────────
class TestMakeFingerprint:
def test_same_structure_same_fingerprint(self):
e1 = "Field $F{customer_name} is not declared in the report"
e2 = "Field $F{order_total} is not declared in the report"
assert _make_fingerprint(e1) == _make_fingerprint(e2)
def test_different_errors_different_fingerprint(self):
e1 = "Missing required attribute pageWidth"
e2 = "Query returned 0 results"
assert _make_fingerprint(e1) != _make_fingerprint(e2)
def test_normalizes_variable_names(self):
fp1 = _make_fingerprint("Field $F{amount} not found")
fp2 = _make_fingerprint("Field $F{total_price} not found")
assert fp1 == fp2
def test_normalizes_string_literals_single_quote(self):
fp1 = _make_fingerprint("Value 'abc123' is invalid")
fp2 = _make_fingerprint("Value 'xyz789' is invalid")
assert fp1 == fp2
def test_normalizes_string_literals_double_quote(self):
fp1 = _make_fingerprint('Name "test_table" not found')
fp2 = _make_fingerprint('Name "prod_table" not found')
assert fp1 == fp2
def test_normalizes_numbers(self):
fp1 = _make_fingerprint("Line 42 has 100 errors")
fp2 = _make_fingerprint("Line 7 has 3 errors")
assert fp1 == fp2
def test_case_insensitive(self):
assert _make_fingerprint("ERROR: Missing Field") == _make_fingerprint("error: missing field")
def test_whitespace_insensitive(self):
e1 = "missing field\n\ndeclaration"
e2 = "missing field declaration"
assert _make_fingerprint(e1) == _make_fingerprint(e2)
def test_output_is_16_char_hex(self):
fp = _make_fingerprint("some error message")
assert len(fp) == 16
assert all(c in "0123456789abcdef" for c in fp)
# ── _extract_keywords ───────────────────────────────────────────
class TestExtractKeywords:
def test_extracts_chinese_words(self):
kw = _extract_keywords("未声明的字段引用和语法错误")
has_cn = any(len(k) >= 2 and "" <= k[0] <= "鿿" for k in kw)
assert has_cn
def test_extracts_english_tokens(self):
kw = _extract_keywords("missing field declaration in report")
assert "missing" in kw
assert "field" in kw
assert "report" in kw
def test_extracts_jrxml_patterns(self):
kw = _extract_keywords("Field $F{customer_name} not declared")
assert "$F{customer_name}" in kw
def test_short_tokens_ignored(self):
kw = _extract_keywords("a b c ab cd")
assert "ab" not in kw
assert "cd" not in kw
def test_empty_input_returns_empty_list(self):
assert _extract_keywords("") == []
def test_mixed_cn_en_jrxml(self):
kw = _extract_keywords("字段 $F{amount} 在 report 中未声明")
assert "$F{amount}" in kw
assert "report" in kw
# ── ErrorKB class (mock ChromaDB) ───────────────────────────────
def _make_patched_kb(client_override=None, collection_override=None):
"""创建一个 ErrorKB 实例,其 ChromaDB 依赖已被 mock。
因为 chromadb 是懒加载的(在 client/collection property 中导入),
直接设置 _client/_collection 实例属性即可绕过真实 ChromaDB。
"""
kb = ErrorKB()
kb._client = client_override or MagicMock()
kb._collection = collection_override or MagicMock()
if not client_override and not collection_override:
# 默认:client.get_collection 返回 mock collection
kb._client.get_collection.return_value = kb._collection
return kb
class TestErrorKBRecord:
def test_exists_returns_true_when_found(self):
col = MagicMock()
col.get.return_value = {"ids": ["abc123"]}
kb = _make_patched_kb(collection_override=col)
assert kb.exists("some error") is True
def test_exists_returns_false_when_not_found(self):
col = MagicMock()
col.get.return_value = {"ids": []}
kb = _make_patched_kb(collection_override=col)
assert kb.exists("some error") is False
def test_exists_survives_exception(self):
col = MagicMock()
col.get.side_effect = RuntimeError("db down")
kb = _make_patched_kb(collection_override=col)
assert kb.exists("some error") is False
def test_record_skips_duplicate(self):
col = MagicMock()
col.get.return_value = {"ids": ["existing_fp"]}
kb = _make_patched_kb(collection_override=col)
assert kb.record("error", "<bad/>", "<good/>", "fix prompt") is False
col.add.assert_not_called()
def test_record_adds_new_case(self):
col = MagicMock()
col.get.return_value = {"ids": []}
kb = _make_patched_kb(collection_override=col)
assert kb.record(
"Field $F{x} not declared",
"<bad_jrxml>", "<good_jrxml>",
"prompt content", model="test-model", retry_count=2,
) is True
col.add.assert_called_once()
meta = col.add.call_args[1]["metadatas"][0]
assert meta["retry_success"] == 3
class TestErrorKBSearch:
@pytest.fixture
def col(self):
return MagicMock()
@pytest.fixture
def kb(self, col):
return _make_patched_kb(collection_override=col)
def test_search_returns_formatted_results(self, kb, col):
col.get.return_value = {"ids": []}
col.query.return_value = {
"ids": [["fp1"]],
"documents": [[json.dumps({
"error": "test error",
"good_jrxml_snippet": "<good/>",
"correction_prompt": "fix it",
"recorded_at": "2026-01-01T00:00:00",
})]],
"metadatas": [[{}]],
"distances": [[0.05]],
}
results = kb.search("some error", k=3)
assert len(results) == 1
assert results[0]["error"] == "test error"
assert results[0]["distance"] == 0.05
def test_search_returns_empty_on_exception(self, kb, col):
col.query.side_effect = RuntimeError("fail")
assert kb.search("error") == []
def test_search_as_context_formats_output(self, kb, col):
col.get.return_value = {"ids": []}
col.query.return_value = {
"ids": [["fp1", "fp2"]],
"documents": [[
json.dumps({"error": "e1", "good_jrxml_snippet": "<g1/>", "correction_prompt": "p1", "recorded_at": ""}),
json.dumps({"error": "e2", "good_jrxml_snippet": "<g2/>", "correction_prompt": "p2", "recorded_at": ""}),
]],
"metadatas": [[{}, {}]],
"distances": [[0.1, 0.2]],
}
ctx = kb.search_as_context("error", k=2)
assert "[历史错误案例]" in ctx
assert "---" in ctx
def test_search_as_context_empty_for_no_results(self, kb, col):
col.get.return_value = {"ids": []}
col.query.return_value = {"ids": [[]], "documents": [[]], "distances": [[]]}
assert kb.search_as_context("error") == ""
def test_stats_returns_count(self, kb, col):
col.count.return_value = 42
assert kb.stats()["total_cases"] == 42
def test_stats_zero_on_exception(self, kb, col):
col.count.side_effect = RuntimeError("down")
assert kb.stats()["total_cases"] == 0
# ── 全局便捷函数 ───────────────────────────────────────────────
class TestConvenienceFunctions:
def test_get_error_kb_is_singleton(self, monkeypatch):
import backend.error_kb as mod
monkeypatch.setattr(mod, "_kb", None)
assert get_error_kb() is get_error_kb()
def test_record_error_delegates(self):
with patch.object(ErrorKB, "record", return_value=True) as mock_r:
assert record_error("e", "<b>", "<g>", "p") is True
mock_r.assert_called_once()
def test_search_error_cases_delegates(self):
with patch.object(ErrorKB, "search_as_context", return_value="ctx") as mock_s:
assert search_error_cases("err", k=5) == "ctx"
mock_s.assert_called_once_with("err", k=5)
+210
View File
@@ -0,0 +1,210 @@
"""backend/session.py 单元测试 — 会话 CRUD + 原子写入。
覆盖:
- 创建/加载/保存/删除/列出
- 原子写入(tempfile + os.replace
- 边界情况(不存在会话、损坏 JSON、空名称自动填充)
"""
import json
import os
import sys
import tempfile
import time
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from backend.session import (
create_session,
load_session,
save_session,
get_session_state,
list_all_sessions,
delete_session,
generate_session_id,
SESSIONS_DIR,
)
@pytest.fixture
def temp_sessions_dir(monkeypatch):
with tempfile.TemporaryDirectory(prefix="test_sessions_") as tmpdir:
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir))
yield Path(tmpdir)
# ── create_session ──────────────────────────────────────────────
class TestCreateSession:
def test_creates_with_defaults(self, temp_sessions_dir):
s = create_session()
assert len(s["session_id"]) == 12
assert "新建报表" in s["session_name"]
assert s["created_at"]
assert s["updated_at"]
def test_custom_name(self, temp_sessions_dir):
s = create_session(name="测试报表")
assert s["session_name"] == "测试报表"
def test_agent_state_preserved(self, temp_sessions_dir):
s = create_session(agent_state={"current_jrxml": "<x/>"})
assert s["agent_state"]["current_jrxml"] == "<x/>"
def test_session_id_injected_into_agent_state(self, temp_sessions_dir):
s = create_session()
assert s["agent_state"]["session_id"] == s["session_id"]
def test_persists_json_to_disk(self, temp_sessions_dir):
s = create_session(name="磁盘测试")
fp = temp_sessions_dir / f"{s['session_id']}.json"
assert fp.exists()
loaded = json.loads(fp.read_text("utf-8"))
assert loaded["session_name"] == "磁盘测试"
def test_unique_ids_no_collision(self, temp_sessions_dir):
ids = {generate_session_id() for _ in range(100)}
assert len(ids) == 100
def test_creates_sessions_dir_if_missing(self, temp_sessions_dir):
nested = temp_sessions_dir / "nested" / "sub"
import backend.session as mod
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(mod, "SESSIONS_DIR", nested)
s = create_session()
assert nested.exists()
assert (nested / f"{s['session_id']}.json").exists()
# ── load_session ────────────────────────────────────────────────
class TestLoadSession:
def test_returns_none_for_missing(self, temp_sessions_dir):
assert load_session("nonexistent_id") is None
def test_loads_existing(self, temp_sessions_dir):
created = create_session(name="加载测试")
loaded = load_session(created["session_id"])
assert loaded["session_name"] == "加载测试"
assert loaded["session_id"] == created["session_id"]
def test_load_includes_agent_state(self, temp_sessions_dir):
created = create_session(agent_state={"field_count": 5})
loaded = load_session(created["session_id"])
assert loaded["agent_state"]["field_count"] == 5
# ── save_session ────────────────────────────────────────────────
class TestSaveSession:
def test_updates_name_and_state(self, temp_sessions_dir):
created = create_session(name="原始")
save_session(created["session_id"], {"new_key": True}, session_name="更新")
loaded = load_session(created["session_id"])
assert loaded["session_name"] == "更新"
assert loaded["agent_state"]["new_key"] is True
def test_preserves_created_at(self, temp_sessions_dir):
created = create_session()
original = created["created_at"]
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["created_at"] == original
def test_updates_updated_at(self, temp_sessions_dir):
created = create_session()
time.sleep(0.01)
save_session(created["session_id"], {"x": 1})
loaded = load_session(created["session_id"])
assert loaded["updated_at"] != created["updated_at"]
def test_atomic_write_produces_valid_json(self, temp_sessions_dir):
created = create_session()
save_session(created["session_id"], {"data": "x" * 1000})
fp = temp_sessions_dir / f"{created['session_id']}.json"
data = json.loads(fp.read_text("utf-8"))
assert data["agent_state"]["data"] == "x" * 1000
def test_auto_generates_name_when_empty(self, temp_sessions_dir):
created = create_session(name="")
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["session_name"]
def test_keeps_existing_name_when_not_provided(self, temp_sessions_dir):
created = create_session(name="原名")
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["session_name"] == "原名"
def test_fills_missing_created_at(self, temp_sessions_dir):
sid = "test_no_created"
fp = temp_sessions_dir / f"{sid}.json"
fp.write_text(
json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"
)
save_session(sid, {"x": 1})
assert load_session(sid)["created_at"]
# ── get_session_state ───────────────────────────────────────────
class TestGetSessionState:
def test_none_for_missing(self, temp_sessions_dir):
assert get_session_state("missing") is None
def test_returns_all_keys(self, temp_sessions_dir):
created = create_session(name="状态测试")
state = get_session_state(created["session_id"])
for key in ("session_id", "session_name", "agent_state", "created_at", "updated_at"):
assert key in state
# ── list_all_sessions ───────────────────────────────────────────
class TestListAllSessions:
def test_empty_when_no_sessions(self, temp_sessions_dir):
assert list_all_sessions() == []
def test_lists_all_created(self, temp_sessions_dir):
s1 = create_session(name="A")
s2 = create_session(name="B")
ids = {s["session_id"] for s in list_all_sessions()}
assert s1["session_id"] in ids
assert s2["session_id"] in ids
def test_summary_excludes_agent_state(self, temp_sessions_dir):
create_session(agent_state={"secret": True})
result = list_all_sessions()
assert "agent_state" not in result[0]
def test_sorted_by_mtime_desc(self, temp_sessions_dir):
s1 = create_session(name="")
time.sleep(0.02)
s2 = create_session(name="")
assert list_all_sessions()[0]["session_id"] == s2["session_id"]
def test_skips_corrupt_json(self, temp_sessions_dir):
(temp_sessions_dir / "bad.json").write_text("{not json}", "utf-8")
create_session(name="正常")
assert len(list_all_sessions()) == 1
# ── delete_session ──────────────────────────────────────────────
class TestDeleteSession:
def test_returns_false_for_missing(self, temp_sessions_dir):
assert delete_session("ghost_id") is False
def test_returns_true_and_removes(self, temp_sessions_dir):
created = create_session()
assert delete_session(created["session_id"]) is True
assert load_session(created["session_id"]) is None
def test_file_is_removed_from_disk(self, temp_sessions_dir):
created = create_session()
fp = temp_sessions_dir / f"{created['session_id']}.json"
assert fp.exists()
delete_session(created["session_id"])
assert not fp.exists()