refactor: 重构项目配置管理,统一使用.env配置

- 新增config.py统一读取.env配置,移除硬编码路径和参数
- 重构collect_jrxml.py支持命令行参数和环境变量配置源目录
- 新增.env.example示例配置文件,整理所有可配置项
- 重构down_embedding_model.py、import_to_chroma.py等所有脚本使用统一配置
- 新增Windows一键部署脚本setup.bat
- 修正jrxml_banch_chunker.py的文件名拼写错误
This commit is contained in:
2026-05-12 08:29:17 +08:00
parent bd98486de0
commit 9d78a49625
9 changed files with 396 additions and 67 deletions
+54
View File
@@ -0,0 +1,54 @@
# ============================================================
# JRXML RAG 项目 - 环境配置文件
# 复制此文件为 .env 并根据需要修改配置
# ============================================================
# -------------------- 嵌入模型配置 --------------------
# 模型名称或路径,支持以下格式:
# 1. HuggingFace Hub 模型: Qwen/Qwen3-Embedding-4B
# 2. HuggingFace Hub 模型: sentence-transformers/all-MiniLM-L6-v2
# 3. 本地模型路径: models/Qwen3-Embedding-4B
EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-4B
# 本地模型下载/存放目录(使用 Hub 模型时会自动下载到此目录)
EMBEDDING_MODEL_PATH=models/Qwen3-Embedding-4B
# HuggingFace 镜像站点(国内用户建议使用镜像加速)
HF_ENDPOINT=https://hf-mirror.com
# -------------------- 硬件配置 --------------------
# 是否使用 GPU 加速 (true/false)
USE_GPU=true
# 是否启用 FP16 半精度(可节省约 50% 显存)
USE_FP16=true
# 向量化批处理大小(根据显存调整,显存不足时减小此值)
BATCH_SIZE=64
# -------------------- 目录配置 --------------------
# JRXML 源文件目录
JRXML_SOURCE_DIR=jrxml_source
# 分块输出目录
CHUNKER_OUTPUT_DIR=jrxml_chunker_output
# 向量输出目录
EMBEDDINGS_DIR=embeddings
# Chroma 向量数据库目录
CHROMA_DB_PATH=chroma_db
# Chroma 集合名称
CHROMA_COLLECTION_NAME=jrxml_chunks
# -------------------- 分块配置 --------------------
# 单个 chunk 最大字符数
MAX_CHUNK_SIZE=2000
# -------------------- 查询配置 --------------------
# 默认返回结果数
DEFAULT_N_RESULTS=5
# 相似度阈值 (0~1,余弦距离,越小越相似)
SIMILARITY_THRESHOLD=0.3
+15 -3
View File
@@ -2,10 +2,12 @@
""" """
JRXML 文件收集脚本 JRXML 文件收集脚本
从指定目录递归查找所有 .jrxml 文件并复制到项目的 jrxml_source 目录 从指定目录递归查找所有 .jrxml 文件并复制到项目的 jrxml_source 目录
源目录和目标目录通过 .env / config.py 配置
""" """
import os import os
import shutil import shutil
from config import JRXML_SOURCE_DIR
def collect_jrxml_files(source_dir: str, target_dir: str) -> int: def collect_jrxml_files(source_dir: str, target_dir: str) -> int:
""" """
@@ -53,12 +55,22 @@ def collect_jrxml_files(source_dir: str, target_dir: str) -> int:
return copied_count return copied_count
if __name__ == "__main__": if __name__ == "__main__":
SOURCE_DIR = r"C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples" import sys
TARGET_DIR = os.path.join(os.path.dirname(__file__), "jrxml_source")
if len(sys.argv) >= 2:
SOURCE_DIR = sys.argv[1]
else:
SOURCE_DIR = os.environ.get(
"JRXML_COLLECT_SOURCE",
r"C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples"
)
TARGET_DIR = str(JRXML_SOURCE_DIR)
if not os.path.exists(SOURCE_DIR): if not os.path.exists(SOURCE_DIR):
print(f"错误:源目录不存在 - {SOURCE_DIR}") print(f"错误:源目录不存在 - {SOURCE_DIR}")
print("请检查路径是否正确") print("请检查路径是否正确,或通过命令行参数指定:")
print(f" python collect_jrxml.py <源目录路径>")
exit(1) exit(1)
collect_jrxml_files(SOURCE_DIR, TARGET_DIR) collect_jrxml_files(SOURCE_DIR, TARGET_DIR)
+138
View File
@@ -0,0 +1,138 @@
"""
config.py
统一配置管理,从 .env 文件读取环境变量
所有脚本通过此模块获取配置,避免硬编码
"""
import os
from pathlib import Path
def _get_project_root() -> Path:
return Path(__file__).resolve().parent
def _load_dotenv():
"""简易 .env 文件加载器,不依赖 python-dotenv"""
project_root = _get_project_root()
env_file = project_root / ".env"
if not env_file.exists():
return
with open(env_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
_load_dotenv()
PROJECT_ROOT = _get_project_root()
def _get_path(key: str, default: str) -> Path:
value = os.environ.get(key, default)
p = Path(value)
if not p.is_absolute():
p = PROJECT_ROOT / p
return p
def _get_str(key: str, default: str) -> str:
return os.environ.get(key, default)
def _get_bool(key: str, default: bool) -> bool:
value = os.environ.get(key, "").strip().lower()
if value in ("true", "1", "yes"):
return True
if value in ("false", "0", "no"):
return False
return default
def _get_int(key: str, default: int) -> int:
try:
return int(os.environ.get(key, str(default)))
except ValueError:
return default
def _get_float(key: str, default: float) -> float:
try:
return float(os.environ.get(key, str(default)))
except ValueError:
return default
# ==================== 模型配置 ====================
EMBEDDING_MODEL_NAME = _get_str("EMBEDDING_MODEL_NAME", "Qwen/Qwen3-Embedding-4B")
EMBEDDING_MODEL_PATH = _get_path("EMBEDDING_MODEL_PATH", "models/Qwen3-Embedding-4B")
HF_ENDPOINT = _get_str("HF_ENDPOINT", "https://hf-mirror.com")
# ==================== 硬件配置 ====================
USE_GPU = _get_bool("USE_GPU", True)
USE_FP16 = _get_bool("USE_FP16", True)
BATCH_SIZE = _get_int("BATCH_SIZE", 64)
# ==================== 目录配置 ====================
JRXML_SOURCE_DIR = _get_path("JRXML_SOURCE_DIR", "jrxml_source")
CHUNKER_OUTPUT_DIR = _get_path("CHUNKER_OUTPUT_DIR", "jrxml_chunker_output")
EMBEDDINGS_DIR = _get_path("EMBEDDINGS_DIR", "embeddings")
CHROMA_DB_PATH = _get_path("CHROMA_DB_PATH", "chroma_db")
CHROMA_COLLECTION_NAME = _get_str("CHROMA_COLLECTION_NAME", "jrxml_chunks")
# ==================== 分块配置 ====================
MAX_CHUNK_SIZE = _get_int("MAX_CHUNK_SIZE", 2000)
# ==================== 查询配置 ====================
DEFAULT_N_RESULTS = _get_int("DEFAULT_N_RESULTS", 5)
SIMILARITY_THRESHOLD = _get_float("SIMILARITY_THRESHOLD", 0.3)
def resolve_model_path() -> str:
"""
解析模型路径:
1. 如果 EMBEDDING_MODEL_PATH 本地存在,使用本地路径
2. 否则使用 EMBEDDING_MODEL_NAME 作为 Hub 模型名
"""
if EMBEDDING_MODEL_PATH.exists():
return str(EMBEDDING_MODEL_PATH)
return EMBEDDING_MODEL_NAME
def print_config():
"""打印当前配置(调试用)"""
print(f"{'='*60}")
print(f"JRXML RAG 当前配置")
print(f"{'='*60}")
print(f" 项目根目录: {PROJECT_ROOT}")
print(f" 嵌入模型名称: {EMBEDDING_MODEL_NAME}")
print(f" 嵌入模型路径: {EMBEDDING_MODEL_PATH}")
print(f" 模型解析结果: {resolve_model_path()}")
print(f" HF 镜像: {HF_ENDPOINT}")
print(f" GPU 加速: {USE_GPU}")
print(f" FP16 半精度: {USE_FP16}")
print(f" 批处理大小: {BATCH_SIZE}")
print(f" JRXML 源目录: {JRXML_SOURCE_DIR}")
print(f" 分块输出目录: {CHUNKER_OUTPUT_DIR}")
print(f" 向量输出目录: {EMBEDDINGS_DIR}")
print(f" Chroma 数据库: {CHROMA_DB_PATH}")
print(f" Chroma 集合名: {CHROMA_COLLECTION_NAME}")
print(f" 最大 Chunk 大小: {MAX_CHUNK_SIZE}")
print(f" 默认返回结果数: {DEFAULT_N_RESULTS}")
print(f" 相似度阈值: {SIMILARITY_THRESHOLD}")
print(f"{'='*60}")
if __name__ == "__main__":
print_config()
+11 -11
View File
@@ -1,26 +1,26 @@
""" """
down_embedding_model.py down_embedding_model.py
下载 Qwen3-Embedding-4B 嵌入模型 下载嵌入模型(模型名称通过 .env / config.py 配置)
""" """
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from config import EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, HF_ENDPOINT
def download_model(): def download_model():
"""下载 Qwen3-Embedding-4B 模型""" """下载嵌入模型"""
project_root = Path(__file__).resolve().parent model_dir = EMBEDDING_MODEL_PATH
model_dir = project_root / "models" / "Qwen3-Embedding-4B"
print("=" * 60) print("=" * 60)
print("Qwen3-Embedding-4B 模型下载") print(f"{EMBEDDING_MODEL_NAME} 模型下载")
print("=" * 60) print("=" * 60)
print(f"模型名称: {EMBEDDING_MODEL_NAME}")
print(f"模型目录: {model_dir}") print(f"模型目录: {model_dir}")
print() print()
# 使用国内镜像加速 os.environ['HF_ENDPOINT'] = HF_ENDPOINT
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' print(f"使用 HuggingFace 镜像: {HF_ENDPOINT}")
print("使用 HuggingFace 镜像: https://hf-mirror.com")
print() print()
try: try:
@@ -34,13 +34,13 @@ def download_model():
# 创建模型目录 # 创建模型目录
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
print(f"开始下载 Qwen3-Embedding-4B 模型...") print(f"开始下载 {EMBEDDING_MODEL_NAME} 模型...")
print(f"模型大小约 4GB请耐心等待...") print(f"请耐心等待...")
print() print()
try: try:
snapshot_download( snapshot_download(
repo_id="Qwen/Qwen3-Embedding-4B", repo_id=EMBEDDING_MODEL_NAME,
local_dir=str(model_dir), local_dir=str(model_dir),
local_dir_use_symlinks=False, local_dir_use_symlinks=False,
resume_download=True resume_download=True
+25 -18
View File
@@ -1,7 +1,7 @@
""" """
embed_chunks.py embed_chunks.py
使用本地 Qwen3-Embedding-4B 模型对 JRXML chunks 进行向量化 使用嵌入模型对 JRXML chunks 进行向量化
支持 GPU (CUDA) 或 CPU 支持 GPU (CUDA) 或 CPU,模型通过 .env / config.py 配置
""" """
import os import os
@@ -12,6 +12,10 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from config import (
EMBEDDING_MODEL_PATH, CHUNKER_OUTPUT_DIR, EMBEDDINGS_DIR,
USE_FP16, BATCH_SIZE, resolve_model_path
)
def build_text_for_embedding(chunk: dict) -> str: def build_text_for_embedding(chunk: dict) -> str:
""" """
@@ -48,8 +52,8 @@ def build_text_for_embedding(chunk: dict) -> str:
def main(chunks_json_path: str = None, output_dir: str = None, def main(chunks_json_path: str = None, output_dir: str = None,
model_path: str = None, batch_size: int = 64, normalize: bool = True, model_path: str = None, batch_size: int = None, normalize: bool = True,
use_fp16: bool = True): use_fp16: bool = None):
""" """
主流程: 主流程:
1. 加载 chunk JSON 1. 加载 chunk JSON
@@ -60,19 +64,25 @@ def main(chunks_json_path: str = None, output_dir: str = None,
project_root = Path(__file__).resolve().parent project_root = Path(__file__).resolve().parent
if chunks_json_path is None: if chunks_json_path is None:
chunks_json_path = project_root / "jrxml_chunker_output" / "all_chunks.json" chunks_json_path = CHUNKER_OUTPUT_DIR / "all_chunks.json"
else: else:
chunks_json_path = Path(chunks_json_path) chunks_json_path = Path(chunks_json_path)
if output_dir is None: if output_dir is None:
output_dir = project_root / "embeddings" output_dir = EMBEDDINGS_DIR
else: else:
output_dir = Path(output_dir) output_dir = Path(output_dir)
if model_path is None: if model_path is None:
model_path = project_root / "models" / "Qwen3-Embedding-4B" model_path = resolve_model_path()
else: else:
model_path = Path(model_path) model_path = str(model_path)
if batch_size is None:
batch_size = BATCH_SIZE
if use_fp16 is None:
use_fp16 = USE_FP16
if not chunks_json_path.exists(): if not chunks_json_path.exists():
print(f"❌ Chunks 文件不存在: {chunks_json_path}") print(f"❌ Chunks 文件不存在: {chunks_json_path}")
@@ -91,19 +101,16 @@ def main(chunks_json_path: str = None, output_dir: str = None,
print(f"\n🧠 加载嵌入模型: {model_path}") print(f"\n🧠 加载嵌入模型: {model_path}")
print(f" 设备: {device}") print(f" 设备: {device}")
# 检查是否是 HuggingFace Hub 模型(格式为 username/model_name
model_path_str = str(model_path) model_path_str = str(model_path)
# Windows PowerShell 会把 / 自动转成 \,需要还原
if "\\" in model_path_str and not os.path.exists(model_path_str): if "\\" in model_path_str and not os.path.exists(model_path_str):
model_path_str = model_path_str.replace("\\", "/") model_path_str = model_path_str.replace("\\", "/")
is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str) is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str)
# 如果是本地路径但不存在,则报错
if not is_hub_model and not os.path.exists(model_path_str): if not is_hub_model and not os.path.exists(model_path_str):
print(f"❌ 模型目录不存在: {model_path}") print(f"❌ 模型目录不存在: {model_path}")
print(f" 请先下载模型到 {model_path}") print(f" 请先运行 down_embedding_model.py 下载模型")
print(f"者使用 HuggingFace Hub 模型,例如: sentence-transformers/all-MiniLM-L6-v2") print(f"在 .env 中配置 EMBEDDING_MODEL_NAME 为 Hub 模型名")
return None return None
model = SentenceTransformer(model_path_str, device=device) model = SentenceTransformer(model_path_str, device=device)
@@ -183,17 +190,17 @@ def main(chunks_json_path: str = None, output_dir: str = None,
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
project_root = Path(__file__).resolve().parent project_root = Path(__file__).resolve().parent
default_chunks = project_root / "jrxml_chunker_output" / "all_chunks.json" default_chunks = CHUNKER_OUTPUT_DIR / "all_chunks.json"
parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具") parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具")
parser.add_argument("chunks_json", nargs="?", default=str(default_chunks), parser.add_argument("chunks_json", nargs="?", default=str(default_chunks),
help=f"Chunks JSON 文件路径 (默认: {default_chunks})") help=f"Chunks JSON 文件路径 (默认: {default_chunks})")
parser.add_argument("--output_dir", "-o", default=None, parser.add_argument("--output_dir", "-o", default=None,
help="输出目录 (默认: embeddings)") help=f"输出目录 (默认: {EMBEDDINGS_DIR})")
parser.add_argument("--model_path", "-m", default=None, parser.add_argument("--model_path", "-m", default=None,
help="模型路径 (默认: models/Qwen3-Embedding-4B)") help=f"模型路径 (默认: {resolve_model_path()})")
parser.add_argument("--batch_size", "-b", type=int, default=64, parser.add_argument("--batch_size", "-b", type=int, default=BATCH_SIZE,
help="批处理大小 (默认: 64)") help=f"批处理大小 (默认: {BATCH_SIZE})")
parser.add_argument("--no_normalize", action="store_true", parser.add_argument("--no_normalize", action="store_true",
help="不做向量归一化") help="不做向量归一化")
parser.add_argument("--no_fp16", action="store_true", parser.add_argument("--no_fp16", action="store_true",
+11 -7
View File
@@ -11,11 +11,12 @@ from pathlib import Path
import numpy as np import numpy as np
import chromadb import chromadb
from tqdm import tqdm from tqdm import tqdm
from config import EMBEDDINGS_DIR, CHROMA_DB_PATH, CHROMA_COLLECTION_NAME
def main(embeddings_dir: str = None, def main(embeddings_dir: str = None,
chroma_path: str = None, chroma_path: str = None,
collection_name: str = "jrxml_chunks"): collection_name: str = None):
""" """
从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库 从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库
@@ -27,15 +28,18 @@ def main(embeddings_dir: str = None,
project_root = Path(__file__).resolve().parent project_root = Path(__file__).resolve().parent
if embeddings_dir is None: if embeddings_dir is None:
embeddings_dir = project_root / "embeddings" embeddings_dir = EMBEDDINGS_DIR
else: else:
embeddings_dir = Path(embeddings_dir) embeddings_dir = Path(embeddings_dir)
if chroma_path is None: if chroma_path is None:
chroma_path = project_root / "chroma_db" chroma_path = CHROMA_DB_PATH
else: else:
chroma_path = Path(chroma_path) chroma_path = Path(chroma_path)
if collection_name is None:
collection_name = CHROMA_COLLECTION_NAME
embeddings_file = embeddings_dir / "embeddings.npy" embeddings_file = embeddings_dir / "embeddings.npy"
chunks_file = embeddings_dir / "chunks.json" chunks_file = embeddings_dir / "chunks.json"
@@ -164,11 +168,11 @@ if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具") parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具")
parser.add_argument("--embeddings_dir", "-e", default=None, parser.add_argument("--embeddings_dir", "-e", default=None,
help="向量文件目录 (默认: embeddings)") help=f"向量文件目录 (默认: {EMBEDDINGS_DIR})")
parser.add_argument("--chroma_path", "-c", default=None, parser.add_argument("--chroma_path", "-c", default=None,
help="Chroma 数据库路径 (默认: chroma_db)") help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})")
parser.add_argument("--collection_name", "-n", default="jrxml_chunks", parser.add_argument("--collection_name", "-n", default=CHROMA_COLLECTION_NAME,
help="集合名称 (默认: jrxml_chunks)") help=f"集合名称 (默认: {CHROMA_COLLECTION_NAME})")
args = parser.parse_args() args = parser.parse_args()
+17 -11
View File
@@ -11,9 +11,10 @@ from pathlib import Path
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
from jrxml_chunker import JRXMLSemanticChunker, save_chunks_to_json, print_chunk_summary from jrxml_chunker import JRXMLSemanticChunker, save_chunks_to_json, print_chunk_summary
from config import JRXML_SOURCE_DIR, CHUNKER_OUTPUT_DIR, MAX_CHUNK_SIZE
def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_size: int = 2000): def batch_chunk_with_report(input_dir: str = None, output_dir: str = None, max_chunk_size: int = None):
""" """
批量分块并生成详细报告 批量分块并生成详细报告
@@ -22,6 +23,8 @@ def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_si
output_dir: 输出目录,默认为 input_dir/../chunked_output output_dir: 输出目录,默认为 input_dir/../chunked_output
max_chunk_size: 单个chunk最大字节数 max_chunk_size: 单个chunk最大字节数
""" """
if input_dir is None:
input_dir = str(JRXML_SOURCE_DIR)
input_path = Path(input_dir).resolve() input_path = Path(input_dir).resolve()
if not input_path.exists(): if not input_path.exists():
@@ -32,12 +35,14 @@ def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_si
print(f"❌ 不是目录: {input_path}") print(f"❌ 不是目录: {input_path}")
return None return None
# 设置输出目录
if output_dir is None: if output_dir is None:
output_dir = input_path.parent / f"{input_path.name}_chunked_{datetime.now().strftime('%Y%m%d_%H%M%S')}" output_dir = str(CHUNKER_OUTPUT_DIR)
output_path = Path(output_dir) output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)
if max_chunk_size is None:
max_chunk_size = MAX_CHUNK_SIZE
print(f"\n{'='*60}") print(f"\n{'='*60}")
print(f"JRXML 语义分块 v3.0 - 批量处理") print(f"JRXML 语义分块 v3.0 - 批量处理")
print(f"{'='*60}") print(f"{'='*60}")
@@ -214,21 +219,24 @@ if __name__ == "__main__":
print("=" * 60) print("=" * 60)
print("JRXML Semantic Chunking v3.0 - 批量处理工具") print("JRXML Semantic Chunking v3.0 - 批量处理工具")
print("=" * 60) print("=" * 60)
print(f"\n默认输入目录: {JRXML_SOURCE_DIR}")
print(f"默认输出目录: {CHUNKER_OUTPUT_DIR}")
print("\n用法:") print("\n用法:")
print(" python batch_chunker.py <目录路径>") print(" python jrxml_banch_chunker.py <目录路径>")
print(" python batch_chunker.py <文件路径>") print(" python jrxml_banch_chunker.py <文件路径>")
print(" python jrxml_banch_chunker.py (使用默认配置)")
print("\n参数:") print("\n参数:")
print(" <路径> JRXML文件所在目录 或 单个JRXML文件路径") print(" <路径> JRXML文件所在目录 或 单个JRXML文件路径")
print(" --output <目录> 指定输出目录 (可选)") print(" --output <目录> 指定输出目录 (可选)")
print("\n示例:") print("\n示例:")
print(" python batch_chunker.py ./jasper_reports") print(" python jrxml_banch_chunker.py")
print(" python batch_chunker.py ./jasper_reports --output ./chunks") print(" python jrxml_banch_chunker.py ./jasper_reports")
print(" python batch_chunker.py report.jrxml") print(" python jrxml_banch_chunker.py ./jasper_reports --output ./chunks")
print(" python jrxml_banch_chunker.py report.jrxml")
sys.exit(0) sys.exit(0)
input_path = sys.argv[1] input_path = sys.argv[1]
# 解析--output参数
output_dir = None output_dir = None
if "--output" in sys.argv: if "--output" in sys.argv:
idx = sys.argv.index("--output") idx = sys.argv.index("--output")
@@ -236,10 +244,8 @@ if __name__ == "__main__":
output_dir = sys.argv[idx + 1] output_dir = sys.argv[idx + 1]
if os.path.isdir(input_path): if os.path.isdir(input_path):
# 批量处理目录
batch_chunk_with_report(input_path, output_dir) batch_chunk_with_report(input_path, output_dir)
elif os.path.isfile(input_path): elif os.path.isfile(input_path):
# 处理单个文件
chunk_single_file_with_report(input_path, output_dir) chunk_single_file_with_report(input_path, output_dir)
else: else:
print(f"❌ 路径无效: {input_path}") print(f"❌ 路径无效: {input_path}")
+19 -14
View File
@@ -2,6 +2,7 @@
query_chroma.py query_chroma.py
查询 Chroma 数据库,从自然语言查找相关 JRXML chunk 查询 Chroma 数据库,从自然语言查找相关 JRXML chunk
支持命令行单次查询和交互式连续查询 支持命令行单次查询和交互式连续查询
模型通过 .env / config.py 配置
""" """
import os import os
@@ -12,19 +13,27 @@ import numpy as np
import torch import torch
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
import chromadb import chromadb
from config import (
CHROMA_DB_PATH, CHROMA_COLLECTION_NAME, USE_FP16,
DEFAULT_N_RESULTS, SIMILARITY_THRESHOLD, resolve_model_path
)
class JRXMLSearcher: class JRXMLSearcher:
def __init__(self, chroma_path: str = None, def __init__(self, chroma_path: str = None,
collection_name: str = "jrxml_chunks", collection_name: str = None,
model_path: str = None, model_path: str = None,
use_fp16: bool = True): use_fp16: bool = None):
project_root = Path(__file__).resolve().parent project_root = Path(__file__).resolve().parent
if chroma_path is None: if chroma_path is None:
chroma_path = str(project_root / "chroma_db") chroma_path = str(CHROMA_DB_PATH)
if collection_name is None:
collection_name = CHROMA_COLLECTION_NAME
if model_path is None: if model_path is None:
model_path = str(project_root / "models" / "Qwen3-Embedding-4B") model_path = resolve_model_path()
if use_fp16 is None:
use_fp16 = USE_FP16
# 处理 Hub 模型名称 # 处理 Hub 模型名称
model_path_str = str(model_path) model_path_str = str(model_path)
@@ -110,13 +119,13 @@ def main():
parser.add_argument("query", nargs="?", default="", parser.add_argument("query", nargs="?", default="",
help="搜索关键词(不提供则进入交互模式)") help="搜索关键词(不提供则进入交互模式)")
parser.add_argument("--chroma_path", "-c", default=None, parser.add_argument("--chroma_path", "-c", default=None,
help=f"Chroma 数据库路径 (默认: chroma_db)") help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})")
parser.add_argument("--collection", "-n", default="jrxml_chunks", parser.add_argument("--collection", "-n", default=CHROMA_COLLECTION_NAME,
help="集合名称") help="集合名称")
parser.add_argument("--model_path", "-m", default=None, parser.add_argument("--model_path", "-m", default=None,
help="嵌入模型路径") help="嵌入模型路径")
parser.add_argument("--n_results", "-k", type=int, default=5, parser.add_argument("--n_results", "-k", type=int, default=DEFAULT_N_RESULTS,
help="返回结果数 (默认: 5)") help=f"返回结果数 (默认: {DEFAULT_N_RESULTS})")
parser.add_argument("--filter_field", "-f", parser.add_argument("--filter_field", "-f",
help="按 chunk_type 过滤,例如: field, query, chart") help="按 chunk_type 过滤,例如: field, query, chart")
parser.add_argument("--threshold", "-t", type=float, parser.add_argument("--threshold", "-t", type=float,
@@ -127,14 +136,10 @@ def main():
args = parser.parse_args() args = parser.parse_args()
if args.chroma_path is None: if args.chroma_path is None:
args.chroma_path = str(project_root / "chroma_db") args.chroma_path = str(CHROMA_DB_PATH)
if args.model_path is None: if args.model_path is None:
default_model = project_root / "models" / "Qwen3-Embedding-4B" args.model_path = resolve_model_path()
if not default_model.exists():
args.model_path = "sentence-transformers/all-MiniLM-L6-v2"
else:
args.model_path = str(default_model)
# 检查数据库 # 检查数据库
if not os.path.exists(args.chroma_path): if not os.path.exists(args.chroma_path):
+103
View File
@@ -0,0 +1,103 @@
@echo off
chcp 65001 >nul
setlocal enabledelayedexpansion
:: ============================================================
:: JRXML RAG 项目 - Windows 一键部署脚本
:: ============================================================
title JRXML RAG 环境部署
echo.
echo ============================================================
echo JRXML RAG 项目 - 环境部署脚本
echo ============================================================
echo.
:: 检查 Python
echo [1/5] 检查 Python 环境...
python --version >nul 2>&1
if %errorlevel% neq 0 (
echo [错误] 未找到 Python,请先安装 Python 3.11+
pause
exit /b 1
)
for /f "tokens=2" %%v in ('python --version 2^>^&1') do echo Python 版本: %%v
:: 检查 CUDA
echo.
echo [2/5] 检查 CUDA 环境...
python -c "import torch; print(f' PyTorch: {torch.__version__}'); print(f' CUDA 可用: {torch.cuda.is_available()}'); print(f' GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}')" 2>nul
if %errorlevel% neq 0 (
echo PyTorch 未安装,将在后续步骤安装
)
:: 初始化 .env 配置
echo.
echo [3/5] 初始化环境配置...
if not exist ".env" (
if exist ".env.example" (
copy ".env.example" ".env" >nul
echo 已从 .env.example 创建 .env 配置文件
echo 请根据需要编辑 .env 文件修改配置
) else (
echo [警告] 未找到 .env.example,跳过配置初始化
)
) else (
echo .env 配置文件已存在,跳过
)
:: 创建必要目录
echo.
echo [4/5] 创建项目目录...
set "DIRS=jrxml_source jrxml_chunker_output models embeddings chroma_db"
for %%d in (%DIRS%) do (
if not exist "%%d" (
mkdir "%%d" >nul 2>&1
echo 创建目录: %%d
)
)
echo 目录结构已就绪
:: 安装依赖
echo.
echo [5/5] 安装 Python 依赖...
echo.
:: 检测是否有 NVIDIA GPU
python -c "import subprocess; exit(0 if subprocess.run(['nvidia-smi'], capture_output=True).returncode == 0 else 1)" 2>nul
if %errorlevel% equ 0 (
set "HAS_GPU=1"
echo [检测到 NVIDIA GPU,安装 CUDA 版 PyTorch]
echo.
echo 安装 PyTorch (CUDA 版本)...
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
) else (
set "HAS_GPU=0"
echo [未检测到 GPU,安装 CPU 版 PyTorch]
echo.
echo 安装 PyTorch (CPU 版本)...
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
)
echo.
echo 安装核心依赖...
pip install sentence-transformers chromadb numpy tqdm huggingface_hub
echo.
echo ============================================================
echo 部署完成!
echo ============================================================
echo.
echo 下一步操作:
echo 1. 编辑 .env 配置文件(可选)
echo 2. 收集 JRXML 文件: python collect_jrxml.py
echo 3. 语义分块: python jrxml_banch_chunker.py
echo 4. 下载嵌入模型: python down_embedding_model.py
echo 5. 向量化: python embed_chunks.py
echo 6. 导入 Chroma: python import_to_chroma.py
echo 7. 开始查询: python query_chroma.py
echo.
echo 查看当前配置: python config.py
echo.
pause