Files
bettafish-company/model_pro/lstm_predict.py
T

165 lines
5.8 KiB
Python

import torch
import os
import logging
from LSTM_model import lstm_model_manager
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('lstm_predict')
class LSTMPredictor:
"""LSTM预测器,与当前系统的预测接口兼容"""
def __init__(self):
self.model_loaded = False
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"初始化LSTM预测器,使用设备: {self.device}")
def load_models(self, model_save_path, bert_model_path, tokenizer_path=None):
"""
加载模型,与当前系统的model_manager.load_models接口兼容
参数:
model_save_path: LSTM模型保存路径
bert_model_path: BERT模型路径
tokenizer_path: 分词器路径(LSTM模型中使用BERT的分词器,可忽略)
"""
try:
# 检查模型文件是否存在
if not os.path.exists(model_save_path):
logger.warning(f"模型文件 {model_save_path} 不存在,需要先训练模型")
return False
if not os.path.exists(bert_model_path):
logger.error(f"BERT模型路径 {bert_model_path} 不存在")
return False
# 实际上我们在lstm_model_manager初始化时已经加载了模型,这里只是检查一下
if lstm_model_manager.model is not None:
self.model_loaded = True
logger.info("LSTM模型已加载成功")
return True
else:
logger.error("LSTM模型加载失败")
return False
except Exception as e:
logger.error(f"加载模型过程中出错: {e}")
return False
def predict_batch(self, texts):
"""
批量预测文本的情感
参数:
texts: 文本列表
返回:
predictions: 预测结果列表(0表示良好,1表示不良)
probabilities: 预测概率列表
"""
if not self.model_loaded and lstm_model_manager.model is None:
logger.error("模型未加载,无法进行预测")
return None, None
if not texts:
logger.warning("未提供文本,无法进行预测")
return None, None
try:
# 调用LSTM模型管理器的批量预测函数
predictions, probabilities = lstm_model_manager.predict_batch(texts)
return predictions, probabilities
except Exception as e:
logger.error(f"预测过程中出错: {e}")
return None, None
def predict(self, text):
"""
预测单个文本的情感
参数:
text: 文本字符串
返回:
prediction: 预测结果(0表示良好,1表示不良)
probability: 预测概率
"""
if not self.model_loaded and lstm_model_manager.model is None:
logger.error("模型未加载,无法进行预测")
return None, None
if not text or len(text.strip()) == 0:
logger.warning("未提供文本或文本为空,无法进行预测")
return None, None
try:
# 调用LSTM模型管理器的单个文本预测函数
prediction, probability = lstm_model_manager.predict(text)
return prediction, probability
except Exception as e:
logger.error(f"预测过程中出错: {e}")
return None, None
def train_model(self, train_texts, train_labels, val_texts=None, val_labels=None,
batch_size=32, learning_rate=2e-5, epochs=10):
"""
训练模型
参数:
train_texts: 训练集文本
train_labels: 训练集标签
val_texts: 验证集文本
val_labels: 验证集标签
batch_size: 批次大小
learning_rate: 学习率
epochs: 训练轮数
返回:
训练结果
"""
try:
results = lstm_model_manager.train(
train_texts, train_labels, val_texts, val_labels,
batch_size, learning_rate, epochs
)
self.model_loaded = True
return results
except Exception as e:
logger.error(f"训练模型过程中出错: {e}")
return None
# 创建全局预测器实例
lstm_predictor = LSTMPredictor()
# 为了与现有代码兼容,提供一个与model_manager相同的predict_batch函数
def predict_batch(texts):
return lstm_predictor.predict_batch(texts)
# 为了与现有代码兼容,提供一个与model_manager相同的load_models函数
def load_models(model_save_path, bert_model_path, tokenizer_path=None):
return lstm_predictor.load_models(model_save_path, bert_model_path, tokenizer_path)
# 测试代码
if __name__ == "__main__":
# 加载模型
load_models(
model_save_path="model_pro/lstm_model.pt",
bert_model_path="model_pro/bert_model"
)
# 测试预测功能
test_sentences = [
"这件事情做得非常好",
"服务太差了,态度恶劣",
"这个产品质量一般,但价格便宜",
"我对这家公司非常满意",
]
for sentence in test_sentences:
pred, prob = lstm_predictor.predict(sentence)
if pred is not None:
label = '良好' if pred == 0 else '不良'
confidence = prob[pred]
print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})")
else:
print(f"句子: '{sentence}' 预测失败")