Files
rag_jrxml/query_chroma.py
panda 9d78a49625 refactor: 重构项目配置管理,统一使用.env配置
- 新增config.py统一读取.env配置,移除硬编码路径和参数
- 重构collect_jrxml.py支持命令行参数和环境变量配置源目录
- 新增.env.example示例配置文件,整理所有可配置项
- 重构down_embedding_model.py、import_to_chroma.py等所有脚本使用统一配置
- 新增Windows一键部署脚本setup.bat
- 修正jrxml_banch_chunker.py的文件名拼写错误
2026-05-12 08:29:17 +08:00

274 lines
9.6 KiB
Python

"""
query_chroma.py
查询 Chroma 数据库,从自然语言查找相关 JRXML chunk
支持命令行单次查询和交互式连续查询
模型通过 .env / config.py 配置
"""
import os
import sys
import time
from pathlib import Path
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import chromadb
from config import (
CHROMA_DB_PATH, CHROMA_COLLECTION_NAME, USE_FP16,
DEFAULT_N_RESULTS, SIMILARITY_THRESHOLD, resolve_model_path
)
class JRXMLSearcher:
def __init__(self, chroma_path: str = None,
collection_name: str = None,
model_path: str = None,
use_fp16: bool = None):
project_root = Path(__file__).resolve().parent
if chroma_path is None:
chroma_path = str(CHROMA_DB_PATH)
if collection_name is None:
collection_name = CHROMA_COLLECTION_NAME
if model_path is None:
model_path = resolve_model_path()
if use_fp16 is None:
use_fp16 = USE_FP16
# 处理 Hub 模型名称
model_path_str = str(model_path)
if "\\" in model_path_str and not os.path.exists(model_path_str):
model_path_str = model_path_str.replace("\\", "/")
# 加载嵌入模型
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🧠 加载模型: {model_path_str}")
print(f" 设备: {device}")
self.model = SentenceTransformer(model_path_str, device=device)
if device == "cuda" and use_fp16:
self.model = self.model.half()
torch.cuda.empty_cache()
mem = torch.cuda.memory_allocated(0) / 1024**3
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
print(f" FP16 已启用, 显存: {mem:.2f} GB / {total:.2f} GB")
# 连接 Chroma
print(f"💾 连接 Chroma: {chroma_path}")
self.client = chromadb.PersistentClient(path=chroma_path)
self.collection = self.client.get_collection(collection_name)
print(f" 集合 '{collection_name}': {self.collection.count()} 条记录\n")
def search(self, query: str, n_results: int = 5, filter_meta: dict = None):
query_embedding = self.model.encode(
query,
normalize_embeddings=True,
show_progress_bar=False
).tolist()
where_filter = filter_meta if filter_meta else None
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where_filter,
include=["documents", "metadatas", "distances"]
)
return results
def search_with_threshold(self, query: str, n_results: int = 5,
threshold: float = 0.3, filter_meta: dict = None):
results = self.search(query, n_results, filter_meta)
filtered = {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}
for i, dist in enumerate(results["distances"][0]):
if dist <= threshold:
filtered["ids"][0].append(results["ids"][0][i])
filtered["documents"][0].append(results["documents"][0][i])
filtered["metadatas"][0].append(results["metadatas"][0][i])
filtered["distances"][0].append(dist)
return filtered
def format_result(self, results: dict) -> str:
lines = []
n = len(results["ids"][0])
lines.append(f"找到 {n} 条结果:")
for i, (doc_id, doc, dist, meta) in enumerate(zip(
results["ids"][0],
results["documents"][0],
results["distances"][0],
results["metadatas"][0]
)):
chunk_type = meta.get("chunk_type", "N/A")
report = meta.get("report_name", "")
band = meta.get("band_name", "")
lines.append(f"\n--- 结果 {i+1} (相似度={1-dist:.4f}, id={doc_id}) ---")
lines.append(f"类型: {chunk_type}")
if report:
lines.append(f"报表: {report}")
if band:
lines.append(f"区域: {band}")
lines.append(f"内容: {doc[:300]}")
return "\n".join(lines)
def main():
import argparse
project_root = Path(__file__).resolve().parent
parser = argparse.ArgumentParser(description="JRXML Chunks 语义搜索工具")
parser.add_argument("query", nargs="?", default="",
help="搜索关键词(不提供则进入交互模式)")
parser.add_argument("--chroma_path", "-c", default=None,
help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})")
parser.add_argument("--collection", "-n", default=CHROMA_COLLECTION_NAME,
help="集合名称")
parser.add_argument("--model_path", "-m", default=None,
help="嵌入模型路径")
parser.add_argument("--n_results", "-k", type=int, default=DEFAULT_N_RESULTS,
help=f"返回结果数 (默认: {DEFAULT_N_RESULTS})")
parser.add_argument("--filter_field", "-f",
help="按 chunk_type 过滤,例如: field, query, chart")
parser.add_argument("--threshold", "-t", type=float,
help="相似度阈值 (0~1, 越高越相似)")
parser.add_argument("--no_fp16", action="store_true",
help="禁用 FP16 半精度")
args = parser.parse_args()
if args.chroma_path is None:
args.chroma_path = str(CHROMA_DB_PATH)
if args.model_path is None:
args.model_path = resolve_model_path()
# 检查数据库
if not os.path.exists(args.chroma_path):
print(f"❌ Chroma 数据库不存在: {args.chroma_path}")
print(f" 请先运行 import_to_chroma.py 导入数据")
return
# 初始化搜索器
try:
searcher = JRXMLSearcher(
chroma_path=args.chroma_path,
collection_name=args.collection,
model_path=args.model_path,
use_fp16=not args.no_fp16
)
except Exception as e:
print(f"❌ 初始化失败: {e}")
return
# 准备过滤条件
filter_meta = None
if args.filter_field:
filter_meta = {"chunk_type": args.filter_field}
# 单次查询模式
if args.query:
query = args.query
print(f"\n 搜索: '{query}'")
if filter_meta:
print(f" 过滤: {filter_meta}")
start = time.time()
if args.threshold is not None:
results = searcher.search_with_threshold(
query, args.n_results, args.threshold, filter_meta
)
else:
results = searcher.search(query, args.n_results, filter_meta)
elapsed = time.time() - start
print(searcher.format_result(results))
print(f"\n⏱️ 耗时: {elapsed:.2f}s")
return
# 交互模式
print(f"\n{'='*60}")
print(f"JRXML 语义搜索 - 交互模式")
print(f"{'='*60}")
print(f"可用过滤类型: report_overview, query, field, parameter,")
print(f" variable, band_*, chart, crosstab, subreport, style 等")
print(f"示例: '如何修改报表标题'")
print(f" 'filter:query SQL数据源查询'")
print(f" 't:0.5 band:title 标题区域'")
print(f"输入 'help' 查看帮助, 'exit' 退出\n")
while True:
try:
user_input = input("🔍 搜索> ").strip()
except (EOFError, KeyboardInterrupt):
print("\n👋 再见!")
break
if not user_input:
continue
if user_input.lower() in ("exit", "quit", "q"):
print("👋 再见!")
break
if user_input.lower() == "help":
print("""
特殊命令:
filter:<类型> 按 chunk_type 过滤 (如 filter:query)
t:<阈值> 设置相似度阈值 0~1 (如 t:0.5)
k:<数量> 设置返回结果数 (如 k:10)
示例:
filter:field 数据源字段有哪些
t:0.5 band:title 标题区域怎么设置
k:10 报表参数定义
""")
continue
# 解析特殊命令
query_text = user_input
cur_filter = filter_meta
cur_n = args.n_results
cur_threshold = args.threshold
parts = user_input.split()
new_parts = []
for p in parts:
if p.startswith("filter:"):
field_val = p[len("filter:"):]
cur_filter = {"chunk_type": field_val}
print(f" 📌 过滤: {cur_filter}")
elif p.startswith("t:"):
try:
cur_threshold = float(p[2:])
print(f" 📌 阈值: {cur_threshold}")
except ValueError:
pass
elif p.startswith("k:"):
try:
cur_n = int(p[2:])
cur_n = max(1, min(cur_n, 50))
print(f" 📌 返回数量: {cur_n}")
except ValueError:
pass
else:
new_parts.append(p)
query_text = " ".join(new_parts)
if not query_text:
print(" ⚠️ 请输入搜索内容")
continue
print(f"🔍 搜索: '{query_text}'")
start = time.time()
if cur_threshold is not None:
results = searcher.search_with_threshold(
query_text, cur_n, cur_threshold, cur_filter
)
else:
results = searcher.search(query_text, cur_n, cur_filter)
elapsed = time.time() - start
print(searcher.format_result(results))
print(f"⏱️ 耗时: {elapsed:.2f}s\n")
if __name__ == "__main__":
main()