""" import_to_chroma.py 将 chunk 向量导入 Chroma 数据库 支持 JRXML chunks 和 Markdown chunks 混合导入 """ import os import json import sys import time from pathlib import Path import numpy as np import chromadb from tqdm import tqdm from config import EMBEDDINGS_DIR, CHROMA_DB_PATH, CHROMA_COLLECTION_NAME def main(embeddings_dir: str = None, chroma_path: str = None, collection_name: str = None, incremental: bool = False): """ 从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库 Args: embeddings_dir: 包含 embeddings.npy, chunks.json 的目录 chroma_path: Chroma 持久化目录 collection_name: 集合名称 """ project_root = Path(__file__).resolve().parent if embeddings_dir is None: embeddings_dir = EMBEDDINGS_DIR else: embeddings_dir = Path(embeddings_dir) if chroma_path is None: chroma_path = CHROMA_DB_PATH else: chroma_path = Path(chroma_path) if collection_name is None: collection_name = CHROMA_COLLECTION_NAME embeddings_file = embeddings_dir / "embeddings.npy" chunks_file = embeddings_dir / "chunks.json" for f in [embeddings_file, chunks_file]: if not f.exists(): print(f"❌ 缺少文件: {f}") print(f" 请先运行 embed_chunks.py 生成向量") return None print(f"\n{'='*60}") print(f"JRXML Chunks 导入 Chroma 数据库") print(f"{'='*60}") print(f"\n📂 加载向量和 chunks...") embeddings = np.load(embeddings_file).astype('float32') with open(chunks_file, 'r', encoding='utf-8') as f: chunks = json.load(f) if len(embeddings) != len(chunks): print(f"❌ 数量不匹配: {len(embeddings)} vs {len(chunks)}") return None print(f" 向量维度: {embeddings.shape[1]}") print(f" Chunks 数量: {len(chunks)}") print(f"\n💾 初始化 Chroma 数据库: {chroma_path}") chroma_path.mkdir(parents=True, exist_ok=True) client = chromadb.PersistentClient(path=str(chroma_path)) if incremental: try: collection = client.get_collection(collection_name) existing_ids = set(collection.get()['ids']) print(f" 增量模式: 集合 '{collection_name}' 已有 {len(existing_ids)} 条记录") except Exception: collection = client.create_collection( name=collection_name, metadata={"hnsw:space": "cosine"} ) existing_ids = set() print(f" 增量模式: 创建新集合 '{collection_name}'") else: try: client.delete_collection(collection_name) print(f" 已删除旧集合 '{collection_name}'") except Exception: pass collection = client.create_collection( name=collection_name, metadata={"hnsw:space": "cosine"} ) existing_ids = set() print(f"\n🛠️ 准备导入数据...") ids = [] documents = [] metadatas = [] embeddings_list = [] skipped = 0 seen_ids = {} for i, chunk in enumerate(tqdm(chunks, desc="准备数据")): raw_id = str(chunk.get("chunk_id", i)) context = chunk.get("context", "") if raw_id in seen_ids: seen_ids[raw_id] += 1 unique_chunk_id = f"{raw_id}_{seen_ids[raw_id]}" else: seen_ids[raw_id] = 0 unique_chunk_id = raw_id # 增量模式:跳过已导入的 if incremental and unique_chunk_id in existing_ids: skipped += 1 continue ids.append(unique_chunk_id) doc_text = chunk.get("human_description", "") documents.append(doc_text) meta = {} chunk_type = chunk.get("chunk_type", "") if chunk_type: meta["chunk_type"] = chunk_type if context: meta["context"] = context chunk_meta = chunk.get("metadata", {}) if "report_name" in chunk_meta: meta["report_name"] = chunk_meta["report_name"] if "band_name" in chunk_meta: meta["band_name"] = chunk_meta["band_name"] if "element_kind" in chunk_meta: meta["element_kind"] = chunk_meta["element_kind"] if "query_language" in chunk_meta: meta["query_language"] = chunk_meta["query_language"] # Markdown-specific metadata if "heading" in chunk_meta: meta["heading"] = chunk_meta["heading"] if "heading_level" in chunk_meta: meta["heading_level"] = chunk_meta["heading_level"] if "language" in chunk_meta: meta["code_language"] = chunk_meta["language"] metadatas.append(meta) embeddings_list.append(embeddings[i].tolist()) if incremental and skipped > 0: print(f" 增量模式: 跳过 {skipped} 条已存在记录") if not ids: print(f"\n✅ 没有新数据需要导入,集合已是最新") print(f" 数据库路径: {chroma_path}") print(f" 集合数量: {collection.count()}") return collection print(f"\n📥 分批导入到 Chroma (每批 1000 条)...") import_batch_size = 1000 start_time = time.time() for start in tqdm(range(0, len(ids), import_batch_size), desc="导入进度"): end = min(start + import_batch_size, len(ids)) collection.add( ids=ids[start:end], documents=documents[start:end], metadatas=metadatas[start:end], embeddings=embeddings_list[start:end] ) duration = time.time() - start_time print(f"\n✅ 成功导入 {len(ids)} 个 chunks 到 '{collection_name}'") print(f" 数据库路径: {chroma_path}") print(f" 集合数量: {collection.count()}") print(f" 导入耗时: {duration:.2f}s") print(f"\n🔍 快速验证查询...") results = collection.query( query_embeddings=[embeddings_list[0]], n_results=3, include=["documents", "metadatas", "distances"] ) distances = results.get('distances', [[]]) if distances and distances[0]: print(f" Top-3 相似度距离: {[round(d, 4) for d in distances[0]]}") first_doc = results.get('documents', [['']])[0][0] print(f" 首位结果: {first_doc[:120]}...") print(f"\n📊 元数据字段分布:") all_keys = set() for m in metadatas: all_keys.update(m.keys()) for key in sorted(all_keys): count = sum(1 for m in metadatas if key in m) print(f" {key}: {count}") return collection if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具") parser.add_argument("--embeddings_dir", "-e", default=None, help=f"向量文件目录 (默认: {EMBEDDINGS_DIR})") parser.add_argument("--chroma_path", "-c", default=None, help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})") parser.add_argument("--collection_name", "-n", default=CHROMA_COLLECTION_NAME, help=f"集合名称 (默认: {CHROMA_COLLECTION_NAME})") parser.add_argument("--incremental", "-i", action="store_true", help="增量模式:只导入新增记录,不删除已有数据") args = parser.parse_args() main( embeddings_dir=args.embeddings_dir, chroma_path=args.chroma_path, collection_name=args.collection_name, incremental=args.incremental )