Files
agent_jrxml/scripts/init_kb.py
T
2026-05-15 08:29:01 +08:00

121 lines
3.5 KiB
Python

"""初始化 Chroma 知识库,加载示例 JRXML 模板和错误修正案例。
用法: python scripts/init_kb.py
"""
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
sys.path.insert(0, str(Path(__file__).parent.parent))
load_dotenv()
def download_embeddings_model():
"""预下载 Qwen3-Embedding 模型(从 HuggingFace)。
用法: python scripts/init_kb.py --download-model
"""
model_name = os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B")
print(f"正在下载嵌入模型: {model_name}")
print("如遇网络超时,可手动执行以下命令后重试:")
print(f" huggingface-cli download {model_name} --local-dir ./models/{model_name.replace('/', '_')}")
print()
try:
from langchain_huggingface import HuggingFaceEmbeddings
except ImportError:
print("错误: 请先安装 huggingface 依赖")
print(" pip install langchain-huggingface sentence-transformers")
return
# HuggingFaceEmbeddings 会在首次调用时自动下载模型
embeddings = HuggingFaceEmbeddings(model_name=model_name)
# 调用一次以确保完全下载
embeddings.embed_query("测试")
print(f"嵌入模型下载完成: {model_name}")
from backend.embeddings import get_embeddings
def load_templates(template_dir: Path) -> list[dict]:
docs = []
for fpath in template_dir.glob('*.jrxml'):
content = fpath.read_text(encoding='utf-8')
name = fpath.stem
docs.append({
'content': content,
'metadata': {
'source': str(fpath),
'type': 'full_report',
'name': name,
},
})
return docs
def load_corrections(corrections_dir: Path) -> list[dict]:
docs = []
for fpath in corrections_dir.glob('*.jrxml'):
content = fpath.read_text(encoding='utf-8')
docs.append({
'content': content,
'metadata': {
'source': str(fpath),
'type': 'correction_case',
'name': fpath.stem,
},
})
return docs
def main():
persist_dir = os.getenv('CHROMA_PERSIST_DIR', './db/chroma')
data_dir = Path(__file__).parent.parent / 'data'
template_dir = data_dir / 'sample_templates'
corrections_dir = data_dir / 'corrections'
docs = []
if template_dir.exists():
docs.extend(load_templates(template_dir))
print(f'{template_dir} 加载了 {len(docs)} 个模板')
if corrections_dir.exists():
corr = load_corrections(corrections_dir)
docs.extend(corr)
print(f'{corrections_dir} 加载了 {len(corr)} 个修正案例')
if not docs:
print('未找到文档,无需索引。')
return
embeddings = get_embeddings()
from langchain_chroma import Chroma
texts = [d['content'] for d in docs]
metadatas = [d['metadata'] for d in docs]
Chroma.from_texts(
texts=texts,
embedding=embeddings,
metadatas=metadatas,
persist_directory=persist_dir,
)
print(f'已将 {len(docs)} 个文档索引到 Chroma,存储位置: {persist_dir}')
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='初始化 Chroma 知识库')
parser.add_argument('--download-model', action='store_true', help='仅下载嵌入模型到本地')
args = parser.parse_args()
if args.download_model:
download_embeddings_model()
else:
main()