The BERT_CTM module is finally completed
This commit is contained in:
+94
-97
@@ -1,117 +1,114 @@
|
|||||||
import os
|
import os
|
||||||
from transformers.models.bert import BertTokenizer, BertModel
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
import torch
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
from transformers.models.bert import BertTokenizer, BertModel
|
||||||
import jieba
|
|
||||||
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
|
|
||||||
from contextualized_topic_models.models.ctm import CombinedTM
|
from contextualized_topic_models.models.ctm import CombinedTM
|
||||||
|
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import jieba
|
||||||
|
import pickle # 用于保存和加载模型
|
||||||
|
|
||||||
class BERT_CTM_Model:
|
class BERT_CTM_Model:
|
||||||
def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, device=None):
|
def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, model_save_path='./ctm_model'):
|
||||||
# 确定设备 (CPU/GPU)
|
self.bert_model_path = bert_model_path
|
||||||
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
self.ctm_tokenizer_path = ctm_tokenizer_path
|
||||||
|
|
||||||
# 检查模型路径是否存在
|
|
||||||
if not os.path.exists(bert_model_path):
|
|
||||||
raise ValueError(f"BERT模型路径不存在: {bert_model_path}")
|
|
||||||
if not os.path.exists(ctm_tokenizer_path):
|
|
||||||
raise ValueError(f"CTM分词器路径不存在: {ctm_tokenizer_path}")
|
|
||||||
|
|
||||||
# 加载BERT模型和tokenizer
|
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
|
||||||
self.model = BertModel.from_pretrained(bert_model_path).to(self.device)
|
|
||||||
|
|
||||||
# 创建CTM数据预处理对象
|
|
||||||
self.tp = TopicModelDataPreparation(ctm_tokenizer_path)
|
|
||||||
self.n_components = n_components
|
self.n_components = n_components
|
||||||
self.num_epochs = num_epochs
|
self.num_epochs = num_epochs
|
||||||
self.ctm_model = None
|
self.model_save_path = model_save_path
|
||||||
|
# 加载BERT模型和tokenizer
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(self.bert_model_path)
|
||||||
|
self.model = BertModel.from_pretrained(self.bert_model_path)
|
||||||
|
|
||||||
def get_bert_embeddings(self, texts):
|
# 创建CTM数据预处理对象
|
||||||
"""使用BERT模型批量生成文本的嵌入向量"""
|
self.tp = TopicModelDataPreparation(self.ctm_tokenizer_path)
|
||||||
embeddings = []
|
|
||||||
for text in tqdm(texts, desc="Processing texts with BERT"):
|
|
||||||
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
embeddings.append(outputs.last_hidden_state[:, 0, :].cpu().numpy()) # [batch_size, hidden_size]
|
|
||||||
return np.vstack(embeddings)
|
|
||||||
|
|
||||||
def chinese_tokenize(self, text):
|
def chinese_tokenize(self, text):
|
||||||
"""使用jieba对中文文本进行分词"""
|
"""使用jieba对中文文本进行分词"""
|
||||||
return " ".join(jieba.cut(text))
|
return " ".join(jieba.cut(text))
|
||||||
|
|
||||||
|
def get_bert_embeddings(self, texts):
|
||||||
|
"""使用BERT模型生成文本的嵌入向量"""
|
||||||
|
embeddings = []
|
||||||
|
for text in tqdm(texts, desc="Processing texts with BERT"):
|
||||||
|
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
|
||||||
|
return np.vstack(embeddings)
|
||||||
|
|
||||||
def train_ctm(self, texts):
|
def save_model(self, ctm):
|
||||||
"""训练CTM模型"""
|
"""保存CTM模型、词袋和BoW的vectorizer"""
|
||||||
try:
|
os.makedirs(self.model_save_path, exist_ok=True)
|
||||||
# 分词并准备BOW文本
|
with open(f"{self.model_save_path}/ctm_model.pkl", 'wb') as f:
|
||||||
bow_texts = [self.chinese_tokenize(text) for text in texts]
|
pickle.dump(ctm, f)
|
||||||
training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts)
|
with open(f"{self.model_save_path}/vocab.pkl", 'wb') as f:
|
||||||
|
pickle.dump(self.tp.vocab, f)
|
||||||
|
with open(f"{self.model_save_path}/vectorizer.pkl", 'wb') as f: # 保存BoW的vectorizer
|
||||||
|
pickle.dump(self.tp.vectorizer, f)
|
||||||
|
print(f"CTM模型和词袋保存到: {self.model_save_path}")
|
||||||
|
|
||||||
# 训练CTM
|
def load_model(self):
|
||||||
self.ctm_model = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768,
|
"""加载CTM模型、词袋和BoW的vectorizer"""
|
||||||
n_components=self.n_components, num_epochs=self.num_epochs)
|
with open(f"{self.model_save_path}/ctm_model.pkl", 'rb') as f:
|
||||||
self.ctm_model.fit(training_dataset)
|
ctm = pickle.load(f)
|
||||||
print("CTM模型训练完成")
|
with open(f"{self.model_save_path}/vocab.pkl", 'rb') as f:
|
||||||
except Exception as e:
|
self.tp.vocab = pickle.load(f)
|
||||||
print(f"训练CTM模型时发生错误: {e}")
|
with open(f"{self.model_save_path}/vectorizer.pkl", 'rb') as f: # 加载BoW的vectorizer
|
||||||
|
self.tp.vectorizer = pickle.load(f)
|
||||||
|
print(f"CTM模型、词袋和vectorizer加载成功")
|
||||||
|
return ctm
|
||||||
|
|
||||||
def predict(self, texts):
|
def train(self, csv_file):
|
||||||
"""使用训练好的CTM模型预测新文本的主题分布"""
|
"""训练BERT + CTM模型并保存最终的特征向量和标签"""
|
||||||
if not self.ctm_model:
|
# 读取CSV文件中的文本和标签
|
||||||
raise ValueError("模型尚未训练或加载,无法进行预测")
|
data = pd.read_csv(csv_file)
|
||||||
|
texts = data['TEXT'].tolist()
|
||||||
try:
|
labels = data['label'].tolist()
|
||||||
bow_texts = [self.chinese_tokenize(text) for text in texts]
|
|
||||||
testing_dataset = self.tp.transform(text_for_contextual=texts, text_for_bow=bow_texts)
|
|
||||||
topic_distributions = self.ctm_model.get_doc_topic_distribution(testing_dataset)
|
|
||||||
return topic_distributions
|
|
||||||
except Exception as e:
|
|
||||||
print(f"预测主题时发生错误: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def save_model(self, path):
|
# Step 1: 获取BERT的嵌入向量
|
||||||
"""保存训练后的CTM模型"""
|
print("Extracting BERT embeddings...")
|
||||||
if self.ctm_model:
|
bert_embeddings = self.get_bert_embeddings(texts) # [batch_size, sequence_length, hidden_size]
|
||||||
self.ctm_model.save(path)
|
|
||||||
print(f"CTM模型已保存至: {path}")
|
|
||||||
else:
|
|
||||||
print("未找到已训练的CTM模型,无法保存")
|
|
||||||
|
|
||||||
def load_model(self, path):
|
# Step 2: 准备CTM数据
|
||||||
"""加载已保存的CTM模型"""
|
print("Preparing data for CTM using training set...")
|
||||||
if os.path.exists(path):
|
bow_texts = [self.chinese_tokenize(text) for text in texts]
|
||||||
self.ctm_model = CombinedTM.load(path)
|
training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts)
|
||||||
print(f"CTM模型已加载自: {path}")
|
|
||||||
else:
|
# Step 3: 替换BERT嵌入
|
||||||
print(f"无法加载模型,路径不存在: {path}")
|
training_dataset._X = bert_embeddings[:, 0, :] # 只使用第一个token的向量用于CTM
|
||||||
|
|
||||||
|
# Step 4: 训练CTM模型
|
||||||
|
print("Training CTM model...")
|
||||||
|
ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs)
|
||||||
|
ctm.fit(train_dataset=training_dataset, verbose=True)
|
||||||
|
|
||||||
|
# Step 5: 保存CTM模型和词袋
|
||||||
|
self.save_model(ctm)
|
||||||
|
|
||||||
|
# Step 6: 获取CTM的特征向量
|
||||||
|
print("Generating CTM features...")
|
||||||
|
ctm_features = ctm.get_doc_topic_distribution(training_dataset) # [batch_size, n_components]
|
||||||
|
|
||||||
|
# Step 7: 将CTM特征扩展为与BERT的sequence长度一致
|
||||||
|
sequence_length = bert_embeddings.shape[1]
|
||||||
|
ctm_features_expanded = np.repeat(ctm_features[:, np.newaxis, :], sequence_length, axis=1) # [batch_size, sequence_length, n_components]
|
||||||
|
|
||||||
|
# Step 8: 拼接BERT嵌入和CTM特征
|
||||||
|
final_embeddings = np.concatenate([bert_embeddings, ctm_features_expanded], axis=-1) # [batch_size, sequence_length, hidden_size + n_components]
|
||||||
|
|
||||||
|
return bert_embeddings
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 设定BERT和CTM模型的路径
|
# 创建BERT_CTM_Model实例
|
||||||
bert_model_path = './bert_model'
|
model = BERT_CTM_Model(
|
||||||
ctm_tokenizer_path = './sentence_bert_model'
|
bert_model_path='./bert_model', # BERT模型的路径
|
||||||
|
ctm_tokenizer_path='./sentence_bert_model', # CTM分词器的路径
|
||||||
# 初始化模型
|
n_components=12, # 主题数量
|
||||||
model = BERT_CTM_Model(bert_model_path, ctm_tokenizer_path)
|
num_epochs=50, # 训练轮次
|
||||||
|
model_save_path='./ctm_model', # 保存路径
|
||||||
|
)
|
||||||
|
|
||||||
# 示例文本
|
# 传入CSV文件路径进行训练
|
||||||
texts = ["这是第一个文本", "这是第二个文本"]
|
model.train("./train.csv")
|
||||||
|
|
||||||
# 训练CTM模型
|
|
||||||
model.train_ctm(texts)
|
|
||||||
|
|
||||||
# 保存CTM模型
|
|
||||||
model.save_model('./trained_ctm_model')
|
|
||||||
|
|
||||||
# 加载CTM模型
|
|
||||||
model.load_model('./trained_ctm_model')
|
|
||||||
|
|
||||||
# 预测新文本的主题分布
|
|
||||||
new_texts = ["这是一个新的文本", "另外一个新文本"]
|
|
||||||
topic_distributions = model.predict(new_texts)
|
|
||||||
|
|
||||||
# 输出预测结果
|
|
||||||
if topic_distributions is not None:
|
|
||||||
for idx, distribution in enumerate(topic_distributions):
|
|
||||||
print(f"文本 {idx+1} 的主题分布: {distribution}")
|
|
||||||
|
|||||||
Reference in New Issue
Block a user