|
|
import torch |
|
|
|
|
|
class CNN2D(torch.nn.Module): |
|
|
|
|
|
def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15): |
|
|
assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding) |
|
|
super(CNN2D, self).__init__() |
|
|
|
|
|
|
|
|
self.conv_blocks = torch.nn.ModuleList() |
|
|
prev_channel = 1 |
|
|
|
|
|
for i in range(len(channels)): |
|
|
|
|
|
block = [] |
|
|
for j, conv_channel in enumerate(channels[i]): |
|
|
block.append(torch.nn.Conv2d(in_channels=prev_channel, out_channels=conv_channel, kernel_size=conv_kernels[i], stride=conv_strides[i], padding=conv_padding[i])) |
|
|
prev_channel = conv_channel |
|
|
|
|
|
block.append(torch.nn.BatchNorm2d(prev_channel)) |
|
|
|
|
|
block.append(torch.nn.ReLU()) |
|
|
self.conv_blocks.append(torch.nn.Sequential(*block)) |
|
|
|
|
|
|
|
|
self.pool_blocks = torch.nn.ModuleList() |
|
|
for i in range(len(pool_padding)): |
|
|
|
|
|
self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i])) |
|
|
|
|
|
|
|
|
self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
|
|
self.linear = torch.nn.Linear(prev_channel, num_classes) |
|
|
|
|
|
def forward(self, inwav): |
|
|
for i in range(len(self.conv_blocks)): |
|
|
|
|
|
inwav = self.conv_blocks[i](inwav) |
|
|
|
|
|
if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav) |
|
|
|
|
|
out = self.global_pool(inwav).squeeze() |
|
|
out = self.linear(out) |
|
|
return out |
|
|
|
|
|
class ResBlock2D(torch.nn.Module): |
|
|
|
|
|
def __init__(self, prev_channel, channel, conv_kernel, conv_stride, conv_pad): |
|
|
super(ResBlock2D, self).__init__() |
|
|
self.res = torch.nn.Sequential( |
|
|
torch.nn.Conv2d(in_channels=prev_channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad), |
|
|
torch.nn.BatchNorm2d(channel), |
|
|
torch.nn.ReLU(), |
|
|
torch.nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad), |
|
|
torch.nn.BatchNorm2d(channel), |
|
|
) |
|
|
self.bn = torch.nn.BatchNorm2d(channel) |
|
|
self.relu = torch.nn.ReLU() |
|
|
|
|
|
def forward(self, x): |
|
|
identity = x |
|
|
x = self.res(x) |
|
|
if x.shape[1] == identity.shape[1]: |
|
|
x += identity |
|
|
elif x.shape[1] > identity.shape[1]: |
|
|
if x.shape[1] % identity.shape[1] == 0: |
|
|
x += identity.repeat(1, x.shape[1]//identity.shape[1], 1, 1) |
|
|
else: |
|
|
raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!") |
|
|
else: |
|
|
if identity.shape[1] % x.shape[1] == 0: |
|
|
identity += x.repeat(1, identity.shape[1]//x.shape[1], 1, 1) |
|
|
else: |
|
|
raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!") |
|
|
x = identity |
|
|
x = self.bn(x) |
|
|
x = self.relu(x) |
|
|
return x |
|
|
|
|
|
class CNNRes2D(torch.nn.Module): |
|
|
|
|
|
def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15): |
|
|
assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding) |
|
|
super(CNNRes2D, self).__init__() |
|
|
|
|
|
|
|
|
prev_channel = 1 |
|
|
self.conv_block = torch.nn.Sequential( |
|
|
torch.nn.Conv2d(in_channels=prev_channel, out_channels=channels[0][0], kernel_size=conv_kernels[0], stride=conv_strides[0], padding=conv_padding[0]), |
|
|
torch.nn.BatchNorm2d(channels[0][0]), |
|
|
torch.nn.ReLU(), |
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[0]), |
|
|
) |
|
|
|
|
|
|
|
|
prev_channel = channels[0][0] |
|
|
self.res_blocks = torch.nn.ModuleList() |
|
|
for i in range(1, len(channels)): |
|
|
block = [] |
|
|
for j, conv_channel in enumerate(channels[i]): |
|
|
block.append(ResBlock2D(prev_channel, conv_channel, conv_kernels[i], conv_strides[i], conv_padding[i])) |
|
|
prev_channel = conv_channel |
|
|
self.res_blocks.append(torch.nn.Sequential(*block)) |
|
|
|
|
|
|
|
|
self.pool_blocks = torch.nn.ModuleList() |
|
|
for i in range(1, len(pool_padding)): |
|
|
self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i])) |
|
|
|
|
|
|
|
|
self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
|
|
self.linear = torch.nn.Linear(prev_channel, num_classes) |
|
|
|
|
|
def forward(self, inwav): |
|
|
inwav = self.conv_block(inwav) |
|
|
for i in range(len(self.res_blocks)): |
|
|
inwav = self.res_blocks[i](inwav) |
|
|
if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav) |
|
|
out = self.global_pool(inwav).squeeze() |
|
|
out = self.linear(out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|