Local sentiment analysis upload.
This commit is contained in:
@@ -0,0 +1,310 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统一的情感分析预测程序
|
||||
支持加载所有模型进行情感预测
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Tuple, List
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# 导入所有模型类
|
||||
from bayes_train import BayesModel
|
||||
from svm_train import SVMModel
|
||||
from xgboost_train import XGBoostModel
|
||||
from lstm_train import LSTMModel
|
||||
from bert_train import BertModel_Custom
|
||||
from utils import processing
|
||||
|
||||
|
||||
class SentimentPredictor:
|
||||
"""情感分析预测器"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = {}
|
||||
self.available_models = {
|
||||
'bayes': BayesModel,
|
||||
'svm': SVMModel,
|
||||
'xgboost': XGBoostModel,
|
||||
'lstm': LSTMModel,
|
||||
'bert': BertModel_Custom
|
||||
}
|
||||
|
||||
def load_model(self, model_type: str, model_path: str, **kwargs) -> None:
|
||||
"""加载指定类型的模型
|
||||
|
||||
Args:
|
||||
model_type: 模型类型 ('bayes', 'svm', 'xgboost', 'lstm', 'bert')
|
||||
model_path: 模型文件路径
|
||||
**kwargs: 其他参数(如BERT的预训练模型路径)
|
||||
"""
|
||||
if model_type not in self.available_models:
|
||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"警告: 模型文件不存在: {model_path}")
|
||||
return
|
||||
|
||||
print(f"加载 {model_type.upper()} 模型...")
|
||||
|
||||
try:
|
||||
if model_type == 'bert':
|
||||
# BERT需要额外的预训练模型路径
|
||||
bert_path = kwargs.get('bert_path', './model/chinese_wwm_pytorch')
|
||||
model = BertModel_Custom(bert_path)
|
||||
else:
|
||||
model = self.available_models[model_type]()
|
||||
|
||||
model.load_model(model_path)
|
||||
self.models[model_type] = model
|
||||
print(f"{model_type.upper()} 模型加载成功")
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载 {model_type.upper()} 模型失败: {e}")
|
||||
|
||||
def load_all_models(self, model_dir: str = './model', bert_path: str = './model/chinese_wwm_pytorch') -> None:
|
||||
"""加载所有可用的模型
|
||||
|
||||
Args:
|
||||
model_dir: 模型文件目录
|
||||
bert_path: BERT预训练模型路径
|
||||
"""
|
||||
model_files = {
|
||||
'bayes': os.path.join(model_dir, 'bayes_model.pkl'),
|
||||
'svm': os.path.join(model_dir, 'svm_model.pkl'),
|
||||
'xgboost': os.path.join(model_dir, 'xgboost_model.pkl'),
|
||||
'lstm': os.path.join(model_dir, 'lstm_model.pth'),
|
||||
'bert': os.path.join(model_dir, 'bert_model.pth')
|
||||
}
|
||||
|
||||
print("开始加载所有可用模型...")
|
||||
for model_type, model_path in model_files.items():
|
||||
self.load_model(model_type, model_path, bert_path=bert_path)
|
||||
|
||||
print(f"\n已加载 {len(self.models)} 个模型: {list(self.models.keys())}")
|
||||
|
||||
def predict_single(self, text: str, model_type: str = None) -> Dict[str, Tuple[int, float]]:
|
||||
"""预测单条文本的情感
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
model_type: 指定模型类型,如果为None则使用所有已加载的模型
|
||||
|
||||
Returns:
|
||||
Dict[model_type, (prediction, confidence)]
|
||||
"""
|
||||
# 文本预处理
|
||||
processed_text = processing(text)
|
||||
|
||||
if model_type:
|
||||
if model_type not in self.models:
|
||||
raise ValueError(f"模型 {model_type} 未加载")
|
||||
|
||||
prediction, confidence = self.models[model_type].predict_single(processed_text)
|
||||
return {model_type: (prediction, confidence)}
|
||||
|
||||
# 使用所有模型预测
|
||||
results = {}
|
||||
for name, model in self.models.items():
|
||||
try:
|
||||
prediction, confidence = model.predict_single(processed_text)
|
||||
results[name] = (prediction, confidence)
|
||||
except Exception as e:
|
||||
print(f"模型 {name} 预测失败: {e}")
|
||||
results[name] = (0, 0.0)
|
||||
|
||||
return results
|
||||
|
||||
def predict_batch(self, texts: List[str], model_type: str = None) -> Dict[str, List[int]]:
|
||||
"""批量预测文本情感
|
||||
|
||||
Args:
|
||||
texts: 待预测文本列表
|
||||
model_type: 指定模型类型,如果为None则使用所有已加载的模型
|
||||
|
||||
Returns:
|
||||
Dict[model_type, predictions]
|
||||
"""
|
||||
# 文本预处理
|
||||
processed_texts = [processing(text) for text in texts]
|
||||
|
||||
if model_type:
|
||||
if model_type not in self.models:
|
||||
raise ValueError(f"模型 {model_type} 未加载")
|
||||
|
||||
predictions = self.models[model_type].predict(processed_texts)
|
||||
return {model_type: predictions}
|
||||
|
||||
# 使用所有模型预测
|
||||
results = {}
|
||||
for name, model in self.models.items():
|
||||
try:
|
||||
predictions = model.predict(processed_texts)
|
||||
results[name] = predictions
|
||||
except Exception as e:
|
||||
print(f"模型 {name} 预测失败: {e}")
|
||||
results[name] = [0] * len(texts)
|
||||
|
||||
return results
|
||||
|
||||
def ensemble_predict(self, text: str, weights: Dict[str, float] = None) -> Tuple[int, float]:
|
||||
"""集成预测(多个模型投票)
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
weights: 模型权重,如果为None则平均权重
|
||||
|
||||
Returns:
|
||||
(prediction, confidence)
|
||||
"""
|
||||
if len(self.models) == 0:
|
||||
raise ValueError("没有加载任何模型")
|
||||
|
||||
results = self.predict_single(text)
|
||||
|
||||
if weights is None:
|
||||
weights = {name: 1.0 for name in results.keys()}
|
||||
|
||||
# 加权平均
|
||||
total_weight = 0
|
||||
weighted_prob = 0
|
||||
|
||||
for model_name, (pred, conf) in results.items():
|
||||
if model_name in weights:
|
||||
weight = weights[model_name]
|
||||
prob = conf if pred == 1 else 1 - conf
|
||||
weighted_prob += prob * weight
|
||||
total_weight += weight
|
||||
|
||||
if total_weight == 0:
|
||||
return 0, 0.5
|
||||
|
||||
final_prob = weighted_prob / total_weight
|
||||
final_pred = int(final_prob > 0.5)
|
||||
final_conf = final_prob if final_pred == 1 else 1 - final_prob
|
||||
|
||||
return final_pred, final_conf
|
||||
|
||||
def interactive_predict(self):
|
||||
"""交互式预测模式"""
|
||||
if len(self.models) == 0:
|
||||
print("错误: 没有加载任何模型,请先加载模型")
|
||||
return
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("="*50)
|
||||
print(f"已加载模型: {', '.join(self.models.keys())}")
|
||||
print("输入 'q' 退出程序")
|
||||
print("输入 'models' 查看模型列表")
|
||||
print("输入 'ensemble' 使用集成预测")
|
||||
print("-"*50)
|
||||
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入要分析的微博内容: ").strip()
|
||||
|
||||
if text.lower() == 'q':
|
||||
print("👋 再见!")
|
||||
break
|
||||
|
||||
if text.lower() == 'models':
|
||||
print(f"已加载模型: {list(self.models.keys())}")
|
||||
continue
|
||||
|
||||
if text.lower() == 'ensemble':
|
||||
if len(self.models) > 1:
|
||||
pred, conf = self.ensemble_predict(text)
|
||||
sentiment = "😊 正面" if pred == 1 else "😞 负面"
|
||||
print(f"\n🤖 集成预测结果:")
|
||||
print(f" 情感倾向: {sentiment}")
|
||||
print(f" 置信度: {conf:.4f}")
|
||||
else:
|
||||
print("❌ 集成预测需要至少2个模型")
|
||||
continue
|
||||
|
||||
if not text:
|
||||
print("❌ 请输入有效内容")
|
||||
continue
|
||||
|
||||
# 预测
|
||||
results = self.predict_single(text)
|
||||
|
||||
print(f"\n📝 原文: {text}")
|
||||
print("🔍 预测结果:")
|
||||
|
||||
for model_name, (pred, conf) in results.items():
|
||||
sentiment = "😊 正面" if pred == 1 else "😞 负面"
|
||||
print(f" {model_name.upper():8}: {sentiment} (置信度: {conf:.4f})")
|
||||
|
||||
# 如果有多个模型,显示集成结果
|
||||
if len(results) > 1:
|
||||
ensemble_pred, ensemble_conf = self.ensemble_predict(text)
|
||||
ensemble_sentiment = "😊 正面" if ensemble_pred == 1 else "😞 负面"
|
||||
print(f" {'集成':8}: {ensemble_sentiment} (置信度: {ensemble_conf:.4f})")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 程序被中断,再见!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"❌ 预测过程中出现错误: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='微博情感分析统一预测程序')
|
||||
parser.add_argument('--model_dir', type=str, default='./model',
|
||||
help='模型文件目录')
|
||||
parser.add_argument('--bert_path', type=str, default='./model/chinese_wwm_pytorch',
|
||||
help='BERT预训练模型路径')
|
||||
parser.add_argument('--model_type', type=str, choices=['bayes', 'svm', 'xgboost', 'lstm', 'bert'],
|
||||
help='指定单个模型类型进行预测')
|
||||
parser.add_argument('--text', type=str,
|
||||
help='直接预测指定文本')
|
||||
parser.add_argument('--interactive', action='store_true', default=True,
|
||||
help='交互式预测模式(默认)')
|
||||
parser.add_argument('--ensemble', action='store_true',
|
||||
help='使用集成预测')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建预测器
|
||||
predictor = SentimentPredictor()
|
||||
|
||||
# 加载模型
|
||||
if args.model_type:
|
||||
# 加载指定模型
|
||||
model_files = {
|
||||
'bayes': 'bayes_model.pkl',
|
||||
'svm': 'svm_model.pkl',
|
||||
'xgboost': 'xgboost_model.pkl',
|
||||
'lstm': 'lstm_model.pth',
|
||||
'bert': 'bert_model.pth'
|
||||
}
|
||||
model_path = os.path.join(args.model_dir, model_files[args.model_type])
|
||||
predictor.load_model(args.model_type, model_path, bert_path=args.bert_path)
|
||||
else:
|
||||
# 加载所有模型
|
||||
predictor.load_all_models(args.model_dir, args.bert_path)
|
||||
|
||||
# 如果指定了文本,直接预测
|
||||
if args.text:
|
||||
if args.ensemble and len(predictor.models) > 1:
|
||||
pred, conf = predictor.ensemble_predict(args.text)
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"文本: {args.text}")
|
||||
print(f"集成预测: {sentiment} (置信度: {conf:.4f})")
|
||||
else:
|
||||
results = predictor.predict_single(args.text, args.model_type)
|
||||
print(f"文本: {args.text}")
|
||||
for model_name, (pred, conf) in results.items():
|
||||
sentiment = "正面" if pred == 1 else "负面"
|
||||
print(f"{model_name.upper()}: {sentiment} (置信度: {conf:.4f})")
|
||||
elif args.interactive:
|
||||
# 交互式模式
|
||||
predictor.interactive_predict()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user