"""KB 隔离的 ChromaDB 语义搜索适配器。 每个知识库拥有独立的 ChromaDB collection。 调用者: backend/rag_adapter.py, agent/nodes.py, api_server.py """ import os import logging from pathlib import Path from typing import Optional from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) _PROJECT_ROOT = Path(__file__).resolve().parent.parent def _resolve(path: str) -> Path: p = Path(path) return p if p.is_absolute() else _PROJECT_ROOT / p class KBChromaSearcher: """连接指定 KB 的 ChromaDB,提供语义搜索。""" def __init__(self, chroma_path: str, collection_name: str = "kb_chunks", model_name: Optional[str] = None, use_gpu: Optional[bool] = None, use_fp16: Optional[bool] = None): self.chroma_path = str(_resolve(chroma_path)) self.collection_name = collection_name model_path = model_name or os.getenv( "RAG_EMBED_MODEL", "./rag/models/paraphrase-multilingual-MiniLM-L12-v2") resolved = _resolve(model_path) self.model_name = str(resolved) if resolved.exists() else model_path self.use_gpu = (use_gpu if use_gpu is not None else os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1")) self.use_fp16 = (use_fp16 if use_fp16 is not None else os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1")) self._model = None self._client = None self._collection = None @property def model(self): if self._model is None: import torch from sentence_transformers import SentenceTransformer device = "cuda" if (self.use_gpu and torch.cuda.is_available()) else "cpu" logger.info("加载嵌入模型: %s (device=%s)", self.model_name, device) model = SentenceTransformer(self.model_name, device=device) if device == "cuda" and self.use_fp16: model = model.half() self._model = model return self._model @property def client(self): if self._client is None: import chromadb self._client = chromadb.PersistentClient(path=self.chroma_path) return self._client @property def collection(self): if self._collection is None: try: self._collection = self.client.get_collection(self.collection_name) except Exception: self._collection = self.client.create_collection( self.collection_name, metadata={"hnsw:space": "cosine"}) return self._collection def is_ready(self) -> bool: try: self.client.get_collection(self.collection_name) return True except Exception: return False def search(self, query: str, k: int = 5, threshold: Optional[float] = None) -> list[dict]: if not self.is_ready(): return [] query_embedding = self.model.encode( query, normalize_embeddings=True, show_progress_bar=False) results = self.collection.query( query_embeddings=[query_embedding.tolist()], n_results=k, include=["documents", "metadatas", "distances"]) output = [] if not results["ids"] or not results["ids"][0]: return output for i, doc_id in enumerate(results["ids"][0]): dist = results["distances"][0][i] if threshold is not None and dist > threshold: continue output.append({ "id": doc_id, "content": results["documents"][0][i], "metadata": results["metadatas"][0][i] or {}, "distance": dist, }) return output def search_templates(self, query: str, k: int = 3) -> list[dict]: results = self.search(query, k=k * 2) templates = [] for r in results: meta = r.get("metadata", {}) chunk_type = meta.get("chunk_type", "") if "jrxml" in chunk_type.lower() or meta.get("report_name"): templates.append(r) if len(templates) >= k: break return templates def search_as_context(self, query: str, k: int = 5) -> str: results = self.search(query, k=k) if not results: return "" parts = [] for r in results: meta = r.get("metadata", {}) header = f"[类型:{meta.get('chunk_type', 'N/A')}]" if meta.get("report_name"): header += f" [报表:{meta['report_name']}]" parts.append(f"{header}\n{r['content']}") return "\n\n---\n\n".join(parts) def add_chunks(self, chunks: list[dict]) -> None: if not chunks: return ids = [c["id"] for c in chunks] docs = [c["content"] for c in chunks] metas = [c.get("metadata", {}) for c in chunks] embeddings = self.model.encode( docs, normalize_embeddings=True, show_progress_bar=True) self.collection.upsert( ids=ids, documents=docs, metadatas=metas, embeddings=embeddings.tolist()) _searchers: dict = {} def get_kb_searcher(kb_id: str) -> Optional[KBChromaSearcher]: from backend.kb_manager import get_kb_chroma_path if kb_id in _searchers: return _searchers[kb_id] chroma_path = get_kb_chroma_path(kb_id) if chroma_path is None: return None searcher = KBChromaSearcher(str(chroma_path)) _searchers[kb_id] = searcher return searcher def search_kb(kb_id: str, query: str, k: int = 5) -> str: searcher = get_kb_searcher(kb_id) if searcher is None: return "" return searcher.search_as_context(query, k=k) def search_templates_in_kb(kb_id: str, query: str, k: int = 3) -> list[dict]: searcher = get_kb_searcher(kb_id) if searcher is None: return [] return searcher.search_templates(query, k=k)