Added a logging utility class and supplemented, standardized the logging output for all modules.
This commit is contained in:
+111
-94
@@ -5,109 +5,126 @@ from tqdm import tqdm
|
||||
from transformers.models.bert import BertTokenizer, BertModel
|
||||
from contextualized_topic_models.models.ctm import CombinedTM
|
||||
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
|
||||
from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessing
|
||||
import numpy as np
|
||||
import torch
|
||||
import jieba
|
||||
import pickle # 用于保存和加载模型
|
||||
from utils.logger import model_logger as logging
|
||||
|
||||
class BERT_CTM_Model:
|
||||
def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, model_save_path='./ctm_model'):
|
||||
self.bert_model_path = bert_model_path
|
||||
self.ctm_tokenizer_path = ctm_tokenizer_path
|
||||
self.n_components = n_components
|
||||
self.num_epochs = num_epochs
|
||||
class BERT_CTM:
|
||||
def __init__(self, model_save_path='model_pro/saved_models/ctm_model.pkl'):
|
||||
self.model_save_path = model_save_path
|
||||
# 加载BERT模型和tokenizer
|
||||
self.tokenizer = BertTokenizer.from_pretrained(self.bert_model_path)
|
||||
self.model = BertModel.from_pretrained(self.bert_model_path)
|
||||
|
||||
# 创建CTM数据预处理对象
|
||||
self.tp = TopicModelDataPreparation(self.ctm_tokenizer_path)
|
||||
|
||||
def chinese_tokenize(self, text):
|
||||
"""使用jieba对中文文本进行分词"""
|
||||
return " ".join(jieba.cut(text))
|
||||
|
||||
def get_bert_embeddings(self, texts):
|
||||
"""使用BERT模型生成文本的嵌入向量"""
|
||||
embeddings = []
|
||||
for text in tqdm(texts, desc="Processing texts with BERT"):
|
||||
inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
|
||||
return np.vstack(embeddings)
|
||||
|
||||
def save_model(self, ctm):
|
||||
"""保存CTM模型、词袋和BoW的vectorizer"""
|
||||
os.makedirs(self.model_save_path, exist_ok=True)
|
||||
with open(f"{self.model_save_path}/ctm_model.pkl", 'wb') as f:
|
||||
pickle.dump(ctm, f)
|
||||
with open(f"{self.model_save_path}/vocab.pkl", 'wb') as f:
|
||||
pickle.dump(self.tp.vocab, f)
|
||||
with open(f"{self.model_save_path}/vectorizer.pkl", 'wb') as f: # 保存BoW的vectorizer
|
||||
pickle.dump(self.tp.vectorizer, f)
|
||||
print(f"CTM模型和词袋保存到: {self.model_save_path}")
|
||||
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.bert_model = None
|
||||
self.tokenizer = None
|
||||
self.ctm_model = None
|
||||
self.vocab = None
|
||||
self.vectorizer = None
|
||||
|
||||
def save_model(self):
|
||||
"""保存模型和词袋"""
|
||||
try:
|
||||
with open(self.model_save_path, 'wb') as f:
|
||||
pickle.dump({
|
||||
'ctm_model': self.ctm_model,
|
||||
'vocab': self.vocab,
|
||||
'vectorizer': self.vectorizer
|
||||
}, f)
|
||||
logging.info(f"CTM模型和词袋保存到: {self.model_save_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"保存模型时发生错误: {e}")
|
||||
|
||||
def load_model(self):
|
||||
"""加载CTM模型、词袋和BoW的vectorizer"""
|
||||
with open(f"{self.model_save_path}/ctm_model.pkl", 'rb') as f:
|
||||
ctm = pickle.load(f)
|
||||
with open(f"{self.model_save_path}/vocab.pkl", 'rb') as f:
|
||||
self.tp.vocab = pickle.load(f)
|
||||
with open(f"{self.model_save_path}/vectorizer.pkl", 'rb') as f: # 加载BoW的vectorizer
|
||||
self.tp.vectorizer = pickle.load(f)
|
||||
print(f"CTM模型、词袋和vectorizer加载成功")
|
||||
return ctm
|
||||
|
||||
def train(self, csv_file):
|
||||
"""训练BERT + CTM模型并保存最终的特征向量和标签"""
|
||||
# 读取CSV文件中的文本和标签
|
||||
data = pd.read_csv(csv_file)
|
||||
texts = data['TEXT'].tolist()
|
||||
labels = data['label'].tolist()
|
||||
|
||||
# Step 1: 获取BERT的嵌入向量
|
||||
print("Extracting BERT embeddings...")
|
||||
bert_embeddings = self.get_bert_embeddings(texts) # [batch_size, sequence_length, hidden_size]
|
||||
|
||||
# Step 2: 准备CTM数据
|
||||
print("Preparing data for CTM using training set...")
|
||||
bow_texts = [self.chinese_tokenize(text) for text in texts]
|
||||
training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts)
|
||||
|
||||
# Step 3: 替换BERT嵌入
|
||||
training_dataset._X = bert_embeddings[:, 0, :] # 只使用第一个token的向量用于CTM
|
||||
|
||||
# Step 4: 训练CTM模型
|
||||
print("Training CTM model...")
|
||||
ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs)
|
||||
ctm.fit(train_dataset=training_dataset, verbose=True)
|
||||
|
||||
# Step 5: 保存CTM模型和词袋
|
||||
self.save_model(ctm)
|
||||
|
||||
# Step 6: 获取CTM的特征向量
|
||||
print("Generating CTM features...")
|
||||
ctm_features = ctm.get_doc_topic_distribution(training_dataset) # [batch_size, n_components]
|
||||
|
||||
# Step 7: 将CTM特征扩展为与BERT的sequence长度一致
|
||||
sequence_length = bert_embeddings.shape[1]
|
||||
ctm_features_expanded = np.repeat(ctm_features[:, np.newaxis, :], sequence_length, axis=1) # [batch_size, sequence_length, n_components]
|
||||
|
||||
# Step 8: 拼接BERT嵌入和CTM特征
|
||||
final_embeddings = np.concatenate([bert_embeddings, ctm_features_expanded], axis=-1) # [batch_size, sequence_length, hidden_size + n_components]
|
||||
|
||||
return bert_embeddings
|
||||
"""加载模型和词袋"""
|
||||
try:
|
||||
with open(self.model_save_path, 'rb') as f:
|
||||
saved_data = pickle.load(f)
|
||||
self.ctm_model = saved_data['ctm_model']
|
||||
self.vocab = saved_data['vocab']
|
||||
self.vectorizer = saved_data['vectorizer']
|
||||
logging.info("CTM模型、词袋和vectorizer加载成功")
|
||||
except Exception as e:
|
||||
logging.error(f"加载模型时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def train(self, texts, num_topics=10, num_epochs=100):
|
||||
"""训练CTM模型"""
|
||||
try:
|
||||
# 初始化BERT
|
||||
if not self.bert_model:
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
|
||||
self.bert_model = BertModel.from_pretrained('bert-base-chinese').to(self.device)
|
||||
|
||||
# 提取BERT嵌入
|
||||
logging.info("正在提取BERT嵌入...")
|
||||
embeddings = self._get_bert_embeddings(texts)
|
||||
|
||||
# 准备CTM数据
|
||||
logging.info("正在准备CTM训练数据...")
|
||||
preprocessor = WhiteSpacePreprocessing(texts)
|
||||
dataset = TopicModelDataPreparation(embeddings)
|
||||
|
||||
# 训练CTM模型
|
||||
logging.info("正在训练CTM模型...")
|
||||
self.ctm_model = CombinedTM(
|
||||
bow_size=len(preprocessor.vocab),
|
||||
contextual_size=768, # BERT输出维度
|
||||
n_components=num_topics,
|
||||
num_epochs=num_epochs
|
||||
)
|
||||
self.ctm_model.fit(dataset)
|
||||
|
||||
# 保存词袋相关数据
|
||||
self.vocab = preprocessor.vocab
|
||||
self.vectorizer = preprocessor.vectorizer
|
||||
|
||||
# 保存模型
|
||||
self.save_model()
|
||||
logging.info("模型训练完成并保存")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"训练模型时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def _get_bert_embeddings(self, texts):
|
||||
"""获取文本的BERT嵌入"""
|
||||
embeddings = []
|
||||
try:
|
||||
for text in texts:
|
||||
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs)
|
||||
# 使用[CLS]标记的输出作为文档表示
|
||||
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
embeddings.append(embedding[0])
|
||||
|
||||
return np.array(embeddings)
|
||||
except Exception as e:
|
||||
logging.error(f"获取BERT嵌入时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def get_topics(self, num_words=10):
|
||||
"""获取主题词"""
|
||||
try:
|
||||
if not self.ctm_model or not self.vocab:
|
||||
raise ValueError("模型未训练或未加载")
|
||||
|
||||
topics = []
|
||||
for topic_idx in range(self.ctm_model.n_components):
|
||||
topic = self.ctm_model.get_topic_lists(top_n=num_words)[topic_idx]
|
||||
topics.append(topic)
|
||||
return topics
|
||||
except Exception as e:
|
||||
logging.error(f"获取主题词时发生错误: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建BERT_CTM_Model实例
|
||||
model = BERT_CTM_Model(
|
||||
bert_model_path='./bert_model', # BERT模型的路径
|
||||
ctm_tokenizer_path='./sentence_bert_model', # CTM分词器的路径
|
||||
n_components=12, # 主题数量
|
||||
num_epochs=50, # 训练轮次
|
||||
model_save_path='./ctm_model', # 保存路径
|
||||
# 创建BERT_CTM实例
|
||||
model = BERT_CTM(
|
||||
model_save_path='model_pro/saved_models/ctm_model.pkl', # 保存路径
|
||||
)
|
||||
|
||||
# 传入CSV文件路径进行训练
|
||||
|
||||
Reference in New Issue
Block a user