Spaces:
Sleeping
Sleeping
File size: 3,049 Bytes
02443c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
class DoubleConv(torch.nn.Module):
"""
Helper Class which implements the intermediate Convolutions
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels, out_channels, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv3d(out_channels, out_channels, 3, padding=1),
torch.nn.ReLU())
def forward(self, X):
return self.step(X)
class UNet(torch.nn.Module):
"""
This class implements a UNet for the Segmentation
We use 3 down- and 3 UpConvolutions and two Convolutions in each step
"""
def __init__(self):
"""Sets up the U-Net Structure
"""
super().__init__()
############# DOWN SAMPLING #####################
self.layer1 = DoubleConv(1, 32)
self.layer2 = DoubleConv(32, 64)
self.layer3 = DoubleConv(64, 128)
self.layer4 = DoubleConv(128, 256)
#########################################
############## UP SAMPLING #######################
self.layer5 = DoubleConv(256 + 128, 128)
self.layer6 = DoubleConv(128+64, 64)
self.layer7 = DoubleConv(64+32, 32)
self.layer8 = torch.nn.Conv3d(32, 6, 1) # Output: 5 values -> background, upper jaw, lower jaw,upper teeth, lower teeth, artery
#########################################
self.maxpool = torch.nn.MaxPool3d(2)
def forward(self, x):
####### DownConv 1#########
x1 = self.layer1(x)
x1m = self.maxpool(x1)
###########################
####### DownConv 2#########
x2 = self.layer2(x1m)
x2m = self.maxpool(x2)
###########################
####### DownConv 3#########
x3 = self.layer3(x2m)
x3m = self.maxpool(x3)
###########################
##### Intermediate Layer ##
x4 = self.layer4(x3m)
###########################
####### UpCONV 1#########
x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4) # Upsample with a factor of 2
x5 = torch.cat([x5, x3], dim=1) # Skip-Connection
x5 = self.layer5(x5)
###########################
####### UpCONV 2#########
x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5)
x6 = torch.cat([x6, x2], dim=1) # Skip-Connection AKA downsampling
x6 = self.layer6(x6)
###########################
####### UpCONV 3#########
x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6)
x7 = torch.cat([x7, x1], dim=1)
x7 = self.layer7(x7)
###########################
####### Predicted segmentation#########
ret = self.layer8(x7)
return ret |