From ee739c3c811d699834c81cdd9e14ec80cf96cdc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Sat, 5 Oct 2024 00:49:24 +0800 Subject: [PATCH] Multi-head attention mechanism infrastructure and input dimension settings. --- model_pro/MHA.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 model_pro/MHA.py diff --git a/model_pro/MHA.py b/model_pro/MHA.py new file mode 100644 index 0000000..f1cb4bb --- /dev/null +++ b/model_pro/MHA.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + +class MultiHeadAttentionLayer(nn.Module): + def __init__(self, embed_size, num_heads): + super(MultiHeadAttentionLayer, self).__init__() + self.embed_size = embed_size + self.num_heads = num_heads + self.head_dim = embed_size // num_heads + + assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" + + +if __name__ == "__main__": + embed_size = 512 + num_heads = 8 + mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) + print("Model initialized successfully.")