Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from HD_BET.utils import SetNetworkToVal, softmax_helper | |
| from abc import abstractmethod | |
| from HD_BET.network_architecture import Network | |
| class BaseConfig(object): | |
| def __init__(self): | |
| pass | |
| def get_split(self, fold, random_state=12345): | |
| pass | |
| def get_network(self, mode="train"): | |
| pass | |
| def get_basic_generators(self, fold): | |
| pass | |
| def get_data_generators(self, fold): | |
| pass | |
| def preprocess(self, data): | |
| return data | |
| def __repr__(self): | |
| res = "" | |
| for v in vars(self): | |
| if not v.startswith("__") and not v.startswith("_") and v != 'dataset': | |
| res += (v + ": " + str(self.__getattribute__(v)) + "\n") | |
| return res | |
| class HD_BET_Config(BaseConfig): | |
| def __init__(self): | |
| super(HD_BET_Config, self).__init__() | |
| self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name | |
| # network parameters | |
| self.net_base_num_layers = 21 | |
| self.BATCH_SIZE = 2 | |
| self.net_do_DS = True | |
| self.net_dropout_p = 0.0 | |
| self.net_use_inst_norm = True | |
| self.net_conv_use_bias = True | |
| self.net_norm_use_affine = True | |
| self.net_leaky_relu_slope = 1e-1 | |
| # hyperparameters | |
| self.INPUT_PATCH_SIZE = (128, 128, 128) | |
| self.num_classes = 2 | |
| self.selected_data_channels = range(1) | |
| # data augmentation | |
| self.da_mirror_axes = (2, 3, 4) | |
| # validation | |
| self.val_use_DO = False | |
| self.val_use_train_mode = False # for dropout sampling | |
| self.val_num_repeats = 1 # only useful if dropout sampling | |
| self.val_batch_size = 1 # only useful if dropout sampling | |
| self.val_save_npz = True | |
| self.val_do_mirroring = True # test time data augmentation via mirroring | |
| self.val_write_images = True | |
| self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property | |
| self.val_min_size = self.INPUT_PATCH_SIZE | |
| self.val_fn = None | |
| # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_ | |
| # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults | |
| # to false in 0.4) | |
| self.val_use_moving_averages = False | |
| def get_network(self, train=True, pretrained_weights=None): | |
| net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers, | |
| self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias, | |
| self.net_norm_use_affine, True, self.net_do_DS) | |
| if pretrained_weights is not None: | |
| net.load_state_dict( | |
| torch.load(pretrained_weights, map_location=lambda storage, loc: storage)) | |
| if train: | |
| net.train(True) | |
| else: | |
| net.train(False) | |
| net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages)) | |
| net.do_ds = False | |
| optimizer = None | |
| self.lr_scheduler = None | |
| return net, optimizer | |
| def get_data_generators(self, fold): | |
| pass | |
| def get_split(self, fold, random_state=12345): | |
| pass | |
| def get_basic_generators(self, fold): | |
| pass | |
| def on_epoch_end(self, epoch): | |
| pass | |
| def preprocess(self, data): | |
| data = np.copy(data) | |
| for c in range(data.shape[0]): | |
| data[c] -= data[c].mean() | |
| data[c] /= data[c].std() | |
| return data | |
| config = HD_BET_Config | |