test: add unit/integration/E2E test suites, fix create_session bug, update docs
- Unit tests: test_session.py (27), test_error_kb.py (24), test_agent.py hardened - Integration tests: test_api_integration.py (25) with FastAPI TestClient - E2E tests: main-flows.spec.ts (8) with Playwright + API mocking - Bug fix: backend/session.py create_session() missing session_id parameter - Config: frontend/playwright.config.ts, npm run test:e2e - Docs: update CLAUDE.md v9, .gitignore for test artifacts/eval reports
This commit is contained in:
+9
-8
@@ -44,8 +44,8 @@ class TestAcceptanceScenarios:
|
||||
|
||||
final = run_graph(graph, state)
|
||||
assert final.get("current_jrxml"), "应该已生成 JRXML"
|
||||
# 注意:通过/失败取决于 LLM 输出质量;我们检查是否得到了结果
|
||||
print(f"场景 1 状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}")
|
||||
assert final.get("status") in ("pass", "fail"), f"意外状态: {final.get('status')}"
|
||||
assert "<jasperReport" in final["current_jrxml"], "输出应包含合法 JRXML 根元素"
|
||||
|
||||
def test_scenario2_auto_correction(self, graph):
|
||||
"""场景 2:故意提出一个可能初次失败的需求。"""
|
||||
@@ -58,7 +58,8 @@ class TestAcceptanceScenarios:
|
||||
|
||||
final = run_graph(graph, state)
|
||||
assert final.get("retry_count", 0) <= 5, "不应超过最大重试次数"
|
||||
print(f"场景 2 状态: {final.get('status')}, 重试次数: {final.get('retry_count', 0)}")
|
||||
assert "status" in final, "最终状态应包含 status 字段"
|
||||
assert final.get("current_jrxml") or final.get("error_msg"), "应有输出或错误消息"
|
||||
|
||||
def test_scenario3_multi_turn_modification(self, graph):
|
||||
"""场景 3:多轮对话 - 先生成,再修改两次。"""
|
||||
@@ -71,8 +72,8 @@ class TestAcceptanceScenarios:
|
||||
state["stage"] = "initial_generation"
|
||||
|
||||
final = run_graph(graph, state)
|
||||
print(f"第 1 轮状态: {final.get('status')}, 错误: {final.get('error_msg', '')[:100]}")
|
||||
assert final.get("current_jrxml"), "第 1 轮应该已生成 JRXML"
|
||||
assert final.get("status") in ("pass", "fail")
|
||||
|
||||
# 第 2 轮:添加月度销售汇总
|
||||
state2 = final.copy()
|
||||
@@ -82,8 +83,8 @@ class TestAcceptanceScenarios:
|
||||
state2["retry_count"] = 0
|
||||
|
||||
final2 = run_graph(graph, state2)
|
||||
print(f"第 2 轮状态: {final2.get('status')}")
|
||||
assert final2.get("current_jrxml"), "第 2 轮应该已修改 JRXML"
|
||||
assert final2.get("status") in ("pass", "fail")
|
||||
|
||||
# 第 3 轮:修改标题
|
||||
state3 = final2.copy()
|
||||
@@ -93,9 +94,9 @@ class TestAcceptanceScenarios:
|
||||
state3["retry_count"] = 0
|
||||
|
||||
final3 = run_graph(graph, state3)
|
||||
print(f"第 3 轮状态: {final3.get('status')}")
|
||||
jrxml = final3.get("current_jrxml", "")
|
||||
assert "2024" in jrxml or "Annual" in jrxml, "标题修改应该体现在 JRXML 中"
|
||||
assert final3.get("status") in ("pass", "fail")
|
||||
|
||||
def test_scenario4_context_aware_modification(self, graph):
|
||||
"""场景 4:基于对话上下文的修改。"""
|
||||
@@ -109,7 +110,7 @@ class TestAcceptanceScenarios:
|
||||
state["stage"] = "initial_generation"
|
||||
|
||||
final = run_graph(graph, state)
|
||||
print(f"第 1 轮状态: {final.get('status')}")
|
||||
assert final.get("current_jrxml"), "第 1 轮应该已生成 JRXML"
|
||||
|
||||
# 第 2 轮:上下文感知修改
|
||||
state2 = final.copy()
|
||||
@@ -119,9 +120,9 @@ class TestAcceptanceScenarios:
|
||||
state2["retry_count"] = 0
|
||||
|
||||
final2 = run_graph(graph, state2)
|
||||
print(f"第 2 轮状态: {final2.get('status')}")
|
||||
jrxml = final2.get("current_jrxml", "")
|
||||
assert "isBold" in jrxml or "size=" in jrxml, "字体修改应该体现在结果中"
|
||||
assert final2.get("status") in ("pass", "fail")
|
||||
|
||||
def test_max_retry_handling(self, graph):
|
||||
"""测试在 MAX_RETRY 次失败后,图能否正常终止。"""
|
||||
|
||||
@@ -0,0 +1,263 @@
|
||||
"""api_server.py 集成测试 — REST 端点 + SSE 流 + 文件上传/下载。
|
||||
|
||||
使用 FastAPI TestClient,不需要启动真实服务器。
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from api_server import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_sessions(monkeypatch):
|
||||
"""重定向上传目录到临时目录,隔离测试数据。"""
|
||||
with tempfile.TemporaryDirectory(prefix="test_api_") as tmpdir:
|
||||
monkeypatch.setattr("api_server.UPLOADS_DIR", Path(tmpdir) / "uploads")
|
||||
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir) / "sessions")
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
# ── 健康检查 & 配置 ────────────────────────────────────────────
|
||||
|
||||
class TestHealthAndConfig:
|
||||
def test_health_returns_ok(self, client):
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["version"] == "5.0"
|
||||
assert "timestamp" in data
|
||||
|
||||
def test_config_returns_env_keys(self, client):
|
||||
resp = client.get("/api/config")
|
||||
assert resp.status_code == 200
|
||||
cfg = resp.json()["config"]
|
||||
for key in ("LLM_PROVIDER", "OCR_ENGINE", "MAX_RETRY"):
|
||||
assert key in cfg
|
||||
|
||||
|
||||
# ── 会话 CRUD ──────────────────────────────────────────────────
|
||||
|
||||
class TestSessionCRUD:
|
||||
def test_create_session(self, client, temp_sessions):
|
||||
resp = client.post("/api/sessions")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["session_id"]) == 12
|
||||
assert "session_name" in data
|
||||
assert "created_at" in data
|
||||
|
||||
def test_list_sessions_empty(self, client, temp_sessions):
|
||||
assert client.get("/api/sessions").json()["sessions"] == []
|
||||
|
||||
def test_list_sessions_populated(self, client, temp_sessions):
|
||||
client.post("/api/sessions")
|
||||
client.post("/api/sessions")
|
||||
assert len(client.get("/api/sessions").json()["sessions"]) == 2
|
||||
|
||||
def test_get_session_found(self, client, temp_sessions):
|
||||
created = client.post("/api/sessions").json()
|
||||
resp = client.get(f"/api/sessions/{created['session_id']}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["session_id"] == created["session_id"]
|
||||
assert "agent_state" in resp.json()
|
||||
|
||||
def test_get_session_not_found(self, client, temp_sessions):
|
||||
assert client.get("/api/sessions/nonexistent").status_code == 404
|
||||
|
||||
def test_delete_session(self, client, temp_sessions):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
resp = client.delete(f"/api/sessions/{sid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "deleted"
|
||||
assert client.get(f"/api/sessions/{sid}").status_code == 404
|
||||
|
||||
def test_delete_nonexistent(self, client, temp_sessions):
|
||||
assert client.delete("/api/sessions/ghost_id").status_code == 404
|
||||
|
||||
def test_full_crud_lifecycle(self, client, temp_sessions):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
assert client.get(f"/api/sessions/{sid}").status_code == 200
|
||||
assert len(client.get("/api/sessions").json()["sessions"]) == 1
|
||||
client.delete(f"/api/sessions/{sid}")
|
||||
assert client.get("/api/sessions").json()["sessions"] == []
|
||||
|
||||
|
||||
# ── 文件上传 ───────────────────────────────────────────────────
|
||||
|
||||
class TestFileUpload:
|
||||
def test_upload_text_file(self, client, temp_sessions):
|
||||
content = b"Hello, JRXML!"
|
||||
resp = client.post(
|
||||
"/api/upload",
|
||||
files={"file": ("test.txt", io.BytesIO(content), "text/plain")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["filename"] == "test.txt"
|
||||
assert data["size"] == len(content)
|
||||
assert len(data["file_id"]) == 12
|
||||
|
||||
def test_upload_with_session_id_in_query(self, client, temp_sessions):
|
||||
resp = client.post(
|
||||
"/api/upload?session_id=abc123",
|
||||
files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_upload_png_gets_correct_content_type(self, client, temp_sessions):
|
||||
png_minimal = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
resp = client.post(
|
||||
"/api/upload",
|
||||
files={"file": ("chart.png", io.BytesIO(png_minimal), "image/png")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["content_type"] == "image/png"
|
||||
|
||||
def test_upload_writes_file_to_disk(self, client, temp_sessions):
|
||||
data = b"persisted content"
|
||||
file_id = client.post(
|
||||
"/api/upload",
|
||||
files={"file": ("note.txt", io.BytesIO(data), "text/plain")},
|
||||
).json()["file_id"]
|
||||
|
||||
matches = list(temp_sessions.rglob(f"{file_id}_*"))
|
||||
assert len(matches) == 1
|
||||
assert matches[0].read_bytes() == data
|
||||
|
||||
|
||||
# ── 下载 ───────────────────────────────────────────────────────
|
||||
|
||||
class TestDownload:
|
||||
def test_download_missing_session_returns_404(self, client, temp_sessions):
|
||||
assert client.get("/api/sessions/missing/download/latest").status_code == 404
|
||||
|
||||
def test_download_no_jrxml_returns_404(self, client, temp_sessions):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
resp = client.get(f"/api/sessions/{sid}/download/latest")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_download_with_jrxml_returns_file(self, client, temp_sessions):
|
||||
import backend.session as sess
|
||||
|
||||
sess.create_session(name="测试下载")
|
||||
# 需要手动写入 JRXML 到会话
|
||||
sessions = sess.list_all_sessions()
|
||||
sid = sessions[0]["session_id"]
|
||||
sess.save_session(sid, {"current_jrxml": "<jasperReport name='rpt'/>"})
|
||||
|
||||
resp = client.get(f"/api/sessions/{sid}/download/latest")
|
||||
assert resp.status_code == 200
|
||||
assert "<jasperReport" in resp.text
|
||||
assert "attachment" in resp.headers.get("content-disposition", "")
|
||||
|
||||
|
||||
# ── 聊天 SSE ───────────────────────────────────────────────────
|
||||
|
||||
class TestChatSSE:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_graph(self, monkeypatch):
|
||||
"""Mock LangGraph 的 build_graph 和 stream,避免真实 LLM 调用。"""
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.stream.return_value = [
|
||||
("updates", {"classify_intent": {"intent": "initial_generation"}}),
|
||||
("updates", {"generate": {"current_jrxml": "<jasperReport name='test'/>", "status": "pass"}}),
|
||||
("updates", {"validate": {"status": "pass"}}),
|
||||
("updates", {"finalize": {}}),
|
||||
("done", {"reason": "graph_completed"}),
|
||||
]
|
||||
# 注意:_graph 是模块级变量,在导入时就编译了。需要直接替换。
|
||||
monkeypatch.setattr("api_server._graph", mock_graph)
|
||||
# 同时替换 agent.graph.build_graph 以防后续重新编译
|
||||
monkeypatch.setattr("agent.graph.build_graph", lambda on_node_start=None: mock_graph)
|
||||
return mock_graph
|
||||
|
||||
def test_empty_payload_rejected(self, client, temp_sessions):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
resp = client.post(
|
||||
f"/api/sessions/{sid}/chat",
|
||||
json={"text": "", "file_ids": []},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_sse_stream_returns_valid_events(self, client, temp_sessions):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"/api/sessions/{sid}/chat",
|
||||
json={"text": "生成一个简单的员工名册报表", "file_ids": []},
|
||||
) as resp:
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
body = resp.read().decode("utf-8", errors="replace")
|
||||
assert "event: node_complete" in body
|
||||
assert "event: agent_complete" in body
|
||||
|
||||
def test_auto_creates_session_on_chat(self, client, temp_sessions):
|
||||
with client.stream(
|
||||
"POST",
|
||||
"/api/sessions/auto_new_session/chat",
|
||||
json={"text": "生成报表", "file_ids": []},
|
||||
) as resp:
|
||||
assert resp.status_code == 200
|
||||
assert b"event:" in resp.read()
|
||||
|
||||
def test_unknown_file_ids_not_crash(self, client, temp_sessions):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"/api/sessions/{sid}/chat",
|
||||
json={"text": "测试", "file_ids": ["fake_id_xyz"]},
|
||||
) as resp:
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── 边界 & 安全测试 ────────────────────────────────────────────
|
||||
|
||||
class TestBoundaries:
|
||||
def test_session_id_path_traversal_returns_404(self, client, temp_sessions):
|
||||
assert client.get("/api/sessions/../etc/passwd").status_code == 404
|
||||
|
||||
def test_upload_with_path_traversal_session_id(self, client, temp_sessions):
|
||||
"""路径穿越 session_id 仍正常处理(目录隔离在 UPLOADS_DIR 内)。"""
|
||||
resp = client.post(
|
||||
"/api/upload?session_id=../malicious",
|
||||
files={"file": ("t.txt", io.BytesIO(b"x"), "text/plain")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_invalid_json_body_rejected(self, client, temp_sessions):
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
resp = client.post(
|
||||
f"/api/sessions/{sid}/chat",
|
||||
content=b"{not valid json",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_large_payload_survives(self, client, temp_sessions):
|
||||
"""大文本(100KB)不应崩溃。"""
|
||||
sid = client.post("/api/sessions").json()["session_id"]
|
||||
large_text = "生成报表包含字段: " + ", ".join(f"field_{i}" for i in range(5000))
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"/api/sessions/{sid}/chat",
|
||||
json={"text": large_text, "file_ids": []},
|
||||
) as resp:
|
||||
assert resp.status_code == 200
|
||||
@@ -0,0 +1,242 @@
|
||||
"""backend/error_kb.py 单元测试 — 指纹去重 + 关键词提取 + CRUD。
|
||||
|
||||
覆盖:
|
||||
- _make_fingerprint 标准化与去重
|
||||
- _extract_keywords 中英文混合提取
|
||||
- ErrorKB.record / exists / search / search_as_context(mock ChromaDB)
|
||||
- 全局便捷函数 record_error / search_error_cases
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from backend.error_kb import (
|
||||
_make_fingerprint,
|
||||
_extract_keywords,
|
||||
ErrorKB,
|
||||
get_error_kb,
|
||||
record_error,
|
||||
search_error_cases,
|
||||
)
|
||||
|
||||
|
||||
# ── _make_fingerprint ───────────────────────────────────────────
|
||||
|
||||
class TestMakeFingerprint:
|
||||
def test_same_structure_same_fingerprint(self):
|
||||
e1 = "Field $F{customer_name} is not declared in the report"
|
||||
e2 = "Field $F{order_total} is not declared in the report"
|
||||
assert _make_fingerprint(e1) == _make_fingerprint(e2)
|
||||
|
||||
def test_different_errors_different_fingerprint(self):
|
||||
e1 = "Missing required attribute pageWidth"
|
||||
e2 = "Query returned 0 results"
|
||||
assert _make_fingerprint(e1) != _make_fingerprint(e2)
|
||||
|
||||
def test_normalizes_variable_names(self):
|
||||
fp1 = _make_fingerprint("Field $F{amount} not found")
|
||||
fp2 = _make_fingerprint("Field $F{total_price} not found")
|
||||
assert fp1 == fp2
|
||||
|
||||
def test_normalizes_string_literals_single_quote(self):
|
||||
fp1 = _make_fingerprint("Value 'abc123' is invalid")
|
||||
fp2 = _make_fingerprint("Value 'xyz789' is invalid")
|
||||
assert fp1 == fp2
|
||||
|
||||
def test_normalizes_string_literals_double_quote(self):
|
||||
fp1 = _make_fingerprint('Name "test_table" not found')
|
||||
fp2 = _make_fingerprint('Name "prod_table" not found')
|
||||
assert fp1 == fp2
|
||||
|
||||
def test_normalizes_numbers(self):
|
||||
fp1 = _make_fingerprint("Line 42 has 100 errors")
|
||||
fp2 = _make_fingerprint("Line 7 has 3 errors")
|
||||
assert fp1 == fp2
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _make_fingerprint("ERROR: Missing Field") == _make_fingerprint("error: missing field")
|
||||
|
||||
def test_whitespace_insensitive(self):
|
||||
e1 = "missing field\n\ndeclaration"
|
||||
e2 = "missing field declaration"
|
||||
assert _make_fingerprint(e1) == _make_fingerprint(e2)
|
||||
|
||||
def test_output_is_16_char_hex(self):
|
||||
fp = _make_fingerprint("some error message")
|
||||
assert len(fp) == 16
|
||||
assert all(c in "0123456789abcdef" for c in fp)
|
||||
|
||||
|
||||
# ── _extract_keywords ───────────────────────────────────────────
|
||||
|
||||
class TestExtractKeywords:
|
||||
def test_extracts_chinese_words(self):
|
||||
kw = _extract_keywords("未声明的字段引用和语法错误")
|
||||
has_cn = any(len(k) >= 2 and "一" <= k[0] <= "鿿" for k in kw)
|
||||
assert has_cn
|
||||
|
||||
def test_extracts_english_tokens(self):
|
||||
kw = _extract_keywords("missing field declaration in report")
|
||||
assert "missing" in kw
|
||||
assert "field" in kw
|
||||
assert "report" in kw
|
||||
|
||||
def test_extracts_jrxml_patterns(self):
|
||||
kw = _extract_keywords("Field $F{customer_name} not declared")
|
||||
assert "$F{customer_name}" in kw
|
||||
|
||||
def test_short_tokens_ignored(self):
|
||||
kw = _extract_keywords("a b c ab cd")
|
||||
assert "ab" not in kw
|
||||
assert "cd" not in kw
|
||||
|
||||
def test_empty_input_returns_empty_list(self):
|
||||
assert _extract_keywords("") == []
|
||||
|
||||
def test_mixed_cn_en_jrxml(self):
|
||||
kw = _extract_keywords("字段 $F{amount} 在 report 中未声明")
|
||||
assert "$F{amount}" in kw
|
||||
assert "report" in kw
|
||||
|
||||
|
||||
# ── ErrorKB class (mock ChromaDB) ───────────────────────────────
|
||||
|
||||
def _make_patched_kb(client_override=None, collection_override=None):
|
||||
"""创建一个 ErrorKB 实例,其 ChromaDB 依赖已被 mock。
|
||||
|
||||
因为 chromadb 是懒加载的(在 client/collection property 中导入),
|
||||
直接设置 _client/_collection 实例属性即可绕过真实 ChromaDB。
|
||||
"""
|
||||
kb = ErrorKB()
|
||||
kb._client = client_override or MagicMock()
|
||||
kb._collection = collection_override or MagicMock()
|
||||
if not client_override and not collection_override:
|
||||
# 默认:client.get_collection 返回 mock collection
|
||||
kb._client.get_collection.return_value = kb._collection
|
||||
return kb
|
||||
|
||||
|
||||
class TestErrorKBRecord:
|
||||
def test_exists_returns_true_when_found(self):
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": ["abc123"]}
|
||||
kb = _make_patched_kb(collection_override=col)
|
||||
assert kb.exists("some error") is True
|
||||
|
||||
def test_exists_returns_false_when_not_found(self):
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": []}
|
||||
kb = _make_patched_kb(collection_override=col)
|
||||
assert kb.exists("some error") is False
|
||||
|
||||
def test_exists_survives_exception(self):
|
||||
col = MagicMock()
|
||||
col.get.side_effect = RuntimeError("db down")
|
||||
kb = _make_patched_kb(collection_override=col)
|
||||
assert kb.exists("some error") is False
|
||||
|
||||
def test_record_skips_duplicate(self):
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": ["existing_fp"]}
|
||||
kb = _make_patched_kb(collection_override=col)
|
||||
assert kb.record("error", "<bad/>", "<good/>", "fix prompt") is False
|
||||
col.add.assert_not_called()
|
||||
|
||||
def test_record_adds_new_case(self):
|
||||
col = MagicMock()
|
||||
col.get.return_value = {"ids": []}
|
||||
kb = _make_patched_kb(collection_override=col)
|
||||
assert kb.record(
|
||||
"Field $F{x} not declared",
|
||||
"<bad_jrxml>", "<good_jrxml>",
|
||||
"prompt content", model="test-model", retry_count=2,
|
||||
) is True
|
||||
col.add.assert_called_once()
|
||||
meta = col.add.call_args[1]["metadatas"][0]
|
||||
assert meta["retry_success"] == 3
|
||||
|
||||
|
||||
class TestErrorKBSearch:
|
||||
@pytest.fixture
|
||||
def col(self):
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def kb(self, col):
|
||||
return _make_patched_kb(collection_override=col)
|
||||
|
||||
def test_search_returns_formatted_results(self, kb, col):
|
||||
col.get.return_value = {"ids": []}
|
||||
col.query.return_value = {
|
||||
"ids": [["fp1"]],
|
||||
"documents": [[json.dumps({
|
||||
"error": "test error",
|
||||
"good_jrxml_snippet": "<good/>",
|
||||
"correction_prompt": "fix it",
|
||||
"recorded_at": "2026-01-01T00:00:00",
|
||||
})]],
|
||||
"metadatas": [[{}]],
|
||||
"distances": [[0.05]],
|
||||
}
|
||||
results = kb.search("some error", k=3)
|
||||
assert len(results) == 1
|
||||
assert results[0]["error"] == "test error"
|
||||
assert results[0]["distance"] == 0.05
|
||||
|
||||
def test_search_returns_empty_on_exception(self, kb, col):
|
||||
col.query.side_effect = RuntimeError("fail")
|
||||
assert kb.search("error") == []
|
||||
|
||||
def test_search_as_context_formats_output(self, kb, col):
|
||||
col.get.return_value = {"ids": []}
|
||||
col.query.return_value = {
|
||||
"ids": [["fp1", "fp2"]],
|
||||
"documents": [[
|
||||
json.dumps({"error": "e1", "good_jrxml_snippet": "<g1/>", "correction_prompt": "p1", "recorded_at": ""}),
|
||||
json.dumps({"error": "e2", "good_jrxml_snippet": "<g2/>", "correction_prompt": "p2", "recorded_at": ""}),
|
||||
]],
|
||||
"metadatas": [[{}, {}]],
|
||||
"distances": [[0.1, 0.2]],
|
||||
}
|
||||
ctx = kb.search_as_context("error", k=2)
|
||||
assert "[历史错误案例]" in ctx
|
||||
assert "---" in ctx
|
||||
|
||||
def test_search_as_context_empty_for_no_results(self, kb, col):
|
||||
col.get.return_value = {"ids": []}
|
||||
col.query.return_value = {"ids": [[]], "documents": [[]], "distances": [[]]}
|
||||
assert kb.search_as_context("error") == ""
|
||||
|
||||
def test_stats_returns_count(self, kb, col):
|
||||
col.count.return_value = 42
|
||||
assert kb.stats()["total_cases"] == 42
|
||||
|
||||
def test_stats_zero_on_exception(self, kb, col):
|
||||
col.count.side_effect = RuntimeError("down")
|
||||
assert kb.stats()["total_cases"] == 0
|
||||
|
||||
|
||||
# ── 全局便捷函数 ───────────────────────────────────────────────
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
def test_get_error_kb_is_singleton(self, monkeypatch):
|
||||
import backend.error_kb as mod
|
||||
monkeypatch.setattr(mod, "_kb", None)
|
||||
assert get_error_kb() is get_error_kb()
|
||||
|
||||
def test_record_error_delegates(self):
|
||||
with patch.object(ErrorKB, "record", return_value=True) as mock_r:
|
||||
assert record_error("e", "<b>", "<g>", "p") is True
|
||||
mock_r.assert_called_once()
|
||||
|
||||
def test_search_error_cases_delegates(self):
|
||||
with patch.object(ErrorKB, "search_as_context", return_value="ctx") as mock_s:
|
||||
assert search_error_cases("err", k=5) == "ctx"
|
||||
mock_s.assert_called_once_with("err", k=5)
|
||||
@@ -0,0 +1,210 @@
|
||||
"""backend/session.py 单元测试 — 会话 CRUD + 原子写入。
|
||||
|
||||
覆盖:
|
||||
- 创建/加载/保存/删除/列出
|
||||
- 原子写入(tempfile + os.replace)
|
||||
- 边界情况(不存在会话、损坏 JSON、空名称自动填充)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from backend.session import (
|
||||
create_session,
|
||||
load_session,
|
||||
save_session,
|
||||
get_session_state,
|
||||
list_all_sessions,
|
||||
delete_session,
|
||||
generate_session_id,
|
||||
SESSIONS_DIR,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_sessions_dir(monkeypatch):
|
||||
with tempfile.TemporaryDirectory(prefix="test_sessions_") as tmpdir:
|
||||
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir))
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
# ── create_session ──────────────────────────────────────────────
|
||||
|
||||
class TestCreateSession:
|
||||
def test_creates_with_defaults(self, temp_sessions_dir):
|
||||
s = create_session()
|
||||
assert len(s["session_id"]) == 12
|
||||
assert "新建报表" in s["session_name"]
|
||||
assert s["created_at"]
|
||||
assert s["updated_at"]
|
||||
|
||||
def test_custom_name(self, temp_sessions_dir):
|
||||
s = create_session(name="测试报表")
|
||||
assert s["session_name"] == "测试报表"
|
||||
|
||||
def test_agent_state_preserved(self, temp_sessions_dir):
|
||||
s = create_session(agent_state={"current_jrxml": "<x/>"})
|
||||
assert s["agent_state"]["current_jrxml"] == "<x/>"
|
||||
|
||||
def test_session_id_injected_into_agent_state(self, temp_sessions_dir):
|
||||
s = create_session()
|
||||
assert s["agent_state"]["session_id"] == s["session_id"]
|
||||
|
||||
def test_persists_json_to_disk(self, temp_sessions_dir):
|
||||
s = create_session(name="磁盘测试")
|
||||
fp = temp_sessions_dir / f"{s['session_id']}.json"
|
||||
assert fp.exists()
|
||||
loaded = json.loads(fp.read_text("utf-8"))
|
||||
assert loaded["session_name"] == "磁盘测试"
|
||||
|
||||
def test_unique_ids_no_collision(self, temp_sessions_dir):
|
||||
ids = {generate_session_id() for _ in range(100)}
|
||||
assert len(ids) == 100
|
||||
|
||||
def test_creates_sessions_dir_if_missing(self, temp_sessions_dir):
|
||||
nested = temp_sessions_dir / "nested" / "sub"
|
||||
import backend.session as mod
|
||||
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
monkeypatch.setattr(mod, "SESSIONS_DIR", nested)
|
||||
s = create_session()
|
||||
assert nested.exists()
|
||||
assert (nested / f"{s['session_id']}.json").exists()
|
||||
|
||||
|
||||
# ── load_session ────────────────────────────────────────────────
|
||||
|
||||
class TestLoadSession:
|
||||
def test_returns_none_for_missing(self, temp_sessions_dir):
|
||||
assert load_session("nonexistent_id") is None
|
||||
|
||||
def test_loads_existing(self, temp_sessions_dir):
|
||||
created = create_session(name="加载测试")
|
||||
loaded = load_session(created["session_id"])
|
||||
assert loaded["session_name"] == "加载测试"
|
||||
assert loaded["session_id"] == created["session_id"]
|
||||
|
||||
def test_load_includes_agent_state(self, temp_sessions_dir):
|
||||
created = create_session(agent_state={"field_count": 5})
|
||||
loaded = load_session(created["session_id"])
|
||||
assert loaded["agent_state"]["field_count"] == 5
|
||||
|
||||
|
||||
# ── save_session ────────────────────────────────────────────────
|
||||
|
||||
class TestSaveSession:
|
||||
def test_updates_name_and_state(self, temp_sessions_dir):
|
||||
created = create_session(name="原始")
|
||||
save_session(created["session_id"], {"new_key": True}, session_name="更新")
|
||||
loaded = load_session(created["session_id"])
|
||||
assert loaded["session_name"] == "更新"
|
||||
assert loaded["agent_state"]["new_key"] is True
|
||||
|
||||
def test_preserves_created_at(self, temp_sessions_dir):
|
||||
created = create_session()
|
||||
original = created["created_at"]
|
||||
save_session(created["session_id"], {"x": 1})
|
||||
assert load_session(created["session_id"])["created_at"] == original
|
||||
|
||||
def test_updates_updated_at(self, temp_sessions_dir):
|
||||
created = create_session()
|
||||
time.sleep(0.01)
|
||||
save_session(created["session_id"], {"x": 1})
|
||||
loaded = load_session(created["session_id"])
|
||||
assert loaded["updated_at"] != created["updated_at"]
|
||||
|
||||
def test_atomic_write_produces_valid_json(self, temp_sessions_dir):
|
||||
created = create_session()
|
||||
save_session(created["session_id"], {"data": "x" * 1000})
|
||||
fp = temp_sessions_dir / f"{created['session_id']}.json"
|
||||
data = json.loads(fp.read_text("utf-8"))
|
||||
assert data["agent_state"]["data"] == "x" * 1000
|
||||
|
||||
def test_auto_generates_name_when_empty(self, temp_sessions_dir):
|
||||
created = create_session(name="")
|
||||
save_session(created["session_id"], {"x": 1})
|
||||
assert load_session(created["session_id"])["session_name"]
|
||||
|
||||
def test_keeps_existing_name_when_not_provided(self, temp_sessions_dir):
|
||||
created = create_session(name="原名")
|
||||
save_session(created["session_id"], {"x": 1})
|
||||
assert load_session(created["session_id"])["session_name"] == "原名"
|
||||
|
||||
def test_fills_missing_created_at(self, temp_sessions_dir):
|
||||
sid = "test_no_created"
|
||||
fp = temp_sessions_dir / f"{sid}.json"
|
||||
fp.write_text(
|
||||
json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"
|
||||
)
|
||||
save_session(sid, {"x": 1})
|
||||
assert load_session(sid)["created_at"]
|
||||
|
||||
|
||||
# ── get_session_state ───────────────────────────────────────────
|
||||
|
||||
class TestGetSessionState:
|
||||
def test_none_for_missing(self, temp_sessions_dir):
|
||||
assert get_session_state("missing") is None
|
||||
|
||||
def test_returns_all_keys(self, temp_sessions_dir):
|
||||
created = create_session(name="状态测试")
|
||||
state = get_session_state(created["session_id"])
|
||||
for key in ("session_id", "session_name", "agent_state", "created_at", "updated_at"):
|
||||
assert key in state
|
||||
|
||||
|
||||
# ── list_all_sessions ───────────────────────────────────────────
|
||||
|
||||
class TestListAllSessions:
|
||||
def test_empty_when_no_sessions(self, temp_sessions_dir):
|
||||
assert list_all_sessions() == []
|
||||
|
||||
def test_lists_all_created(self, temp_sessions_dir):
|
||||
s1 = create_session(name="A")
|
||||
s2 = create_session(name="B")
|
||||
ids = {s["session_id"] for s in list_all_sessions()}
|
||||
assert s1["session_id"] in ids
|
||||
assert s2["session_id"] in ids
|
||||
|
||||
def test_summary_excludes_agent_state(self, temp_sessions_dir):
|
||||
create_session(agent_state={"secret": True})
|
||||
result = list_all_sessions()
|
||||
assert "agent_state" not in result[0]
|
||||
|
||||
def test_sorted_by_mtime_desc(self, temp_sessions_dir):
|
||||
s1 = create_session(name="先")
|
||||
time.sleep(0.02)
|
||||
s2 = create_session(name="后")
|
||||
assert list_all_sessions()[0]["session_id"] == s2["session_id"]
|
||||
|
||||
def test_skips_corrupt_json(self, temp_sessions_dir):
|
||||
(temp_sessions_dir / "bad.json").write_text("{not json}", "utf-8")
|
||||
create_session(name="正常")
|
||||
assert len(list_all_sessions()) == 1
|
||||
|
||||
|
||||
# ── delete_session ──────────────────────────────────────────────
|
||||
|
||||
class TestDeleteSession:
|
||||
def test_returns_false_for_missing(self, temp_sessions_dir):
|
||||
assert delete_session("ghost_id") is False
|
||||
|
||||
def test_returns_true_and_removes(self, temp_sessions_dir):
|
||||
created = create_session()
|
||||
assert delete_session(created["session_id"]) is True
|
||||
assert load_session(created["session_id"]) is None
|
||||
|
||||
def test_file_is_removed_from_disk(self, temp_sessions_dir):
|
||||
created = create_session()
|
||||
fp = temp_sessions_dir / f"{created['session_id']}.json"
|
||||
assert fp.exists()
|
||||
delete_session(created["session_id"])
|
||||
assert not fp.exists()
|
||||
Reference in New Issue
Block a user