Spaces:
Sleeping
Sleeping
import torch | |
class DoubleConv(torch.nn.Module): | |
""" | |
Helper Class which implements the intermediate Convolutions | |
""" | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels, out_channels, 3, padding=1), | |
torch.nn.ReLU(), | |
torch.nn.Conv3d(out_channels, out_channels, 3, padding=1), | |
torch.nn.ReLU()) | |
def forward(self, X): | |
return self.step(X) | |
class UNet(torch.nn.Module): | |
""" | |
This class implements a UNet for the Segmentation | |
We use 3 down- and 3 UpConvolutions and two Convolutions in each step | |
""" | |
def __init__(self): | |
"""Sets up the U-Net Structure | |
""" | |
super().__init__() | |
############# DOWN SAMPLING ##################### | |
self.layer1 = DoubleConv(1, 32) | |
self.layer2 = DoubleConv(32, 64) | |
self.layer3 = DoubleConv(64, 128) | |
self.layer4 = DoubleConv(128, 256) | |
######################################### | |
############## UP SAMPLING ####################### | |
self.layer5 = DoubleConv(256 + 128, 128) | |
self.layer6 = DoubleConv(128+64, 64) | |
self.layer7 = DoubleConv(64+32, 32) | |
self.layer8 = torch.nn.Conv3d(32, 6, 1) # Output: 5 values -> background, upper jaw, lower jaw,upper teeth, lower teeth, artery | |
######################################### | |
self.maxpool = torch.nn.MaxPool3d(2) | |
def forward(self, x): | |
####### DownConv 1######### | |
x1 = self.layer1(x) | |
x1m = self.maxpool(x1) | |
########################### | |
####### DownConv 2######### | |
x2 = self.layer2(x1m) | |
x2m = self.maxpool(x2) | |
########################### | |
####### DownConv 3######### | |
x3 = self.layer3(x2m) | |
x3m = self.maxpool(x3) | |
########################### | |
##### Intermediate Layer ## | |
x4 = self.layer4(x3m) | |
########################### | |
####### UpCONV 1######### | |
x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4) # Upsample with a factor of 2 | |
x5 = torch.cat([x5, x3], dim=1) # Skip-Connection | |
x5 = self.layer5(x5) | |
########################### | |
####### UpCONV 2######### | |
x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5) | |
x6 = torch.cat([x6, x2], dim=1) # Skip-Connection AKA downsampling | |
x6 = self.layer6(x6) | |
########################### | |
####### UpCONV 3######### | |
x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6) | |
x7 = torch.cat([x7, x1], dim=1) | |
x7 = self.layer7(x7) | |
########################### | |
####### Predicted segmentation######### | |
ret = self.layer8(x7) | |
return ret |