doggywastaken commited on
Commit
d26715f
β€’
1 Parent(s): 288210a

init commit

Browse files
Files changed (3) hide show
  1. LICENSE +21 -0
  2. README.md +9 -3
  3. model/leaky_vnet.py +192 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Sam
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,9 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
1
+ # Leaky V-Net in Pytorch
2
+
3
+ This is a fork of @Dootmaan's VNet.PyTorch repo, who attempted a faithful recreation of the original V-Net: Fully Convolutional Neural Network for Volumetric Medical Image paper, with as little adaptations as possible.
4
+
5
+ This repo's model has some minor adaptations to fit it's designated application:
6
+
7
+ * ReLU layers is now the leaky version to allow for more consistent convergence on small training datasets
8
+
9
+ * Final Sigmoid layer has been removed from the network in favour of manual thresholding (for flexibility during testing)
model/leaky_vnet.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ class conv3d(nn.Module, PyTorchModelHubMixin):
7
+ def __init__(self, in_channels, out_channels):
8
+ """
9
+ + Instantiate modules: conv-relu-norm
10
+ + Assign them as member variables
11
+ """
12
+ super(conv3d, self).__init__()
13
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2)
14
+ self.relu = nn.LeakyReLU(0.2)
15
+ # with learnable parameters
16
+ self.norm = nn.InstanceNorm3d(out_channels, affine=True)
17
+
18
+ def forward(self, x):
19
+ return self.relu(self.norm(self.conv(x)))
20
+
21
+
22
+ class conv3d_x3(nn.Module, PyTorchModelHubMixin):
23
+ """Three serial convs with a residual connection.
24
+ Structure:
25
+ inputs --> β‘  --> β‘‘ --> β‘’ --> outputs
26
+ ↓ --> add--> ↑
27
+ """
28
+
29
+ def __init__(self, in_channels, out_channels):
30
+ super(conv3d_x3, self).__init__()
31
+ self.conv_1 = conv3d(in_channels, out_channels)
32
+ self.conv_2 = conv3d(out_channels, out_channels)
33
+ self.conv_3 = conv3d(out_channels, out_channels)
34
+ self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
35
+
36
+ def forward(self, x):
37
+ z_1 = self.conv_1(x)
38
+ z_3 = self.conv_3(self.conv_2(z_1))
39
+ return z_3 + self.skip_connection(x)
40
+
41
+ class conv3d_x2(nn.Module, PyTorchModelHubMixin):
42
+ """Three serial convs with a residual connection.
43
+ Structure:
44
+ inputs --> β‘  --> β‘‘ --> β‘’ --> outputs
45
+ ↓ --> add--> ↑
46
+ """
47
+
48
+ def __init__(self, in_channels, out_channels):
49
+ super(conv3d_x2, self).__init__()
50
+ self.conv_1 = conv3d(in_channels, out_channels)
51
+ self.conv_2 = conv3d(out_channels, out_channels)
52
+ self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
53
+
54
+ def forward(self, x):
55
+ z_1 = self.conv_1(x)
56
+ z_2 = self.conv_2(z_1)
57
+ return z_2 + self.skip_connection(x)
58
+
59
+
60
+ class conv3d_x1(nn.Module, PyTorchModelHubMixin):
61
+ """Three serial convs with a residual connection.
62
+ Structure:
63
+ inputs --> β‘  --> β‘‘ --> β‘’ --> outputs
64
+ ↓ --> add--> ↑
65
+ """
66
+
67
+ def __init__(self, in_channels, out_channels):
68
+ super(conv3d_x1, self).__init__()
69
+ self.conv_1 = conv3d(in_channels, out_channels)
70
+ self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
71
+
72
+ def forward(self, x):
73
+ z_1 = self.conv_1(x)
74
+ return z_1 + self.skip_connection(x)
75
+
76
+ class deconv3d_x3(nn.Module, PyTorchModelHubMixin):
77
+ def __init__(self, in_channels, out_channels):
78
+ super(deconv3d_x3, self).__init__()
79
+ self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
80
+ self.lhs_conv = conv3d(out_channels // 2, out_channels)
81
+ self.conv_x3 = nn.Sequential(
82
+ nn.Conv3d(2*out_channels, out_channels,5,1,2),
83
+ nn.LeakyReLU(0.1),
84
+ nn.Conv3d(out_channels, out_channels,5,1,2),
85
+ nn.LeakyReLU(0.1),
86
+ nn.Conv3d(out_channels, out_channels,5,1,2),
87
+ nn.LeakyReLU(0.1),
88
+ )
89
+
90
+ def forward(self, lhs, rhs):
91
+ rhs_up = self.up(rhs)
92
+ lhs_conv = self.lhs_conv(lhs)
93
+ rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
94
+ return self.conv_x3(rhs_add)+ rhs_up
95
+
96
+ class deconv3d_x2(nn.Module, PyTorchModelHubMixin):
97
+ def __init__(self, in_channels, out_channels):
98
+ super(deconv3d_x2, self).__init__()
99
+ self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
100
+ self.lhs_conv = conv3d(out_channels // 2, out_channels)
101
+ self.conv_x2= nn.Sequential(
102
+ nn.Conv3d(2*out_channels, out_channels,5,1,2),
103
+ nn.LeakyReLU(0.1),
104
+ nn.Conv3d(out_channels, out_channels,5,1,2),
105
+ nn.LeakyReLU(0.1),
106
+ )
107
+
108
+ def forward(self, lhs, rhs):
109
+ rhs_up = self.up(rhs)
110
+ lhs_conv = self.lhs_conv(lhs)
111
+ rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
112
+ return self.conv_x2(rhs_add)+ rhs_up
113
+
114
+ class deconv3d_x1(nn.Module, PyTorchModelHubMixin):
115
+ def __init__(self, in_channels, out_channels):
116
+ super(deconv3d_x1, self).__init__()
117
+ self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
118
+ self.lhs_conv = conv3d(out_channels // 2, out_channels)
119
+ self.conv_x1 = nn.Sequential(
120
+ nn.Conv3d(2*out_channels, out_channels,5,1,2),
121
+ nn.LeakyReLU(0.2),
122
+ )
123
+
124
+ def forward(self, lhs, rhs):
125
+ rhs_up = self.up(rhs)
126
+ lhs_conv = self.lhs_conv(lhs)
127
+ rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
128
+ return self.conv_x1(rhs_add)+ rhs_up
129
+
130
+
131
+ def conv3d_as_pool(in_channels, out_channels, kernel_size=2, stride=2):
132
+ return nn.Sequential(
133
+ nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=0),
134
+ nn.LeakyReLU(0.2))
135
+
136
+
137
+ def deconv3d_as_up(in_channels, out_channels, kernel_size=2, stride=2):
138
+ return nn.Sequential(
139
+ nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride),
140
+ nn.PReLU()
141
+ )
142
+
143
+
144
+ class softmax_out(nn.Module, PyTorchModelHubMixin):
145
+ def __init__(self, in_channels, out_channels):
146
+ super(softmax_out, self).__init__()
147
+ self.conv_1 = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2)
148
+ self.conv_2 = nn.Conv3d(out_channels, out_channels, kernel_size=1, padding=0)
149
+
150
+ def forward(self, x):
151
+ """Output with shape [batch_size, 1, depth, height, width]."""
152
+ # Do NOT add normalize layer, or its values vanish.
153
+ y_conv = self.conv_2(self.conv_1(x))
154
+ return y_conv
155
+
156
+
157
+ class VNet(nn.Module, PyTorchModelHubMixin):
158
+ def __init__(self):
159
+ super(VNet, self).__init__()
160
+ self.conv_1 = conv3d_x1(1, 16)
161
+ self.pool_1 = conv3d_as_pool(16, 32)
162
+ self.conv_2 = conv3d_x2(32, 32)
163
+ self.pool_2 = conv3d_as_pool(32, 64)
164
+ self.conv_3 = conv3d_x3(64, 64)
165
+ self.pool_3 = conv3d_as_pool(64, 128)
166
+ self.conv_4 = conv3d_x3(128, 128)
167
+ self.pool_4 = conv3d_as_pool(128, 256)
168
+
169
+ self.bottom = conv3d_x3(256, 256)
170
+
171
+ self.deconv_4 = deconv3d_x3(256, 256)
172
+ self.deconv_3 = deconv3d_x3(256, 128)
173
+ self.deconv_2 = deconv3d_x2(128, 64)
174
+ self.deconv_1 = deconv3d_x1(64, 32)
175
+
176
+ self.out = softmax_out(32, 1)
177
+
178
+ def forward(self, x):
179
+ conv_1 = self.conv_1(x)
180
+ pool = self.pool_1(conv_1)
181
+ conv_2 = self.conv_2(pool)
182
+ pool = self.pool_2(conv_2)
183
+ conv_3 = self.conv_3(pool)
184
+ pool = self.pool_3(conv_3)
185
+ conv_4 = self.conv_4(pool)
186
+ pool = self.pool_4(conv_4)
187
+ bottom = self.bottom(pool)
188
+ deconv = self.deconv_4(conv_4, bottom)
189
+ deconv = self.deconv_3(conv_3, deconv)
190
+ deconv = self.deconv_2(conv_2, deconv)
191
+ deconv = self.deconv_1(conv_1, deconv)
192
+ return self.out(deconv)