From f5e307d3f80999cb047c5d46ed2833dc4e688df7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Sun, 6 Oct 2024 11:34:31 +0800 Subject: [PATCH] Define the linear transformation layer --- model_pro/MHA.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/model_pro/MHA.py b/model_pro/MHA.py index f1cb4bb..d4c4e68 100644 --- a/model_pro/MHA.py +++ b/model_pro/MHA.py @@ -9,10 +9,15 @@ class MultiHeadAttentionLayer(nn.Module): self.head_dim = embed_size // num_heads assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" + + # Define linear layers for Q, K, V + self.q_linear = nn.Linear(embed_size, embed_size) + self.k_linear = nn.Linear(embed_size, embed_size) + self.v_linear = nn.Linear(embed_size, embed_size) if __name__ == "__main__": embed_size = 512 num_heads = 8 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) - print("Model initialized successfully.") + print("Linear layers for Q, K, V initialized.")