diff --git a/model_pro/MHA.py b/model_pro/MHA.py index 9d95a21..3de7ff3 100644 --- a/model_pro/MHA.py +++ b/model_pro/MHA.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F class MultiHeadAttentionLayer(nn.Module): def __init__(self, embed_size, num_heads): @@ -19,16 +20,21 @@ class MultiHeadAttentionLayer(nn.Module): N = query.shape[0] # batch_size # Linear transformations for Q, K, V - Q = self.q_linear(query) # shape: (N, seq_len, embed_size) - K = self.k_linear(keys) # shape: (N, seq_len, embed_size) - V = self.v_linear(values) # shape: (N, seq_len, embed_size) + Q = self.q_linear(query) + K = self.k_linear(keys) + V = self.v_linear(values) - # Reshape Q, K, V into multiple heads + # Reshape into multiple heads Q = Q.reshape(N, -1, self.num_heads, self.head_dim) K = K.reshape(N, -1, self.num_heads, self.head_dim) V = V.reshape(N, -1, self.num_heads, self.head_dim) - return Q, K, V + # Compute scaled dot-product attention scores + attention_scores = torch.einsum("nqhd,nkhd->nhqk", [Q, K]) + attention_scores = attention_scores / (self.head_dim ** 0.5) + attention = torch.softmax(attention_scores, dim=-1) # Normalize + + return attention if __name__ == "__main__": @@ -36,10 +42,9 @@ if __name__ == "__main__": num_heads = 8 mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) - # Dummy data values = torch.randn(2, 10, embed_size) keys = torch.randn(2, 10, embed_size) query = torch.randn(2, 10, embed_size) - Q, K, V = mha_layer(values, keys, query) - print(f"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}") + attention = mha_layer(values, keys, query) + print(f"Attention shape: {attention.shape}")