feat: integrate RAG rag_jrxml submodule and fix Anthropic API key
Add rag submodule for semantic JRXML chunk retrieval, refactor retrieve node to use RAGSearcher, and fix missing api_key in Anthropic SDK client initialization. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -22,6 +22,26 @@ VALIDATION_SERVICE_URL=http://localhost:8001/validate
|
|||||||
# Chroma 持久化目录
|
# Chroma 持久化目录
|
||||||
CHROMA_PERSIST_DIR=./db/chroma
|
CHROMA_PERSIST_DIR=./db/chroma
|
||||||
|
|
||||||
|
# ---- RAG / 向量知识库 (rag_jrxml 子模块) ----
|
||||||
|
# 嵌入模型
|
||||||
|
RAG_EMBED_MODEL=sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
|
||||||
|
# JRXML 模板源目录 (rag 子模块内已含 107 个模板)
|
||||||
|
RAG_JRXML_SOURCE=./rag/jrxml_source
|
||||||
|
# 分块输出目录
|
||||||
|
RAG_CHUNKER_OUTPUT=./rag/jrxml_chunker_output
|
||||||
|
# 向量输出目录
|
||||||
|
RAG_EMBEDDINGS_DIR=./rag/embeddings
|
||||||
|
# ChromaDB 知识库路径
|
||||||
|
RAG_CHROMA_PATH=./db/chroma
|
||||||
|
# ChromaDB 集合名称
|
||||||
|
RAG_COLLECTION_NAME=jrxml_chunks
|
||||||
|
# GPU 加速
|
||||||
|
RAG_USE_GPU=true
|
||||||
|
# FP16 半精度
|
||||||
|
RAG_USE_FP16=true
|
||||||
|
# 向量化批处理大小
|
||||||
|
RAG_BATCH_SIZE=64
|
||||||
|
|
||||||
# 最大自动修正尝试次数
|
# 最大自动修正尝试次数
|
||||||
MAX_RETRY=3
|
MAX_RETRY=3
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,14 @@ dist/
|
|||||||
db/chroma/
|
db/chroma/
|
||||||
sessions/
|
sessions/
|
||||||
|
|
||||||
|
# RAG 管线中间产物 (rag 子模块内)
|
||||||
|
rag/jrxml_chunker_output/
|
||||||
|
rag/embeddings/
|
||||||
|
rag/models/
|
||||||
|
rag/__pycache__/
|
||||||
|
rag/chroma_db/
|
||||||
|
rag/jrxml_source_chunks/
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "rag"]
|
||||||
|
path = rag
|
||||||
|
url = http://www.1415243231.top:8418/panda/rag_jrxml.git
|
||||||
+4
-18
@@ -10,7 +10,6 @@ from typing import Dict
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from agent.state import AgentState
|
from agent.state import AgentState
|
||||||
from backend.embeddings import get_embeddings
|
|
||||||
from backend.llm import get_llm
|
from backend.llm import get_llm
|
||||||
from backend.validation import validate_jrxml
|
from backend.validation import validate_jrxml
|
||||||
|
|
||||||
@@ -422,26 +421,13 @@ def _now_iso() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def retrieve(state: AgentState) -> Dict:
|
def retrieve(state: AgentState) -> Dict:
|
||||||
"""在 Chroma 中搜索相关的 JRXML 模板和组件。"""
|
"""在 Chroma 中搜索相关的 JRXML 模板和组件(使用 rag_jrxml 语义分块管线)。"""
|
||||||
try:
|
try:
|
||||||
embeddings = get_embeddings()
|
from backend.rag_adapter import search_chunks
|
||||||
from langchain_chroma import Chroma
|
|
||||||
|
|
||||||
persist_dir = os.getenv("CHROMA_PERSIST_DIR", "./db/chroma")
|
|
||||||
if not os.path.exists(persist_dir) or not os.listdir(persist_dir):
|
|
||||||
state["retrieved_context"] = ""
|
|
||||||
return state
|
|
||||||
|
|
||||||
vectorstore = Chroma(
|
|
||||||
embedding_function=embeddings,
|
|
||||||
persist_directory=persist_dir,
|
|
||||||
)
|
|
||||||
user_input = state.get("user_input", "")
|
user_input = state.get("user_input", "")
|
||||||
docs = vectorstore.similarity_search(user_input, k=5)
|
context = search_chunks(user_input, k=5)
|
||||||
context_parts = []
|
state["retrieved_context"] = context
|
||||||
for d in docs:
|
|
||||||
context_parts.append(d.page_content)
|
|
||||||
state["retrieved_context"] = "\n\n---\n\n".join(context_parts)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
state["retrieved_context"] = ""
|
state["retrieved_context"] = ""
|
||||||
return state
|
return state
|
||||||
|
|||||||
+24
-2
@@ -1,4 +1,9 @@
|
|||||||
"""嵌入模型工厂:支持本地 sentence-transformers 和云端 API。"""
|
"""嵌入模型工厂:支持本地 Sentence-Transformers 和云端 API。
|
||||||
|
|
||||||
|
调用方式:
|
||||||
|
get_embeddings() → LangChain 兼容的 embeddings 对象
|
||||||
|
get_st_embeddings() → 原始 SentenceTransformer 实例
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -7,6 +12,7 @@ load_dotenv()
|
|||||||
|
|
||||||
|
|
||||||
def get_embeddings():
|
def get_embeddings():
|
||||||
|
"""返回 LangChain 兼容的 embeddings 对象(用于 langchain_chroma 等)。"""
|
||||||
backend = os.getenv("EMBED_BACKEND", "local")
|
backend = os.getenv("EMBED_BACKEND", "local")
|
||||||
if backend == "cloud":
|
if backend == "cloud":
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
@@ -22,5 +28,21 @@ def get_embeddings():
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
model = os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B")
|
model = os.getenv("RAG_EMBED_MODEL", os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B"))
|
||||||
return HuggingFaceEmbeddings(model_name=model)
|
return HuggingFaceEmbeddings(model_name=model)
|
||||||
|
|
||||||
|
|
||||||
|
def get_st_model():
|
||||||
|
"""返回原始 SentenceTransformer 实例(与 rag_jrxml 子模块使用方式一致)。"""
|
||||||
|
import torch
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
model_name = os.getenv("RAG_EMBED_MODEL", os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B"))
|
||||||
|
use_gpu = os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1")
|
||||||
|
use_fp16 = os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1")
|
||||||
|
|
||||||
|
device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu"
|
||||||
|
model = SentenceTransformer(model_name, device=device)
|
||||||
|
if device == "cuda" and use_fp16:
|
||||||
|
model = model.half()
|
||||||
|
return model
|
||||||
|
|||||||
+1
-1
@@ -28,7 +28,7 @@ def get_llm():
|
|||||||
|
|
||||||
os.environ["NO_PROXY"] = "*"
|
os.environ["NO_PROXY"] = "*"
|
||||||
|
|
||||||
client = Anthropic(base_url=base_url, timeout=120)
|
client = Anthropic(api_key=api_key, base_url=base_url, timeout=120)
|
||||||
|
|
||||||
class MiniMaxLLM:
|
class MiniMaxLLM:
|
||||||
def invoke(self, prompt: str) -> Any:
|
def invoke(self, prompt: str) -> Any:
|
||||||
|
|||||||
@@ -0,0 +1,155 @@
|
|||||||
|
"""RAG 适配层 — 查询已由 rag_jrxml 子项目构建好的 ChromaDB 向量知识库。
|
||||||
|
|
||||||
|
rag_jrxml 独立运行产出向量库后,主项目通过此模块进行语义搜索。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
from backend.rag_adapter import search_chunks
|
||||||
|
context = search_chunks("如何添加饼图", k=5)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve(path: str) -> Path:
|
||||||
|
p = Path(path)
|
||||||
|
if not p.is_absolute():
|
||||||
|
p = _PROJECT_ROOT / p
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
class RAGSearcher:
|
||||||
|
"""连接预构建的 ChromaDB,提供语义搜索。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chroma_path: Optional[str] = None,
|
||||||
|
collection_name: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
use_gpu: Optional[bool] = None,
|
||||||
|
use_fp16: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
self.chroma_path = _resolve(chroma_path or os.getenv("RAG_CHROMA_PATH", "./db/chroma"))
|
||||||
|
self.collection_name = collection_name or os.getenv("RAG_COLLECTION_NAME", "jrxml_chunks")
|
||||||
|
model_path = model_name or os.getenv("RAG_EMBED_MODEL", "./rag/models/paraphrase-multilingual-MiniLM-L12-v2")
|
||||||
|
# 如果本地路径存在则使用本地,否则当 Hub 模型名使用
|
||||||
|
resolved = _resolve(model_path)
|
||||||
|
self.model_name = str(resolved) if resolved.exists() else model_path
|
||||||
|
self.use_gpu = use_gpu if use_gpu is not None else os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1")
|
||||||
|
self.use_fp16 = use_fp16 if use_fp16 is not None else os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1")
|
||||||
|
|
||||||
|
self._model = None
|
||||||
|
self._client = None
|
||||||
|
self._collection = None
|
||||||
|
|
||||||
|
# ---- 模型懒加载 ----
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
if self._model is None:
|
||||||
|
import torch
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
device = "cuda" if (self.use_gpu and torch.cuda.is_available()) else "cpu"
|
||||||
|
logger.info("加载嵌入模型: %s (device=%s)", self.model_name, device)
|
||||||
|
model = SentenceTransformer(self.model_name, device=device)
|
||||||
|
if device == "cuda" and self.use_fp16:
|
||||||
|
model = model.half()
|
||||||
|
self._model = model
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
# ---- ChromaDB 懒连接 ----
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
if self._client is None:
|
||||||
|
import chromadb
|
||||||
|
self._client = chromadb.PersistentClient(path=str(self.chroma_path))
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collection(self):
|
||||||
|
if self._collection is None:
|
||||||
|
self._collection = self.client.get_collection(self.collection_name)
|
||||||
|
return self._collection
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
try:
|
||||||
|
self.client.get_collection(self.collection_name)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ---- 语义搜索 ----
|
||||||
|
def search(self, query: str, k: int = 5, threshold: Optional[float] = None) -> list[dict]:
|
||||||
|
"""搜索相关 JRXML chunks,返回 [{id, content, metadata, distance}, ...]."""
|
||||||
|
if not self.is_ready():
|
||||||
|
logger.warning("ChromaDB 集合 '%s' 不存在,请先在 rag/ 子项目中运行管线", self.collection_name)
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_embedding = self.model.encode(
|
||||||
|
query, normalize_embeddings=True, show_progress_bar=False
|
||||||
|
)
|
||||||
|
|
||||||
|
results = self.collection.query(
|
||||||
|
query_embeddings=[query_embedding.tolist()],
|
||||||
|
n_results=k,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
if not results["ids"] or not results["ids"][0]:
|
||||||
|
return output
|
||||||
|
|
||||||
|
for i, doc_id in enumerate(results["ids"][0]):
|
||||||
|
dist = results["distances"][0][i]
|
||||||
|
if threshold is not None and dist > threshold:
|
||||||
|
continue
|
||||||
|
output.append({
|
||||||
|
"id": doc_id,
|
||||||
|
"content": results["documents"][0][i],
|
||||||
|
"metadata": results["metadatas"][0][i],
|
||||||
|
"distance": dist,
|
||||||
|
})
|
||||||
|
return output
|
||||||
|
|
||||||
|
def search_as_context(self, query: str, k: int = 5) -> str:
|
||||||
|
"""搜索并返回拼接好的上下文字符串,可直接注入 LLM prompt。"""
|
||||||
|
results = self.search(query, k=k)
|
||||||
|
if not results:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
for r in results:
|
||||||
|
meta = r["metadata"]
|
||||||
|
header = f"[类型:{meta.get('chunk_type', 'N/A')}]"
|
||||||
|
if meta.get("report_name"):
|
||||||
|
header += f" [报表:{meta['report_name']}]"
|
||||||
|
if meta.get("band_name"):
|
||||||
|
header += f" [区域:{meta['band_name']}]"
|
||||||
|
parts.append(f"{header}\n{r['content']}")
|
||||||
|
return "\n\n---\n\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例,避免重复加载模型
|
||||||
|
_searcher: Optional[RAGSearcher] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_searcher() -> RAGSearcher:
|
||||||
|
global _searcher
|
||||||
|
if _searcher is None:
|
||||||
|
_searcher = RAGSearcher()
|
||||||
|
return _searcher
|
||||||
|
|
||||||
|
|
||||||
|
def search_chunks(query: str, k: int = 5) -> str:
|
||||||
|
"""搜索 JRXML 知识库并返回拼接后的上下文文本(便捷函数)。"""
|
||||||
|
return _get_searcher().search_as_context(query, k=k)
|
||||||
Submodule
+1
Submodule rag added at 687b3a8f90
@@ -17,6 +17,9 @@ lxml>=5.3.0
|
|||||||
|
|
||||||
# 嵌入模型(本地)
|
# 嵌入模型(本地)
|
||||||
sentence-transformers>=3.0.0
|
sentence-transformers>=3.0.0
|
||||||
|
torch>=2.0.0
|
||||||
|
huggingface_hub>=0.19.0
|
||||||
|
tqdm>=4.65.0
|
||||||
|
|
||||||
# 工具类
|
# 工具类
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
|
|||||||
+29
-94
@@ -1,120 +1,55 @@
|
|||||||
"""初始化 Chroma 知识库,加载示例 JRXML 模板和错误修正案例。
|
"""初始化 JRXML 向量知识库。
|
||||||
|
|
||||||
用法: python scripts/init_kb.py
|
rag_jrxml 子项目独立运行管线(分块→向量化→导入),本脚本仅用于预下载嵌入模型。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python scripts/init_kb.py --download-model # 预下载嵌入模型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
def download_embeddings_model():
|
def download_model():
|
||||||
"""预下载 Qwen3-Embedding 模型(从 HuggingFace)。
|
"""预下载嵌入模型到本地。"""
|
||||||
|
model_name = os.getenv("RAG_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B")
|
||||||
用法: python scripts/init_kb.py --download-model
|
|
||||||
"""
|
|
||||||
model_name = os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B")
|
|
||||||
print(f"正在下载嵌入模型: {model_name}")
|
print(f"正在下载嵌入模型: {model_name}")
|
||||||
print("如遇网络超时,可手动执行以下命令后重试:")
|
print("如遇网络超时,可设置环境变量 HF_ENDPOINT=https://hf-mirror.com 使用镜像")
|
||||||
print(f" huggingface-cli download {model_name} --local-dir ./models/{model_name.replace('/', '_')}")
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
try:
|
from sentence_transformers import SentenceTransformer
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
|
||||||
except ImportError:
|
|
||||||
print("错误: 请先安装 huggingface 依赖")
|
|
||||||
print(" pip install langchain-huggingface sentence-transformers")
|
|
||||||
return
|
|
||||||
|
|
||||||
# HuggingFaceEmbeddings 会在首次调用时自动下载模型
|
model = SentenceTransformer(model_name)
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
model.encode("测试下载")
|
||||||
# 调用一次以确保完全下载
|
|
||||||
embeddings.embed_query("测试")
|
|
||||||
print(f"嵌入模型下载完成: {model_name}")
|
print(f"嵌入模型下载完成: {model_name}")
|
||||||
|
|
||||||
from backend.embeddings import get_embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def load_templates(template_dir: Path) -> list[dict]:
|
|
||||||
docs = []
|
|
||||||
for fpath in template_dir.glob('*.jrxml'):
|
|
||||||
content = fpath.read_text(encoding='utf-8')
|
|
||||||
name = fpath.stem
|
|
||||||
docs.append({
|
|
||||||
'content': content,
|
|
||||||
'metadata': {
|
|
||||||
'source': str(fpath),
|
|
||||||
'type': 'full_report',
|
|
||||||
'name': name,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def load_corrections(corrections_dir: Path) -> list[dict]:
|
|
||||||
docs = []
|
|
||||||
for fpath in corrections_dir.glob('*.jrxml'):
|
|
||||||
content = fpath.read_text(encoding='utf-8')
|
|
||||||
docs.append({
|
|
||||||
'content': content,
|
|
||||||
'metadata': {
|
|
||||||
'source': str(fpath),
|
|
||||||
'type': 'correction_case',
|
|
||||||
'name': fpath.stem,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
persist_dir = os.getenv('CHROMA_PERSIST_DIR', './db/chroma')
|
parser = argparse.ArgumentParser(description="JRXML 向量知识库工具")
|
||||||
data_dir = Path(__file__).parent.parent / 'data'
|
parser.add_argument(
|
||||||
|
"--download-model", action="store_true",
|
||||||
template_dir = data_dir / 'sample_templates'
|
help="预下载嵌入模型到本地"
|
||||||
corrections_dir = data_dir / 'corrections'
|
|
||||||
|
|
||||||
docs = []
|
|
||||||
if template_dir.exists():
|
|
||||||
docs.extend(load_templates(template_dir))
|
|
||||||
print(f'从 {template_dir} 加载了 {len(docs)} 个模板')
|
|
||||||
|
|
||||||
if corrections_dir.exists():
|
|
||||||
corr = load_corrections(corrections_dir)
|
|
||||||
docs.extend(corr)
|
|
||||||
print(f'从 {corrections_dir} 加载了 {len(corr)} 个修正案例')
|
|
||||||
|
|
||||||
if not docs:
|
|
||||||
print('未找到文档,无需索引。')
|
|
||||||
return
|
|
||||||
|
|
||||||
embeddings = get_embeddings()
|
|
||||||
from langchain_chroma import Chroma
|
|
||||||
|
|
||||||
texts = [d['content'] for d in docs]
|
|
||||||
metadatas = [d['metadata'] for d in docs]
|
|
||||||
|
|
||||||
Chroma.from_texts(
|
|
||||||
texts=texts,
|
|
||||||
embedding=embeddings,
|
|
||||||
metadatas=metadatas,
|
|
||||||
persist_directory=persist_dir,
|
|
||||||
)
|
)
|
||||||
print(f'已将 {len(docs)} 个文档索引到 Chroma,存储位置: {persist_dir}')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser(description='初始化 Chroma 知识库')
|
|
||||||
parser.add_argument('--download-model', action='store_true', help='仅下载嵌入模型到本地')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.download_model:
|
if args.download_model:
|
||||||
download_embeddings_model()
|
download_model()
|
||||||
else:
|
else:
|
||||||
main()
|
print("用法: python scripts/init_kb.py --download-model")
|
||||||
|
print()
|
||||||
|
print("知识库构建请在 rag/ 子项目中独立运行:")
|
||||||
|
print(" cd rag")
|
||||||
|
print(" python batch_chunker.py jrxml_source")
|
||||||
|
print(" python embed_chunks.py")
|
||||||
|
print(" python import_to_chroma.py")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user