From 9d78a496256db88a5ee0979cc8787b48dfaff139 Mon Sep 17 00:00:00 2001 From: panda <1415243231@qq.com> Date: Tue, 12 May 2026 08:29:17 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E9=A1=B9?= =?UTF-8?q?=E7=9B=AE=E9=85=8D=E7=BD=AE=E7=AE=A1=E7=90=86=EF=BC=8C=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E4=BD=BF=E7=94=A8.env=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增config.py统一读取.env配置,移除硬编码路径和参数 - 重构collect_jrxml.py支持命令行参数和环境变量配置源目录 - 新增.env.example示例配置文件,整理所有可配置项 - 重构down_embedding_model.py、import_to_chroma.py等所有脚本使用统一配置 - 新增Windows一键部署脚本setup.bat - 修正jrxml_banch_chunker.py的文件名拼写错误 --- .env.example | 54 ++++++++++++++++ collect_jrxml.py | 18 +++++- config.py | 138 ++++++++++++++++++++++++++++++++++++++++ down_embedding_model.py | 22 +++---- embed_chunks.py | 47 ++++++++------ import_to_chroma.py | 18 ++++-- jrxml_banch_chunker.py | 30 +++++---- query_chroma.py | 33 ++++++---- setup.bat | 103 ++++++++++++++++++++++++++++++ 9 files changed, 396 insertions(+), 67 deletions(-) create mode 100644 .env.example create mode 100644 config.py create mode 100644 setup.bat diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..b757d52 --- /dev/null +++ b/.env.example @@ -0,0 +1,54 @@ +# ============================================================ +# JRXML RAG 项目 - 环境配置文件 +# 复制此文件为 .env 并根据需要修改配置 +# ============================================================ + +# -------------------- 嵌入模型配置 -------------------- +# 模型名称或路径,支持以下格式: +# 1. HuggingFace Hub 模型: Qwen/Qwen3-Embedding-4B +# 2. HuggingFace Hub 模型: sentence-transformers/all-MiniLM-L6-v2 +# 3. 本地模型路径: models/Qwen3-Embedding-4B +EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-4B + +# 本地模型下载/存放目录(使用 Hub 模型时会自动下载到此目录) +EMBEDDING_MODEL_PATH=models/Qwen3-Embedding-4B + +# HuggingFace 镜像站点(国内用户建议使用镜像加速) +HF_ENDPOINT=https://hf-mirror.com + +# -------------------- 硬件配置 -------------------- +# 是否使用 GPU 加速 (true/false) +USE_GPU=true + +# 是否启用 FP16 半精度(可节省约 50% 显存) +USE_FP16=true + +# 向量化批处理大小(根据显存调整,显存不足时减小此值) +BATCH_SIZE=64 + +# -------------------- 目录配置 -------------------- +# JRXML 源文件目录 +JRXML_SOURCE_DIR=jrxml_source + +# 分块输出目录 +CHUNKER_OUTPUT_DIR=jrxml_chunker_output + +# 向量输出目录 +EMBEDDINGS_DIR=embeddings + +# Chroma 向量数据库目录 +CHROMA_DB_PATH=chroma_db + +# Chroma 集合名称 +CHROMA_COLLECTION_NAME=jrxml_chunks + +# -------------------- 分块配置 -------------------- +# 单个 chunk 最大字符数 +MAX_CHUNK_SIZE=2000 + +# -------------------- 查询配置 -------------------- +# 默认返回结果数 +DEFAULT_N_RESULTS=5 + +# 相似度阈值 (0~1,余弦距离,越小越相似) +SIMILARITY_THRESHOLD=0.3 \ No newline at end of file diff --git a/collect_jrxml.py b/collect_jrxml.py index d85e6c1..16f5acc 100644 --- a/collect_jrxml.py +++ b/collect_jrxml.py @@ -2,10 +2,12 @@ """ JRXML 文件收集脚本 从指定目录递归查找所有 .jrxml 文件并复制到项目的 jrxml_source 目录 +源目录和目标目录通过 .env / config.py 配置 """ import os import shutil +from config import JRXML_SOURCE_DIR def collect_jrxml_files(source_dir: str, target_dir: str) -> int: """ @@ -53,12 +55,22 @@ def collect_jrxml_files(source_dir: str, target_dir: str) -> int: return copied_count if __name__ == "__main__": - SOURCE_DIR = r"C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples" - TARGET_DIR = os.path.join(os.path.dirname(__file__), "jrxml_source") + import sys + + if len(sys.argv) >= 2: + SOURCE_DIR = sys.argv[1] + else: + SOURCE_DIR = os.environ.get( + "JRXML_COLLECT_SOURCE", + r"C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples" + ) + + TARGET_DIR = str(JRXML_SOURCE_DIR) if not os.path.exists(SOURCE_DIR): print(f"错误:源目录不存在 - {SOURCE_DIR}") - print("请检查路径是否正确") + print("请检查路径是否正确,或通过命令行参数指定:") + print(f" python collect_jrxml.py <源目录路径>") exit(1) collect_jrxml_files(SOURCE_DIR, TARGET_DIR) \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..dcfb6c7 --- /dev/null +++ b/config.py @@ -0,0 +1,138 @@ +""" +config.py +统一配置管理,从 .env 文件读取环境变量 +所有脚本通过此模块获取配置,避免硬编码 +""" + +import os +from pathlib import Path + + +def _get_project_root() -> Path: + return Path(__file__).resolve().parent + + +def _load_dotenv(): + """简易 .env 文件加载器,不依赖 python-dotenv""" + project_root = _get_project_root() + env_file = project_root / ".env" + + if not env_file.exists(): + return + + with open(env_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + value = value.strip().strip('"').strip("'") + if key and key not in os.environ: + os.environ[key] = value + + +_load_dotenv() + +PROJECT_ROOT = _get_project_root() + + +def _get_path(key: str, default: str) -> Path: + value = os.environ.get(key, default) + p = Path(value) + if not p.is_absolute(): + p = PROJECT_ROOT / p + return p + + +def _get_str(key: str, default: str) -> str: + return os.environ.get(key, default) + + +def _get_bool(key: str, default: bool) -> bool: + value = os.environ.get(key, "").strip().lower() + if value in ("true", "1", "yes"): + return True + if value in ("false", "0", "no"): + return False + return default + + +def _get_int(key: str, default: int) -> int: + try: + return int(os.environ.get(key, str(default))) + except ValueError: + return default + + +def _get_float(key: str, default: float) -> float: + try: + return float(os.environ.get(key, str(default))) + except ValueError: + return default + + +# ==================== 模型配置 ==================== +EMBEDDING_MODEL_NAME = _get_str("EMBEDDING_MODEL_NAME", "Qwen/Qwen3-Embedding-4B") +EMBEDDING_MODEL_PATH = _get_path("EMBEDDING_MODEL_PATH", "models/Qwen3-Embedding-4B") +HF_ENDPOINT = _get_str("HF_ENDPOINT", "https://hf-mirror.com") + +# ==================== 硬件配置 ==================== +USE_GPU = _get_bool("USE_GPU", True) +USE_FP16 = _get_bool("USE_FP16", True) +BATCH_SIZE = _get_int("BATCH_SIZE", 64) + +# ==================== 目录配置 ==================== +JRXML_SOURCE_DIR = _get_path("JRXML_SOURCE_DIR", "jrxml_source") +CHUNKER_OUTPUT_DIR = _get_path("CHUNKER_OUTPUT_DIR", "jrxml_chunker_output") +EMBEDDINGS_DIR = _get_path("EMBEDDINGS_DIR", "embeddings") +CHROMA_DB_PATH = _get_path("CHROMA_DB_PATH", "chroma_db") +CHROMA_COLLECTION_NAME = _get_str("CHROMA_COLLECTION_NAME", "jrxml_chunks") + +# ==================== 分块配置 ==================== +MAX_CHUNK_SIZE = _get_int("MAX_CHUNK_SIZE", 2000) + +# ==================== 查询配置 ==================== +DEFAULT_N_RESULTS = _get_int("DEFAULT_N_RESULTS", 5) +SIMILARITY_THRESHOLD = _get_float("SIMILARITY_THRESHOLD", 0.3) + + +def resolve_model_path() -> str: + """ + 解析模型路径: + 1. 如果 EMBEDDING_MODEL_PATH 本地存在,使用本地路径 + 2. 否则使用 EMBEDDING_MODEL_NAME 作为 Hub 模型名 + """ + if EMBEDDING_MODEL_PATH.exists(): + return str(EMBEDDING_MODEL_PATH) + return EMBEDDING_MODEL_NAME + + +def print_config(): + """打印当前配置(调试用)""" + print(f"{'='*60}") + print(f"JRXML RAG 当前配置") + print(f"{'='*60}") + print(f" 项目根目录: {PROJECT_ROOT}") + print(f" 嵌入模型名称: {EMBEDDING_MODEL_NAME}") + print(f" 嵌入模型路径: {EMBEDDING_MODEL_PATH}") + print(f" 模型解析结果: {resolve_model_path()}") + print(f" HF 镜像: {HF_ENDPOINT}") + print(f" GPU 加速: {USE_GPU}") + print(f" FP16 半精度: {USE_FP16}") + print(f" 批处理大小: {BATCH_SIZE}") + print(f" JRXML 源目录: {JRXML_SOURCE_DIR}") + print(f" 分块输出目录: {CHUNKER_OUTPUT_DIR}") + print(f" 向量输出目录: {EMBEDDINGS_DIR}") + print(f" Chroma 数据库: {CHROMA_DB_PATH}") + print(f" Chroma 集合名: {CHROMA_COLLECTION_NAME}") + print(f" 最大 Chunk 大小: {MAX_CHUNK_SIZE}") + print(f" 默认返回结果数: {DEFAULT_N_RESULTS}") + print(f" 相似度阈值: {SIMILARITY_THRESHOLD}") + print(f"{'='*60}") + + +if __name__ == "__main__": + print_config() \ No newline at end of file diff --git a/down_embedding_model.py b/down_embedding_model.py index 9695284..83f9dec 100644 --- a/down_embedding_model.py +++ b/down_embedding_model.py @@ -1,26 +1,26 @@ """ down_embedding_model.py -下载 Qwen3-Embedding-4B 嵌入模型 +下载嵌入模型(模型名称通过 .env / config.py 配置) """ import os import sys from pathlib import Path +from config import EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, HF_ENDPOINT def download_model(): - """下载 Qwen3-Embedding-4B 模型""" - project_root = Path(__file__).resolve().parent - model_dir = project_root / "models" / "Qwen3-Embedding-4B" + """下载嵌入模型""" + model_dir = EMBEDDING_MODEL_PATH print("=" * 60) - print("Qwen3-Embedding-4B 模型下载") + print(f"{EMBEDDING_MODEL_NAME} 模型下载") print("=" * 60) + print(f"模型名称: {EMBEDDING_MODEL_NAME}") print(f"模型目录: {model_dir}") print() - # 使用国内镜像加速 - os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' - print("使用 HuggingFace 镜像: https://hf-mirror.com") + os.environ['HF_ENDPOINT'] = HF_ENDPOINT + print(f"使用 HuggingFace 镜像: {HF_ENDPOINT}") print() try: @@ -34,13 +34,13 @@ def download_model(): # 创建模型目录 os.makedirs(model_dir, exist_ok=True) - print(f"开始下载 Qwen3-Embedding-4B 模型...") - print(f"模型大小约 4GB,请耐心等待...") + print(f"开始下载 {EMBEDDING_MODEL_NAME} 模型...") + print(f"请耐心等待...") print() try: snapshot_download( - repo_id="Qwen/Qwen3-Embedding-4B", + repo_id=EMBEDDING_MODEL_NAME, local_dir=str(model_dir), local_dir_use_symlinks=False, resume_download=True diff --git a/embed_chunks.py b/embed_chunks.py index 3ab8e2e..2f90c63 100644 --- a/embed_chunks.py +++ b/embed_chunks.py @@ -1,7 +1,7 @@ """ embed_chunks.py -使用本地 Qwen3-Embedding-4B 模型对 JRXML chunks 进行向量化 -支持 GPU (CUDA) 或 CPU +使用嵌入模型对 JRXML chunks 进行向量化 +支持 GPU (CUDA) 或 CPU,模型通过 .env / config.py 配置 """ import os @@ -12,6 +12,10 @@ from pathlib import Path import numpy as np import torch from sentence_transformers import SentenceTransformer +from config import ( + EMBEDDING_MODEL_PATH, CHUNKER_OUTPUT_DIR, EMBEDDINGS_DIR, + USE_FP16, BATCH_SIZE, resolve_model_path +) def build_text_for_embedding(chunk: dict) -> str: """ @@ -48,8 +52,8 @@ def build_text_for_embedding(chunk: dict) -> str: def main(chunks_json_path: str = None, output_dir: str = None, - model_path: str = None, batch_size: int = 64, normalize: bool = True, - use_fp16: bool = True): + model_path: str = None, batch_size: int = None, normalize: bool = True, + use_fp16: bool = None): """ 主流程: 1. 加载 chunk JSON @@ -60,19 +64,25 @@ def main(chunks_json_path: str = None, output_dir: str = None, project_root = Path(__file__).resolve().parent if chunks_json_path is None: - chunks_json_path = project_root / "jrxml_chunker_output" / "all_chunks.json" + chunks_json_path = CHUNKER_OUTPUT_DIR / "all_chunks.json" else: chunks_json_path = Path(chunks_json_path) if output_dir is None: - output_dir = project_root / "embeddings" + output_dir = EMBEDDINGS_DIR else: output_dir = Path(output_dir) if model_path is None: - model_path = project_root / "models" / "Qwen3-Embedding-4B" + model_path = resolve_model_path() else: - model_path = Path(model_path) + model_path = str(model_path) + + if batch_size is None: + batch_size = BATCH_SIZE + + if use_fp16 is None: + use_fp16 = USE_FP16 if not chunks_json_path.exists(): print(f"❌ Chunks 文件不存在: {chunks_json_path}") @@ -91,19 +101,16 @@ def main(chunks_json_path: str = None, output_dir: str = None, print(f"\n🧠 加载嵌入模型: {model_path}") print(f" 设备: {device}") - # 检查是否是 HuggingFace Hub 模型(格式为 username/model_name) model_path_str = str(model_path) - # Windows PowerShell 会把 / 自动转成 \,需要还原 if "\\" in model_path_str and not os.path.exists(model_path_str): model_path_str = model_path_str.replace("\\", "/") - + is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str) - - # 如果是本地路径但不存在,则报错 + if not is_hub_model and not os.path.exists(model_path_str): print(f"❌ 模型目录不存在: {model_path}") - print(f" 请先下载模型到 {model_path}") - print(f" 或者使用 HuggingFace Hub 模型,例如: sentence-transformers/all-MiniLM-L6-v2") + print(f" 请先运行 down_embedding_model.py 下载模型") + print(f" 或在 .env 中配置 EMBEDDING_MODEL_NAME 为 Hub 模型名") return None model = SentenceTransformer(model_path_str, device=device) @@ -183,17 +190,17 @@ def main(chunks_json_path: str = None, output_dir: str = None, if __name__ == "__main__": import argparse project_root = Path(__file__).resolve().parent - default_chunks = project_root / "jrxml_chunker_output" / "all_chunks.json" + default_chunks = CHUNKER_OUTPUT_DIR / "all_chunks.json" parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具") parser.add_argument("chunks_json", nargs="?", default=str(default_chunks), help=f"Chunks JSON 文件路径 (默认: {default_chunks})") parser.add_argument("--output_dir", "-o", default=None, - help="输出目录 (默认: embeddings)") + help=f"输出目录 (默认: {EMBEDDINGS_DIR})") parser.add_argument("--model_path", "-m", default=None, - help="模型路径 (默认: models/Qwen3-Embedding-4B)") - parser.add_argument("--batch_size", "-b", type=int, default=64, - help="批处理大小 (默认: 64)") + help=f"模型路径 (默认: {resolve_model_path()})") + parser.add_argument("--batch_size", "-b", type=int, default=BATCH_SIZE, + help=f"批处理大小 (默认: {BATCH_SIZE})") parser.add_argument("--no_normalize", action="store_true", help="不做向量归一化") parser.add_argument("--no_fp16", action="store_true", diff --git a/import_to_chroma.py b/import_to_chroma.py index b97d384..6ae72b9 100644 --- a/import_to_chroma.py +++ b/import_to_chroma.py @@ -11,11 +11,12 @@ from pathlib import Path import numpy as np import chromadb from tqdm import tqdm +from config import EMBEDDINGS_DIR, CHROMA_DB_PATH, CHROMA_COLLECTION_NAME def main(embeddings_dir: str = None, chroma_path: str = None, - collection_name: str = "jrxml_chunks"): + collection_name: str = None): """ 从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库 @@ -27,15 +28,18 @@ def main(embeddings_dir: str = None, project_root = Path(__file__).resolve().parent if embeddings_dir is None: - embeddings_dir = project_root / "embeddings" + embeddings_dir = EMBEDDINGS_DIR else: embeddings_dir = Path(embeddings_dir) if chroma_path is None: - chroma_path = project_root / "chroma_db" + chroma_path = CHROMA_DB_PATH else: chroma_path = Path(chroma_path) + if collection_name is None: + collection_name = CHROMA_COLLECTION_NAME + embeddings_file = embeddings_dir / "embeddings.npy" chunks_file = embeddings_dir / "chunks.json" @@ -164,11 +168,11 @@ if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具") parser.add_argument("--embeddings_dir", "-e", default=None, - help="向量文件目录 (默认: embeddings)") + help=f"向量文件目录 (默认: {EMBEDDINGS_DIR})") parser.add_argument("--chroma_path", "-c", default=None, - help="Chroma 数据库路径 (默认: chroma_db)") - parser.add_argument("--collection_name", "-n", default="jrxml_chunks", - help="集合名称 (默认: jrxml_chunks)") + help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})") + parser.add_argument("--collection_name", "-n", default=CHROMA_COLLECTION_NAME, + help=f"集合名称 (默认: {CHROMA_COLLECTION_NAME})") args = parser.parse_args() diff --git a/jrxml_banch_chunker.py b/jrxml_banch_chunker.py index da8db92..1279bf5 100644 --- a/jrxml_banch_chunker.py +++ b/jrxml_banch_chunker.py @@ -11,9 +11,10 @@ from pathlib import Path from datetime import datetime from collections import defaultdict from jrxml_chunker import JRXMLSemanticChunker, save_chunks_to_json, print_chunk_summary +from config import JRXML_SOURCE_DIR, CHUNKER_OUTPUT_DIR, MAX_CHUNK_SIZE -def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_size: int = 2000): +def batch_chunk_with_report(input_dir: str = None, output_dir: str = None, max_chunk_size: int = None): """ 批量分块并生成详细报告 @@ -22,6 +23,8 @@ def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_si output_dir: 输出目录,默认为 input_dir/../chunked_output max_chunk_size: 单个chunk最大字节数 """ + if input_dir is None: + input_dir = str(JRXML_SOURCE_DIR) input_path = Path(input_dir).resolve() if not input_path.exists(): @@ -32,11 +35,13 @@ def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_si print(f"❌ 不是目录: {input_path}") return None - # 设置输出目录 if output_dir is None: - output_dir = input_path.parent / f"{input_path.name}_chunked_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + output_dir = str(CHUNKER_OUTPUT_DIR) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) + + if max_chunk_size is None: + max_chunk_size = MAX_CHUNK_SIZE print(f"\n{'='*60}") print(f"JRXML 语义分块 v3.0 - 批量处理") @@ -214,21 +219,24 @@ if __name__ == "__main__": print("=" * 60) print("JRXML Semantic Chunking v3.0 - 批量处理工具") print("=" * 60) + print(f"\n默认输入目录: {JRXML_SOURCE_DIR}") + print(f"默认输出目录: {CHUNKER_OUTPUT_DIR}") print("\n用法:") - print(" python batch_chunker.py <目录路径>") - print(" python batch_chunker.py <文件路径>") + print(" python jrxml_banch_chunker.py <目录路径>") + print(" python jrxml_banch_chunker.py <文件路径>") + print(" python jrxml_banch_chunker.py (使用默认配置)") print("\n参数:") - print(" <路径> JRXML文件所在目录 或 单个JRXML文件路径") + print(" <路径> JRXML文件所在目录 或 单个JRXML文件路径") print(" --output <目录> 指定输出目录 (可选)") print("\n示例:") - print(" python batch_chunker.py ./jasper_reports") - print(" python batch_chunker.py ./jasper_reports --output ./chunks") - print(" python batch_chunker.py report.jrxml") + print(" python jrxml_banch_chunker.py") + print(" python jrxml_banch_chunker.py ./jasper_reports") + print(" python jrxml_banch_chunker.py ./jasper_reports --output ./chunks") + print(" python jrxml_banch_chunker.py report.jrxml") sys.exit(0) input_path = sys.argv[1] - # 解析--output参数 output_dir = None if "--output" in sys.argv: idx = sys.argv.index("--output") @@ -236,10 +244,8 @@ if __name__ == "__main__": output_dir = sys.argv[idx + 1] if os.path.isdir(input_path): - # 批量处理目录 batch_chunk_with_report(input_path, output_dir) elif os.path.isfile(input_path): - # 处理单个文件 chunk_single_file_with_report(input_path, output_dir) else: print(f"❌ 路径无效: {input_path}") \ No newline at end of file diff --git a/query_chroma.py b/query_chroma.py index 64d5da5..9886034 100644 --- a/query_chroma.py +++ b/query_chroma.py @@ -2,6 +2,7 @@ query_chroma.py 查询 Chroma 数据库,从自然语言查找相关 JRXML chunk 支持命令行单次查询和交互式连续查询 +模型通过 .env / config.py 配置 """ import os @@ -12,19 +13,27 @@ import numpy as np import torch from sentence_transformers import SentenceTransformer import chromadb +from config import ( + CHROMA_DB_PATH, CHROMA_COLLECTION_NAME, USE_FP16, + DEFAULT_N_RESULTS, SIMILARITY_THRESHOLD, resolve_model_path +) class JRXMLSearcher: def __init__(self, chroma_path: str = None, - collection_name: str = "jrxml_chunks", + collection_name: str = None, model_path: str = None, - use_fp16: bool = True): + use_fp16: bool = None): project_root = Path(__file__).resolve().parent if chroma_path is None: - chroma_path = str(project_root / "chroma_db") + chroma_path = str(CHROMA_DB_PATH) + if collection_name is None: + collection_name = CHROMA_COLLECTION_NAME if model_path is None: - model_path = str(project_root / "models" / "Qwen3-Embedding-4B") + model_path = resolve_model_path() + if use_fp16 is None: + use_fp16 = USE_FP16 # 处理 Hub 模型名称 model_path_str = str(model_path) @@ -110,13 +119,13 @@ def main(): parser.add_argument("query", nargs="?", default="", help="搜索关键词(不提供则进入交互模式)") parser.add_argument("--chroma_path", "-c", default=None, - help=f"Chroma 数据库路径 (默认: chroma_db)") - parser.add_argument("--collection", "-n", default="jrxml_chunks", + help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})") + parser.add_argument("--collection", "-n", default=CHROMA_COLLECTION_NAME, help="集合名称") parser.add_argument("--model_path", "-m", default=None, help="嵌入模型路径") - parser.add_argument("--n_results", "-k", type=int, default=5, - help="返回结果数 (默认: 5)") + parser.add_argument("--n_results", "-k", type=int, default=DEFAULT_N_RESULTS, + help=f"返回结果数 (默认: {DEFAULT_N_RESULTS})") parser.add_argument("--filter_field", "-f", help="按 chunk_type 过滤,例如: field, query, chart") parser.add_argument("--threshold", "-t", type=float, @@ -127,14 +136,10 @@ def main(): args = parser.parse_args() if args.chroma_path is None: - args.chroma_path = str(project_root / "chroma_db") + args.chroma_path = str(CHROMA_DB_PATH) if args.model_path is None: - default_model = project_root / "models" / "Qwen3-Embedding-4B" - if not default_model.exists(): - args.model_path = "sentence-transformers/all-MiniLM-L6-v2" - else: - args.model_path = str(default_model) + args.model_path = resolve_model_path() # 检查数据库 if not os.path.exists(args.chroma_path): diff --git a/setup.bat b/setup.bat new file mode 100644 index 0000000..e8e9a21 --- /dev/null +++ b/setup.bat @@ -0,0 +1,103 @@ +@echo off +chcp 65001 >nul +setlocal enabledelayedexpansion + +:: ============================================================ +:: JRXML RAG 项目 - Windows 一键部署脚本 +:: ============================================================ + +title JRXML RAG 环境部署 + +echo. +echo ============================================================ +echo JRXML RAG 项目 - 环境部署脚本 +echo ============================================================ +echo. + +:: 检查 Python +echo [1/5] 检查 Python 环境... +python --version >nul 2>&1 +if %errorlevel% neq 0 ( + echo [错误] 未找到 Python,请先安装 Python 3.11+ + pause + exit /b 1 +) +for /f "tokens=2" %%v in ('python --version 2^>^&1') do echo Python 版本: %%v + +:: 检查 CUDA +echo. +echo [2/5] 检查 CUDA 环境... +python -c "import torch; print(f' PyTorch: {torch.__version__}'); print(f' CUDA 可用: {torch.cuda.is_available()}'); print(f' GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}')" 2>nul +if %errorlevel% neq 0 ( + echo PyTorch 未安装,将在后续步骤安装 +) + +:: 初始化 .env 配置 +echo. +echo [3/5] 初始化环境配置... +if not exist ".env" ( + if exist ".env.example" ( + copy ".env.example" ".env" >nul + echo 已从 .env.example 创建 .env 配置文件 + echo 请根据需要编辑 .env 文件修改配置 + ) else ( + echo [警告] 未找到 .env.example,跳过配置初始化 + ) +) else ( + echo .env 配置文件已存在,跳过 +) + +:: 创建必要目录 +echo. +echo [4/5] 创建项目目录... +set "DIRS=jrxml_source jrxml_chunker_output models embeddings chroma_db" +for %%d in (%DIRS%) do ( + if not exist "%%d" ( + mkdir "%%d" >nul 2>&1 + echo 创建目录: %%d + ) +) +echo 目录结构已就绪 + +:: 安装依赖 +echo. +echo [5/5] 安装 Python 依赖... +echo. + +:: 检测是否有 NVIDIA GPU +python -c "import subprocess; exit(0 if subprocess.run(['nvidia-smi'], capture_output=True).returncode == 0 else 1)" 2>nul +if %errorlevel% equ 0 ( + set "HAS_GPU=1" + echo [检测到 NVIDIA GPU,安装 CUDA 版 PyTorch] + echo. + echo 安装 PyTorch (CUDA 版本)... + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 +) else ( + set "HAS_GPU=0" + echo [未检测到 GPU,安装 CPU 版 PyTorch] + echo. + echo 安装 PyTorch (CPU 版本)... + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +) + +echo. +echo 安装核心依赖... +pip install sentence-transformers chromadb numpy tqdm huggingface_hub + +echo. +echo ============================================================ +echo 部署完成! +echo ============================================================ +echo. +echo 下一步操作: +echo 1. 编辑 .env 配置文件(可选) +echo 2. 收集 JRXML 文件: python collect_jrxml.py +echo 3. 语义分块: python jrxml_banch_chunker.py +echo 4. 下载嵌入模型: python down_embedding_model.py +echo 5. 向量化: python embed_chunks.py +echo 6. 导入 Chroma: python import_to_chroma.py +echo 7. 开始查询: python query_chroma.py +echo. +echo 查看当前配置: python config.py +echo. +pause \ No newline at end of file