Train and prediction script for a topic classification model based on bert-chinese.
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
## 话题分类(BERT 中文基座)
|
||||
|
||||
本目录提供一个使用 `google-bert/bert-base-chinese` 的中文话题分类实现:
|
||||
- 自动处理本地/缓存/远程三段式加载逻辑;
|
||||
- `train.py` 进行微调训练;`predict.py` 进行单条或交互式预测;
|
||||
- 所有模型与权重统一保存至本目录的 `model/`。
|
||||
|
||||
参考模型卡片: [google-bert/bert-base-chinese](https://huggingface.co/google-bert/bert-base-chinese)
|
||||
|
||||
### 数据集亮点
|
||||
|
||||
- 约 **410 万**条预过滤高质量问题与回复;
|
||||
- 每个问题对应一个“【话题】”,覆盖 **约 2.8 万**个多样主题;
|
||||
- 从 **1400 万**原始问答中筛选,保留至少 **3 个点赞以上**的答案,确保内容质量与有趣度;
|
||||
- 除了问题、话题与一个或多个回复外,每个回复还带有点赞数、回复 ID、回复者标签;
|
||||
- 数据清洗去重后划分三部分:示例划分训练集约 **412 万**、验证/测试若干(可按需调整)。
|
||||
|
||||
> 实际训练时,请以 `dataset/` 下的 CSV 为准;脚本会自动识别常见列名或允许通过命令参数显式指定。
|
||||
|
||||
### 目录结构
|
||||
|
||||
```
|
||||
BertTopicDetection_Finetuned/
|
||||
├─ dataset/ # 已放置数据
|
||||
├─ model/ # 训练生成;亦缓存基础 BERT
|
||||
├─ train.py
|
||||
├─ predict.py
|
||||
└─ README.md
|
||||
```
|
||||
|
||||
### 环境
|
||||
|
||||
```
|
||||
pip install torch transformers scikit-learn pandas
|
||||
```
|
||||
|
||||
或使用你既有的 Conda 环境。
|
||||
|
||||
### 数据格式
|
||||
|
||||
CSV 至少包含文本列与标签列,脚本会尝试自动识别:
|
||||
- 文本列候选:`text`/`content`/`sentence`/`title`/`desc`/`question`
|
||||
- 标签列候选:`label`/`labels`/`category`/`topic`/`class`
|
||||
|
||||
如需显式指定,请使用 `--text_col` 与 `--label_col`。
|
||||
|
||||
### 训练
|
||||
|
||||
```
|
||||
python train.py \
|
||||
--train_file ./dataset/web_text_zh_train.csv \
|
||||
--valid_file ./dataset/web_text_zh_valid.csv \
|
||||
--text_col auto \
|
||||
--label_col auto \
|
||||
--model_root ./model \
|
||||
--save_subdir bert-chinese-classifier \
|
||||
--num_epochs 10 --batch_size 16 --learning_rate 2e-5 --fp16
|
||||
```
|
||||
|
||||
要点:
|
||||
- 首次运行会检查 `model/bert-base-chinese`;若无则尝试本机缓存,再不行则自动下载并保存;
|
||||
- 训练过程按步评估与保存(默认每 1/4 个 epoch),最多保留 5 个最近 checkpoint(可通过环境变量 `SAVE_TOTAL_LIMIT` 调整);
|
||||
- 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型;
|
||||
- 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`。
|
||||
|
||||
### 预测
|
||||
|
||||
单条:
|
||||
```
|
||||
python predict.py --text "这条微博讨论的是哪个话题?" --model_root ./model --finetuned_subdir bert-chinese-classifier
|
||||
```
|
||||
|
||||
交互:
|
||||
```
|
||||
python predict.py --interactive --model_root ./model --finetuned_subdir bert-chinese-classifier
|
||||
```
|
||||
|
||||
示例输出:
|
||||
```
|
||||
预测结果: 体育-足球 (置信度: 0.9412)
|
||||
```
|
||||
|
||||
### 说明
|
||||
|
||||
- 训练与预测均内置简易中文文本清洗。
|
||||
- 标签集合以训练集为准,脚本自动生成并保存 `label_map.json`。
|
||||
|
||||
### 训练策略(简述)
|
||||
|
||||
- 基座:`google-bert/bert-base-chinese`;分类头维度=训练集唯一标签数。
|
||||
- 学习率与正则:`lr=2e-5`,`weight_decay=0.01`,可在大型数据上微调到 `1e-5~3e-5`。
|
||||
- 序列长度与批量:`max_length=128`,`batch_size=16`;若截断严重可升至 256(成本上升)。
|
||||
- Warmup:若环境支持,使用 `warmup_ratio=0.1`;否则回退 `warmup_steps=0`。
|
||||
- 评估/保存:按 `--eval_fraction` 折算步数(默认 0.25),`save_total_limit=5` 限制磁盘占用。
|
||||
- 早停:监控加权 F1(越大越好),默认耐心 5、改善阈值 0.0。
|
||||
- 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"architectures": [
|
||||
"BertModel"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"classifier_dropout": null,
|
||||
"directionality": "bidi",
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 0,
|
||||
"pooler_fc_size": 768,
|
||||
"pooler_num_attention_heads": 12,
|
||||
"pooler_num_fc_layers": 3,
|
||||
"pooler_size_per_head": 128,
|
||||
"pooler_type": "first_token_transform",
|
||||
"position_embedding_type": "absolute",
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.51.3",
|
||||
"type_vocab_size": 2,
|
||||
"use_cache": true,
|
||||
"vocab_size": 21128
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"mask_token": "[MASK]",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,56 @@
|
||||
{
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "[PAD]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"100": {
|
||||
"content": "[UNK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"101": {
|
||||
"content": "[CLS]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"102": {
|
||||
"content": "[SEP]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"103": {
|
||||
"content": "[MASK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"cls_token": "[CLS]",
|
||||
"do_lower_case": false,
|
||||
"extra_special_tokens": {},
|
||||
"mask_token": "[MASK]",
|
||||
"model_max_length": 512,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import argparse
|
||||
from typing import Dict, Tuple
|
||||
|
||||
# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
|
||||
def _extract_gpu_arg(argv, default: str = "0") -> str:
|
||||
for i, arg in enumerate(argv):
|
||||
if arg.startswith("--gpu="):
|
||||
return arg.split("=", 1)[1]
|
||||
if arg == "--gpu" and i + 1 < len(argv):
|
||||
return argv[i + 1]
|
||||
return default
|
||||
|
||||
env_vis = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
|
||||
try:
|
||||
gpu_to_use = _extract_gpu_arg(sys.argv, default="0")
|
||||
except Exception:
|
||||
gpu_to_use = "0"
|
||||
if (not env_vis) or ("," in env_vis):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_to_use
|
||||
os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
|
||||
|
||||
for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
|
||||
os.environ.pop(_k, None)
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
)
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
if text is None:
|
||||
return ""
|
||||
text = str(text)
|
||||
text = re.sub(r"\{%.+?%\}", " ", text)
|
||||
text = re.sub(r"@.+?( |$)", " ", text)
|
||||
text = re.sub(r"【.+?】", " ", text)
|
||||
text = re.sub(r"\u200b", " ", text)
|
||||
text = re.sub(
|
||||
r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+",
|
||||
"",
|
||||
text,
|
||||
)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def ensure_base_model_local(model_name_or_path: str, local_model_root: str) -> Tuple[str, AutoTokenizer]:
|
||||
os.makedirs(local_model_root, exist_ok=True)
|
||||
base_dir = os.path.join(local_model_root, "bert-base-chinese")
|
||||
|
||||
def is_ready(path: str) -> bool:
|
||||
return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))
|
||||
|
||||
if is_ready(base_dir):
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
|
||||
# 本机缓存
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, local_files_only=True)
|
||||
base = AutoModel.from_pretrained(model_name_or_path, local_files_only=True)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(base_dir)
|
||||
base.save_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 远程下载
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
base = AutoModel.from_pretrained(model_name_or_path)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(base_dir)
|
||||
base.save_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="使用本地/缓存/远程加载的中文 BERT 分类模型进行预测")
|
||||
parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
|
||||
parser.add_argument("--finetuned_subdir", type=str, default="bert-chinese-classifier", help="微调结果子目录")
|
||||
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="预训练模型名称或路径")
|
||||
parser.add_argument("--text", type=str, default=None, help="直接输入一条要预测的文本")
|
||||
parser.add_argument("--interactive", action="store_true", help="进入交互式预测模式")
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--gpu", type=str, default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), help="指定单卡 GPU,如 0 或 1")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]:
|
||||
finetuned_path = os.path.join(model_root, subdir)
|
||||
if not os.path.isdir(finetuned_path):
|
||||
raise FileNotFoundError(
|
||||
f"未找到微调模型目录: {finetuned_path},请先运行训练脚本。"
|
||||
)
|
||||
label_map_path = os.path.join(finetuned_path, "label_map.json")
|
||||
id2label = None
|
||||
if os.path.isfile(label_map_path):
|
||||
with open(label_map_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
id2label = {int(k): str(v) for k, v in data.get("id2label", {}).items()}
|
||||
return finetuned_path, id2label
|
||||
|
||||
|
||||
def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
processed = preprocess_text(text)
|
||||
encoded = tokenizer(
|
||||
processed,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = encoded["input_ids"].to(device)
|
||||
attention_mask = encoded["attention_mask"].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
pred = int(torch.argmax(probs, dim=-1).item())
|
||||
conf = float(probs[0, pred].item())
|
||||
id2label = getattr(model.config, "id2label", None)
|
||||
label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred)
|
||||
return label_name, conf
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
|
||||
# 确保基础模型在本地
|
||||
ensure_base_model_local(args.pretrained_name, model_root)
|
||||
|
||||
finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir)
|
||||
|
||||
if args.text is not None:
|
||||
label, conf = predict_once(finetuned_dir, args.text, args.max_length)
|
||||
print(f"预测结果: {label} (置信度: {conf:.4f})")
|
||||
return
|
||||
|
||||
if args.interactive:
|
||||
print("进入交互模式。输入 'q' 退出。")
|
||||
while True:
|
||||
try:
|
||||
text = input("请输入文本: ").strip()
|
||||
except EOFError:
|
||||
break
|
||||
if text.lower() == "q":
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
label, conf = predict_once(finetuned_dir, text, args.max_length)
|
||||
print(f"预测结果: {label} (置信度: {conf:.4f})")
|
||||
return
|
||||
|
||||
print("未提供 --text 或 --interactive,什么也没有发生。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -0,0 +1,440 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import argparse
|
||||
import math
|
||||
import inspect
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
|
||||
def _extract_gpu_arg(argv: List[str], default: str = "0") -> str:
|
||||
for i, arg in enumerate(argv):
|
||||
if arg.startswith("--gpu="):
|
||||
return arg.split("=", 1)[1]
|
||||
if arg == "--gpu" and i + 1 < len(argv):
|
||||
return argv[i + 1]
|
||||
return default
|
||||
|
||||
env_vis = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
|
||||
try:
|
||||
gpu_to_use = _extract_gpu_arg(sys.argv, default="0")
|
||||
except Exception:
|
||||
gpu_to_use = "0"
|
||||
# 若未设置或暴露了多卡,则强制只暴露单卡(默认0)以确保直接运行稳定
|
||||
if (not env_vis) or ("," in env_vis):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_to_use
|
||||
os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
|
||||
|
||||
# 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式
|
||||
for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
|
||||
os.environ.pop(_k, None)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
|
||||
import pandas as pd
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoConfig,
|
||||
DataCollatorWithPadding,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
try:
|
||||
from transformers import EarlyStoppingCallback # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
EarlyStoppingCallback = None # type: ignore
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
if text is None:
|
||||
return ""
|
||||
text = str(text)
|
||||
text = re.sub(r"\{%.+?%\}", " ", text)
|
||||
text = re.sub(r"@.+?( |$)", " ", text)
|
||||
text = re.sub(r"【.+?】", " ", text)
|
||||
text = re.sub(r"\u200b", " ", text)
|
||||
text = re.sub(
|
||||
r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+",
|
||||
"",
|
||||
text,
|
||||
)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def ensure_base_model_local(model_name_or_path: str, local_model_root: str) -> Tuple[str, AutoTokenizer]:
|
||||
os.makedirs(local_model_root, exist_ok=True)
|
||||
base_dir = os.path.join(local_model_root, "bert-base-chinese")
|
||||
|
||||
def is_ready(path: str) -> bool:
|
||||
return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))
|
||||
|
||||
# 1) 本地现成
|
||||
if is_ready(base_dir):
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
|
||||
# 2) 本机缓存
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, local_files_only=True)
|
||||
base = AutoModel.from_pretrained(model_name_or_path, local_files_only=True)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(base_dir)
|
||||
base.save_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3) 远程下载
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
base = AutoModel.from_pretrained(model_name_or_path)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(base_dir)
|
||||
base.save_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
|
||||
|
||||
class TextClassificationDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataframe: pd.DataFrame,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_column: str,
|
||||
label_column: str,
|
||||
label2id: Dict[str, int],
|
||||
max_length: int,
|
||||
) -> None:
|
||||
self.dataframe = dataframe.reset_index(drop=True)
|
||||
self.tokenizer = tokenizer
|
||||
self.text_column = text_column
|
||||
self.label_column = label_column
|
||||
self.label2id = label2id
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataframe)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
row = self.dataframe.iloc[idx]
|
||||
text = preprocess_text(row[self.text_column])
|
||||
encoding = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
padding=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
item = {k: v.squeeze(0) for k, v in encoding.items()}
|
||||
if self.label_column in row and pd.notna(row[self.label_column]):
|
||||
label_str = str(row[self.label_column])
|
||||
item["labels"] = torch.tensor(self.label2id[label_str], dtype=torch.long)
|
||||
return item
|
||||
|
||||
|
||||
def build_label_mappings(train_df: pd.DataFrame, label_column: str) -> Tuple[Dict[str, int], Dict[int, str]]:
|
||||
labels: List[str] = [str(x) for x in train_df[label_column].dropna().astype(str).tolist()]
|
||||
unique_sorted = sorted(set(labels))
|
||||
label2id = {label: i for i, label in enumerate(unique_sorted)}
|
||||
id2label = {i: label for label, i in label2id.items()}
|
||||
return label2id, id2label
|
||||
|
||||
|
||||
def compute_metrics_fn(eval_pred) -> Dict[str, float]:
|
||||
logits, labels = eval_pred
|
||||
preds = np.argmax(logits, axis=-1)
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted", zero_division=0)
|
||||
acc = accuracy_score(labels, preds)
|
||||
return {
|
||||
"accuracy": float(acc),
|
||||
"precision": float(precision),
|
||||
"recall": float(recall),
|
||||
"f1": float(f1),
|
||||
}
|
||||
|
||||
|
||||
def autodetect_columns(df: pd.DataFrame, text_col: str, label_col: str) -> Tuple[str, str]:
|
||||
if text_col != "auto" and label_col != "auto":
|
||||
return text_col, label_col
|
||||
candidates_text = ["text", "content", "sentence", "title", "desc", "question"]
|
||||
candidates_label = ["label", "labels", "category", "topic", "class"]
|
||||
t = text_col
|
||||
l = label_col
|
||||
if text_col == "auto":
|
||||
for name in candidates_text:
|
||||
if name in df.columns:
|
||||
t = name
|
||||
break
|
||||
if label_col == "auto":
|
||||
for name in candidates_label:
|
||||
if name in df.columns:
|
||||
l = name
|
||||
break
|
||||
if t == "auto" or l == "auto":
|
||||
raise ValueError(
|
||||
f"无法自动识别列名,请显式传入 --text_col 与 --label_col。现有列: {list(df.columns)}"
|
||||
)
|
||||
return t, l
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="使用 google-bert/bert-base-chinese 在本目录数据集上进行文本分类微调")
|
||||
parser.add_argument("--train_file", type=str, default="./dataset/web_text_zh_train.csv")
|
||||
parser.add_argument("--valid_file", type=str, default="./dataset/web_text_zh_valid.csv")
|
||||
parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别")
|
||||
parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别")
|
||||
parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
|
||||
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese")
|
||||
parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier")
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_epochs", type=int, default=10)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01)
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.1)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--gpu", type=str, default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), help="指定单卡 GPU,如 0 或 1")
|
||||
parser.add_argument("--eval_fraction", type=float, default=0.25, help="每多少个 epoch 做一次评估与保存,例如 0.25 表示每四分之一个 epoch")
|
||||
parser.add_argument("--early_stop_patience", type=int, default=5, help="早停耐心(以评估轮次计)")
|
||||
parser.add_argument("--early_stop_threshold", type=float, default=0.0, help="早停最小改善阈值(与 metric_for_best_model 同单位)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
|
||||
# 确保基础模型就绪
|
||||
base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root)
|
||||
print(f"[Info] 使用基础模型目录: {base_dir}")
|
||||
|
||||
# 读取数据
|
||||
train_path = args.train_file if os.path.isabs(args.train_file) else os.path.join(script_dir, args.train_file)
|
||||
valid_path = args.valid_file if os.path.isabs(args.valid_file) else os.path.join(script_dir, args.valid_file)
|
||||
if not os.path.isfile(train_path):
|
||||
raise FileNotFoundError(f"训练集不存在: {train_path}")
|
||||
train_df = pd.read_csv(train_path)
|
||||
if not os.path.isfile(valid_path):
|
||||
# 若未提供或不存在验证集,自动切分
|
||||
shuffled = train_df.sample(frac=1.0, random_state=args.seed).reset_index(drop=True)
|
||||
split_idx = int(len(shuffled) * 0.9)
|
||||
valid_df = shuffled.iloc[split_idx:].reset_index(drop=True)
|
||||
train_df = shuffled.iloc[:split_idx].reset_index(drop=True)
|
||||
else:
|
||||
valid_df = pd.read_csv(valid_path)
|
||||
print(f"[Info] 训练集: {train_path} | 样本数: {len(train_df)}")
|
||||
print(f"[Info] 验证集: {valid_path if os.path.isfile(valid_path) else '(从训练集切分)'} | 样本数: {len(valid_df)}")
|
||||
|
||||
# 自动识别列名
|
||||
text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col)
|
||||
print(f"[Info] 文本列: {text_col} | 标签列: {label_col}")
|
||||
|
||||
# 标签映射
|
||||
label2id, id2label = build_label_mappings(train_df, label_col)
|
||||
if len(label2id) < 2:
|
||||
raise ValueError("标签类别数少于 2,无法训练分类模型。")
|
||||
print(f"[Info] 标签类别数: {len(label2id)}")
|
||||
|
||||
# 数据集
|
||||
train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length)
|
||||
eval_dataset = TextClassificationDataset(valid_df, tokenizer, text_col, label_col, label2id, args.max_length)
|
||||
collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# 模型
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_dir,
|
||||
num_labels=len(label2id),
|
||||
id2label={int(i): str(l) for i, l in id2label.items()},
|
||||
label2id={str(l): int(i) for l, i in label2id.items()},
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_dir,
|
||||
config=config,
|
||||
ignore_mismatched_sizes=True,
|
||||
)
|
||||
|
||||
# 训练参数
|
||||
output_dir = os.path.join(model_root, args.save_subdir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# 训练参数(兼容不同 transformers 版本)
|
||||
args_dict = {
|
||||
"output_dir": output_dir,
|
||||
"per_device_train_batch_size": args.batch_size,
|
||||
"per_device_eval_batch_size": args.batch_size,
|
||||
"learning_rate": args.learning_rate,
|
||||
"weight_decay": args.weight_decay,
|
||||
"num_train_epochs": args.num_epochs,
|
||||
"logging_steps": 100,
|
||||
"fp16": args.fp16,
|
||||
"seed": args.seed,
|
||||
}
|
||||
|
||||
sig = inspect.signature(TrainingArguments.__init__)
|
||||
allowed = set(sig.parameters.keys())
|
||||
|
||||
# 可选参数(仅在支持时添加,尽量简化与参考实现一致以提升兼容性)
|
||||
if "warmup_ratio" in allowed:
|
||||
args_dict["warmup_ratio"] = args.warmup_ratio
|
||||
if "report_to" in allowed:
|
||||
args_dict["report_to"] = []
|
||||
# 评估/保存步进:按 eval_fraction 折算每个 epoch 的步数
|
||||
steps_per_epoch = max(1, math.ceil(len(train_dataset) / max(1, args.batch_size)))
|
||||
eval_every_steps = max(1, math.ceil(steps_per_epoch * max(0.01, min(1.0, args.eval_fraction))))
|
||||
# 策略式(新/旧版本字段名兼容)
|
||||
key_eval = "evaluation_strategy" if "evaluation_strategy" in allowed else ("eval_strategy" if "eval_strategy" in allowed else None)
|
||||
if key_eval:
|
||||
args_dict[key_eval] = "steps"
|
||||
if "save_strategy" in allowed:
|
||||
args_dict["save_strategy"] = "steps"
|
||||
if "eval_steps" in allowed:
|
||||
args_dict["eval_steps"] = eval_every_steps
|
||||
if "save_steps" in allowed:
|
||||
args_dict["save_steps"] = eval_every_steps
|
||||
if "save_total_limit" in allowed:
|
||||
args_dict["save_total_limit"] = 5
|
||||
# 将日志步长与评估/保存步长对齐,减少刷屏
|
||||
if "logging_steps" in allowed:
|
||||
args_dict["logging_steps"] = eval_every_steps
|
||||
# 最优模型回滚(仅当评估与保存策略一致时开启)
|
||||
if "metric_for_best_model" in allowed:
|
||||
args_dict["metric_for_best_model"] = "f1"
|
||||
if "greater_is_better" in allowed:
|
||||
args_dict["greater_is_better"] = True
|
||||
if "load_best_model_at_end" in allowed:
|
||||
eval_strat = args_dict.get("evaluation_strategy", args_dict.get("eval_strategy"))
|
||||
save_strat = args_dict.get("save_strategy")
|
||||
if eval_strat == save_strat and eval_strat in ("steps", "epoch"):
|
||||
args_dict["load_best_model_at_end"] = True
|
||||
|
||||
# 兼容无 warmup_ratio 的版本:若支持 warmup_steps 则忽略比例
|
||||
if "warmup_ratio" not in allowed and "warmup_steps" in allowed:
|
||||
# 不计算总步数,默认 0
|
||||
args_dict["warmup_steps"] = 0
|
||||
|
||||
# 若不支持策略式参数:退化为每 eval_every_steps 步保存/评估
|
||||
if "save_strategy" not in allowed and "save_steps" in allowed:
|
||||
args_dict["save_steps"] = eval_every_steps
|
||||
if ("evaluation_strategy" not in allowed and "eval_strategy" not in allowed) and "eval_steps" in allowed:
|
||||
args_dict["eval_steps"] = eval_every_steps
|
||||
|
||||
# 如果支持 load_best_model_at_end,但无法同时设置评估/保存策略,则关闭它以避免报错
|
||||
if "load_best_model_at_end" in allowed:
|
||||
want_load_best = args_dict.get("load_best_model_at_end", False)
|
||||
eval_set = args_dict.get("evaluation_strategy", None)
|
||||
save_set = args_dict.get("save_strategy", None)
|
||||
if want_load_best and (eval_set is None or save_set is None or eval_set != save_set):
|
||||
args_dict["load_best_model_at_end"] = False
|
||||
|
||||
training_args = TrainingArguments(**args_dict)
|
||||
print("[Info] 训练参数要点:")
|
||||
print(f" epochs={args.num_epochs}, batch_size={args.batch_size}, lr={args.learning_rate}, weight_decay={args.weight_decay}")
|
||||
print(f" max_length={args.max_length}, seed={args.seed}, fp16={args.fp16}")
|
||||
if "warmup_ratio" in allowed and "warmup_ratio" in args_dict:
|
||||
print(f" warmup_ratio={args_dict['warmup_ratio']}")
|
||||
elif "warmup_steps" in allowed and "warmup_steps" in args_dict:
|
||||
print(f" warmup_steps={args_dict['warmup_steps']}")
|
||||
print(f" steps_per_epoch={steps_per_epoch}, eval_every_steps={eval_every_steps}")
|
||||
print(f" eval_strategy={args_dict.get('evaluation_strategy', args_dict.get('eval_strategy'))}, save_strategy={args_dict.get('save_strategy')}, logging_steps={args_dict.get('logging_steps')}")
|
||||
print(f" save_total_limit={args_dict.get('save_total_limit', 'n/a')}, load_best_model_at_end={args_dict.get('load_best_model_at_end', False)}")
|
||||
|
||||
callbacks = []
|
||||
if EarlyStoppingCallback is not None and (args_dict.get("evaluation_strategy") in ("steps", "epoch") or "eval_steps" in allowed):
|
||||
try:
|
||||
callbacks.append(
|
||||
EarlyStoppingCallback(
|
||||
early_stopping_patience=args.early_stop_patience,
|
||||
early_stopping_threshold=args.early_stop_threshold,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collator,
|
||||
compute_metrics=compute_metrics_fn,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# 设备与 GPU 信息
|
||||
try:
|
||||
device_cnt = torch.cuda.device_count()
|
||||
dev_name = torch.cuda.get_device_name(0) if device_cnt > 0 else "cpu"
|
||||
print(f"[Info] CUDA 可见设备数: {device_cnt}, 当前设备: {dev_name}, CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("[Info] 开始训练 ...")
|
||||
|
||||
trainer.train()
|
||||
|
||||
# 保存
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
trainer.model.config.id2label = {int(i): str(l) for i, l in id2label.items()}
|
||||
trainer.model.config.label2id = {str(l): int(i) for l, i in label2id.items()}
|
||||
trainer.save_model(output_dir)
|
||||
try:
|
||||
best_metric = getattr(trainer.state, "best_metric", None)
|
||||
best_ckpt = getattr(trainer.state, "best_model_checkpoint", None)
|
||||
if best_metric is not None and best_ckpt is not None:
|
||||
print(f"[Info] 最优模型: metric={best_metric:.6f} | checkpoint={best_ckpt}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with open(os.path.join(output_dir, "label_map.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{"label2id": trainer.model.config.label2id, "id2label": trainer.model.config.id2label},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
# 训练曲线:可选保存训练与评估 loss
|
||||
try:
|
||||
import matplotlib.pyplot as plt # type: ignore
|
||||
logs = trainer.state.log_history
|
||||
t_steps, t_losses, e_steps, e_losses = [], [], [], []
|
||||
step_counter = 0
|
||||
for rec in logs:
|
||||
if "loss" in rec and "epoch" in rec:
|
||||
step_counter += 1
|
||||
t_steps.append(step_counter)
|
||||
t_losses.append(rec["loss"])
|
||||
if "eval_loss" in rec:
|
||||
e_steps.append(step_counter)
|
||||
e_losses.append(rec["eval_loss"])
|
||||
if t_losses or e_losses:
|
||||
plt.figure(figsize=(8,4))
|
||||
if t_losses:
|
||||
plt.plot(t_steps, t_losses, label="train_loss")
|
||||
if e_losses:
|
||||
plt.plot(e_steps, e_losses, label="eval_loss")
|
||||
plt.xlabel("training step (logged)")
|
||||
plt.ylabel("loss")
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "training_curve.png"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"微调完成,模型已保存到: {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user