refactor: 重构项目配置管理,统一使用.env配置
- 新增config.py统一读取.env配置,移除硬编码路径和参数 - 重构collect_jrxml.py支持命令行参数和环境变量配置源目录 - 新增.env.example示例配置文件,整理所有可配置项 - 重构down_embedding_model.py、import_to_chroma.py等所有脚本使用统一配置 - 新增Windows一键部署脚本setup.bat - 修正jrxml_banch_chunker.py的文件名拼写错误
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
config.py
|
||||
统一配置管理,从 .env 文件读取环境变量
|
||||
所有脚本通过此模块获取配置,避免硬编码
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _get_project_root() -> Path:
|
||||
return Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def _load_dotenv():
|
||||
"""简易 .env 文件加载器,不依赖 python-dotenv"""
|
||||
project_root = _get_project_root()
|
||||
env_file = project_root / ".env"
|
||||
|
||||
if not env_file.exists():
|
||||
return
|
||||
|
||||
with open(env_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
key = key.strip()
|
||||
value = value.strip().strip('"').strip("'")
|
||||
if key and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
_load_dotenv()
|
||||
|
||||
PROJECT_ROOT = _get_project_root()
|
||||
|
||||
|
||||
def _get_path(key: str, default: str) -> Path:
|
||||
value = os.environ.get(key, default)
|
||||
p = Path(value)
|
||||
if not p.is_absolute():
|
||||
p = PROJECT_ROOT / p
|
||||
return p
|
||||
|
||||
|
||||
def _get_str(key: str, default: str) -> str:
|
||||
return os.environ.get(key, default)
|
||||
|
||||
|
||||
def _get_bool(key: str, default: bool) -> bool:
|
||||
value = os.environ.get(key, "").strip().lower()
|
||||
if value in ("true", "1", "yes"):
|
||||
return True
|
||||
if value in ("false", "0", "no"):
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def _get_int(key: str, default: int) -> int:
|
||||
try:
|
||||
return int(os.environ.get(key, str(default)))
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def _get_float(key: str, default: float) -> float:
|
||||
try:
|
||||
return float(os.environ.get(key, str(default)))
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
# ==================== 模型配置 ====================
|
||||
EMBEDDING_MODEL_NAME = _get_str("EMBEDDING_MODEL_NAME", "Qwen/Qwen3-Embedding-4B")
|
||||
EMBEDDING_MODEL_PATH = _get_path("EMBEDDING_MODEL_PATH", "models/Qwen3-Embedding-4B")
|
||||
HF_ENDPOINT = _get_str("HF_ENDPOINT", "https://hf-mirror.com")
|
||||
|
||||
# ==================== 硬件配置 ====================
|
||||
USE_GPU = _get_bool("USE_GPU", True)
|
||||
USE_FP16 = _get_bool("USE_FP16", True)
|
||||
BATCH_SIZE = _get_int("BATCH_SIZE", 64)
|
||||
|
||||
# ==================== 目录配置 ====================
|
||||
JRXML_SOURCE_DIR = _get_path("JRXML_SOURCE_DIR", "jrxml_source")
|
||||
CHUNKER_OUTPUT_DIR = _get_path("CHUNKER_OUTPUT_DIR", "jrxml_chunker_output")
|
||||
EMBEDDINGS_DIR = _get_path("EMBEDDINGS_DIR", "embeddings")
|
||||
CHROMA_DB_PATH = _get_path("CHROMA_DB_PATH", "chroma_db")
|
||||
CHROMA_COLLECTION_NAME = _get_str("CHROMA_COLLECTION_NAME", "jrxml_chunks")
|
||||
|
||||
# ==================== 分块配置 ====================
|
||||
MAX_CHUNK_SIZE = _get_int("MAX_CHUNK_SIZE", 2000)
|
||||
|
||||
# ==================== 查询配置 ====================
|
||||
DEFAULT_N_RESULTS = _get_int("DEFAULT_N_RESULTS", 5)
|
||||
SIMILARITY_THRESHOLD = _get_float("SIMILARITY_THRESHOLD", 0.3)
|
||||
|
||||
|
||||
def resolve_model_path() -> str:
|
||||
"""
|
||||
解析模型路径:
|
||||
1. 如果 EMBEDDING_MODEL_PATH 本地存在,使用本地路径
|
||||
2. 否则使用 EMBEDDING_MODEL_NAME 作为 Hub 模型名
|
||||
"""
|
||||
if EMBEDDING_MODEL_PATH.exists():
|
||||
return str(EMBEDDING_MODEL_PATH)
|
||||
return EMBEDDING_MODEL_NAME
|
||||
|
||||
|
||||
def print_config():
|
||||
"""打印当前配置(调试用)"""
|
||||
print(f"{'='*60}")
|
||||
print(f"JRXML RAG 当前配置")
|
||||
print(f"{'='*60}")
|
||||
print(f" 项目根目录: {PROJECT_ROOT}")
|
||||
print(f" 嵌入模型名称: {EMBEDDING_MODEL_NAME}")
|
||||
print(f" 嵌入模型路径: {EMBEDDING_MODEL_PATH}")
|
||||
print(f" 模型解析结果: {resolve_model_path()}")
|
||||
print(f" HF 镜像: {HF_ENDPOINT}")
|
||||
print(f" GPU 加速: {USE_GPU}")
|
||||
print(f" FP16 半精度: {USE_FP16}")
|
||||
print(f" 批处理大小: {BATCH_SIZE}")
|
||||
print(f" JRXML 源目录: {JRXML_SOURCE_DIR}")
|
||||
print(f" 分块输出目录: {CHUNKER_OUTPUT_DIR}")
|
||||
print(f" 向量输出目录: {EMBEDDINGS_DIR}")
|
||||
print(f" Chroma 数据库: {CHROMA_DB_PATH}")
|
||||
print(f" Chroma 集合名: {CHROMA_COLLECTION_NAME}")
|
||||
print(f" 最大 Chunk 大小: {MAX_CHUNK_SIZE}")
|
||||
print(f" 默认返回结果数: {DEFAULT_N_RESULTS}")
|
||||
print(f" 相似度阈值: {SIMILARITY_THRESHOLD}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_config()
|
||||
Reference in New Issue
Block a user