From de68be59f58504359bf1cddf6fb5b0805ebb1584 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Sat, 9 Aug 2025 23:16:05 +0800 Subject: [PATCH] Updated the topic prediction script to support Top-K predictions and added an interactive selection feature for optional Chinese foundation models. --- .gitignore | 1 + BertTopicDetection_Finetuned/README.md | 28 +++++++++++++ BertTopicDetection_Finetuned/predict.py | 53 +++++++++++++++---------- BertTopicDetection_Finetuned/train.py | 52 +++++++++++++++++++++++- 4 files changed, 110 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index f1233f9..e996b22 100644 --- a/.gitignore +++ b/.gitignore @@ -185,6 +185,7 @@ WeiboSentiment_Finetuned/BertChinese-Lora/model/ WeiboMultilingualSentiment/model/ WeiboSentiment_MachineLearning/model/chinese_wwm_pytorch/ WeiboSentiment_SmallQwen/models/ +BertTopicDetection_Finetuned/model/ # LoRA 和 Adapter 权重 */adapter_model.safetensors diff --git a/BertTopicDetection_Finetuned/README.md b/BertTopicDetection_Finetuned/README.md index d654167..c17c8b7 100644 --- a/BertTopicDetection_Finetuned/README.md +++ b/BertTopicDetection_Finetuned/README.md @@ -63,6 +63,23 @@ python train.py \ - 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型; - 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`。 +### 可选中文基座模型(训练前交互选择) + +默认基座:`google-bert/bert-base-chinese`。启动训练时,若终端可交互,程序会提示从下列选项中选择(或输入任意 Hugging Face 模型 ID): + +1) `google-bert/bert-base-chinese` +2) `hfl/chinese-roberta-wwm-ext-large` +3) `hfl/chinese-macbert-large` +4) `IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese` +5) `IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese` +6) `Langboat/mengzi-bert-base` +7) `BAAI/bge-base-zh`(更适合检索式/对比学习范式) +8) `nghuyong/ernie-3.0-base-zh` + +说明: +- 非交互环境(如调度系统)或设置 `NON_INTERACTIVE=1` 时,会直接使用命令行参数 `--pretrained_name` 指定的模型(默认为 `google-bert/bert-base-chinese`)。 +- 选择后,基础模型将下载/缓存至 `model/` 目录,统一管理。 + ### 预测 单条: @@ -96,3 +113,14 @@ python predict.py --interactive --model_root ./model --finetuned_subdir bert-chi - 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。 +### 作者说明(关于超大规模多分类) + +- 当话题类别达到上万级时,直接在编码器后接单一线性分类头(大 softmax)往往受限:长尾类别难学、语义稀疏、新增话题无法增量适配、上线后需频繁重训。 +- 改进思路(推荐优先级): + - 检索式/双塔范式(文本 vs. 话题名称/描述 对比学习)+ 近邻检索 + 小头重排,天然支持增量扩类与快速更新; + - 分层分类(先粗分再细分),显著降低单头难度与计算; + - 文本-标签联合建模(使用标签描述),提升近义话题的可迁移性; + - 训练细节:class-balanced/focal/label smoothing、sampled softmax、对比预训练等。 +- 重要声明:本目录使用的“静态分类头微调”仅作为备选与学习参考。对于英文/多语微短文场景,话题变化极快,传统静态分类器难以及时覆盖,我们的工作重点在 `TopicGPT` 等生成式/自监督话题发现与动态体系构建方向;本实现旨在提供一个可运行的基线与工程示例。 + + diff --git a/BertTopicDetection_Finetuned/predict.py b/BertTopicDetection_Finetuned/predict.py index 6663fe6..7dc86dc 100644 --- a/BertTopicDetection_Finetuned/predict.py +++ b/BertTopicDetection_Finetuned/predict.py @@ -3,7 +3,7 @@ import sys import json import re import argparse -from typing import Dict, Tuple +from typing import Dict, Tuple, List # ========== 单卡锁定(在导入 torch/transformers 前执行) ========== def _extract_gpu_arg(argv, default: str = "0") -> str: @@ -109,14 +109,8 @@ def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]: return finetuned_path, id2label -def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - tokenizer = AutoTokenizer.from_pretrained(model_dir) - model = AutoModelForSequenceClassification.from_pretrained(model_dir) - model.to(device) - model.eval() - - processed = preprocess_text(text) +def predict_topk(model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, device: torch.device, text: str, max_length: int = 128, top_k: int = 3) -> List[Tuple[str, float]]: + processed = preprocess_text(text or "") encoded = tokenizer( processed, max_length=max_length, @@ -130,12 +124,17 @@ def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits - probs = torch.softmax(logits, dim=-1) - pred = int(torch.argmax(probs, dim=-1).item()) - conf = float(probs[0, pred].item()) - id2label = getattr(model.config, "id2label", None) - label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred) - return label_name, conf + probs = torch.softmax(logits, dim=-1)[0] + k = min(top_k, probs.shape[-1]) + confs, idxs = torch.topk(probs, k) + id2label = getattr(model.config, "id2label", {}) if isinstance(getattr(model.config, "id2label", None), dict) else {} + results: List[Tuple[str, float]] = [] + for i in range(k): + idx = int(idxs[i].item()) + conf = float(confs[i].item()) + label_name = id2label.get(idx, str(idx)) + results.append((label_name, conf)) + return results def main() -> None: @@ -149,13 +148,21 @@ def main() -> None: ensure_base_model_local(args.pretrained_name, model_root) finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained(finetuned_dir) + model = AutoModelForSequenceClassification.from_pretrained(finetuned_dir) + model.to(device) + model.eval() if args.text is not None: - label, conf = predict_once(finetuned_dir, args.text, args.max_length) - print(f"预测结果: {label} (置信度: {conf:.4f})") + topk = predict_topk(model, tokenizer, device, args.text, args.max_length, top_k=3) + print("Top-3 预测:") + for rank, (label, conf) in enumerate(topk, 1): + print(f"{rank}. {label} (p={conf:.4f})") return - if args.interactive: + # 默认进入交互模式(未显式指定 --text 且未显式关闭交互) + if args.interactive or (args.text is None): print("进入交互模式。输入 'q' 退出。") while True: try: @@ -166,11 +173,13 @@ def main() -> None: break if not text: continue - label, conf = predict_once(finetuned_dir, text, args.max_length) - print(f"预测结果: {label} (置信度: {conf:.4f})") + topk = predict_topk(model, tokenizer, device, text, args.max_length, top_k=3) + print("Top-3 预测:") + for rank, (label, conf) in enumerate(topk, 1): + print(f"{rank}. {label} (p={conf:.4f})") return - - print("未提供 --text 或 --interactive,什么也没有发生。") + # 理论上不会到达这里 + print("未提供输入。") if __name__ == "__main__": diff --git a/BertTopicDetection_Finetuned/train.py b/BertTopicDetection_Finetuned/train.py index 7eb9c89..6b68f04 100644 --- a/BertTopicDetection_Finetuned/train.py +++ b/BertTopicDetection_Finetuned/train.py @@ -52,6 +52,52 @@ try: except Exception: # pragma: no cover EarlyStoppingCallback = None # type: ignore +# 预置可选中文基座模型(可扩展) +BACKBONE_CANDIDATES: List[Tuple[str, str]] = [ + ("1) google-bert/bert-base-chinese", "google-bert/bert-base-chinese"), + ("2) hfl/chinese-roberta-wwm-ext-large", "hfl/chinese-roberta-wwm-ext-large"), + ("3) hfl/chinese-macbert-large", "hfl/chinese-macbert-large"), + ("4) IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"), + ("5) IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese"), + ("6) Langboat/mengzi-bert-base", "Langboat/mengzi-bert-base"), + ("7) BAAI/bge-base-zh", "BAAI/bge-base-zh"), + ("8) nghuyong/ernie-3.0-base-zh", "nghuyong/ernie-3.0-base-zh"), +] + + +def prompt_backbone_interactive(current_id: str) -> str: + """交互式选择基座模型。 + + - 当处于非交互环境(stdin 非 TTY)或设置了环境变量 NON_INTERACTIVE=1 时,直接返回 current_id。 + - 用户可输入序号选择预置项,或直接输入任意 Hugging Face 模型 ID。 + - 空回车使用当前默认。 + """ + if os.environ.get("NON_INTERACTIVE", "0") == "1": + return current_id + try: + if not sys.stdin.isatty(): + return current_id + except Exception: + return current_id + + print("\n可选中文基座模型(直接回车使用默认):") + for label, hf_id in BACKBONE_CANDIDATES: + print(f" {label}") + print(f"当前默认: {current_id}") + choice = input("请输入序号或直接粘贴模型ID(回车沿用默认): ").strip() + if not choice: + return current_id + # 数字选项 + if choice.isdigit(): + idx = int(choice) + for label, hf_id in BACKBONE_CANDIDATES: + if label.startswith(f"{idx})"): + return hf_id + print("未找到该序号,沿用默认。") + return current_id + # 自定义 HF 模型 ID + return choice + def preprocess_text(text: str) -> str: if text is None: @@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别") parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别") parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录") - parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese") + parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="Hugging Face 模型ID;留空则进入交互选择") parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier") parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--batch_size", type=int, default=64) @@ -216,8 +262,10 @@ def main() -> None: model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root) os.makedirs(model_root, exist_ok=True) + # 交互式选择基座模型(若允许交互且未通过环境禁用) + selected_model_id = prompt_backbone_interactive(args.pretrained_name) # 确保基础模型就绪 - base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root) + base_dir, tokenizer = ensure_base_model_local(selected_model_id, model_root) print(f"[Info] 使用基础模型目录: {base_dir}") # 读取数据