refactor: 重构项目配置管理,统一使用.env配置
- 新增config.py统一读取.env配置,移除硬编码路径和参数 - 重构collect_jrxml.py支持命令行参数和环境变量配置源目录 - 新增.env.example示例配置文件,整理所有可配置项 - 重构down_embedding_model.py、import_to_chroma.py等所有脚本使用统一配置 - 新增Windows一键部署脚本setup.bat - 修正jrxml_banch_chunker.py的文件名拼写错误
This commit is contained in:
@@ -0,0 +1,54 @@
|
|||||||
|
# ============================================================
|
||||||
|
# JRXML RAG 项目 - 环境配置文件
|
||||||
|
# 复制此文件为 .env 并根据需要修改配置
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# -------------------- 嵌入模型配置 --------------------
|
||||||
|
# 模型名称或路径,支持以下格式:
|
||||||
|
# 1. HuggingFace Hub 模型: Qwen/Qwen3-Embedding-4B
|
||||||
|
# 2. HuggingFace Hub 模型: sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
# 3. 本地模型路径: models/Qwen3-Embedding-4B
|
||||||
|
EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-4B
|
||||||
|
|
||||||
|
# 本地模型下载/存放目录(使用 Hub 模型时会自动下载到此目录)
|
||||||
|
EMBEDDING_MODEL_PATH=models/Qwen3-Embedding-4B
|
||||||
|
|
||||||
|
# HuggingFace 镜像站点(国内用户建议使用镜像加速)
|
||||||
|
HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
# -------------------- 硬件配置 --------------------
|
||||||
|
# 是否使用 GPU 加速 (true/false)
|
||||||
|
USE_GPU=true
|
||||||
|
|
||||||
|
# 是否启用 FP16 半精度(可节省约 50% 显存)
|
||||||
|
USE_FP16=true
|
||||||
|
|
||||||
|
# 向量化批处理大小(根据显存调整,显存不足时减小此值)
|
||||||
|
BATCH_SIZE=64
|
||||||
|
|
||||||
|
# -------------------- 目录配置 --------------------
|
||||||
|
# JRXML 源文件目录
|
||||||
|
JRXML_SOURCE_DIR=jrxml_source
|
||||||
|
|
||||||
|
# 分块输出目录
|
||||||
|
CHUNKER_OUTPUT_DIR=jrxml_chunker_output
|
||||||
|
|
||||||
|
# 向量输出目录
|
||||||
|
EMBEDDINGS_DIR=embeddings
|
||||||
|
|
||||||
|
# Chroma 向量数据库目录
|
||||||
|
CHROMA_DB_PATH=chroma_db
|
||||||
|
|
||||||
|
# Chroma 集合名称
|
||||||
|
CHROMA_COLLECTION_NAME=jrxml_chunks
|
||||||
|
|
||||||
|
# -------------------- 分块配置 --------------------
|
||||||
|
# 单个 chunk 最大字符数
|
||||||
|
MAX_CHUNK_SIZE=2000
|
||||||
|
|
||||||
|
# -------------------- 查询配置 --------------------
|
||||||
|
# 默认返回结果数
|
||||||
|
DEFAULT_N_RESULTS=5
|
||||||
|
|
||||||
|
# 相似度阈值 (0~1,余弦距离,越小越相似)
|
||||||
|
SIMILARITY_THRESHOLD=0.3
|
||||||
+15
-3
@@ -2,10 +2,12 @@
|
|||||||
"""
|
"""
|
||||||
JRXML 文件收集脚本
|
JRXML 文件收集脚本
|
||||||
从指定目录递归查找所有 .jrxml 文件并复制到项目的 jrxml_source 目录
|
从指定目录递归查找所有 .jrxml 文件并复制到项目的 jrxml_source 目录
|
||||||
|
源目录和目标目录通过 .env / config.py 配置
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from config import JRXML_SOURCE_DIR
|
||||||
|
|
||||||
def collect_jrxml_files(source_dir: str, target_dir: str) -> int:
|
def collect_jrxml_files(source_dir: str, target_dir: str) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -53,12 +55,22 @@ def collect_jrxml_files(source_dir: str, target_dir: str) -> int:
|
|||||||
return copied_count
|
return copied_count
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
SOURCE_DIR = r"C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples"
|
import sys
|
||||||
TARGET_DIR = os.path.join(os.path.dirname(__file__), "jrxml_source")
|
|
||||||
|
if len(sys.argv) >= 2:
|
||||||
|
SOURCE_DIR = sys.argv[1]
|
||||||
|
else:
|
||||||
|
SOURCE_DIR = os.environ.get(
|
||||||
|
"JRXML_COLLECT_SOURCE",
|
||||||
|
r"C:\Users\zy187\JaspersoftWorkspace\JasperReportsSamples"
|
||||||
|
)
|
||||||
|
|
||||||
|
TARGET_DIR = str(JRXML_SOURCE_DIR)
|
||||||
|
|
||||||
if not os.path.exists(SOURCE_DIR):
|
if not os.path.exists(SOURCE_DIR):
|
||||||
print(f"错误:源目录不存在 - {SOURCE_DIR}")
|
print(f"错误:源目录不存在 - {SOURCE_DIR}")
|
||||||
print("请检查路径是否正确")
|
print("请检查路径是否正确,或通过命令行参数指定:")
|
||||||
|
print(f" python collect_jrxml.py <源目录路径>")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
collect_jrxml_files(SOURCE_DIR, TARGET_DIR)
|
collect_jrxml_files(SOURCE_DIR, TARGET_DIR)
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
"""
|
||||||
|
config.py
|
||||||
|
统一配置管理,从 .env 文件读取环境变量
|
||||||
|
所有脚本通过此模块获取配置,避免硬编码
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def _get_project_root() -> Path:
|
||||||
|
return Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dotenv():
|
||||||
|
"""简易 .env 文件加载器,不依赖 python-dotenv"""
|
||||||
|
project_root = _get_project_root()
|
||||||
|
env_file = project_root / ".env"
|
||||||
|
|
||||||
|
if not env_file.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(env_file, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
if "=" not in line:
|
||||||
|
continue
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip().strip('"').strip("'")
|
||||||
|
if key and key not in os.environ:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
_load_dotenv()
|
||||||
|
|
||||||
|
PROJECT_ROOT = _get_project_root()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_path(key: str, default: str) -> Path:
|
||||||
|
value = os.environ.get(key, default)
|
||||||
|
p = Path(value)
|
||||||
|
if not p.is_absolute():
|
||||||
|
p = PROJECT_ROOT / p
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _get_str(key: str, default: str) -> str:
|
||||||
|
return os.environ.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bool(key: str, default: bool) -> bool:
|
||||||
|
value = os.environ.get(key, "").strip().lower()
|
||||||
|
if value in ("true", "1", "yes"):
|
||||||
|
return True
|
||||||
|
if value in ("false", "0", "no"):
|
||||||
|
return False
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _get_int(key: str, default: int) -> int:
|
||||||
|
try:
|
||||||
|
return int(os.environ.get(key, str(default)))
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _get_float(key: str, default: float) -> float:
|
||||||
|
try:
|
||||||
|
return float(os.environ.get(key, str(default)))
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 模型配置 ====================
|
||||||
|
EMBEDDING_MODEL_NAME = _get_str("EMBEDDING_MODEL_NAME", "Qwen/Qwen3-Embedding-4B")
|
||||||
|
EMBEDDING_MODEL_PATH = _get_path("EMBEDDING_MODEL_PATH", "models/Qwen3-Embedding-4B")
|
||||||
|
HF_ENDPOINT = _get_str("HF_ENDPOINT", "https://hf-mirror.com")
|
||||||
|
|
||||||
|
# ==================== 硬件配置 ====================
|
||||||
|
USE_GPU = _get_bool("USE_GPU", True)
|
||||||
|
USE_FP16 = _get_bool("USE_FP16", True)
|
||||||
|
BATCH_SIZE = _get_int("BATCH_SIZE", 64)
|
||||||
|
|
||||||
|
# ==================== 目录配置 ====================
|
||||||
|
JRXML_SOURCE_DIR = _get_path("JRXML_SOURCE_DIR", "jrxml_source")
|
||||||
|
CHUNKER_OUTPUT_DIR = _get_path("CHUNKER_OUTPUT_DIR", "jrxml_chunker_output")
|
||||||
|
EMBEDDINGS_DIR = _get_path("EMBEDDINGS_DIR", "embeddings")
|
||||||
|
CHROMA_DB_PATH = _get_path("CHROMA_DB_PATH", "chroma_db")
|
||||||
|
CHROMA_COLLECTION_NAME = _get_str("CHROMA_COLLECTION_NAME", "jrxml_chunks")
|
||||||
|
|
||||||
|
# ==================== 分块配置 ====================
|
||||||
|
MAX_CHUNK_SIZE = _get_int("MAX_CHUNK_SIZE", 2000)
|
||||||
|
|
||||||
|
# ==================== 查询配置 ====================
|
||||||
|
DEFAULT_N_RESULTS = _get_int("DEFAULT_N_RESULTS", 5)
|
||||||
|
SIMILARITY_THRESHOLD = _get_float("SIMILARITY_THRESHOLD", 0.3)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_path() -> str:
|
||||||
|
"""
|
||||||
|
解析模型路径:
|
||||||
|
1. 如果 EMBEDDING_MODEL_PATH 本地存在,使用本地路径
|
||||||
|
2. 否则使用 EMBEDDING_MODEL_NAME 作为 Hub 模型名
|
||||||
|
"""
|
||||||
|
if EMBEDDING_MODEL_PATH.exists():
|
||||||
|
return str(EMBEDDING_MODEL_PATH)
|
||||||
|
return EMBEDDING_MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
|
def print_config():
|
||||||
|
"""打印当前配置(调试用)"""
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"JRXML RAG 当前配置")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f" 项目根目录: {PROJECT_ROOT}")
|
||||||
|
print(f" 嵌入模型名称: {EMBEDDING_MODEL_NAME}")
|
||||||
|
print(f" 嵌入模型路径: {EMBEDDING_MODEL_PATH}")
|
||||||
|
print(f" 模型解析结果: {resolve_model_path()}")
|
||||||
|
print(f" HF 镜像: {HF_ENDPOINT}")
|
||||||
|
print(f" GPU 加速: {USE_GPU}")
|
||||||
|
print(f" FP16 半精度: {USE_FP16}")
|
||||||
|
print(f" 批处理大小: {BATCH_SIZE}")
|
||||||
|
print(f" JRXML 源目录: {JRXML_SOURCE_DIR}")
|
||||||
|
print(f" 分块输出目录: {CHUNKER_OUTPUT_DIR}")
|
||||||
|
print(f" 向量输出目录: {EMBEDDINGS_DIR}")
|
||||||
|
print(f" Chroma 数据库: {CHROMA_DB_PATH}")
|
||||||
|
print(f" Chroma 集合名: {CHROMA_COLLECTION_NAME}")
|
||||||
|
print(f" 最大 Chunk 大小: {MAX_CHUNK_SIZE}")
|
||||||
|
print(f" 默认返回结果数: {DEFAULT_N_RESULTS}")
|
||||||
|
print(f" 相似度阈值: {SIMILARITY_THRESHOLD}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print_config()
|
||||||
+11
-11
@@ -1,26 +1,26 @@
|
|||||||
"""
|
"""
|
||||||
down_embedding_model.py
|
down_embedding_model.py
|
||||||
下载 Qwen3-Embedding-4B 嵌入模型
|
下载嵌入模型(模型名称通过 .env / config.py 配置)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from config import EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH, HF_ENDPOINT
|
||||||
|
|
||||||
def download_model():
|
def download_model():
|
||||||
"""下载 Qwen3-Embedding-4B 模型"""
|
"""下载嵌入模型"""
|
||||||
project_root = Path(__file__).resolve().parent
|
model_dir = EMBEDDING_MODEL_PATH
|
||||||
model_dir = project_root / "models" / "Qwen3-Embedding-4B"
|
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Qwen3-Embedding-4B 模型下载")
|
print(f"{EMBEDDING_MODEL_NAME} 模型下载")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
print(f"模型名称: {EMBEDDING_MODEL_NAME}")
|
||||||
print(f"模型目录: {model_dir}")
|
print(f"模型目录: {model_dir}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# 使用国内镜像加速
|
os.environ['HF_ENDPOINT'] = HF_ENDPOINT
|
||||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
print(f"使用 HuggingFace 镜像: {HF_ENDPOINT}")
|
||||||
print("使用 HuggingFace 镜像: https://hf-mirror.com")
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -34,13 +34,13 @@ def download_model():
|
|||||||
# 创建模型目录
|
# 创建模型目录
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
print(f"开始下载 Qwen3-Embedding-4B 模型...")
|
print(f"开始下载 {EMBEDDING_MODEL_NAME} 模型...")
|
||||||
print(f"模型大小约 4GB,请耐心等待...")
|
print(f"请耐心等待...")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id="Qwen/Qwen3-Embedding-4B",
|
repo_id=EMBEDDING_MODEL_NAME,
|
||||||
local_dir=str(model_dir),
|
local_dir=str(model_dir),
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
resume_download=True
|
resume_download=True
|
||||||
|
|||||||
+25
-18
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
embed_chunks.py
|
embed_chunks.py
|
||||||
使用本地 Qwen3-Embedding-4B 模型对 JRXML chunks 进行向量化
|
使用嵌入模型对 JRXML chunks 进行向量化
|
||||||
支持 GPU (CUDA) 或 CPU
|
支持 GPU (CUDA) 或 CPU,模型通过 .env / config.py 配置
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -12,6 +12,10 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from config import (
|
||||||
|
EMBEDDING_MODEL_PATH, CHUNKER_OUTPUT_DIR, EMBEDDINGS_DIR,
|
||||||
|
USE_FP16, BATCH_SIZE, resolve_model_path
|
||||||
|
)
|
||||||
|
|
||||||
def build_text_for_embedding(chunk: dict) -> str:
|
def build_text_for_embedding(chunk: dict) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -48,8 +52,8 @@ def build_text_for_embedding(chunk: dict) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def main(chunks_json_path: str = None, output_dir: str = None,
|
def main(chunks_json_path: str = None, output_dir: str = None,
|
||||||
model_path: str = None, batch_size: int = 64, normalize: bool = True,
|
model_path: str = None, batch_size: int = None, normalize: bool = True,
|
||||||
use_fp16: bool = True):
|
use_fp16: bool = None):
|
||||||
"""
|
"""
|
||||||
主流程:
|
主流程:
|
||||||
1. 加载 chunk JSON
|
1. 加载 chunk JSON
|
||||||
@@ -60,19 +64,25 @@ def main(chunks_json_path: str = None, output_dir: str = None,
|
|||||||
project_root = Path(__file__).resolve().parent
|
project_root = Path(__file__).resolve().parent
|
||||||
|
|
||||||
if chunks_json_path is None:
|
if chunks_json_path is None:
|
||||||
chunks_json_path = project_root / "jrxml_chunker_output" / "all_chunks.json"
|
chunks_json_path = CHUNKER_OUTPUT_DIR / "all_chunks.json"
|
||||||
else:
|
else:
|
||||||
chunks_json_path = Path(chunks_json_path)
|
chunks_json_path = Path(chunks_json_path)
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = project_root / "embeddings"
|
output_dir = EMBEDDINGS_DIR
|
||||||
else:
|
else:
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
model_path = project_root / "models" / "Qwen3-Embedding-4B"
|
model_path = resolve_model_path()
|
||||||
else:
|
else:
|
||||||
model_path = Path(model_path)
|
model_path = str(model_path)
|
||||||
|
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = BATCH_SIZE
|
||||||
|
|
||||||
|
if use_fp16 is None:
|
||||||
|
use_fp16 = USE_FP16
|
||||||
|
|
||||||
if not chunks_json_path.exists():
|
if not chunks_json_path.exists():
|
||||||
print(f"❌ Chunks 文件不存在: {chunks_json_path}")
|
print(f"❌ Chunks 文件不存在: {chunks_json_path}")
|
||||||
@@ -91,19 +101,16 @@ def main(chunks_json_path: str = None, output_dir: str = None,
|
|||||||
print(f"\n🧠 加载嵌入模型: {model_path}")
|
print(f"\n🧠 加载嵌入模型: {model_path}")
|
||||||
print(f" 设备: {device}")
|
print(f" 设备: {device}")
|
||||||
|
|
||||||
# 检查是否是 HuggingFace Hub 模型(格式为 username/model_name)
|
|
||||||
model_path_str = str(model_path)
|
model_path_str = str(model_path)
|
||||||
# Windows PowerShell 会把 / 自动转成 \,需要还原
|
|
||||||
if "\\" in model_path_str and not os.path.exists(model_path_str):
|
if "\\" in model_path_str and not os.path.exists(model_path_str):
|
||||||
model_path_str = model_path_str.replace("\\", "/")
|
model_path_str = model_path_str.replace("\\", "/")
|
||||||
|
|
||||||
is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str)
|
is_hub_model = "/" in model_path_str and not os.path.exists(model_path_str)
|
||||||
|
|
||||||
# 如果是本地路径但不存在,则报错
|
|
||||||
if not is_hub_model and not os.path.exists(model_path_str):
|
if not is_hub_model and not os.path.exists(model_path_str):
|
||||||
print(f"❌ 模型目录不存在: {model_path}")
|
print(f"❌ 模型目录不存在: {model_path}")
|
||||||
print(f" 请先下载模型到 {model_path}")
|
print(f" 请先运行 down_embedding_model.py 下载模型")
|
||||||
print(f" 或者使用 HuggingFace Hub 模型,例如: sentence-transformers/all-MiniLM-L6-v2")
|
print(f" 或在 .env 中配置 EMBEDDING_MODEL_NAME 为 Hub 模型名")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model = SentenceTransformer(model_path_str, device=device)
|
model = SentenceTransformer(model_path_str, device=device)
|
||||||
@@ -183,17 +190,17 @@ def main(chunks_json_path: str = None, output_dir: str = None,
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
project_root = Path(__file__).resolve().parent
|
project_root = Path(__file__).resolve().parent
|
||||||
default_chunks = project_root / "jrxml_chunker_output" / "all_chunks.json"
|
default_chunks = CHUNKER_OUTPUT_DIR / "all_chunks.json"
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具")
|
parser = argparse.ArgumentParser(description="JRXML Chunks 向量化工具")
|
||||||
parser.add_argument("chunks_json", nargs="?", default=str(default_chunks),
|
parser.add_argument("chunks_json", nargs="?", default=str(default_chunks),
|
||||||
help=f"Chunks JSON 文件路径 (默认: {default_chunks})")
|
help=f"Chunks JSON 文件路径 (默认: {default_chunks})")
|
||||||
parser.add_argument("--output_dir", "-o", default=None,
|
parser.add_argument("--output_dir", "-o", default=None,
|
||||||
help="输出目录 (默认: embeddings)")
|
help=f"输出目录 (默认: {EMBEDDINGS_DIR})")
|
||||||
parser.add_argument("--model_path", "-m", default=None,
|
parser.add_argument("--model_path", "-m", default=None,
|
||||||
help="模型路径 (默认: models/Qwen3-Embedding-4B)")
|
help=f"模型路径 (默认: {resolve_model_path()})")
|
||||||
parser.add_argument("--batch_size", "-b", type=int, default=64,
|
parser.add_argument("--batch_size", "-b", type=int, default=BATCH_SIZE,
|
||||||
help="批处理大小 (默认: 64)")
|
help=f"批处理大小 (默认: {BATCH_SIZE})")
|
||||||
parser.add_argument("--no_normalize", action="store_true",
|
parser.add_argument("--no_normalize", action="store_true",
|
||||||
help="不做向量归一化")
|
help="不做向量归一化")
|
||||||
parser.add_argument("--no_fp16", action="store_true",
|
parser.add_argument("--no_fp16", action="store_true",
|
||||||
|
|||||||
+11
-7
@@ -11,11 +11,12 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import chromadb
|
import chromadb
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from config import EMBEDDINGS_DIR, CHROMA_DB_PATH, CHROMA_COLLECTION_NAME
|
||||||
|
|
||||||
|
|
||||||
def main(embeddings_dir: str = None,
|
def main(embeddings_dir: str = None,
|
||||||
chroma_path: str = None,
|
chroma_path: str = None,
|
||||||
collection_name: str = "jrxml_chunks"):
|
collection_name: str = None):
|
||||||
"""
|
"""
|
||||||
从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库
|
从 embeddings 目录读取向量和 chunks,导入 Chroma 持久化数据库
|
||||||
|
|
||||||
@@ -27,15 +28,18 @@ def main(embeddings_dir: str = None,
|
|||||||
project_root = Path(__file__).resolve().parent
|
project_root = Path(__file__).resolve().parent
|
||||||
|
|
||||||
if embeddings_dir is None:
|
if embeddings_dir is None:
|
||||||
embeddings_dir = project_root / "embeddings"
|
embeddings_dir = EMBEDDINGS_DIR
|
||||||
else:
|
else:
|
||||||
embeddings_dir = Path(embeddings_dir)
|
embeddings_dir = Path(embeddings_dir)
|
||||||
|
|
||||||
if chroma_path is None:
|
if chroma_path is None:
|
||||||
chroma_path = project_root / "chroma_db"
|
chroma_path = CHROMA_DB_PATH
|
||||||
else:
|
else:
|
||||||
chroma_path = Path(chroma_path)
|
chroma_path = Path(chroma_path)
|
||||||
|
|
||||||
|
if collection_name is None:
|
||||||
|
collection_name = CHROMA_COLLECTION_NAME
|
||||||
|
|
||||||
embeddings_file = embeddings_dir / "embeddings.npy"
|
embeddings_file = embeddings_dir / "embeddings.npy"
|
||||||
chunks_file = embeddings_dir / "chunks.json"
|
chunks_file = embeddings_dir / "chunks.json"
|
||||||
|
|
||||||
@@ -164,11 +168,11 @@ if __name__ == "__main__":
|
|||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具")
|
parser = argparse.ArgumentParser(description="JRXML Chunks 导入 Chroma 工具")
|
||||||
parser.add_argument("--embeddings_dir", "-e", default=None,
|
parser.add_argument("--embeddings_dir", "-e", default=None,
|
||||||
help="向量文件目录 (默认: embeddings)")
|
help=f"向量文件目录 (默认: {EMBEDDINGS_DIR})")
|
||||||
parser.add_argument("--chroma_path", "-c", default=None,
|
parser.add_argument("--chroma_path", "-c", default=None,
|
||||||
help="Chroma 数据库路径 (默认: chroma_db)")
|
help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})")
|
||||||
parser.add_argument("--collection_name", "-n", default="jrxml_chunks",
|
parser.add_argument("--collection_name", "-n", default=CHROMA_COLLECTION_NAME,
|
||||||
help="集合名称 (默认: jrxml_chunks)")
|
help=f"集合名称 (默认: {CHROMA_COLLECTION_NAME})")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
+17
-11
@@ -11,9 +11,10 @@ from pathlib import Path
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from jrxml_chunker import JRXMLSemanticChunker, save_chunks_to_json, print_chunk_summary
|
from jrxml_chunker import JRXMLSemanticChunker, save_chunks_to_json, print_chunk_summary
|
||||||
|
from config import JRXML_SOURCE_DIR, CHUNKER_OUTPUT_DIR, MAX_CHUNK_SIZE
|
||||||
|
|
||||||
|
|
||||||
def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_size: int = 2000):
|
def batch_chunk_with_report(input_dir: str = None, output_dir: str = None, max_chunk_size: int = None):
|
||||||
"""
|
"""
|
||||||
批量分块并生成详细报告
|
批量分块并生成详细报告
|
||||||
|
|
||||||
@@ -22,6 +23,8 @@ def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_si
|
|||||||
output_dir: 输出目录,默认为 input_dir/../chunked_output
|
output_dir: 输出目录,默认为 input_dir/../chunked_output
|
||||||
max_chunk_size: 单个chunk最大字节数
|
max_chunk_size: 单个chunk最大字节数
|
||||||
"""
|
"""
|
||||||
|
if input_dir is None:
|
||||||
|
input_dir = str(JRXML_SOURCE_DIR)
|
||||||
input_path = Path(input_dir).resolve()
|
input_path = Path(input_dir).resolve()
|
||||||
|
|
||||||
if not input_path.exists():
|
if not input_path.exists():
|
||||||
@@ -32,12 +35,14 @@ def batch_chunk_with_report(input_dir: str, output_dir: str = None, max_chunk_si
|
|||||||
print(f"❌ 不是目录: {input_path}")
|
print(f"❌ 不是目录: {input_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 设置输出目录
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = input_path.parent / f"{input_path.name}_chunked_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
output_dir = str(CHUNKER_OUTPUT_DIR)
|
||||||
output_path = Path(output_dir)
|
output_path = Path(output_dir)
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if max_chunk_size is None:
|
||||||
|
max_chunk_size = MAX_CHUNK_SIZE
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"JRXML 语义分块 v3.0 - 批量处理")
|
print(f"JRXML 语义分块 v3.0 - 批量处理")
|
||||||
print(f"{'='*60}")
|
print(f"{'='*60}")
|
||||||
@@ -214,21 +219,24 @@ if __name__ == "__main__":
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("JRXML Semantic Chunking v3.0 - 批量处理工具")
|
print("JRXML Semantic Chunking v3.0 - 批量处理工具")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
print(f"\n默认输入目录: {JRXML_SOURCE_DIR}")
|
||||||
|
print(f"默认输出目录: {CHUNKER_OUTPUT_DIR}")
|
||||||
print("\n用法:")
|
print("\n用法:")
|
||||||
print(" python batch_chunker.py <目录路径>")
|
print(" python jrxml_banch_chunker.py <目录路径>")
|
||||||
print(" python batch_chunker.py <文件路径>")
|
print(" python jrxml_banch_chunker.py <文件路径>")
|
||||||
|
print(" python jrxml_banch_chunker.py (使用默认配置)")
|
||||||
print("\n参数:")
|
print("\n参数:")
|
||||||
print(" <路径> JRXML文件所在目录 或 单个JRXML文件路径")
|
print(" <路径> JRXML文件所在目录 或 单个JRXML文件路径")
|
||||||
print(" --output <目录> 指定输出目录 (可选)")
|
print(" --output <目录> 指定输出目录 (可选)")
|
||||||
print("\n示例:")
|
print("\n示例:")
|
||||||
print(" python batch_chunker.py ./jasper_reports")
|
print(" python jrxml_banch_chunker.py")
|
||||||
print(" python batch_chunker.py ./jasper_reports --output ./chunks")
|
print(" python jrxml_banch_chunker.py ./jasper_reports")
|
||||||
print(" python batch_chunker.py report.jrxml")
|
print(" python jrxml_banch_chunker.py ./jasper_reports --output ./chunks")
|
||||||
|
print(" python jrxml_banch_chunker.py report.jrxml")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
input_path = sys.argv[1]
|
input_path = sys.argv[1]
|
||||||
|
|
||||||
# 解析--output参数
|
|
||||||
output_dir = None
|
output_dir = None
|
||||||
if "--output" in sys.argv:
|
if "--output" in sys.argv:
|
||||||
idx = sys.argv.index("--output")
|
idx = sys.argv.index("--output")
|
||||||
@@ -236,10 +244,8 @@ if __name__ == "__main__":
|
|||||||
output_dir = sys.argv[idx + 1]
|
output_dir = sys.argv[idx + 1]
|
||||||
|
|
||||||
if os.path.isdir(input_path):
|
if os.path.isdir(input_path):
|
||||||
# 批量处理目录
|
|
||||||
batch_chunk_with_report(input_path, output_dir)
|
batch_chunk_with_report(input_path, output_dir)
|
||||||
elif os.path.isfile(input_path):
|
elif os.path.isfile(input_path):
|
||||||
# 处理单个文件
|
|
||||||
chunk_single_file_with_report(input_path, output_dir)
|
chunk_single_file_with_report(input_path, output_dir)
|
||||||
else:
|
else:
|
||||||
print(f"❌ 路径无效: {input_path}")
|
print(f"❌ 路径无效: {input_path}")
|
||||||
+19
-14
@@ -2,6 +2,7 @@
|
|||||||
query_chroma.py
|
query_chroma.py
|
||||||
查询 Chroma 数据库,从自然语言查找相关 JRXML chunk
|
查询 Chroma 数据库,从自然语言查找相关 JRXML chunk
|
||||||
支持命令行单次查询和交互式连续查询
|
支持命令行单次查询和交互式连续查询
|
||||||
|
模型通过 .env / config.py 配置
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -12,19 +13,27 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import chromadb
|
import chromadb
|
||||||
|
from config import (
|
||||||
|
CHROMA_DB_PATH, CHROMA_COLLECTION_NAME, USE_FP16,
|
||||||
|
DEFAULT_N_RESULTS, SIMILARITY_THRESHOLD, resolve_model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JRXMLSearcher:
|
class JRXMLSearcher:
|
||||||
def __init__(self, chroma_path: str = None,
|
def __init__(self, chroma_path: str = None,
|
||||||
collection_name: str = "jrxml_chunks",
|
collection_name: str = None,
|
||||||
model_path: str = None,
|
model_path: str = None,
|
||||||
use_fp16: bool = True):
|
use_fp16: bool = None):
|
||||||
project_root = Path(__file__).resolve().parent
|
project_root = Path(__file__).resolve().parent
|
||||||
|
|
||||||
if chroma_path is None:
|
if chroma_path is None:
|
||||||
chroma_path = str(project_root / "chroma_db")
|
chroma_path = str(CHROMA_DB_PATH)
|
||||||
|
if collection_name is None:
|
||||||
|
collection_name = CHROMA_COLLECTION_NAME
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
model_path = str(project_root / "models" / "Qwen3-Embedding-4B")
|
model_path = resolve_model_path()
|
||||||
|
if use_fp16 is None:
|
||||||
|
use_fp16 = USE_FP16
|
||||||
|
|
||||||
# 处理 Hub 模型名称
|
# 处理 Hub 模型名称
|
||||||
model_path_str = str(model_path)
|
model_path_str = str(model_path)
|
||||||
@@ -110,13 +119,13 @@ def main():
|
|||||||
parser.add_argument("query", nargs="?", default="",
|
parser.add_argument("query", nargs="?", default="",
|
||||||
help="搜索关键词(不提供则进入交互模式)")
|
help="搜索关键词(不提供则进入交互模式)")
|
||||||
parser.add_argument("--chroma_path", "-c", default=None,
|
parser.add_argument("--chroma_path", "-c", default=None,
|
||||||
help=f"Chroma 数据库路径 (默认: chroma_db)")
|
help=f"Chroma 数据库路径 (默认: {CHROMA_DB_PATH})")
|
||||||
parser.add_argument("--collection", "-n", default="jrxml_chunks",
|
parser.add_argument("--collection", "-n", default=CHROMA_COLLECTION_NAME,
|
||||||
help="集合名称")
|
help="集合名称")
|
||||||
parser.add_argument("--model_path", "-m", default=None,
|
parser.add_argument("--model_path", "-m", default=None,
|
||||||
help="嵌入模型路径")
|
help="嵌入模型路径")
|
||||||
parser.add_argument("--n_results", "-k", type=int, default=5,
|
parser.add_argument("--n_results", "-k", type=int, default=DEFAULT_N_RESULTS,
|
||||||
help="返回结果数 (默认: 5)")
|
help=f"返回结果数 (默认: {DEFAULT_N_RESULTS})")
|
||||||
parser.add_argument("--filter_field", "-f",
|
parser.add_argument("--filter_field", "-f",
|
||||||
help="按 chunk_type 过滤,例如: field, query, chart")
|
help="按 chunk_type 过滤,例如: field, query, chart")
|
||||||
parser.add_argument("--threshold", "-t", type=float,
|
parser.add_argument("--threshold", "-t", type=float,
|
||||||
@@ -127,14 +136,10 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.chroma_path is None:
|
if args.chroma_path is None:
|
||||||
args.chroma_path = str(project_root / "chroma_db")
|
args.chroma_path = str(CHROMA_DB_PATH)
|
||||||
|
|
||||||
if args.model_path is None:
|
if args.model_path is None:
|
||||||
default_model = project_root / "models" / "Qwen3-Embedding-4B"
|
args.model_path = resolve_model_path()
|
||||||
if not default_model.exists():
|
|
||||||
args.model_path = "sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
else:
|
|
||||||
args.model_path = str(default_model)
|
|
||||||
|
|
||||||
# 检查数据库
|
# 检查数据库
|
||||||
if not os.path.exists(args.chroma_path):
|
if not os.path.exists(args.chroma_path):
|
||||||
|
|||||||
@@ -0,0 +1,103 @@
|
|||||||
|
@echo off
|
||||||
|
chcp 65001 >nul
|
||||||
|
setlocal enabledelayedexpansion
|
||||||
|
|
||||||
|
:: ============================================================
|
||||||
|
:: JRXML RAG 项目 - Windows 一键部署脚本
|
||||||
|
:: ============================================================
|
||||||
|
|
||||||
|
title JRXML RAG 环境部署
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo ============================================================
|
||||||
|
echo JRXML RAG 项目 - 环境部署脚本
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
|
||||||
|
:: 检查 Python
|
||||||
|
echo [1/5] 检查 Python 环境...
|
||||||
|
python --version >nul 2>&1
|
||||||
|
if %errorlevel% neq 0 (
|
||||||
|
echo [错误] 未找到 Python,请先安装 Python 3.11+
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
for /f "tokens=2" %%v in ('python --version 2^>^&1') do echo Python 版本: %%v
|
||||||
|
|
||||||
|
:: 检查 CUDA
|
||||||
|
echo.
|
||||||
|
echo [2/5] 检查 CUDA 环境...
|
||||||
|
python -c "import torch; print(f' PyTorch: {torch.__version__}'); print(f' CUDA 可用: {torch.cuda.is_available()}'); print(f' GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}')" 2>nul
|
||||||
|
if %errorlevel% neq 0 (
|
||||||
|
echo PyTorch 未安装,将在后续步骤安装
|
||||||
|
)
|
||||||
|
|
||||||
|
:: 初始化 .env 配置
|
||||||
|
echo.
|
||||||
|
echo [3/5] 初始化环境配置...
|
||||||
|
if not exist ".env" (
|
||||||
|
if exist ".env.example" (
|
||||||
|
copy ".env.example" ".env" >nul
|
||||||
|
echo 已从 .env.example 创建 .env 配置文件
|
||||||
|
echo 请根据需要编辑 .env 文件修改配置
|
||||||
|
) else (
|
||||||
|
echo [警告] 未找到 .env.example,跳过配置初始化
|
||||||
|
)
|
||||||
|
) else (
|
||||||
|
echo .env 配置文件已存在,跳过
|
||||||
|
)
|
||||||
|
|
||||||
|
:: 创建必要目录
|
||||||
|
echo.
|
||||||
|
echo [4/5] 创建项目目录...
|
||||||
|
set "DIRS=jrxml_source jrxml_chunker_output models embeddings chroma_db"
|
||||||
|
for %%d in (%DIRS%) do (
|
||||||
|
if not exist "%%d" (
|
||||||
|
mkdir "%%d" >nul 2>&1
|
||||||
|
echo 创建目录: %%d
|
||||||
|
)
|
||||||
|
)
|
||||||
|
echo 目录结构已就绪
|
||||||
|
|
||||||
|
:: 安装依赖
|
||||||
|
echo.
|
||||||
|
echo [5/5] 安装 Python 依赖...
|
||||||
|
echo.
|
||||||
|
|
||||||
|
:: 检测是否有 NVIDIA GPU
|
||||||
|
python -c "import subprocess; exit(0 if subprocess.run(['nvidia-smi'], capture_output=True).returncode == 0 else 1)" 2>nul
|
||||||
|
if %errorlevel% equ 0 (
|
||||||
|
set "HAS_GPU=1"
|
||||||
|
echo [检测到 NVIDIA GPU,安装 CUDA 版 PyTorch]
|
||||||
|
echo.
|
||||||
|
echo 安装 PyTorch (CUDA 版本)...
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||||
|
) else (
|
||||||
|
set "HAS_GPU=0"
|
||||||
|
echo [未检测到 GPU,安装 CPU 版 PyTorch]
|
||||||
|
echo.
|
||||||
|
echo 安装 PyTorch (CPU 版本)...
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
)
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo 安装核心依赖...
|
||||||
|
pip install sentence-transformers chromadb numpy tqdm huggingface_hub
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo ============================================================
|
||||||
|
echo 部署完成!
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
echo 下一步操作:
|
||||||
|
echo 1. 编辑 .env 配置文件(可选)
|
||||||
|
echo 2. 收集 JRXML 文件: python collect_jrxml.py
|
||||||
|
echo 3. 语义分块: python jrxml_banch_chunker.py
|
||||||
|
echo 4. 下载嵌入模型: python down_embedding_model.py
|
||||||
|
echo 5. 向量化: python embed_chunks.py
|
||||||
|
echo 6. 导入 Chroma: python import_to_chroma.py
|
||||||
|
echo 7. 开始查询: python query_chroma.py
|
||||||
|
echo.
|
||||||
|
echo 查看当前配置: python config.py
|
||||||
|
echo.
|
||||||
|
pause
|
||||||
Reference in New Issue
Block a user