diff --git a/model_pro/BERT_CTM.py b/model_pro/BERT_CTM.py index 61128b0..dbc7dde 100644 --- a/model_pro/BERT_CTM.py +++ b/model_pro/BERT_CTM.py @@ -35,7 +35,7 @@ class BERT_CTM_Model: inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) - embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] + embeddings.append(outputs.last_hidden_state[:, 0, :].cpu().numpy()) # [batch_size, hidden_size] return np.vstack(embeddings) def chinese_tokenize(self, text): @@ -57,6 +57,20 @@ class BERT_CTM_Model: except Exception as e: print(f"训练CTM模型时发生错误: {e}") + def predict(self, texts): + """使用训练好的CTM模型预测新文本的主题分布""" + if not self.ctm_model: + raise ValueError("模型尚未训练或加载,无法进行预测") + + try: + bow_texts = [self.chinese_tokenize(text) for text in texts] + testing_dataset = self.tp.transform(text_for_contextual=texts, text_for_bow=bow_texts) + topic_distributions = self.ctm_model.get_doc_topic_distribution(testing_dataset) + return topic_distributions + except Exception as e: + print(f"预测主题时发生错误: {e}") + return None + def save_model(self, path): """保存训练后的CTM模型""" if self.ctm_model: @@ -92,3 +106,12 @@ if __name__ == "__main__": # 加载CTM模型 model.load_model('./trained_ctm_model') + + # 预测新文本的主题分布 + new_texts = ["这是一个新的文本", "另外一个新文本"] + topic_distributions = model.predict(new_texts) + + # 输出预测结果 + if topic_distributions is not None: + for idx, distribution in enumerate(topic_distributions): + print(f"文本 {idx+1} 的主题分布: {distribution}")