Files
2025-08-23 15:55:07 +08:00

310 lines
11 KiB
Python

# -*- coding: utf-8 -*-
"""
统一的情感分析预测程序
支持加载所有模型进行情感预测
"""
import argparse
import os
import re
from typing import Dict, Tuple, List
import warnings
warnings.filterwarnings("ignore")
# 导入所有模型类
from bayes_train import BayesModel
from svm_train import SVMModel
from xgboost_train import XGBoostModel
from lstm_train import LSTMModel
from bert_train import BertModel_Custom
from utils import processing
class SentimentPredictor:
"""情感分析预测器"""
def __init__(self):
self.models = {}
self.available_models = {
'bayes': BayesModel,
'svm': SVMModel,
'xgboost': XGBoostModel,
'lstm': LSTMModel,
'bert': BertModel_Custom
}
def load_model(self, model_type: str, model_path: str, **kwargs) -> None:
"""加载指定类型的模型
Args:
model_type: 模型类型 ('bayes', 'svm', 'xgboost', 'lstm', 'bert')
model_path: 模型文件路径
**kwargs: 其他参数(如BERT的预训练模型路径)
"""
if model_type not in self.available_models:
raise ValueError(f"不支持的模型类型: {model_type}")
if not os.path.exists(model_path):
print(f"警告: 模型文件不存在: {model_path}")
return
print(f"加载 {model_type.upper()} 模型...")
try:
if model_type == 'bert':
# BERT需要额外的预训练模型路径
bert_path = kwargs.get('bert_path', './model/chinese_wwm_pytorch')
model = BertModel_Custom(bert_path)
else:
model = self.available_models[model_type]()
model.load_model(model_path)
self.models[model_type] = model
print(f"{model_type.upper()} 模型加载成功")
except Exception as e:
print(f"加载 {model_type.upper()} 模型失败: {e}")
def load_all_models(self, model_dir: str = './model', bert_path: str = './model/chinese_wwm_pytorch') -> None:
"""加载所有可用的模型
Args:
model_dir: 模型文件目录
bert_path: BERT预训练模型路径
"""
model_files = {
'bayes': os.path.join(model_dir, 'bayes_model.pkl'),
'svm': os.path.join(model_dir, 'svm_model.pkl'),
'xgboost': os.path.join(model_dir, 'xgboost_model.pkl'),
'lstm': os.path.join(model_dir, 'lstm_model.pth'),
'bert': os.path.join(model_dir, 'bert_model.pth')
}
print("开始加载所有可用模型...")
for model_type, model_path in model_files.items():
self.load_model(model_type, model_path, bert_path=bert_path)
print(f"\n已加载 {len(self.models)} 个模型: {list(self.models.keys())}")
def predict_single(self, text: str, model_type: str = None) -> Dict[str, Tuple[int, float]]:
"""预测单条文本的情感
Args:
text: 待预测文本
model_type: 指定模型类型,如果为None则使用所有已加载的模型
Returns:
Dict[model_type, (prediction, confidence)]
"""
# 文本预处理
processed_text = processing(text)
if model_type:
if model_type not in self.models:
raise ValueError(f"模型 {model_type} 未加载")
prediction, confidence = self.models[model_type].predict_single(processed_text)
return {model_type: (prediction, confidence)}
# 使用所有模型预测
results = {}
for name, model in self.models.items():
try:
prediction, confidence = model.predict_single(processed_text)
results[name] = (prediction, confidence)
except Exception as e:
print(f"模型 {name} 预测失败: {e}")
results[name] = (0, 0.0)
return results
def predict_batch(self, texts: List[str], model_type: str = None) -> Dict[str, List[int]]:
"""批量预测文本情感
Args:
texts: 待预测文本列表
model_type: 指定模型类型,如果为None则使用所有已加载的模型
Returns:
Dict[model_type, predictions]
"""
# 文本预处理
processed_texts = [processing(text) for text in texts]
if model_type:
if model_type not in self.models:
raise ValueError(f"模型 {model_type} 未加载")
predictions = self.models[model_type].predict(processed_texts)
return {model_type: predictions}
# 使用所有模型预测
results = {}
for name, model in self.models.items():
try:
predictions = model.predict(processed_texts)
results[name] = predictions
except Exception as e:
print(f"模型 {name} 预测失败: {e}")
results[name] = [0] * len(texts)
return results
def ensemble_predict(self, text: str, weights: Dict[str, float] = None) -> Tuple[int, float]:
"""集成预测(多个模型投票)
Args:
text: 待预测文本
weights: 模型权重,如果为None则平均权重
Returns:
(prediction, confidence)
"""
if len(self.models) == 0:
raise ValueError("没有加载任何模型")
results = self.predict_single(text)
if weights is None:
weights = {name: 1.0 for name in results.keys()}
# 加权平均
total_weight = 0
weighted_prob = 0
for model_name, (pred, conf) in results.items():
if model_name in weights:
weight = weights[model_name]
prob = conf if pred == 1 else 1 - conf
weighted_prob += prob * weight
total_weight += weight
if total_weight == 0:
return 0, 0.5
final_prob = weighted_prob / total_weight
final_pred = int(final_prob > 0.5)
final_conf = final_prob if final_pred == 1 else 1 - final_prob
return final_pred, final_conf
def interactive_predict(self):
"""交互式预测模式"""
if len(self.models) == 0:
print("错误: 没有加载任何模型,请先加载模型")
return
print("\n" + "="*50)
print("="*50)
print(f"已加载模型: {', '.join(self.models.keys())}")
print("输入 'q' 退出程序")
print("输入 'models' 查看模型列表")
print("输入 'ensemble' 使用集成预测")
print("-"*50)
while True:
try:
text = input("\n请输入要分析的微博内容: ").strip()
if text.lower() == 'q':
print("👋 再见!")
break
if text.lower() == 'models':
print(f"已加载模型: {list(self.models.keys())}")
continue
if text.lower() == 'ensemble':
if len(self.models) > 1:
pred, conf = self.ensemble_predict(text)
sentiment = "😊 正面" if pred == 1 else "😞 负面"
print(f"\n🤖 集成预测结果:")
print(f" 情感倾向: {sentiment}")
print(f" 置信度: {conf:.4f}")
else:
print("❌ 集成预测需要至少2个模型")
continue
if not text:
print("❌ 请输入有效内容")
continue
# 预测
results = self.predict_single(text)
print(f"\n📝 原文: {text}")
print("🔍 预测结果:")
for model_name, (pred, conf) in results.items():
sentiment = "😊 正面" if pred == 1 else "😞 负面"
print(f" {model_name.upper():8}: {sentiment} (置信度: {conf:.4f})")
# 如果有多个模型,显示集成结果
if len(results) > 1:
ensemble_pred, ensemble_conf = self.ensemble_predict(text)
ensemble_sentiment = "😊 正面" if ensemble_pred == 1 else "😞 负面"
print(f" {'集成':8}: {ensemble_sentiment} (置信度: {ensemble_conf:.4f})")
except KeyboardInterrupt:
print("\n\n👋 程序被中断,再见!")
break
except Exception as e:
print(f"❌ 预测过程中出现错误: {e}")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='微博情感分析统一预测程序')
parser.add_argument('--model_dir', type=str, default='./model',
help='模型文件目录')
parser.add_argument('--bert_path', type=str, default='./model/chinese_wwm_pytorch',
help='BERT预训练模型路径')
parser.add_argument('--model_type', type=str, choices=['bayes', 'svm', 'xgboost', 'lstm', 'bert'],
help='指定单个模型类型进行预测')
parser.add_argument('--text', type=str,
help='直接预测指定文本')
parser.add_argument('--interactive', action='store_true', default=True,
help='交互式预测模式(默认)')
parser.add_argument('--ensemble', action='store_true',
help='使用集成预测')
args = parser.parse_args()
# 创建预测器
predictor = SentimentPredictor()
# 加载模型
if args.model_type:
# 加载指定模型
model_files = {
'bayes': 'bayes_model.pkl',
'svm': 'svm_model.pkl',
'xgboost': 'xgboost_model.pkl',
'lstm': 'lstm_model.pth',
'bert': 'bert_model.pth'
}
model_path = os.path.join(args.model_dir, model_files[args.model_type])
predictor.load_model(args.model_type, model_path, bert_path=args.bert_path)
else:
# 加载所有模型
predictor.load_all_models(args.model_dir, args.bert_path)
# 如果指定了文本,直接预测
if args.text:
if args.ensemble and len(predictor.models) > 1:
pred, conf = predictor.ensemble_predict(args.text)
sentiment = "正面" if pred == 1 else "负面"
print(f"文本: {args.text}")
print(f"集成预测: {sentiment} (置信度: {conf:.4f})")
else:
results = predictor.predict_single(args.text, args.model_type)
print(f"文本: {args.text}")
for model_name, (pred, conf) in results.items():
sentiment = "正面" if pred == 1 else "负面"
print(f"{model_name.upper()}: {sentiment} (置信度: {conf:.4f})")
elif args.interactive:
# 交互式模式
predictor.interactive_predict()
if __name__ == "__main__":
main()