Fix: Provide a seed for the random_state parameter.
This commit is contained in:
+16
-3
@@ -219,7 +219,15 @@ class LSTMModelManager:
|
|||||||
|
|
||||||
def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522,
|
def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522,
|
||||||
embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1,
|
embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1,
|
||||||
bidirectional=True, dropout=0.3, word2vec_path=None):
|
bidirectional=True, dropout=0.3, word2vec_path=None, random_seed=42):
|
||||||
|
# 设置随机种子以确保可重现性
|
||||||
|
self.random_seed = random_seed
|
||||||
|
random.seed(random_seed)
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
torch.manual_seed(random_seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(random_seed)
|
||||||
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@@ -305,13 +313,18 @@ class LSTMModelManager:
|
|||||||
|
|
||||||
if val_texts is None:
|
if val_texts is None:
|
||||||
X_train, X_val, y_train, y_val = train_test_split(
|
X_train, X_val, y_train, y_val = train_test_split(
|
||||||
X_train, train_labels, test_size=0.2, stratify=train_labels
|
X_train, train_labels, test_size=0.2,
|
||||||
|
stratify=train_labels,
|
||||||
|
random_state=self.random_seed # 添加随机种子
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
X_val = vectorizer.transform(val_texts)
|
X_val = vectorizer.transform(val_texts)
|
||||||
y_train, y_val = train_labels, val_labels
|
y_train, y_val = train_labels, val_labels
|
||||||
|
|
||||||
lr_model = LogisticRegression(class_weight='balanced')
|
lr_model = LogisticRegression(
|
||||||
|
class_weight='balanced',
|
||||||
|
random_state=self.random_seed # 添加随机种子
|
||||||
|
)
|
||||||
lr_model.fit(X_train, y_train)
|
lr_model.fit(X_train, y_train)
|
||||||
|
|
||||||
val_pred = lr_model.predict(X_val)
|
val_pred = lr_model.predict(X_val)
|
||||||
|
|||||||
Reference in New Issue
Block a user