diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..55190db --- /dev/null +++ b/.gitignore @@ -0,0 +1,52 @@ +# Python +__pycache__/ +*.py[cod] +*.pyo +*.egg-info/ +dist/ +build/ +*.egg + +# 虚拟环境 +.venv/ +venv/ +env/ + +# 嵌入模型(体积大,通过 down_embedding_model.py 下载) +models/ + +# 向量数据(体积大,通过脚本生成) +embeddings/ +*.npy +*.pkl + +# Chroma 向量数据库(通过 import_to_chroma.py 生成) +chroma_db/ + +# JRXML 源文件(通过 collect_jrxml.py 收集) +jrxml_source/ + +# 分块输出(通过 jrxml_banch_chunker.py 生成) +jrxml_chunker_output/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db +desktop.ini + +# Jupyter +.ipynb_checkpoints/ + +# 环境变量 +.env +.env.local + +# 日志 +*.log \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..1a7f1f3 --- /dev/null +++ b/README.md @@ -0,0 +1,131 @@ +# JRXML RAG 项目 + +基于 RAG(Retrieval-Augmented Generation)的 JasperReports JRXML 模板智能问答系统。 + +## 项目简介 + +本项目将 JasperReports 的 JRXML 模板文件进行语义分块、向量化,并存入 Chroma 向量数据库,实现通过自然语言查询来检索和理解报表模板的结构、配置和逻辑。 + +## 项目结构 + +``` +RAG-jaspersoft/ +├── collect_jrxml.py # JRXML 文件收集脚本 +├── jrxml_chunker.py # JRXML 语义分块核心引擎 +├── jrxml_banch_chunker.py # 批量分块入口脚本 +├── down_embedding_model.py # 嵌入模型下载脚本 +├── embed_chunks.py # Chunk 向量化脚本 +├── import_to_chroma.py # 向量导入 Chroma 数据库 +├── query_chroma.py # 语义搜索查询工具 +├── jrxml_source/ # JRXML 源文件目录 +├── jrxml_chunker_output/ # 分块输出目录 +│ ├── all_chunks.json # 所有 chunks 合并文件 +│ ├── processing_stats.json # 处理统计报告 +│ └── per_file/ # 按文件分类的 chunks +├── models/ # 嵌入模型存放目录 +│ └── Qwen3-Embedding-4B/ # Qwen3 嵌入模型 +├── embeddings/ # 向量输出目录 +│ ├── embeddings.npy # 向量矩阵 +│ ├── chunks.json # 原始 chunks +│ └── embeddings.pkl # 完整数据 pickle +├── chroma_db/ # Chroma 向量数据库 +└── docs/ # 项目文档 + └── file_guide.md # 文件功能说明 +``` + +## 快速开始 + +### 环境要求 + +- Python 3.11+ +- NVIDIA GPU(推荐,8GB+ 显存)或 CPU +- CUDA 12.1+(GPU 模式) + +### 安装依赖 + +```bash +# 安装 PyTorch (CUDA 版本) +uv pip install torch --index-url https://download.pytorch.org/whl/cu130 + +# 安装其他依赖 +uv pip install sentence-transformers chromadb numpy tqdm +``` + +### 完整流程 + +```bash +# 1. 收集 JRXML 文件 +python collect_jrxml.py + +# 2. 语义分块 +python jrxml_banch_chunker.py ./jrxml_source --output ./jrxml_chunker_output + +# 3. 下载嵌入模型(首次运行) +python down_embedding_model.py + +# 4. 向量化 +python embed_chunks.py --batch_size 2 + +# 5. 导入 Chroma 数据库 +python import_to_chroma.py + +# 6. 开始查询 +python query_chroma.py +``` + +### 快速查询 + +```bash +# 交互模式 +python query_chroma.py + +# 单次查询 +python query_chroma.py "如何修改报表标题" + +# 按类型过滤 +python query_chroma.py "SQL查询怎么写" --filter_field query +``` + +## 分块类型 + +系统将 JRXML 模板按以下语义类型进行分块: + +| 类型 | 说明 | +|------|------| +| `report_overview` | 报告整体概览,含数据源分析 | +| `datasource_config` | 数据源配置属性 | +| `query` | 数据查询(SQL/HQL/XPath 等) | +| `parameters` | 参数定义 | +| `fields` | 字段定义 | +| `sortFields` | 排序字段 | +| `filterExpression` | 过滤表达式 | +| `variables_*` | 变量定义(按重置类型分组) | +| `styles` | 样式定义 | +| `groups` | 分组定义 | +| `band_*` | 标准带(title/detail/pageHeader 等) | +| `chart` | 图表元素 | +| `crosstab` | 交叉表元素 | +| `subreport` | 子报表元素 | +| `component` | 组件元素(列表等) | +| `dataset` | 数据集定义 | + +## 技术栈 + +- **分块引擎**: 基于 XML 解析的语义分块器 +- **嵌入模型**: Qwen3-Embedding-4B(支持 FP16 半精度) +- **向量数据库**: ChromaDB(持久化模式,余弦相似度) +- **嵌入框架**: Sentence-Transformers +- **深度学习**: PyTorch + CUDA + +## 性能参考 + +| 硬件 | 模型 | Batch Size | 速度 | +|------|------|-----------|------| +| RTX 4060 Laptop 8GB | Qwen3-Embedding-4B (FP16) | 2 | ~1.2s/chunk | +| RTX 4060 Laptop 8GB | all-MiniLM-L6-v2 | 64 | ~0.001s/chunk | + +> 离线建库是一次性开销,在线查询仅需 1-2 秒。 + +## License + +MIT \ No newline at end of file diff --git a/docs/file_guide.md b/docs/file_guide.md new file mode 100644 index 0000000..ee85dc4 --- /dev/null +++ b/docs/file_guide.md @@ -0,0 +1,282 @@ +# 文件功能说明 + +本文档详细解释项目中每个 Python 脚本的功能、输入输出和使用方式。 + +--- + +## 1. collect_jrxml.py — JRXML 文件收集脚本 + +**功能**: 从指定的 JasperReports 模板库目录递归收集所有 `.jrxml` 文件,复制到项目的 `jrxml_source` 目录。 + +**输入**: +- 源目录: `C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples`(可修改) + +**输出**: +- `jrxml_source/` 目录,包含所有收集到的 JRXML 文件 + +**使用方式**: +```bash +python collect_jrxml.py +``` + +**核心逻辑**: +- 使用 `os.walk()` 递归遍历源目录 +- 筛选 `.jrxml` 后缀文件 +- 自动处理文件名冲突(添加数字后缀) +- 使用 `shutil.copy2()` 保留文件元数据 + +--- + +## 2. jrxml_chunker.py — JRXML 语义分块核心引擎 + +**功能**: 将单个 JRXML 文件按语义结构拆分为多个 chunk,每个 chunk 包含人类可读描述、原始 XML 和元数据。 + +**输入**: +- 单个 JRXML 文件路径 + +**输出**: +- `JRXMLChunk` 对象列表,每个包含: + - `chunk_id`: 唯一标识 + - `chunk_type`: 分块类型(如 `query`, `field`, `band_title` 等) + - `human_description`: 人类可读的结构化描述 + - `raw_xml`: 原始 XML 片段 + - `context`: 上下文信息(所属报表名称) + - `metadata`: 元数据字典 + +**核心类**: +- `JRXMLChunk`: 单个 chunk 的数据结构 +- `JRXMLSemanticChunker`: 主分块器,支持多种数据源类型(SQL、HQL、XPath、JSON、CSV 等) + +**分块策略**: +- 按 XML 元素类型分类(field、parameter、variable、band、chart 等) +- 提取数据源配置和查询语句 +- 保留元素间的层级关系 +- 为每个 chunk 生成结构化的人类可读描述 + +**使用方式**: +```bash +# 处理单个文件 +python jrxml_chunker.py report.jrxml + +# 处理整个目录 +python jrxml_chunker.py ./jrxml_source/ +``` + +--- + +## 3. jrxml_banch_chunker.py — 批量分块入口脚本 + +**功能**: 批量处理目录下所有 JRXML 文件,生成统计报告和分类输出。 + +**输入**: +- JRXML 文件目录(默认: `jrxml_source`) + +**输出**: +- `jrxml_chunker_output/all_chunks.json`: 所有 chunks 合并文件 +- `jrxml_chunker_output/processing_stats.json`: 处理统计(成功/失败数、耗时、chunk 类型分布) +- `jrxml_chunker_output/per_file/`: 按原文件分类的独立 chunk 文件 + +**核心函数**: +- `batch_chunk_with_report()`: 批量处理目录 +- `chunk_single_file_with_report()`: 处理单个文件 + +**使用方式**: +```bash +# 使用默认输入目录 +python jrxml_banch_chunker.py + +# 指定输入目录 +python jrxml_banch_chunker.py ./jrxml_source + +# 指定输出目录 +python jrxml_banch_chunker.py ./jrxml_source --output ./my_output +``` + +--- + +## 4. down_embedding_model.py — 嵌入模型下载脚本 + +**功能**: 从 HuggingFace Hub 下载 Qwen3-Embedding-4B 嵌入模型到本地。 + +**输入**: +- HuggingFace 模型仓库: `Qwen/Qwen3-Embedding-4B` + +**输出**: +- `models/Qwen3-Embedding-4B/` 目录,包含完整的模型文件 + +**特性**: +- 使用国内镜像加速下载(`hf-mirror.com`) +- 支持断点续传 +- 自动安装依赖 + +**使用方式**: +```bash +python down_embedding_model.py +``` + +--- + +## 5. embed_chunks.py — Chunk 向量化脚本 + +**功能**: 使用嵌入模型将分块后的文本转换为向量表示,支持 GPU 加速和 FP16 半精度。 + +**输入**: +- `jrxml_chunker_output/all_chunks.json`(默认) + +**输出**: +- `embeddings/embeddings.npy`: 向量矩阵(float32) +- `embeddings/chunk_id_map.json`: chunk ID 映射 +- `embeddings/chunk_type_map.json`: chunk 类型映射 +- `embeddings/chunks.json`: 原始 chunks 副本 +- `embeddings/embeddings.pkl`: 完整数据 pickle + +**核心函数**: +- `build_text_for_embedding()`: 将 chunk 转换为适合向量化的文本(拼接类型、描述、XML、元数据) +- `main()`: 主流程(加载→编码→保存→质量检查) + +**特性**: +- 自动检测 CUDA/CPU +- 默认启用 FP16 半精度(节省约 50% 显存) +- 支持 HuggingFace Hub 在线模型 +- 向量归一化 + NaN 检测 + +**使用方式**: +```bash +# 使用默认设置 +python embed_chunks.py + +# 指定模型和 batch size +python embed_chunks.py --model_path "sentence-transformers/all-MiniLM-L6-v2" --batch_size 64 + +# 使用本地 Qwen3 模型 +python embed_chunks.py --batch_size 2 + +# 禁用 FP16 +python embed_chunks.py --no_fp16 --batch_size 1 +``` + +--- + +## 6. import_to_chroma.py — 向量导入 Chroma 数据库 + +**功能**: 将已生成的向量和 chunks 导入 Chroma 持久化向量数据库。 + +**输入**: +- `embeddings/embeddings.npy`: 向量矩阵 +- `embeddings/chunks.json`: chunks 数据 + +**输出**: +- `chroma_db/`: Chroma 持久化数据库目录 +- 集合名称: `jrxml_chunks`(默认) + +**核心逻辑**: +- 加载向量和 chunks +- 初始化 Chroma PersistentClient +- 创建集合(余弦相似度) +- 分批导入(每批 1000 条) +- 提取元数据(chunk_type、report_name、band_name 等) +- 快速验证查询 + +**使用方式**: +```bash +# 使用默认设置 +python import_to_chroma.py + +# 指定路径 +python import_to_chroma.py --embeddings_dir ./embeddings --chroma_path ./chroma_db +``` + +--- + +## 7. query_chroma.py — 语义搜索查询工具 + +**功能**: 通过自然语言查询 Chroma 数据库,检索相关的 JRXML chunk。 + +**输入**: +- 用户自然语言查询 +- 可选的元数据过滤条件 + +**输出**: +- 相似度排序的检索结果(含 chunk 类型、报表名称、区域、内容摘要) + +**核心类**: +- `JRXMLSearcher`: 搜索器,封装模型加载、向量编码和 Chroma 查询 + +**核心方法**: +- `search()`: 基础语义搜索 +- `search_with_threshold()`: 带相似度阈值的搜索 +- `format_result()`: 格式化输出结果 + +**两种模式**: +1. **命令行单次查询**: `python query_chroma.py "查询内容"` +2. **交互模式**: `python query_chroma.py`(支持连续查询和内联命令) + +**交互模式命令**: +``` +filter:<类型> 按 chunk_type 过滤(如 filter:query) +t:<阈值> 设置相似度阈值 0~1(如 t:0.5) +k:<数量> 设置返回结果数(如 k:10) +``` + +**使用方式**: +```bash +# 交互模式 +python query_chroma.py + +# 单次查询 +python query_chroma.py "如何修改报表标题" + +# 按类型过滤 +python query_chroma.py "SQL怎么写" --filter_field query + +# 设置阈值和返回数量 +python query_chroma.py "报表参数" --threshold 0.5 --n_results 10 +``` + +--- + +## 数据流全景 + +``` +┌─────────────────┐ +│ JasperReports │ C:\Users\...\JasperReportsSamples +│ 模板库 │ +└────────┬────────┘ + │ collect_jrxml.py + ▼ +┌─────────────────┐ +│ jrxml_source/ │ 收集的 JRXML 文件 +└────────┬────────┘ + │ jrxml_banch_chunker.py (调用 jrxml_chunker.py) + ▼ +┌──────────────────────┐ +│ jrxml_chunker_output/│ all_chunks.json + per_file/ +└────────┬─────────────┘ + │ embed_chunks.py (使用 Qwen3-Embedding-4B) + ▼ +┌─────────────────┐ +│ embeddings/ │ embeddings.npy + chunks.json +└────────┬────────┘ + │ import_to_chroma.py + ▼ +┌─────────────────┐ +│ chroma_db/ │ Chroma 向量数据库 +└────────┬────────┘ + │ query_chroma.py + ▼ +┌─────────────────┐ +│ 用户查询 │ 自然语言 → 相关 JRXML chunks +└─────────────────┘ +``` + +## 依赖关系 + +``` +query_chroma.py ──────► chromadb, sentence_transformers, torch +import_to_chroma.py ──► chromadb, numpy +embed_chunks.py ──────► sentence_transformers, torch, numpy +down_embedding_model.py ► huggingface_hub +jrxml_banch_chunker.py ─► jrxml_chunker.py +jrxml_chunker.py ─────► xml.etree.ElementTree (标准库) +collect_jrxml.py ─────► 标准库 (os, shutil) +``` \ No newline at end of file diff --git a/embed_chunks.py b/embed_chunks.py index 0a0be8c..3ab8e2e 100644 --- a/embed_chunks.py +++ b/embed_chunks.py @@ -4,10 +4,13 @@ embed_chunks.py 支持 GPU (CUDA) 或 CPU """ -import os, sys, json, pickle +import os +import sys +import json +import pickle +from pathlib import Path import numpy as np import torch -from tqdm import tqdm from sentence_transformers import SentenceTransformer def build_text_for_embedding(chunk: dict) -> str: @@ -22,13 +25,11 @@ def build_text_for_embedding(chunk: dict) -> str: context = chunk.get('context', '') if context: parts.append(f"Context: {context}") - - # 添加部分 XML (前500字符) + raw_xml = chunk.get('raw_xml', '') if raw_xml: parts.append(f"XML: {raw_xml[:500]}") - - # 添加元数据 + meta = chunk.get('metadata', {}) if meta: if 'field_names' in meta: @@ -45,9 +46,10 @@ def build_text_for_embedding(chunk: dict) -> str: parts.append(f"QueryLang: {meta['query_language']}") return "\n".join(parts) -def main(chunks_json_path: str, output_dir: str = "./embeddings", - model_path: str = "./models/Qwen3-Embedding-4B", - batch_size: int = 16, normalize: bool = True): + +def main(chunks_json_path: str = None, output_dir: str = None, + model_path: str = None, batch_size: int = 64, normalize: bool = True, + use_fp16: bool = True): """ 主流程: 1. 加载 chunk JSON @@ -55,29 +57,80 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings", 3. 构造文本并向量化 4. 保存向量及映射文件 """ - # --- 1. 加载 chunks --- - print(f"📄 Loading chunks from {chunks_json_path}") + project_root = Path(__file__).resolve().parent + + if chunks_json_path is None: + chunks_json_path = project_root / "jrxml_chunker_output" / "all_chunks.json" + else: + chunks_json_path = Path(chunks_json_path) + + if output_dir is None: + output_dir = project_root / "embeddings" + else: + output_dir = Path(output_dir) + + if model_path is None: + model_path = project_root / "models" / "Qwen3-Embedding-4B" + else: + model_path = Path(model_path) + + if not chunks_json_path.exists(): + print(f"❌ Chunks 文件不存在: {chunks_json_path}") + print(f" 请先运行 jrxml_banch_chunker.py 生成 chunks") + return None + + print(f"\n{'='*60}") + print(f"JRXML Chunks 向量化") + print(f"{'='*60}") + print(f"📄 加载 chunks: {chunks_json_path}") with open(chunks_json_path, 'r', encoding='utf-8') as f: chunks = json.load(f) print(f" Total chunks: {len(chunks)}") - # --- 2. 加载模型 --- device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"🧠 Loading embedding model from {model_path} on {device}") - model = SentenceTransformer(model_path, device=device) - if device == "cuda": - print(f" GPU memory allocated: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB") + print(f"\n🧠 加载嵌入模型: {model_path}") + print(f" 设备: {device}") - # --- 3. 构造文本 --- - print("🛠️ Building text representations...") + # 检查是否是 HuggingFace Hub 模型(格式为 username/model_name) + model_path_str = str(model_path) + # Windows PowerShell 会把 / 自动转成 \,需要还原 + if "\\" in model_path_str and not os.path.exists(model_path_str): + model_path_str = model_path_str.replace("\\", "/") + + is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str) + + # 如果是本地路径但不存在,则报错 + if not is_hub_model and not os.path.exists(model_path_str): + print(f"❌ 模型目录不存在: {model_path}") + print(f" 请先下载模型到 {model_path}") + print(f" 或者使用 HuggingFace Hub 模型,例如: sentence-transformers/all-MiniLM-L6-v2") + return None + + model = SentenceTransformer(model_path_str, device=device) + + if device == "cuda" and use_fp16: + model = model.half() + torch.cuda.empty_cache() + mem_used = torch.cuda.memory_allocated(0) / 1024**3 + total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 + print(f" FP16 已启用") + print(f" GPU: {torch.cuda.get_device_name(0)}") + print(f" GPU memory: {mem_used:.2f} GB / {total_mem:.2f} GB (FP16)") + elif device == "cuda": + print(f" GPU: {torch.cuda.get_device_name(0)}") + print(f" GPU memory: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB") + + print(f"\n🛠️ 构建文本表示...") texts = [] chunk_ids = [] + chunk_types = [] + for chunk in chunks: texts.append(build_text_for_embedding(chunk)) chunk_ids.append(chunk.get('chunk_id', -1)) + chunk_types.append(chunk.get('chunk_type', 'unknown')) - # --- 4. 向量化 --- - print(f"🔢 Encoding {len(texts)} texts (batch_size={batch_size})...") + print(f"\n🔢 向量化 {len(texts)} 个文本 (batch_size={batch_size})...") embeddings = model.encode( texts, batch_size=batch_size, @@ -87,19 +140,16 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings", ) print(f" Embeddings shape: {embeddings.shape}") - # --- 5. 保存到输出目录 --- - os.makedirs(output_dir, exist_ok=True) + output_dir.mkdir(parents=True, exist_ok=True) - # 向量矩阵 (float32) - np.save(os.path.join(output_dir, "embeddings.npy"), embeddings.astype('float32')) - # chunk_id 映射 - with open(os.path.join(output_dir, "chunk_id_map.json"), 'w') as f: + np.save(output_dir / "embeddings.npy", embeddings.astype('float32')) + with open(output_dir / "chunk_id_map.json", 'w', encoding='utf-8') as f: json.dump(chunk_ids, f, ensure_ascii=False, indent=2) - # 原始 chunks 副本 - with open(os.path.join(output_dir, "chunks.json"), 'w') as f: + with open(output_dir / "chunk_type_map.json", 'w', encoding='utf-8') as f: + json.dump(chunk_types, f, ensure_ascii=False, indent=2) + with open(output_dir / "chunks.json", 'w', encoding='utf-8') as f: json.dump(chunks, f, ensure_ascii=False, indent=2) - # pickle 方便调试 - with open(os.path.join(output_dir, "embeddings.pkl"), 'wb') as f: + with open(output_dir / "embeddings.pkl", 'wb') as f: pickle.dump({ 'chunks': chunks, 'embeddings': embeddings, @@ -107,24 +157,48 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings", 'normalized': normalize }, f) - # --- 6. 质量检查 --- nan_count = np.isnan(embeddings).sum() - print(f"\n📊 Quality check:") + print(f"\n📊 质量检查:") print(f" NaN values: {nan_count}") norms = np.linalg.norm(embeddings, axis=1) print(f" Norms: min={norms.min():.4f}, max={norms.max():.4f}, mean={norms.mean():.4f}") - print(f"\n✅ Embeddings saved to {output_dir}/") - print(f" Files: embeddings.npy, chunk_id_map.json, chunks.json, embeddings.pkl") + + print(f"\n✅ 向量数据已保存到: {output_dir}/") + print(f" 文件: embeddings.npy, chunk_id_map.json, chunk_type_map.json, chunks.json, embeddings.pkl") + + type_counts = {} + for ct in chunk_types: + type_counts[ct] = type_counts.get(ct, 0) + 1 + print(f"\n📈 Chunk 类型分布:") + for ct, count in sorted(type_counts.items(), key=lambda x: -x[1]): + print(f" {ct}: {count}") + + return { + "chunks": len(chunks), + "embedding_dim": embeddings.shape[1], + "output_dir": str(output_dir) + } + if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser() - parser.add_argument("chunks_json", help="Path to all_chunks.json") - parser.add_argument("--output_dir", "-o", default="./embeddings") - parser.add_argument("--model_path", "-m", default="./models/Qwen3-Embedding-4B") - parser.add_argument("--batch_size", "-b", type=int, default=8, - help="Batch size (lower if OOM)") - parser.add_argument("--no_normalize", action="store_true") + project_root = Path(__file__).resolve().parent + default_chunks = project_root / "jrxml_chunker_output" / "all_chunks.json" + + parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具") + parser.add_argument("chunks_json", nargs="?", default=str(default_chunks), + help=f"Chunks JSON 文件路径 (默认: {default_chunks})") + parser.add_argument("--output_dir", "-o", default=None, + help="输出目录 (默认: embeddings)") + parser.add_argument("--model_path", "-m", default=None, + help="模型路径 (默认: models/Qwen3-Embedding-4B)") + parser.add_argument("--batch_size", "-b", type=int, default=64, + help="批处理大小 (默认: 64)") + parser.add_argument("--no_normalize", action="store_true", + help="不做向量归一化") + parser.add_argument("--no_fp16", action="store_true", + help="禁用 FP16 半精度(默认启用,可节省约 50%% 显存)") + args = parser.parse_args() main( @@ -132,5 +206,6 @@ if __name__ == "__main__": output_dir=args.output_dir, model_path=args.model_path, batch_size=args.batch_size, - normalize=not args.no_normalize + normalize=not args.no_normalize, + use_fp16=not args.no_fp16 ) \ No newline at end of file diff --git a/import_to_chroma.py b/import_to_chroma.py new file mode 100644 index 0000000..b97d384 --- /dev/null +++ b/import_to_chroma.py @@ -0,0 +1,179 @@ +""" +import_to_chroma.py +将已生成的 chunk 向量导入 Chroma 数据库 +""" + +import os +import json +import sys +import time +from pathlib import Path +import numpy as np +import chromadb +from tqdm import tqdm + + +def main(embeddings_dir: str = None, + chroma_path: str = None, + collection_name: str = "jrxml_chunks"): + """ + 从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库 + + Args: + embeddings_dir: 包含 embeddings.npy, chunks.json 的目录 + chroma_path: Chroma 持久化目录 + collection_name: 集合名称 + """ + project_root = Path(__file__).resolve().parent + + if embeddings_dir is None: + embeddings_dir = project_root / "embeddings" + else: + embeddings_dir = Path(embeddings_dir) + + if chroma_path is None: + chroma_path = project_root / "chroma_db" + else: + chroma_path = Path(chroma_path) + + embeddings_file = embeddings_dir / "embeddings.npy" + chunks_file = embeddings_dir / "chunks.json" + + for f in [embeddings_file, chunks_file]: + if not f.exists(): + print(f"❌ 缺少文件: {f}") + print(f" 请先运行 embed_chunks.py 生成向量") + return None + + print(f"\n{'='*60}") + print(f"JRXML Chunks 导入 Chroma 数据库") + print(f"{'='*60}") + + print(f"\n📂 加载向量和 chunks...") + embeddings = np.load(embeddings_file).astype('float32') + with open(chunks_file, 'r', encoding='utf-8') as f: + chunks = json.load(f) + + if len(embeddings) != len(chunks): + print(f"❌ 数量不匹配: {len(embeddings)} vs {len(chunks)}") + return None + + print(f" 向量维度: {embeddings.shape[1]}") + print(f" Chunks 数量: {len(chunks)}") + + print(f"\n💾 初始化 Chroma 数据库: {chroma_path}") + chroma_path.mkdir(parents=True, exist_ok=True) + client = chromadb.PersistentClient(path=str(chroma_path)) + + try: + client.delete_collection(collection_name) + print(f" 已删除旧集合 '{collection_name}'") + except Exception: + pass + + collection = client.create_collection( + name=collection_name, + metadata={"hnsw:space": "cosine"} + ) + + print(f"\n🛠️ 准备导入数据...") + ids = [] + documents = [] + metadatas = [] + embeddings_list = [] + + seen_ids = {} + for i, chunk in enumerate(tqdm(chunks, desc="准备数据")): + raw_id = str(chunk.get("chunk_id", i)) + if raw_id in seen_ids: + seen_ids[raw_id] += 1 + chunk_id = f"{raw_id}_{seen_ids[raw_id]}" + else: + seen_ids[raw_id] = 0 + chunk_id = raw_id + ids.append(chunk_id) + + doc_text = chunk.get("human_description", "") + documents.append(doc_text) + + meta = {} + chunk_type = chunk.get("chunk_type", "") + if chunk_type: + meta["chunk_type"] = chunk_type + + context = chunk.get("context", "") + if context: + meta["context"] = context + + chunk_meta = chunk.get("metadata", {}) + if "report_name" in chunk_meta: + meta["report_name"] = chunk_meta["report_name"] + if "band_name" in chunk_meta: + meta["band_name"] = chunk_meta["band_name"] + if "element_kind" in chunk_meta: + meta["element_kind"] = chunk_meta["element_kind"] + if "query_language" in chunk_meta: + meta["query_language"] = chunk_meta["query_language"] + + metadatas.append(meta) + embeddings_list.append(embeddings[i].tolist()) + + print(f"\n📥 分批导入到 Chroma (每批 1000 条)...") + import_batch_size = 1000 + start_time = time.time() + + for start in tqdm(range(0, len(ids), import_batch_size), desc="导入进度"): + end = min(start + import_batch_size, len(ids)) + collection.add( + ids=ids[start:end], + documents=documents[start:end], + metadatas=metadatas[start:end], + embeddings=embeddings_list[start:end] + ) + + duration = time.time() - start_time + print(f"\n✅ 成功导入 {len(ids)} 个 chunks 到 '{collection_name}'") + print(f" 数据库路径: {chroma_path}") + print(f" 集合数量: {collection.count()}") + print(f" 导入耗时: {duration:.2f}s") + + print(f"\n🔍 快速验证查询...") + results = collection.query( + query_embeddings=[embeddings_list[0]], + n_results=3, + include=["documents", "metadatas", "distances"] + ) + distances = results.get('distances', [[]]) + if distances and distances[0]: + print(f" Top-3 相似度距离: {[round(d, 4) for d in distances[0]]}") + first_doc = results.get('documents', [['']])[0][0] + print(f" 首位结果: {first_doc[:120]}...") + + print(f"\n📊 元数据字段分布:") + all_keys = set() + for m in metadatas: + all_keys.update(m.keys()) + for key in sorted(all_keys): + count = sum(1 for m in metadatas if key in m) + print(f" {key}: {count}") + + return collection + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具") + parser.add_argument("--embeddings_dir", "-e", default=None, + help="向量文件目录 (默认: embeddings)") + parser.add_argument("--chroma_path", "-c", default=None, + help="Chroma 数据库路径 (默认: chroma_db)") + parser.add_argument("--collection_name", "-n", default="jrxml_chunks", + help="集合名称 (默认: jrxml_chunks)") + + args = parser.parse_args() + + main( + embeddings_dir=args.embeddings_dir, + chroma_path=args.chroma_path, + collection_name=args.collection_name + ) \ No newline at end of file diff --git a/query_chroma.py b/query_chroma.py new file mode 100644 index 0000000..64d5da5 --- /dev/null +++ b/query_chroma.py @@ -0,0 +1,269 @@ +""" +query_chroma.py +查询 Chroma 数据库,从自然语言查找相关 JRXML chunk +支持命令行单次查询和交互式连续查询 +""" + +import os +import sys +import time +from pathlib import Path +import numpy as np +import torch +from sentence_transformers import SentenceTransformer +import chromadb + + +class JRXMLSearcher: + def __init__(self, chroma_path: str = None, + collection_name: str = "jrxml_chunks", + model_path: str = None, + use_fp16: bool = True): + project_root = Path(__file__).resolve().parent + + if chroma_path is None: + chroma_path = str(project_root / "chroma_db") + if model_path is None: + model_path = str(project_root / "models" / "Qwen3-Embedding-4B") + + # 处理 Hub 模型名称 + model_path_str = str(model_path) + if "\\" in model_path_str and not os.path.exists(model_path_str): + model_path_str = model_path_str.replace("\\", "/") + + # 加载嵌入模型 + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"🧠 加载模型: {model_path_str}") + print(f" 设备: {device}") + self.model = SentenceTransformer(model_path_str, device=device) + + if device == "cuda" and use_fp16: + self.model = self.model.half() + torch.cuda.empty_cache() + mem = torch.cuda.memory_allocated(0) / 1024**3 + total = torch.cuda.get_device_properties(0).total_memory / 1024**3 + print(f" FP16 已启用, 显存: {mem:.2f} GB / {total:.2f} GB") + + # 连接 Chroma + print(f"💾 连接 Chroma: {chroma_path}") + self.client = chromadb.PersistentClient(path=chroma_path) + self.collection = self.client.get_collection(collection_name) + print(f" 集合 '{collection_name}': {self.collection.count()} 条记录\n") + + def search(self, query: str, n_results: int = 5, filter_meta: dict = None): + query_embedding = self.model.encode( + query, + normalize_embeddings=True, + show_progress_bar=False + ).tolist() + + where_filter = filter_meta if filter_meta else None + results = self.collection.query( + query_embeddings=[query_embedding], + n_results=n_results, + where=where_filter, + include=["documents", "metadatas", "distances"] + ) + return results + + def search_with_threshold(self, query: str, n_results: int = 5, + threshold: float = 0.3, filter_meta: dict = None): + results = self.search(query, n_results, filter_meta) + filtered = {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]} + + for i, dist in enumerate(results["distances"][0]): + if dist <= threshold: + filtered["ids"][0].append(results["ids"][0][i]) + filtered["documents"][0].append(results["documents"][0][i]) + filtered["metadatas"][0].append(results["metadatas"][0][i]) + filtered["distances"][0].append(dist) + return filtered + + def format_result(self, results: dict) -> str: + lines = [] + n = len(results["ids"][0]) + lines.append(f"找到 {n} 条结果:") + for i, (doc_id, doc, dist, meta) in enumerate(zip( + results["ids"][0], + results["documents"][0], + results["distances"][0], + results["metadatas"][0] + )): + chunk_type = meta.get("chunk_type", "N/A") + report = meta.get("report_name", "") + band = meta.get("band_name", "") + lines.append(f"\n--- 结果 {i+1} (相似度={1-dist:.4f}, id={doc_id}) ---") + lines.append(f"类型: {chunk_type}") + if report: + lines.append(f"报表: {report}") + if band: + lines.append(f"区域: {band}") + lines.append(f"内容: {doc[:300]}") + return "\n".join(lines) + + +def main(): + import argparse + project_root = Path(__file__).resolve().parent + + parser = argparse.ArgumentParser(description="JRXML Chunks 语义搜索工具") + parser.add_argument("query", nargs="?", default="", + help="搜索关键词(不提供则进入交互模式)") + parser.add_argument("--chroma_path", "-c", default=None, + help=f"Chroma 数据库路径 (默认: chroma_db)") + parser.add_argument("--collection", "-n", default="jrxml_chunks", + help="集合名称") + parser.add_argument("--model_path", "-m", default=None, + help="嵌入模型路径") + parser.add_argument("--n_results", "-k", type=int, default=5, + help="返回结果数 (默认: 5)") + parser.add_argument("--filter_field", "-f", + help="按 chunk_type 过滤,例如: field, query, chart") + parser.add_argument("--threshold", "-t", type=float, + help="相似度阈值 (0~1, 越高越相似)") + parser.add_argument("--no_fp16", action="store_true", + help="禁用 FP16 半精度") + + args = parser.parse_args() + + if args.chroma_path is None: + args.chroma_path = str(project_root / "chroma_db") + + if args.model_path is None: + default_model = project_root / "models" / "Qwen3-Embedding-4B" + if not default_model.exists(): + args.model_path = "sentence-transformers/all-MiniLM-L6-v2" + else: + args.model_path = str(default_model) + + # 检查数据库 + if not os.path.exists(args.chroma_path): + print(f"❌ Chroma 数据库不存在: {args.chroma_path}") + print(f" 请先运行 import_to_chroma.py 导入数据") + return + + # 初始化搜索器 + try: + searcher = JRXMLSearcher( + chroma_path=args.chroma_path, + collection_name=args.collection, + model_path=args.model_path, + use_fp16=not args.no_fp16 + ) + except Exception as e: + print(f"❌ 初始化失败: {e}") + return + + # 准备过滤条件 + filter_meta = None + if args.filter_field: + filter_meta = {"chunk_type": args.filter_field} + + # 单次查询模式 + if args.query: + query = args.query + print(f"\n� 搜索: '{query}'") + if filter_meta: + print(f" 过滤: {filter_meta}") + + start = time.time() + if args.threshold is not None: + results = searcher.search_with_threshold( + query, args.n_results, args.threshold, filter_meta + ) + else: + results = searcher.search(query, args.n_results, filter_meta) + elapsed = time.time() - start + + print(searcher.format_result(results)) + print(f"\n⏱️ 耗时: {elapsed:.2f}s") + return + + # 交互模式 + print(f"\n{'='*60}") + print(f"JRXML 语义搜索 - 交互模式") + print(f"{'='*60}") + print(f"可用过滤类型: report_overview, query, field, parameter,") + print(f" variable, band_*, chart, crosstab, subreport, style 等") + print(f"示例: '如何修改报表标题'") + print(f" 'filter:query SQL数据源查询'") + print(f" 't:0.5 band:title 标题区域'") + print(f"输入 'help' 查看帮助, 'exit' 退出\n") + + while True: + try: + user_input = input("🔍 搜索> ").strip() + except (EOFError, KeyboardInterrupt): + print("\n👋 再见!") + break + + if not user_input: + continue + if user_input.lower() in ("exit", "quit", "q"): + print("👋 再见!") + break + if user_input.lower() == "help": + print(""" +特殊命令: + filter:<类型> 按 chunk_type 过滤 (如 filter:query) + t:<阈值> 设置相似度阈值 0~1 (如 t:0.5) + k:<数量> 设置返回结果数 (如 k:10) + +示例: + filter:field 数据源字段有哪些 + t:0.5 band:title 标题区域怎么设置 + k:10 报表参数定义 +""") + continue + + # 解析特殊命令 + query_text = user_input + cur_filter = filter_meta + cur_n = args.n_results + cur_threshold = args.threshold + + parts = user_input.split() + new_parts = [] + for p in parts: + if p.startswith("filter:"): + field_val = p[len("filter:"):] + cur_filter = {"chunk_type": field_val} + print(f" 📌 过滤: {cur_filter}") + elif p.startswith("t:"): + try: + cur_threshold = float(p[2:]) + print(f" 📌 阈值: {cur_threshold}") + except ValueError: + pass + elif p.startswith("k:"): + try: + cur_n = int(p[2:]) + cur_n = max(1, min(cur_n, 50)) + print(f" 📌 返回数量: {cur_n}") + except ValueError: + pass + else: + new_parts.append(p) + query_text = " ".join(new_parts) + + if not query_text: + print(" ⚠️ 请输入搜索内容") + continue + + print(f"🔍 搜索: '{query_text}'") + start = time.time() + + if cur_threshold is not None: + results = searcher.search_with_threshold( + query_text, cur_n, cur_threshold, cur_filter + ) + else: + results = searcher.search(query_text, cur_n, cur_filter) + + elapsed = time.time() - start + print(searcher.format_result(results)) + print(f"⏱️ 耗时: {elapsed:.2f}s\n") + + +if __name__ == "__main__": + main() \ No newline at end of file