ImageSeg_IPPR / unet.py
Manthanx's picture
Upload 4 files
ba20db4
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
class UNet(nn.Module):
def __init__(self, num_classes):
super(UNet, self).__init__()
self.num_classes = num_classes
self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
self.middle = self.conv_block(in_channels=512, out_channels=1024)
self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1)
def conv_block(self, in_channels, out_channels):
block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.BatchNorm2d(num_features=out_channels),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.BatchNorm2d(num_features=out_channels))
return block
def forward(self, X):
contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256]
contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128]
contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128]
contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64]
contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64]
contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32]
contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32]
contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16]
middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]
expansive_11_out = self.expansive_11(middle_out) # [-1, 512, 32, 32]
expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024, 32, 32] -> [-1, 512, 32, 32]
expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64]
expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128]
expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256]
expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256]
return output_out