120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
基础模型类,为所有情感分析模型提供统一接口
|
|
"""
|
|
import os
|
|
import pickle
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Tuple, Dict, Any
|
|
import pandas as pd
|
|
from sklearn.metrics import accuracy_score, f1_score, classification_report
|
|
from utils import load_corpus
|
|
|
|
|
|
class BaseModel(ABC):
|
|
"""情感分析模型基类"""
|
|
|
|
def __init__(self, model_name: str):
|
|
self.model_name = model_name
|
|
self.model = None
|
|
self.vectorizer = None
|
|
self.is_trained = False
|
|
|
|
@abstractmethod
|
|
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
|
|
"""训练模型"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def predict(self, texts: List[str]) -> List[int]:
|
|
"""预测文本情感"""
|
|
pass
|
|
|
|
def predict_single(self, text: str) -> Tuple[int, float]:
|
|
"""预测单条文本的情感
|
|
|
|
Args:
|
|
text: 待预测文本
|
|
|
|
Returns:
|
|
(predicted_label, confidence)
|
|
"""
|
|
predictions = self.predict([text])
|
|
return predictions[0], 0.0 # 默认置信度为0
|
|
|
|
def evaluate(self, test_data: List[Tuple[str, int]]) -> Dict[str, float]:
|
|
"""评估模型性能"""
|
|
if not self.is_trained:
|
|
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
|
|
|
|
texts = [item[0] for item in test_data]
|
|
labels = [item[1] for item in test_data]
|
|
|
|
predictions = self.predict(texts)
|
|
|
|
accuracy = accuracy_score(labels, predictions)
|
|
f1 = f1_score(labels, predictions, average='weighted')
|
|
|
|
print(f"\n{self.model_name} 模型评估结果:")
|
|
print(f"准确率: {accuracy:.4f}")
|
|
print(f"F1分数: {f1:.4f}")
|
|
print("\n详细报告:")
|
|
print(classification_report(labels, predictions))
|
|
|
|
return {
|
|
'accuracy': accuracy,
|
|
'f1_score': f1,
|
|
'classification_report': classification_report(labels, predictions)
|
|
}
|
|
|
|
def save_model(self, model_path: str = None) -> None:
|
|
"""保存模型到文件"""
|
|
if not self.is_trained:
|
|
raise ValueError(f"模型 {self.model_name} 尚未训练,无法保存")
|
|
|
|
if model_path is None:
|
|
model_path = f"model/{self.model_name}_model.pkl"
|
|
|
|
# 创建保存目录
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
|
|
# 保存模型数据
|
|
model_data = {
|
|
'model': self.model,
|
|
'vectorizer': self.vectorizer,
|
|
'model_name': self.model_name,
|
|
'is_trained': self.is_trained
|
|
}
|
|
|
|
with open(model_path, 'wb') as f:
|
|
pickle.dump(model_data, f)
|
|
|
|
print(f"模型已保存到: {model_path}")
|
|
|
|
def load_model(self, model_path: str) -> None:
|
|
"""从文件加载模型"""
|
|
if not os.path.exists(model_path):
|
|
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
|
|
|
with open(model_path, 'rb') as f:
|
|
model_data = pickle.load(f)
|
|
|
|
self.model = model_data['model']
|
|
self.vectorizer = model_data.get('vectorizer')
|
|
self.model_name = model_data['model_name']
|
|
self.is_trained = model_data['is_trained']
|
|
|
|
print(f"已加载模型: {model_path}")
|
|
|
|
@staticmethod
|
|
def load_data(train_path: str, test_path: str) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
|
"""加载训练和测试数据"""
|
|
print("加载训练数据...")
|
|
train_data = load_corpus(train_path)
|
|
print(f"训练数据量: {len(train_data)}")
|
|
|
|
print("加载测试数据...")
|
|
test_data = load_corpus(test_path)
|
|
print(f"测试数据量: {len(test_data)}")
|
|
|
|
return train_data, test_data |