|
import numpy as np |
|
import torch |
|
from HD_BET.utils import SetNetworkToVal, softmax_helper |
|
from abc import abstractmethod |
|
from HD_BET.network_architecture import Network |
|
|
|
|
|
class BaseConfig(object): |
|
def __init__(self): |
|
pass |
|
|
|
@abstractmethod |
|
def get_split(self, fold, random_state=12345): |
|
pass |
|
|
|
@abstractmethod |
|
def get_network(self, mode="train"): |
|
pass |
|
|
|
@abstractmethod |
|
def get_basic_generators(self, fold): |
|
pass |
|
|
|
@abstractmethod |
|
def get_data_generators(self, fold): |
|
pass |
|
|
|
def preprocess(self, data): |
|
return data |
|
|
|
def __repr__(self): |
|
res = "" |
|
for v in vars(self): |
|
if not v.startswith("__") and not v.startswith("_") and v != 'dataset': |
|
res += (v + ": " + str(self.__getattribute__(v)) + "\n") |
|
return res |
|
|
|
|
|
class HD_BET_Config(BaseConfig): |
|
def __init__(self): |
|
super(HD_BET_Config, self).__init__() |
|
|
|
self.EXPERIMENT_NAME = self.__class__.__name__ |
|
|
|
|
|
self.net_base_num_layers = 21 |
|
self.BATCH_SIZE = 2 |
|
self.net_do_DS = True |
|
self.net_dropout_p = 0.0 |
|
self.net_use_inst_norm = True |
|
self.net_conv_use_bias = True |
|
self.net_norm_use_affine = True |
|
self.net_leaky_relu_slope = 1e-1 |
|
|
|
|
|
self.INPUT_PATCH_SIZE = (128, 128, 128) |
|
self.num_classes = 2 |
|
self.selected_data_channels = range(1) |
|
|
|
|
|
self.da_mirror_axes = (2, 3, 4) |
|
|
|
|
|
self.val_use_DO = False |
|
self.val_use_train_mode = False |
|
self.val_num_repeats = 1 |
|
self.val_batch_size = 1 |
|
self.val_save_npz = True |
|
self.val_do_mirroring = True |
|
self.val_write_images = True |
|
self.net_input_must_be_divisible_by = 16 |
|
self.val_min_size = self.INPUT_PATCH_SIZE |
|
self.val_fn = None |
|
|
|
|
|
|
|
|
|
self.val_use_moving_averages = False |
|
|
|
def get_network(self, train=True, pretrained_weights=None): |
|
net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers, |
|
self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias, |
|
self.net_norm_use_affine, True, self.net_do_DS) |
|
|
|
if pretrained_weights is not None: |
|
net.load_state_dict( |
|
torch.load(pretrained_weights, map_location=lambda storage, loc: storage)) |
|
|
|
if train: |
|
net.train(True) |
|
else: |
|
net.train(False) |
|
net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages)) |
|
net.do_ds = False |
|
|
|
optimizer = None |
|
self.lr_scheduler = None |
|
return net, optimizer |
|
|
|
def get_data_generators(self, fold): |
|
pass |
|
|
|
def get_split(self, fold, random_state=12345): |
|
pass |
|
|
|
def get_basic_generators(self, fold): |
|
pass |
|
|
|
def on_epoch_end(self, epoch): |
|
pass |
|
|
|
def preprocess(self, data): |
|
data = np.copy(data) |
|
for c in range(data.shape[0]): |
|
data[c] -= data[c].mean() |
|
data[c] /= data[c].std() |
|
return data |
|
|
|
|
|
config = HD_BET_Config |
|
|
|
|