b280c2b453
Add rag submodule for semantic JRXML chunk retrieval, refactor retrieve node to use RAGSearcher, and fix missing api_key in Anthropic SDK client initialization. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
156 lines
5.2 KiB
Python
156 lines
5.2 KiB
Python
"""RAG 适配层 — 查询已由 rag_jrxml 子项目构建好的 ChromaDB 向量知识库。
|
|
|
|
rag_jrxml 独立运行产出向量库后,主项目通过此模块进行语义搜索。
|
|
|
|
用法:
|
|
from backend.rag_adapter import search_chunks
|
|
context = search_chunks("如何添加饼图", k=5)
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
|
|
|
|
def _resolve(path: str) -> Path:
|
|
p = Path(path)
|
|
if not p.is_absolute():
|
|
p = _PROJECT_ROOT / p
|
|
return p
|
|
|
|
|
|
class RAGSearcher:
|
|
"""连接预构建的 ChromaDB,提供语义搜索。"""
|
|
|
|
def __init__(
|
|
self,
|
|
chroma_path: Optional[str] = None,
|
|
collection_name: Optional[str] = None,
|
|
model_name: Optional[str] = None,
|
|
use_gpu: Optional[bool] = None,
|
|
use_fp16: Optional[bool] = None,
|
|
):
|
|
self.chroma_path = _resolve(chroma_path or os.getenv("RAG_CHROMA_PATH", "./db/chroma"))
|
|
self.collection_name = collection_name or os.getenv("RAG_COLLECTION_NAME", "jrxml_chunks")
|
|
model_path = model_name or os.getenv("RAG_EMBED_MODEL", "./rag/models/paraphrase-multilingual-MiniLM-L12-v2")
|
|
# 如果本地路径存在则使用本地,否则当 Hub 模型名使用
|
|
resolved = _resolve(model_path)
|
|
self.model_name = str(resolved) if resolved.exists() else model_path
|
|
self.use_gpu = use_gpu if use_gpu is not None else os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1")
|
|
self.use_fp16 = use_fp16 if use_fp16 is not None else os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1")
|
|
|
|
self._model = None
|
|
self._client = None
|
|
self._collection = None
|
|
|
|
# ---- 模型懒加载 ----
|
|
@property
|
|
def model(self):
|
|
if self._model is None:
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
device = "cuda" if (self.use_gpu and torch.cuda.is_available()) else "cpu"
|
|
logger.info("加载嵌入模型: %s (device=%s)", self.model_name, device)
|
|
model = SentenceTransformer(self.model_name, device=device)
|
|
if device == "cuda" and self.use_fp16:
|
|
model = model.half()
|
|
self._model = model
|
|
return self._model
|
|
|
|
# ---- ChromaDB 懒连接 ----
|
|
@property
|
|
def client(self):
|
|
if self._client is None:
|
|
import chromadb
|
|
self._client = chromadb.PersistentClient(path=str(self.chroma_path))
|
|
return self._client
|
|
|
|
@property
|
|
def collection(self):
|
|
if self._collection is None:
|
|
self._collection = self.client.get_collection(self.collection_name)
|
|
return self._collection
|
|
|
|
def is_ready(self) -> bool:
|
|
try:
|
|
self.client.get_collection(self.collection_name)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
# ---- 语义搜索 ----
|
|
def search(self, query: str, k: int = 5, threshold: Optional[float] = None) -> list[dict]:
|
|
"""搜索相关 JRXML chunks,返回 [{id, content, metadata, distance}, ...]."""
|
|
if not self.is_ready():
|
|
logger.warning("ChromaDB 集合 '%s' 不存在,请先在 rag/ 子项目中运行管线", self.collection_name)
|
|
return []
|
|
|
|
query_embedding = self.model.encode(
|
|
query, normalize_embeddings=True, show_progress_bar=False
|
|
)
|
|
|
|
results = self.collection.query(
|
|
query_embeddings=[query_embedding.tolist()],
|
|
n_results=k,
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
|
|
output = []
|
|
if not results["ids"] or not results["ids"][0]:
|
|
return output
|
|
|
|
for i, doc_id in enumerate(results["ids"][0]):
|
|
dist = results["distances"][0][i]
|
|
if threshold is not None and dist > threshold:
|
|
continue
|
|
output.append({
|
|
"id": doc_id,
|
|
"content": results["documents"][0][i],
|
|
"metadata": results["metadatas"][0][i],
|
|
"distance": dist,
|
|
})
|
|
return output
|
|
|
|
def search_as_context(self, query: str, k: int = 5) -> str:
|
|
"""搜索并返回拼接好的上下文字符串,可直接注入 LLM prompt。"""
|
|
results = self.search(query, k=k)
|
|
if not results:
|
|
return ""
|
|
|
|
parts = []
|
|
for r in results:
|
|
meta = r["metadata"]
|
|
header = f"[类型:{meta.get('chunk_type', 'N/A')}]"
|
|
if meta.get("report_name"):
|
|
header += f" [报表:{meta['report_name']}]"
|
|
if meta.get("band_name"):
|
|
header += f" [区域:{meta['band_name']}]"
|
|
parts.append(f"{header}\n{r['content']}")
|
|
return "\n\n---\n\n".join(parts)
|
|
|
|
|
|
# 全局单例,避免重复加载模型
|
|
_searcher: Optional[RAGSearcher] = None
|
|
|
|
|
|
def _get_searcher() -> RAGSearcher:
|
|
global _searcher
|
|
if _searcher is None:
|
|
_searcher = RAGSearcher()
|
|
return _searcher
|
|
|
|
|
|
def search_chunks(query: str, k: int = 5) -> str:
|
|
"""搜索 JRXML 知识库并返回拼接后的上下文文本(便捷函数)。"""
|
|
return _get_searcher().search_as_context(query, k=k)
|