Thesis / models /basic /conv.py
Ryan-Pham's picture
Upload 103 files
beb7843 verified
raw
history blame
4.35 kB
import torch.nn as nn
def get_activation(act_type=None):
if act_type == 'relu':
return nn.ReLU(inplace=True)
elif act_type == 'lrelu':
return nn.LeakyReLU(0.1, inplace=True)
elif act_type == 'mish':
return nn.Mish(inplace=True)
elif act_type == 'silu':
return nn.SiLU(inplace=True)
# 2D Conv
def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
return conv
def get_norm2d(norm_type, dim):
if norm_type == 'BN':
return nn.BatchNorm2d(dim)
elif norm_type == 'IN':
return nn.InstanceNorm2d(dim)
class Conv2d(nn.Module):
def __init__(self,
c1, # in channels
c2, # out channels
k=1, # kernel size
p=0, # padding
s=1, # padding
d=1, # dilation
g=1,
act_type='', # activation
norm_type='', # normalization
depthwise=False):
super(Conv2d, self).__init__()
convs = []
add_bias = False if norm_type else True
if depthwise:
assert c1 == c2, "In depthwise conv, the in_dim (c1) should be equal to out_dim (c2)."
convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
# depthwise conv
if norm_type:
convs.append(get_norm2d(norm_type, c2))
if act_type:
convs.append(get_activation(act_type))
# pointwise conv
convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
if norm_type:
convs.append(get_norm2d(norm_type, c2))
if act_type:
convs.append(get_activation(act_type))
else:
convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=g, bias=add_bias))
if norm_type:
convs.append(get_norm2d(norm_type, c2))
if act_type:
convs.append(get_activation(act_type))
self.convs = nn.Sequential(*convs)
def forward(self, x):
return self.convs(x)
# 3D Conv
def get_conv3d(c1, c2, k, p, s, d, g, bias=False):
conv = nn.Conv3d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
return conv
def get_norm3d(norm_type, dim):
if norm_type == 'BN':
return nn.BatchNorm3d(dim)
elif norm_type == 'IN':
return nn.InstanceNorm3d(dim)
class Conv3d(nn.Module):
def __init__(self,
c1, # in channels
c2, # out channels
k=1, # kernel size
p=0, # padding
s=1, # padding
d=1, # dilation
g=1,
act_type='', # activation
norm_type='', # normalization
depthwise=False):
super(Conv3d, self).__init__()
convs = []
add_bias = False if norm_type else True
if depthwise:
assert c1 == c2, "In depthwise conv, the in_dim (c1) should be equal to out_dim (c2)."
convs.append(get_conv3d(c1, c2, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
# depthwise conv
if norm_type:
convs.append(get_norm3d(norm_type, c2))
if act_type:
convs.append(get_activation(act_type))
# pointwise conv
convs.append(get_conv3d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
if norm_type:
convs.append(get_norm3d(norm_type, c2))
if act_type:
convs.append(get_activation(act_type))
else:
convs.append(get_conv3d(c1, c2, k=k, p=p, s=s, d=d, g=g, bias=add_bias))
if norm_type:
convs.append(get_norm3d(norm_type, c2))
if act_type:
convs.append(get_activation(act_type))
self.convs = nn.Sequential(*convs)
def forward(self, x):
return self.convs(x)