Divyanshu Tak
Initial commit of BrainIAC Docker application
f5288df
import numpy as np
import torch
from HD_BET.utils import SetNetworkToVal, softmax_helper
from abc import abstractmethod
from HD_BET.network_architecture import Network
class BaseConfig(object):
def __init__(self):
pass
@abstractmethod
def get_split(self, fold, random_state=12345):
pass
@abstractmethod
def get_network(self, mode="train"):
pass
@abstractmethod
def get_basic_generators(self, fold):
pass
@abstractmethod
def get_data_generators(self, fold):
pass
def preprocess(self, data):
return data
def __repr__(self):
res = ""
for v in vars(self):
if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
res += (v + ": " + str(self.__getattribute__(v)) + "\n")
return res
class HD_BET_Config(BaseConfig):
def __init__(self):
super(HD_BET_Config, self).__init__()
self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
# network parameters
self.net_base_num_layers = 21
self.BATCH_SIZE = 2
self.net_do_DS = True
self.net_dropout_p = 0.0
self.net_use_inst_norm = True
self.net_conv_use_bias = True
self.net_norm_use_affine = True
self.net_leaky_relu_slope = 1e-1
# hyperparameters
self.INPUT_PATCH_SIZE = (128, 128, 128)
self.num_classes = 2
self.selected_data_channels = range(1)
# data augmentation
self.da_mirror_axes = (2, 3, 4)
# validation
self.val_use_DO = False
self.val_use_train_mode = False # for dropout sampling
self.val_num_repeats = 1 # only useful if dropout sampling
self.val_batch_size = 1 # only useful if dropout sampling
self.val_save_npz = True
self.val_do_mirroring = True # test time data augmentation via mirroring
self.val_write_images = True
self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
self.val_min_size = self.INPUT_PATCH_SIZE
self.val_fn = None
# CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
# stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
# to false in 0.4)
self.val_use_moving_averages = False
def get_network(self, train=True, pretrained_weights=None):
net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
self.net_norm_use_affine, True, self.net_do_DS)
if pretrained_weights is not None:
net.load_state_dict(
torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
if train:
net.train(True)
else:
net.train(False)
net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
net.do_ds = False
optimizer = None
self.lr_scheduler = None
return net, optimizer
def get_data_generators(self, fold):
pass
def get_split(self, fold, random_state=12345):
pass
def get_basic_generators(self, fold):
pass
def on_epoch_end(self, epoch):
pass
def preprocess(self, data):
data = np.copy(data)
for c in range(data.shape[0]):
data[c] -= data[c].mean()
data[c] /= data[c].std()
return data
config = HD_BET_Config