""" import_to_chroma.py 将已生成的 chunk 向量导入 Chroma 数据库 """ import os import json import sys import time from pathlib import Path import numpy as np import chromadb from tqdm import tqdm from config import EMBEDDINGS_DIR, CHROMA_DB_PATH, CHROMA_COLLECTION_NAME def main(embeddings_dir: str = None, chroma_path: str = None, collection_name: str = None): """ 从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库 Args: embeddings_dir: 包含 embeddings.npy, chunks.json 的目录 chroma_path: Chroma 持久化目录 collection_name: 集合名称 """ project_root = Path(__file__).resolve().parent if embeddings_dir is None: embeddings_dir = EMBEDDINGS_DIR else: embeddings_dir = Path(embeddings_dir) if chroma_path is None: chroma_path = CHROMA_DB_PATH else: chroma_path = Path(chroma_path) if collection_name is None: collection_name = CHROMA_COLLECTION_NAME embeddings_file = embeddings_dir / "embeddings.npy" chunks_file = embeddings_dir / "chunks.json" for f in [embeddings_file, chunks_file]: if not f.exists(): print(f"❌ 缺少文件: {f}") print(f" 请先运行 embed_chunks.py 生成向量") return None print(f"\n{'='*60}") print(f"JRXML Chunks 导入 Chroma 数据库") print(f"{'='*60}") print(f"\n📂 加载向量和 chunks...") embeddings = np.load(embeddings_file).astype('float32') with open(chunks_file, 'r', encoding='utf-8') as f: chunks = json.load(f) if len(embeddings) != len(chunks): print(f"❌ 数量不匹配: {len(embeddings)} vs {len(chunks)}") return None print(f" 向量维度: {embeddings.shape[1]}") print(f" Chunks 数量: {len(chunks)}") print(f"\n💾 初始化 Chroma 数据库: {chroma_path}") chroma_path.mkdir(parents=True, exist_ok=True) client = chromadb.PersistentClient(path=str(chroma_path)) try: client.delete_collection(collection_name) print(f" 已删除旧集合 '{collection_name}'") except Exception: pass collection = client.create_collection( name=collection_name, metadata={"hnsw:space": "cosine"} ) print(f"\n🛠️ 准备导入数据...") ids = [] documents = [] metadatas = [] embeddings_list = [] seen_ids = {} for i, chunk in enumerate(tqdm(chunks, desc="准备数据")): raw_id = str(chunk.get("chunk_id", i)) if raw_id in seen_ids: seen_ids[raw_id] += 1 chunk_id = f"{raw_id}_{seen_ids[raw_id]}" else: seen_ids[raw_id] = 0 chunk_id = raw_id ids.append(chunk_id) doc_text = chunk.get("human_description", "") documents.append(doc_text) meta = {} chunk_type = chunk.get("chunk_type", "") if chunk_type: meta["chunk_type"] = chunk_type context = chunk.get("context", "") if context: meta["context"] = context chunk_meta = chunk.get("metadata", {}) if "report_name" in chunk_meta: meta["report_name"] = chunk_meta["report_name"] if "band_name" in chunk_meta: meta["band_name"] = chunk_meta["band_name"] if "element_kind" in chunk_meta: meta["element_kind"] = chunk_meta["element_kind"] if "query_language" in chunk_meta: meta["query_language"] = chunk_meta["query_language"] metadatas.append(meta) embeddings_list.append(embeddings[i].tolist()) print(f"\n📥 分批导入到 Chroma (每批 1000 条)...") import_batch_size = 1000 start_time = time.time() for start in tqdm(range(0, len(ids), import_batch_size), desc="导入进度"): end = min(start + import_batch_size, len(ids)) collection.add( ids=ids[start:end], documents=documents[start:end], metadatas=metadatas[start:end], embeddings=embeddings_list[start:end] ) duration = time.time() - start_time print(f"\n✅ 成功导入 {len(ids)} 个 chunks 到 '{collection_name}'") print(f" 数据库路径: {chroma_path}") print(f" 集合数量: {collection.count()}") print(f" 导入耗时: {duration:.2f}s") print(f"\n🔍 快速验证查询...") results = collection.query( query_embeddings=[embeddings_list[0]], n_results=3, include=["documents", "metadatas", "distances"] ) distances = results.get('distances', [[]]) if distances and distances[0]: print(f" Top-3 相似度距离: {[round(d, 4) for d in distances[0]]}") first_doc = results.get('documents', [['']])[0][0] print(f" 首位结果: {first_doc[:120]}...") print(f"\n📊 元数据字段分布:") all_keys = set() for m in metadatas: all_keys.update(m.keys()) for key in sorted(all_keys): count = sum(1 for m in metadatas if key in m) print(f" {key}: {count}") return collection if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具") parser.add_argument("--embeddings_dir", "-e", default=None, help=f"向量文件目录 (默认: {EMBEDDINGS_DIR})") parser.add_argument("--chroma_path", "-c", default=None, help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})") parser.add_argument("--collection_name", "-n", default=CHROMA_COLLECTION_NAME, help=f"集合名称 (默认: {CHROMA_COLLECTION_NAME})") args = parser.parse_args() main( embeddings_dir=args.embeddings_dir, chroma_path=args.chroma_path, collection_name=args.collection_name )