Spaces:
Build error
Build error
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
from util.GDANet_util import local_operator, GDM, SGCAM | |
class GDANET(nn.Module): | |
def __init__(self): | |
super(GDANET, self).__init__() | |
self.bn1 = nn.BatchNorm2d(64, momentum=0.1) | |
self.bn11 = nn.BatchNorm2d(64, momentum=0.1) | |
self.bn12 = nn.BatchNorm1d(64, momentum=0.1) | |
self.bn2 = nn.BatchNorm2d(64, momentum=0.1) | |
self.bn21 = nn.BatchNorm2d(64, momentum=0.1) | |
self.bn22 = nn.BatchNorm1d(64, momentum=0.1) | |
self.bn3 = nn.BatchNorm2d(128, momentum=0.1) | |
self.bn31 = nn.BatchNorm2d(128, momentum=0.1) | |
self.bn32 = nn.BatchNorm1d(128, momentum=0.1) | |
self.bn4 = nn.BatchNorm1d(512, momentum=0.1) | |
self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=True), | |
self.bn1) | |
self.conv11 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=True), | |
self.bn11) | |
self.conv12 = nn.Sequential(nn.Conv1d(64 * 2, 64, kernel_size=1, bias=True), | |
self.bn12) | |
self.conv2 = nn.Sequential(nn.Conv2d(67 * 2, 64, kernel_size=1, bias=True), | |
self.bn2) | |
self.conv21 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=True), | |
self.bn21) | |
self.conv22 = nn.Sequential(nn.Conv1d(64 * 2, 64, kernel_size=1, bias=True), | |
self.bn22) | |
self.conv3 = nn.Sequential(nn.Conv2d(131 * 2, 128, kernel_size=1, bias=True), | |
self.bn3) | |
self.conv31 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=1, bias=True), | |
self.bn31) | |
self.conv32 = nn.Sequential(nn.Conv1d(128, 128, kernel_size=1, bias=True), | |
self.bn32) | |
self.conv4 = nn.Sequential(nn.Conv1d(256, 512, kernel_size=1, bias=True), | |
self.bn4) | |
self.SGCAM_1s = SGCAM(64) | |
self.SGCAM_1g = SGCAM(64) | |
self.SGCAM_2s = SGCAM(64) | |
self.SGCAM_2g = SGCAM(64) | |
self.linear1 = nn.Linear(1024, 512, bias=True) | |
self.bn6 = nn.BatchNorm1d(512) | |
self.dp1 = nn.Dropout(p=0.4) | |
self.linear2 = nn.Linear(512, 256, bias=True) | |
self.bn7 = nn.BatchNorm1d(256) | |
self.dp2 = nn.Dropout(p=0.4) | |
self.linear3 = nn.Linear(256, 40, bias=True) | |
def forward(self, x): | |
B, C, N = x.size() | |
############### | |
"""block 1""" | |
# Local operator: | |
x1 = local_operator(x, k=30) | |
x1 = F.relu(self.conv1(x1)) | |
x1 = F.relu(self.conv11(x1)) | |
x1 = x1.max(dim=-1, keepdim=False)[0] | |
# Geometry-Disentangle Module: | |
x1s, x1g = GDM(x1, M=256) | |
# Sharp-Gentle Complementary Attention Module: | |
y1s = self.SGCAM_1s(x1, x1s.transpose(2, 1)) | |
y1g = self.SGCAM_1g(x1, x1g.transpose(2, 1)) | |
z1 = torch.cat([y1s, y1g], 1) | |
z1 = F.relu(self.conv12(z1)) | |
############### | |
"""block 2""" | |
x1t = torch.cat((x, z1), dim=1) | |
x2 = local_operator(x1t, k=30) | |
x2 = F.relu(self.conv2(x2)) | |
x2 = F.relu(self.conv21(x2)) | |
x2 = x2.max(dim=-1, keepdim=False)[0] | |
x2s, x2g = GDM(x2, M=256) | |
y2s = self.SGCAM_2s(x2, x2s.transpose(2, 1)) | |
y2g = self.SGCAM_2g(x2, x2g.transpose(2, 1)) | |
z2 = torch.cat([y2s, y2g], 1) | |
z2 = F.relu(self.conv22(z2)) | |
############### | |
x2t = torch.cat((x1t, z2), dim=1) | |
x3 = local_operator(x2t, k=30) | |
x3 = F.relu(self.conv3(x3)) | |
x3 = F.relu(self.conv31(x3)) | |
x3 = x3.max(dim=-1, keepdim=False)[0] | |
z3 = F.relu(self.conv32(x3)) | |
############### | |
x = torch.cat((z1, z2, z3), dim=1) | |
x = F.relu(self.conv4(x)) | |
x11 = F.adaptive_max_pool1d(x, 1).view(B, -1) | |
x22 = F.adaptive_avg_pool1d(x, 1).view(B, -1) | |
x = torch.cat((x11, x22), 1) | |
x = F.relu(self.bn6(self.linear1(x))) | |
x = self.dp1(x) | |
x = F.relu(self.bn7(self.linear2(x))) | |
x = self.dp2(x) | |
x = self.linear3(x) | |
return x | |