chore: 初始化JRXML RAG项目,添加基础文件

创建了完整的JRXML语义检索RAG项目,包含:
1. 新增.gitignore忽略项目生成的缓存、依赖目录和本地文件
2. 编写详细的项目README文档
3. 补充文件功能说明文档
4. 实现向量导入、向量化、查询等核心脚本
This commit is contained in:
2026-05-12 08:14:55 +08:00
parent 4f475e9e36
commit bd98486de0
6 changed files with 1030 additions and 42 deletions
+52
View File
@@ -0,0 +1,52 @@
# Python
__pycache__/
*.py[cod]
*.pyo
*.egg-info/
dist/
build/
*.egg
# 虚拟环境
.venv/
venv/
env/
# 嵌入模型(体积大,通过 down_embedding_model.py 下载)
models/
# 向量数据(体积大,通过脚本生成)
embeddings/
*.npy
*.pkl
# Chroma 向量数据库(通过 import_to_chroma.py 生成)
chroma_db/
# JRXML 源文件(通过 collect_jrxml.py 收集)
jrxml_source/
# 分块输出(通过 jrxml_banch_chunker.py 生成)
jrxml_chunker_output/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# OS
.DS_Store
Thumbs.db
desktop.ini
# Jupyter
.ipynb_checkpoints/
# 环境变量
.env
.env.local
# 日志
*.log
+131
View File
@@ -0,0 +1,131 @@
# JRXML RAG 项目
基于 RAGRetrieval-Augmented Generation)的 JasperReports JRXML 模板智能问答系统。
## 项目简介
本项目将 JasperReports 的 JRXML 模板文件进行语义分块、向量化,并存入 Chroma 向量数据库,实现通过自然语言查询来检索和理解报表模板的结构、配置和逻辑。
## 项目结构
```
RAG-jaspersoft/
├── collect_jrxml.py # JRXML 文件收集脚本
├── jrxml_chunker.py # JRXML 语义分块核心引擎
├── jrxml_banch_chunker.py # 批量分块入口脚本
├── down_embedding_model.py # 嵌入模型下载脚本
├── embed_chunks.py # Chunk 向量化脚本
├── import_to_chroma.py # 向量导入 Chroma 数据库
├── query_chroma.py # 语义搜索查询工具
├── jrxml_source/ # JRXML 源文件目录
├── jrxml_chunker_output/ # 分块输出目录
│ ├── all_chunks.json # 所有 chunks 合并文件
│ ├── processing_stats.json # 处理统计报告
│ └── per_file/ # 按文件分类的 chunks
├── models/ # 嵌入模型存放目录
│ └── Qwen3-Embedding-4B/ # Qwen3 嵌入模型
├── embeddings/ # 向量输出目录
│ ├── embeddings.npy # 向量矩阵
│ ├── chunks.json # 原始 chunks
│ └── embeddings.pkl # 完整数据 pickle
├── chroma_db/ # Chroma 向量数据库
└── docs/ # 项目文档
└── file_guide.md # 文件功能说明
```
## 快速开始
### 环境要求
- Python 3.11+
- NVIDIA GPU(推荐,8GB+ 显存)或 CPU
- CUDA 12.1+GPU 模式)
### 安装依赖
```bash
# 安装 PyTorch (CUDA 版本)
uv pip install torch --index-url https://download.pytorch.org/whl/cu130
# 安装其他依赖
uv pip install sentence-transformers chromadb numpy tqdm
```
### 完整流程
```bash
# 1. 收集 JRXML 文件
python collect_jrxml.py
# 2. 语义分块
python jrxml_banch_chunker.py ./jrxml_source --output ./jrxml_chunker_output
# 3. 下载嵌入模型(首次运行)
python down_embedding_model.py
# 4. 向量化
python embed_chunks.py --batch_size 2
# 5. 导入 Chroma 数据库
python import_to_chroma.py
# 6. 开始查询
python query_chroma.py
```
### 快速查询
```bash
# 交互模式
python query_chroma.py
# 单次查询
python query_chroma.py "如何修改报表标题"
# 按类型过滤
python query_chroma.py "SQL查询怎么写" --filter_field query
```
## 分块类型
系统将 JRXML 模板按以下语义类型进行分块:
| 类型 | 说明 |
|------|------|
| `report_overview` | 报告整体概览,含数据源分析 |
| `datasource_config` | 数据源配置属性 |
| `query` | 数据查询(SQL/HQL/XPath 等) |
| `parameters` | 参数定义 |
| `fields` | 字段定义 |
| `sortFields` | 排序字段 |
| `filterExpression` | 过滤表达式 |
| `variables_*` | 变量定义(按重置类型分组) |
| `styles` | 样式定义 |
| `groups` | 分组定义 |
| `band_*` | 标准带(title/detail/pageHeader 等) |
| `chart` | 图表元素 |
| `crosstab` | 交叉表元素 |
| `subreport` | 子报表元素 |
| `component` | 组件元素(列表等) |
| `dataset` | 数据集定义 |
## 技术栈
- **分块引擎**: 基于 XML 解析的语义分块器
- **嵌入模型**: Qwen3-Embedding-4B(支持 FP16 半精度)
- **向量数据库**: ChromaDB(持久化模式,余弦相似度)
- **嵌入框架**: Sentence-Transformers
- **深度学习**: PyTorch + CUDA
## 性能参考
| 硬件 | 模型 | Batch Size | 速度 |
|------|------|-----------|------|
| RTX 4060 Laptop 8GB | Qwen3-Embedding-4B (FP16) | 2 | ~1.2s/chunk |
| RTX 4060 Laptop 8GB | all-MiniLM-L6-v2 | 64 | ~0.001s/chunk |
> 离线建库是一次性开销,在线查询仅需 1-2 秒。
## License
MIT
+282
View File
@@ -0,0 +1,282 @@
# 文件功能说明
本文档详细解释项目中每个 Python 脚本的功能、输入输出和使用方式。
---
## 1. collect_jrxml.py — JRXML 文件收集脚本
**功能**: 从指定的 JasperReports 模板库目录递归收集所有 `.jrxml` 文件,复制到项目的 `jrxml_source` 目录。
**输入**:
- 源目录: `C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples`(可修改)
**输出**:
- `jrxml_source/` 目录,包含所有收集到的 JRXML 文件
**使用方式**:
```bash
python collect_jrxml.py
```
**核心逻辑**:
- 使用 `os.walk()` 递归遍历源目录
- 筛选 `.jrxml` 后缀文件
- 自动处理文件名冲突(添加数字后缀)
- 使用 `shutil.copy2()` 保留文件元数据
---
## 2. jrxml_chunker.py — JRXML 语义分块核心引擎
**功能**: 将单个 JRXML 文件按语义结构拆分为多个 chunk,每个 chunk 包含人类可读描述、原始 XML 和元数据。
**输入**:
- 单个 JRXML 文件路径
**输出**:
- `JRXMLChunk` 对象列表,每个包含:
- `chunk_id`: 唯一标识
- `chunk_type`: 分块类型(如 `query`, `field`, `band_title` 等)
- `human_description`: 人类可读的结构化描述
- `raw_xml`: 原始 XML 片段
- `context`: 上下文信息(所属报表名称)
- `metadata`: 元数据字典
**核心类**:
- `JRXMLChunk`: 单个 chunk 的数据结构
- `JRXMLSemanticChunker`: 主分块器,支持多种数据源类型(SQL、HQL、XPath、JSON、CSV 等)
**分块策略**:
- 按 XML 元素类型分类(field、parameter、variable、band、chart 等)
- 提取数据源配置和查询语句
- 保留元素间的层级关系
- 为每个 chunk 生成结构化的人类可读描述
**使用方式**:
```bash
# 处理单个文件
python jrxml_chunker.py report.jrxml
# 处理整个目录
python jrxml_chunker.py ./jrxml_source/
```
---
## 3. jrxml_banch_chunker.py — 批量分块入口脚本
**功能**: 批量处理目录下所有 JRXML 文件,生成统计报告和分类输出。
**输入**:
- JRXML 文件目录(默认: `jrxml_source`
**输出**:
- `jrxml_chunker_output/all_chunks.json`: 所有 chunks 合并文件
- `jrxml_chunker_output/processing_stats.json`: 处理统计(成功/失败数、耗时、chunk 类型分布)
- `jrxml_chunker_output/per_file/`: 按原文件分类的独立 chunk 文件
**核心函数**:
- `batch_chunk_with_report()`: 批量处理目录
- `chunk_single_file_with_report()`: 处理单个文件
**使用方式**:
```bash
# 使用默认输入目录
python jrxml_banch_chunker.py
# 指定输入目录
python jrxml_banch_chunker.py ./jrxml_source
# 指定输出目录
python jrxml_banch_chunker.py ./jrxml_source --output ./my_output
```
---
## 4. down_embedding_model.py — 嵌入模型下载脚本
**功能**: 从 HuggingFace Hub 下载 Qwen3-Embedding-4B 嵌入模型到本地。
**输入**:
- HuggingFace 模型仓库: `Qwen/Qwen3-Embedding-4B`
**输出**:
- `models/Qwen3-Embedding-4B/` 目录,包含完整的模型文件
**特性**:
- 使用国内镜像加速下载(`hf-mirror.com`
- 支持断点续传
- 自动安装依赖
**使用方式**:
```bash
python down_embedding_model.py
```
---
## 5. embed_chunks.py — Chunk 向量化脚本
**功能**: 使用嵌入模型将分块后的文本转换为向量表示,支持 GPU 加速和 FP16 半精度。
**输入**:
- `jrxml_chunker_output/all_chunks.json`(默认)
**输出**:
- `embeddings/embeddings.npy`: 向量矩阵(float32
- `embeddings/chunk_id_map.json`: chunk ID 映射
- `embeddings/chunk_type_map.json`: chunk 类型映射
- `embeddings/chunks.json`: 原始 chunks 副本
- `embeddings/embeddings.pkl`: 完整数据 pickle
**核心函数**:
- `build_text_for_embedding()`: 将 chunk 转换为适合向量化的文本(拼接类型、描述、XML、元数据)
- `main()`: 主流程(加载→编码→保存→质量检查)
**特性**:
- 自动检测 CUDA/CPU
- 默认启用 FP16 半精度(节省约 50% 显存)
- 支持 HuggingFace Hub 在线模型
- 向量归一化 + NaN 检测
**使用方式**:
```bash
# 使用默认设置
python embed_chunks.py
# 指定模型和 batch size
python embed_chunks.py --model_path "sentence-transformers/all-MiniLM-L6-v2" --batch_size 64
# 使用本地 Qwen3 模型
python embed_chunks.py --batch_size 2
# 禁用 FP16
python embed_chunks.py --no_fp16 --batch_size 1
```
---
## 6. import_to_chroma.py — 向量导入 Chroma 数据库
**功能**: 将已生成的向量和 chunks 导入 Chroma 持久化向量数据库。
**输入**:
- `embeddings/embeddings.npy`: 向量矩阵
- `embeddings/chunks.json`: chunks 数据
**输出**:
- `chroma_db/`: Chroma 持久化数据库目录
- 集合名称: `jrxml_chunks`(默认)
**核心逻辑**:
- 加载向量和 chunks
- 初始化 Chroma PersistentClient
- 创建集合(余弦相似度)
- 分批导入(每批 1000 条)
- 提取元数据(chunk_type、report_name、band_name 等)
- 快速验证查询
**使用方式**:
```bash
# 使用默认设置
python import_to_chroma.py
# 指定路径
python import_to_chroma.py --embeddings_dir ./embeddings --chroma_path ./chroma_db
```
---
## 7. query_chroma.py — 语义搜索查询工具
**功能**: 通过自然语言查询 Chroma 数据库,检索相关的 JRXML chunk。
**输入**:
- 用户自然语言查询
- 可选的元数据过滤条件
**输出**:
- 相似度排序的检索结果(含 chunk 类型、报表名称、区域、内容摘要)
**核心类**:
- `JRXMLSearcher`: 搜索器,封装模型加载、向量编码和 Chroma 查询
**核心方法**:
- `search()`: 基础语义搜索
- `search_with_threshold()`: 带相似度阈值的搜索
- `format_result()`: 格式化输出结果
**两种模式**:
1. **命令行单次查询**: `python query_chroma.py "查询内容"`
2. **交互模式**: `python query_chroma.py`(支持连续查询和内联命令)
**交互模式命令**:
```
filter:<类型> 按 chunk_type 过滤(如 filter:query
t:<阈值> 设置相似度阈值 0~1(如 t:0.5)
k:<数量> 设置返回结果数(如 k:10)
```
**使用方式**:
```bash
# 交互模式
python query_chroma.py
# 单次查询
python query_chroma.py "如何修改报表标题"
# 按类型过滤
python query_chroma.py "SQL怎么写" --filter_field query
# 设置阈值和返回数量
python query_chroma.py "报表参数" --threshold 0.5 --n_results 10
```
---
## 数据流全景
```
┌─────────────────┐
│ JasperReports │ C:\Users\...\JasperReportsSamples
│ 模板库 │
└────────┬────────┘
│ collect_jrxml.py
┌─────────────────┐
│ jrxml_source/ │ 收集的 JRXML 文件
└────────┬────────┘
│ jrxml_banch_chunker.py (调用 jrxml_chunker.py)
┌──────────────────────┐
│ jrxml_chunker_output/│ all_chunks.json + per_file/
└────────┬─────────────┘
│ embed_chunks.py (使用 Qwen3-Embedding-4B)
┌─────────────────┐
│ embeddings/ │ embeddings.npy + chunks.json
└────────┬────────┘
│ import_to_chroma.py
┌─────────────────┐
│ chroma_db/ │ Chroma 向量数据库
└────────┬────────┘
│ query_chroma.py
┌─────────────────┐
│ 用户查询 │ 自然语言 → 相关 JRXML chunks
└─────────────────┘
```
## 依赖关系
```
query_chroma.py ──────► chromadb, sentence_transformers, torch
import_to_chroma.py ──► chromadb, numpy
embed_chunks.py ──────► sentence_transformers, torch, numpy
down_embedding_model.py ► huggingface_hub
jrxml_banch_chunker.py ─► jrxml_chunker.py
jrxml_chunker.py ─────► xml.etree.ElementTree (标准库)
collect_jrxml.py ─────► 标准库 (os, shutil)
```
+117 -42
View File
@@ -4,10 +4,13 @@ embed_chunks.py
支持 GPU (CUDA) 或 CPU
"""
import os, sys, json, pickle
import os
import sys
import json
import pickle
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
def build_text_for_embedding(chunk: dict) -> str:
@@ -22,13 +25,11 @@ def build_text_for_embedding(chunk: dict) -> str:
context = chunk.get('context', '')
if context:
parts.append(f"Context: {context}")
# 添加部分 XML (前500字符)
raw_xml = chunk.get('raw_xml', '')
if raw_xml:
parts.append(f"XML: {raw_xml[:500]}")
# 添加元数据
meta = chunk.get('metadata', {})
if meta:
if 'field_names' in meta:
@@ -45,9 +46,10 @@ def build_text_for_embedding(chunk: dict) -> str:
parts.append(f"QueryLang: {meta['query_language']}")
return "\n".join(parts)
def main(chunks_json_path: str, output_dir: str = "./embeddings",
model_path: str = "./models/Qwen3-Embedding-4B",
batch_size: int = 16, normalize: bool = True):
def main(chunks_json_path: str = None, output_dir: str = None,
model_path: str = None, batch_size: int = 64, normalize: bool = True,
use_fp16: bool = True):
"""
主流程:
1. 加载 chunk JSON
@@ -55,29 +57,80 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings",
3. 构造文本并向量化
4. 保存向量及映射文件
"""
# --- 1. 加载 chunks ---
print(f"📄 Loading chunks from {chunks_json_path}")
project_root = Path(__file__).resolve().parent
if chunks_json_path is None:
chunks_json_path = project_root / "jrxml_chunker_output" / "all_chunks.json"
else:
chunks_json_path = Path(chunks_json_path)
if output_dir is None:
output_dir = project_root / "embeddings"
else:
output_dir = Path(output_dir)
if model_path is None:
model_path = project_root / "models" / "Qwen3-Embedding-4B"
else:
model_path = Path(model_path)
if not chunks_json_path.exists():
print(f"❌ Chunks 文件不存在: {chunks_json_path}")
print(f" 请先运行 jrxml_banch_chunker.py 生成 chunks")
return None
print(f"\n{'='*60}")
print(f"JRXML Chunks 向量化")
print(f"{'='*60}")
print(f"📄 加载 chunks: {chunks_json_path}")
with open(chunks_json_path, 'r', encoding='utf-8') as f:
chunks = json.load(f)
print(f" Total chunks: {len(chunks)}")
# --- 2. 加载模型 ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🧠 Loading embedding model from {model_path} on {device}")
model = SentenceTransformer(model_path, device=device)
if device == "cuda":
print(f" GPU memory allocated: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
print(f"\n🧠 加载嵌入模型: {model_path}")
print(f" 设备: {device}")
# --- 3. 构造文本 ---
print("🛠️ Building text representations...")
# 检查是否是 HuggingFace Hub 模型(格式为 username/model_name
model_path_str = str(model_path)
# Windows PowerShell 会把 / 自动转成 \,需要还原
if "\\" in model_path_str and not os.path.exists(model_path_str):
model_path_str = model_path_str.replace("\\", "/")
is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str)
# 如果是本地路径但不存在,则报错
if not is_hub_model and not os.path.exists(model_path_str):
print(f"❌ 模型目录不存在: {model_path}")
print(f" 请先下载模型到 {model_path}")
print(f" 或者使用 HuggingFace Hub 模型,例如: sentence-transformers/all-MiniLM-L6-v2")
return None
model = SentenceTransformer(model_path_str, device=device)
if device == "cuda" and use_fp16:
model = model.half()
torch.cuda.empty_cache()
mem_used = torch.cuda.memory_allocated(0) / 1024**3
total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
print(f" FP16 已启用")
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" GPU memory: {mem_used:.2f} GB / {total_mem:.2f} GB (FP16)")
elif device == "cuda":
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" GPU memory: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")
print(f"\n🛠️ 构建文本表示...")
texts = []
chunk_ids = []
chunk_types = []
for chunk in chunks:
texts.append(build_text_for_embedding(chunk))
chunk_ids.append(chunk.get('chunk_id', -1))
chunk_types.append(chunk.get('chunk_type', 'unknown'))
# --- 4. 向量化 ---
print(f"🔢 Encoding {len(texts)} texts (batch_size={batch_size})...")
print(f"\n🔢 向量化 {len(texts)} 个文本 (batch_size={batch_size})...")
embeddings = model.encode(
texts,
batch_size=batch_size,
@@ -87,19 +140,16 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings",
)
print(f" Embeddings shape: {embeddings.shape}")
# --- 5. 保存到输出目录 ---
os.makedirs(output_dir, exist_ok=True)
output_dir.mkdir(parents=True, exist_ok=True)
# 向量矩阵 (float32)
np.save(os.path.join(output_dir, "embeddings.npy"), embeddings.astype('float32'))
# chunk_id 映射
with open(os.path.join(output_dir, "chunk_id_map.json"), 'w') as f:
np.save(output_dir / "embeddings.npy", embeddings.astype('float32'))
with open(output_dir / "chunk_id_map.json", 'w', encoding='utf-8') as f:
json.dump(chunk_ids, f, ensure_ascii=False, indent=2)
# 原始 chunks 副本
with open(os.path.join(output_dir, "chunks.json"), 'w') as f:
with open(output_dir / "chunk_type_map.json", 'w', encoding='utf-8') as f:
json.dump(chunk_types, f, ensure_ascii=False, indent=2)
with open(output_dir / "chunks.json", 'w', encoding='utf-8') as f:
json.dump(chunks, f, ensure_ascii=False, indent=2)
# pickle 方便调试
with open(os.path.join(output_dir, "embeddings.pkl"), 'wb') as f:
with open(output_dir / "embeddings.pkl", 'wb') as f:
pickle.dump({
'chunks': chunks,
'embeddings': embeddings,
@@ -107,24 +157,48 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings",
'normalized': normalize
}, f)
# --- 6. 质量检查 ---
nan_count = np.isnan(embeddings).sum()
print(f"\n📊 Quality check:")
print(f"\n📊 质量检查:")
print(f" NaN values: {nan_count}")
norms = np.linalg.norm(embeddings, axis=1)
print(f" Norms: min={norms.min():.4f}, max={norms.max():.4f}, mean={norms.mean():.4f}")
print(f"\n✅ Embeddings saved to {output_dir}/")
print(f" Files: embeddings.npy, chunk_id_map.json, chunks.json, embeddings.pkl")
print(f"\n✅ 向量数据已保存到: {output_dir}/")
print(f" 文件: embeddings.npy, chunk_id_map.json, chunk_type_map.json, chunks.json, embeddings.pkl")
type_counts = {}
for ct in chunk_types:
type_counts[ct] = type_counts.get(ct, 0) + 1
print(f"\n📈 Chunk 类型分布:")
for ct, count in sorted(type_counts.items(), key=lambda x: -x[1]):
print(f" {ct}: {count}")
return {
"chunks": len(chunks),
"embedding_dim": embeddings.shape[1],
"output_dir": str(output_dir)
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("chunks_json", help="Path to all_chunks.json")
parser.add_argument("--output_dir", "-o", default="./embeddings")
parser.add_argument("--model_path", "-m", default="./models/Qwen3-Embedding-4B")
parser.add_argument("--batch_size", "-b", type=int, default=8,
help="Batch size (lower if OOM)")
parser.add_argument("--no_normalize", action="store_true")
project_root = Path(__file__).resolve().parent
default_chunks = project_root / "jrxml_chunker_output" / "all_chunks.json"
parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具")
parser.add_argument("chunks_json", nargs="?", default=str(default_chunks),
help=f"Chunks JSON 文件路径 (默认: {default_chunks})")
parser.add_argument("--output_dir", "-o", default=None,
help="输出目录 (默认: embeddings)")
parser.add_argument("--model_path", "-m", default=None,
help="模型路径 (默认: models/Qwen3-Embedding-4B)")
parser.add_argument("--batch_size", "-b", type=int, default=64,
help="批处理大小 (默认: 64)")
parser.add_argument("--no_normalize", action="store_true",
help="不做向量归一化")
parser.add_argument("--no_fp16", action="store_true",
help="禁用 FP16 半精度(默认启用,可节省约 50%% 显存)")
args = parser.parse_args()
main(
@@ -132,5 +206,6 @@ if __name__ == "__main__":
output_dir=args.output_dir,
model_path=args.model_path,
batch_size=args.batch_size,
normalize=not args.no_normalize
normalize=not args.no_normalize,
use_fp16=not args.no_fp16
)
+179
View File
@@ -0,0 +1,179 @@
"""
import_to_chroma.py
将已生成的 chunk 向量导入 Chroma 数据库
"""
import os
import json
import sys
import time
from pathlib import Path
import numpy as np
import chromadb
from tqdm import tqdm
def main(embeddings_dir: str = None,
chroma_path: str = None,
collection_name: str = "jrxml_chunks"):
"""
从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库
Args:
embeddings_dir: 包含 embeddings.npy, chunks.json 的目录
chroma_path: Chroma 持久化目录
collection_name: 集合名称
"""
project_root = Path(__file__).resolve().parent
if embeddings_dir is None:
embeddings_dir = project_root / "embeddings"
else:
embeddings_dir = Path(embeddings_dir)
if chroma_path is None:
chroma_path = project_root / "chroma_db"
else:
chroma_path = Path(chroma_path)
embeddings_file = embeddings_dir / "embeddings.npy"
chunks_file = embeddings_dir / "chunks.json"
for f in [embeddings_file, chunks_file]:
if not f.exists():
print(f"❌ 缺少文件: {f}")
print(f" 请先运行 embed_chunks.py 生成向量")
return None
print(f"\n{'='*60}")
print(f"JRXML Chunks 导入 Chroma 数据库")
print(f"{'='*60}")
print(f"\n📂 加载向量和 chunks...")
embeddings = np.load(embeddings_file).astype('float32')
with open(chunks_file, 'r', encoding='utf-8') as f:
chunks = json.load(f)
if len(embeddings) != len(chunks):
print(f"❌ 数量不匹配: {len(embeddings)} vs {len(chunks)}")
return None
print(f" 向量维度: {embeddings.shape[1]}")
print(f" Chunks 数量: {len(chunks)}")
print(f"\n💾 初始化 Chroma 数据库: {chroma_path}")
chroma_path.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(path=str(chroma_path))
try:
client.delete_collection(collection_name)
print(f" 已删除旧集合 '{collection_name}'")
except Exception:
pass
collection = client.create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
print(f"\n🛠️ 准备导入数据...")
ids = []
documents = []
metadatas = []
embeddings_list = []
seen_ids = {}
for i, chunk in enumerate(tqdm(chunks, desc="准备数据")):
raw_id = str(chunk.get("chunk_id", i))
if raw_id in seen_ids:
seen_ids[raw_id] += 1
chunk_id = f"{raw_id}_{seen_ids[raw_id]}"
else:
seen_ids[raw_id] = 0
chunk_id = raw_id
ids.append(chunk_id)
doc_text = chunk.get("human_description", "")
documents.append(doc_text)
meta = {}
chunk_type = chunk.get("chunk_type", "")
if chunk_type:
meta["chunk_type"] = chunk_type
context = chunk.get("context", "")
if context:
meta["context"] = context
chunk_meta = chunk.get("metadata", {})
if "report_name" in chunk_meta:
meta["report_name"] = chunk_meta["report_name"]
if "band_name" in chunk_meta:
meta["band_name"] = chunk_meta["band_name"]
if "element_kind" in chunk_meta:
meta["element_kind"] = chunk_meta["element_kind"]
if "query_language" in chunk_meta:
meta["query_language"] = chunk_meta["query_language"]
metadatas.append(meta)
embeddings_list.append(embeddings[i].tolist())
print(f"\n📥 分批导入到 Chroma (每批 1000 条)...")
import_batch_size = 1000
start_time = time.time()
for start in tqdm(range(0, len(ids), import_batch_size), desc="导入进度"):
end = min(start + import_batch_size, len(ids))
collection.add(
ids=ids[start:end],
documents=documents[start:end],
metadatas=metadatas[start:end],
embeddings=embeddings_list[start:end]
)
duration = time.time() - start_time
print(f"\n✅ 成功导入 {len(ids)} 个 chunks 到 '{collection_name}'")
print(f" 数据库路径: {chroma_path}")
print(f" 集合数量: {collection.count()}")
print(f" 导入耗时: {duration:.2f}s")
print(f"\n🔍 快速验证查询...")
results = collection.query(
query_embeddings=[embeddings_list[0]],
n_results=3,
include=["documents", "metadatas", "distances"]
)
distances = results.get('distances', [[]])
if distances and distances[0]:
print(f" Top-3 相似度距离: {[round(d, 4) for d in distances[0]]}")
first_doc = results.get('documents', [['']])[0][0]
print(f" 首位结果: {first_doc[:120]}...")
print(f"\n📊 元数据字段分布:")
all_keys = set()
for m in metadatas:
all_keys.update(m.keys())
for key in sorted(all_keys):
count = sum(1 for m in metadatas if key in m)
print(f" {key}: {count}")
return collection
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具")
parser.add_argument("--embeddings_dir", "-e", default=None,
help="向量文件目录 (默认: embeddings)")
parser.add_argument("--chroma_path", "-c", default=None,
help="Chroma 数据库路径 (默认: chroma_db)")
parser.add_argument("--collection_name", "-n", default="jrxml_chunks",
help="集合名称 (默认: jrxml_chunks)")
args = parser.parse_args()
main(
embeddings_dir=args.embeddings_dir,
chroma_path=args.chroma_path,
collection_name=args.collection_name
)
+269
View File
@@ -0,0 +1,269 @@
"""
query_chroma.py
查询 Chroma 数据库,从自然语言查找相关 JRXML chunk
支持命令行单次查询和交互式连续查询
"""
import os
import sys
import time
from pathlib import Path
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import chromadb
class JRXMLSearcher:
def __init__(self, chroma_path: str = None,
collection_name: str = "jrxml_chunks",
model_path: str = None,
use_fp16: bool = True):
project_root = Path(__file__).resolve().parent
if chroma_path is None:
chroma_path = str(project_root / "chroma_db")
if model_path is None:
model_path = str(project_root / "models" / "Qwen3-Embedding-4B")
# 处理 Hub 模型名称
model_path_str = str(model_path)
if "\\" in model_path_str and not os.path.exists(model_path_str):
model_path_str = model_path_str.replace("\\", "/")
# 加载嵌入模型
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🧠 加载模型: {model_path_str}")
print(f" 设备: {device}")
self.model = SentenceTransformer(model_path_str, device=device)
if device == "cuda" and use_fp16:
self.model = self.model.half()
torch.cuda.empty_cache()
mem = torch.cuda.memory_allocated(0) / 1024**3
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
print(f" FP16 已启用, 显存: {mem:.2f} GB / {total:.2f} GB")
# 连接 Chroma
print(f"💾 连接 Chroma: {chroma_path}")
self.client = chromadb.PersistentClient(path=chroma_path)
self.collection = self.client.get_collection(collection_name)
print(f" 集合 '{collection_name}': {self.collection.count()} 条记录\n")
def search(self, query: str, n_results: int = 5, filter_meta: dict = None):
query_embedding = self.model.encode(
query,
normalize_embeddings=True,
show_progress_bar=False
).tolist()
where_filter = filter_meta if filter_meta else None
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where_filter,
include=["documents", "metadatas", "distances"]
)
return results
def search_with_threshold(self, query: str, n_results: int = 5,
threshold: float = 0.3, filter_meta: dict = None):
results = self.search(query, n_results, filter_meta)
filtered = {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}
for i, dist in enumerate(results["distances"][0]):
if dist <= threshold:
filtered["ids"][0].append(results["ids"][0][i])
filtered["documents"][0].append(results["documents"][0][i])
filtered["metadatas"][0].append(results["metadatas"][0][i])
filtered["distances"][0].append(dist)
return filtered
def format_result(self, results: dict) -> str:
lines = []
n = len(results["ids"][0])
lines.append(f"找到 {n} 条结果:")
for i, (doc_id, doc, dist, meta) in enumerate(zip(
results["ids"][0],
results["documents"][0],
results["distances"][0],
results["metadatas"][0]
)):
chunk_type = meta.get("chunk_type", "N/A")
report = meta.get("report_name", "")
band = meta.get("band_name", "")
lines.append(f"\n--- 结果 {i+1} (相似度={1-dist:.4f}, id={doc_id}) ---")
lines.append(f"类型: {chunk_type}")
if report:
lines.append(f"报表: {report}")
if band:
lines.append(f"区域: {band}")
lines.append(f"内容: {doc[:300]}")
return "\n".join(lines)
def main():
import argparse
project_root = Path(__file__).resolve().parent
parser = argparse.ArgumentParser(description="JRXML Chunks 语义搜索工具")
parser.add_argument("query", nargs="?", default="",
help="搜索关键词(不提供则进入交互模式)")
parser.add_argument("--chroma_path", "-c", default=None,
help=f"Chroma 数据库路径 (默认: chroma_db)")
parser.add_argument("--collection", "-n", default="jrxml_chunks",
help="集合名称")
parser.add_argument("--model_path", "-m", default=None,
help="嵌入模型路径")
parser.add_argument("--n_results", "-k", type=int, default=5,
help="返回结果数 (默认: 5)")
parser.add_argument("--filter_field", "-f",
help="按 chunk_type 过滤,例如: field, query, chart")
parser.add_argument("--threshold", "-t", type=float,
help="相似度阈值 (0~1, 越高越相似)")
parser.add_argument("--no_fp16", action="store_true",
help="禁用 FP16 半精度")
args = parser.parse_args()
if args.chroma_path is None:
args.chroma_path = str(project_root / "chroma_db")
if args.model_path is None:
default_model = project_root / "models" / "Qwen3-Embedding-4B"
if not default_model.exists():
args.model_path = "sentence-transformers/all-MiniLM-L6-v2"
else:
args.model_path = str(default_model)
# 检查数据库
if not os.path.exists(args.chroma_path):
print(f"❌ Chroma 数据库不存在: {args.chroma_path}")
print(f" 请先运行 import_to_chroma.py 导入数据")
return
# 初始化搜索器
try:
searcher = JRXMLSearcher(
chroma_path=args.chroma_path,
collection_name=args.collection,
model_path=args.model_path,
use_fp16=not args.no_fp16
)
except Exception as e:
print(f"❌ 初始化失败: {e}")
return
# 准备过滤条件
filter_meta = None
if args.filter_field:
filter_meta = {"chunk_type": args.filter_field}
# 单次查询模式
if args.query:
query = args.query
print(f"\n 搜索: '{query}'")
if filter_meta:
print(f" 过滤: {filter_meta}")
start = time.time()
if args.threshold is not None:
results = searcher.search_with_threshold(
query, args.n_results, args.threshold, filter_meta
)
else:
results = searcher.search(query, args.n_results, filter_meta)
elapsed = time.time() - start
print(searcher.format_result(results))
print(f"\n⏱️ 耗时: {elapsed:.2f}s")
return
# 交互模式
print(f"\n{'='*60}")
print(f"JRXML 语义搜索 - 交互模式")
print(f"{'='*60}")
print(f"可用过滤类型: report_overview, query, field, parameter,")
print(f" variable, band_*, chart, crosstab, subreport, style 等")
print(f"示例: '如何修改报表标题'")
print(f" 'filter:query SQL数据源查询'")
print(f" 't:0.5 band:title 标题区域'")
print(f"输入 'help' 查看帮助, 'exit' 退出\n")
while True:
try:
user_input = input("🔍 搜索> ").strip()
except (EOFError, KeyboardInterrupt):
print("\n👋 再见!")
break
if not user_input:
continue
if user_input.lower() in ("exit", "quit", "q"):
print("👋 再见!")
break
if user_input.lower() == "help":
print("""
特殊命令:
filter:<类型> 按 chunk_type 过滤 (如 filter:query)
t:<阈值> 设置相似度阈值 0~1 (如 t:0.5)
k:<数量> 设置返回结果数 (如 k:10)
示例:
filter:field 数据源字段有哪些
t:0.5 band:title 标题区域怎么设置
k:10 报表参数定义
""")
continue
# 解析特殊命令
query_text = user_input
cur_filter = filter_meta
cur_n = args.n_results
cur_threshold = args.threshold
parts = user_input.split()
new_parts = []
for p in parts:
if p.startswith("filter:"):
field_val = p[len("filter:"):]
cur_filter = {"chunk_type": field_val}
print(f" 📌 过滤: {cur_filter}")
elif p.startswith("t:"):
try:
cur_threshold = float(p[2:])
print(f" 📌 阈值: {cur_threshold}")
except ValueError:
pass
elif p.startswith("k:"):
try:
cur_n = int(p[2:])
cur_n = max(1, min(cur_n, 50))
print(f" 📌 返回数量: {cur_n}")
except ValueError:
pass
else:
new_parts.append(p)
query_text = " ".join(new_parts)
if not query_text:
print(" ⚠️ 请输入搜索内容")
continue
print(f"🔍 搜索: '{query_text}'")
start = time.time()
if cur_threshold is not None:
results = searcher.search_with_threshold(
query_text, cur_n, cur_threshold, cur_filter
)
else:
results = searcher.search(query_text, cur_n, cur_filter)
elapsed = time.time() - start
print(searcher.format_result(results))
print(f"⏱️ 耗时: {elapsed:.2f}s\n")
if __name__ == "__main__":
main()