Files
rag_jrxml/import_to_chroma.py
T
panda 9d78a49625 refactor: 重构项目配置管理,统一使用.env配置
- 新增config.py统一读取.env配置,移除硬编码路径和参数
- 重构collect_jrxml.py支持命令行参数和环境变量配置源目录
- 新增.env.example示例配置文件,整理所有可配置项
- 重构down_embedding_model.py、import_to_chroma.py等所有脚本使用统一配置
- 新增Windows一键部署脚本setup.bat
- 修正jrxml_banch_chunker.py的文件名拼写错误
2026-05-12 08:29:17 +08:00

183 lines
5.8 KiB
Python

"""
import_to_chroma.py
将已生成的 chunk 向量导入 Chroma 数据库
"""
import os
import json
import sys
import time
from pathlib import Path
import numpy as np
import chromadb
from tqdm import tqdm
from config import EMBEDDINGS_DIR, CHROMA_DB_PATH, CHROMA_COLLECTION_NAME
def main(embeddings_dir: str = None,
chroma_path: str = None,
collection_name: str = None):
"""
从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库
Args:
embeddings_dir: 包含 embeddings.npy, chunks.json 的目录
chroma_path: Chroma 持久化目录
collection_name: 集合名称
"""
project_root = Path(__file__).resolve().parent
if embeddings_dir is None:
embeddings_dir = EMBEDDINGS_DIR
else:
embeddings_dir = Path(embeddings_dir)
if chroma_path is None:
chroma_path = CHROMA_DB_PATH
else:
chroma_path = Path(chroma_path)
if collection_name is None:
collection_name = CHROMA_COLLECTION_NAME
embeddings_file = embeddings_dir / "embeddings.npy"
chunks_file = embeddings_dir / "chunks.json"
for f in [embeddings_file, chunks_file]:
if not f.exists():
print(f"❌ 缺少文件: {f}")
print(f" 请先运行 embed_chunks.py 生成向量")
return None
print(f"\n{'='*60}")
print(f"JRXML Chunks 导入 Chroma 数据库")
print(f"{'='*60}")
print(f"\n📂 加载向量和 chunks...")
embeddings = np.load(embeddings_file).astype('float32')
with open(chunks_file, 'r', encoding='utf-8') as f:
chunks = json.load(f)
if len(embeddings) != len(chunks):
print(f"❌ 数量不匹配: {len(embeddings)} vs {len(chunks)}")
return None
print(f" 向量维度: {embeddings.shape[1]}")
print(f" Chunks 数量: {len(chunks)}")
print(f"\n💾 初始化 Chroma 数据库: {chroma_path}")
chroma_path.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(path=str(chroma_path))
try:
client.delete_collection(collection_name)
print(f" 已删除旧集合 '{collection_name}'")
except Exception:
pass
collection = client.create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
print(f"\n🛠️ 准备导入数据...")
ids = []
documents = []
metadatas = []
embeddings_list = []
seen_ids = {}
for i, chunk in enumerate(tqdm(chunks, desc="准备数据")):
raw_id = str(chunk.get("chunk_id", i))
if raw_id in seen_ids:
seen_ids[raw_id] += 1
chunk_id = f"{raw_id}_{seen_ids[raw_id]}"
else:
seen_ids[raw_id] = 0
chunk_id = raw_id
ids.append(chunk_id)
doc_text = chunk.get("human_description", "")
documents.append(doc_text)
meta = {}
chunk_type = chunk.get("chunk_type", "")
if chunk_type:
meta["chunk_type"] = chunk_type
context = chunk.get("context", "")
if context:
meta["context"] = context
chunk_meta = chunk.get("metadata", {})
if "report_name" in chunk_meta:
meta["report_name"] = chunk_meta["report_name"]
if "band_name" in chunk_meta:
meta["band_name"] = chunk_meta["band_name"]
if "element_kind" in chunk_meta:
meta["element_kind"] = chunk_meta["element_kind"]
if "query_language" in chunk_meta:
meta["query_language"] = chunk_meta["query_language"]
metadatas.append(meta)
embeddings_list.append(embeddings[i].tolist())
print(f"\n📥 分批导入到 Chroma (每批 1000 条)...")
import_batch_size = 1000
start_time = time.time()
for start in tqdm(range(0, len(ids), import_batch_size), desc="导入进度"):
end = min(start + import_batch_size, len(ids))
collection.add(
ids=ids[start:end],
documents=documents[start:end],
metadatas=metadatas[start:end],
embeddings=embeddings_list[start:end]
)
duration = time.time() - start_time
print(f"\n✅ 成功导入 {len(ids)} 个 chunks 到 '{collection_name}'")
print(f" 数据库路径: {chroma_path}")
print(f" 集合数量: {collection.count()}")
print(f" 导入耗时: {duration:.2f}s")
print(f"\n🔍 快速验证查询...")
results = collection.query(
query_embeddings=[embeddings_list[0]],
n_results=3,
include=["documents", "metadatas", "distances"]
)
distances = results.get('distances', [[]])
if distances and distances[0]:
print(f" Top-3 相似度距离: {[round(d, 4) for d in distances[0]]}")
first_doc = results.get('documents', [['']])[0][0]
print(f" 首位结果: {first_doc[:120]}...")
print(f"\n📊 元数据字段分布:")
all_keys = set()
for m in metadatas:
all_keys.update(m.keys())
for key in sorted(all_keys):
count = sum(1 for m in metadatas if key in m)
print(f" {key}: {count}")
return collection
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具")
parser.add_argument("--embeddings_dir", "-e", default=None,
help=f"向量文件目录 (默认: {EMBEDDINGS_DIR})")
parser.add_argument("--chroma_path", "-c", default=None,
help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})")
parser.add_argument("--collection_name", "-n", default=CHROMA_COLLECTION_NAME,
help=f"集合名称 (默认: {CHROMA_COLLECTION_NAME})")
args = parser.parse_args()
main(
embeddings_dir=args.embeddings_dir,
chroma_path=args.chroma_path,
collection_name=args.collection_name
)