"""backend/session.py 单元测试 — 会话 CRUD + 原子写入。
覆盖:
- 创建/加载/保存/删除/列出
- 原子写入(tempfile + os.replace)
- 边界情况(不存在会话、损坏 JSON、空名称自动填充)
"""
import json
import os
import sys
import tempfile
import time
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from backend.session import (
create_session,
load_session,
save_session,
get_session_state,
list_all_sessions,
delete_session,
generate_session_id,
SESSIONS_DIR,
)
@pytest.fixture
def temp_sessions_dir(monkeypatch):
with tempfile.TemporaryDirectory(prefix="test_sessions_") as tmpdir:
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir))
yield Path(tmpdir)
# ── create_session ──────────────────────────────────────────────
class TestCreateSession:
def test_creates_with_defaults(self, temp_sessions_dir):
s = create_session()
assert len(s["session_id"]) == 12
assert "新建报表" in s["session_name"]
assert s["created_at"]
assert s["updated_at"]
def test_custom_name(self, temp_sessions_dir):
s = create_session(name="测试报表")
assert s["session_name"] == "测试报表"
def test_agent_state_preserved(self, temp_sessions_dir):
s = create_session(agent_state={"current_jrxml": ""})
assert s["agent_state"]["current_jrxml"] == ""
def test_session_id_injected_into_agent_state(self, temp_sessions_dir):
s = create_session()
assert s["agent_state"]["session_id"] == s["session_id"]
def test_persists_json_to_disk(self, temp_sessions_dir):
s = create_session(name="磁盘测试")
fp = temp_sessions_dir / f"{s['session_id']}.json"
assert fp.exists()
loaded = json.loads(fp.read_text("utf-8"))
assert loaded["session_name"] == "磁盘测试"
def test_unique_ids_no_collision(self, temp_sessions_dir):
ids = {generate_session_id() for _ in range(100)}
assert len(ids) == 100
def test_creates_sessions_dir_if_missing(self, temp_sessions_dir):
nested = temp_sessions_dir / "nested" / "sub"
import backend.session as mod
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(mod, "SESSIONS_DIR", nested)
s = create_session()
assert nested.exists()
assert (nested / f"{s['session_id']}.json").exists()
# ── load_session ────────────────────────────────────────────────
class TestLoadSession:
def test_returns_none_for_missing(self, temp_sessions_dir):
assert load_session("nonexistent_id") is None
def test_loads_existing(self, temp_sessions_dir):
created = create_session(name="加载测试")
loaded = load_session(created["session_id"])
assert loaded["session_name"] == "加载测试"
assert loaded["session_id"] == created["session_id"]
def test_load_includes_agent_state(self, temp_sessions_dir):
created = create_session(agent_state={"field_count": 5})
loaded = load_session(created["session_id"])
assert loaded["agent_state"]["field_count"] == 5
# ── save_session ────────────────────────────────────────────────
class TestSaveSession:
def test_updates_name_and_state(self, temp_sessions_dir):
created = create_session(name="原始")
save_session(created["session_id"], {"new_key": True}, session_name="更新")
loaded = load_session(created["session_id"])
assert loaded["session_name"] == "更新"
assert loaded["agent_state"]["new_key"] is True
def test_preserves_created_at(self, temp_sessions_dir):
created = create_session()
original = created["created_at"]
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["created_at"] == original
def test_updates_updated_at(self, temp_sessions_dir):
created = create_session()
time.sleep(0.01)
save_session(created["session_id"], {"x": 1})
loaded = load_session(created["session_id"])
assert loaded["updated_at"] != created["updated_at"]
def test_atomic_write_produces_valid_json(self, temp_sessions_dir):
created = create_session()
save_session(created["session_id"], {"data": "x" * 1000})
fp = temp_sessions_dir / f"{created['session_id']}.json"
data = json.loads(fp.read_text("utf-8"))
assert data["agent_state"]["data"] == "x" * 1000
def test_auto_generates_name_when_empty(self, temp_sessions_dir):
created = create_session(name="")
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["session_name"]
def test_keeps_existing_name_when_not_provided(self, temp_sessions_dir):
created = create_session(name="原名")
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["session_name"] == "原名"
def test_fills_missing_created_at(self, temp_sessions_dir):
sid = "test_no_created"
fp = temp_sessions_dir / f"{sid}.json"
fp.write_text(
json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"
)
save_session(sid, {"x": 1})
assert load_session(sid)["created_at"]
# ── get_session_state ───────────────────────────────────────────
class TestGetSessionState:
def test_none_for_missing(self, temp_sessions_dir):
assert get_session_state("missing") is None
def test_returns_all_keys(self, temp_sessions_dir):
created = create_session(name="状态测试")
state = get_session_state(created["session_id"])
for key in ("session_id", "session_name", "agent_state", "created_at", "updated_at"):
assert key in state
# ── list_all_sessions ───────────────────────────────────────────
class TestListAllSessions:
def test_empty_when_no_sessions(self, temp_sessions_dir):
assert list_all_sessions() == []
def test_lists_all_created(self, temp_sessions_dir):
s1 = create_session(name="A")
s2 = create_session(name="B")
ids = {s["session_id"] for s in list_all_sessions()}
assert s1["session_id"] in ids
assert s2["session_id"] in ids
def test_summary_excludes_agent_state(self, temp_sessions_dir):
create_session(agent_state={"secret": True})
result = list_all_sessions()
assert "agent_state" not in result[0]
def test_sorted_by_mtime_desc(self, temp_sessions_dir):
s1 = create_session(name="先")
time.sleep(0.02)
s2 = create_session(name="后")
assert list_all_sessions()[0]["session_id"] == s2["session_id"]
def test_skips_corrupt_json(self, temp_sessions_dir):
(temp_sessions_dir / "bad.json").write_text("{not json}", "utf-8")
create_session(name="正常")
assert len(list_all_sessions()) == 1
# ── delete_session ──────────────────────────────────────────────
class TestDeleteSession:
def test_returns_false_for_missing(self, temp_sessions_dir):
assert delete_session("ghost_id") is False
def test_returns_true_and_removes(self, temp_sessions_dir):
created = create_session()
assert delete_session(created["session_id"]) is True
assert load_session(created["session_id"]) is None
def test_file_is_removed_from_disk(self, temp_sessions_dir):
created = create_session()
fp = temp_sessions_dir / f"{created['session_id']}.json"
assert fp.exists()
delete_session(created["session_id"])
assert not fp.exists()