import torch import os import logging from LSTM_model import lstm_model_manager # 配置日志记录 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger('lstm_predict') class LSTMPredictor: """LSTM预测器,与当前系统的预测接口兼容""" def __init__(self): self.model_loaded = False self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"初始化LSTM预测器,使用设备: {self.device}") def load_models(self, model_save_path, bert_model_path, tokenizer_path=None): """ 加载模型,与当前系统的model_manager.load_models接口兼容 参数: model_save_path: LSTM模型保存路径 bert_model_path: BERT模型路径 tokenizer_path: 分词器路径(LSTM模型中使用BERT的分词器,可忽略) """ try: # 检查模型文件是否存在 if not os.path.exists(model_save_path): logger.warning(f"模型文件 {model_save_path} 不存在,需要先训练模型") return False if not os.path.exists(bert_model_path): logger.error(f"BERT模型路径 {bert_model_path} 不存在") return False # 实际上我们在lstm_model_manager初始化时已经加载了模型,这里只是检查一下 if lstm_model_manager.model is not None: self.model_loaded = True logger.info("LSTM模型已加载成功") return True else: logger.error("LSTM模型加载失败") return False except Exception as e: logger.error(f"加载模型过程中出错: {e}") return False def predict_batch(self, texts): """ 批量预测文本的情感 参数: texts: 文本列表 返回: predictions: 预测结果列表(0表示良好,1表示不良) probabilities: 预测概率列表 """ if not self.model_loaded and lstm_model_manager.model is None: logger.error("模型未加载,无法进行预测") return None, None if not texts: logger.warning("未提供文本,无法进行预测") return None, None try: # 调用LSTM模型管理器的批量预测函数 predictions, probabilities = lstm_model_manager.predict_batch(texts) return predictions, probabilities except Exception as e: logger.error(f"预测过程中出错: {e}") return None, None def predict(self, text): """ 预测单个文本的情感 参数: text: 文本字符串 返回: prediction: 预测结果(0表示良好,1表示不良) probability: 预测概率 """ if not self.model_loaded and lstm_model_manager.model is None: logger.error("模型未加载,无法进行预测") return None, None if not text or len(text.strip()) == 0: logger.warning("未提供文本或文本为空,无法进行预测") return None, None try: # 调用LSTM模型管理器的单个文本预测函数 prediction, probability = lstm_model_manager.predict(text) return prediction, probability except Exception as e: logger.error(f"预测过程中出错: {e}") return None, None def train_model(self, train_texts, train_labels, val_texts=None, val_labels=None, batch_size=32, learning_rate=2e-5, epochs=10): """ 训练模型 参数: train_texts: 训练集文本 train_labels: 训练集标签 val_texts: 验证集文本 val_labels: 验证集标签 batch_size: 批次大小 learning_rate: 学习率 epochs: 训练轮数 返回: 训练结果 """ try: results = lstm_model_manager.train( train_texts, train_labels, val_texts, val_labels, batch_size, learning_rate, epochs ) self.model_loaded = True return results except Exception as e: logger.error(f"训练模型过程中出错: {e}") return None # 创建全局预测器实例 lstm_predictor = LSTMPredictor() # 为了与现有代码兼容,提供一个与model_manager相同的predict_batch函数 def predict_batch(texts): return lstm_predictor.predict_batch(texts) # 为了与现有代码兼容,提供一个与model_manager相同的load_models函数 def load_models(model_save_path, bert_model_path, tokenizer_path=None): return lstm_predictor.load_models(model_save_path, bert_model_path, tokenizer_path) # 测试代码 if __name__ == "__main__": # 加载模型 load_models( model_save_path="model_pro/lstm_model.pt", bert_model_path="model_pro/bert_model" ) # 测试预测功能 test_sentences = [ "这件事情做得非常好", "服务太差了,态度恶劣", "这个产品质量一般,但价格便宜", "我对这家公司非常满意", ] for sentence in test_sentences: pred, prob = lstm_predictor.predict(sentence) if pred is not None: label = '良好' if pred == 0 else '不良' confidence = prob[pred] print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})") else: print(f"句子: '{sentence}' 预测失败")