From 4500b2719e5477ab6424de205665a23d15d12b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Sun, 6 Oct 2024 11:54:32 +0800 Subject: [PATCH] Divide the input into long heads --- model_pro/MHA.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/model_pro/MHA.py b/model_pro/MHA.py index d4c4e68..9d95a21 100644 --- a/model_pro/MHA.py +++ b/model_pro/MHA.py @@ -15,9 +15,31 @@ class MultiHeadAttentionLayer(nn.Module): self.k_linear = nn.Linear(embed_size, embed_size) self.v_linear = nn.Linear(embed_size, embed_size) + def forward(self, values, keys, query): + N = query.shape[0] # batch_size + + # Linear transformations for Q, K, V + Q = self.q_linear(query) # shape: (N, seq_len, embed_size) + K = self.k_linear(keys) # shape: (N, seq_len, embed_size) + V = self.v_linear(values) # shape: (N, seq_len, embed_size) + + # Reshape Q, K, V into multiple heads + Q = Q.reshape(N, -1, self.num_heads, self.head_dim) + K = K.reshape(N, -1, self.num_heads, self.head_dim) + V = V.reshape(N, -1, self.num_heads, self.head_dim) + + return Q, K, V + if __name__ == "__main__": embed_size = 512 num_heads = 8 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) - print("Linear layers for Q, K, V initialized.") + + # Dummy data + values = torch.randn(2, 10, embed_size) + keys = torch.randn(2, 10, embed_size) + query = torch.randn(2, 10, embed_size) + + Q, K, V = mha_layer(values, keys, query) + print(f"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}")