From 645242a55269cafc06af2d08b276f9853a970de6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Mon, 4 Aug 2025 14:06:45 +0800 Subject: [PATCH] Updated how the fine-tuned BERT model is stored. --- .gitignore | 1 + .../BertChinese-Lora/README.md | 11 +++++-- .../BertChinese-Lora/predict.py | 20 +++++++++-- .../BertChinese-Lora/predict_pipeline.py | 33 ++++++++++++++++--- 4 files changed, 55 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index fa4b95f..b2a1796 100644 --- a/.gitignore +++ b/.gitignore @@ -181,6 +181,7 @@ WeiboSentiment_Finetuned/GPT2-Lora/models/ WeiboSentiment_Finetuned/GPT2-AdapterTuning/models/ WeiboSentiment_Finetuned/BertChinese-Lora/models/ WeiboSentiment_LLM/models/ +WeiboSentiment_Finetuned/BertChinese-Lora/model/ # LoRA 和 Adapter 权重 */adapter_model.safetensors diff --git a/WeiboSentiment_Finetuned/BertChinese-Lora/README.md b/WeiboSentiment_Finetuned/BertChinese-Lora/README.md index c164ef3..a85c403 100644 --- a/WeiboSentiment_Finetuned/BertChinese-Lora/README.md +++ b/WeiboSentiment_Finetuned/BertChinese-Lora/README.md @@ -64,8 +64,15 @@ print("正面情感" if prediction == 1 else "负面情感") - `predict_pipeline.py`: 使用pipeline方式的预测程序 - `README.md`: 使用说明 +## 模型存储 + +- 首次运行时会自动下载模型到当前目录的 `model` 文件夹 +- 后续运行会直接从本地加载,无需重复下载 +- 模型大小约400MB,首次下载需要网络连接 + ## 注意事项 - 首次运行时会自动下载模型,需要网络连接 -- 模型大小约400MB,下载可能需要一些时间 -- 支持GPU加速,会自动检测可用设备 \ No newline at end of file +- 模型会保存到当前目录,方便后续使用 +- 支持GPU加速,会自动检测可用设备 +- 如需清理模型文件,删除 `model` 文件夹即可 \ No newline at end of file diff --git a/WeiboSentiment_Finetuned/BertChinese-Lora/predict.py b/WeiboSentiment_Finetuned/BertChinese-Lora/predict.py index e8f90ca..627d879 100644 --- a/WeiboSentiment_Finetuned/BertChinese-Lora/predict.py +++ b/WeiboSentiment_Finetuned/BertChinese-Lora/predict.py @@ -16,11 +16,25 @@ def main(): # 使用HuggingFace预训练模型 model_name = "wsqstar/GISchat-weibo-100k-fine-tuned-bert" + local_model_path = "./model" try: - # 加载模型和分词器 - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForSequenceClassification.from_pretrained(model_name) + # 检查本地是否已有模型 + import os + if os.path.exists(local_model_path): + print("从本地加载模型...") + tokenizer = AutoTokenizer.from_pretrained(local_model_path) + model = AutoModelForSequenceClassification.from_pretrained(local_model_path) + else: + print("首次使用,正在下载模型到本地...") + # 下载并保存到本地 + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSequenceClassification.from_pretrained(model_name) + + # 保存到本地 + tokenizer.save_pretrained(local_model_path) + model.save_pretrained(local_model_path) + print(f"模型已保存到: {local_model_path}") # 设置设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') diff --git a/WeiboSentiment_Finetuned/BertChinese-Lora/predict_pipeline.py b/WeiboSentiment_Finetuned/BertChinese-Lora/predict_pipeline.py index 4233697..b8e2482 100644 --- a/WeiboSentiment_Finetuned/BertChinese-Lora/predict_pipeline.py +++ b/WeiboSentiment_Finetuned/BertChinese-Lora/predict_pipeline.py @@ -15,13 +15,36 @@ def main(): # 使用pipeline方式 - 更简单 model_name = "wsqstar/GISchat-weibo-100k-fine-tuned-bert" + local_model_path = "./model" try: - classifier = pipeline( - "text-classification", - model=model_name, - return_all_scores=True - ) + # 检查本地是否已有模型 + import os + if os.path.exists(local_model_path): + print("从本地加载模型...") + classifier = pipeline( + "text-classification", + model=local_model_path, + return_all_scores=True + ) + else: + print("首次使用,正在下载模型到本地...") + # 先下载模型 + from transformers import AutoTokenizer, AutoModelForSequenceClassification + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSequenceClassification.from_pretrained(model_name) + + # 保存到本地 + tokenizer.save_pretrained(local_model_path) + model.save_pretrained(local_model_path) + print(f"模型已保存到: {local_model_path}") + + # 使用本地模型创建pipeline + classifier = pipeline( + "text-classification", + model=local_model_path, + return_all_scores=True + ) print("模型加载成功!") except Exception as e: