93ad5e8876
- Replace truncated 12-char UUID with full 32-char UUID (128-bit entropy) - Add validate_session_id() regex check to prevent path traversal - Add _check_session_id() guard on all 6 API endpoints - Change _step_counter from module global to contextvars.ContextVar - Filter None values from node_state before merging into agent_state - Log save_session failures instead of silently swallowing them - Add finishStreaming() in catch/finally blocks to prevent UI lockup - Fix broken multiline docstring in chat() endpoint
211 lines
8.1 KiB
Python
211 lines
8.1 KiB
Python
"""backend/session.py 单元测试 — 会话 CRUD + 原子写入。
|
||
|
||
覆盖:
|
||
- 创建/加载/保存/删除/列出
|
||
- 原子写入(tempfile + os.replace)
|
||
- 边界情况(不存在会话、损坏 JSON、空名称自动填充)
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import tempfile
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import pytest
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
||
from backend.session import (
|
||
create_session,
|
||
load_session,
|
||
save_session,
|
||
get_session_state,
|
||
list_all_sessions,
|
||
delete_session,
|
||
generate_session_id,
|
||
SESSIONS_DIR,
|
||
)
|
||
|
||
|
||
@pytest.fixture
|
||
def temp_sessions_dir(monkeypatch):
|
||
with tempfile.TemporaryDirectory(prefix="test_sessions_") as tmpdir:
|
||
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir))
|
||
yield Path(tmpdir)
|
||
|
||
|
||
# ── create_session ──────────────────────────────────────────────
|
||
|
||
class TestCreateSession:
|
||
def test_creates_with_defaults(self, temp_sessions_dir):
|
||
s = create_session()
|
||
assert len(s["session_id"]) == 32
|
||
assert "新建报表" in s["session_name"]
|
||
assert s["created_at"]
|
||
assert s["updated_at"]
|
||
|
||
def test_custom_name(self, temp_sessions_dir):
|
||
s = create_session(name="测试报表")
|
||
assert s["session_name"] == "测试报表"
|
||
|
||
def test_agent_state_preserved(self, temp_sessions_dir):
|
||
s = create_session(agent_state={"current_jrxml": "<x/>"})
|
||
assert s["agent_state"]["current_jrxml"] == "<x/>"
|
||
|
||
def test_session_id_injected_into_agent_state(self, temp_sessions_dir):
|
||
s = create_session()
|
||
assert s["agent_state"]["session_id"] == s["session_id"]
|
||
|
||
def test_persists_json_to_disk(self, temp_sessions_dir):
|
||
s = create_session(name="磁盘测试")
|
||
fp = temp_sessions_dir / f"{s['session_id']}.json"
|
||
assert fp.exists()
|
||
loaded = json.loads(fp.read_text("utf-8"))
|
||
assert loaded["session_name"] == "磁盘测试"
|
||
|
||
def test_unique_ids_no_collision(self, temp_sessions_dir):
|
||
ids = {generate_session_id() for _ in range(100)}
|
||
assert len(ids) == 100
|
||
|
||
def test_creates_sessions_dir_if_missing(self, temp_sessions_dir):
|
||
nested = temp_sessions_dir / "nested" / "sub"
|
||
import backend.session as mod
|
||
|
||
monkeypatch = pytest.MonkeyPatch()
|
||
monkeypatch.setattr(mod, "SESSIONS_DIR", nested)
|
||
s = create_session()
|
||
assert nested.exists()
|
||
assert (nested / f"{s['session_id']}.json").exists()
|
||
|
||
|
||
# ── load_session ────────────────────────────────────────────────
|
||
|
||
class TestLoadSession:
|
||
def test_returns_none_for_missing(self, temp_sessions_dir):
|
||
assert load_session("nonexistent_id") is None
|
||
|
||
def test_loads_existing(self, temp_sessions_dir):
|
||
created = create_session(name="加载测试")
|
||
loaded = load_session(created["session_id"])
|
||
assert loaded["session_name"] == "加载测试"
|
||
assert loaded["session_id"] == created["session_id"]
|
||
|
||
def test_load_includes_agent_state(self, temp_sessions_dir):
|
||
created = create_session(agent_state={"field_count": 5})
|
||
loaded = load_session(created["session_id"])
|
||
assert loaded["agent_state"]["field_count"] == 5
|
||
|
||
|
||
# ── save_session ────────────────────────────────────────────────
|
||
|
||
class TestSaveSession:
|
||
def test_updates_name_and_state(self, temp_sessions_dir):
|
||
created = create_session(name="原始")
|
||
save_session(created["session_id"], {"new_key": True}, session_name="更新")
|
||
loaded = load_session(created["session_id"])
|
||
assert loaded["session_name"] == "更新"
|
||
assert loaded["agent_state"]["new_key"] is True
|
||
|
||
def test_preserves_created_at(self, temp_sessions_dir):
|
||
created = create_session()
|
||
original = created["created_at"]
|
||
save_session(created["session_id"], {"x": 1})
|
||
assert load_session(created["session_id"])["created_at"] == original
|
||
|
||
def test_updates_updated_at(self, temp_sessions_dir):
|
||
created = create_session()
|
||
time.sleep(0.01)
|
||
save_session(created["session_id"], {"x": 1})
|
||
loaded = load_session(created["session_id"])
|
||
assert loaded["updated_at"] != created["updated_at"]
|
||
|
||
def test_atomic_write_produces_valid_json(self, temp_sessions_dir):
|
||
created = create_session()
|
||
save_session(created["session_id"], {"data": "x" * 1000})
|
||
fp = temp_sessions_dir / f"{created['session_id']}.json"
|
||
data = json.loads(fp.read_text("utf-8"))
|
||
assert data["agent_state"]["data"] == "x" * 1000
|
||
|
||
def test_auto_generates_name_when_empty(self, temp_sessions_dir):
|
||
created = create_session(name="")
|
||
save_session(created["session_id"], {"x": 1})
|
||
assert load_session(created["session_id"])["session_name"]
|
||
|
||
def test_keeps_existing_name_when_not_provided(self, temp_sessions_dir):
|
||
created = create_session(name="原名")
|
||
save_session(created["session_id"], {"x": 1})
|
||
assert load_session(created["session_id"])["session_name"] == "原名"
|
||
|
||
def test_fills_missing_created_at(self, temp_sessions_dir):
|
||
sid = "aaaabbbbccccddddeeeeffff"
|
||
fp = temp_sessions_dir / f"{sid}.json"
|
||
fp.write_text(
|
||
json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"
|
||
)
|
||
save_session(sid, {"x": 1})
|
||
assert load_session(sid)["created_at"]
|
||
|
||
|
||
# ── get_session_state ───────────────────────────────────────────
|
||
|
||
class TestGetSessionState:
|
||
def test_none_for_missing(self, temp_sessions_dir):
|
||
assert get_session_state("missing") is None
|
||
|
||
def test_returns_all_keys(self, temp_sessions_dir):
|
||
created = create_session(name="状态测试")
|
||
state = get_session_state(created["session_id"])
|
||
for key in ("session_id", "session_name", "agent_state", "created_at", "updated_at"):
|
||
assert key in state
|
||
|
||
|
||
# ── list_all_sessions ───────────────────────────────────────────
|
||
|
||
class TestListAllSessions:
|
||
def test_empty_when_no_sessions(self, temp_sessions_dir):
|
||
assert list_all_sessions() == []
|
||
|
||
def test_lists_all_created(self, temp_sessions_dir):
|
||
s1 = create_session(name="A")
|
||
s2 = create_session(name="B")
|
||
ids = {s["session_id"] for s in list_all_sessions()}
|
||
assert s1["session_id"] in ids
|
||
assert s2["session_id"] in ids
|
||
|
||
def test_summary_excludes_agent_state(self, temp_sessions_dir):
|
||
create_session(agent_state={"secret": True})
|
||
result = list_all_sessions()
|
||
assert "agent_state" not in result[0]
|
||
|
||
def test_sorted_by_mtime_desc(self, temp_sessions_dir):
|
||
s1 = create_session(name="先")
|
||
time.sleep(0.02)
|
||
s2 = create_session(name="后")
|
||
assert list_all_sessions()[0]["session_id"] == s2["session_id"]
|
||
|
||
def test_skips_corrupt_json(self, temp_sessions_dir):
|
||
(temp_sessions_dir / "bad.json").write_text("{not json}", "utf-8")
|
||
create_session(name="正常")
|
||
assert len(list_all_sessions()) == 1
|
||
|
||
|
||
# ── delete_session ──────────────────────────────────────────────
|
||
|
||
class TestDeleteSession:
|
||
def test_returns_false_for_missing(self, temp_sessions_dir):
|
||
assert delete_session("ghost_id") is False
|
||
|
||
def test_returns_true_and_removes(self, temp_sessions_dir):
|
||
created = create_session()
|
||
assert delete_session(created["session_id"]) is True
|
||
assert load_session(created["session_id"]) is None
|
||
|
||
def test_file_is_removed_from_disk(self, temp_sessions_dir):
|
||
created = create_session()
|
||
fp = temp_sessions_dir / f"{created['session_id']}.json"
|
||
assert fp.exists()
|
||
delete_session(created["session_id"])
|
||
assert not fp.exists()
|