""" config.py 统一配置管理,从 .env 文件读取环境变量 所有脚本通过此模块获取配置,避免硬编码 """ import os from pathlib import Path def _get_project_root() -> Path: return Path(__file__).resolve().parent def _load_dotenv(): """简易 .env 文件加载器,不依赖 python-dotenv""" project_root = _get_project_root() env_file = project_root / ".env" if not env_file.exists(): return with open(env_file, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line or line.startswith("#"): continue if "=" not in line: continue key, _, value = line.partition("=") key = key.strip() value = value.strip().strip('"').strip("'") if key and key not in os.environ: os.environ[key] = value _load_dotenv() PROJECT_ROOT = _get_project_root() def _get_path(key: str, default: str) -> Path: value = os.environ.get(key, default) p = Path(value) if not p.is_absolute(): p = PROJECT_ROOT / p return p def _get_str(key: str, default: str) -> str: return os.environ.get(key, default) def _get_bool(key: str, default: bool) -> bool: value = os.environ.get(key, "").strip().lower() if value in ("true", "1", "yes"): return True if value in ("false", "0", "no"): return False return default def _get_int(key: str, default: int) -> int: try: return int(os.environ.get(key, str(default))) except ValueError: return default def _get_float(key: str, default: float) -> float: try: return float(os.environ.get(key, str(default))) except ValueError: return default # ==================== 模型配置 ==================== EMBEDDING_MODEL_NAME = _get_str("EMBEDDING_MODEL_NAME", "Qwen/Qwen3-Embedding-4B") EMBEDDING_MODEL_PATH = _get_path("EMBEDDING_MODEL_PATH", "models/Qwen3-Embedding-4B") HF_ENDPOINT = _get_str("HF_ENDPOINT", "https://hf-mirror.com") # ==================== 硬件配置 ==================== USE_GPU = _get_bool("USE_GPU", True) USE_FP16 = _get_bool("USE_FP16", True) BATCH_SIZE = _get_int("BATCH_SIZE", 64) # ==================== 目录配置 ==================== JRXML_SOURCE_DIR = _get_path("JRXML_SOURCE_DIR", "jrxml_source") CHUNKER_OUTPUT_DIR = _get_path("CHUNKER_OUTPUT_DIR", "jrxml_chunker_output") EMBEDDINGS_DIR = _get_path("EMBEDDINGS_DIR", "embeddings") CHROMA_DB_PATH = _get_path("CHROMA_DB_PATH", "chroma_db") CHROMA_COLLECTION_NAME = _get_str("CHROMA_COLLECTION_NAME", "jrxml_chunks") # ==================== 分块配置 ==================== MAX_CHUNK_SIZE = _get_int("MAX_CHUNK_SIZE", 2000) # ==================== 查询配置 ==================== DEFAULT_N_RESULTS = _get_int("DEFAULT_N_RESULTS", 5) SIMILARITY_THRESHOLD = _get_float("SIMILARITY_THRESHOLD", 0.3) def resolve_model_path() -> str: """ 解析模型路径: 1. 如果 EMBEDDING_MODEL_PATH 本地存在,使用本地路径 2. 否则使用 EMBEDDING_MODEL_NAME 作为 Hub 模型名 """ if EMBEDDING_MODEL_PATH.exists(): return str(EMBEDDING_MODEL_PATH) return EMBEDDING_MODEL_NAME def print_config(): """打印当前配置(调试用)""" print(f"{'='*60}") print(f"JRXML RAG 当前配置") print(f"{'='*60}") print(f" 项目根目录: {PROJECT_ROOT}") print(f" 嵌入模型名称: {EMBEDDING_MODEL_NAME}") print(f" 嵌入模型路径: {EMBEDDING_MODEL_PATH}") print(f" 模型解析结果: {resolve_model_path()}") print(f" HF 镜像: {HF_ENDPOINT}") print(f" GPU 加速: {USE_GPU}") print(f" FP16 半精度: {USE_FP16}") print(f" 批处理大小: {BATCH_SIZE}") print(f" JRXML 源目录: {JRXML_SOURCE_DIR}") print(f" 分块输出目录: {CHUNKER_OUTPUT_DIR}") print(f" 向量输出目录: {EMBEDDINGS_DIR}") print(f" Chroma 数据库: {CHROMA_DB_PATH}") print(f" Chroma 集合名: {CHROMA_COLLECTION_NAME}") print(f" 最大 Chunk 大小: {MAX_CHUNK_SIZE}") print(f" 默认返回结果数: {DEFAULT_N_RESULTS}") print(f" 相似度阈值: {SIMILARITY_THRESHOLD}") print(f"{'='*60}") if __name__ == "__main__": print_config()