Updated how the fine-tuned BERT model is stored.
This commit is contained in:
@@ -181,6 +181,7 @@ WeiboSentiment_Finetuned/GPT2-Lora/models/
|
||||
WeiboSentiment_Finetuned/GPT2-AdapterTuning/models/
|
||||
WeiboSentiment_Finetuned/BertChinese-Lora/models/
|
||||
WeiboSentiment_LLM/models/
|
||||
WeiboSentiment_Finetuned/BertChinese-Lora/model/
|
||||
|
||||
# LoRA 和 Adapter 权重
|
||||
*/adapter_model.safetensors
|
||||
|
||||
@@ -64,8 +64,15 @@ print("正面情感" if prediction == 1 else "负面情感")
|
||||
- `predict_pipeline.py`: 使用pipeline方式的预测程序
|
||||
- `README.md`: 使用说明
|
||||
|
||||
## 模型存储
|
||||
|
||||
- 首次运行时会自动下载模型到当前目录的 `model` 文件夹
|
||||
- 后续运行会直接从本地加载,无需重复下载
|
||||
- 模型大小约400MB,首次下载需要网络连接
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 首次运行时会自动下载模型,需要网络连接
|
||||
- 模型大小约400MB,下载可能需要一些时间
|
||||
- 模型会保存到当前目录,方便后续使用
|
||||
- 支持GPU加速,会自动检测可用设备
|
||||
- 如需清理模型文件,删除 `model` 文件夹即可
|
||||
@@ -16,12 +16,26 @@ def main():
|
||||
|
||||
# 使用HuggingFace预训练模型
|
||||
model_name = "wsqstar/GISchat-weibo-100k-fine-tuned-bert"
|
||||
local_model_path = "./model"
|
||||
|
||||
try:
|
||||
# 加载模型和分词器
|
||||
# 检查本地是否已有模型
|
||||
import os
|
||||
if os.path.exists(local_model_path):
|
||||
print("从本地加载模型...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(local_model_path)
|
||||
else:
|
||||
print("首次使用,正在下载模型到本地...")
|
||||
# 下载并保存到本地
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
|
||||
# 保存到本地
|
||||
tokenizer.save_pretrained(local_model_path)
|
||||
model.save_pretrained(local_model_path)
|
||||
print(f"模型已保存到: {local_model_path}")
|
||||
|
||||
# 设置设备
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model.to(device)
|
||||
|
||||
@@ -15,11 +15,34 @@ def main():
|
||||
|
||||
# 使用pipeline方式 - 更简单
|
||||
model_name = "wsqstar/GISchat-weibo-100k-fine-tuned-bert"
|
||||
local_model_path = "./model"
|
||||
|
||||
try:
|
||||
# 检查本地是否已有模型
|
||||
import os
|
||||
if os.path.exists(local_model_path):
|
||||
print("从本地加载模型...")
|
||||
classifier = pipeline(
|
||||
"text-classification",
|
||||
model=model_name,
|
||||
model=local_model_path,
|
||||
return_all_scores=True
|
||||
)
|
||||
else:
|
||||
print("首次使用,正在下载模型到本地...")
|
||||
# 先下载模型
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
|
||||
# 保存到本地
|
||||
tokenizer.save_pretrained(local_model_path)
|
||||
model.save_pretrained(local_model_path)
|
||||
print(f"模型已保存到: {local_model_path}")
|
||||
|
||||
# 使用本地模型创建pipeline
|
||||
classifier = pipeline(
|
||||
"text-classification",
|
||||
model=local_model_path,
|
||||
return_all_scores=True
|
||||
)
|
||||
print("模型加载成功!")
|
||||
|
||||
Reference in New Issue
Block a user