refactor: 重构项目配置管理,统一使用.env配置

- 新增config.py统一读取.env配置,移除硬编码路径和参数
- 重构collect_jrxml.py支持命令行参数和环境变量配置源目录
- 新增.env.example示例配置文件,整理所有可配置项
- 重构down_embedding_model.py、import_to_chroma.py等所有脚本使用统一配置
- 新增Windows一键部署脚本setup.bat
- 修正jrxml_banch_chunker.py的文件名拼写错误
This commit is contained in:
2026-05-12 08:29:17 +08:00
parent bd98486de0
commit 9d78a49625
9 changed files with 396 additions and 67 deletions
+138
View File
@@ -0,0 +1,138 @@
"""
config.py
统一配置管理,从 .env 文件读取环境变量
所有脚本通过此模块获取配置,避免硬编码
"""
import os
from pathlib import Path
def _get_project_root() -> Path:
return Path(__file__).resolve().parent
def _load_dotenv():
"""简易 .env 文件加载器,不依赖 python-dotenv"""
project_root = _get_project_root()
env_file = project_root / ".env"
if not env_file.exists():
return
with open(env_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key and key not in os.environ:
os.environ[key] = value
_load_dotenv()
PROJECT_ROOT = _get_project_root()
def _get_path(key: str, default: str) -> Path:
value = os.environ.get(key, default)
p = Path(value)
if not p.is_absolute():
p = PROJECT_ROOT / p
return p
def _get_str(key: str, default: str) -> str:
return os.environ.get(key, default)
def _get_bool(key: str, default: bool) -> bool:
value = os.environ.get(key, "").strip().lower()
if value in ("true", "1", "yes"):
return True
if value in ("false", "0", "no"):
return False
return default
def _get_int(key: str, default: int) -> int:
try:
return int(os.environ.get(key, str(default)))
except ValueError:
return default
def _get_float(key: str, default: float) -> float:
try:
return float(os.environ.get(key, str(default)))
except ValueError:
return default
# ==================== 模型配置 ====================
EMBEDDING_MODEL_NAME = _get_str("EMBEDDING_MODEL_NAME", "Qwen/Qwen3-Embedding-4B")
EMBEDDING_MODEL_PATH = _get_path("EMBEDDING_MODEL_PATH", "models/Qwen3-Embedding-4B")
HF_ENDPOINT = _get_str("HF_ENDPOINT", "https://hf-mirror.com")
# ==================== 硬件配置 ====================
USE_GPU = _get_bool("USE_GPU", True)
USE_FP16 = _get_bool("USE_FP16", True)
BATCH_SIZE = _get_int("BATCH_SIZE", 64)
# ==================== 目录配置 ====================
JRXML_SOURCE_DIR = _get_path("JRXML_SOURCE_DIR", "jrxml_source")
CHUNKER_OUTPUT_DIR = _get_path("CHUNKER_OUTPUT_DIR", "jrxml_chunker_output")
EMBEDDINGS_DIR = _get_path("EMBEDDINGS_DIR", "embeddings")
CHROMA_DB_PATH = _get_path("CHROMA_DB_PATH", "chroma_db")
CHROMA_COLLECTION_NAME = _get_str("CHROMA_COLLECTION_NAME", "jrxml_chunks")
# ==================== 分块配置 ====================
MAX_CHUNK_SIZE = _get_int("MAX_CHUNK_SIZE", 2000)
# ==================== 查询配置 ====================
DEFAULT_N_RESULTS = _get_int("DEFAULT_N_RESULTS", 5)
SIMILARITY_THRESHOLD = _get_float("SIMILARITY_THRESHOLD", 0.3)
def resolve_model_path() -> str:
"""
解析模型路径:
1. 如果 EMBEDDING_MODEL_PATH 本地存在,使用本地路径
2. 否则使用 EMBEDDING_MODEL_NAME 作为 Hub 模型名
"""
if EMBEDDING_MODEL_PATH.exists():
return str(EMBEDDING_MODEL_PATH)
return EMBEDDING_MODEL_NAME
def print_config():
"""打印当前配置(调试用)"""
print(f"{'='*60}")
print(f"JRXML RAG 当前配置")
print(f"{'='*60}")
print(f" 项目根目录: {PROJECT_ROOT}")
print(f" 嵌入模型名称: {EMBEDDING_MODEL_NAME}")
print(f" 嵌入模型路径: {EMBEDDING_MODEL_PATH}")
print(f" 模型解析结果: {resolve_model_path()}")
print(f" HF 镜像: {HF_ENDPOINT}")
print(f" GPU 加速: {USE_GPU}")
print(f" FP16 半精度: {USE_FP16}")
print(f" 批处理大小: {BATCH_SIZE}")
print(f" JRXML 源目录: {JRXML_SOURCE_DIR}")
print(f" 分块输出目录: {CHUNKER_OUTPUT_DIR}")
print(f" 向量输出目录: {EMBEDDINGS_DIR}")
print(f" Chroma 数据库: {CHROMA_DB_PATH}")
print(f" Chroma 集合名: {CHROMA_COLLECTION_NAME}")
print(f" 最大 Chunk 大小: {MAX_CHUNK_SIZE}")
print(f" 默认返回结果数: {DEFAULT_N_RESULTS}")
print(f" 相似度阈值: {SIMILARITY_THRESHOLD}")
print(f"{'='*60}")
if __name__ == "__main__":
print_config()