feat: integrate RAG rag_jrxml submodule and fix Anthropic API key

Add rag submodule for semantic JRXML chunk retrieval, refactor
retrieve node to use RAGSearcher, and fix missing api_key in
Anthropic SDK client initialization.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-19 09:42:57 +08:00
parent 4416c20b77
commit b280c2b453
10 changed files with 248 additions and 115 deletions
+20
View File
@@ -22,6 +22,26 @@ VALIDATION_SERVICE_URL=http://localhost:8001/validate
# Chroma 持久化目录
CHROMA_PERSIST_DIR=./db/chroma
# ---- RAG / 向量知识库 (rag_jrxml 子模块) ----
# 嵌入模型
RAG_EMBED_MODEL=sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
# JRXML 模板源目录 (rag 子模块内已含 107 个模板)
RAG_JRXML_SOURCE=./rag/jrxml_source
# 分块输出目录
RAG_CHUNKER_OUTPUT=./rag/jrxml_chunker_output
# 向量输出目录
RAG_EMBEDDINGS_DIR=./rag/embeddings
# ChromaDB 知识库路径
RAG_CHROMA_PATH=./db/chroma
# ChromaDB 集合名称
RAG_COLLECTION_NAME=jrxml_chunks
# GPU 加速
RAG_USE_GPU=true
# FP16 半精度
RAG_USE_FP16=true
# 向量化批处理大小
RAG_BATCH_SIZE=64
# 最大自动修正尝试次数
MAX_RETRY=3
+8
View File
@@ -12,6 +12,14 @@ dist/
db/chroma/
sessions/
# RAG 管线中间产物 (rag 子模块内)
rag/jrxml_chunker_output/
rag/embeddings/
rag/models/
rag/__pycache__/
rag/chroma_db/
rag/jrxml_source_chunks/
# IDE
.idea/
.vscode/
+3
View File
@@ -0,0 +1,3 @@
[submodule "rag"]
path = rag
url = http://www.1415243231.top:8418/panda/rag_jrxml.git
+4 -18
View File
@@ -10,7 +10,6 @@ from typing import Dict
from dotenv import load_dotenv
from agent.state import AgentState
from backend.embeddings import get_embeddings
from backend.llm import get_llm
from backend.validation import validate_jrxml
@@ -422,26 +421,13 @@ def _now_iso() -> str:
def retrieve(state: AgentState) -> Dict:
"""在 Chroma 中搜索相关的 JRXML 模板和组件。"""
"""在 Chroma 中搜索相关的 JRXML 模板和组件(使用 rag_jrxml 语义分块管线)"""
try:
embeddings = get_embeddings()
from langchain_chroma import Chroma
from backend.rag_adapter import search_chunks
persist_dir = os.getenv("CHROMA_PERSIST_DIR", "./db/chroma")
if not os.path.exists(persist_dir) or not os.listdir(persist_dir):
state["retrieved_context"] = ""
return state
vectorstore = Chroma(
embedding_function=embeddings,
persist_directory=persist_dir,
)
user_input = state.get("user_input", "")
docs = vectorstore.similarity_search(user_input, k=5)
context_parts = []
for d in docs:
context_parts.append(d.page_content)
state["retrieved_context"] = "\n\n---\n\n".join(context_parts)
context = search_chunks(user_input, k=5)
state["retrieved_context"] = context
except Exception:
state["retrieved_context"] = ""
return state
+24 -2
View File
@@ -1,4 +1,9 @@
"""嵌入模型工厂:支持本地 sentence-transformers 和云端 API。"""
"""嵌入模型工厂:支持本地 Sentence-Transformers 和云端 API。
调用方式:
get_embeddings() → LangChain 兼容的 embeddings 对象
get_st_embeddings() → 原始 SentenceTransformer 实例
"""
import os
from dotenv import load_dotenv
@@ -7,6 +12,7 @@ load_dotenv()
def get_embeddings():
"""返回 LangChain 兼容的 embeddings 对象(用于 langchain_chroma 等)。"""
backend = os.getenv("EMBED_BACKEND", "local")
if backend == "cloud":
from langchain_openai import OpenAIEmbeddings
@@ -22,5 +28,21 @@ def get_embeddings():
except ImportError:
from langchain_community.embeddings import HuggingFaceEmbeddings
model = os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B")
model = os.getenv("RAG_EMBED_MODEL", os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B"))
return HuggingFaceEmbeddings(model_name=model)
def get_st_model():
"""返回原始 SentenceTransformer 实例(与 rag_jrxml 子模块使用方式一致)。"""
import torch
from sentence_transformers import SentenceTransformer
model_name = os.getenv("RAG_EMBED_MODEL", os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B"))
use_gpu = os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1")
use_fp16 = os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1")
device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu"
model = SentenceTransformer(model_name, device=device)
if device == "cuda" and use_fp16:
model = model.half()
return model
+1 -1
View File
@@ -28,7 +28,7 @@ def get_llm():
os.environ["NO_PROXY"] = "*"
client = Anthropic(base_url=base_url, timeout=120)
client = Anthropic(api_key=api_key, base_url=base_url, timeout=120)
class MiniMaxLLM:
def invoke(self, prompt: str) -> Any:
+155
View File
@@ -0,0 +1,155 @@
"""RAG 适配层 — 查询已由 rag_jrxml 子项目构建好的 ChromaDB 向量知识库。
rag_jrxml 独立运行产出向量库后,主项目通过此模块进行语义搜索。
用法:
from backend.rag_adapter import search_chunks
context = search_chunks("如何添加饼图", k=5)
"""
import os
import logging
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
def _resolve(path: str) -> Path:
p = Path(path)
if not p.is_absolute():
p = _PROJECT_ROOT / p
return p
class RAGSearcher:
"""连接预构建的 ChromaDB,提供语义搜索。"""
def __init__(
self,
chroma_path: Optional[str] = None,
collection_name: Optional[str] = None,
model_name: Optional[str] = None,
use_gpu: Optional[bool] = None,
use_fp16: Optional[bool] = None,
):
self.chroma_path = _resolve(chroma_path or os.getenv("RAG_CHROMA_PATH", "./db/chroma"))
self.collection_name = collection_name or os.getenv("RAG_COLLECTION_NAME", "jrxml_chunks")
model_path = model_name or os.getenv("RAG_EMBED_MODEL", "./rag/models/paraphrase-multilingual-MiniLM-L12-v2")
# 如果本地路径存在则使用本地,否则当 Hub 模型名使用
resolved = _resolve(model_path)
self.model_name = str(resolved) if resolved.exists() else model_path
self.use_gpu = use_gpu if use_gpu is not None else os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1")
self.use_fp16 = use_fp16 if use_fp16 is not None else os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1")
self._model = None
self._client = None
self._collection = None
# ---- 模型懒加载 ----
@property
def model(self):
if self._model is None:
import torch
from sentence_transformers import SentenceTransformer
device = "cuda" if (self.use_gpu and torch.cuda.is_available()) else "cpu"
logger.info("加载嵌入模型: %s (device=%s)", self.model_name, device)
model = SentenceTransformer(self.model_name, device=device)
if device == "cuda" and self.use_fp16:
model = model.half()
self._model = model
return self._model
# ---- ChromaDB 懒连接 ----
@property
def client(self):
if self._client is None:
import chromadb
self._client = chromadb.PersistentClient(path=str(self.chroma_path))
return self._client
@property
def collection(self):
if self._collection is None:
self._collection = self.client.get_collection(self.collection_name)
return self._collection
def is_ready(self) -> bool:
try:
self.client.get_collection(self.collection_name)
return True
except Exception:
return False
# ---- 语义搜索 ----
def search(self, query: str, k: int = 5, threshold: Optional[float] = None) -> list[dict]:
"""搜索相关 JRXML chunks,返回 [{id, content, metadata, distance}, ...]."""
if not self.is_ready():
logger.warning("ChromaDB 集合 '%s' 不存在,请先在 rag/ 子项目中运行管线", self.collection_name)
return []
query_embedding = self.model.encode(
query, normalize_embeddings=True, show_progress_bar=False
)
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=k,
include=["documents", "metadatas", "distances"],
)
output = []
if not results["ids"] or not results["ids"][0]:
return output
for i, doc_id in enumerate(results["ids"][0]):
dist = results["distances"][0][i]
if threshold is not None and dist > threshold:
continue
output.append({
"id": doc_id,
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"distance": dist,
})
return output
def search_as_context(self, query: str, k: int = 5) -> str:
"""搜索并返回拼接好的上下文字符串,可直接注入 LLM prompt。"""
results = self.search(query, k=k)
if not results:
return ""
parts = []
for r in results:
meta = r["metadata"]
header = f"[类型:{meta.get('chunk_type', 'N/A')}]"
if meta.get("report_name"):
header += f" [报表:{meta['report_name']}]"
if meta.get("band_name"):
header += f" [区域:{meta['band_name']}]"
parts.append(f"{header}\n{r['content']}")
return "\n\n---\n\n".join(parts)
# 全局单例,避免重复加载模型
_searcher: Optional[RAGSearcher] = None
def _get_searcher() -> RAGSearcher:
global _searcher
if _searcher is None:
_searcher = RAGSearcher()
return _searcher
def search_chunks(query: str, k: int = 5) -> str:
"""搜索 JRXML 知识库并返回拼接后的上下文文本(便捷函数)。"""
return _get_searcher().search_as_context(query, k=k)
Submodule
+1
Submodule rag added at 687b3a8f90
+3
View File
@@ -17,6 +17,9 @@ lxml>=5.3.0
# 嵌入模型(本地)
sentence-transformers>=3.0.0
torch>=2.0.0
huggingface_hub>=0.19.0
tqdm>=4.65.0
# 工具类
python-dotenv>=1.0.0
+28 -93
View File
@@ -1,120 +1,55 @@
"""初始化 Chroma 知识库,加载示例 JRXML 模板和错误修正案例
"""初始化 JRXML 向量知识库
用法: python scripts/init_kb.py
rag_jrxml 子项目独立运行管线(分块→向量化→导入),本脚本仅用于预下载嵌入模型。
用法:
python scripts/init_kb.py --download-model # 预下载嵌入模型
"""
import os
import sys
import argparse
from pathlib import Path
from dotenv import load_dotenv
sys.path.insert(0, str(Path(__file__).parent.parent))
load_dotenv()
def download_embeddings_model():
"""预下载 Qwen3-Embedding 模型(从 HuggingFace)。
用法: python scripts/init_kb.py --download-model
"""
model_name = os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B")
def download_model():
"""预下载嵌入模型到本地。"""
model_name = os.getenv("RAG_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B")
print(f"正在下载嵌入模型: {model_name}")
print("如遇网络超时,可手动执行以下命令后重试:")
print(f" huggingface-cli download {model_name} --local-dir ./models/{model_name.replace('/', '_')}")
print("如遇网络超时,可设置环境变量 HF_ENDPOINT=https://hf-mirror.com 使用镜像")
print()
try:
from langchain_huggingface import HuggingFaceEmbeddings
except ImportError:
print("错误: 请先安装 huggingface 依赖")
print(" pip install langchain-huggingface sentence-transformers")
return
from sentence_transformers import SentenceTransformer
# HuggingFaceEmbeddings 会在首次调用时自动下载模型
embeddings = HuggingFaceEmbeddings(model_name=model_name)
# 调用一次以确保完全下载
embeddings.embed_query("测试")
model = SentenceTransformer(model_name)
model.encode("测试下载")
print(f"嵌入模型下载完成: {model_name}")
from backend.embeddings import get_embeddings
def load_templates(template_dir: Path) -> list[dict]:
docs = []
for fpath in template_dir.glob('*.jrxml'):
content = fpath.read_text(encoding='utf-8')
name = fpath.stem
docs.append({
'content': content,
'metadata': {
'source': str(fpath),
'type': 'full_report',
'name': name,
},
})
return docs
def load_corrections(corrections_dir: Path) -> list[dict]:
docs = []
for fpath in corrections_dir.glob('*.jrxml'):
content = fpath.read_text(encoding='utf-8')
docs.append({
'content': content,
'metadata': {
'source': str(fpath),
'type': 'correction_case',
'name': fpath.stem,
},
})
return docs
def main():
persist_dir = os.getenv('CHROMA_PERSIST_DIR', './db/chroma')
data_dir = Path(__file__).parent.parent / 'data'
template_dir = data_dir / 'sample_templates'
corrections_dir = data_dir / 'corrections'
docs = []
if template_dir.exists():
docs.extend(load_templates(template_dir))
print(f'{template_dir} 加载了 {len(docs)} 个模板')
if corrections_dir.exists():
corr = load_corrections(corrections_dir)
docs.extend(corr)
print(f'{corrections_dir} 加载了 {len(corr)} 个修正案例')
if not docs:
print('未找到文档,无需索引。')
return
embeddings = get_embeddings()
from langchain_chroma import Chroma
texts = [d['content'] for d in docs]
metadatas = [d['metadata'] for d in docs]
Chroma.from_texts(
texts=texts,
embedding=embeddings,
metadatas=metadatas,
persist_directory=persist_dir,
parser = argparse.ArgumentParser(description="JRXML 向量知识库工具")
parser.add_argument(
"--download-model", action="store_true",
help="预下载嵌入模型到本地"
)
print(f'已将 {len(docs)} 个文档索引到 Chroma,存储位置: {persist_dir}')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='初始化 Chroma 知识库')
parser.add_argument('--download-model', action='store_true', help='仅下载嵌入模型到本地')
args = parser.parse_args()
if args.download_model:
download_embeddings_model()
download_model()
else:
print("用法: python scripts/init_kb.py --download-model")
print()
print("知识库构建请在 rag/ 子项目中独立运行:")
print(" cd rag")
print(" python batch_chunker.py jrxml_source")
print(" python embed_chunks.py")
print(" python import_to_chroma.py")
if __name__ == "__main__":
main()