Fix label mapping and boost training with batch_size=64.
This commit is contained in:
@@ -29,6 +29,7 @@ os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
|
||||
# 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式
|
||||
for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
|
||||
os.environ.pop(_k, None)
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -240,11 +241,22 @@ def main() -> None:
|
||||
text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col)
|
||||
print(f"[Info] 文本列: {text_col} | 标签列: {label_col}")
|
||||
|
||||
# 标签映射
|
||||
label2id, id2label = build_label_mappings(train_df, label_col)
|
||||
# 标签映射(使用 训练集∪验证集 的并集,避免验证集中出现新标签导致报错)
|
||||
combined_labels_df = pd.concat([train_df[[label_col]], valid_df[[label_col]]], ignore_index=True)
|
||||
label2id, id2label = build_label_mappings(combined_labels_df, label_col)
|
||||
if len(label2id) < 2:
|
||||
raise ValueError("标签类别数少于 2,无法训练分类模型。")
|
||||
print(f"[Info] 标签类别数: {len(label2id)}")
|
||||
# 提示验证集中未出现在训练集的标签数量
|
||||
try:
|
||||
train_label_set = set(str(x) for x in train_df[label_col].dropna().astype(str).tolist())
|
||||
valid_label_set = set(str(x) for x in valid_df[label_col].dropna().astype(str).tolist())
|
||||
unseen_in_train = sorted(valid_label_set - train_label_set)
|
||||
if unseen_in_train:
|
||||
preview = ", ".join(unseen_in_train[:10])
|
||||
print(f"[Warn] 验证集中存在 {len(unseen_in_train)} 个训练未出现的标签(已纳入映射以避免报错)。示例: {preview} ...")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 数据集
|
||||
train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length)
|
||||
|
||||
Reference in New Issue
Block a user