Files

215 lines
6.5 KiB
Python

"""多会话持久化管理模块。
每个会话对应一个独立的 JSON 文件存储在 ./sessions/ 目录下。
"""
import json
import os
import re
import threading
import uuid
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional, Any
# Per-session-file locks to prevent concurrent writes from corrupting JSON
_session_locks: dict[str, threading.Lock] = {}
_locks_lock = threading.Lock()
def _get_lock(session_id: str) -> threading.Lock:
with _locks_lock:
if session_id not in _session_locks:
_session_locks[session_id] = threading.Lock()
return _session_locks[session_id]
class _SafeEncoder(json.JSONEncoder):
"""处理 numpy / lxml / 等非标准类型的 JSON 序列化"""
def default(self, o: Any) -> Any:
try:
# numpy 标量
import numpy as np
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, np.bool_):
return bool(o)
except ImportError:
pass
# lxml intc / 其他 C 类型
try:
return int(o)
except Exception:
pass
# bytes
if isinstance(o, bytes):
return o.decode("utf-8", errors="replace")
return super().default(o)
from dotenv import load_dotenv
from backend.logger import get_logger
load_dotenv()
_session_log = get_logger("session")
SESSIONS_DIR = Path(os.getenv("SESSIONS_DIR", "./sessions"))
def _ensure_dir():
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
_VALID_SESSION_ID_RE = re.compile(r'^[a-fA-F0-9]{12,}$')
def validate_session_id(session_id: str) -> bool:
"""校验 session_id 仅含合法 hex 字符(防路径穿越)。"""
return bool(_VALID_SESSION_ID_RE.match(session_id))
def _session_path(session_id: str) -> Path:
if not validate_session_id(session_id):
raise ValueError(f"Invalid session_id: {session_id!r}")
return SESSIONS_DIR / f"{session_id}.json"
def generate_session_id() -> str:
return uuid.uuid4().hex
def create_session(name: str = "", agent_state: Optional[dict] = None,
session_id: Optional[str] = None) -> dict:
"""创建新会话,返回会话元数据。session_id 可选——传入时使用指定 ID。"""
_ensure_dir()
sid = session_id or generate_session_id()
now = _now_iso()
agent_state = agent_state or {}
agent_state["session_id"] = sid
data = {
"session_id": sid,
"session_name": name or f"新建报表 {now[:10]}",
"created_at": now,
"updated_at": now,
"kb_id": agent_state.get("kb_id", "") if agent_state else "",
"agent_state": agent_state,
}
fp = _session_path(sid)
tmp = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False,
dir=SESSIONS_DIR, encoding="utf-8",
)
try:
json.dump(data, tmp, ensure_ascii=False, indent=2, cls=_SafeEncoder)
tmp.flush()
os.fsync(tmp.fileno())
tmp.close()
os.replace(tmp.name, str(fp))
except Exception:
tmp.close()
Path(tmp.name).unlink(missing_ok=True)
raise
_session_log.info("创建会话", extra={"session_id": sid, "session_name": data["session_name"]})
return data
def load_session(session_id: str) -> Optional[dict]:
"""按 ID 加载会话数据。未找到则返回 None。"""
_ensure_dir()
try:
fp = _session_path(session_id)
except ValueError:
return None
if not fp.exists():
return None
with open(fp, "r", encoding="utf-8") as f:
return json.load(f)
def save_session(session_id: str, agent_state: dict, session_name: str = ""):
"""线程安全地原子保存会话状态到磁盘。"""
_ensure_dir()
fp = _session_path(session_id)
lock = _get_lock(session_id)
with lock:
data = {}
if fp.exists():
with open(fp, "r", encoding="utf-8") as f:
data = json.load(f)
data["session_id"] = session_id
if session_name:
data["session_name"] = session_name
if not data.get("session_name"):
data["session_name"] = f"报表 {data.get('created_at', _now_iso())[:10]}"
data["updated_at"] = _now_iso()
if not data.get("created_at"):
data["created_at"] = data["updated_at"]
data["agent_state"] = agent_state
# 原子写入:先写临时文件,再 replace,避免崩溃时截断 JSON
tmp = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False,
dir=SESSIONS_DIR, encoding="utf-8",
)
try:
json.dump(data, tmp, ensure_ascii=False, indent=2, cls=_SafeEncoder)
tmp.flush()
os.fsync(tmp.fileno())
tmp.close()
os.replace(tmp.name, str(fp))
except Exception:
tmp.close()
Path(tmp.name).unlink(missing_ok=True)
raise
def get_session_state(session_id: str) -> Optional[dict]:
"""获取会话的完整 agent_state,用于 REST API。
返回 dict 包含 session_id, session_name, created_at, updated_at, agent_state。
未找到则返回 None。
"""
return load_session(session_id)
def list_all_sessions() -> list[dict]:
"""列出所有历史会话(仅摘要,不含完整 agent_state)。"""
_ensure_dir()
sessions = []
for fp in sorted(SESSIONS_DIR.glob("*.json"), key=os.path.getmtime, reverse=True):
try:
with open(fp, "r", encoding="utf-8") as f:
data = json.load(f)
sessions.append({
"session_id": data.get("session_id", fp.stem),
"session_name": data.get("session_name", fp.stem),
"created_at": data.get("created_at", ""),
"updated_at": data.get("updated_at", ""),
})
except (json.JSONDecodeError, KeyError):
continue
return sessions
def delete_session(session_id: str) -> bool:
"""按 ID 删除会话文件。"""
_ensure_dir()
try:
fp = _session_path(session_id)
except ValueError:
return False
if fp.exists():
fp.unlink()
_session_log.info("删除会话", extra={"session_id": session_id})
return True
return False
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()