diff --git a/scripts/init_kb.py b/scripts/init_kb.py index 7aef308..b50ad4c 100644 --- a/scripts/init_kb.py +++ b/scripts/init_kb.py @@ -13,6 +13,31 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) load_dotenv() + +def download_embeddings_model(): + """预下载 Qwen3-Embedding 模型(从 HuggingFace)。 + + 用法: python scripts/init_kb.py --download-model + """ + model_name = os.getenv("LOCAL_EMBED_MODEL", "Qwen/Qwen3-Embedding-0.6B") + print(f"正在下载嵌入模型: {model_name}") + print("如遇网络超时,可手动执行以下命令后重试:") + print(f" huggingface-cli download {model_name} --local-dir ./models/{model_name.replace('/', '_')}") + print() + + try: + from langchain_huggingface import HuggingFaceEmbeddings + except ImportError: + print("错误: 请先安装 huggingface 依赖") + print(" pip install langchain-huggingface sentence-transformers") + return + + # HuggingFaceEmbeddings 会在首次调用时自动下载模型 + embeddings = HuggingFaceEmbeddings(model_name=model_name) + # 调用一次以确保完全下载 + embeddings.embed_query("测试") + print(f"嵌入模型下载完成: {model_name}") + from backend.embeddings import get_embeddings @@ -84,4 +109,12 @@ def main(): if __name__ == '__main__': - main() + import argparse + parser = argparse.ArgumentParser(description='初始化 Chroma 知识库') + parser.add_argument('--download-model', action='store_true', help='仅下载嵌入模型到本地') + args = parser.parse_args() + + if args.download_model: + download_embeddings_model() + else: + main()