Updated the topic prediction script to support Top-K predictions and added an interactive selection feature for optional Chinese foundation models.

This commit is contained in:
戒酒的李白
2025-08-09 23:16:05 +08:00
parent d726941d95
commit de68be59f5
4 changed files with 110 additions and 24 deletions
+50 -2
View File
@@ -52,6 +52,52 @@ try:
except Exception: # pragma: no cover
EarlyStoppingCallback = None # type: ignore
# 预置可选中文基座模型(可扩展)
BACKBONE_CANDIDATES: List[Tuple[str, str]] = [
("1) google-bert/bert-base-chinese", "google-bert/bert-base-chinese"),
("2) hfl/chinese-roberta-wwm-ext-large", "hfl/chinese-roberta-wwm-ext-large"),
("3) hfl/chinese-macbert-large", "hfl/chinese-macbert-large"),
("4) IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"),
("5) IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese"),
("6) Langboat/mengzi-bert-base", "Langboat/mengzi-bert-base"),
("7) BAAI/bge-base-zh", "BAAI/bge-base-zh"),
("8) nghuyong/ernie-3.0-base-zh", "nghuyong/ernie-3.0-base-zh"),
]
def prompt_backbone_interactive(current_id: str) -> str:
"""交互式选择基座模型。
- 当处于非交互环境(stdin 非 TTY)或设置了环境变量 NON_INTERACTIVE=1 时,直接返回 current_id。
- 用户可输入序号选择预置项,或直接输入任意 Hugging Face 模型 ID。
- 空回车使用当前默认。
"""
if os.environ.get("NON_INTERACTIVE", "0") == "1":
return current_id
try:
if not sys.stdin.isatty():
return current_id
except Exception:
return current_id
print("\n可选中文基座模型(直接回车使用默认):")
for label, hf_id in BACKBONE_CANDIDATES:
print(f" {label}")
print(f"当前默认: {current_id}")
choice = input("请输入序号或直接粘贴模型ID(回车沿用默认): ").strip()
if not choice:
return current_id
# 数字选项
if choice.isdigit():
idx = int(choice)
for label, hf_id in BACKBONE_CANDIDATES:
if label.startswith(f"{idx})"):
return hf_id
print("未找到该序号,沿用默认。")
return current_id
# 自定义 HF 模型 ID
return choice
def preprocess_text(text: str) -> str:
if text is None:
@@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别")
parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别")
parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese")
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="Hugging Face 模型ID;留空则进入交互选择")
parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier")
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=64)
@@ -216,8 +262,10 @@ def main() -> None:
model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
os.makedirs(model_root, exist_ok=True)
# 交互式选择基座模型(若允许交互且未通过环境禁用)
selected_model_id = prompt_backbone_interactive(args.pretrained_name)
# 确保基础模型就绪
base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root)
base_dir, tokenizer = ensure_base_model_local(selected_model_id, model_root)
print(f"[Info] 使用基础模型目录: {base_dir}")
# 读取数据