377 lines
15 KiB
Python
377 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
Qwen3微博情感分析统一预测接口
|
|
支持0.6B、4B、8B三种规格的Embedding和LoRA模型
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import torch
|
|
from typing import List, Dict, Tuple, Any
|
|
|
|
# 添加当前目录到路径
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from models_config import QWEN3_MODELS, MODEL_PATHS
|
|
from qwen3_embedding_universal import Qwen3EmbeddingUniversal
|
|
from qwen3_lora_universal import Qwen3LoRAUniversal
|
|
|
|
|
|
class Qwen3UniversalPredictor:
|
|
"""Qwen3统一预测器"""
|
|
|
|
def __init__(self):
|
|
self.models = {} # 存储已加载的模型 {model_key: {model: obj, display_name: str}}
|
|
|
|
def _get_model_key(self, model_type: str, model_size: str) -> str:
|
|
"""生成模型键值"""
|
|
return f"{model_type}_{model_size}"
|
|
|
|
def load_model(self, model_type: str, model_size: str) -> None:
|
|
"""加载指定的模型"""
|
|
if model_type not in ['embedding', 'lora']:
|
|
raise ValueError(f"不支持的模型类型: {model_type}")
|
|
if model_size not in ['0.6B', '4B', '8B']:
|
|
raise ValueError(f"不支持的模型大小: {model_size}")
|
|
|
|
model_path = MODEL_PATHS[model_type][model_size]
|
|
model_key = self._get_model_key(model_type, model_size)
|
|
|
|
# 检查训练好的模型文件是否存在
|
|
if not os.path.exists(model_path):
|
|
print(f"训练好的模型文件不存在: {model_path}")
|
|
print(f"请先训练 {model_type.upper()}-{model_size} 模型,或检查模型路径配置")
|
|
return
|
|
|
|
print(f"加载 {model_type.upper()}-{model_size} 模型...")
|
|
|
|
try:
|
|
if model_type == 'embedding':
|
|
model = Qwen3EmbeddingUniversal(model_size)
|
|
model.load_model(model_path)
|
|
else: # lora
|
|
model = Qwen3LoRAUniversal(model_size)
|
|
model.load_model(model_path)
|
|
|
|
self.models[model_key] = {
|
|
'model': model,
|
|
'display_name': f"Qwen3-{model_type.title()}-{model_size}"
|
|
}
|
|
print(f"{model_type.upper()}-{model_size} 模型加载成功")
|
|
|
|
except Exception as e:
|
|
print(f"加载 {model_type.upper()}-{model_size} 模型失败: {e}")
|
|
print(f"这可能是因为基础模型下载失败或训练好的模型文件损坏")
|
|
|
|
def load_all_models(self, model_dir: str = './models') -> None:
|
|
"""加载所有可用的模型"""
|
|
print("开始加载所有可用的Qwen3模型...")
|
|
|
|
loaded_count = 0
|
|
for model_type in ['embedding', 'lora']:
|
|
for model_size in ['0.6B', '4B', '8B']:
|
|
try:
|
|
self.load_model(model_type, model_size)
|
|
loaded_count += 1
|
|
except Exception as e:
|
|
print(f"跳过 {model_type}-{model_size}: {e}")
|
|
|
|
print(f"\n已加载 {loaded_count} 个模型")
|
|
self._print_loaded_models()
|
|
|
|
def load_specific_models(self, model_configs: List[Tuple[str, str]]) -> None:
|
|
"""加载指定的模型配置
|
|
Args:
|
|
model_configs: [(model_type, model_size), ...] 的列表
|
|
"""
|
|
print("加载指定的Qwen3模型...")
|
|
|
|
for model_type, model_size in model_configs:
|
|
try:
|
|
self.load_model(model_type, model_size)
|
|
except Exception as e:
|
|
print(f"跳过 {model_type}-{model_size}: {e}")
|
|
|
|
print(f"\n已加载 {len(self.models)} 个模型")
|
|
self._print_loaded_models()
|
|
|
|
def _print_loaded_models(self):
|
|
"""打印已加载的模型列表"""
|
|
if self.models:
|
|
print("已加载模型:")
|
|
for model_info in self.models.values():
|
|
print(f" - {model_info['display_name']}")
|
|
else:
|
|
print("没有成功加载任何模型")
|
|
|
|
def predict_single(self, text: str, model_key: str = None) -> Dict[str, Tuple[int, float]]:
|
|
"""单文本预测
|
|
Args:
|
|
text: 要预测的文本
|
|
model_key: 指定模型键值,None表示使用所有模型
|
|
Returns:
|
|
{model_name: (prediction, confidence), ...}
|
|
"""
|
|
results = {}
|
|
|
|
if model_key and model_key in self.models:
|
|
# 使用指定模型
|
|
model_info = self.models[model_key]
|
|
try:
|
|
prediction, confidence = model_info['model'].predict_single(text)
|
|
results[model_info['display_name']] = (prediction, confidence)
|
|
except Exception as e:
|
|
print(f"模型 {model_info['display_name']} 预测失败: {e}")
|
|
results[model_info['display_name']] = (0, 0.0)
|
|
else:
|
|
# 使用所有模型
|
|
for model_info in self.models.values():
|
|
try:
|
|
prediction, confidence = model_info['model'].predict_single(text)
|
|
results[model_info['display_name']] = (prediction, confidence)
|
|
except Exception as e:
|
|
print(f"模型 {model_info['display_name']} 预测失败: {e}")
|
|
results[model_info['display_name']] = (0, 0.0)
|
|
|
|
return results
|
|
|
|
def predict_batch(self, texts: List[str]) -> Dict[str, List[int]]:
|
|
"""批量预测"""
|
|
results = {}
|
|
|
|
for model_info in self.models.values():
|
|
try:
|
|
predictions = model_info['model'].predict(texts)
|
|
results[model_info['display_name']] = predictions
|
|
except Exception as e:
|
|
print(f"模型 {model_info['display_name']} 预测失败: {e}")
|
|
results[model_info['display_name']] = [0] * len(texts)
|
|
|
|
return results
|
|
|
|
def ensemble_predict(self, text: str) -> Tuple[int, float]:
|
|
"""集成预测"""
|
|
if len(self.models) < 2:
|
|
raise ValueError("集成预测需要至少2个模型")
|
|
|
|
results = self.predict_single(text)
|
|
|
|
# 加权平均(这里使用简单平均,可以根据模型性能调整权重)
|
|
total_weight = 0
|
|
weighted_prob = 0
|
|
|
|
for model_name, (pred, conf) in results.items():
|
|
if conf > 0: # 只考虑有效预测
|
|
prob = conf if pred == 1 else 1 - conf
|
|
weighted_prob += prob
|
|
total_weight += 1
|
|
|
|
if total_weight == 0:
|
|
return 0, 0.5
|
|
|
|
final_prob = weighted_prob / total_weight
|
|
final_pred = int(final_prob > 0.5)
|
|
final_conf = final_prob if final_pred == 1 else 1 - final_prob
|
|
|
|
return final_pred, final_conf
|
|
|
|
def _select_and_load_model(self):
|
|
"""让用户选择并加载模型"""
|
|
print("Qwen3微博情感分析预测系统")
|
|
print("="*40)
|
|
print("请选择要使用的模型:")
|
|
print("\n方法选择:")
|
|
print(" 1. Embedding + 分类头 (推理快速,显存占用少)")
|
|
print(" 2. LoRA微调 (效果更好,显存占用较多)")
|
|
|
|
method_choice = None
|
|
while method_choice not in ['1', '2']:
|
|
method_choice = input("\n请选择方法 (1/2): ").strip()
|
|
if method_choice not in ['1', '2']:
|
|
print("无效选择,请输入 1 或 2")
|
|
|
|
method_type = "embedding" if method_choice == '1' else "lora"
|
|
method_name = "Embedding + 分类头" if method_choice == '1' else "LoRA微调"
|
|
|
|
print(f"\n已选择: {method_name}")
|
|
print("\n模型大小选择:")
|
|
print(" 1. 0.6B - 轻量级,推理快速")
|
|
print(" 2. 4B - 中等规模,性能均衡")
|
|
print(" 3. 8B - 大规模,性能最佳")
|
|
|
|
size_choice = None
|
|
while size_choice not in ['1', '2', '3']:
|
|
size_choice = input("\n请选择模型大小 (1/2/3): ").strip()
|
|
if size_choice not in ['1', '2', '3']:
|
|
print("无效选择,请输入 1、2 或 3")
|
|
|
|
size_map = {'1': '0.6B', '2': '4B', '3': '8B'}
|
|
model_size = size_map[size_choice]
|
|
|
|
print(f"已选择: Qwen3-{method_name}-{model_size}")
|
|
print("正在加载模型...")
|
|
|
|
try:
|
|
self.load_model(method_type, model_size)
|
|
print(f"模型加载成功!")
|
|
except Exception as e:
|
|
print(f"模型加载失败: {e}")
|
|
print("请检查模型文件是否存在,或先进行训练")
|
|
|
|
def interactive_predict(self):
|
|
"""交互式预测模式"""
|
|
if len(self.models) == 0:
|
|
# 让用户选择要加载的模型
|
|
self._select_and_load_model()
|
|
if len(self.models) == 0:
|
|
print("没有加载任何模型,退出预测")
|
|
return
|
|
|
|
print("\n" + "="*60)
|
|
print("Qwen3微博情感分析预测系统")
|
|
print("="*60)
|
|
print("已加载模型:")
|
|
for model_info in self.models.values():
|
|
print(f" - {model_info['display_name']}")
|
|
print("\n命令提示:")
|
|
print(" 输入 'q' 退出程序")
|
|
print(" 输入 'switch' 切换模型")
|
|
print(" 输入 'models' 查看已加载模型")
|
|
print(" 输入 'compare' 比较所有模型性能")
|
|
print("-"*60)
|
|
|
|
while True:
|
|
try:
|
|
text = input("\n请输入要分析的微博内容: ").strip()
|
|
|
|
if text.lower() == 'q':
|
|
print("感谢使用,再见!")
|
|
break
|
|
|
|
if text.lower() == 'models':
|
|
print("已加载模型:")
|
|
for model_info in self.models.values():
|
|
print(f" - {model_info['display_name']}")
|
|
continue
|
|
|
|
if text.lower() == 'switch':
|
|
print("切换模型...")
|
|
self.models.clear() # 清空当前模型
|
|
self._select_and_load_model()
|
|
if len(self.models) > 0:
|
|
print("模型切换成功!")
|
|
for model_info in self.models.values():
|
|
print(f" 当前模型: {model_info['display_name']}")
|
|
continue
|
|
|
|
if text.lower() == 'compare':
|
|
test_text = input("请输入要比较的文本: ")
|
|
self._compare_models(test_text)
|
|
continue
|
|
|
|
if not text:
|
|
print("请输入有效内容")
|
|
continue
|
|
|
|
# 预测
|
|
results = self.predict_single(text)
|
|
|
|
print(f"\n原文: {text}")
|
|
print("预测结果:")
|
|
|
|
# 按模型类型和大小排序显示
|
|
sorted_results = sorted(results.items())
|
|
for model_name, (pred, conf) in sorted_results:
|
|
sentiment = "正面" if pred == 1 else "负面"
|
|
print(f" {model_name:20}: {sentiment} (置信度: {conf:.4f})")
|
|
|
|
# 只显示单个模型的预测结果(不进行集成)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n\n程序被中断,再见!")
|
|
break
|
|
except Exception as e:
|
|
print(f"预测过程中出现错误: {e}")
|
|
|
|
def _compare_models(self, text: str):
|
|
"""比较不同模型的性能"""
|
|
print(f"\n模型性能比较 - 文本: {text}")
|
|
print("-" * 60)
|
|
|
|
results = self.predict_single(text)
|
|
|
|
embedding_models = []
|
|
lora_models = []
|
|
|
|
for model_name, (pred, conf) in results.items():
|
|
sentiment = "正面" if pred == 1 else "负面"
|
|
if "Embedding" in model_name:
|
|
embedding_models.append((model_name, sentiment, conf))
|
|
elif "Lora" in model_name:
|
|
lora_models.append((model_name, sentiment, conf))
|
|
|
|
if embedding_models:
|
|
print("Embedding + 分类头方法:")
|
|
for name, sentiment, conf in embedding_models:
|
|
print(f" {name}: {sentiment} ({conf:.4f})")
|
|
|
|
if lora_models:
|
|
print("LoRA微调方法:")
|
|
for name, sentiment, conf in lora_models:
|
|
print(f" {name}: {sentiment} ({conf:.4f})")
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
parser = argparse.ArgumentParser(description='Qwen3微博情感分析统一预测接口')
|
|
parser.add_argument('--model_dir', type=str, default='./models',
|
|
help='模型文件目录')
|
|
parser.add_argument('--model_type', type=str, choices=['embedding', 'lora'],
|
|
help='指定模型类型')
|
|
parser.add_argument('--model_size', type=str, choices=['0.6B', '4B', '8B'],
|
|
help='指定模型大小')
|
|
parser.add_argument('--text', type=str,
|
|
help='直接预测指定文本')
|
|
parser.add_argument('--interactive', action='store_true', default=True,
|
|
help='交互式预测模式(默认)')
|
|
parser.add_argument('--ensemble', action='store_true',
|
|
help='使用集成预测')
|
|
parser.add_argument('--load_all', action='store_true',
|
|
help='加载所有可用模型')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 创建预测器
|
|
predictor = Qwen3UniversalPredictor()
|
|
|
|
# 加载模型
|
|
if args.load_all:
|
|
# 加载所有模型
|
|
predictor.load_all_models(args.model_dir)
|
|
elif args.model_type and args.model_size:
|
|
# 加载指定模型
|
|
predictor.load_model(args.model_type, args.model_size)
|
|
# 如果没有指定模型,交互式模式会让用户选择
|
|
|
|
# 如果指定了文本,直接预测
|
|
if args.text:
|
|
if args.ensemble and len(predictor.models) > 1:
|
|
pred, conf = predictor.ensemble_predict(args.text)
|
|
sentiment = "正面" if pred == 1 else "负面"
|
|
print(f"文本: {args.text}")
|
|
print(f"集成预测: {sentiment} (置信度: {conf:.4f})")
|
|
else:
|
|
results = predictor.predict_single(args.text)
|
|
print(f"文本: {args.text}")
|
|
for model_name, (pred, conf) in results.items():
|
|
sentiment = "正面" if pred == 1 else "负面"
|
|
print(f"{model_name}: {sentiment} (置信度: {conf:.4f})")
|
|
else:
|
|
# 进入交互式模式
|
|
predictor.interactive_predict()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |