"""RAG 适配层 — 查询已由 rag_jrxml 子项目构建好的 ChromaDB 向量知识库。 rag_jrxml 独立运行产出向量库后,主项目通过此模块进行语义搜索。 用法: from backend.rag_adapter import search_chunks context = search_chunks("如何添加饼图", k=5) """ import os import logging from pathlib import Path from typing import Optional from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) _PROJECT_ROOT = Path(__file__).resolve().parent.parent def _resolve(path: str) -> Path: p = Path(path) if not p.is_absolute(): p = _PROJECT_ROOT / p return p class RAGSearcher: """连接预构建的 ChromaDB,提供语义搜索。""" def __init__( self, chroma_path: Optional[str] = None, collection_name: Optional[str] = None, model_name: Optional[str] = None, use_gpu: Optional[bool] = None, use_fp16: Optional[bool] = None, ): self.chroma_path = _resolve(chroma_path or os.getenv("RAG_CHROMA_PATH", "./db/chroma")) self.collection_name = collection_name or os.getenv("RAG_COLLECTION_NAME", "jrxml_chunks") model_path = model_name or os.getenv("RAG_EMBED_MODEL", "./rag/models/paraphrase-multilingual-MiniLM-L12-v2") # 如果本地路径存在则使用本地,否则当 Hub 模型名使用 resolved = _resolve(model_path) self.model_name = str(resolved) if resolved.exists() else model_path self.use_gpu = use_gpu if use_gpu is not None else os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1") self.use_fp16 = use_fp16 if use_fp16 is not None else os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1") self._model = None self._client = None self._collection = None # ---- 模型懒加载 ---- @property def model(self): if self._model is None: import torch from sentence_transformers import SentenceTransformer device = "cuda" if (self.use_gpu and torch.cuda.is_available()) else "cpu" logger.info("加载嵌入模型: %s (device=%s)", self.model_name, device) model = SentenceTransformer(self.model_name, device=device) if device == "cuda" and self.use_fp16: model = model.half() self._model = model return self._model # ---- ChromaDB 懒连接 ---- @property def client(self): if self._client is None: import chromadb self._client = chromadb.PersistentClient(path=str(self.chroma_path)) return self._client @property def collection(self): if self._collection is None: self._collection = self.client.get_collection(self.collection_name) return self._collection def is_ready(self) -> bool: try: self.client.get_collection(self.collection_name) return True except Exception: return False # ---- 语义搜索 ---- def search(self, query: str, k: int = 5, threshold: Optional[float] = None) -> list[dict]: """搜索相关 JRXML chunks,返回 [{id, content, metadata, distance}, ...].""" if not self.is_ready(): logger.warning("ChromaDB 集合 '%s' 不存在,请先在 rag/ 子项目中运行管线", self.collection_name) return [] query_embedding = self.model.encode( query, normalize_embeddings=True, show_progress_bar=False ) results = self.collection.query( query_embeddings=[query_embedding.tolist()], n_results=k, include=["documents", "metadatas", "distances"], ) output = [] if not results["ids"] or not results["ids"][0]: return output for i, doc_id in enumerate(results["ids"][0]): dist = results["distances"][0][i] if threshold is not None and dist > threshold: continue output.append({ "id": doc_id, "content": results["documents"][0][i], "metadata": results["metadatas"][0][i], "distance": dist, }) return output def search_as_context(self, query: str, k: int = 5) -> str: """搜索并返回拼接好的上下文字符串,可直接注入 LLM prompt。""" results = self.search(query, k=k) if not results: return "" parts = [] for r in results: meta = r["metadata"] header = f"[类型:{meta.get('chunk_type', 'N/A')}]" if meta.get("report_name"): header += f" [报表:{meta['report_name']}]" if meta.get("band_name"): header += f" [区域:{meta['band_name']}]" parts.append(f"{header}\n{r['content']}") return "\n\n---\n\n".join(parts) # 全局单例,避免重复加载模型 _searcher: Optional[RAGSearcher] = None def _get_searcher() -> RAGSearcher: global _searcher if _searcher is None: _searcher = RAGSearcher() return _searcher def search_chunks(query: str, k: int = 5, kb_id: str = "") -> str: """搜索知识库并返回拼接后的上下文文本。 若指定 kb_id,使用该 KB 专属 ChromaDB;否则使用全局默认库。 """ if kb_id: from backend.kb_searcher import search_kb return search_kb(kb_id, query, k=k) return _get_searcher().search_as_context(query, k=k)