Enhanced public opinion prediction system by integrating LSTM model.
This commit is contained in:
@@ -0,0 +1,165 @@
|
||||
import torch
|
||||
import os
|
||||
import logging
|
||||
from LSTM_model import lstm_model_manager
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger('lstm_predict')
|
||||
|
||||
class LSTMPredictor:
|
||||
"""LSTM预测器,与当前系统的预测接口兼容"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_loaded = False
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"初始化LSTM预测器,使用设备: {self.device}")
|
||||
|
||||
def load_models(self, model_save_path, bert_model_path, tokenizer_path=None):
|
||||
"""
|
||||
加载模型,与当前系统的model_manager.load_models接口兼容
|
||||
|
||||
参数:
|
||||
model_save_path: LSTM模型保存路径
|
||||
bert_model_path: BERT模型路径
|
||||
tokenizer_path: 分词器路径(LSTM模型中使用BERT的分词器,可忽略)
|
||||
"""
|
||||
try:
|
||||
# 检查模型文件是否存在
|
||||
if not os.path.exists(model_save_path):
|
||||
logger.warning(f"模型文件 {model_save_path} 不存在,需要先训练模型")
|
||||
return False
|
||||
|
||||
if not os.path.exists(bert_model_path):
|
||||
logger.error(f"BERT模型路径 {bert_model_path} 不存在")
|
||||
return False
|
||||
|
||||
# 实际上我们在lstm_model_manager初始化时已经加载了模型,这里只是检查一下
|
||||
if lstm_model_manager.model is not None:
|
||||
self.model_loaded = True
|
||||
logger.info("LSTM模型已加载成功")
|
||||
return True
|
||||
else:
|
||||
logger.error("LSTM模型加载失败")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"加载模型过程中出错: {e}")
|
||||
return False
|
||||
|
||||
def predict_batch(self, texts):
|
||||
"""
|
||||
批量预测文本的情感
|
||||
|
||||
参数:
|
||||
texts: 文本列表
|
||||
|
||||
返回:
|
||||
predictions: 预测结果列表(0表示良好,1表示不良)
|
||||
probabilities: 预测概率列表
|
||||
"""
|
||||
if not self.model_loaded and lstm_model_manager.model is None:
|
||||
logger.error("模型未加载,无法进行预测")
|
||||
return None, None
|
||||
|
||||
if not texts:
|
||||
logger.warning("未提供文本,无法进行预测")
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# 调用LSTM模型管理器的批量预测函数
|
||||
predictions, probabilities = lstm_model_manager.predict_batch(texts)
|
||||
return predictions, probabilities
|
||||
except Exception as e:
|
||||
logger.error(f"预测过程中出错: {e}")
|
||||
return None, None
|
||||
|
||||
def predict(self, text):
|
||||
"""
|
||||
预测单个文本的情感
|
||||
|
||||
参数:
|
||||
text: 文本字符串
|
||||
|
||||
返回:
|
||||
prediction: 预测结果(0表示良好,1表示不良)
|
||||
probability: 预测概率
|
||||
"""
|
||||
if not self.model_loaded and lstm_model_manager.model is None:
|
||||
logger.error("模型未加载,无法进行预测")
|
||||
return None, None
|
||||
|
||||
if not text or len(text.strip()) == 0:
|
||||
logger.warning("未提供文本或文本为空,无法进行预测")
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# 调用LSTM模型管理器的单个文本预测函数
|
||||
prediction, probability = lstm_model_manager.predict(text)
|
||||
return prediction, probability
|
||||
except Exception as e:
|
||||
logger.error(f"预测过程中出错: {e}")
|
||||
return None, None
|
||||
|
||||
def train_model(self, train_texts, train_labels, val_texts=None, val_labels=None,
|
||||
batch_size=32, learning_rate=2e-5, epochs=10):
|
||||
"""
|
||||
训练模型
|
||||
|
||||
参数:
|
||||
train_texts: 训练集文本
|
||||
train_labels: 训练集标签
|
||||
val_texts: 验证集文本
|
||||
val_labels: 验证集标签
|
||||
batch_size: 批次大小
|
||||
learning_rate: 学习率
|
||||
epochs: 训练轮数
|
||||
|
||||
返回:
|
||||
训练结果
|
||||
"""
|
||||
try:
|
||||
results = lstm_model_manager.train(
|
||||
train_texts, train_labels, val_texts, val_labels,
|
||||
batch_size, learning_rate, epochs
|
||||
)
|
||||
self.model_loaded = True
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"训练模型过程中出错: {e}")
|
||||
return None
|
||||
|
||||
# 创建全局预测器实例
|
||||
lstm_predictor = LSTMPredictor()
|
||||
|
||||
# 为了与现有代码兼容,提供一个与model_manager相同的predict_batch函数
|
||||
def predict_batch(texts):
|
||||
return lstm_predictor.predict_batch(texts)
|
||||
|
||||
# 为了与现有代码兼容,提供一个与model_manager相同的load_models函数
|
||||
def load_models(model_save_path, bert_model_path, tokenizer_path=None):
|
||||
return lstm_predictor.load_models(model_save_path, bert_model_path, tokenizer_path)
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
# 加载模型
|
||||
load_models(
|
||||
model_save_path="model_pro/lstm_model.pt",
|
||||
bert_model_path="model_pro/bert_model"
|
||||
)
|
||||
|
||||
# 测试预测功能
|
||||
test_sentences = [
|
||||
"这件事情做得非常好",
|
||||
"服务太差了,态度恶劣",
|
||||
"这个产品质量一般,但价格便宜",
|
||||
"我对这家公司非常满意",
|
||||
]
|
||||
|
||||
for sentence in test_sentences:
|
||||
pred, prob = lstm_predictor.predict(sentence)
|
||||
if pred is not None:
|
||||
label = '良好' if pred == 0 else '不良'
|
||||
confidence = prob[pred]
|
||||
print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})")
|
||||
else:
|
||||
print(f"句子: '{sentence}' 预测失败")
|
||||
Reference in New Issue
Block a user