bd98486de0
创建了完整的JRXML语义检索RAG项目,包含: 1. 新增.gitignore忽略项目生成的缓存、依赖目录和本地文件 2. 编写详细的项目README文档 3. 补充文件功能说明文档 4. 实现向量导入、向量化、查询等核心脚本
269 lines
9.5 KiB
Python
269 lines
9.5 KiB
Python
"""
|
|
query_chroma.py
|
|
查询 Chroma 数据库,从自然语言查找相关 JRXML chunk
|
|
支持命令行单次查询和交互式连续查询
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
import chromadb
|
|
|
|
|
|
class JRXMLSearcher:
|
|
def __init__(self, chroma_path: str = None,
|
|
collection_name: str = "jrxml_chunks",
|
|
model_path: str = None,
|
|
use_fp16: bool = True):
|
|
project_root = Path(__file__).resolve().parent
|
|
|
|
if chroma_path is None:
|
|
chroma_path = str(project_root / "chroma_db")
|
|
if model_path is None:
|
|
model_path = str(project_root / "models" / "Qwen3-Embedding-4B")
|
|
|
|
# 处理 Hub 模型名称
|
|
model_path_str = str(model_path)
|
|
if "\\" in model_path_str and not os.path.exists(model_path_str):
|
|
model_path_str = model_path_str.replace("\\", "/")
|
|
|
|
# 加载嵌入模型
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"🧠 加载模型: {model_path_str}")
|
|
print(f" 设备: {device}")
|
|
self.model = SentenceTransformer(model_path_str, device=device)
|
|
|
|
if device == "cuda" and use_fp16:
|
|
self.model = self.model.half()
|
|
torch.cuda.empty_cache()
|
|
mem = torch.cuda.memory_allocated(0) / 1024**3
|
|
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
print(f" FP16 已启用, 显存: {mem:.2f} GB / {total:.2f} GB")
|
|
|
|
# 连接 Chroma
|
|
print(f"💾 连接 Chroma: {chroma_path}")
|
|
self.client = chromadb.PersistentClient(path=chroma_path)
|
|
self.collection = self.client.get_collection(collection_name)
|
|
print(f" 集合 '{collection_name}': {self.collection.count()} 条记录\n")
|
|
|
|
def search(self, query: str, n_results: int = 5, filter_meta: dict = None):
|
|
query_embedding = self.model.encode(
|
|
query,
|
|
normalize_embeddings=True,
|
|
show_progress_bar=False
|
|
).tolist()
|
|
|
|
where_filter = filter_meta if filter_meta else None
|
|
results = self.collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_results,
|
|
where=where_filter,
|
|
include=["documents", "metadatas", "distances"]
|
|
)
|
|
return results
|
|
|
|
def search_with_threshold(self, query: str, n_results: int = 5,
|
|
threshold: float = 0.3, filter_meta: dict = None):
|
|
results = self.search(query, n_results, filter_meta)
|
|
filtered = {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}
|
|
|
|
for i, dist in enumerate(results["distances"][0]):
|
|
if dist <= threshold:
|
|
filtered["ids"][0].append(results["ids"][0][i])
|
|
filtered["documents"][0].append(results["documents"][0][i])
|
|
filtered["metadatas"][0].append(results["metadatas"][0][i])
|
|
filtered["distances"][0].append(dist)
|
|
return filtered
|
|
|
|
def format_result(self, results: dict) -> str:
|
|
lines = []
|
|
n = len(results["ids"][0])
|
|
lines.append(f"找到 {n} 条结果:")
|
|
for i, (doc_id, doc, dist, meta) in enumerate(zip(
|
|
results["ids"][0],
|
|
results["documents"][0],
|
|
results["distances"][0],
|
|
results["metadatas"][0]
|
|
)):
|
|
chunk_type = meta.get("chunk_type", "N/A")
|
|
report = meta.get("report_name", "")
|
|
band = meta.get("band_name", "")
|
|
lines.append(f"\n--- 结果 {i+1} (相似度={1-dist:.4f}, id={doc_id}) ---")
|
|
lines.append(f"类型: {chunk_type}")
|
|
if report:
|
|
lines.append(f"报表: {report}")
|
|
if band:
|
|
lines.append(f"区域: {band}")
|
|
lines.append(f"内容: {doc[:300]}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
project_root = Path(__file__).resolve().parent
|
|
|
|
parser = argparse.ArgumentParser(description="JRXML Chunks 语义搜索工具")
|
|
parser.add_argument("query", nargs="?", default="",
|
|
help="搜索关键词(不提供则进入交互模式)")
|
|
parser.add_argument("--chroma_path", "-c", default=None,
|
|
help=f"Chroma 数据库路径 (默认: chroma_db)")
|
|
parser.add_argument("--collection", "-n", default="jrxml_chunks",
|
|
help="集合名称")
|
|
parser.add_argument("--model_path", "-m", default=None,
|
|
help="嵌入模型路径")
|
|
parser.add_argument("--n_results", "-k", type=int, default=5,
|
|
help="返回结果数 (默认: 5)")
|
|
parser.add_argument("--filter_field", "-f",
|
|
help="按 chunk_type 过滤,例如: field, query, chart")
|
|
parser.add_argument("--threshold", "-t", type=float,
|
|
help="相似度阈值 (0~1, 越高越相似)")
|
|
parser.add_argument("--no_fp16", action="store_true",
|
|
help="禁用 FP16 半精度")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.chroma_path is None:
|
|
args.chroma_path = str(project_root / "chroma_db")
|
|
|
|
if args.model_path is None:
|
|
default_model = project_root / "models" / "Qwen3-Embedding-4B"
|
|
if not default_model.exists():
|
|
args.model_path = "sentence-transformers/all-MiniLM-L6-v2"
|
|
else:
|
|
args.model_path = str(default_model)
|
|
|
|
# 检查数据库
|
|
if not os.path.exists(args.chroma_path):
|
|
print(f"❌ Chroma 数据库不存在: {args.chroma_path}")
|
|
print(f" 请先运行 import_to_chroma.py 导入数据")
|
|
return
|
|
|
|
# 初始化搜索器
|
|
try:
|
|
searcher = JRXMLSearcher(
|
|
chroma_path=args.chroma_path,
|
|
collection_name=args.collection,
|
|
model_path=args.model_path,
|
|
use_fp16=not args.no_fp16
|
|
)
|
|
except Exception as e:
|
|
print(f"❌ 初始化失败: {e}")
|
|
return
|
|
|
|
# 准备过滤条件
|
|
filter_meta = None
|
|
if args.filter_field:
|
|
filter_meta = {"chunk_type": args.filter_field}
|
|
|
|
# 单次查询模式
|
|
if args.query:
|
|
query = args.query
|
|
print(f"\n� 搜索: '{query}'")
|
|
if filter_meta:
|
|
print(f" 过滤: {filter_meta}")
|
|
|
|
start = time.time()
|
|
if args.threshold is not None:
|
|
results = searcher.search_with_threshold(
|
|
query, args.n_results, args.threshold, filter_meta
|
|
)
|
|
else:
|
|
results = searcher.search(query, args.n_results, filter_meta)
|
|
elapsed = time.time() - start
|
|
|
|
print(searcher.format_result(results))
|
|
print(f"\n⏱️ 耗时: {elapsed:.2f}s")
|
|
return
|
|
|
|
# 交互模式
|
|
print(f"\n{'='*60}")
|
|
print(f"JRXML 语义搜索 - 交互模式")
|
|
print(f"{'='*60}")
|
|
print(f"可用过滤类型: report_overview, query, field, parameter,")
|
|
print(f" variable, band_*, chart, crosstab, subreport, style 等")
|
|
print(f"示例: '如何修改报表标题'")
|
|
print(f" 'filter:query SQL数据源查询'")
|
|
print(f" 't:0.5 band:title 标题区域'")
|
|
print(f"输入 'help' 查看帮助, 'exit' 退出\n")
|
|
|
|
while True:
|
|
try:
|
|
user_input = input("🔍 搜索> ").strip()
|
|
except (EOFError, KeyboardInterrupt):
|
|
print("\n👋 再见!")
|
|
break
|
|
|
|
if not user_input:
|
|
continue
|
|
if user_input.lower() in ("exit", "quit", "q"):
|
|
print("👋 再见!")
|
|
break
|
|
if user_input.lower() == "help":
|
|
print("""
|
|
特殊命令:
|
|
filter:<类型> 按 chunk_type 过滤 (如 filter:query)
|
|
t:<阈值> 设置相似度阈值 0~1 (如 t:0.5)
|
|
k:<数量> 设置返回结果数 (如 k:10)
|
|
|
|
示例:
|
|
filter:field 数据源字段有哪些
|
|
t:0.5 band:title 标题区域怎么设置
|
|
k:10 报表参数定义
|
|
""")
|
|
continue
|
|
|
|
# 解析特殊命令
|
|
query_text = user_input
|
|
cur_filter = filter_meta
|
|
cur_n = args.n_results
|
|
cur_threshold = args.threshold
|
|
|
|
parts = user_input.split()
|
|
new_parts = []
|
|
for p in parts:
|
|
if p.startswith("filter:"):
|
|
field_val = p[len("filter:"):]
|
|
cur_filter = {"chunk_type": field_val}
|
|
print(f" 📌 过滤: {cur_filter}")
|
|
elif p.startswith("t:"):
|
|
try:
|
|
cur_threshold = float(p[2:])
|
|
print(f" 📌 阈值: {cur_threshold}")
|
|
except ValueError:
|
|
pass
|
|
elif p.startswith("k:"):
|
|
try:
|
|
cur_n = int(p[2:])
|
|
cur_n = max(1, min(cur_n, 50))
|
|
print(f" 📌 返回数量: {cur_n}")
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
new_parts.append(p)
|
|
query_text = " ".join(new_parts)
|
|
|
|
if not query_text:
|
|
print(" ⚠️ 请输入搜索内容")
|
|
continue
|
|
|
|
print(f"🔍 搜索: '{query_text}'")
|
|
start = time.time()
|
|
|
|
if cur_threshold is not None:
|
|
results = searcher.search_with_threshold(
|
|
query_text, cur_n, cur_threshold, cur_filter
|
|
)
|
|
else:
|
|
results = searcher.search(query_text, cur_n, cur_filter)
|
|
|
|
elapsed = time.time() - start
|
|
print(searcher.format_result(results))
|
|
print(f"⏱️ 耗时: {elapsed:.2f}s\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |