Spaces:
Runtime error
Runtime error
import torch | |
from torch.nn import Module, Conv2d, ReLU, ModuleList, MaxPool2d, ConvTranspose2d, BCELoss, BCEWithLogitsLoss, functional as F | |
from torch.optim import Adam | |
from torchvision import transforms | |
from torchvision.transforms import CenterCrop | |
from torch.utils.data import Dataset, DataLoader | |
import cv2 | |
class Block(Module): | |
def __init__(self, inChannels, outChannels): | |
super().__init__() | |
# store the convolution and RELU layers | |
self.conv1 = Conv2d(inChannels, outChannels, 3) | |
self.relu = ReLU() | |
self.conv2 = Conv2d(outChannels, outChannels, 3) | |
def forward(self, x): | |
# apply CONV => RELU => CONV block to the inputs and return it | |
return self.conv2(self.relu(self.conv1(x))) | |
class Encoder(Module): | |
def __init__(self, channels=(3, 16, 32, 64)): | |
super().__init__() | |
# store the encoder blocks and maxpooling layer | |
self.encBlocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]) | |
self.pool = MaxPool2d(2) | |
def forward(self, x): | |
# initialize an empty list to store the intermediate outputs | |
blockOutputs = [] | |
# loop through the encoder blocks | |
for block in self.encBlocks: | |
# pass the inputs through the current encoder block, store | |
# the outputs, and then apply maxpooling on the output | |
x = block(x) | |
blockOutputs.append(x) | |
x = self.pool(x) | |
# return the list containing the intermediate outputs | |
return blockOutputs | |
class Decoder(Module): | |
def __init__(self, channels=(64, 32, 16)): | |
super().__init__() | |
# initialize the number of channels, upsampler blocks, and | |
# decoder blocks | |
self.channels = channels | |
self.upconvs = ModuleList([ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)]) | |
self.dec_blocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]) | |
def forward(self, x, encFeatures): | |
# loop through the number of channels | |
for i in range(len(self.channels) - 1): | |
# pass the inputs through the upsampler blocks | |
x = self.upconvs[i](x) | |
# crop the current features from the encoder blocks, | |
# concatenate them with the current upsampled features, | |
# and pass the concatenated output through the current | |
# decoder block | |
encFeat = self.crop(encFeatures[i], x) | |
x = torch.cat([x, encFeat], dim=1) | |
x = self.dec_blocks[i](x) | |
# return the final decoder output | |
return x | |
def crop(self, encFeatures, x): | |
# grab the dimensions of the inputs, and crop the encoder | |
# features to match the dimensions | |
(_, _, H, W) = x.shape | |
encFeatures = CenterCrop([H, W])(encFeatures) | |
# return the cropped features | |
return encFeatures | |
class UNet(Module): | |
def __init__(self, encChannels=(3, 64, 128, 256, 512, 1024), decChannels=(1024, 512, 256, 128, 64), | |
nbClasses=1, retainDim=True, outSize=(256, 256)): | |
super().__init__() | |
# initialize the encoder and decoder | |
self.encoder = Encoder(encChannels) | |
self.decoder = Decoder(decChannels) | |
# initialize the regression head and store the class variables | |
self.head = Conv2d(decChannels[-1], nbClasses, 1) | |
self.retainDim = retainDim | |
self.outSize = outSize | |
def forward(self, x): | |
# grab the features from the encoder | |
encFeatures = self.encoder(x) | |
# pass the encoder features through decoder making sure that | |
# their dimensions are suited for concatenation | |
decFeatures = self.decoder(encFeatures[::-1][0], encFeatures[::-1][1:]) | |
# pass the decoder features through the regression head to | |
# obtain the segmentation mask | |
map_ = self.head(decFeatures) | |
# check to see if we are retaining the original output | |
# dimensions and if so, then resize the output to match them | |
if self.retainDim: | |
map_ = F.interpolate(map_, self.outSize) | |
# return the segmentation map | |
return map_ |