""" embed_chunks.py 使用本地 Qwen3-Embedding-4B 模型对 JRXML chunks 进行向量化 支持 GPU (CUDA) 或 CPU """ import os import sys import json import pickle from pathlib import Path import numpy as np import torch from sentence_transformers import SentenceTransformer def build_text_for_embedding(chunk: dict) -> str: """ 将单个 chunk 转换为适合向量化的文本 拼接:类型、描述、上下文、关键元数据、部分 XML """ parts = [ f"[ChunkType: {chunk.get('chunk_type', 'unknown')}]", chunk.get('human_description', ''), ] context = chunk.get('context', '') if context: parts.append(f"Context: {context}") raw_xml = chunk.get('raw_xml', '') if raw_xml: parts.append(f"XML: {raw_xml[:500]}") meta = chunk.get('metadata', {}) if meta: if 'field_names' in meta: parts.append(f"Fields: {', '.join(meta['field_names'])}") if 'parameter_names' in meta: parts.append(f"Parameters: {', '.join(meta['parameter_names'])}") if 'report_name' in meta: parts.append(f"Report: {meta['report_name']}") if 'band_name' in meta: parts.append(f"Band: {meta['band_name']}") if 'element_kind' in meta: parts.append(f"Element: {meta['element_kind']}") if 'query_language' in meta: parts.append(f"QueryLang: {meta['query_language']}") return "\n".join(parts) def main(chunks_json_path: str = None, output_dir: str = None, model_path: str = None, batch_size: int = 64, normalize: bool = True, use_fp16: bool = True): """ 主流程: 1. 加载 chunk JSON 2. 加载嵌入模型 3. 构造文本并向量化 4. 保存向量及映射文件 """ project_root = Path(__file__).resolve().parent if chunks_json_path is None: chunks_json_path = project_root / "jrxml_chunker_output" / "all_chunks.json" else: chunks_json_path = Path(chunks_json_path) if output_dir is None: output_dir = project_root / "embeddings" else: output_dir = Path(output_dir) if model_path is None: model_path = project_root / "models" / "Qwen3-Embedding-4B" else: model_path = Path(model_path) if not chunks_json_path.exists(): print(f"❌ Chunks 文件不存在: {chunks_json_path}") print(f" 请先运行 jrxml_banch_chunker.py 生成 chunks") return None print(f"\n{'='*60}") print(f"JRXML Chunks 向量化") print(f"{'='*60}") print(f"📄 加载 chunks: {chunks_json_path}") with open(chunks_json_path, 'r', encoding='utf-8') as f: chunks = json.load(f) print(f" Total chunks: {len(chunks)}") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"\n🧠 加载嵌入模型: {model_path}") print(f" 设备: {device}") # 检查是否是 HuggingFace Hub 模型(格式为 username/model_name) model_path_str = str(model_path) # Windows PowerShell 会把 / 自动转成 \,需要还原 if "\\" in model_path_str and not os.path.exists(model_path_str): model_path_str = model_path_str.replace("\\", "/") is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str) # 如果是本地路径但不存在,则报错 if not is_hub_model and not os.path.exists(model_path_str): print(f"❌ 模型目录不存在: {model_path}") print(f" 请先下载模型到 {model_path}") print(f" 或者使用 HuggingFace Hub 模型,例如: sentence-transformers/all-MiniLM-L6-v2") return None model = SentenceTransformer(model_path_str, device=device) if device == "cuda" and use_fp16: model = model.half() torch.cuda.empty_cache() mem_used = torch.cuda.memory_allocated(0) / 1024**3 total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 print(f" FP16 已启用") print(f" GPU: {torch.cuda.get_device_name(0)}") print(f" GPU memory: {mem_used:.2f} GB / {total_mem:.2f} GB (FP16)") elif device == "cuda": print(f" GPU: {torch.cuda.get_device_name(0)}") print(f" GPU memory: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB") print(f"\n🛠️ 构建文本表示...") texts = [] chunk_ids = [] chunk_types = [] for chunk in chunks: texts.append(build_text_for_embedding(chunk)) chunk_ids.append(chunk.get('chunk_id', -1)) chunk_types.append(chunk.get('chunk_type', 'unknown')) print(f"\n🔢 向量化 {len(texts)} 个文本 (batch_size={batch_size})...") embeddings = model.encode( texts, batch_size=batch_size, show_progress_bar=True, normalize_embeddings=normalize, convert_to_numpy=True ) print(f" Embeddings shape: {embeddings.shape}") output_dir.mkdir(parents=True, exist_ok=True) np.save(output_dir / "embeddings.npy", embeddings.astype('float32')) with open(output_dir / "chunk_id_map.json", 'w', encoding='utf-8') as f: json.dump(chunk_ids, f, ensure_ascii=False, indent=2) with open(output_dir / "chunk_type_map.json", 'w', encoding='utf-8') as f: json.dump(chunk_types, f, ensure_ascii=False, indent=2) with open(output_dir / "chunks.json", 'w', encoding='utf-8') as f: json.dump(chunks, f, ensure_ascii=False, indent=2) with open(output_dir / "embeddings.pkl", 'wb') as f: pickle.dump({ 'chunks': chunks, 'embeddings': embeddings, 'texts': texts, 'normalized': normalize }, f) nan_count = np.isnan(embeddings).sum() print(f"\n📊 质量检查:") print(f" NaN values: {nan_count}") norms = np.linalg.norm(embeddings, axis=1) print(f" Norms: min={norms.min():.4f}, max={norms.max():.4f}, mean={norms.mean():.4f}") print(f"\n✅ 向量数据已保存到: {output_dir}/") print(f" 文件: embeddings.npy, chunk_id_map.json, chunk_type_map.json, chunks.json, embeddings.pkl") type_counts = {} for ct in chunk_types: type_counts[ct] = type_counts.get(ct, 0) + 1 print(f"\n📈 Chunk 类型分布:") for ct, count in sorted(type_counts.items(), key=lambda x: -x[1]): print(f" {ct}: {count}") return { "chunks": len(chunks), "embedding_dim": embeddings.shape[1], "output_dir": str(output_dir) } if __name__ == "__main__": import argparse project_root = Path(__file__).resolve().parent default_chunks = project_root / "jrxml_chunker_output" / "all_chunks.json" parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具") parser.add_argument("chunks_json", nargs="?", default=str(default_chunks), help=f"Chunks JSON 文件路径 (默认: {default_chunks})") parser.add_argument("--output_dir", "-o", default=None, help="输出目录 (默认: embeddings)") parser.add_argument("--model_path", "-m", default=None, help="模型路径 (默认: models/Qwen3-Embedding-4B)") parser.add_argument("--batch_size", "-b", type=int, default=64, help="批处理大小 (默认: 64)") parser.add_argument("--no_normalize", action="store_true", help="不做向量归一化") parser.add_argument("--no_fp16", action="store_true", help="禁用 FP16 半精度(默认启用,可节省约 50%% 显存)") args = parser.parse_args() main( chunks_json_path=args.chunks_json, output_dir=args.output_dir, model_path=args.model_path, batch_size=args.batch_size, normalize=not args.no_normalize, use_fp16=not args.no_fp16 )