From af5e2265eec1dd25c9b2700961df42ee1cfce782 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Tue, 15 Oct 2024 08:13:08 +0800 Subject: [PATCH] The final classification layer is complete --- model_pro/classifier.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 model_pro/classifier.py diff --git a/model_pro/classifier.py b/model_pro/classifier.py new file mode 100644 index 0000000..5215811 --- /dev/null +++ b/model_pro/classifier.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +class FinalClassifier(nn.Module): + def __init__(self, input_dim, num_classes, hidden_dim=512, dropout_rate=0.3): + super(FinalClassifier, self).__init__() + # 增加一个隐藏层 + self.fc1 = nn.Linear(input_dim, hidden_dim) # 第一层全连接层 + self.fc2 = nn.Linear(hidden_dim, num_classes) # 第二层全连接层 + self.dropout = nn.Dropout(dropout_rate) # Dropout 防止过拟合 + self.relu = nn.ReLU() # 激活函数 + + def forward(self, x): + x = self.relu(self.fc1(x)) # 第一层全连接 + ReLU 激活 + x = self.dropout(x) # Dropout + out = self.fc2(x) # 最终输出层(未应用 softmax) + return out