Updated the topic prediction script to support Top-K predictions and added an interactive selection feature for optional Chinese foundation models.
This commit is contained in:
@@ -52,6 +52,52 @@ try:
|
||||
except Exception: # pragma: no cover
|
||||
EarlyStoppingCallback = None # type: ignore
|
||||
|
||||
# 预置可选中文基座模型(可扩展)
|
||||
BACKBONE_CANDIDATES: List[Tuple[str, str]] = [
|
||||
("1) google-bert/bert-base-chinese", "google-bert/bert-base-chinese"),
|
||||
("2) hfl/chinese-roberta-wwm-ext-large", "hfl/chinese-roberta-wwm-ext-large"),
|
||||
("3) hfl/chinese-macbert-large", "hfl/chinese-macbert-large"),
|
||||
("4) IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"),
|
||||
("5) IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese"),
|
||||
("6) Langboat/mengzi-bert-base", "Langboat/mengzi-bert-base"),
|
||||
("7) BAAI/bge-base-zh", "BAAI/bge-base-zh"),
|
||||
("8) nghuyong/ernie-3.0-base-zh", "nghuyong/ernie-3.0-base-zh"),
|
||||
]
|
||||
|
||||
|
||||
def prompt_backbone_interactive(current_id: str) -> str:
|
||||
"""交互式选择基座模型。
|
||||
|
||||
- 当处于非交互环境(stdin 非 TTY)或设置了环境变量 NON_INTERACTIVE=1 时,直接返回 current_id。
|
||||
- 用户可输入序号选择预置项,或直接输入任意 Hugging Face 模型 ID。
|
||||
- 空回车使用当前默认。
|
||||
"""
|
||||
if os.environ.get("NON_INTERACTIVE", "0") == "1":
|
||||
return current_id
|
||||
try:
|
||||
if not sys.stdin.isatty():
|
||||
return current_id
|
||||
except Exception:
|
||||
return current_id
|
||||
|
||||
print("\n可选中文基座模型(直接回车使用默认):")
|
||||
for label, hf_id in BACKBONE_CANDIDATES:
|
||||
print(f" {label}")
|
||||
print(f"当前默认: {current_id}")
|
||||
choice = input("请输入序号或直接粘贴模型ID(回车沿用默认): ").strip()
|
||||
if not choice:
|
||||
return current_id
|
||||
# 数字选项
|
||||
if choice.isdigit():
|
||||
idx = int(choice)
|
||||
for label, hf_id in BACKBONE_CANDIDATES:
|
||||
if label.startswith(f"{idx})"):
|
||||
return hf_id
|
||||
print("未找到该序号,沿用默认。")
|
||||
return current_id
|
||||
# 自定义 HF 模型 ID
|
||||
return choice
|
||||
|
||||
|
||||
def preprocess_text(text: str) -> str:
|
||||
if text is None:
|
||||
@@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别")
|
||||
parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别")
|
||||
parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
|
||||
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese")
|
||||
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="Hugging Face 模型ID;留空则进入交互选择")
|
||||
parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier")
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
@@ -216,8 +262,10 @@ def main() -> None:
|
||||
model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
|
||||
# 交互式选择基座模型(若允许交互且未通过环境禁用)
|
||||
selected_model_id = prompt_backbone_interactive(args.pretrained_name)
|
||||
# 确保基础模型就绪
|
||||
base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root)
|
||||
base_dir, tokenizer = ensure_base_model_local(selected_model_id, model_root)
|
||||
print(f"[Info] 使用基础模型目录: {base_dir}")
|
||||
|
||||
# 读取数据
|
||||
|
||||
Reference in New Issue
Block a user