Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- base_model.py +69 -0
- inference.py +75 -0
- model_loader.py +21 -0
- vnet_blocks.py +110 -0
- vnet_light_arch.py +38 -0
base_model.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torchsummary import summary
|
| 7 |
+
|
| 8 |
+
class BaseModel(nn.Module, ABC):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.best_loss = 1000000
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def test(self):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def device(self):
|
| 23 |
+
return next(self.parameters()).device
|
| 24 |
+
|
| 25 |
+
def restore_checkpoint(self, ckpt_file, optimizer=None, affect_weights=True):
|
| 26 |
+
"""
|
| 27 |
+
Restores checkpoint from a pth file and restores optimizer state.
|
| 28 |
+
Args:
|
| 29 |
+
ckpt_file (str): A PyTorch pth file containing model weights.
|
| 30 |
+
optimizer (Optimizer): A vanilla optimizer to have its state restored from.
|
| 31 |
+
Returns:
|
| 32 |
+
int: Global step variable where the model was last checkpointed.
|
| 33 |
+
"""
|
| 34 |
+
if not ckpt_file:
|
| 35 |
+
raise ValueError("No checkpoint file to be restored.")
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
ckpt_dict = torch.load(ckpt_file)
|
| 39 |
+
|
| 40 |
+
except RuntimeError:
|
| 41 |
+
ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage)
|
| 42 |
+
# Restore model weights if needed
|
| 43 |
+
if affect_weights:
|
| 44 |
+
self.load_state_dict(ckpt_dict['model_state_dict'])
|
| 45 |
+
# Restore optimizer status if existing. Evaluation doesn't need this
|
| 46 |
+
if optimizer:
|
| 47 |
+
optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])
|
| 48 |
+
# Return global step
|
| 49 |
+
return ckpt_dict, optimizer
|
| 50 |
+
|
| 51 |
+
def count_params(self):
|
| 52 |
+
"""
|
| 53 |
+
Computes the number of parameters in this model.
|
| 54 |
+
Args: None
|
| 55 |
+
Returns:
|
| 56 |
+
int: Total number of weight parameters for this model.
|
| 57 |
+
int: Total number of trainable parameters for this model.
|
| 58 |
+
"""
|
| 59 |
+
num_total_params = sum(p.numel() for p in self.parameters())
|
| 60 |
+
num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 61 |
+
return num_total_params, num_trainable_params
|
| 62 |
+
|
| 63 |
+
def inference(self, input_tensor):
|
| 64 |
+
self.eval()
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
output = self.forward(input_tensor)
|
| 67 |
+
if isinstance(output, tuple):
|
| 68 |
+
output = output[0]
|
| 69 |
+
return output.cpu().detach()
|
inference.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tempfile
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nibabel as nib
|
| 5 |
+
from fastapi import FastAPI, UploadFile, File
|
| 6 |
+
|
| 7 |
+
def preprocess_input(flair_file: UploadFile, t1_file: UploadFile, t1ce_file: UploadFile, t2_file: UploadFile):
|
| 8 |
+
with tempfile.NamedTemporaryFile(suffix=".nii") as temp_flair, \
|
| 9 |
+
tempfile.NamedTemporaryFile(suffix=".nii") as temp_t1, \
|
| 10 |
+
tempfile.NamedTemporaryFile(suffix=".nii") as temp_t1ce, \
|
| 11 |
+
tempfile.NamedTemporaryFile(suffix=".nii") as temp_t2:
|
| 12 |
+
|
| 13 |
+
# Save uploaded files to temporary NIfTI files
|
| 14 |
+
flair_content = flair_file.file.read()
|
| 15 |
+
t1_content = t1_file.file.read()
|
| 16 |
+
t1ce_content = t1ce_file.file.read()
|
| 17 |
+
t2_content = t2_file.file.read()
|
| 18 |
+
|
| 19 |
+
temp_flair.write(flair_content)
|
| 20 |
+
temp_t1.write(t1_content)
|
| 21 |
+
temp_t1ce.write(t1ce_content)
|
| 22 |
+
temp_t2.write(t2_content)
|
| 23 |
+
|
| 24 |
+
# Load and preprocess the NIfTI files
|
| 25 |
+
flair = nib.load(temp_flair.name).get_fdata()
|
| 26 |
+
t1 = nib.load(temp_t1.name).get_fdata()
|
| 27 |
+
t1ce = nib.load(temp_t1ce.name).get_fdata()
|
| 28 |
+
t2 = nib.load(temp_t2.name).get_fdata()
|
| 29 |
+
|
| 30 |
+
flair_tensor = torch.tensor(flair[56:184, 56:184, 13:141], dtype=torch.float32).unsqueeze(0)
|
| 31 |
+
t1_tensor = torch.tensor(t1[56:184, 56:184, 13:141], dtype=torch.float32).unsqueeze(0)
|
| 32 |
+
t1ce_tensor = torch.tensor(t1ce[56:184, 56:184, 13:141], dtype=torch.float32).unsqueeze(0)
|
| 33 |
+
t2_tensor = torch.tensor(t2[56:184, 56:184, 13:141], dtype=torch.float32).unsqueeze(0)
|
| 34 |
+
|
| 35 |
+
return flair_tensor, t1_tensor, t1ce_tensor, t2_tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def segment_wt(wt_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor):
|
| 39 |
+
wt_input = torch.cat((flair_tensor, t1_tensor, t1ce_tensor, t2_tensor), dim=0).unsqueeze(0)
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
wt_output = wt_model(wt_input)
|
| 42 |
+
return wt_output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def segment_tc(tc_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output):
|
| 46 |
+
tc_input = torch.cat((flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output.squeeze(0),
|
| 47 |
+
wt_output.squeeze(0), wt_output.squeeze(0), wt_output.squeeze(0)), dim=0).unsqueeze(0)
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
tc_output = tc_model(tc_input)
|
| 50 |
+
return tc_output
|
| 51 |
+
|
| 52 |
+
def segment_et(et_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output, tc_output):
|
| 53 |
+
et_input = torch.cat((flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output.squeeze(0),
|
| 54 |
+
wt_output.squeeze(0), tc_output.squeeze(0), tc_output.squeeze(0)), dim=0).unsqueeze(0)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
et_output = et_model(et_input)
|
| 57 |
+
return et_output
|
| 58 |
+
|
| 59 |
+
def segment_all(wt_model, tc_model, et_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor):
|
| 60 |
+
wt_output = segment_wt(wt_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor)
|
| 61 |
+
tc_output = segment_tc(tc_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output)
|
| 62 |
+
et_output = segment_et(et_model, flair_tensor, t1_tensor, t1ce_tensor, t2_tensor, wt_output, tc_output)
|
| 63 |
+
|
| 64 |
+
wt_label = torch.sigmoid(wt_output.squeeze(0, 1)) > 0.5
|
| 65 |
+
tc_label = torch.sigmoid(tc_output.squeeze(0, 1)) > 0.5
|
| 66 |
+
et_label = torch.sigmoid(et_output.squeeze(0, 1)) > 0.5
|
| 67 |
+
|
| 68 |
+
output = np.zeros((128, 128, 128))
|
| 69 |
+
output[(tc_label == 1) & (wt_label == 0)] = 2
|
| 70 |
+
output[(et_label == 1) & (tc_label == 0)] = 3
|
| 71 |
+
output[(et_label == 0) & (tc_label == 0)] = 1
|
| 72 |
+
|
| 73 |
+
output_in_original_size = np.zeros((240, 240, 155))
|
| 74 |
+
output_in_original_size[56:184, 56:184, 13:141] = output
|
| 75 |
+
return output_in_original_size
|
model_loader.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from vnet_light_arch import VNetLight
|
| 3 |
+
|
| 4 |
+
def load_light_model(checkpoint_path, channels):
|
| 5 |
+
print(f"Loading The Saved Model at {checkpoint_path}")
|
| 6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
model = VNetLight(in_channels=channels, classes=1)
|
| 8 |
+
model.to(device)
|
| 9 |
+
model.restore_checkpoint(checkpoint_path)
|
| 10 |
+
for param in model.parameters():
|
| 11 |
+
param.requires_grad = False
|
| 12 |
+
return model
|
| 13 |
+
|
| 14 |
+
def load_vChain_model(wt_chkpt, tc_chkpt, et_chkpt):
|
| 15 |
+
wt_model = load_light_model(wt_chkpt, 4)
|
| 16 |
+
tc_model = load_light_model(tc_chkpt, 8)
|
| 17 |
+
et_model = load_light_model(et_chkpt, 8)
|
| 18 |
+
return wt_model, tc_model, et_model
|
| 19 |
+
|
| 20 |
+
# see this for modifying loss function:
|
| 21 |
+
# https://docs.monai.io/en/latest/losses.html
|
vnet_blocks.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torchsummary import summary
|
| 7 |
+
|
| 8 |
+
from base_model import BaseModel
|
| 9 |
+
|
| 10 |
+
def passthrough(x, **kwargs):
|
| 11 |
+
return x
|
| 12 |
+
|
| 13 |
+
def ELUCons(elu, nchan):
|
| 14 |
+
if elu:
|
| 15 |
+
return nn.ELU(inplace=True)
|
| 16 |
+
else:
|
| 17 |
+
return nn.PReLU(nchan)
|
| 18 |
+
|
| 19 |
+
class LUConv(nn.Module):
|
| 20 |
+
def __init__(self, nchan, elu):
|
| 21 |
+
super(LUConv, self).__init__()
|
| 22 |
+
self.relu1 = ELUCons(elu, nchan)
|
| 23 |
+
self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
|
| 24 |
+
self.bn1 = torch.nn.BatchNorm3d(nchan)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 28 |
+
return out
|
| 29 |
+
|
| 30 |
+
def _make_nConv(nchan, depth, elu):
|
| 31 |
+
layers = []
|
| 32 |
+
for _ in range(depth):
|
| 33 |
+
layers.append(LUConv(nchan, elu))
|
| 34 |
+
return nn.Sequential(*layers)
|
| 35 |
+
|
| 36 |
+
class InputTransition(nn.Module):
|
| 37 |
+
def __init__(self, in_channels, elu):
|
| 38 |
+
super(InputTransition, self).__init__()
|
| 39 |
+
self.num_features = 16
|
| 40 |
+
self.in_channels = in_channels
|
| 41 |
+
self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)
|
| 42 |
+
self.bn1 = torch.nn.BatchNorm3d(self.num_features)
|
| 43 |
+
self.relu1 = ELUCons(elu, self.num_features)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
out = self.conv1(x)
|
| 47 |
+
repeat_rate = int(self.num_features / self.in_channels)
|
| 48 |
+
out = self.bn1(out)
|
| 49 |
+
x16 = x.repeat(1, repeat_rate, 1, 1, 1)
|
| 50 |
+
return self.relu1(torch.add(out, x16))
|
| 51 |
+
|
| 52 |
+
class DownTransition(nn.Module):
|
| 53 |
+
def __init__(self, inChans, nConvs, elu, dropout=False):
|
| 54 |
+
super(DownTransition, self).__init__()
|
| 55 |
+
outChans = 2 * inChans
|
| 56 |
+
self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
|
| 57 |
+
self.bn1 = torch.nn.BatchNorm3d(outChans)
|
| 58 |
+
|
| 59 |
+
self.do1 = passthrough
|
| 60 |
+
self.relu1 = ELUCons(elu, outChans)
|
| 61 |
+
self.relu2 = ELUCons(elu, outChans)
|
| 62 |
+
if dropout:
|
| 63 |
+
self.do1 = nn.Dropout3d()
|
| 64 |
+
self.ops = _make_nConv(outChans, nConvs, elu)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
down = self.relu1(self.bn1(self.down_conv(x)))
|
| 68 |
+
out = self.do1(down)
|
| 69 |
+
out = self.ops(out)
|
| 70 |
+
out = self.relu2(torch.add(out, down))
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
class UpTransition(nn.Module):
|
| 74 |
+
def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
|
| 75 |
+
super(UpTransition, self).__init__()
|
| 76 |
+
self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
|
| 77 |
+
|
| 78 |
+
self.bn1 = torch.nn.BatchNorm3d(outChans // 2)
|
| 79 |
+
self.do1 = passthrough
|
| 80 |
+
self.do2 = nn.Dropout3d()
|
| 81 |
+
self.relu1 = ELUCons(elu, outChans // 2)
|
| 82 |
+
self.relu2 = ELUCons(elu, outChans)
|
| 83 |
+
if dropout:
|
| 84 |
+
self.do1 = nn.Dropout3d()
|
| 85 |
+
self.ops = _make_nConv(outChans, nConvs, elu)
|
| 86 |
+
|
| 87 |
+
def forward(self, x, skipx):
|
| 88 |
+
out = self.do1(x)
|
| 89 |
+
skipxdo = self.do2(skipx)
|
| 90 |
+
out = self.relu1(self.bn1(self.up_conv(out)))
|
| 91 |
+
xcat = torch.cat((out, skipxdo), 1)
|
| 92 |
+
out = self.ops(xcat)
|
| 93 |
+
out = self.relu2(torch.add(out, xcat))
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
class OutputTransition(nn.Module):
|
| 97 |
+
def __init__(self, in_channels, classes, elu):
|
| 98 |
+
super(OutputTransition, self).__init__()
|
| 99 |
+
self.classes = classes
|
| 100 |
+
self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)
|
| 101 |
+
self.bn1 = torch.nn.BatchNorm3d(classes)
|
| 102 |
+
|
| 103 |
+
self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)
|
| 104 |
+
self.relu1 = ELUCons(elu, classes)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
# convolve 32 down to channels as the desired classes
|
| 108 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 109 |
+
out = self.conv2(out)
|
| 110 |
+
return out
|
vnet_light_arch.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vnet_blocks import *
|
| 2 |
+
|
| 3 |
+
class VNetLight(BaseModel):
|
| 4 |
+
"""
|
| 5 |
+
A lighter version of Vnet that skips down_tr256 and up_tr256 in oreder to reduce time and space complexity
|
| 6 |
+
"""
|
| 7 |
+
def __init__(self, elu=True, in_channels=1, classes=4):
|
| 8 |
+
super(VNetLight, self).__init__()
|
| 9 |
+
self.classes = classes
|
| 10 |
+
self.in_channels = in_channels
|
| 11 |
+
|
| 12 |
+
self.in_tr = InputTransition(in_channels, elu)
|
| 13 |
+
self.down_tr32 = DownTransition(16, 1, elu)
|
| 14 |
+
self.down_tr64 = DownTransition(32, 2, elu)
|
| 15 |
+
self.down_tr128 = DownTransition(64, 3, elu, dropout=True)
|
| 16 |
+
self.up_tr128 = UpTransition(128, 128, 2, elu, dropout=True)
|
| 17 |
+
self.up_tr64 = UpTransition(128, 64, 1, elu)
|
| 18 |
+
self.up_tr32 = UpTransition(64, 32, 1, elu)
|
| 19 |
+
self.out_tr = OutputTransition(32, classes, elu)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
out16 = self.in_tr(x)
|
| 23 |
+
out32 = self.down_tr32(out16)
|
| 24 |
+
out64 = self.down_tr64(out32)
|
| 25 |
+
out128 = self.down_tr128(out64)
|
| 26 |
+
out = self.up_tr128(out128, out64)
|
| 27 |
+
out = self.up_tr64(out, out32)
|
| 28 |
+
out = self.up_tr32(out, out16)
|
| 29 |
+
out = self.out_tr(out)
|
| 30 |
+
return out
|
| 31 |
+
|
| 32 |
+
def test(self,device='cpu'):
|
| 33 |
+
input_tensor = torch.rand(1, self.in_channels, 32, 32, 32)
|
| 34 |
+
ideal_out = torch.rand(1, self.classes, 32, 32, 32)
|
| 35 |
+
out = self.forward(input_tensor)
|
| 36 |
+
assert ideal_out.shape == out.shape
|
| 37 |
+
summary(self.to(torch.device(device)), (self.in_channels, 32, 32, 32),device=device)
|
| 38 |
+
print("Vnet Light test is complete")
|