fix: address audit findings — session_id validation, streaming reset, state isolation
- Replace truncated 12-char UUID with full 32-char UUID (128-bit entropy) - Add validate_session_id() regex check to prevent path traversal - Add _check_session_id() guard on all 6 API endpoints - Change _step_counter from module global to contextvars.ContextVar - Filter None values from node_state before merging into agent_state - Log save_session failures instead of silently swallowing them - Add finishStreaming() in catch/finally blocks to prevent UI lockup - Fix broken multiline docstring in chat() endpoint
This commit is contained in:
+27
-10
@@ -16,6 +16,7 @@ Usage:
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextvars
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -97,25 +98,30 @@ SKIP_NODES = {"load_session", "process_input", "manage_context",
|
||||
_api_log = get_logger("api")
|
||||
UPLOADS_DIR = Path(os.getenv("UPLOADS_DIR", "./uploads"))
|
||||
|
||||
def _check_session_id(session_id: str) -> None:
|
||||
"""校验 session_id 合法性(防路径穿越),非法时抛出 HTTPException(400)。"""
|
||||
from backend.session import validate_session_id
|
||||
if not validate_session_id(session_id):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid session_id: {session_id!r}")
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# 图编译(全局单例,带 node_start 回调)
|
||||
# ─────────────────────────────────────────────
|
||||
|
||||
# 当前请求的事件队列(单个用户桌面应用,无并发问题)
|
||||
# 当前请求的事件队列(单个用户桌面应用)
|
||||
_current_event_queue: Optional[queue.Queue] = None
|
||||
_step_counter: int = 0
|
||||
_step_counter: contextvars.ContextVar[int] = contextvars.ContextVar('_step_counter', default=0)
|
||||
|
||||
|
||||
def _on_node_start(node_name: str):
|
||||
"""全局 node_start 回调 — 将事件推入当前请求的事件队列。"""
|
||||
global _step_counter
|
||||
q = _current_event_queue
|
||||
if q is not None:
|
||||
_step_counter += 1
|
||||
_step_counter.set(_step_counter.get() + 1)
|
||||
q.put(("node_start", {
|
||||
"node": node_name,
|
||||
"label": NODE_LABELS.get(node_name, node_name),
|
||||
"step_index": _step_counter,
|
||||
"step_index": _step_counter.get(),
|
||||
}))
|
||||
|
||||
|
||||
@@ -180,14 +186,18 @@ def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
|
||||
if mode == "updates" and isinstance(data, dict):
|
||||
for node_state in data.values():
|
||||
if isinstance(node_state, dict):
|
||||
agent_state.update(node_state)
|
||||
agent_state.update({k: v for k, v in node_state.items() if v is not None})
|
||||
# 在 graph 完成后立即保存 session,防止 SSE 流中断导致数据丢失
|
||||
sid = agent_state.get("session_id", "")
|
||||
if sid:
|
||||
try:
|
||||
save_session(sid, agent_state)
|
||||
except Exception:
|
||||
pass # 静默失败,SSE 流中还有一次保存机会
|
||||
except Exception as exc:
|
||||
_api_log.error("图运行中保存会话失败", extra={
|
||||
"session_id": sid,
|
||||
"error": str(exc),
|
||||
"traceback": traceback.format_exc(),
|
||||
})
|
||||
event_q.put(("done", {"reason": "graph_completed"}))
|
||||
except Exception as exc:
|
||||
event_q.put(("error", {
|
||||
@@ -198,9 +208,9 @@ def _run_graph_sync(agent_state: AgentState, event_q: queue.Queue):
|
||||
|
||||
async def _sse_generator(agent_state: AgentState, session_id: str = "") -> str:
|
||||
"""SSE 事件生成器 —— 在后台线程运行图,异步产出 SSE 字符串。"""
|
||||
global _current_event_queue, _step_counter
|
||||
global _current_event_queue
|
||||
|
||||
_step_counter = 0
|
||||
_step_counter.set(0)
|
||||
t_start = time.time()
|
||||
event_q: queue.Queue = queue.Queue()
|
||||
_current_event_queue = event_q
|
||||
@@ -347,6 +357,7 @@ async def list_sessions():
|
||||
|
||||
@app.get("/api/sessions/{session_id}")
|
||||
async def get_session(session_id: str):
|
||||
_check_session_id(session_id)
|
||||
data = get_session_state(session_id)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
@@ -361,6 +372,7 @@ async def get_session(session_id: str):
|
||||
|
||||
@app.delete("/api/sessions/{session_id}")
|
||||
async def remove_session(session_id: str):
|
||||
_check_session_id(session_id)
|
||||
ok = delete_session(session_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="会话不存在或已删除")
|
||||
@@ -373,6 +385,8 @@ async def remove_session(session_id: str):
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def upload_file(file: UploadFile = File(...), session_id: str = ""):
|
||||
if session_id:
|
||||
_check_session_id(session_id)
|
||||
file_id = uuid.uuid4().hex[:12]
|
||||
_ensure_upload_dir(session_id)
|
||||
|
||||
@@ -492,6 +506,7 @@ async def chat(session_id: str, payload: dict):
|
||||
Returns:
|
||||
text/event-stream (SSE)
|
||||
"""
|
||||
_check_session_id(session_id)
|
||||
text = payload.get("text", "").strip()
|
||||
file_ids = payload.get("file_ids", [])
|
||||
|
||||
@@ -577,6 +592,7 @@ async def chat(session_id: str, payload: dict):
|
||||
@app.get("/api/sessions/{session_id}/download/latest")
|
||||
async def download_latest(session_id: str):
|
||||
"""下载最新 JRXML 文件。"""
|
||||
_check_session_id(session_id)
|
||||
data = load_session(session_id)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
@@ -601,6 +617,7 @@ async def download_latest(session_id: str):
|
||||
@app.get("/api/sessions/{session_id}/download/{version}")
|
||||
async def download_version(session_id: str, version: int):
|
||||
"""下载指定版本的 JRXML 文件。"""
|
||||
_check_session_id(session_id)
|
||||
data = load_session(session_id)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
Reference in New Issue
Block a user