Added a base model class and training scripts for various sentiment analysis models, including Naive Bayes, SVM, XGBoost, LSTM, and BERT. Also, improved prediction functionality and the model loading mechanism.
This commit is contained in:
@@ -0,0 +1,120 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基础模型类,为所有情感分析模型提供统一接口
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import pandas as pd
|
||||
from sklearn.metrics import accuracy_score, f1_score, classification_report
|
||||
from utils import load_corpus
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""情感分析模型基类"""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self.model = None
|
||||
self.vectorizer = None
|
||||
self.is_trained = False
|
||||
|
||||
@abstractmethod
|
||||
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
||||
"""训练模型"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, texts: List[str]) -> List[int]:
|
||||
"""预测文本情感"""
|
||||
pass
|
||||
|
||||
def predict_single(self, text: str) -> Tuple[int, float]:
|
||||
"""预测单条文本的情感
|
||||
|
||||
Args:
|
||||
text: 待预测文本
|
||||
|
||||
Returns:
|
||||
(predicted_label, confidence)
|
||||
"""
|
||||
predictions = self.predict([text])
|
||||
return predictions[0], 0.0 # 默认置信度为0
|
||||
|
||||
def evaluate(self, test_data: List[Tuple[str, int]]) -> Dict[str, float]:
|
||||
"""评估模型性能"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
||||
|
||||
texts = [item[0] for item in test_data]
|
||||
labels = [item[1] for item in test_data]
|
||||
|
||||
predictions = self.predict(texts)
|
||||
|
||||
accuracy = accuracy_score(labels, predictions)
|
||||
f1 = f1_score(labels, predictions, average='weighted')
|
||||
|
||||
print(f"\n{self.model_name} 模型评估结果:")
|
||||
print(f"准确率: {accuracy:.4f}")
|
||||
print(f"F1分数: {f1:.4f}")
|
||||
print("\n详细报告:")
|
||||
print(classification_report(labels, predictions))
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'f1_score': f1,
|
||||
'classification_report': classification_report(labels, predictions)
|
||||
}
|
||||
|
||||
def save_model(self, model_path: str = None) -> None:
|
||||
"""保存模型到文件"""
|
||||
if not self.is_trained:
|
||||
raise ValueError(f"模型 {self.model_name} 尚未训练,无法保存")
|
||||
|
||||
if model_path is None:
|
||||
model_path = f"model/{self.model_name}_model.pkl"
|
||||
|
||||
# 创建保存目录
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
# 保存模型数据
|
||||
model_data = {
|
||||
'model': self.model,
|
||||
'vectorizer': self.vectorizer,
|
||||
'model_name': self.model_name,
|
||||
'is_trained': self.is_trained
|
||||
}
|
||||
|
||||
with open(model_path, 'wb') as f:
|
||||
pickle.dump(model_data, f)
|
||||
|
||||
print(f"模型已保存到: {model_path}")
|
||||
|
||||
def load_model(self, model_path: str) -> None:
|
||||
"""从文件加载模型"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
with open(model_path, 'rb') as f:
|
||||
model_data = pickle.load(f)
|
||||
|
||||
self.model = model_data['model']
|
||||
self.vectorizer = model_data.get('vectorizer')
|
||||
self.model_name = model_data['model_name']
|
||||
self.is_trained = model_data['is_trained']
|
||||
|
||||
print(f"已加载模型: {model_path}")
|
||||
|
||||
@staticmethod
|
||||
def load_data(train_path: str, test_path: str) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
||||
"""加载训练和测试数据"""
|
||||
print("加载训练数据...")
|
||||
train_data = load_corpus(train_path)
|
||||
print(f"训练数据量: {len(train_data)}")
|
||||
|
||||
print("加载测试数据...")
|
||||
test_data = load_corpus(test_path)
|
||||
print(f"测试数据量: {len(test_data)}")
|
||||
|
||||
return train_data, test_data
|
||||
Reference in New Issue
Block a user