Updated the topic prediction script to support Top-K predictions and added an interactive selection feature for optional Chinese foundation models.
This commit is contained in:
@@ -185,6 +185,7 @@ WeiboSentiment_Finetuned/BertChinese-Lora/model/
|
||||
WeiboMultilingualSentiment/model/
|
||||
WeiboSentiment_MachineLearning/model/chinese_wwm_pytorch/
|
||||
WeiboSentiment_SmallQwen/models/
|
||||
BertTopicDetection_Finetuned/model/
|
||||
|
||||
# LoRA 和 Adapter 权重
|
||||
*/adapter_model.safetensors
|
||||
|
||||
@@ -63,6 +63,23 @@ python train.py \
|
||||
- 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型;
|
||||
- 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`。
|
||||
|
||||
### 可选中文基座模型(训练前交互选择)
|
||||
|
||||
默认基座:`google-bert/bert-base-chinese`。启动训练时,若终端可交互,程序会提示从下列选项中选择(或输入任意 Hugging Face 模型 ID):
|
||||
|
||||
1) `google-bert/bert-base-chinese`
|
||||
2) `hfl/chinese-roberta-wwm-ext-large`
|
||||
3) `hfl/chinese-macbert-large`
|
||||
4) `IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese`
|
||||
5) `IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese`
|
||||
6) `Langboat/mengzi-bert-base`
|
||||
7) `BAAI/bge-base-zh`(更适合检索式/对比学习范式)
|
||||
8) `nghuyong/ernie-3.0-base-zh`
|
||||
|
||||
说明:
|
||||
- 非交互环境(如调度系统)或设置 `NON_INTERACTIVE=1` 时,会直接使用命令行参数 `--pretrained_name` 指定的模型(默认为 `google-bert/bert-base-chinese`)。
|
||||
- 选择后,基础模型将下载/缓存至 `model/` 目录,统一管理。
|
||||
|
||||
### 预测
|
||||
|
||||
单条:
|
||||
@@ -96,3 +113,14 @@ python predict.py --interactive --model_root ./model --finetuned_subdir bert-chi
|
||||
- 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。
|
||||
|
||||
|
||||
### 作者说明(关于超大规模多分类)
|
||||
|
||||
- 当话题类别达到上万级时,直接在编码器后接单一线性分类头(大 softmax)往往受限:长尾类别难学、语义稀疏、新增话题无法增量适配、上线后需频繁重训。
|
||||
- 改进思路(推荐优先级):
|
||||
- 检索式/双塔范式(文本 vs. 话题名称/描述 对比学习)+ 近邻检索 + 小头重排,天然支持增量扩类与快速更新;
|
||||
- 分层分类(先粗分再细分),显著降低单头难度与计算;
|
||||
- 文本-标签联合建模(使用标签描述),提升近义话题的可迁移性;
|
||||
- 训练细节:class-balanced/focal/label smoothing、sampled softmax、对比预训练等。
|
||||
- 重要声明:本目录使用的“静态分类头微调”仅作为备选与学习参考。对于英文/多语微短文场景,话题变化极快,传统静态分类器难以及时覆盖,我们的工作重点在 `TopicGPT` 等生成式/自监督话题发现与动态体系构建方向;本实现旨在提供一个可运行的基线与工程示例。
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import sys
|
||||
import json
|
||||
import re
|
||||
import argparse
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
|
||||
def _extract_gpu_arg(argv, default: str = "0") -> str:
|
||||
@@ -109,14 +109,8 @@ def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]:
|
||||
return finetuned_path, id2label
|
||||
|
||||
|
||||
def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
processed = preprocess_text(text)
|
||||
def predict_topk(model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, device: torch.device, text: str, max_length: int = 128, top_k: int = 3) -> List[Tuple[str, float]]:
|
||||
processed = preprocess_text(text or "")
|
||||
encoded = tokenizer(
|
||||
processed,
|
||||
max_length=max_length,
|
||||
@@ -130,12 +124,17 @@ def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str,
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
pred = int(torch.argmax(probs, dim=-1).item())
|
||||
conf = float(probs[0, pred].item())
|
||||
id2label = getattr(model.config, "id2label", None)
|
||||
label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred)
|
||||
return label_name, conf
|
||||
probs = torch.softmax(logits, dim=-1)[0]
|
||||
k = min(top_k, probs.shape[-1])
|
||||
confs, idxs = torch.topk(probs, k)
|
||||
id2label = getattr(model.config, "id2label", {}) if isinstance(getattr(model.config, "id2label", None), dict) else {}
|
||||
results: List[Tuple[str, float]] = []
|
||||
for i in range(k):
|
||||
idx = int(idxs[i].item())
|
||||
conf = float(confs[i].item())
|
||||
label_name = id2label.get(idx, str(idx))
|
||||
results.append((label_name, conf))
|
||||
return results
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -149,13 +148,21 @@ def main() -> None:
|
||||
ensure_base_model_local(args.pretrained_name, model_root)
|
||||
|
||||
finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
tokenizer = AutoTokenizer.from_pretrained(finetuned_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(finetuned_dir)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.text is not None:
|
||||
label, conf = predict_once(finetuned_dir, args.text, args.max_length)
|
||||
print(f"预测结果: {label} (置信度: {conf:.4f})")
|
||||
topk = predict_topk(model, tokenizer, device, args.text, args.max_length, top_k=3)
|
||||
print("Top-3 预测:")
|
||||
for rank, (label, conf) in enumerate(topk, 1):
|
||||
print(f"{rank}. {label} (p={conf:.4f})")
|
||||
return
|
||||
|
||||
if args.interactive:
|
||||
# 默认进入交互模式(未显式指定 --text 且未显式关闭交互)
|
||||
if args.interactive or (args.text is None):
|
||||
print("进入交互模式。输入 'q' 退出。")
|
||||
while True:
|
||||
try:
|
||||
@@ -166,11 +173,13 @@ def main() -> None:
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
label, conf = predict_once(finetuned_dir, text, args.max_length)
|
||||
print(f"预测结果: {label} (置信度: {conf:.4f})")
|
||||
topk = predict_topk(model, tokenizer, device, text, args.max_length, top_k=3)
|
||||
print("Top-3 预测:")
|
||||
for rank, (label, conf) in enumerate(topk, 1):
|
||||
print(f"{rank}. {label} (p={conf:.4f})")
|
||||
return
|
||||
|
||||
print("未提供 --text 或 --interactive,什么也没有发生。")
|
||||
# 理论上不会到达这里
|
||||
print("未提供输入。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -52,6 +52,52 @@ try:
|
||||
except Exception: # pragma: no cover
|
||||
EarlyStoppingCallback = None # type: ignore
|
||||
|
||||
# 预置可选中文基座模型(可扩展)
|
||||
BACKBONE_CANDIDATES: List[Tuple[str, str]] = [
|
||||
("1) google-bert/bert-base-chinese", "google-bert/bert-base-chinese"),
|
||||
("2) hfl/chinese-roberta-wwm-ext-large", "hfl/chinese-roberta-wwm-ext-large"),
|
||||
("3) hfl/chinese-macbert-large", "hfl/chinese-macbert-large"),
|
||||
("4) IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"),
|
||||
("5) IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese"),
|
||||
("6) Langboat/mengzi-bert-base", "Langboat/mengzi-bert-base"),
|
||||
("7) BAAI/bge-base-zh", "BAAI/bge-base-zh"),
|
||||
("8) nghuyong/ernie-3.0-base-zh", "nghuyong/ernie-3.0-base-zh"),
|
||||
]
|
||||
|
||||
|
||||
def prompt_backbone_interactive(current_id: str) -> str:
|
||||
"""交互式选择基座模型。
|
||||
|
||||
- 当处于非交互环境(stdin 非 TTY)或设置了环境变量 NON_INTERACTIVE=1 时,直接返回 current_id。
|
||||
- 用户可输入序号选择预置项,或直接输入任意 Hugging Face 模型 ID。
|
||||
- 空回车使用当前默认。
|
||||
"""
|
||||
if os.environ.get("NON_INTERACTIVE", "0") == "1":
|
||||
return current_id
|
||||
try:
|
||||
if not sys.stdin.isatty():
|
||||
return current_id
|
||||
except Exception:
|
||||
return current_id
|
||||
|
||||
print("\n可选中文基座模型(直接回车使用默认):")
|
||||
for label, hf_id in BACKBONE_CANDIDATES:
|
||||
print(f" {label}")
|
||||
print(f"当前默认: {current_id}")
|
||||
choice = input("请输入序号或直接粘贴模型ID(回车沿用默认): ").strip()
|
||||
if not choice:
|
||||
return current_id
|
||||
# 数字选项
|
||||
if choice.isdigit():
|
||||
idx = int(choice)
|
||||
for label, hf_id in BACKBONE_CANDIDATES:
|
||||
if label.startswith(f"{idx})"):
|
||||
return hf_id
|
||||
print("未找到该序号,沿用默认。")
|
||||
return current_id
|
||||
# 自定义 HF 模型 ID
|
||||
return choice
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
if text is None:
|
||||
@@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别")
|
||||
parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别")
|
||||
parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
|
||||
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese")
|
||||
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="Hugging Face 模型ID;留空则进入交互选择")
|
||||
parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier")
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
@@ -216,8 +262,10 @@ def main() -> None:
|
||||
model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
|
||||
# 交互式选择基座模型(若允许交互且未通过环境禁用)
|
||||
selected_model_id = prompt_backbone_interactive(args.pretrained_name)
|
||||
# 确保基础模型就绪
|
||||
base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root)
|
||||
base_dir, tokenizer = ensure_base_model_local(selected_model_id, model_root)
|
||||
print(f"[Info] 使用基础模型目录: {base_dir}")
|
||||
|
||||
# 读取数据
|
||||
|
||||
Reference in New Issue
Block a user