Commit
β’
d26715f
1
Parent(s):
288210a
init commit
Browse files- LICENSE +21 -0
- README.md +9 -3
- model/leaky_vnet.py +192 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Sam
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Leaky V-Net in Pytorch
|
2 |
+
|
3 |
+
This is a fork of @Dootmaan's VNet.PyTorch repo, who attempted a faithful recreation of the original V-Net: Fully Convolutional Neural Network for Volumetric Medical Image paper, with as little adaptations as possible.
|
4 |
+
|
5 |
+
This repo's model has some minor adaptations to fit it's designated application:
|
6 |
+
|
7 |
+
* ReLU layers is now the leaky version to allow for more consistent convergence on small training datasets
|
8 |
+
|
9 |
+
* Final Sigmoid layer has been removed from the network in favour of manual thresholding (for flexibility during testing)
|
model/leaky_vnet.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
|
6 |
+
class conv3d(nn.Module, PyTorchModelHubMixin):
|
7 |
+
def __init__(self, in_channels, out_channels):
|
8 |
+
"""
|
9 |
+
+ Instantiate modules: conv-relu-norm
|
10 |
+
+ Assign them as member variables
|
11 |
+
"""
|
12 |
+
super(conv3d, self).__init__()
|
13 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2)
|
14 |
+
self.relu = nn.LeakyReLU(0.2)
|
15 |
+
# with learnable parameters
|
16 |
+
self.norm = nn.InstanceNorm3d(out_channels, affine=True)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return self.relu(self.norm(self.conv(x)))
|
20 |
+
|
21 |
+
|
22 |
+
class conv3d_x3(nn.Module, PyTorchModelHubMixin):
|
23 |
+
"""Three serial convs with a residual connection.
|
24 |
+
Structure:
|
25 |
+
inputs --> β --> β‘ --> β’ --> outputs
|
26 |
+
β --> add--> β
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, in_channels, out_channels):
|
30 |
+
super(conv3d_x3, self).__init__()
|
31 |
+
self.conv_1 = conv3d(in_channels, out_channels)
|
32 |
+
self.conv_2 = conv3d(out_channels, out_channels)
|
33 |
+
self.conv_3 = conv3d(out_channels, out_channels)
|
34 |
+
self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
z_1 = self.conv_1(x)
|
38 |
+
z_3 = self.conv_3(self.conv_2(z_1))
|
39 |
+
return z_3 + self.skip_connection(x)
|
40 |
+
|
41 |
+
class conv3d_x2(nn.Module, PyTorchModelHubMixin):
|
42 |
+
"""Three serial convs with a residual connection.
|
43 |
+
Structure:
|
44 |
+
inputs --> β --> β‘ --> β’ --> outputs
|
45 |
+
β --> add--> β
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, in_channels, out_channels):
|
49 |
+
super(conv3d_x2, self).__init__()
|
50 |
+
self.conv_1 = conv3d(in_channels, out_channels)
|
51 |
+
self.conv_2 = conv3d(out_channels, out_channels)
|
52 |
+
self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
z_1 = self.conv_1(x)
|
56 |
+
z_2 = self.conv_2(z_1)
|
57 |
+
return z_2 + self.skip_connection(x)
|
58 |
+
|
59 |
+
|
60 |
+
class conv3d_x1(nn.Module, PyTorchModelHubMixin):
|
61 |
+
"""Three serial convs with a residual connection.
|
62 |
+
Structure:
|
63 |
+
inputs --> β --> β‘ --> β’ --> outputs
|
64 |
+
β --> add--> β
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, in_channels, out_channels):
|
68 |
+
super(conv3d_x1, self).__init__()
|
69 |
+
self.conv_1 = conv3d(in_channels, out_channels)
|
70 |
+
self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
z_1 = self.conv_1(x)
|
74 |
+
return z_1 + self.skip_connection(x)
|
75 |
+
|
76 |
+
class deconv3d_x3(nn.Module, PyTorchModelHubMixin):
|
77 |
+
def __init__(self, in_channels, out_channels):
|
78 |
+
super(deconv3d_x3, self).__init__()
|
79 |
+
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
|
80 |
+
self.lhs_conv = conv3d(out_channels // 2, out_channels)
|
81 |
+
self.conv_x3 = nn.Sequential(
|
82 |
+
nn.Conv3d(2*out_channels, out_channels,5,1,2),
|
83 |
+
nn.LeakyReLU(0.1),
|
84 |
+
nn.Conv3d(out_channels, out_channels,5,1,2),
|
85 |
+
nn.LeakyReLU(0.1),
|
86 |
+
nn.Conv3d(out_channels, out_channels,5,1,2),
|
87 |
+
nn.LeakyReLU(0.1),
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, lhs, rhs):
|
91 |
+
rhs_up = self.up(rhs)
|
92 |
+
lhs_conv = self.lhs_conv(lhs)
|
93 |
+
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
|
94 |
+
return self.conv_x3(rhs_add)+ rhs_up
|
95 |
+
|
96 |
+
class deconv3d_x2(nn.Module, PyTorchModelHubMixin):
|
97 |
+
def __init__(self, in_channels, out_channels):
|
98 |
+
super(deconv3d_x2, self).__init__()
|
99 |
+
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
|
100 |
+
self.lhs_conv = conv3d(out_channels // 2, out_channels)
|
101 |
+
self.conv_x2= nn.Sequential(
|
102 |
+
nn.Conv3d(2*out_channels, out_channels,5,1,2),
|
103 |
+
nn.LeakyReLU(0.1),
|
104 |
+
nn.Conv3d(out_channels, out_channels,5,1,2),
|
105 |
+
nn.LeakyReLU(0.1),
|
106 |
+
)
|
107 |
+
|
108 |
+
def forward(self, lhs, rhs):
|
109 |
+
rhs_up = self.up(rhs)
|
110 |
+
lhs_conv = self.lhs_conv(lhs)
|
111 |
+
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
|
112 |
+
return self.conv_x2(rhs_add)+ rhs_up
|
113 |
+
|
114 |
+
class deconv3d_x1(nn.Module, PyTorchModelHubMixin):
|
115 |
+
def __init__(self, in_channels, out_channels):
|
116 |
+
super(deconv3d_x1, self).__init__()
|
117 |
+
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
|
118 |
+
self.lhs_conv = conv3d(out_channels // 2, out_channels)
|
119 |
+
self.conv_x1 = nn.Sequential(
|
120 |
+
nn.Conv3d(2*out_channels, out_channels,5,1,2),
|
121 |
+
nn.LeakyReLU(0.2),
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(self, lhs, rhs):
|
125 |
+
rhs_up = self.up(rhs)
|
126 |
+
lhs_conv = self.lhs_conv(lhs)
|
127 |
+
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
|
128 |
+
return self.conv_x1(rhs_add)+ rhs_up
|
129 |
+
|
130 |
+
|
131 |
+
def conv3d_as_pool(in_channels, out_channels, kernel_size=2, stride=2):
|
132 |
+
return nn.Sequential(
|
133 |
+
nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=0),
|
134 |
+
nn.LeakyReLU(0.2))
|
135 |
+
|
136 |
+
|
137 |
+
def deconv3d_as_up(in_channels, out_channels, kernel_size=2, stride=2):
|
138 |
+
return nn.Sequential(
|
139 |
+
nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride),
|
140 |
+
nn.PReLU()
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
class softmax_out(nn.Module, PyTorchModelHubMixin):
|
145 |
+
def __init__(self, in_channels, out_channels):
|
146 |
+
super(softmax_out, self).__init__()
|
147 |
+
self.conv_1 = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2)
|
148 |
+
self.conv_2 = nn.Conv3d(out_channels, out_channels, kernel_size=1, padding=0)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
"""Output with shape [batch_size, 1, depth, height, width]."""
|
152 |
+
# Do NOT add normalize layer, or its values vanish.
|
153 |
+
y_conv = self.conv_2(self.conv_1(x))
|
154 |
+
return y_conv
|
155 |
+
|
156 |
+
|
157 |
+
class VNet(nn.Module, PyTorchModelHubMixin):
|
158 |
+
def __init__(self):
|
159 |
+
super(VNet, self).__init__()
|
160 |
+
self.conv_1 = conv3d_x1(1, 16)
|
161 |
+
self.pool_1 = conv3d_as_pool(16, 32)
|
162 |
+
self.conv_2 = conv3d_x2(32, 32)
|
163 |
+
self.pool_2 = conv3d_as_pool(32, 64)
|
164 |
+
self.conv_3 = conv3d_x3(64, 64)
|
165 |
+
self.pool_3 = conv3d_as_pool(64, 128)
|
166 |
+
self.conv_4 = conv3d_x3(128, 128)
|
167 |
+
self.pool_4 = conv3d_as_pool(128, 256)
|
168 |
+
|
169 |
+
self.bottom = conv3d_x3(256, 256)
|
170 |
+
|
171 |
+
self.deconv_4 = deconv3d_x3(256, 256)
|
172 |
+
self.deconv_3 = deconv3d_x3(256, 128)
|
173 |
+
self.deconv_2 = deconv3d_x2(128, 64)
|
174 |
+
self.deconv_1 = deconv3d_x1(64, 32)
|
175 |
+
|
176 |
+
self.out = softmax_out(32, 1)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
conv_1 = self.conv_1(x)
|
180 |
+
pool = self.pool_1(conv_1)
|
181 |
+
conv_2 = self.conv_2(pool)
|
182 |
+
pool = self.pool_2(conv_2)
|
183 |
+
conv_3 = self.conv_3(pool)
|
184 |
+
pool = self.pool_3(conv_3)
|
185 |
+
conv_4 = self.conv_4(pool)
|
186 |
+
pool = self.pool_4(conv_4)
|
187 |
+
bottom = self.bottom(pool)
|
188 |
+
deconv = self.deconv_4(conv_4, bottom)
|
189 |
+
deconv = self.deconv_3(conv_3, deconv)
|
190 |
+
deconv = self.deconv_2(conv_2, deconv)
|
191 |
+
deconv = self.deconv_1(conv_1, deconv)
|
192 |
+
return self.out(deconv)
|