|
import torch.nn as nn |
|
|
|
|
|
def get_activation(act_type=None): |
|
if act_type == 'relu': |
|
return nn.ReLU(inplace=True) |
|
elif act_type == 'lrelu': |
|
return nn.LeakyReLU(0.1, inplace=True) |
|
elif act_type == 'mish': |
|
return nn.Mish(inplace=True) |
|
elif act_type == 'silu': |
|
return nn.SiLU(inplace=True) |
|
|
|
|
|
|
|
def get_conv2d(c1, c2, k, p, s, d, g, bias=False): |
|
conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias) |
|
return conv |
|
|
|
|
|
def get_norm2d(norm_type, dim): |
|
if norm_type == 'BN': |
|
return nn.BatchNorm2d(dim) |
|
elif norm_type == 'IN': |
|
return nn.InstanceNorm2d(dim) |
|
|
|
|
|
class Conv2d(nn.Module): |
|
def __init__(self, |
|
c1, |
|
c2, |
|
k=1, |
|
p=0, |
|
s=1, |
|
d=1, |
|
g=1, |
|
act_type='', |
|
norm_type='', |
|
depthwise=False): |
|
super(Conv2d, self).__init__() |
|
convs = [] |
|
add_bias = False if norm_type else True |
|
if depthwise: |
|
assert c1 == c2, "In depthwise conv, the in_dim (c1) should be equal to out_dim (c2)." |
|
convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=c1, bias=add_bias)) |
|
|
|
if norm_type: |
|
convs.append(get_norm2d(norm_type, c2)) |
|
if act_type: |
|
convs.append(get_activation(act_type)) |
|
|
|
convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias)) |
|
if norm_type: |
|
convs.append(get_norm2d(norm_type, c2)) |
|
if act_type: |
|
convs.append(get_activation(act_type)) |
|
|
|
else: |
|
convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=g, bias=add_bias)) |
|
if norm_type: |
|
convs.append(get_norm2d(norm_type, c2)) |
|
if act_type: |
|
convs.append(get_activation(act_type)) |
|
|
|
self.convs = nn.Sequential(*convs) |
|
|
|
|
|
def forward(self, x): |
|
return self.convs(x) |
|
|
|
|
|
|
|
def get_conv3d(c1, c2, k, p, s, d, g, bias=False): |
|
conv = nn.Conv3d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias) |
|
return conv |
|
|
|
|
|
def get_norm3d(norm_type, dim): |
|
if norm_type == 'BN': |
|
return nn.BatchNorm3d(dim) |
|
elif norm_type == 'IN': |
|
return nn.InstanceNorm3d(dim) |
|
|
|
|
|
class Conv3d(nn.Module): |
|
def __init__(self, |
|
c1, |
|
c2, |
|
k=1, |
|
p=0, |
|
s=1, |
|
d=1, |
|
g=1, |
|
act_type='', |
|
norm_type='', |
|
depthwise=False): |
|
super(Conv3d, self).__init__() |
|
convs = [] |
|
add_bias = False if norm_type else True |
|
if depthwise: |
|
assert c1 == c2, "In depthwise conv, the in_dim (c1) should be equal to out_dim (c2)." |
|
convs.append(get_conv3d(c1, c2, k=k, p=p, s=s, d=d, g=c1, bias=add_bias)) |
|
|
|
if norm_type: |
|
convs.append(get_norm3d(norm_type, c2)) |
|
if act_type: |
|
convs.append(get_activation(act_type)) |
|
|
|
convs.append(get_conv3d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias)) |
|
if norm_type: |
|
convs.append(get_norm3d(norm_type, c2)) |
|
if act_type: |
|
convs.append(get_activation(act_type)) |
|
|
|
else: |
|
convs.append(get_conv3d(c1, c2, k=k, p=p, s=s, d=d, g=g, bias=add_bias)) |
|
if norm_type: |
|
convs.append(get_norm3d(norm_type, c2)) |
|
if act_type: |
|
convs.append(get_activation(act_type)) |
|
|
|
self.convs = nn.Sequential(*convs) |
|
|
|
|
|
def forward(self, x): |
|
return self.convs(x) |
|
|
|
|