shivambhosale commited on
Commit
f61800c
1 Parent(s): a98ddcd

Create new file

Browse files
Files changed (1) hide show
  1. UNet.py +87 -0
UNet.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ height, width = 512, 512
2
+ class Block(Module):
3
+ def __init__(self, inChannels, outChannels):
4
+ super().__init__()
5
+ # store the convolution and RELU layers
6
+ self.conv1 = Conv2d(inChannels, outChannels, 3)
7
+ self.relu = ReLU()
8
+ self.conv2 = Conv2d(outChannels, outChannels, 3)
9
+ def forward(self, x):
10
+ # apply CONV => RELU => CONV block to the inputs and return it
11
+ return self.conv2(self.relu(self.conv1(x)))
12
+
13
+ class Encoder(Module):
14
+ def __init__(self, channels=(3, 16, 32, 64)):
15
+ super().__init__()
16
+ # store the encoder blocks and maxpooling layer
17
+ self.encBlocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])
18
+ self.pool = MaxPool2d(2)
19
+ def forward(self, x):
20
+ # initialize an empty list to store the intermediate outputs
21
+ blockOutputs = []
22
+ # loop through the encoder blocks
23
+ for block in self.encBlocks:
24
+ # pass the inputs through the current encoder block, store
25
+ # the outputs, and then apply maxpooling on the output
26
+ x = block(x)
27
+ blockOutputs.append(x)
28
+ x = self.pool(x)
29
+ # return the list containing the intermediate outputs
30
+ return blockOutputs
31
+
32
+ class Decoder(Module):
33
+ def __init__(self, channels=(64, 32, 16)):
34
+ super().__init__()
35
+ # initialize the number of channels, upsampler blocks, and
36
+ # decoder blocks
37
+ self.channels = channels
38
+ self.upconvs = ModuleList([ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)])
39
+ self.dec_blocks = ModuleList([Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)])
40
+ def forward(self, x, encFeatures):
41
+ # loop through the number of channels
42
+ for i in range(len(self.channels) - 1):
43
+ # pass the inputs through the upsampler blocks
44
+ x = self.upconvs[i](x)
45
+ # crop the current features from the encoder blocks,
46
+ # concatenate them with the current upsampled features,
47
+ # and pass the concatenated output through the current
48
+ # decoder block
49
+ encFeat = self.crop(encFeatures[i], x)
50
+ x = torch.cat([x, encFeat], dim=1)
51
+ x = self.dec_blocks[i](x)
52
+ # return the final decoder output
53
+ return x
54
+ def crop(self, encFeatures, x):
55
+ # grab the dimensions of the inputs, and crop the encoder
56
+ # features to match the dimensions
57
+ (_, _, H, W) = x.shape
58
+ encFeatures = CenterCrop([H, W])(encFeatures)
59
+ # return the cropped features
60
+ return encFeatures
61
+
62
+ class UNet(Module):
63
+ def __init__(self, encChannels=(3, 64, 128, 256, 512, 1024), decChannels=(1024, 512, 256, 128, 64),
64
+ nbClasses=1, retainDim=True, outSize=(height, width)):
65
+ super().__init__()
66
+ # initialize the encoder and decoder
67
+ self.encoder = Encoder(encChannels)
68
+ self.decoder = Decoder(decChannels)
69
+ # initialize the regression head and store the class variables
70
+ self.head = Conv2d(decChannels[-1], nbClasses, 1)
71
+ self.retainDim = retainDim
72
+ self.outSize = outSize
73
+ def forward(self, x):
74
+ # grab the features from the encoder
75
+ encFeatures = self.encoder(x)
76
+ # pass the encoder features through decoder making sure that
77
+ # their dimensions are suited for concatenation
78
+ decFeatures = self.decoder(encFeatures[::-1][0], encFeatures[::-1][1:])
79
+ # pass the decoder features through the regression head to
80
+ # obtain the segmentation mask
81
+ map_ = self.head(decFeatures)
82
+ # check to see if we are retaining the original output
83
+ # dimensions and if so, then resize the output to match them
84
+ if self.retainDim:
85
+ map_ = F.interpolate(map_, self.outSize)
86
+ # return the segmentation map
87
+ return map_