Updated the topic prediction script to support Top-K predictions and added an interactive selection feature for optional Chinese foundation models.
This commit is contained in:
@@ -3,7 +3,7 @@ import sys
|
||||
import json
|
||||
import re
|
||||
import argparse
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
|
||||
def _extract_gpu_arg(argv, default: str = "0") -> str:
|
||||
@@ -109,14 +109,8 @@ def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]:
|
||||
return finetuned_path, id2label
|
||||
|
||||
|
||||
def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
processed = preprocess_text(text)
|
||||
def predict_topk(model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, device: torch.device, text: str, max_length: int = 128, top_k: int = 3) -> List[Tuple[str, float]]:
|
||||
processed = preprocess_text(text or "")
|
||||
encoded = tokenizer(
|
||||
processed,
|
||||
max_length=max_length,
|
||||
@@ -130,12 +124,17 @@ def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str,
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
pred = int(torch.argmax(probs, dim=-1).item())
|
||||
conf = float(probs[0, pred].item())
|
||||
id2label = getattr(model.config, "id2label", None)
|
||||
label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred)
|
||||
return label_name, conf
|
||||
probs = torch.softmax(logits, dim=-1)[0]
|
||||
k = min(top_k, probs.shape[-1])
|
||||
confs, idxs = torch.topk(probs, k)
|
||||
id2label = getattr(model.config, "id2label", {}) if isinstance(getattr(model.config, "id2label", None), dict) else {}
|
||||
results: List[Tuple[str, float]] = []
|
||||
for i in range(k):
|
||||
idx = int(idxs[i].item())
|
||||
conf = float(confs[i].item())
|
||||
label_name = id2label.get(idx, str(idx))
|
||||
results.append((label_name, conf))
|
||||
return results
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -149,13 +148,21 @@ def main() -> None:
|
||||
ensure_base_model_local(args.pretrained_name, model_root)
|
||||
|
||||
finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
tokenizer = AutoTokenizer.from_pretrained(finetuned_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(finetuned_dir)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.text is not None:
|
||||
label, conf = predict_once(finetuned_dir, args.text, args.max_length)
|
||||
print(f"预测结果: {label} (置信度: {conf:.4f})")
|
||||
topk = predict_topk(model, tokenizer, device, args.text, args.max_length, top_k=3)
|
||||
print("Top-3 预测:")
|
||||
for rank, (label, conf) in enumerate(topk, 1):
|
||||
print(f"{rank}. {label} (p={conf:.4f})")
|
||||
return
|
||||
|
||||
if args.interactive:
|
||||
# 默认进入交互模式(未显式指定 --text 且未显式关闭交互)
|
||||
if args.interactive or (args.text is None):
|
||||
print("进入交互模式。输入 'q' 退出。")
|
||||
while True:
|
||||
try:
|
||||
@@ -166,11 +173,13 @@ def main() -> None:
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
label, conf = predict_once(finetuned_dir, text, args.max_length)
|
||||
print(f"预测结果: {label} (置信度: {conf:.4f})")
|
||||
topk = predict_topk(model, tokenizer, device, text, args.max_length, top_k=3)
|
||||
print("Top-3 预测:")
|
||||
for rank, (label, conf) in enumerate(topk, 1):
|
||||
print(f"{rank}. {label} (p={conf:.4f})")
|
||||
return
|
||||
|
||||
print("未提供 --text 或 --interactive,什么也没有发生。")
|
||||
# 理论上不会到达这里
|
||||
print("未提供输入。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user