LouayMagdy commited on
Commit
235e4f9
·
verified ·
1 Parent(s): 107f59e

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. base_model.py +69 -0
  2. inference.py +75 -0
  3. model_loader.py +21 -0
  4. vnet_blocks.py +110 -0
  5. vnet_light_arch.py +38 -0
base_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchsummary import summary
7
+
8
+ class BaseModel(nn.Module, ABC):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.best_loss = 1000000
12
+
13
+ @abstractmethod
14
+ def forward(self, x):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def test(self):
19
+ pass
20
+
21
+ @property
22
+ def device(self):
23
+ return next(self.parameters()).device
24
+
25
+ def restore_checkpoint(self, ckpt_file, optimizer=None, affect_weights=True):
26
+ """
27
+ Restores checkpoint from a pth file and restores optimizer state.
28
+ Args:
29
+ ckpt_file (str): A PyTorch pth file containing model weights.
30
+ optimizer (Optimizer): A vanilla optimizer to have its state restored from.
31
+ Returns:
32
+ int: Global step variable where the model was last checkpointed.
33
+ """
34
+ if not ckpt_file:
35
+ raise ValueError("No checkpoint file to be restored.")
36
+
37
+ try:
38
+ ckpt_dict = torch.load(ckpt_file)
39
+
40
+ except RuntimeError:
41
+ ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
42
+ # Restore model weights if needed
43
+ if affect_weights:
44
+ self.load_state_dict(ckpt_dict['model_state_dict'])
45
+ # Restore optimizer status if existing. Evaluation doesn't need this
46
+ if optimizer:
47
+ optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])
48
+ # Return global step
49
+ return ckpt_dict, optimizer
50
+
51
+ def count_params(self):
52
+ """
53
+ Computes the number of parameters in this model.
54
+ Args: None
55
+ Returns:
56
+ int: Total number of weight parameters for this model.
57
+ int: Total number of trainable parameters for this model.
58
+ """
59
+ num_total_params = sum(p.numel() for p in self.parameters())
60
+ num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
61
+ return num_total_params, num_trainable_params
62
+
63
+ def inference(self, input_tensor):
64
+ self.eval()
65
+ with torch.no_grad():
66
+ output = self.forward(input_tensor)
67
+ if isinstance(output, tuple):
68
+ output = output[0]
69
+ return output.cpu().detach()
inference.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tempfile
3
+ import numpy as np
4
+ import nibabel as nib
5
+ from fastapi import FastAPI, UploadFile, File
6
+
7
+ def preprocess_input(flair_file: UploadFile, t1_file: UploadFile, t1ce_file: UploadFile, t2_file: UploadFile):
8
+ with tempfile.NamedTemporaryFile(suffix=".nii") as temp_flair, \
9
+ tempfile.NamedTemporaryFile(suffix=".nii") as temp_t1, \
10
+ tempfile.NamedTemporaryFile(suffix=".nii") as temp_t1ce, \
11
+ tempfile.NamedTemporaryFile(suffix=".nii") as temp_t2:
12
+
13
+ # Save uploaded files to temporary NIfTI files
14
+ flair_content = flair_file.file.read()
15
+ t1_content = t1_file.file.read()
16
+ t1ce_content = t1ce_file.file.read()
17
+ t2_content = t2_file.file.read()
18
+
19
+ temp_flair.write(flair_content)
20
+ temp_t1.write(t1_content)
21
+ temp_t1ce.write(t1ce_content)
22
+ temp_t2.write(t2_content)
23
+
24
+ # Load and preprocess the NIfTI files
25
+ flair = nib.load(temp_flair.name).get_fdata()
26
+ t1 = nib.load(temp_t1.name).get_fdata()
27
+ t1ce = nib.load(temp_t1ce.name).get_fdata()
28
+ t2 = nib.load(temp_t2.name).get_fdata()
29
+
30
+ flair_tensor = torch.tensor(flair[56:184, 56:184, 13:141], dtype=torch.float32).unsqueeze(0)
31
+ t1_tensor = torch.tensor(t1[56:184, 56:184, 13:141], dtype=torch.float32).unsqueeze(0)
32
+ t1ce_tensor = torch.tensor(t1ce[56:184, 56:184, 13:141], dtype=torch.float32).unsqueeze(0)
33
+ t2_tensor = torch.tensor(t2[56:184, 56:184, 13:141], dtype=torch.float32).unsqueeze(0)
34
+
35
+ return flair_tensor, t1_tensor, t1ce_tensor, t2_tensor
36
+
37
+
38
+ def segment_wt(wt_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor):
39
+ wt_input = torch.cat((flair_tensor, t1_tensor, t1ce_tensor, t2_tensor), dim=0).unsqueeze(0)
40
+ with torch.no_grad():
41
+ wt_output = wt_model(wt_input)
42
+ return wt_output
43
+
44
+
45
+ def segment_tc(tc_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output):
46
+ tc_input = torch.cat((flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output.squeeze(0),
47
+ wt_output.squeeze(0), wt_output.squeeze(0), wt_output.squeeze(0)), dim=0).unsqueeze(0)
48
+ with torch.no_grad():
49
+ tc_output = tc_model(tc_input)
50
+ return tc_output
51
+
52
+ def segment_et(et_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output, tc_output):
53
+ et_input = torch.cat((flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output.squeeze(0),
54
+ wt_output.squeeze(0), tc_output.squeeze(0), tc_output.squeeze(0)), dim=0).unsqueeze(0)
55
+ with torch.no_grad():
56
+ et_output = et_model(et_input)
57
+ return et_output
58
+
59
+ def segment_all(wt_model, tc_model, et_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor):
60
+ wt_output = segment_wt(wt_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor)
61
+ tc_output = segment_tc(tc_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output)
62
+ et_output = segment_et(et_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output, tc_output)
63
+
64
+ wt_label = torch.sigmoid(wt_output.squeeze(0, 1)) > 0.5
65
+ tc_label = torch.sigmoid(tc_output.squeeze(0, 1)) > 0.5
66
+ et_label = torch.sigmoid(et_output.squeeze(0, 1)) > 0.5
67
+
68
+ output = np.zeros((128, 128, 128))
69
+ output[(tc_label == 1) & (wt_label == 0)] = 2
70
+ output[(et_label == 1) & (tc_label == 0)] = 3
71
+ output[(et_label == 0) & (tc_label == 0)] = 1
72
+
73
+ output_in_original_size = np.zeros((240, 240, 155))
74
+ output_in_original_size[56:184, 56:184, 13:141] = output
75
+ return output_in_original_size
model_loader.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from vnet_light_arch import VNetLight
3
+
4
+ def load_light_model(checkpoint_path, channels):
5
+ print(f"Loading The Saved Model at {checkpoint_path}")
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ model = VNetLight(in_channels=channels, classes=1)
8
+ model.to(device)
9
+ model.restore_checkpoint(checkpoint_path)
10
+ for param in model.parameters():
11
+ param.requires_grad = False
12
+ return model
13
+
14
+ def load_vChain_model(wt_chkpt, tc_chkpt, et_chkpt):
15
+ wt_model = load_light_model(wt_chkpt, 4)
16
+ tc_model = load_light_model(tc_chkpt, 8)
17
+ et_model = load_light_model(et_chkpt, 8)
18
+ return wt_model, tc_model, et_model
19
+
20
+ # see this for modifying loss function:
21
+ # https://docs.monai.io/en/latest/losses.html
vnet_blocks.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchsummary import summary
7
+
8
+ from base_model import BaseModel
9
+
10
+ def passthrough(x, **kwargs):
11
+ return x
12
+
13
+ def ELUCons(elu, nchan):
14
+ if elu:
15
+ return nn.ELU(inplace=True)
16
+ else:
17
+ return nn.PReLU(nchan)
18
+
19
+ class LUConv(nn.Module):
20
+ def __init__(self, nchan, elu):
21
+ super(LUConv, self).__init__()
22
+ self.relu1 = ELUCons(elu, nchan)
23
+ self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
24
+ self.bn1 = torch.nn.BatchNorm3d(nchan)
25
+
26
+ def forward(self, x):
27
+ out = self.relu1(self.bn1(self.conv1(x)))
28
+ return out
29
+
30
+ def _make_nConv(nchan, depth, elu):
31
+ layers = []
32
+ for _ in range(depth):
33
+ layers.append(LUConv(nchan, elu))
34
+ return nn.Sequential(*layers)
35
+
36
+ class InputTransition(nn.Module):
37
+ def __init__(self, in_channels, elu):
38
+ super(InputTransition, self).__init__()
39
+ self.num_features = 16
40
+ self.in_channels = in_channels
41
+ self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)
42
+ self.bn1 = torch.nn.BatchNorm3d(self.num_features)
43
+ self.relu1 = ELUCons(elu, self.num_features)
44
+
45
+ def forward(self, x):
46
+ out = self.conv1(x)
47
+ repeat_rate = int(self.num_features / self.in_channels)
48
+ out = self.bn1(out)
49
+ x16 = x.repeat(1, repeat_rate, 1, 1, 1)
50
+ return self.relu1(torch.add(out, x16))
51
+
52
+ class DownTransition(nn.Module):
53
+ def __init__(self, inChans, nConvs, elu, dropout=False):
54
+ super(DownTransition, self).__init__()
55
+ outChans = 2 * inChans
56
+ self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
57
+ self.bn1 = torch.nn.BatchNorm3d(outChans)
58
+
59
+ self.do1 = passthrough
60
+ self.relu1 = ELUCons(elu, outChans)
61
+ self.relu2 = ELUCons(elu, outChans)
62
+ if dropout:
63
+ self.do1 = nn.Dropout3d()
64
+ self.ops = _make_nConv(outChans, nConvs, elu)
65
+
66
+ def forward(self, x):
67
+ down = self.relu1(self.bn1(self.down_conv(x)))
68
+ out = self.do1(down)
69
+ out = self.ops(out)
70
+ out = self.relu2(torch.add(out, down))
71
+ return out
72
+
73
+ class UpTransition(nn.Module):
74
+ def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
75
+ super(UpTransition, self).__init__()
76
+ self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
77
+
78
+ self.bn1 = torch.nn.BatchNorm3d(outChans // 2)
79
+ self.do1 = passthrough
80
+ self.do2 = nn.Dropout3d()
81
+ self.relu1 = ELUCons(elu, outChans // 2)
82
+ self.relu2 = ELUCons(elu, outChans)
83
+ if dropout:
84
+ self.do1 = nn.Dropout3d()
85
+ self.ops = _make_nConv(outChans, nConvs, elu)
86
+
87
+ def forward(self, x, skipx):
88
+ out = self.do1(x)
89
+ skipxdo = self.do2(skipx)
90
+ out = self.relu1(self.bn1(self.up_conv(out)))
91
+ xcat = torch.cat((out, skipxdo), 1)
92
+ out = self.ops(xcat)
93
+ out = self.relu2(torch.add(out, xcat))
94
+ return out
95
+
96
+ class OutputTransition(nn.Module):
97
+ def __init__(self, in_channels, classes, elu):
98
+ super(OutputTransition, self).__init__()
99
+ self.classes = classes
100
+ self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)
101
+ self.bn1 = torch.nn.BatchNorm3d(classes)
102
+
103
+ self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)
104
+ self.relu1 = ELUCons(elu, classes)
105
+
106
+ def forward(self, x):
107
+ # convolve 32 down to channels as the desired classes
108
+ out = self.relu1(self.bn1(self.conv1(x)))
109
+ out = self.conv2(out)
110
+ return out
vnet_light_arch.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vnet_blocks import *
2
+
3
+ class VNetLight(BaseModel):
4
+ """
5
+ A lighter version of Vnet that skips down_tr256 and up_tr256 in oreder to reduce time and space complexity
6
+ """
7
+ def __init__(self, elu=True, in_channels=1, classes=4):
8
+ super(VNetLight, self).__init__()
9
+ self.classes = classes
10
+ self.in_channels = in_channels
11
+
12
+ self.in_tr = InputTransition(in_channels, elu)
13
+ self.down_tr32 = DownTransition(16, 1, elu)
14
+ self.down_tr64 = DownTransition(32, 2, elu)
15
+ self.down_tr128 = DownTransition(64, 3, elu, dropout=True)
16
+ self.up_tr128 = UpTransition(128, 128, 2, elu, dropout=True)
17
+ self.up_tr64 = UpTransition(128, 64, 1, elu)
18
+ self.up_tr32 = UpTransition(64, 32, 1, elu)
19
+ self.out_tr = OutputTransition(32, classes, elu)
20
+
21
+ def forward(self, x):
22
+ out16 = self.in_tr(x)
23
+ out32 = self.down_tr32(out16)
24
+ out64 = self.down_tr64(out32)
25
+ out128 = self.down_tr128(out64)
26
+ out = self.up_tr128(out128, out64)
27
+ out = self.up_tr64(out, out32)
28
+ out = self.up_tr32(out, out16)
29
+ out = self.out_tr(out)
30
+ return out
31
+
32
+ def test(self,device='cpu'):
33
+ input_tensor = torch.rand(1, self.in_channels, 32, 32, 32)
34
+ ideal_out = torch.rand(1, self.classes, 32, 32, 32)
35
+ out = self.forward(input_tensor)
36
+ assert ideal_out.shape == out.shape
37
+ summary(self.to(torch.device(device)), (self.in_channels, 32, 32, 32),device=device)
38
+ print("Vnet Light test is complete")