Files
2025-08-23 15:55:07 +08:00

171 lines
6.3 KiB
Python

# -*- coding: utf-8 -*-
"""
Qwen3模型基础类,统一接口
"""
import os
import pickle
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
class BaseQwenModel(ABC):
"""Qwen3情感分析模型基类"""
def __init__(self, model_name: str):
self.model_name = model_name
self.model = None
self.is_trained = False
@abstractmethod
def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
"""训练模型"""
pass
@abstractmethod
def predict(self, texts: List[str]) -> List[int]:
"""预测文本情感"""
pass
def predict_single(self, text: str) -> Tuple[int, float]:
"""预测单条文本的情感
Args:
text: 待预测文本
Returns:
(predicted_label, confidence)
"""
predictions = self.predict([text])
return predictions[0], 0.0 # 默认置信度为0
def evaluate(self, test_data: List[Tuple[str, int]]) -> Dict[str, float]:
"""评估模型性能"""
if not self.is_trained:
raise ValueError(f"模型 {self.model_name} 尚未训练,请先调用train方法")
texts = [item[0] for item in test_data]
labels = [item[1] for item in test_data]
predictions = self.predict(texts)
accuracy = accuracy_score(labels, predictions)
f1 = f1_score(labels, predictions, average='weighted')
print(f"\n{self.model_name} 模型评估结果:")
print(f"准确率: {accuracy:.4f}")
print(f"F1分数: {f1:.4f}")
print("\n详细报告:")
print(classification_report(labels, predictions))
return {
'accuracy': accuracy,
'f1_score': f1,
'classification_report': classification_report(labels, predictions)
}
@abstractmethod
def save_model(self, model_path: str = None) -> None:
"""保存模型到文件"""
pass
@abstractmethod
def load_model(self, model_path: str) -> None:
"""从文件加载模型"""
pass
@staticmethod
def load_data(train_path: str = None, test_path: str = None, csv_path: str = 'dataset/weibo_senti_100k.csv') -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
"""加载训练和测试数据
Args:
train_path: 训练数据txt文件路径(可选)
test_path: 测试数据txt文件路径(可选)
csv_path: CSV数据文件路径(默认使用)
"""
# 优先尝试使用CSV文件
if os.path.exists(csv_path):
print(f"从CSV文件加载数据: {csv_path}")
df = pd.read_csv(csv_path)
# 检查数据格式
if 'review' in df.columns and 'label' in df.columns:
# 将DataFrame转换为元组列表
data = [(row['review'], row['label']) for _, row in df.iterrows()]
# 分割训练和测试数据,固定测试集为5000条
total_samples = len(data)
if total_samples > 5000:
test_size = 5000
train_data, test_data = train_test_split(
data,
test_size=test_size,
random_state=42,
stratify=[label for _, label in data]
)
else:
# 如果总数据不足5000条,使用20%作为测试集
train_data, test_data = train_test_split(
data,
test_size=0.2,
random_state=42,
stratify=[label for _, label in data]
)
print(f"训练数据量: {len(train_data)}")
print(f"测试数据量: {len(test_data)}")
return train_data, test_data
else:
print(f"CSV文件格式不正确,缺少'review''label'")
# 如果CSV不存在,尝试使用txt文件
elif train_path and test_path and os.path.exists(train_path) and os.path.exists(test_path):
def load_corpus(path):
data = []
with open(path, "r", encoding="utf8") as f:
for line in f:
parts = line.strip().split("\t")
if len(parts) >= 2:
content = parts[0]
sentiment = int(parts[1])
data.append((content, sentiment))
return data
print("从txt文件加载训练数据...")
train_data = load_corpus(train_path)
print(f"训练数据量: {len(train_data)}")
print("从txt文件加载测试数据...")
test_data = load_corpus(test_path)
print(f"测试数据量: {len(test_data)}")
return train_data, test_data
else:
# 如果都没有,提供样例数据创建指导
print("未找到数据文件!")
print("请确保以下文件之一存在:")
print(f"1. CSV文件: {csv_path}")
print(f"2. txt文件: {train_path}{test_path}")
print("\n数据格式要求:")
print("CSV文件: 包含'review''label'")
print("txt文件: 每行格式为'文本内容\\t标签'")
# 创建样例数据
sample_data = [
("今天天气真好,心情很棒!", 1),
("这部电影太无聊了", 0),
("非常喜欢这个产品", 1),
("服务态度很差", 0),
("质量不错,值得推荐", 1)
]
print("使用样例数据进行演示...")
train_data = sample_data * 20 # 扩充样例数据
test_data = sample_data * 5
return train_data, test_data