Train and prediction script for a topic classification model based on bert-chinese.
This commit is contained in:
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import argparse
|
||||
from typing import Dict, Tuple
|
||||
|
||||
# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
|
||||
def _extract_gpu_arg(argv, default: str = "0") -> str:
|
||||
for i, arg in enumerate(argv):
|
||||
if arg.startswith("--gpu="):
|
||||
return arg.split("=", 1)[1]
|
||||
if arg == "--gpu" and i + 1 < len(argv):
|
||||
return argv[i + 1]
|
||||
return default
|
||||
|
||||
env_vis = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
|
||||
try:
|
||||
gpu_to_use = _extract_gpu_arg(sys.argv, default="0")
|
||||
except Exception:
|
||||
gpu_to_use = "0"
|
||||
if (not env_vis) or ("," in env_vis):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_to_use
|
||||
os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
|
||||
|
||||
for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
|
||||
os.environ.pop(_k, None)
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
)
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
if text is None:
|
||||
return ""
|
||||
text = str(text)
|
||||
text = re.sub(r"\{%.+?%\}", " ", text)
|
||||
text = re.sub(r"@.+?( |$)", " ", text)
|
||||
text = re.sub(r"【.+?】", " ", text)
|
||||
text = re.sub(r"\u200b", " ", text)
|
||||
text = re.sub(
|
||||
r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+",
|
||||
"",
|
||||
text,
|
||||
)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def ensure_base_model_local(model_name_or_path: str, local_model_root: str) -> Tuple[str, AutoTokenizer]:
|
||||
os.makedirs(local_model_root, exist_ok=True)
|
||||
base_dir = os.path.join(local_model_root, "bert-base-chinese")
|
||||
|
||||
def is_ready(path: str) -> bool:
|
||||
return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))
|
||||
|
||||
if is_ready(base_dir):
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
|
||||
# 本机缓存
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, local_files_only=True)
|
||||
base = AutoModel.from_pretrained(model_name_or_path, local_files_only=True)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(base_dir)
|
||||
base.save_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 远程下载
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
base = AutoModel.from_pretrained(model_name_or_path)
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(base_dir)
|
||||
base.save_pretrained(base_dir)
|
||||
return base_dir, tokenizer
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="使用本地/缓存/远程加载的中文 BERT 分类模型进行预测")
|
||||
parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
|
||||
parser.add_argument("--finetuned_subdir", type=str, default="bert-chinese-classifier", help="微调结果子目录")
|
||||
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="预训练模型名称或路径")
|
||||
parser.add_argument("--text", type=str, default=None, help="直接输入一条要预测的文本")
|
||||
parser.add_argument("--interactive", action="store_true", help="进入交互式预测模式")
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--gpu", type=str, default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), help="指定单卡 GPU,如 0 或 1")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]:
|
||||
finetuned_path = os.path.join(model_root, subdir)
|
||||
if not os.path.isdir(finetuned_path):
|
||||
raise FileNotFoundError(
|
||||
f"未找到微调模型目录: {finetuned_path},请先运行训练脚本。"
|
||||
)
|
||||
label_map_path = os.path.join(finetuned_path, "label_map.json")
|
||||
id2label = None
|
||||
if os.path.isfile(label_map_path):
|
||||
with open(label_map_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
id2label = {int(k): str(v) for k, v in data.get("id2label", {}).items()}
|
||||
return finetuned_path, id2label
|
||||
|
||||
|
||||
def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
processed = preprocess_text(text)
|
||||
encoded = tokenizer(
|
||||
processed,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = encoded["input_ids"].to(device)
|
||||
attention_mask = encoded["attention_mask"].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
pred = int(torch.argmax(probs, dim=-1).item())
|
||||
conf = float(probs[0, pred].item())
|
||||
id2label = getattr(model.config, "id2label", None)
|
||||
label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred)
|
||||
return label_name, conf
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
|
||||
# 确保基础模型在本地
|
||||
ensure_base_model_local(args.pretrained_name, model_root)
|
||||
|
||||
finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir)
|
||||
|
||||
if args.text is not None:
|
||||
label, conf = predict_once(finetuned_dir, args.text, args.max_length)
|
||||
print(f"预测结果: {label} (置信度: {conf:.4f})")
|
||||
return
|
||||
|
||||
if args.interactive:
|
||||
print("进入交互模式。输入 'q' 退出。")
|
||||
while True:
|
||||
try:
|
||||
text = input("请输入文本: ").strip()
|
||||
except EOFError:
|
||||
break
|
||||
if text.lower() == "q":
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
label, conf = predict_once(finetuned_dir, text, args.max_length)
|
||||
print(f"预测结果: {label} (置信度: {conf:.4f})")
|
||||
return
|
||||
|
||||
print("未提供 --text 或 --interactive,什么也没有发生。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user