1 Commits

2 changed files with 90 additions and 41 deletions
+4 -9
View File
@@ -35,7 +35,6 @@ class _LLMLoggingWrapper(_BaseLLM):
def invoke(self, prompt: str) -> Any: def invoke(self, prompt: str) -> Any:
t0 = time.time() t0 = time.time()
prompt_len = len(prompt) prompt_len = len(prompt)
prompt_preview = prompt[:500]
_llm_log.debug( _llm_log.debug(
"LLM invoke 请求", "LLM invoke 请求",
extra={ extra={
@@ -44,8 +43,7 @@ class _LLMLoggingWrapper(_BaseLLM):
"backend": self._backend, "backend": self._backend,
"caller": self._caller, "caller": self._caller,
"prompt_length": prompt_len, "prompt_length": prompt_len,
"prompt_preview": prompt_preview, "prompt_preview": prompt[:500],
"prompt": prompt[:10000],
}, },
) )
try: try:
@@ -64,7 +62,6 @@ class _LLMLoggingWrapper(_BaseLLM):
"duration_ms": elapsed, "duration_ms": elapsed,
"response_length": resp_len, "response_length": resp_len,
"response_preview": resp_preview, "response_preview": resp_preview,
"response": content[:10000],
}, },
) )
return result return result
@@ -79,7 +76,7 @@ class _LLMLoggingWrapper(_BaseLLM):
"caller": self._caller, "caller": self._caller,
"duration_ms": elapsed, "duration_ms": elapsed,
"error": str(e), "error": str(e),
"prompt": prompt[:10000], "prompt_preview": prompt[:500],
}, },
) )
raise raise
@@ -96,8 +93,7 @@ class _LLMLoggingWrapper(_BaseLLM):
"backend": self._backend, "backend": self._backend,
"caller": self._caller, "caller": self._caller,
"prompt_length": prompt_len, "prompt_length": prompt_len,
"prompt_preview": prompt_preview, "prompt_preview": prompt[:500],
"prompt": prompt[:10000],
}, },
) )
full = [] full = []
@@ -135,7 +131,6 @@ class _LLMLoggingWrapper(_BaseLLM):
"duration_ms": elapsed, "duration_ms": elapsed,
"response_length": resp_len, "response_length": resp_len,
"response_preview": resp_preview, "response_preview": resp_preview,
"response": resp_text[:10000],
"stop_reason": stop_reason, "stop_reason": stop_reason,
}, },
) )
@@ -150,7 +145,7 @@ class _LLMLoggingWrapper(_BaseLLM):
"caller": self._caller, "caller": self._caller,
"duration_ms": elapsed, "duration_ms": elapsed,
"error": str(e), "error": str(e),
"prompt": prompt[:10000], "prompt_preview": prompt[:500],
}, },
) )
raise raise
+86 -32
View File
@@ -6,11 +6,50 @@
import json import json
import os import os
import re import re
import threading
import uuid import uuid
import tempfile import tempfile
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Any
# Per-session-file locks to prevent concurrent writes from corrupting JSON
_session_locks: dict[str, threading.Lock] = {}
_locks_lock = threading.Lock()
def _get_lock(session_id: str) -> threading.Lock:
with _locks_lock:
if session_id not in _session_locks:
_session_locks[session_id] = threading.Lock()
return _session_locks[session_id]
class _SafeEncoder(json.JSONEncoder):
"""处理 numpy / lxml / 等非标准类型的 JSON 序列化"""
def default(self, o: Any) -> Any:
try:
# numpy 标量
import numpy as np
if isinstance(o, np.integer):
return int(o)
if isinstance(o, np.floating):
return float(o)
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, np.bool_):
return bool(o)
except ImportError:
pass
# lxml intc / 其他 C 类型
try:
return int(o)
except Exception:
pass
# bytes
if isinstance(o, bytes):
return o.decode("utf-8", errors="replace")
return super().default(o)
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -59,8 +98,21 @@ def create_session(name: str = "", agent_state: Optional[dict] = None,
"kb_id": agent_state.get("kb_id", "") if agent_state else "", "kb_id": agent_state.get("kb_id", "") if agent_state else "",
"agent_state": agent_state, "agent_state": agent_state,
} }
with open(_session_path(sid), "w", encoding="utf-8") as f: fp = _session_path(sid)
json.dump(data, f, ensure_ascii=False, indent=2) tmp = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False,
dir=SESSIONS_DIR, encoding="utf-8",
)
try:
json.dump(data, tmp, ensure_ascii=False, indent=2, cls=_SafeEncoder)
tmp.flush()
os.fsync(tmp.fileno())
tmp.close()
os.replace(tmp.name, str(fp))
except Exception:
tmp.close()
Path(tmp.name).unlink(missing_ok=True)
raise
_session_log.info("创建会话", extra={"session_id": sid, "session_name": data["session_name"]}) _session_log.info("创建会话", extra={"session_id": sid, "session_name": data["session_name"]})
return data return data
@@ -79,39 +131,41 @@ def load_session(session_id: str) -> Optional[dict]:
def save_session(session_id: str, agent_state: dict, session_name: str = ""): def save_session(session_id: str, agent_state: dict, session_name: str = ""):
"""将会话状态原子保存至磁盘(temp file + rename,避免崩溃时截断)""" """线程安全地原子保存会话状态到磁盘"""
_ensure_dir() _ensure_dir()
fp = _session_path(session_id) fp = _session_path(session_id)
data = {} lock = _get_lock(session_id)
if fp.exists(): with lock:
with open(fp, "r", encoding="utf-8") as f: data = {}
data = json.load(f) if fp.exists():
with open(fp, "r", encoding="utf-8") as f:
data = json.load(f)
data["session_id"] = session_id data["session_id"] = session_id
if session_name: if session_name:
data["session_name"] = session_name data["session_name"] = session_name
if not data.get("session_name"): if not data.get("session_name"):
data["session_name"] = f"报表 {data.get('created_at', _now_iso())[:10]}" data["session_name"] = f"报表 {data.get('created_at', _now_iso())[:10]}"
data["updated_at"] = _now_iso() data["updated_at"] = _now_iso()
if not data.get("created_at"): if not data.get("created_at"):
data["created_at"] = data["updated_at"] data["created_at"] = data["updated_at"]
data["agent_state"] = agent_state data["agent_state"] = agent_state
# 原子写入:先写临时文件,再 replace,避免崩溃时截断 JSON # 原子写入:先写临时文件,再 replace,避免崩溃时截断 JSON
tmp = tempfile.NamedTemporaryFile( tmp = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False, mode="w", suffix=".json", delete=False,
dir=SESSIONS_DIR, encoding="utf-8", dir=SESSIONS_DIR, encoding="utf-8",
) )
try: try:
json.dump(data, tmp, ensure_ascii=False, indent=2) json.dump(data, tmp, ensure_ascii=False, indent=2, cls=_SafeEncoder)
tmp.flush() tmp.flush()
os.fsync(tmp.fileno()) os.fsync(tmp.fileno())
tmp.close() tmp.close()
os.replace(tmp.name, str(fp)) os.replace(tmp.name, str(fp))
except Exception: except Exception:
tmp.close() tmp.close()
Path(tmp.name).unlink(missing_ok=True) Path(tmp.name).unlink(missing_ok=True)
raise raise
def get_session_state(session_id: str) -> Optional[dict]: def get_session_state(session_id: str) -> Optional[dict]: