Files
rag_jrxml/embed_chunks.py
T
panda bd98486de0 chore: 初始化JRXML RAG项目,添加基础文件
创建了完整的JRXML语义检索RAG项目,包含:
1. 新增.gitignore忽略项目生成的缓存、依赖目录和本地文件
2. 编写详细的项目README文档
3. 补充文件功能说明文档
4. 实现向量导入、向量化、查询等核心脚本
2026-05-12 08:14:55 +08:00

211 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
embed_chunks.py
使用本地 Qwen3-Embedding-4B 模型对 JRXML chunks 进行向量化
支持 GPU (CUDA) 或 CPU
"""
import os
import sys
import json
import pickle
from pathlib import Path
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
def build_text_for_embedding(chunk: dict) -> str:
"""
将单个 chunk 转换为适合向量化的文本
拼接:类型、描述、上下文、关键元数据、部分 XML
"""
parts = [
f"[ChunkType: {chunk.get('chunk_type', 'unknown')}]",
chunk.get('human_description', ''),
]
context = chunk.get('context', '')
if context:
parts.append(f"Context: {context}")
raw_xml = chunk.get('raw_xml', '')
if raw_xml:
parts.append(f"XML: {raw_xml[:500]}")
meta = chunk.get('metadata', {})
if meta:
if 'field_names' in meta:
parts.append(f"Fields: {', '.join(meta['field_names'])}")
if 'parameter_names' in meta:
parts.append(f"Parameters: {', '.join(meta['parameter_names'])}")
if 'report_name' in meta:
parts.append(f"Report: {meta['report_name']}")
if 'band_name' in meta:
parts.append(f"Band: {meta['band_name']}")
if 'element_kind' in meta:
parts.append(f"Element: {meta['element_kind']}")
if 'query_language' in meta:
parts.append(f"QueryLang: {meta['query_language']}")
return "\n".join(parts)
def main(chunks_json_path: str = None, output_dir: str = None,
model_path: str = None, batch_size: int = 64, normalize: bool = True,
use_fp16: bool = True):
"""
主流程:
1. 加载 chunk JSON
2. 加载嵌入模型
3. 构造文本并向量化
4. 保存向量及映射文件
"""
project_root = Path(__file__).resolve().parent
if chunks_json_path is None:
chunks_json_path = project_root / "jrxml_chunker_output" / "all_chunks.json"
else:
chunks_json_path = Path(chunks_json_path)
if output_dir is None:
output_dir = project_root / "embeddings"
else:
output_dir = Path(output_dir)
if model_path is None:
model_path = project_root / "models" / "Qwen3-Embedding-4B"
else:
model_path = Path(model_path)
if not chunks_json_path.exists():
print(f"❌ Chunks 文件不存在: {chunks_json_path}")
print(f" 请先运行 jrxml_banch_chunker.py 生成 chunks")
return None
print(f"\n{'='*60}")
print(f"JRXML Chunks 向量化")
print(f"{'='*60}")
print(f"📄 加载 chunks: {chunks_json_path}")
with open(chunks_json_path, 'r', encoding='utf-8') as f:
chunks = json.load(f)
print(f" Total chunks: {len(chunks)}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n🧠 加载嵌入模型: {model_path}")
print(f" 设备: {device}")
# 检查是否是 HuggingFace Hub 模型(格式为 username/model_name
model_path_str = str(model_path)
# Windows PowerShell 会把 / 自动转成 \,需要还原
if "\\" in model_path_str and not os.path.exists(model_path_str):
model_path_str = model_path_str.replace("\\", "/")
is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str)
# 如果是本地路径但不存在,则报错
if not is_hub_model and not os.path.exists(model_path_str):
print(f"❌ 模型目录不存在: {model_path}")
print(f" 请先下载模型到 {model_path}")
print(f" 或者使用 HuggingFace Hub 模型,例如: sentence-transformers/all-MiniLM-L6-v2")
return None
model = SentenceTransformer(model_path_str, device=device)
if device == "cuda" and use_fp16:
model = model.half()
torch.cuda.empty_cache()
mem_used = torch.cuda.memory_allocated(0) / 1024**3
total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
print(f" FP16 已启用")
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" GPU memory: {mem_used:.2f} GB / {total_mem:.2f} GB (FP16)")
elif device == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" GPU memory: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")
print(f"\n🛠️ 构建文本表示...")
texts = []
chunk_ids = []
chunk_types = []
for chunk in chunks:
texts.append(build_text_for_embedding(chunk))
chunk_ids.append(chunk.get('chunk_id', -1))
chunk_types.append(chunk.get('chunk_type', 'unknown'))
print(f"\n🔢 向量化 {len(texts)} 个文本 (batch_size={batch_size})...")
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=True,
normalize_embeddings=normalize,
convert_to_numpy=True
)
print(f" Embeddings shape: {embeddings.shape}")
output_dir.mkdir(parents=True, exist_ok=True)
np.save(output_dir / "embeddings.npy", embeddings.astype('float32'))
with open(output_dir / "chunk_id_map.json", 'w', encoding='utf-8') as f:
json.dump(chunk_ids, f, ensure_ascii=False, indent=2)
with open(output_dir / "chunk_type_map.json", 'w', encoding='utf-8') as f:
json.dump(chunk_types, f, ensure_ascii=False, indent=2)
with open(output_dir / "chunks.json", 'w', encoding='utf-8') as f:
json.dump(chunks, f, ensure_ascii=False, indent=2)
with open(output_dir / "embeddings.pkl", 'wb') as f:
pickle.dump({
'chunks': chunks,
'embeddings': embeddings,
'texts': texts,
'normalized': normalize
}, f)
nan_count = np.isnan(embeddings).sum()
print(f"\n📊 质量检查:")
print(f" NaN values: {nan_count}")
norms = np.linalg.norm(embeddings, axis=1)
print(f" Norms: min={norms.min():.4f}, max={norms.max():.4f}, mean={norms.mean():.4f}")
print(f"\n✅ 向量数据已保存到: {output_dir}/")
print(f" 文件: embeddings.npy, chunk_id_map.json, chunk_type_map.json, chunks.json, embeddings.pkl")
type_counts = {}
for ct in chunk_types:
type_counts[ct] = type_counts.get(ct, 0) + 1
print(f"\n📈 Chunk 类型分布:")
for ct, count in sorted(type_counts.items(), key=lambda x: -x[1]):
print(f" {ct}: {count}")
return {
"chunks": len(chunks),
"embedding_dim": embeddings.shape[1],
"output_dir": str(output_dir)
}
if __name__ == "__main__":
import argparse
project_root = Path(__file__).resolve().parent
default_chunks = project_root / "jrxml_chunker_output" / "all_chunks.json"
parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具")
parser.add_argument("chunks_json", nargs="?", default=str(default_chunks),
help=f"Chunks JSON 文件路径 (默认: {default_chunks})")
parser.add_argument("--output_dir", "-o", default=None,
help="输出目录 (默认: embeddings)")
parser.add_argument("--model_path", "-m", default=None,
help="模型路径 (默认: models/Qwen3-Embedding-4B)")
parser.add_argument("--batch_size", "-b", type=int, default=64,
help="批处理大小 (默认: 64)")
parser.add_argument("--no_normalize", action="store_true",
help="不做向量归一化")
parser.add_argument("--no_fp16", action="store_true",
help="禁用 FP16 半精度(默认启用,可节省约 50%% 显存)")
args = parser.parse_args()
main(
chunks_json_path=args.chunks_json,
output_dir=args.output_dir,
model_path=args.model_path,
batch_size=args.batch_size,
normalize=not args.no_normalize,
use_fp16=not args.no_fp16
)