|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import shutil |
|
from _warnings import warn |
|
from collections import OrderedDict |
|
from multiprocessing import Pool |
|
from time import sleep, time |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p, join, subfiles, isfile, load_pickle, \ |
|
save_json |
|
from nnunet.configuration import default_num_threads |
|
from nnunet.evaluation.evaluator import aggregate_scores |
|
from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax |
|
from nnunet.network_architecture.neural_network import SegmentationNetwork |
|
from nnunet.postprocessing.connected_components import determine_postprocessing |
|
from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation |
|
from nnunet.training.dataloading.dataset_loading import unpack_dataset |
|
from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss |
|
from nnunet.training.loss_functions.dice_loss import get_tp_fp_fn_tn |
|
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2 |
|
from nnunet.utilities.distributed import awesome_allgather_function |
|
from nnunet.utilities.nd_softmax import softmax_helper |
|
from nnunet.utilities.tensor_utilities import sum_tensor |
|
from nnunet.utilities.to_torch import to_cuda, maybe_to_torch |
|
from torch import nn, distributed |
|
from torch.backends import cudnn |
|
from torch.cuda.amp import autocast |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
from tqdm import trange |
|
|
|
|
|
class nnUNetTrainerV2_DDP(nnUNetTrainerV2): |
|
def __init__(self, plans_file, fold, local_rank, output_folder=None, dataset_directory=None, batch_dice=True, |
|
stage=None, |
|
unpack_data=True, deterministic=True, distribute_batch_size=False, fp16=False): |
|
super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, |
|
unpack_data, deterministic, fp16) |
|
self.init_args = ( |
|
plans_file, fold, local_rank, output_folder, dataset_directory, batch_dice, stage, unpack_data, |
|
deterministic, distribute_batch_size, fp16) |
|
self.distribute_batch_size = distribute_batch_size |
|
np.random.seed(local_rank) |
|
torch.manual_seed(local_rank) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(local_rank) |
|
self.local_rank = local_rank |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.set_device(local_rank) |
|
dist.init_process_group(backend='nccl', init_method='env://') |
|
|
|
self.loss = None |
|
self.ce_loss = RobustCrossEntropyLoss() |
|
|
|
self.global_batch_size = None |
|
|
|
def set_batch_size_and_oversample(self): |
|
batch_sizes = [] |
|
oversample_percents = [] |
|
|
|
world_size = dist.get_world_size() |
|
my_rank = dist.get_rank() |
|
|
|
if self.distribute_batch_size: |
|
self.global_batch_size = self.batch_size |
|
else: |
|
self.global_batch_size = self.batch_size * world_size |
|
|
|
batch_size_per_GPU = np.ceil(self.batch_size / world_size).astype(int) |
|
|
|
for rank in range(world_size): |
|
if self.distribute_batch_size: |
|
if (rank + 1) * batch_size_per_GPU > self.batch_size: |
|
batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - self.batch_size) |
|
else: |
|
batch_size = batch_size_per_GPU |
|
else: |
|
batch_size = self.batch_size |
|
|
|
batch_sizes.append(batch_size) |
|
|
|
sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1]) |
|
sample_id_high = np.sum(batch_sizes) |
|
|
|
if sample_id_high / self.global_batch_size < (1 - self.oversample_foreground_percent): |
|
oversample_percents.append(0.0) |
|
elif sample_id_low / self.global_batch_size > (1 - self.oversample_foreground_percent): |
|
oversample_percents.append(1.0) |
|
else: |
|
percent_covered_by_this_rank = sample_id_high / self.global_batch_size - sample_id_low / self.global_batch_size |
|
oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) - |
|
sample_id_low / self.global_batch_size) / percent_covered_by_this_rank) |
|
oversample_percents.append(oversample_percent_here) |
|
|
|
print("worker", my_rank, "oversample", oversample_percents[my_rank]) |
|
print("worker", my_rank, "batch_size", batch_sizes[my_rank]) |
|
|
|
self.batch_size = batch_sizes[my_rank] |
|
self.oversample_foreground_percent = oversample_percents[my_rank] |
|
|
|
def save_checkpoint(self, fname, save_optimizer=True): |
|
if self.local_rank == 0: |
|
super().save_checkpoint(fname, save_optimizer) |
|
|
|
def plot_progress(self): |
|
if self.local_rank == 0: |
|
super().plot_progress() |
|
|
|
def print_to_log_file(self, *args, also_print_to_console=True): |
|
if self.local_rank == 0: |
|
super().print_to_log_file(*args, also_print_to_console=also_print_to_console) |
|
|
|
def process_plans(self, plans): |
|
super().process_plans(plans) |
|
self.set_batch_size_and_oversample() |
|
|
|
def initialize(self, training=True, force_load_plans=False): |
|
""" |
|
:param training: |
|
:return: |
|
""" |
|
if not self.was_initialized: |
|
maybe_mkdir_p(self.output_folder) |
|
|
|
if force_load_plans or (self.plans is None): |
|
self.load_plans_file() |
|
|
|
self.process_plans(self.plans) |
|
|
|
self.setup_DA_params() |
|
|
|
self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + |
|
"_stage%d" % self.stage) |
|
if training: |
|
self.dl_tr, self.dl_val = self.get_basic_generators() |
|
if self.unpack_data: |
|
if self.local_rank == 0: |
|
print("unpacking dataset") |
|
unpack_dataset(self.folder_with_preprocessed_data) |
|
print("done") |
|
distributed.barrier() |
|
else: |
|
print( |
|
"INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " |
|
"will wait all winter for your model to finish!") |
|
|
|
|
|
net_numpool = len(self.net_num_pool_op_kernel_sizes) |
|
|
|
|
|
|
|
weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) |
|
|
|
|
|
mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)]) |
|
weights[~mask] = 0 |
|
weights = weights / weights.sum() |
|
self.ds_loss_weights = weights |
|
|
|
seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads')) |
|
seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1)) |
|
print("seeds train", seeds_train) |
|
print("seeds_val", seeds_val) |
|
self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val, |
|
self.data_aug_params[ |
|
'patch_size_for_spatialtransform'], |
|
self.data_aug_params, |
|
deep_supervision_scales=self.deep_supervision_scales, |
|
seeds_train=seeds_train, |
|
seeds_val=seeds_val, |
|
pin_memory=self.pin_memory) |
|
self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), |
|
also_print_to_console=False) |
|
self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), |
|
also_print_to_console=False) |
|
else: |
|
pass |
|
|
|
self.initialize_network() |
|
self.initialize_optimizer_and_scheduler() |
|
self.network = DDP(self.network, device_ids=[self.local_rank]) |
|
|
|
else: |
|
self.print_to_log_file('self.was_initialized is True, not running self.initialize again') |
|
self.was_initialized = True |
|
|
|
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): |
|
data_dict = next(data_generator) |
|
data = data_dict['data'] |
|
target = data_dict['target'] |
|
|
|
data = maybe_to_torch(data) |
|
target = maybe_to_torch(target) |
|
|
|
if torch.cuda.is_available(): |
|
data = to_cuda(data, gpu_id=None) |
|
target = to_cuda(target, gpu_id=None) |
|
|
|
self.optimizer.zero_grad() |
|
|
|
if self.fp16: |
|
with autocast(): |
|
output = self.network(data) |
|
del data |
|
l = self.compute_loss(output, target) |
|
|
|
if do_backprop: |
|
self.amp_grad_scaler.scale(l).backward() |
|
self.amp_grad_scaler.unscale_(self.optimizer) |
|
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) |
|
self.amp_grad_scaler.step(self.optimizer) |
|
self.amp_grad_scaler.update() |
|
else: |
|
output = self.network(data) |
|
del data |
|
l = self.compute_loss(output, target) |
|
|
|
if do_backprop: |
|
l.backward() |
|
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) |
|
self.optimizer.step() |
|
|
|
if run_online_evaluation: |
|
self.run_online_evaluation(output, target) |
|
|
|
del target |
|
|
|
return l.detach().cpu().numpy() |
|
|
|
def compute_loss(self, output, target): |
|
total_loss = None |
|
for i in range(len(output)): |
|
|
|
axes = tuple(range(2, len(output[i].size()))) |
|
|
|
|
|
output_softmax = softmax_helper(output[i]) |
|
|
|
|
|
tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None) |
|
|
|
|
|
nominator = 2 * tp[:, 1:] |
|
denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:] |
|
|
|
if self.batch_dice: |
|
|
|
nominator = awesome_allgather_function.apply(nominator) |
|
denominator = awesome_allgather_function.apply(denominator) |
|
nominator = nominator.sum(0) |
|
denominator = denominator.sum(0) |
|
else: |
|
pass |
|
|
|
ce_loss = self.ce_loss(output[i], target[i][:, 0].long()) |
|
|
|
|
|
dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean() |
|
if total_loss is None: |
|
total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss) |
|
else: |
|
total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss) |
|
return total_loss |
|
|
|
def run_online_evaluation(self, output, target): |
|
with torch.no_grad(): |
|
num_classes = output[0].shape[1] |
|
output_seg = output[0].argmax(1) |
|
target = target[0][:, 0] |
|
axes = tuple(range(1, len(target.shape))) |
|
tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) |
|
fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) |
|
fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) |
|
for c in range(1, num_classes): |
|
tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes) |
|
fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes) |
|
fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes) |
|
|
|
|
|
|
|
|
|
tp_hard = tp_hard.sum(0, keepdim=False)[None] |
|
fp_hard = fp_hard.sum(0, keepdim=False)[None] |
|
fn_hard = fn_hard.sum(0, keepdim=False)[None] |
|
|
|
tp_hard = awesome_allgather_function.apply(tp_hard) |
|
fp_hard = awesome_allgather_function.apply(fp_hard) |
|
fn_hard = awesome_allgather_function.apply(fn_hard) |
|
|
|
tp_hard = tp_hard.detach().cpu().numpy().sum(0) |
|
fp_hard = fp_hard.detach().cpu().numpy().sum(0) |
|
fn_hard = fn_hard.detach().cpu().numpy().sum(0) |
|
self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) |
|
self.online_eval_tp.append(list(tp_hard)) |
|
self.online_eval_fp.append(list(fp_hard)) |
|
self.online_eval_fn.append(list(fn_hard)) |
|
|
|
def run_training(self): |
|
""" |
|
if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first |
|
continued epoch with self.initial_lr |
|
|
|
we also need to make sure deep supervision in the network is enabled for training, thus the wrapper |
|
:return: |
|
""" |
|
if self.local_rank == 0: |
|
self.save_debug_information() |
|
|
|
if not torch.cuda.is_available(): |
|
self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!") |
|
|
|
self.maybe_update_lr(self.epoch) |
|
|
|
if isinstance(self.network, DDP): |
|
net = self.network.module |
|
else: |
|
net = self.network |
|
ds = net.do_ds |
|
net.do_ds = True |
|
|
|
_ = self.tr_gen.next() |
|
_ = self.val_gen.next() |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
self._maybe_init_amp() |
|
|
|
maybe_mkdir_p(self.output_folder) |
|
self.plot_network_architecture() |
|
|
|
if cudnn.benchmark and cudnn.deterministic: |
|
warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. " |
|
"But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! " |
|
"If you want deterministic then set benchmark=False") |
|
|
|
if not self.was_initialized: |
|
self.initialize(True) |
|
|
|
while self.epoch < self.max_num_epochs: |
|
self.print_to_log_file("\nepoch: ", self.epoch) |
|
epoch_start_time = time() |
|
train_losses_epoch = [] |
|
|
|
|
|
self.network.train() |
|
|
|
if self.use_progress_bar: |
|
with trange(self.num_batches_per_epoch) as tbar: |
|
for b in tbar: |
|
tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs)) |
|
|
|
l = self.run_iteration(self.tr_gen, True) |
|
|
|
tbar.set_postfix(loss=l) |
|
train_losses_epoch.append(l) |
|
else: |
|
for _ in range(self.num_batches_per_epoch): |
|
l = self.run_iteration(self.tr_gen, True) |
|
train_losses_epoch.append(l) |
|
|
|
self.all_tr_losses.append(np.mean(train_losses_epoch)) |
|
self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1]) |
|
|
|
with torch.no_grad(): |
|
|
|
self.network.eval() |
|
val_losses = [] |
|
for b in range(self.num_val_batches_per_epoch): |
|
l = self.run_iteration(self.val_gen, False, True) |
|
val_losses.append(l) |
|
self.all_val_losses.append(np.mean(val_losses)) |
|
self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1]) |
|
|
|
if self.also_val_in_tr_mode: |
|
self.network.train() |
|
|
|
val_losses = [] |
|
for b in range(self.num_val_batches_per_epoch): |
|
l = self.run_iteration(self.val_gen, False) |
|
val_losses.append(l) |
|
self.all_val_losses_tr_mode.append(np.mean(val_losses)) |
|
self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1]) |
|
|
|
self.update_train_loss_MA() |
|
|
|
continue_training = self.on_epoch_end() |
|
|
|
epoch_end_time = time() |
|
|
|
if not continue_training: |
|
|
|
break |
|
|
|
self.epoch += 1 |
|
self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time)) |
|
|
|
self.epoch -= 1 |
|
|
|
if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model")) |
|
|
|
if self.local_rank == 0: |
|
|
|
if isfile(join(self.output_folder, "model_latest.model")): |
|
os.remove(join(self.output_folder, "model_latest.model")) |
|
if isfile(join(self.output_folder, "model_latest.model.pkl")): |
|
os.remove(join(self.output_folder, "model_latest.model.pkl")) |
|
|
|
net.do_ds = ds |
|
|
|
def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, |
|
step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, |
|
validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, |
|
segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True): |
|
if isinstance(self.network, DDP): |
|
net = self.network.module |
|
else: |
|
net = self.network |
|
ds = net.do_ds |
|
net.do_ds = False |
|
|
|
current_mode = self.network.training |
|
self.network.eval() |
|
|
|
assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)" |
|
if self.dataset_val is None: |
|
self.load_dataset() |
|
self.do_split() |
|
|
|
if segmentation_export_kwargs is None: |
|
if 'segmentation_export_params' in self.plans.keys(): |
|
force_separate_z = self.plans['segmentation_export_params']['force_separate_z'] |
|
interpolation_order = self.plans['segmentation_export_params']['interpolation_order'] |
|
interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z'] |
|
else: |
|
force_separate_z = None |
|
interpolation_order = 1 |
|
interpolation_order_z = 0 |
|
else: |
|
force_separate_z = segmentation_export_kwargs['force_separate_z'] |
|
interpolation_order = segmentation_export_kwargs['interpolation_order'] |
|
interpolation_order_z = segmentation_export_kwargs['interpolation_order_z'] |
|
|
|
|
|
output_folder = join(self.output_folder, validation_folder_name) |
|
maybe_mkdir_p(output_folder) |
|
|
|
my_input_args = {'do_mirroring': do_mirroring, |
|
'use_sliding_window': use_sliding_window, |
|
'step_size': step_size, |
|
'save_softmax': save_softmax, |
|
'use_gaussian': use_gaussian, |
|
'overwrite': overwrite, |
|
'validation_folder_name': validation_folder_name, |
|
'debug': debug, |
|
'all_in_gpu': all_in_gpu, |
|
'segmentation_export_kwargs': segmentation_export_kwargs, |
|
} |
|
save_json(my_input_args, join(output_folder, "validation_args.json")) |
|
|
|
if do_mirroring: |
|
if not self.data_aug_params['do_mirror']: |
|
raise RuntimeError( |
|
"We did not train with mirroring so you cannot do inference with mirroring enabled") |
|
mirror_axes = self.data_aug_params['mirror_axes'] |
|
else: |
|
mirror_axes = () |
|
|
|
pred_gt_tuples = [] |
|
|
|
export_pool = Pool(default_num_threads) |
|
results = [] |
|
|
|
all_keys = list(self.dataset_val.keys()) |
|
my_keys = all_keys[self.local_rank::dist.get_world_size()] |
|
|
|
|
|
for k in my_keys: |
|
properties = load_pickle(self.dataset[k]['properties_file']) |
|
fname = properties['list_of_data_files'][0].split("/")[-1][:-12] |
|
pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"), |
|
join(self.gt_niftis_folder, fname + ".nii.gz")]) |
|
if k in my_keys: |
|
if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \ |
|
(save_softmax and not isfile(join(output_folder, fname + ".npz"))): |
|
data = np.load(self.dataset[k]['data_file'])['data'] |
|
|
|
print(k, data.shape) |
|
data[-1][data[-1] == -1] = 0 |
|
|
|
softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1], |
|
do_mirroring=do_mirroring, |
|
mirror_axes=mirror_axes, |
|
use_sliding_window=use_sliding_window, |
|
step_size=step_size, |
|
use_gaussian=use_gaussian, |
|
all_in_gpu=all_in_gpu, |
|
mixed_precision=self.fp16)[1] |
|
|
|
softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward]) |
|
|
|
if save_softmax: |
|
softmax_fname = join(output_folder, fname + ".npz") |
|
else: |
|
softmax_fname = None |
|
|
|
"""There is a problem with python process communication that prevents us from communicating obejcts |
|
larger than 2 GB between processes (basically when the length of the pickle string that will be sent is |
|
communicated by the multiprocessing.Pipe object then the placeholder (\%i I think) does not allow for long |
|
enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually |
|
patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will |
|
then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either |
|
filename or np.ndarray and will handle this automatically""" |
|
if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): |
|
np.save(join(output_folder, fname + ".npy"), softmax_pred) |
|
softmax_pred = join(output_folder, fname + ".npy") |
|
|
|
results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax, |
|
((softmax_pred, join(output_folder, fname + ".nii.gz"), |
|
properties, interpolation_order, |
|
self.regions_class_order, |
|
None, None, |
|
softmax_fname, None, force_separate_z, |
|
interpolation_order_z), |
|
) |
|
) |
|
) |
|
|
|
_ = [i.get() for i in results] |
|
self.print_to_log_file("finished prediction") |
|
|
|
distributed.barrier() |
|
|
|
if self.local_rank == 0: |
|
|
|
self.print_to_log_file("evaluation of raw predictions") |
|
task = self.dataset_directory.split("/")[-1] |
|
job_name = self.experiment_name |
|
_ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)), |
|
json_output_file=join(output_folder, "summary.json"), |
|
json_name=job_name + " val tiled %s" % (str(use_sliding_window)), |
|
json_author="Fabian", |
|
json_task=task, num_threads=default_num_threads) |
|
|
|
if run_postprocessing_on_folds: |
|
|
|
|
|
|
|
|
|
self.print_to_log_file("determining postprocessing") |
|
determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name, |
|
final_subf_name=validation_folder_name + "_postprocessed", debug=debug) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_nifti_folder = join(self.output_folder_base, "gt_niftis") |
|
maybe_mkdir_p(gt_nifti_folder) |
|
for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"): |
|
success = False |
|
attempts = 0 |
|
e = None |
|
while not success and attempts < 10: |
|
try: |
|
shutil.copy(f, gt_nifti_folder) |
|
success = True |
|
except OSError as e: |
|
attempts += 1 |
|
sleep(1) |
|
if not success: |
|
print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder)) |
|
if e is not None: |
|
raise e |
|
|
|
self.network.train(current_mode) |
|
net.do_ds = ds |
|
|
|
def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True, |
|
mirror_axes: Tuple[int] = None, |
|
use_sliding_window: bool = True, step_size: float = 0.5, |
|
use_gaussian: bool = True, pad_border_mode: str = 'constant', |
|
pad_kwargs: dict = None, all_in_gpu: bool = False, |
|
verbose: bool = True, mixed_precision=True) -> Tuple[ |
|
np.ndarray, np.ndarray]: |
|
if pad_border_mode == 'constant' and pad_kwargs is None: |
|
pad_kwargs = {'constant_values': 0} |
|
|
|
if do_mirroring and mirror_axes is None: |
|
mirror_axes = self.data_aug_params['mirror_axes'] |
|
|
|
if do_mirroring: |
|
assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \ |
|
"was done without mirroring" |
|
|
|
valid = list((SegmentationNetwork, nn.DataParallel, DDP)) |
|
assert isinstance(self.network, tuple(valid)) |
|
if isinstance(self.network, DDP): |
|
net = self.network.module |
|
else: |
|
net = self.network |
|
ds = net.do_ds |
|
net.do_ds = False |
|
ret = net.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes, |
|
use_sliding_window=use_sliding_window, step_size=step_size, |
|
patch_size=self.patch_size, regions_class_order=self.regions_class_order, |
|
use_gaussian=use_gaussian, pad_border_mode=pad_border_mode, |
|
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose, |
|
mixed_precision=mixed_precision) |
|
net.do_ds = ds |
|
return ret |
|
|
|
def load_checkpoint_ram(self, checkpoint, train=True): |
|
""" |
|
used for if the checkpoint is already in ram |
|
:param checkpoint: |
|
:param train: |
|
:return: |
|
""" |
|
if not self.was_initialized: |
|
self.initialize(train) |
|
|
|
new_state_dict = OrderedDict() |
|
curr_state_dict_keys = list(self.network.state_dict().keys()) |
|
|
|
|
|
for k, value in checkpoint['state_dict'].items(): |
|
key = k |
|
if key not in curr_state_dict_keys: |
|
print("duh") |
|
key = key[7:] |
|
new_state_dict[key] = value |
|
|
|
if self.fp16: |
|
self._maybe_init_amp() |
|
if 'amp_grad_scaler' in checkpoint.keys(): |
|
self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler']) |
|
|
|
self.network.load_state_dict(new_state_dict) |
|
self.epoch = checkpoint['epoch'] |
|
if train: |
|
optimizer_state_dict = checkpoint['optimizer_state_dict'] |
|
if optimizer_state_dict is not None: |
|
self.optimizer.load_state_dict(optimizer_state_dict) |
|
|
|
if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[ |
|
'lr_scheduler_state_dict'] is not None: |
|
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) |
|
|
|
if issubclass(self.lr_scheduler.__class__, _LRScheduler): |
|
self.lr_scheduler.step(self.epoch) |
|
|
|
self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[ |
|
'plot_stuff'] |
|
|
|
|
|
|
|
|
|
if self.epoch != len(self.all_tr_losses): |
|
self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is " |
|
"due to an old bug and should only appear when you are loading old models. New " |
|
"models should have this fixed! self.epoch is now set to len(self.all_tr_losses)") |
|
self.epoch = len(self.all_tr_losses) |
|
self.all_tr_losses = self.all_tr_losses[:self.epoch] |
|
self.all_val_losses = self.all_val_losses[:self.epoch] |
|
self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch] |
|
self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch] |
|
|