Andres Felipe Ruiz-Hurtado
initial
bc97962
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
class ConvBNRelu(nn.Module):
def __init__(self, in_ch, out_ch):
super(ConvBNRelu, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=True)
self.bn = nn.BatchNorm2d(out_ch)
self.activation = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.activation(x)
# print(x.shape)
return x
class FirstBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super(FirstBlock, self).__init__()
self.conv1 = ConvBNRelu(in_ch, out_ch)
self.conv2 = ConvBNRelu(out_ch, out_ch)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class DownBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super(DownBlock, self).__init__()
self.conv1 = ConvBNRelu(in_ch, out_ch)
self.conv2 = ConvBNRelu(out_ch, out_ch)
def forward(self, x):
x = F.max_pool2d(x,kernel_size=2,stride=2)
x = self.conv1(x)
x = self.conv2(x)
return x
class Encoder(nn.Module):
def __init__(self, in_ch, out_ch, block_num=2):
super(Encoder, self).__init__()
layers = []
layers += [ConvBNRelu(in_ch, out_ch)]
for i in range(block_num-1):
layers += [ConvBNRelu(out_ch, out_ch)]
# layers += [nn.Dropout2d(0.5)]
self.features = nn.Sequential(*layers)
def forward(self, x):
x = self.features(x)
x, indices = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
return x, indices
class Decoder(nn.Module):
def __init__(self, in_ch, out_ch, block_num=2):
super(Decoder, self).__init__()
layers = []
layers += [ConvBNRelu(in_ch, out_ch)]
for i in range(block_num-1):
layers += [ConvBNRelu(out_ch, out_ch)]
# layers += [nn.Dropout2d(0.5)]
self.features = nn.Sequential(*layers)
def forward(self, x, indices):
x = F.max_unpool2d(x, indices=indices, kernel_size=2, stride=2)
x = self.features(x)
return x
class SegRoot(nn.Module):
def __init__(self, width=8, depth=5, num_classes=2):
super(SegRoot, self).__init__()
chs = []
for i in range(depth-1):
chs.append(width * (2**i))
chs.append(chs[-1])
self.e_ch_info = [3,] + chs
self.e_bl_info = [2,2,3,3]
for _ in range(depth - 4):
self.e_bl_info += [3,]
self.d_ch_info = chs[::-1] + [4,]
self.d_bl_info = self.e_bl_info[::-1]
# using same setup with Unet
if width == 4:
self.e_ch_info = [3,4,8,16,32,64]
self.d_ch_info = [64,32,16,8,4,4]
self.num_classes = num_classes
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
for i in range(1,len(self.e_ch_info)):
self.encoders.append(Encoder(self.e_ch_info[i-1], self.e_ch_info[i], self.e_bl_info[i-1]))
self.decoders.append(Decoder(self.d_ch_info[i-1], self.d_ch_info[i], self.d_bl_info[i-1]))
# self.classifier = nn.Conv2d(self.d_ch_info[-1], num_classes, kernel_size=3, padding=1)
self.classifier = nn.Conv2d(self.d_ch_info[-1], 1, 1)
def forward(self, x):
indices = []
bs = x.shape[0]
for i in range(len(self.e_bl_info)):
x, ind = self.encoders[i](x)
indices.append(ind)
indices = indices[::-1]
for i in range(len(self.e_bl_info)):
x = self.decoders[i](x, indices[i])
x = self.classifier(x)
# x = F.softmax(x,dim=1)
x = torch.sigmoid(x)
return x
if __name__ == '__main__':
x = torch.zeros((1, 3, 256, 256))
net = SegRoot(8,5)
print(net(x).shape)