Added inference function for the model
This commit is contained in:
+24
-1
@@ -35,7 +35,7 @@ class BERT_CTM_Model:
|
|||||||
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device)
|
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model(**inputs)
|
outputs = self.model(**inputs)
|
||||||
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
|
embeddings.append(outputs.last_hidden_state[:, 0, :].cpu().numpy()) # [batch_size, hidden_size]
|
||||||
return np.vstack(embeddings)
|
return np.vstack(embeddings)
|
||||||
|
|
||||||
def chinese_tokenize(self, text):
|
def chinese_tokenize(self, text):
|
||||||
@@ -57,6 +57,20 @@ class BERT_CTM_Model:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"训练CTM模型时发生错误: {e}")
|
print(f"训练CTM模型时发生错误: {e}")
|
||||||
|
|
||||||
|
def predict(self, texts):
|
||||||
|
"""使用训练好的CTM模型预测新文本的主题分布"""
|
||||||
|
if not self.ctm_model:
|
||||||
|
raise ValueError("模型尚未训练或加载,无法进行预测")
|
||||||
|
|
||||||
|
try:
|
||||||
|
bow_texts = [self.chinese_tokenize(text) for text in texts]
|
||||||
|
testing_dataset = self.tp.transform(text_for_contextual=texts, text_for_bow=bow_texts)
|
||||||
|
topic_distributions = self.ctm_model.get_doc_topic_distribution(testing_dataset)
|
||||||
|
return topic_distributions
|
||||||
|
except Exception as e:
|
||||||
|
print(f"预测主题时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def save_model(self, path):
|
def save_model(self, path):
|
||||||
"""保存训练后的CTM模型"""
|
"""保存训练后的CTM模型"""
|
||||||
if self.ctm_model:
|
if self.ctm_model:
|
||||||
@@ -92,3 +106,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 加载CTM模型
|
# 加载CTM模型
|
||||||
model.load_model('./trained_ctm_model')
|
model.load_model('./trained_ctm_model')
|
||||||
|
|
||||||
|
# 预测新文本的主题分布
|
||||||
|
new_texts = ["这是一个新的文本", "另外一个新文本"]
|
||||||
|
topic_distributions = model.predict(new_texts)
|
||||||
|
|
||||||
|
# 输出预测结果
|
||||||
|
if topic_distributions is not None:
|
||||||
|
for idx, distribution in enumerate(topic_distributions):
|
||||||
|
print(f"文本 {idx+1} 的主题分布: {distribution}")
|
||||||
|
|||||||
Reference in New Issue
Block a user