""" query_chroma.py 查询 Chroma 数据库,从自然语言查找相关 JRXML chunk 支持命令行单次查询和交互式连续查询 模型通过 .env / config.py 配置 """ import os import sys import time from pathlib import Path import numpy as np import torch from sentence_transformers import SentenceTransformer import chromadb from config import ( CHROMA_DB_PATH, CHROMA_COLLECTION_NAME, USE_FP16, DEFAULT_N_RESULTS, SIMILARITY_THRESHOLD, resolve_model_path ) class JRXMLSearcher: def __init__(self, chroma_path: str = None, collection_name: str = None, model_path: str = None, use_fp16: bool = None): project_root = Path(__file__).resolve().parent if chroma_path is None: chroma_path = str(CHROMA_DB_PATH) if collection_name is None: collection_name = CHROMA_COLLECTION_NAME if model_path is None: model_path = resolve_model_path() if use_fp16 is None: use_fp16 = USE_FP16 # 处理 Hub 模型名称 model_path_str = str(model_path) if "\\" in model_path_str and not os.path.exists(model_path_str): model_path_str = model_path_str.replace("\\", "/") # 加载嵌入模型 device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🧠 加载模型: {model_path_str}") print(f" 设备: {device}") self.model = SentenceTransformer(model_path_str, device=device) if device == "cuda" and use_fp16: self.model = self.model.half() torch.cuda.empty_cache() mem = torch.cuda.memory_allocated(0) / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 print(f" FP16 已启用, 显存: {mem:.2f} GB / {total:.2f} GB") # 连接 Chroma print(f"💾 连接 Chroma: {chroma_path}") self.client = chromadb.PersistentClient(path=chroma_path) self.collection = self.client.get_collection(collection_name) print(f" 集合 '{collection_name}': {self.collection.count()} 条记录\n") def search(self, query: str, n_results: int = 5, filter_meta: dict = None): query_embedding = self.model.encode( query, normalize_embeddings=True, show_progress_bar=False ).tolist() where_filter = filter_meta if filter_meta else None results = self.collection.query( query_embeddings=[query_embedding], n_results=n_results, where=where_filter, include=["documents", "metadatas", "distances"] ) return results def search_with_threshold(self, query: str, n_results: int = 5, threshold: float = 0.3, filter_meta: dict = None): results = self.search(query, n_results, filter_meta) filtered = {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]} for i, dist in enumerate(results["distances"][0]): if dist <= threshold: filtered["ids"][0].append(results["ids"][0][i]) filtered["documents"][0].append(results["documents"][0][i]) filtered["metadatas"][0].append(results["metadatas"][0][i]) filtered["distances"][0].append(dist) return filtered def format_result(self, results: dict) -> str: lines = [] n = len(results["ids"][0]) lines.append(f"找到 {n} 条结果:") for i, (doc_id, doc, dist, meta) in enumerate(zip( results["ids"][0], results["documents"][0], results["distances"][0], results["metadatas"][0] )): chunk_type = meta.get("chunk_type", "N/A") report = meta.get("report_name", "") band = meta.get("band_name", "") lines.append(f"\n--- 结果 {i+1} (相似度={1-dist:.4f}, id={doc_id}) ---") lines.append(f"类型: {chunk_type}") if report: lines.append(f"报表: {report}") if band: lines.append(f"区域: {band}") lines.append(f"内容: {doc[:300]}") return "\n".join(lines) def main(): import argparse project_root = Path(__file__).resolve().parent parser = argparse.ArgumentParser(description="JRXML Chunks 语义搜索工具") parser.add_argument("query", nargs="?", default="", help="搜索关键词(不提供则进入交互模式)") parser.add_argument("--chroma_path", "-c", default=None, help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})") parser.add_argument("--collection", "-n", default=CHROMA_COLLECTION_NAME, help="集合名称") parser.add_argument("--model_path", "-m", default=None, help="嵌入模型路径") parser.add_argument("--n_results", "-k", type=int, default=DEFAULT_N_RESULTS, help=f"返回结果数 (默认: {DEFAULT_N_RESULTS})") parser.add_argument("--filter_field", "-f", help="按 chunk_type 过滤,例如: field, query, chart") parser.add_argument("--threshold", "-t", type=float, help="相似度阈值 (0~1, 越高越相似)") parser.add_argument("--no_fp16", action="store_true", help="禁用 FP16 半精度") args = parser.parse_args() if args.chroma_path is None: args.chroma_path = str(CHROMA_DB_PATH) if args.model_path is None: args.model_path = resolve_model_path() # 检查数据库 if not os.path.exists(args.chroma_path): print(f"❌ Chroma 数据库不存在: {args.chroma_path}") print(f" 请先运行 import_to_chroma.py 导入数据") return # 初始化搜索器 try: searcher = JRXMLSearcher( chroma_path=args.chroma_path, collection_name=args.collection, model_path=args.model_path, use_fp16=not args.no_fp16 ) except Exception as e: print(f"❌ 初始化失败: {e}") return # 准备过滤条件 filter_meta = None if args.filter_field: filter_meta = {"chunk_type": args.filter_field} # 单次查询模式 if args.query: query = args.query print(f"\n� 搜索: '{query}'") if filter_meta: print(f" 过滤: {filter_meta}") start = time.time() if args.threshold is not None: results = searcher.search_with_threshold( query, args.n_results, args.threshold, filter_meta ) else: results = searcher.search(query, args.n_results, filter_meta) elapsed = time.time() - start print(searcher.format_result(results)) print(f"\n⏱️ 耗时: {elapsed:.2f}s") return # 交互模式 print(f"\n{'='*60}") print(f"JRXML 语义搜索 - 交互模式") print(f"{'='*60}") print(f"可用过滤类型: report_overview, query, field, parameter,") print(f" variable, band_*, chart, crosstab, subreport, style 等") print(f"示例: '如何修改报表标题'") print(f" 'filter:query SQL数据源查询'") print(f" 't:0.5 band:title 标题区域'") print(f"输入 'help' 查看帮助, 'exit' 退出\n") while True: try: user_input = input("🔍 搜索> ").strip() except (EOFError, KeyboardInterrupt): print("\n👋 再见!") break if not user_input: continue if user_input.lower() in ("exit", "quit", "q"): print("👋 再见!") break if user_input.lower() == "help": print(""" 特殊命令: filter:<类型> 按 chunk_type 过滤 (如 filter:query) t:<阈值> 设置相似度阈值 0~1 (如 t:0.5) k:<数量> 设置返回结果数 (如 k:10) 示例: filter:field 数据源字段有哪些 t:0.5 band:title 标题区域怎么设置 k:10 报表参数定义 """) continue # 解析特殊命令 query_text = user_input cur_filter = filter_meta cur_n = args.n_results cur_threshold = args.threshold parts = user_input.split() new_parts = [] for p in parts: if p.startswith("filter:"): field_val = p[len("filter:"):] cur_filter = {"chunk_type": field_val} print(f" 📌 过滤: {cur_filter}") elif p.startswith("t:"): try: cur_threshold = float(p[2:]) print(f" 📌 阈值: {cur_threshold}") except ValueError: pass elif p.startswith("k:"): try: cur_n = int(p[2:]) cur_n = max(1, min(cur_n, 50)) print(f" 📌 返回数量: {cur_n}") except ValueError: pass else: new_parts.append(p) query_text = " ".join(new_parts) if not query_text: print(" ⚠️ 请输入搜索内容") continue print(f"🔍 搜索: '{query_text}'") start = time.time() if cur_threshold is not None: results = searcher.search_with_threshold( query_text, cur_n, cur_threshold, cur_filter ) else: results = searcher.search(query_text, cur_n, cur_filter) elapsed = time.time() - start print(searcher.format_result(results)) print(f"⏱️ 耗时: {elapsed:.2f}s\n") if __name__ == "__main__": main()