Files
panda 93ad5e8876 fix: address audit findings — session_id validation, streaming reset, state isolation
- Replace truncated 12-char UUID with full 32-char UUID (128-bit entropy)
- Add validate_session_id() regex check to prevent path traversal
- Add _check_session_id() guard on all 6 API endpoints
- Change _step_counter from module global to contextvars.ContextVar
- Filter None values from node_state before merging into agent_state
- Log save_session failures instead of silently swallowing them
- Add finishStreaming() in catch/finally blocks to prevent UI lockup
- Fix broken multiline docstring in chat() endpoint
2026-05-23 09:08:53 +08:00

211 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""backend/session.py 单元测试 — 会话 CRUD + 原子写入。
覆盖:
- 创建/加载/保存/删除/列出
- 原子写入(tempfile + os.replace
- 边界情况(不存在会话、损坏 JSON、空名称自动填充)
"""
import json
import os
import sys
import tempfile
import time
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from backend.session import (
create_session,
load_session,
save_session,
get_session_state,
list_all_sessions,
delete_session,
generate_session_id,
SESSIONS_DIR,
)
@pytest.fixture
def temp_sessions_dir(monkeypatch):
with tempfile.TemporaryDirectory(prefix="test_sessions_") as tmpdir:
monkeypatch.setattr("backend.session.SESSIONS_DIR", Path(tmpdir))
yield Path(tmpdir)
# ── create_session ──────────────────────────────────────────────
class TestCreateSession:
def test_creates_with_defaults(self, temp_sessions_dir):
s = create_session()
assert len(s["session_id"]) == 32
assert "新建报表" in s["session_name"]
assert s["created_at"]
assert s["updated_at"]
def test_custom_name(self, temp_sessions_dir):
s = create_session(name="测试报表")
assert s["session_name"] == "测试报表"
def test_agent_state_preserved(self, temp_sessions_dir):
s = create_session(agent_state={"current_jrxml": "<x/>"})
assert s["agent_state"]["current_jrxml"] == "<x/>"
def test_session_id_injected_into_agent_state(self, temp_sessions_dir):
s = create_session()
assert s["agent_state"]["session_id"] == s["session_id"]
def test_persists_json_to_disk(self, temp_sessions_dir):
s = create_session(name="磁盘测试")
fp = temp_sessions_dir / f"{s['session_id']}.json"
assert fp.exists()
loaded = json.loads(fp.read_text("utf-8"))
assert loaded["session_name"] == "磁盘测试"
def test_unique_ids_no_collision(self, temp_sessions_dir):
ids = {generate_session_id() for _ in range(100)}
assert len(ids) == 100
def test_creates_sessions_dir_if_missing(self, temp_sessions_dir):
nested = temp_sessions_dir / "nested" / "sub"
import backend.session as mod
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(mod, "SESSIONS_DIR", nested)
s = create_session()
assert nested.exists()
assert (nested / f"{s['session_id']}.json").exists()
# ── load_session ────────────────────────────────────────────────
class TestLoadSession:
def test_returns_none_for_missing(self, temp_sessions_dir):
assert load_session("nonexistent_id") is None
def test_loads_existing(self, temp_sessions_dir):
created = create_session(name="加载测试")
loaded = load_session(created["session_id"])
assert loaded["session_name"] == "加载测试"
assert loaded["session_id"] == created["session_id"]
def test_load_includes_agent_state(self, temp_sessions_dir):
created = create_session(agent_state={"field_count": 5})
loaded = load_session(created["session_id"])
assert loaded["agent_state"]["field_count"] == 5
# ── save_session ────────────────────────────────────────────────
class TestSaveSession:
def test_updates_name_and_state(self, temp_sessions_dir):
created = create_session(name="原始")
save_session(created["session_id"], {"new_key": True}, session_name="更新")
loaded = load_session(created["session_id"])
assert loaded["session_name"] == "更新"
assert loaded["agent_state"]["new_key"] is True
def test_preserves_created_at(self, temp_sessions_dir):
created = create_session()
original = created["created_at"]
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["created_at"] == original
def test_updates_updated_at(self, temp_sessions_dir):
created = create_session()
time.sleep(0.01)
save_session(created["session_id"], {"x": 1})
loaded = load_session(created["session_id"])
assert loaded["updated_at"] != created["updated_at"]
def test_atomic_write_produces_valid_json(self, temp_sessions_dir):
created = create_session()
save_session(created["session_id"], {"data": "x" * 1000})
fp = temp_sessions_dir / f"{created['session_id']}.json"
data = json.loads(fp.read_text("utf-8"))
assert data["agent_state"]["data"] == "x" * 1000
def test_auto_generates_name_when_empty(self, temp_sessions_dir):
created = create_session(name="")
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["session_name"]
def test_keeps_existing_name_when_not_provided(self, temp_sessions_dir):
created = create_session(name="原名")
save_session(created["session_id"], {"x": 1})
assert load_session(created["session_id"])["session_name"] == "原名"
def test_fills_missing_created_at(self, temp_sessions_dir):
sid = "aaaabbbbccccddddeeeeffff"
fp = temp_sessions_dir / f"{sid}.json"
fp.write_text(
json.dumps({"session_id": sid, "session_name": "旧数据"}), "utf-8"
)
save_session(sid, {"x": 1})
assert load_session(sid)["created_at"]
# ── get_session_state ───────────────────────────────────────────
class TestGetSessionState:
def test_none_for_missing(self, temp_sessions_dir):
assert get_session_state("missing") is None
def test_returns_all_keys(self, temp_sessions_dir):
created = create_session(name="状态测试")
state = get_session_state(created["session_id"])
for key in ("session_id", "session_name", "agent_state", "created_at", "updated_at"):
assert key in state
# ── list_all_sessions ───────────────────────────────────────────
class TestListAllSessions:
def test_empty_when_no_sessions(self, temp_sessions_dir):
assert list_all_sessions() == []
def test_lists_all_created(self, temp_sessions_dir):
s1 = create_session(name="A")
s2 = create_session(name="B")
ids = {s["session_id"] for s in list_all_sessions()}
assert s1["session_id"] in ids
assert s2["session_id"] in ids
def test_summary_excludes_agent_state(self, temp_sessions_dir):
create_session(agent_state={"secret": True})
result = list_all_sessions()
assert "agent_state" not in result[0]
def test_sorted_by_mtime_desc(self, temp_sessions_dir):
s1 = create_session(name="")
time.sleep(0.02)
s2 = create_session(name="")
assert list_all_sessions()[0]["session_id"] == s2["session_id"]
def test_skips_corrupt_json(self, temp_sessions_dir):
(temp_sessions_dir / "bad.json").write_text("{not json}", "utf-8")
create_session(name="正常")
assert len(list_all_sessions()) == 1
# ── delete_session ──────────────────────────────────────────────
class TestDeleteSession:
def test_returns_false_for_missing(self, temp_sessions_dir):
assert delete_session("ghost_id") is False
def test_returns_true_and_removes(self, temp_sessions_dir):
created = create_session()
assert delete_session(created["session_id"]) is True
assert load_session(created["session_id"]) is None
def test_file_is_removed_from_disk(self, temp_sessions_dir):
created = create_session()
fp = temp_sessions_dir / f"{created['session_id']}.json"
assert fp.exists()
delete_session(created["session_id"])
assert not fp.exists()