Define the linear transformation layer
This commit is contained in:
+6
-1
@@ -9,10 +9,15 @@ class MultiHeadAttentionLayer(nn.Module):
|
|||||||
self.head_dim = embed_size // num_heads
|
self.head_dim = embed_size // num_heads
|
||||||
|
|
||||||
assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads"
|
assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads"
|
||||||
|
|
||||||
|
# Define linear layers for Q, K, V
|
||||||
|
self.q_linear = nn.Linear(embed_size, embed_size)
|
||||||
|
self.k_linear = nn.Linear(embed_size, embed_size)
|
||||||
|
self.v_linear = nn.Linear(embed_size, embed_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
embed_size = 512
|
embed_size = 512
|
||||||
num_heads = 8
|
num_heads = 8
|
||||||
mha_layer = MultiHeadAttentionLayer(embed_size, num_heads)
|
mha_layer = MultiHeadAttentionLayer(embed_size, num_heads)
|
||||||
print("Model initialized successfully.")
|
print("Linear layers for Q, K, V initialized.")
|
||||||
|
|||||||
Reference in New Issue
Block a user