shivambhosale's picture
Update UNet.py
3904368
raw history blame
No virus
4.25 kB
import torch
from torch.nn import Module, Conv2d, ReLU, ModuleList, MaxPool2d, ConvTranspose2d, BCELoss, BCEWithLogitsLoss, functional as F
from torch.optim import Adam
from torchvision import transforms
from torchvision.transforms import CenterCrop
from torch.utils.data import Dataset, DataLoader
import cv2
class Block(Module):
def __init__(self, inChannels, outChannels):
super().__init__()
# store the convolution and RELU layers
self.conv1 = Conv2d(inChannels, outChannels, 3)
self.relu = ReLU()
self.conv2 = Conv2d(outChannels, outChannels, 3)
def forward(self, x):
# apply CONV => RELU => CONV block to the inputs and return it
return self.conv2(self.relu(self.conv1(x)))
class Encoder(Module):
def __init__(self, channels=(3, 16, 32, 64)):
super().__init__()
# store the encoder blocks and maxpooling layer
self.encBlocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])
self.pool = MaxPool2d(2)
def forward(self, x):
# initialize an empty list to store the intermediate outputs
blockOutputs = []
# loop through the encoder blocks
for block in self.encBlocks:
# pass the inputs through the current encoder block, store
# the outputs, and then apply maxpooling on the output
x = block(x)
blockOutputs.append(x)
x = self.pool(x)
# return the list containing the intermediate outputs
return blockOutputs
class Decoder(Module):
def __init__(self, channels=(64, 32, 16)):
super().__init__()
# initialize the number of channels, upsampler blocks, and
# decoder blocks
self.channels = channels
self.upconvs = ModuleList([ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)])
self.dec_blocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])
def forward(self, x, encFeatures):
# loop through the number of channels
for i in range(len(self.channels) - 1):
# pass the inputs through the upsampler blocks
x = self.upconvs[i](x)
# crop the current features from the encoder blocks,
# concatenate them with the current upsampled features,
# and pass the concatenated output through the current
# decoder block
encFeat = self.crop(encFeatures[i], x)
x = torch.cat([x, encFeat], dim=1)
x = self.dec_blocks[i](x)
# return the final decoder output
return x
def crop(self, encFeatures, x):
# grab the dimensions of the inputs, and crop the encoder
# features to match the dimensions
(_, _, H, W) = x.shape
encFeatures = CenterCrop([H, W])(encFeatures)
# return the cropped features
return encFeatures
class UNet(Module):
def __init__(self, encChannels=(3, 64, 128, 256, 512, 1024), decChannels=(1024, 512, 256, 128, 64),
nbClasses=1, retainDim=True, outSize=(256, 256)):
super().__init__()
# initialize the encoder and decoder
self.encoder = Encoder(encChannels)
self.decoder = Decoder(decChannels)
# initialize the regression head and store the class variables
self.head = Conv2d(decChannels[-1], nbClasses, 1)
self.retainDim = retainDim
self.outSize = outSize
def forward(self, x):
# grab the features from the encoder
encFeatures = self.encoder(x)
# pass the encoder features through decoder making sure that
# their dimensions are suited for concatenation
decFeatures = self.decoder(encFeatures[::-1][0], encFeatures[::-1][1:])
# pass the decoder features through the regression head to
# obtain the segmentation mask
map_ = self.head(decFeatures)
# check to see if we are retaining the original output
# dimensions and if so, then resize the output to match them
if self.retainDim:
map_ = F.interpolate(map_, self.outSize)
# return the segmentation map
return map_