chore: 初始化JRXML RAG项目,添加基础文件
创建了完整的JRXML语义检索RAG项目,包含: 1. 新增.gitignore忽略项目生成的缓存、依赖目录和本地文件 2. 编写详细的项目README文档 3. 补充文件功能说明文档 4. 实现向量导入、向量化、查询等核心脚本
This commit is contained in:
+52
@@ -0,0 +1,52 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
*.egg
|
||||
|
||||
# 虚拟环境
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
|
||||
# 嵌入模型(体积大,通过 down_embedding_model.py 下载)
|
||||
models/
|
||||
|
||||
# 向量数据(体积大,通过脚本生成)
|
||||
embeddings/
|
||||
*.npy
|
||||
*.pkl
|
||||
|
||||
# Chroma 向量数据库(通过 import_to_chroma.py 生成)
|
||||
chroma_db/
|
||||
|
||||
# JRXML 源文件(通过 collect_jrxml.py 收集)
|
||||
jrxml_source/
|
||||
|
||||
# 分块输出(通过 jrxml_banch_chunker.py 生成)
|
||||
jrxml_chunker_output/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
desktop.ini
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# 环境变量
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# 日志
|
||||
*.log
|
||||
@@ -0,0 +1,131 @@
|
||||
# JRXML RAG 项目
|
||||
|
||||
基于 RAG(Retrieval-Augmented Generation)的 JasperReports JRXML 模板智能问答系统。
|
||||
|
||||
## 项目简介
|
||||
|
||||
本项目将 JasperReports 的 JRXML 模板文件进行语义分块、向量化,并存入 Chroma 向量数据库,实现通过自然语言查询来检索和理解报表模板的结构、配置和逻辑。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
RAG-jaspersoft/
|
||||
├── collect_jrxml.py # JRXML 文件收集脚本
|
||||
├── jrxml_chunker.py # JRXML 语义分块核心引擎
|
||||
├── jrxml_banch_chunker.py # 批量分块入口脚本
|
||||
├── down_embedding_model.py # 嵌入模型下载脚本
|
||||
├── embed_chunks.py # Chunk 向量化脚本
|
||||
├── import_to_chroma.py # 向量导入 Chroma 数据库
|
||||
├── query_chroma.py # 语义搜索查询工具
|
||||
├── jrxml_source/ # JRXML 源文件目录
|
||||
├── jrxml_chunker_output/ # 分块输出目录
|
||||
│ ├── all_chunks.json # 所有 chunks 合并文件
|
||||
│ ├── processing_stats.json # 处理统计报告
|
||||
│ └── per_file/ # 按文件分类的 chunks
|
||||
├── models/ # 嵌入模型存放目录
|
||||
│ └── Qwen3-Embedding-4B/ # Qwen3 嵌入模型
|
||||
├── embeddings/ # 向量输出目录
|
||||
│ ├── embeddings.npy # 向量矩阵
|
||||
│ ├── chunks.json # 原始 chunks
|
||||
│ └── embeddings.pkl # 完整数据 pickle
|
||||
├── chroma_db/ # Chroma 向量数据库
|
||||
└── docs/ # 项目文档
|
||||
└── file_guide.md # 文件功能说明
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Python 3.11+
|
||||
- NVIDIA GPU(推荐,8GB+ 显存)或 CPU
|
||||
- CUDA 12.1+(GPU 模式)
|
||||
|
||||
### 安装依赖
|
||||
|
||||
```bash
|
||||
# 安装 PyTorch (CUDA 版本)
|
||||
uv pip install torch --index-url https://download.pytorch.org/whl/cu130
|
||||
|
||||
# 安装其他依赖
|
||||
uv pip install sentence-transformers chromadb numpy tqdm
|
||||
```
|
||||
|
||||
### 完整流程
|
||||
|
||||
```bash
|
||||
# 1. 收集 JRXML 文件
|
||||
python collect_jrxml.py
|
||||
|
||||
# 2. 语义分块
|
||||
python jrxml_banch_chunker.py ./jrxml_source --output ./jrxml_chunker_output
|
||||
|
||||
# 3. 下载嵌入模型(首次运行)
|
||||
python down_embedding_model.py
|
||||
|
||||
# 4. 向量化
|
||||
python embed_chunks.py --batch_size 2
|
||||
|
||||
# 5. 导入 Chroma 数据库
|
||||
python import_to_chroma.py
|
||||
|
||||
# 6. 开始查询
|
||||
python query_chroma.py
|
||||
```
|
||||
|
||||
### 快速查询
|
||||
|
||||
```bash
|
||||
# 交互模式
|
||||
python query_chroma.py
|
||||
|
||||
# 单次查询
|
||||
python query_chroma.py "如何修改报表标题"
|
||||
|
||||
# 按类型过滤
|
||||
python query_chroma.py "SQL查询怎么写" --filter_field query
|
||||
```
|
||||
|
||||
## 分块类型
|
||||
|
||||
系统将 JRXML 模板按以下语义类型进行分块:
|
||||
|
||||
| 类型 | 说明 |
|
||||
|------|------|
|
||||
| `report_overview` | 报告整体概览,含数据源分析 |
|
||||
| `datasource_config` | 数据源配置属性 |
|
||||
| `query` | 数据查询(SQL/HQL/XPath 等) |
|
||||
| `parameters` | 参数定义 |
|
||||
| `fields` | 字段定义 |
|
||||
| `sortFields` | 排序字段 |
|
||||
| `filterExpression` | 过滤表达式 |
|
||||
| `variables_*` | 变量定义(按重置类型分组) |
|
||||
| `styles` | 样式定义 |
|
||||
| `groups` | 分组定义 |
|
||||
| `band_*` | 标准带(title/detail/pageHeader 等) |
|
||||
| `chart` | 图表元素 |
|
||||
| `crosstab` | 交叉表元素 |
|
||||
| `subreport` | 子报表元素 |
|
||||
| `component` | 组件元素(列表等) |
|
||||
| `dataset` | 数据集定义 |
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **分块引擎**: 基于 XML 解析的语义分块器
|
||||
- **嵌入模型**: Qwen3-Embedding-4B(支持 FP16 半精度)
|
||||
- **向量数据库**: ChromaDB(持久化模式,余弦相似度)
|
||||
- **嵌入框架**: Sentence-Transformers
|
||||
- **深度学习**: PyTorch + CUDA
|
||||
|
||||
## 性能参考
|
||||
|
||||
| 硬件 | 模型 | Batch Size | 速度 |
|
||||
|------|------|-----------|------|
|
||||
| RTX 4060 Laptop 8GB | Qwen3-Embedding-4B (FP16) | 2 | ~1.2s/chunk |
|
||||
| RTX 4060 Laptop 8GB | all-MiniLM-L6-v2 | 64 | ~0.001s/chunk |
|
||||
|
||||
> 离线建库是一次性开销,在线查询仅需 1-2 秒。
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
@@ -0,0 +1,282 @@
|
||||
# 文件功能说明
|
||||
|
||||
本文档详细解释项目中每个 Python 脚本的功能、输入输出和使用方式。
|
||||
|
||||
---
|
||||
|
||||
## 1. collect_jrxml.py — JRXML 文件收集脚本
|
||||
|
||||
**功能**: 从指定的 JasperReports 模板库目录递归收集所有 `.jrxml` 文件,复制到项目的 `jrxml_source` 目录。
|
||||
|
||||
**输入**:
|
||||
- 源目录: `C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples`(可修改)
|
||||
|
||||
**输出**:
|
||||
- `jrxml_source/` 目录,包含所有收集到的 JRXML 文件
|
||||
|
||||
**使用方式**:
|
||||
```bash
|
||||
python collect_jrxml.py
|
||||
```
|
||||
|
||||
**核心逻辑**:
|
||||
- 使用 `os.walk()` 递归遍历源目录
|
||||
- 筛选 `.jrxml` 后缀文件
|
||||
- 自动处理文件名冲突(添加数字后缀)
|
||||
- 使用 `shutil.copy2()` 保留文件元数据
|
||||
|
||||
---
|
||||
|
||||
## 2. jrxml_chunker.py — JRXML 语义分块核心引擎
|
||||
|
||||
**功能**: 将单个 JRXML 文件按语义结构拆分为多个 chunk,每个 chunk 包含人类可读描述、原始 XML 和元数据。
|
||||
|
||||
**输入**:
|
||||
- 单个 JRXML 文件路径
|
||||
|
||||
**输出**:
|
||||
- `JRXMLChunk` 对象列表,每个包含:
|
||||
- `chunk_id`: 唯一标识
|
||||
- `chunk_type`: 分块类型(如 `query`, `field`, `band_title` 等)
|
||||
- `human_description`: 人类可读的结构化描述
|
||||
- `raw_xml`: 原始 XML 片段
|
||||
- `context`: 上下文信息(所属报表名称)
|
||||
- `metadata`: 元数据字典
|
||||
|
||||
**核心类**:
|
||||
- `JRXMLChunk`: 单个 chunk 的数据结构
|
||||
- `JRXMLSemanticChunker`: 主分块器,支持多种数据源类型(SQL、HQL、XPath、JSON、CSV 等)
|
||||
|
||||
**分块策略**:
|
||||
- 按 XML 元素类型分类(field、parameter、variable、band、chart 等)
|
||||
- 提取数据源配置和查询语句
|
||||
- 保留元素间的层级关系
|
||||
- 为每个 chunk 生成结构化的人类可读描述
|
||||
|
||||
**使用方式**:
|
||||
```bash
|
||||
# 处理单个文件
|
||||
python jrxml_chunker.py report.jrxml
|
||||
|
||||
# 处理整个目录
|
||||
python jrxml_chunker.py ./jrxml_source/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. jrxml_banch_chunker.py — 批量分块入口脚本
|
||||
|
||||
**功能**: 批量处理目录下所有 JRXML 文件,生成统计报告和分类输出。
|
||||
|
||||
**输入**:
|
||||
- JRXML 文件目录(默认: `jrxml_source`)
|
||||
|
||||
**输出**:
|
||||
- `jrxml_chunker_output/all_chunks.json`: 所有 chunks 合并文件
|
||||
- `jrxml_chunker_output/processing_stats.json`: 处理统计(成功/失败数、耗时、chunk 类型分布)
|
||||
- `jrxml_chunker_output/per_file/`: 按原文件分类的独立 chunk 文件
|
||||
|
||||
**核心函数**:
|
||||
- `batch_chunk_with_report()`: 批量处理目录
|
||||
- `chunk_single_file_with_report()`: 处理单个文件
|
||||
|
||||
**使用方式**:
|
||||
```bash
|
||||
# 使用默认输入目录
|
||||
python jrxml_banch_chunker.py
|
||||
|
||||
# 指定输入目录
|
||||
python jrxml_banch_chunker.py ./jrxml_source
|
||||
|
||||
# 指定输出目录
|
||||
python jrxml_banch_chunker.py ./jrxml_source --output ./my_output
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. down_embedding_model.py — 嵌入模型下载脚本
|
||||
|
||||
**功能**: 从 HuggingFace Hub 下载 Qwen3-Embedding-4B 嵌入模型到本地。
|
||||
|
||||
**输入**:
|
||||
- HuggingFace 模型仓库: `Qwen/Qwen3-Embedding-4B`
|
||||
|
||||
**输出**:
|
||||
- `models/Qwen3-Embedding-4B/` 目录,包含完整的模型文件
|
||||
|
||||
**特性**:
|
||||
- 使用国内镜像加速下载(`hf-mirror.com`)
|
||||
- 支持断点续传
|
||||
- 自动安装依赖
|
||||
|
||||
**使用方式**:
|
||||
```bash
|
||||
python down_embedding_model.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. embed_chunks.py — Chunk 向量化脚本
|
||||
|
||||
**功能**: 使用嵌入模型将分块后的文本转换为向量表示,支持 GPU 加速和 FP16 半精度。
|
||||
|
||||
**输入**:
|
||||
- `jrxml_chunker_output/all_chunks.json`(默认)
|
||||
|
||||
**输出**:
|
||||
- `embeddings/embeddings.npy`: 向量矩阵(float32)
|
||||
- `embeddings/chunk_id_map.json`: chunk ID 映射
|
||||
- `embeddings/chunk_type_map.json`: chunk 类型映射
|
||||
- `embeddings/chunks.json`: 原始 chunks 副本
|
||||
- `embeddings/embeddings.pkl`: 完整数据 pickle
|
||||
|
||||
**核心函数**:
|
||||
- `build_text_for_embedding()`: 将 chunk 转换为适合向量化的文本(拼接类型、描述、XML、元数据)
|
||||
- `main()`: 主流程(加载→编码→保存→质量检查)
|
||||
|
||||
**特性**:
|
||||
- 自动检测 CUDA/CPU
|
||||
- 默认启用 FP16 半精度(节省约 50% 显存)
|
||||
- 支持 HuggingFace Hub 在线模型
|
||||
- 向量归一化 + NaN 检测
|
||||
|
||||
**使用方式**:
|
||||
```bash
|
||||
# 使用默认设置
|
||||
python embed_chunks.py
|
||||
|
||||
# 指定模型和 batch size
|
||||
python embed_chunks.py --model_path "sentence-transformers/all-MiniLM-L6-v2" --batch_size 64
|
||||
|
||||
# 使用本地 Qwen3 模型
|
||||
python embed_chunks.py --batch_size 2
|
||||
|
||||
# 禁用 FP16
|
||||
python embed_chunks.py --no_fp16 --batch_size 1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. import_to_chroma.py — 向量导入 Chroma 数据库
|
||||
|
||||
**功能**: 将已生成的向量和 chunks 导入 Chroma 持久化向量数据库。
|
||||
|
||||
**输入**:
|
||||
- `embeddings/embeddings.npy`: 向量矩阵
|
||||
- `embeddings/chunks.json`: chunks 数据
|
||||
|
||||
**输出**:
|
||||
- `chroma_db/`: Chroma 持久化数据库目录
|
||||
- 集合名称: `jrxml_chunks`(默认)
|
||||
|
||||
**核心逻辑**:
|
||||
- 加载向量和 chunks
|
||||
- 初始化 Chroma PersistentClient
|
||||
- 创建集合(余弦相似度)
|
||||
- 分批导入(每批 1000 条)
|
||||
- 提取元数据(chunk_type、report_name、band_name 等)
|
||||
- 快速验证查询
|
||||
|
||||
**使用方式**:
|
||||
```bash
|
||||
# 使用默认设置
|
||||
python import_to_chroma.py
|
||||
|
||||
# 指定路径
|
||||
python import_to_chroma.py --embeddings_dir ./embeddings --chroma_path ./chroma_db
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. query_chroma.py — 语义搜索查询工具
|
||||
|
||||
**功能**: 通过自然语言查询 Chroma 数据库,检索相关的 JRXML chunk。
|
||||
|
||||
**输入**:
|
||||
- 用户自然语言查询
|
||||
- 可选的元数据过滤条件
|
||||
|
||||
**输出**:
|
||||
- 相似度排序的检索结果(含 chunk 类型、报表名称、区域、内容摘要)
|
||||
|
||||
**核心类**:
|
||||
- `JRXMLSearcher`: 搜索器,封装模型加载、向量编码和 Chroma 查询
|
||||
|
||||
**核心方法**:
|
||||
- `search()`: 基础语义搜索
|
||||
- `search_with_threshold()`: 带相似度阈值的搜索
|
||||
- `format_result()`: 格式化输出结果
|
||||
|
||||
**两种模式**:
|
||||
1. **命令行单次查询**: `python query_chroma.py "查询内容"`
|
||||
2. **交互模式**: `python query_chroma.py`(支持连续查询和内联命令)
|
||||
|
||||
**交互模式命令**:
|
||||
```
|
||||
filter:<类型> 按 chunk_type 过滤(如 filter:query)
|
||||
t:<阈值> 设置相似度阈值 0~1(如 t:0.5)
|
||||
k:<数量> 设置返回结果数(如 k:10)
|
||||
```
|
||||
|
||||
**使用方式**:
|
||||
```bash
|
||||
# 交互模式
|
||||
python query_chroma.py
|
||||
|
||||
# 单次查询
|
||||
python query_chroma.py "如何修改报表标题"
|
||||
|
||||
# 按类型过滤
|
||||
python query_chroma.py "SQL怎么写" --filter_field query
|
||||
|
||||
# 设置阈值和返回数量
|
||||
python query_chroma.py "报表参数" --threshold 0.5 --n_results 10
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 数据流全景
|
||||
|
||||
```
|
||||
┌─────────────────┐
|
||||
│ JasperReports │ C:\Users\...\JasperReportsSamples
|
||||
│ 模板库 │
|
||||
└────────┬────────┘
|
||||
│ collect_jrxml.py
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ jrxml_source/ │ 收集的 JRXML 文件
|
||||
└────────┬────────┘
|
||||
│ jrxml_banch_chunker.py (调用 jrxml_chunker.py)
|
||||
▼
|
||||
┌──────────────────────┐
|
||||
│ jrxml_chunker_output/│ all_chunks.json + per_file/
|
||||
└────────┬─────────────┘
|
||||
│ embed_chunks.py (使用 Qwen3-Embedding-4B)
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ embeddings/ │ embeddings.npy + chunks.json
|
||||
└────────┬────────┘
|
||||
│ import_to_chroma.py
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ chroma_db/ │ Chroma 向量数据库
|
||||
└────────┬────────┘
|
||||
│ query_chroma.py
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ 用户查询 │ 自然语言 → 相关 JRXML chunks
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
## 依赖关系
|
||||
|
||||
```
|
||||
query_chroma.py ──────► chromadb, sentence_transformers, torch
|
||||
import_to_chroma.py ──► chromadb, numpy
|
||||
embed_chunks.py ──────► sentence_transformers, torch, numpy
|
||||
down_embedding_model.py ► huggingface_hub
|
||||
jrxml_banch_chunker.py ─► jrxml_chunker.py
|
||||
jrxml_chunker.py ─────► xml.etree.ElementTree (标准库)
|
||||
collect_jrxml.py ─────► 标准库 (os, shutil)
|
||||
```
|
||||
+115
-40
@@ -4,10 +4,13 @@ embed_chunks.py
|
||||
支持 GPU (CUDA) 或 CPU
|
||||
"""
|
||||
|
||||
import os, sys, json, pickle
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
def build_text_for_embedding(chunk: dict) -> str:
|
||||
@@ -23,12 +26,10 @@ def build_text_for_embedding(chunk: dict) -> str:
|
||||
if context:
|
||||
parts.append(f"Context: {context}")
|
||||
|
||||
# 添加部分 XML (前500字符)
|
||||
raw_xml = chunk.get('raw_xml', '')
|
||||
if raw_xml:
|
||||
parts.append(f"XML: {raw_xml[:500]}")
|
||||
|
||||
# 添加元数据
|
||||
meta = chunk.get('metadata', {})
|
||||
if meta:
|
||||
if 'field_names' in meta:
|
||||
@@ -45,9 +46,10 @@ def build_text_for_embedding(chunk: dict) -> str:
|
||||
parts.append(f"QueryLang: {meta['query_language']}")
|
||||
return "\n".join(parts)
|
||||
|
||||
def main(chunks_json_path: str, output_dir: str = "./embeddings",
|
||||
model_path: str = "./models/Qwen3-Embedding-4B",
|
||||
batch_size: int = 16, normalize: bool = True):
|
||||
|
||||
def main(chunks_json_path: str = None, output_dir: str = None,
|
||||
model_path: str = None, batch_size: int = 64, normalize: bool = True,
|
||||
use_fp16: bool = True):
|
||||
"""
|
||||
主流程:
|
||||
1. 加载 chunk JSON
|
||||
@@ -55,29 +57,80 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings",
|
||||
3. 构造文本并向量化
|
||||
4. 保存向量及映射文件
|
||||
"""
|
||||
# --- 1. 加载 chunks ---
|
||||
print(f"📄 Loading chunks from {chunks_json_path}")
|
||||
project_root = Path(__file__).resolve().parent
|
||||
|
||||
if chunks_json_path is None:
|
||||
chunks_json_path = project_root / "jrxml_chunker_output" / "all_chunks.json"
|
||||
else:
|
||||
chunks_json_path = Path(chunks_json_path)
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = project_root / "embeddings"
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if model_path is None:
|
||||
model_path = project_root / "models" / "Qwen3-Embedding-4B"
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not chunks_json_path.exists():
|
||||
print(f"❌ Chunks 文件不存在: {chunks_json_path}")
|
||||
print(f" 请先运行 jrxml_banch_chunker.py 生成 chunks")
|
||||
return None
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"JRXML Chunks 向量化")
|
||||
print(f"{'='*60}")
|
||||
print(f"📄 加载 chunks: {chunks_json_path}")
|
||||
with open(chunks_json_path, 'r', encoding='utf-8') as f:
|
||||
chunks = json.load(f)
|
||||
print(f" Total chunks: {len(chunks)}")
|
||||
|
||||
# --- 2. 加载模型 ---
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🧠 Loading embedding model from {model_path} on {device}")
|
||||
model = SentenceTransformer(model_path, device=device)
|
||||
if device == "cuda":
|
||||
print(f" GPU memory allocated: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
|
||||
print(f"\n🧠 加载嵌入模型: {model_path}")
|
||||
print(f" 设备: {device}")
|
||||
|
||||
# --- 3. 构造文本 ---
|
||||
print("🛠️ Building text representations...")
|
||||
# 检查是否是 HuggingFace Hub 模型(格式为 username/model_name)
|
||||
model_path_str = str(model_path)
|
||||
# Windows PowerShell 会把 / 自动转成 \,需要还原
|
||||
if "\\" in model_path_str and not os.path.exists(model_path_str):
|
||||
model_path_str = model_path_str.replace("\\", "/")
|
||||
|
||||
is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str)
|
||||
|
||||
# 如果是本地路径但不存在,则报错
|
||||
if not is_hub_model and not os.path.exists(model_path_str):
|
||||
print(f"❌ 模型目录不存在: {model_path}")
|
||||
print(f" 请先下载模型到 {model_path}")
|
||||
print(f" 或者使用 HuggingFace Hub 模型,例如: sentence-transformers/all-MiniLM-L6-v2")
|
||||
return None
|
||||
|
||||
model = SentenceTransformer(model_path_str, device=device)
|
||||
|
||||
if device == "cuda" and use_fp16:
|
||||
model = model.half()
|
||||
torch.cuda.empty_cache()
|
||||
mem_used = torch.cuda.memory_allocated(0) / 1024**3
|
||||
total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
print(f" FP16 已启用")
|
||||
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f" GPU memory: {mem_used:.2f} GB / {total_mem:.2f} GB (FP16)")
|
||||
elif device == "cuda":
|
||||
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f" GPU memory: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")
|
||||
|
||||
print(f"\n🛠️ 构建文本表示...")
|
||||
texts = []
|
||||
chunk_ids = []
|
||||
chunk_types = []
|
||||
|
||||
for chunk in chunks:
|
||||
texts.append(build_text_for_embedding(chunk))
|
||||
chunk_ids.append(chunk.get('chunk_id', -1))
|
||||
chunk_types.append(chunk.get('chunk_type', 'unknown'))
|
||||
|
||||
# --- 4. 向量化 ---
|
||||
print(f"🔢 Encoding {len(texts)} texts (batch_size={batch_size})...")
|
||||
print(f"\n🔢 向量化 {len(texts)} 个文本 (batch_size={batch_size})...")
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
@@ -87,19 +140,16 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings",
|
||||
)
|
||||
print(f" Embeddings shape: {embeddings.shape}")
|
||||
|
||||
# --- 5. 保存到输出目录 ---
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 向量矩阵 (float32)
|
||||
np.save(os.path.join(output_dir, "embeddings.npy"), embeddings.astype('float32'))
|
||||
# chunk_id 映射
|
||||
with open(os.path.join(output_dir, "chunk_id_map.json"), 'w') as f:
|
||||
np.save(output_dir / "embeddings.npy", embeddings.astype('float32'))
|
||||
with open(output_dir / "chunk_id_map.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(chunk_ids, f, ensure_ascii=False, indent=2)
|
||||
# 原始 chunks 副本
|
||||
with open(os.path.join(output_dir, "chunks.json"), 'w') as f:
|
||||
with open(output_dir / "chunk_type_map.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(chunk_types, f, ensure_ascii=False, indent=2)
|
||||
with open(output_dir / "chunks.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(chunks, f, ensure_ascii=False, indent=2)
|
||||
# pickle 方便调试
|
||||
with open(os.path.join(output_dir, "embeddings.pkl"), 'wb') as f:
|
||||
with open(output_dir / "embeddings.pkl", 'wb') as f:
|
||||
pickle.dump({
|
||||
'chunks': chunks,
|
||||
'embeddings': embeddings,
|
||||
@@ -107,24 +157,48 @@ def main(chunks_json_path: str, output_dir: str = "./embeddings",
|
||||
'normalized': normalize
|
||||
}, f)
|
||||
|
||||
# --- 6. 质量检查 ---
|
||||
nan_count = np.isnan(embeddings).sum()
|
||||
print(f"\n📊 Quality check:")
|
||||
print(f"\n📊 质量检查:")
|
||||
print(f" NaN values: {nan_count}")
|
||||
norms = np.linalg.norm(embeddings, axis=1)
|
||||
print(f" Norms: min={norms.min():.4f}, max={norms.max():.4f}, mean={norms.mean():.4f}")
|
||||
print(f"\n✅ Embeddings saved to {output_dir}/")
|
||||
print(f" Files: embeddings.npy, chunk_id_map.json, chunks.json, embeddings.pkl")
|
||||
|
||||
print(f"\n✅ 向量数据已保存到: {output_dir}/")
|
||||
print(f" 文件: embeddings.npy, chunk_id_map.json, chunk_type_map.json, chunks.json, embeddings.pkl")
|
||||
|
||||
type_counts = {}
|
||||
for ct in chunk_types:
|
||||
type_counts[ct] = type_counts.get(ct, 0) + 1
|
||||
print(f"\n📈 Chunk 类型分布:")
|
||||
for ct, count in sorted(type_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {ct}: {count}")
|
||||
|
||||
return {
|
||||
"chunks": len(chunks),
|
||||
"embedding_dim": embeddings.shape[1],
|
||||
"output_dir": str(output_dir)
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("chunks_json", help="Path to all_chunks.json")
|
||||
parser.add_argument("--output_dir", "-o", default="./embeddings")
|
||||
parser.add_argument("--model_path", "-m", default="./models/Qwen3-Embedding-4B")
|
||||
parser.add_argument("--batch_size", "-b", type=int, default=8,
|
||||
help="Batch size (lower if OOM)")
|
||||
parser.add_argument("--no_normalize", action="store_true")
|
||||
project_root = Path(__file__).resolve().parent
|
||||
default_chunks = project_root / "jrxml_chunker_output" / "all_chunks.json"
|
||||
|
||||
parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具")
|
||||
parser.add_argument("chunks_json", nargs="?", default=str(default_chunks),
|
||||
help=f"Chunks JSON 文件路径 (默认: {default_chunks})")
|
||||
parser.add_argument("--output_dir", "-o", default=None,
|
||||
help="输出目录 (默认: embeddings)")
|
||||
parser.add_argument("--model_path", "-m", default=None,
|
||||
help="模型路径 (默认: models/Qwen3-Embedding-4B)")
|
||||
parser.add_argument("--batch_size", "-b", type=int, default=64,
|
||||
help="批处理大小 (默认: 64)")
|
||||
parser.add_argument("--no_normalize", action="store_true",
|
||||
help="不做向量归一化")
|
||||
parser.add_argument("--no_fp16", action="store_true",
|
||||
help="禁用 FP16 半精度(默认启用,可节省约 50%% 显存)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(
|
||||
@@ -132,5 +206,6 @@ if __name__ == "__main__":
|
||||
output_dir=args.output_dir,
|
||||
model_path=args.model_path,
|
||||
batch_size=args.batch_size,
|
||||
normalize=not args.no_normalize
|
||||
normalize=not args.no_normalize,
|
||||
use_fp16=not args.no_fp16
|
||||
)
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
import_to_chroma.py
|
||||
将已生成的 chunk 向量导入 Chroma 数据库
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import chromadb
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main(embeddings_dir: str = None,
|
||||
chroma_path: str = None,
|
||||
collection_name: str = "jrxml_chunks"):
|
||||
"""
|
||||
从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库
|
||||
|
||||
Args:
|
||||
embeddings_dir: 包含 embeddings.npy, chunks.json 的目录
|
||||
chroma_path: Chroma 持久化目录
|
||||
collection_name: 集合名称
|
||||
"""
|
||||
project_root = Path(__file__).resolve().parent
|
||||
|
||||
if embeddings_dir is None:
|
||||
embeddings_dir = project_root / "embeddings"
|
||||
else:
|
||||
embeddings_dir = Path(embeddings_dir)
|
||||
|
||||
if chroma_path is None:
|
||||
chroma_path = project_root / "chroma_db"
|
||||
else:
|
||||
chroma_path = Path(chroma_path)
|
||||
|
||||
embeddings_file = embeddings_dir / "embeddings.npy"
|
||||
chunks_file = embeddings_dir / "chunks.json"
|
||||
|
||||
for f in [embeddings_file, chunks_file]:
|
||||
if not f.exists():
|
||||
print(f"❌ 缺少文件: {f}")
|
||||
print(f" 请先运行 embed_chunks.py 生成向量")
|
||||
return None
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"JRXML Chunks 导入 Chroma 数据库")
|
||||
print(f"{'='*60}")
|
||||
|
||||
print(f"\n📂 加载向量和 chunks...")
|
||||
embeddings = np.load(embeddings_file).astype('float32')
|
||||
with open(chunks_file, 'r', encoding='utf-8') as f:
|
||||
chunks = json.load(f)
|
||||
|
||||
if len(embeddings) != len(chunks):
|
||||
print(f"❌ 数量不匹配: {len(embeddings)} vs {len(chunks)}")
|
||||
return None
|
||||
|
||||
print(f" 向量维度: {embeddings.shape[1]}")
|
||||
print(f" Chunks 数量: {len(chunks)}")
|
||||
|
||||
print(f"\n💾 初始化 Chroma 数据库: {chroma_path}")
|
||||
chroma_path.mkdir(parents=True, exist_ok=True)
|
||||
client = chromadb.PersistentClient(path=str(chroma_path))
|
||||
|
||||
try:
|
||||
client.delete_collection(collection_name)
|
||||
print(f" 已删除旧集合 '{collection_name}'")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
collection = client.create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
print(f"\n🛠️ 准备导入数据...")
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
embeddings_list = []
|
||||
|
||||
seen_ids = {}
|
||||
for i, chunk in enumerate(tqdm(chunks, desc="准备数据")):
|
||||
raw_id = str(chunk.get("chunk_id", i))
|
||||
if raw_id in seen_ids:
|
||||
seen_ids[raw_id] += 1
|
||||
chunk_id = f"{raw_id}_{seen_ids[raw_id]}"
|
||||
else:
|
||||
seen_ids[raw_id] = 0
|
||||
chunk_id = raw_id
|
||||
ids.append(chunk_id)
|
||||
|
||||
doc_text = chunk.get("human_description", "")
|
||||
documents.append(doc_text)
|
||||
|
||||
meta = {}
|
||||
chunk_type = chunk.get("chunk_type", "")
|
||||
if chunk_type:
|
||||
meta["chunk_type"] = chunk_type
|
||||
|
||||
context = chunk.get("context", "")
|
||||
if context:
|
||||
meta["context"] = context
|
||||
|
||||
chunk_meta = chunk.get("metadata", {})
|
||||
if "report_name" in chunk_meta:
|
||||
meta["report_name"] = chunk_meta["report_name"]
|
||||
if "band_name" in chunk_meta:
|
||||
meta["band_name"] = chunk_meta["band_name"]
|
||||
if "element_kind" in chunk_meta:
|
||||
meta["element_kind"] = chunk_meta["element_kind"]
|
||||
if "query_language" in chunk_meta:
|
||||
meta["query_language"] = chunk_meta["query_language"]
|
||||
|
||||
metadatas.append(meta)
|
||||
embeddings_list.append(embeddings[i].tolist())
|
||||
|
||||
print(f"\n📥 分批导入到 Chroma (每批 1000 条)...")
|
||||
import_batch_size = 1000
|
||||
start_time = time.time()
|
||||
|
||||
for start in tqdm(range(0, len(ids), import_batch_size), desc="导入进度"):
|
||||
end = min(start + import_batch_size, len(ids))
|
||||
collection.add(
|
||||
ids=ids[start:end],
|
||||
documents=documents[start:end],
|
||||
metadatas=metadatas[start:end],
|
||||
embeddings=embeddings_list[start:end]
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(f"\n✅ 成功导入 {len(ids)} 个 chunks 到 '{collection_name}'")
|
||||
print(f" 数据库路径: {chroma_path}")
|
||||
print(f" 集合数量: {collection.count()}")
|
||||
print(f" 导入耗时: {duration:.2f}s")
|
||||
|
||||
print(f"\n🔍 快速验证查询...")
|
||||
results = collection.query(
|
||||
query_embeddings=[embeddings_list[0]],
|
||||
n_results=3,
|
||||
include=["documents", "metadatas", "distances"]
|
||||
)
|
||||
distances = results.get('distances', [[]])
|
||||
if distances and distances[0]:
|
||||
print(f" Top-3 相似度距离: {[round(d, 4) for d in distances[0]]}")
|
||||
first_doc = results.get('documents', [['']])[0][0]
|
||||
print(f" 首位结果: {first_doc[:120]}...")
|
||||
|
||||
print(f"\n📊 元数据字段分布:")
|
||||
all_keys = set()
|
||||
for m in metadatas:
|
||||
all_keys.update(m.keys())
|
||||
for key in sorted(all_keys):
|
||||
count = sum(1 for m in metadatas if key in m)
|
||||
print(f" {key}: {count}")
|
||||
|
||||
return collection
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具")
|
||||
parser.add_argument("--embeddings_dir", "-e", default=None,
|
||||
help="向量文件目录 (默认: embeddings)")
|
||||
parser.add_argument("--chroma_path", "-c", default=None,
|
||||
help="Chroma 数据库路径 (默认: chroma_db)")
|
||||
parser.add_argument("--collection_name", "-n", default="jrxml_chunks",
|
||||
help="集合名称 (默认: jrxml_chunks)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(
|
||||
embeddings_dir=args.embeddings_dir,
|
||||
chroma_path=args.chroma_path,
|
||||
collection_name=args.collection_name
|
||||
)
|
||||
+269
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
query_chroma.py
|
||||
查询 Chroma 数据库,从自然语言查找相关 JRXML chunk
|
||||
支持命令行单次查询和交互式连续查询
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import chromadb
|
||||
|
||||
|
||||
class JRXMLSearcher:
|
||||
def __init__(self, chroma_path: str = None,
|
||||
collection_name: str = "jrxml_chunks",
|
||||
model_path: str = None,
|
||||
use_fp16: bool = True):
|
||||
project_root = Path(__file__).resolve().parent
|
||||
|
||||
if chroma_path is None:
|
||||
chroma_path = str(project_root / "chroma_db")
|
||||
if model_path is None:
|
||||
model_path = str(project_root / "models" / "Qwen3-Embedding-4B")
|
||||
|
||||
# 处理 Hub 模型名称
|
||||
model_path_str = str(model_path)
|
||||
if "\\" in model_path_str and not os.path.exists(model_path_str):
|
||||
model_path_str = model_path_str.replace("\\", "/")
|
||||
|
||||
# 加载嵌入模型
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🧠 加载模型: {model_path_str}")
|
||||
print(f" 设备: {device}")
|
||||
self.model = SentenceTransformer(model_path_str, device=device)
|
||||
|
||||
if device == "cuda" and use_fp16:
|
||||
self.model = self.model.half()
|
||||
torch.cuda.empty_cache()
|
||||
mem = torch.cuda.memory_allocated(0) / 1024**3
|
||||
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
print(f" FP16 已启用, 显存: {mem:.2f} GB / {total:.2f} GB")
|
||||
|
||||
# 连接 Chroma
|
||||
print(f"💾 连接 Chroma: {chroma_path}")
|
||||
self.client = chromadb.PersistentClient(path=chroma_path)
|
||||
self.collection = self.client.get_collection(collection_name)
|
||||
print(f" 集合 '{collection_name}': {self.collection.count()} 条记录\n")
|
||||
|
||||
def search(self, query: str, n_results: int = 5, filter_meta: dict = None):
|
||||
query_embedding = self.model.encode(
|
||||
query,
|
||||
normalize_embeddings=True,
|
||||
show_progress_bar=False
|
||||
).tolist()
|
||||
|
||||
where_filter = filter_meta if filter_meta else None
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_results,
|
||||
where=where_filter,
|
||||
include=["documents", "metadatas", "distances"]
|
||||
)
|
||||
return results
|
||||
|
||||
def search_with_threshold(self, query: str, n_results: int = 5,
|
||||
threshold: float = 0.3, filter_meta: dict = None):
|
||||
results = self.search(query, n_results, filter_meta)
|
||||
filtered = {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}
|
||||
|
||||
for i, dist in enumerate(results["distances"][0]):
|
||||
if dist <= threshold:
|
||||
filtered["ids"][0].append(results["ids"][0][i])
|
||||
filtered["documents"][0].append(results["documents"][0][i])
|
||||
filtered["metadatas"][0].append(results["metadatas"][0][i])
|
||||
filtered["distances"][0].append(dist)
|
||||
return filtered
|
||||
|
||||
def format_result(self, results: dict) -> str:
|
||||
lines = []
|
||||
n = len(results["ids"][0])
|
||||
lines.append(f"找到 {n} 条结果:")
|
||||
for i, (doc_id, doc, dist, meta) in enumerate(zip(
|
||||
results["ids"][0],
|
||||
results["documents"][0],
|
||||
results["distances"][0],
|
||||
results["metadatas"][0]
|
||||
)):
|
||||
chunk_type = meta.get("chunk_type", "N/A")
|
||||
report = meta.get("report_name", "")
|
||||
band = meta.get("band_name", "")
|
||||
lines.append(f"\n--- 结果 {i+1} (相似度={1-dist:.4f}, id={doc_id}) ---")
|
||||
lines.append(f"类型: {chunk_type}")
|
||||
if report:
|
||||
lines.append(f"报表: {report}")
|
||||
if band:
|
||||
lines.append(f"区域: {band}")
|
||||
lines.append(f"内容: {doc[:300]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
project_root = Path(__file__).resolve().parent
|
||||
|
||||
parser = argparse.ArgumentParser(description="JRXML Chunks 语义搜索工具")
|
||||
parser.add_argument("query", nargs="?", default="",
|
||||
help="搜索关键词(不提供则进入交互模式)")
|
||||
parser.add_argument("--chroma_path", "-c", default=None,
|
||||
help=f"Chroma 数据库路径 (默认: chroma_db)")
|
||||
parser.add_argument("--collection", "-n", default="jrxml_chunks",
|
||||
help="集合名称")
|
||||
parser.add_argument("--model_path", "-m", default=None,
|
||||
help="嵌入模型路径")
|
||||
parser.add_argument("--n_results", "-k", type=int, default=5,
|
||||
help="返回结果数 (默认: 5)")
|
||||
parser.add_argument("--filter_field", "-f",
|
||||
help="按 chunk_type 过滤,例如: field, query, chart")
|
||||
parser.add_argument("--threshold", "-t", type=float,
|
||||
help="相似度阈值 (0~1, 越高越相似)")
|
||||
parser.add_argument("--no_fp16", action="store_true",
|
||||
help="禁用 FP16 半精度")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.chroma_path is None:
|
||||
args.chroma_path = str(project_root / "chroma_db")
|
||||
|
||||
if args.model_path is None:
|
||||
default_model = project_root / "models" / "Qwen3-Embedding-4B"
|
||||
if not default_model.exists():
|
||||
args.model_path = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
else:
|
||||
args.model_path = str(default_model)
|
||||
|
||||
# 检查数据库
|
||||
if not os.path.exists(args.chroma_path):
|
||||
print(f"❌ Chroma 数据库不存在: {args.chroma_path}")
|
||||
print(f" 请先运行 import_to_chroma.py 导入数据")
|
||||
return
|
||||
|
||||
# 初始化搜索器
|
||||
try:
|
||||
searcher = JRXMLSearcher(
|
||||
chroma_path=args.chroma_path,
|
||||
collection_name=args.collection,
|
||||
model_path=args.model_path,
|
||||
use_fp16=not args.no_fp16
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化失败: {e}")
|
||||
return
|
||||
|
||||
# 准备过滤条件
|
||||
filter_meta = None
|
||||
if args.filter_field:
|
||||
filter_meta = {"chunk_type": args.filter_field}
|
||||
|
||||
# 单次查询模式
|
||||
if args.query:
|
||||
query = args.query
|
||||
print(f"\n� 搜索: '{query}'")
|
||||
if filter_meta:
|
||||
print(f" 过滤: {filter_meta}")
|
||||
|
||||
start = time.time()
|
||||
if args.threshold is not None:
|
||||
results = searcher.search_with_threshold(
|
||||
query, args.n_results, args.threshold, filter_meta
|
||||
)
|
||||
else:
|
||||
results = searcher.search(query, args.n_results, filter_meta)
|
||||
elapsed = time.time() - start
|
||||
|
||||
print(searcher.format_result(results))
|
||||
print(f"\n⏱️ 耗时: {elapsed:.2f}s")
|
||||
return
|
||||
|
||||
# 交互模式
|
||||
print(f"\n{'='*60}")
|
||||
print(f"JRXML 语义搜索 - 交互模式")
|
||||
print(f"{'='*60}")
|
||||
print(f"可用过滤类型: report_overview, query, field, parameter,")
|
||||
print(f" variable, band_*, chart, crosstab, subreport, style 等")
|
||||
print(f"示例: '如何修改报表标题'")
|
||||
print(f" 'filter:query SQL数据源查询'")
|
||||
print(f" 't:0.5 band:title 标题区域'")
|
||||
print(f"输入 'help' 查看帮助, 'exit' 退出\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("🔍 搜索> ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n👋 再见!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
if user_input.lower() in ("exit", "quit", "q"):
|
||||
print("👋 再见!")
|
||||
break
|
||||
if user_input.lower() == "help":
|
||||
print("""
|
||||
特殊命令:
|
||||
filter:<类型> 按 chunk_type 过滤 (如 filter:query)
|
||||
t:<阈值> 设置相似度阈值 0~1 (如 t:0.5)
|
||||
k:<数量> 设置返回结果数 (如 k:10)
|
||||
|
||||
示例:
|
||||
filter:field 数据源字段有哪些
|
||||
t:0.5 band:title 标题区域怎么设置
|
||||
k:10 报表参数定义
|
||||
""")
|
||||
continue
|
||||
|
||||
# 解析特殊命令
|
||||
query_text = user_input
|
||||
cur_filter = filter_meta
|
||||
cur_n = args.n_results
|
||||
cur_threshold = args.threshold
|
||||
|
||||
parts = user_input.split()
|
||||
new_parts = []
|
||||
for p in parts:
|
||||
if p.startswith("filter:"):
|
||||
field_val = p[len("filter:"):]
|
||||
cur_filter = {"chunk_type": field_val}
|
||||
print(f" 📌 过滤: {cur_filter}")
|
||||
elif p.startswith("t:"):
|
||||
try:
|
||||
cur_threshold = float(p[2:])
|
||||
print(f" 📌 阈值: {cur_threshold}")
|
||||
except ValueError:
|
||||
pass
|
||||
elif p.startswith("k:"):
|
||||
try:
|
||||
cur_n = int(p[2:])
|
||||
cur_n = max(1, min(cur_n, 50))
|
||||
print(f" 📌 返回数量: {cur_n}")
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
new_parts.append(p)
|
||||
query_text = " ".join(new_parts)
|
||||
|
||||
if not query_text:
|
||||
print(" ⚠️ 请输入搜索内容")
|
||||
continue
|
||||
|
||||
print(f"🔍 搜索: '{query_text}'")
|
||||
start = time.time()
|
||||
|
||||
if cur_threshold is not None:
|
||||
results = searcher.search_with_threshold(
|
||||
query_text, cur_n, cur_threshold, cur_filter
|
||||
)
|
||||
else:
|
||||
results = searcher.search(query_text, cur_n, cur_filter)
|
||||
|
||||
elapsed = time.time() - start
|
||||
print(searcher.format_result(results))
|
||||
print(f"⏱️ 耗时: {elapsed:.2f}s\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user