|
import generators |
|
import monai |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
import os |
|
import sys |
|
from pathlib import Path |
|
import pickle |
|
|
|
ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute()) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/utils')) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/loss_function')) |
|
from utils import ( |
|
preview_image, preview_3D_vector_field, preview_3D_deformation, |
|
jacobian_determinant, plot_progress, make_if_dont_exist, save_seg_checkpoint, save_reg_checkpoint, load_latest_checkpoint, |
|
load_best_checkpoint, load_valid_checkpoint, plot_architecture |
|
) |
|
from losses import ( |
|
warp_func, warp_nearest_func, lncc_loss_func, dice_loss_func, reg_losses, dice_loss_func2 |
|
) |
|
|
|
|
|
def swap_training(network_to_train, network_to_not_train): |
|
""" |
|
Switch out of training one network and into training another |
|
""" |
|
|
|
for param in network_to_not_train.parameters(): |
|
param.requires_grad = False |
|
|
|
for param in network_to_train.parameters(): |
|
param.requires_grad = True |
|
|
|
network_to_not_train.eval() |
|
network_to_train.train() |
|
|
|
def train_network(dataloader_train_reg, |
|
dataloader_valid_reg, |
|
dataloader_train_seg, |
|
dataloader_valid_seg, |
|
device, |
|
seg_net, |
|
reg_net, |
|
num_segmentation_classes, |
|
lr_reg, |
|
lr_seg, |
|
lam_a, |
|
lam_sp, |
|
lam_re, |
|
max_epoch, |
|
val_step, |
|
result_seg_path, |
|
result_reg_path, |
|
logger, |
|
img_shape, |
|
plot_network=False, |
|
continue_training=False |
|
): |
|
|
|
|
|
make_if_dont_exist(os.path.join(result_seg_path, 'training_plot')) |
|
make_if_dont_exist(os.path.join(result_reg_path, 'training_plot')) |
|
make_if_dont_exist(os.path.join(result_seg_path, 'model')) |
|
make_if_dont_exist(os.path.join(result_reg_path, 'model')) |
|
make_if_dont_exist(os.path.join(result_seg_path, 'checkpoints')) |
|
make_if_dont_exist(os.path.join(result_reg_path, 'checkpoints')) |
|
|
|
ROOT_DIR = str(Path(result_reg_path).parent.absolute()) |
|
seg_availabilities = ['00', '01', '10', '11'] |
|
batch_generator_train_reg = generators.create_batch_generator( |
|
dataloader_train_reg) |
|
batch_generator_valid_reg = generators.create_batch_generator( |
|
dataloader_valid_reg) |
|
seg_train_sampling_weights = [ |
|
0] + [len(dataloader_train_reg[s]) for s in seg_availabilities[1:]] |
|
print('----------'*10) |
|
print(f"""When training seg_net alone, segmentation availabilities {seg_availabilities} |
|
will be sampled with respective weights {seg_train_sampling_weights}""") |
|
batch_generator_train_seg = generators.create_batch_generator( |
|
dataloader_train_reg, seg_train_sampling_weights) |
|
seg_net = seg_net.to(device) |
|
reg_net = reg_net.to(device) |
|
|
|
learning_rate_reg = lr_reg |
|
optimizer_reg = torch.optim.Adam(reg_net.parameters(), learning_rate_reg) |
|
scheduler_reg = torch.optim.lr_scheduler.StepLR(optimizer_reg, step_size=70, gamma=0.2, verbose=True) |
|
learning_rate_seg = lr_seg |
|
optimizer_seg = torch.optim.Adam(seg_net.parameters(), learning_rate_seg) |
|
scheduler_seg = torch.optim.lr_scheduler.StepLR(optimizer_seg, step_size=50, gamma=0.2, verbose=True) |
|
last_epoch = 0 |
|
|
|
training_losses_reg = [] |
|
validation_losses_reg = [] |
|
training_losses_seg = [] |
|
validation_losses_seg = [] |
|
regularization_loss_reg = [] |
|
anatomy_loss_reg = [] |
|
similarity_loss_reg = [] |
|
supervised_loss_seg = [] |
|
anatomy_loss_seg = [] |
|
best_seg_validation_loss = float('inf') |
|
best_reg_validation_loss = float('inf') |
|
|
|
last_epoch_valid = 0 |
|
if continue_training: |
|
if os.path.exists(os.path.join(result_seg_path, 'checkpoints', 'valid_checkpoint.pth')) and os.path.exists(os.path.join(result_reg_path, 'checkpoints', 'valid_checkpoint.pth')): |
|
if os.path.exists(os.path.join(result_seg_path, 'checkpoints', 'best_checkpoint.pth')) and os.path.exists(os.path.join(result_reg_path, 'checkpoints', 'best_checkpoint.pth')): |
|
best_seg_validation_loss = load_best_checkpoint(os.path.join(result_reg_path, 'checkpoints'), device) |
|
best_reg_validation_loss = load_best_checkpoint(os.path.join(result_seg_path, 'checkpoints'), device) |
|
|
|
all_validation_losses_reg = load_valid_checkpoint(os.path.join(result_reg_path, 'checkpoints'), device) |
|
all_validation_losses_seg = load_valid_checkpoint(os.path.join(result_seg_path, 'checkpoints'), device) |
|
validation_losses_reg = all_validation_losses_reg['total_loss'] |
|
validation_losses_seg = all_validation_losses_seg['total_loss'] |
|
last_epoch_valid = np.minimum(len(validation_losses_reg), len(validation_losses_seg)) |
|
validation_losses_reg = validation_losses_reg[:last_epoch_valid] |
|
validation_losses_seg = validation_losses_seg[:last_epoch_valid] |
|
np_validation_losses_reg = np.array(validation_losses_reg) |
|
np_validation_losses_seg = np.array(validation_losses_seg) |
|
if best_reg_validation_loss not in np_validation_losses_reg[:, 1]: |
|
best_reg_validation_loss = np.min(np_validation_losses_reg[:, 1]) |
|
if os.path.exists(os.path.join(result_reg_path, 'model', 'reg_net_last_best.pth')): |
|
assert os.path.exists(os.path.join(result_reg_path, 'model', 'reg_net_best.pth')) |
|
os.remove(os.path.join(result_reg_path, 'model', 'reg_net_best.pth')) |
|
os.rename(os.path.join(result_reg_path, 'model', 'reg_net_last_best.pth'), os.path.join(result_reg_path, 'model', 'reg_net_best.pth')) |
|
if best_seg_validation_loss not in np_validation_losses_seg[:, 1]: |
|
best_seg_validation_loss = np.min(np_validation_losses_seg[:, 1]) |
|
if os.path.exists(os.path.join(result_seg_path, 'model', 'seg_net_last_best.pth')): |
|
assert os.path.exists(os.path.join(result_seg_path, 'model', 'seg_net_best.pth')) |
|
os.remove(os.path.join(result_seg_path, 'model', 'seg_net_best.pth')) |
|
os.rename(os.path.join(result_seg_path, 'model', 'seg_net_last_best.pth'), os.path.join(result_seg_path, 'model', 'seg_net_best.pth')) |
|
else: |
|
if os.path.exists(os.path.join(result_seg_path, 'checkpoints', 'valid_checkpoint.pth')): |
|
os.remove(os.path.join(result_seg_path, 'checkpoints', 'valid_checkpoint.pth')) |
|
elif os.path.exists(os.path.join(result_reg_path, 'checkpoints', 'valid_checkpoint.pth')): |
|
os.remove(os.path.join(result_reg_path, 'checkpoints', 'valid_checkpoint.pth')) |
|
|
|
if last_epoch_valid != 0 and os.path.exists(os.path.join(result_seg_path, 'checkpoints', 'latest_checkpoint.pth')) and os.path.exists(os.path.join(result_reg_path, 'checkpoints', 'latest_checkpoint.pth')): |
|
reg_net, optimizer_reg, all_training_losses_reg = load_latest_checkpoint(os.path.join(result_reg_path, 'checkpoints'), reg_net, optimizer_reg, device) |
|
seg_net, optimizer_seg, all_training_losses_seg = load_latest_checkpoint(os.path.join(result_seg_path, 'checkpoints'), seg_net, optimizer_seg, device) |
|
regularization_loss_reg = all_training_losses_reg['regular_loss'] |
|
anatomy_loss_reg = all_training_losses_reg['ana_loss'] |
|
similarity_loss_reg = all_training_losses_reg['sim_loss'] |
|
supervised_loss_seg = all_training_losses_seg['super_loss'] |
|
anatomy_loss_seg = all_training_losses_seg['ana_loss'] |
|
training_losses_reg = all_training_losses_reg['total_loss'] |
|
training_losses_seg = all_training_losses_seg['total_loss'] |
|
last_epoch_train = np.min(np.array([last_epoch_valid * val_step, len(training_losses_reg), len(training_losses_seg)])) |
|
regularization_loss_reg = regularization_loss_reg[:last_epoch_train] |
|
anatomy_loss_reg = anatomy_loss_reg[:last_epoch_train] |
|
similarity_loss_reg = similarity_loss_reg[:last_epoch_train] |
|
supervised_loss_seg = supervised_loss_seg[:last_epoch_train] |
|
anatomy_loss_seg = anatomy_loss_seg[:last_epoch_train] |
|
training_losses_reg = training_losses_reg[:last_epoch_train] |
|
training_losses_seg = training_losses_seg[:last_epoch_train] |
|
|
|
last_epoch = last_epoch_train |
|
else: |
|
if os.path.exists(os.path.join(result_seg_path, 'checkpoints', 'latest_checkpoint.pth')): |
|
os.remove(os.path.join(result_seg_path, 'checkpoints', 'latest_checkpoint.pth')) |
|
elif os.path.exists(os.path.join(result_reg_path, 'checkpoints', 'latest_checkpoint.pth')): |
|
os.remove(os.path.join(result_reg_path, 'checkpoints', 'latest_checkpoint.pth')) |
|
|
|
if len(dataloader_valid_reg) == 0: |
|
validation_losses_reg = [] |
|
|
|
if len(dataloader_valid_seg) == 0: |
|
validation_losses_seg = [] |
|
|
|
lambda_a = lam_a |
|
lambda_sp = lam_sp |
|
|
|
|
|
|
|
|
|
lambda_r = lam_re |
|
|
|
max_epochs = max_epoch |
|
reg_phase_training_batches_per_epoch = 10 |
|
|
|
seg_phase_training_batches_per_epoch = 5 |
|
reg_phase_num_validation_batches_to_use = 10 |
|
val_interval = val_step |
|
if plot_network: |
|
plot_architecture(seg_net, img_shape, seg_phase_training_batches_per_epoch, 'SegNet', result_seg_path) |
|
plot_architecture(reg_net, img_shape, reg_phase_training_batches_per_epoch, 'RegNet', result_reg_path) |
|
|
|
logger.info('Start Training') |
|
|
|
for epoch_number in range(last_epoch, max_epochs): |
|
|
|
logger.info(f"Epoch {epoch_number+1}/{max_epochs}:") |
|
|
|
|
|
|
|
|
|
|
|
swap_training(reg_net, seg_net) |
|
|
|
losses = [] |
|
regularization_loss = [] |
|
similarity_loss = [] |
|
anatomy_loss = [] |
|
for batch in batch_generator_train_reg(reg_phase_training_batches_per_epoch): |
|
optimizer_reg.zero_grad() |
|
loss_sim, loss_reg, loss_ana, df = reg_losses( |
|
batch, device, reg_net, seg_net, num_segmentation_classes) |
|
loss = loss_sim + lambda_r * loss_reg + lambda_a * loss_ana |
|
loss.backward() |
|
optimizer_reg.step() |
|
losses.append(loss.item()) |
|
regularization_loss.append(loss_reg.item()) |
|
similarity_loss.append(loss_sim.item()) |
|
anatomy_loss.append(loss_ana.item()) |
|
|
|
|
|
|
|
training_loss_reg = np.mean(losses) |
|
regularization_loss_reg.append( |
|
[epoch_number+1, np.mean(regularization_loss)]) |
|
similarity_loss_reg.append([epoch_number+1, np.mean(similarity_loss)]) |
|
anatomy_loss_reg.append([epoch_number+1, np.mean(anatomy_loss)]) |
|
logger.info(f"\treg training loss: {training_loss_reg}") |
|
training_losses_reg.append([epoch_number+1, training_loss_reg]) |
|
logger.info("\tsave latest reg_net checkpoint") |
|
save_reg_checkpoint(reg_net, optimizer_reg, epoch_number, training_loss_reg, sim_loss=similarity_loss_reg, regular_loss=regularization_loss_reg, ana_loss=anatomy_loss_reg, total_loss=training_losses_reg, save_dir=os.path.join(result_reg_path, 'checkpoints'), name='latest') |
|
|
|
if len(dataloader_valid_reg) == 0: |
|
logger.info("\tno enough dataset for validation") |
|
save_reg_checkpoint(reg_net, optimizer_reg, epoch_number, training_loss_reg, sim_loss=similarity_loss_reg, regular_loss=regularization_loss_reg, ana_loss=anatomy_loss_reg, total_loss=training_losses_reg, save_dir=os.path.join(result_reg_path, 'checkpoints'), name='best') |
|
save_reg_checkpoint(reg_net, optimizer_reg, epoch_number, training_loss_reg, sim_loss=similarity_loss_reg, regular_loss=regularization_loss_reg, ana_loss=anatomy_loss_reg, total_loss=training_losses_reg, save_dir=os.path.join(result_reg_path, 'checkpoints'), name='valid') |
|
if os.path.exists(os.path.join(result_reg_path, 'model', 'reg_net_best.pth')): |
|
if os.path.exists(os.path.join(result_reg_path, 'model', 'reg_net_last_best.pth')): |
|
os.remove(os.path.join(result_reg_path, 'model', 'reg_net_last_best.pth')) |
|
os.rename(os.path.join(result_reg_path, 'model', 'reg_net_best.pth'), os.path.join(result_reg_path, 'model', 'reg_net_last_best.pth')) |
|
torch.save(reg_net.state_dict(), os.path.join(result_reg_path, 'model', 'reg_net_best.pth')) |
|
else: |
|
if epoch_number % val_interval == 0: |
|
reg_net.eval() |
|
losses = [] |
|
with torch.no_grad(): |
|
for batch in batch_generator_valid_reg(reg_phase_num_validation_batches_to_use): |
|
loss_sim, loss_reg, loss_ana, dv = reg_losses( |
|
batch, device, reg_net, seg_net, num_segmentation_classes) |
|
loss = loss_sim + lambda_r * loss_reg + lambda_a * loss_ana |
|
losses.append(loss.item()) |
|
|
|
validation_loss_reg = np.mean(losses) |
|
logger.info(f"\treg validation loss: {validation_loss_reg}") |
|
validation_losses_reg.append([epoch_number+1, validation_loss_reg]) |
|
|
|
if validation_loss_reg < best_reg_validation_loss: |
|
best_reg_validation_loss = validation_loss_reg |
|
logger.info("\tsave best reg_net checkpoint and model") |
|
save_reg_checkpoint(reg_net, optimizer_reg, epoch_number, best_reg_validation_loss, total_loss=validation_losses_reg, save_dir=os.path.join(result_reg_path, 'checkpoints'), name='best') |
|
if os.path.exists(os.path.join(result_reg_path, 'model', 'reg_net_best.pth')): |
|
if os.path.exists(os.path.join(result_reg_path, 'model', 'reg_net_last_best.pth')): |
|
os.remove(os.path.join(result_reg_path, 'model', 'reg_net_last_best.pth')) |
|
os.rename(os.path.join(result_reg_path, 'model', 'reg_net_best.pth'), os.path.join(result_reg_path, 'model', 'reg_net_last_best.pth')) |
|
torch.save(reg_net.state_dict(), os.path.join(result_reg_path, 'model', 'reg_net_best.pth')) |
|
save_reg_checkpoint(reg_net, optimizer_reg, epoch_number, validation_loss_reg, total_loss=validation_losses_reg, save_dir=os.path.join(result_reg_path, 'checkpoints'), name='valid') |
|
|
|
plot_progress(logger, os.path.join(result_reg_path, 'training_plot'), training_losses_reg, validation_losses_reg, 'reg_net_training_loss') |
|
plot_progress(logger, os.path.join(result_reg_path, 'training_plot'), regularization_loss_reg, [], 'regularization_reg_net_loss') |
|
plot_progress(logger, os.path.join(result_reg_path, 'training_plot'), anatomy_loss_reg, [], 'anatomy_reg_net_loss') |
|
plot_progress(logger, os.path.join(result_reg_path, 'training_plot'), similarity_loss_reg, [], 'similarity_reg_net_loss') |
|
|
|
|
|
del loss, loss_sim, loss_reg, loss_ana |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info('\t'+'----'*10) |
|
swap_training(seg_net, reg_net) |
|
losses = [] |
|
supervised_loss = [] |
|
anatomy_loss = [] |
|
dice_loss = dice_loss_func() |
|
warp = warp_func() |
|
warp_nearest = warp_nearest_func() |
|
dice_loss2 = dice_loss_func2() |
|
for batch in batch_generator_train_seg(seg_phase_training_batches_per_epoch): |
|
optimizer_seg.zero_grad() |
|
|
|
img12 = batch['img12'].to(device) |
|
|
|
displacement_fields = reg_net(img12) |
|
seg1_predicted = seg_net(img12[:, [0], :, :, :]).softmax(dim=1) |
|
seg2_predicted = seg_net(img12[:, [1], :, :, :]).softmax(dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'seg1' in batch.keys() and 'seg2' in batch.keys(): |
|
seg1 = monai.networks.one_hot( |
|
batch['seg1'].to(device), num_segmentation_classes) |
|
seg2 = monai.networks.one_hot( |
|
batch['seg2'].to(device), num_segmentation_classes) |
|
loss_metric = dice_loss(seg2_predicted, seg2) |
|
loss_supervised = loss_metric + dice_loss(seg1_predicted, seg1) |
|
|
|
|
|
|
|
|
|
|
|
elif 'seg1' in batch.keys(): |
|
seg1 = monai.networks.one_hot( |
|
batch['seg1'].to(device), num_segmentation_classes) |
|
loss_metric = dice_loss(seg1_predicted, seg1) |
|
loss_supervised = loss_metric |
|
seg2 = seg2_predicted |
|
|
|
else: |
|
assert('seg2' in batch.keys()) |
|
seg2 = monai.networks.one_hot( |
|
batch['seg2'].to(device), num_segmentation_classes) |
|
loss_metric = dice_loss(seg2_predicted, seg2) |
|
loss_supervised = loss_metric |
|
seg1 = seg1_predicted |
|
|
|
|
|
|
|
loss_anatomy = dice_loss(warp_nearest(seg2, displacement_fields), seg1)\ |
|
if 'seg1' in batch.keys() or 'seg2' in batch.keys()\ |
|
else 0. |
|
|
|
|
|
|
|
|
|
|
|
loss = lambda_a * loss_anatomy + lambda_sp * loss_supervised |
|
loss.backward() |
|
optimizer_seg.step() |
|
|
|
losses.append(loss_metric.item()) |
|
supervised_loss.append(loss_supervised.item()) |
|
anatomy_loss.append(loss_anatomy.item()) |
|
|
|
training_loss_seg = np.mean(losses) |
|
supervised_loss_seg.append([epoch_number+1, np.mean(supervised_loss)]) |
|
anatomy_loss_seg.append([epoch_number+1, np.mean(anatomy_loss)]) |
|
logger.info(f"\tseg training loss: {training_loss_seg}") |
|
training_losses_seg.append([epoch_number+1, training_loss_seg]) |
|
logger.info("\tsave latest seg_net checkpoint") |
|
save_seg_checkpoint(seg_net, optimizer_seg, epoch_number, training_loss_seg, super_loss=supervised_loss_seg,ana_loss=anatomy_loss_seg, total_loss=training_losses_seg, save_dir=os.path.join(result_seg_path, 'checkpoints'), name='latest') |
|
|
|
if len(dataloader_valid_seg) == 0: |
|
logger.info("\tno enough dataset for validation") |
|
save_seg_checkpoint(seg_net, optimizer_seg, epoch_number, training_loss_seg, super_loss=supervised_loss_seg,ana_loss=anatomy_loss_seg, total_loss=training_losses_seg, save_dir=os.path.join(result_seg_path, 'checkpoints'), name='valid') |
|
save_seg_checkpoint(seg_net, optimizer_seg, epoch_number, training_loss_seg, super_loss=supervised_loss_seg,ana_loss=anatomy_loss_seg, total_loss=training_losses_seg, save_dir=os.path.join(result_seg_path, 'checkpoints'), name='best') |
|
if os.path.exists(os.path.join(result_seg_path, 'model', 'seg_net_best.pth')): |
|
if os.path.exists(os.path.join(result_seg_path, 'model', 'seg_net_last_best.pth')): |
|
os.remove(os.path.join(result_seg_path, 'model', 'seg_net_last_best.pth')) |
|
os.rename(os.path.join(result_seg_path, 'model', 'seg_net_best.pth'), os.path.join(result_seg_path, 'model', 'seg_net_last_best.pth')) |
|
torch.save(seg_net.state_dict(), os.path.join(result_seg_path, 'model', 'seg_net_best.pth')) |
|
else: |
|
if epoch_number % val_interval == 0: |
|
|
|
|
|
|
|
seg_net.eval() |
|
losses = [] |
|
with torch.no_grad(): |
|
for batch in dataloader_valid_seg: |
|
imgs = batch['img'].to(device) |
|
true_segs = batch['seg'].to(device) |
|
predicted_segs = seg_net(imgs) |
|
loss = dice_loss2(predicted_segs, true_segs) |
|
losses.append(loss.item()) |
|
|
|
validation_loss_seg = np.mean(losses) |
|
logger.info(f"\tseg validation loss: {validation_loss_seg}") |
|
validation_losses_seg.append([epoch_number+1, validation_loss_seg]) |
|
|
|
if validation_loss_seg < best_seg_validation_loss: |
|
best_seg_validation_loss = validation_loss_seg |
|
logger.info("\tsave best seg_net checkpoint and model") |
|
save_seg_checkpoint(seg_net, optimizer_seg, epoch_number, best_seg_validation_loss, total_loss=validation_losses_seg, save_dir=os.path.join(result_seg_path, 'checkpoints'), name='best') |
|
if os.path.exists(os.path.join(result_seg_path, 'model', 'seg_net_best.pth')): |
|
if os.path.exists(os.path.join(result_seg_path, 'model', 'seg_net_last_best.pth')): |
|
os.remove(os.path.join(result_seg_path, 'model', 'seg_net_last_best.pth')) |
|
os.rename(os.path.join(result_seg_path, 'model', 'seg_net_best.pth'), os.path.join(result_seg_path, 'model', 'seg_net_last_best.pth')) |
|
torch.save(seg_net.state_dict(), os.path.join(result_seg_path, 'model', 'seg_net_best.pth')) |
|
save_seg_checkpoint(seg_net, optimizer_seg, epoch_number, validation_loss_seg, total_loss=validation_losses_seg, save_dir=os.path.join(result_seg_path, 'checkpoints'), name='valid') |
|
|
|
plot_progress(logger, os.path.join(result_seg_path, 'training_plot'), training_losses_seg, validation_losses_seg, 'seg_net_training_loss') |
|
plot_progress(logger, os.path.join(result_seg_path, 'training_plot'), anatomy_loss_seg, [], 'anatomy_seg_net_loss') |
|
plot_progress(logger, os.path.join(result_seg_path, 'training_plot'), supervised_loss_seg, [], 'supervised_seg_net_loss') |
|
logger.info(f"\tseg lr: {optimizer_seg.param_groups[0]['lr']}") |
|
logger.info(f"\treg lr: {optimizer_reg.param_groups[0]['lr']}") |
|
|
|
|
|
del (loss, seg1, seg2, displacement_fields, img12, loss_supervised, loss_anatomy, loss_metric,\ |
|
seg1_predicted, seg2_predicted) |
|
torch.cuda.empty_cache() |
|
|
|
if len(validation_losses_reg) == 0: |
|
logger.info('Only small number of pairs are used for training, no need to do validation. Replace best validation loss with training loss !!!') |
|
logger.info(f'Best reg_net validation loss: {training_loss_reg}') |
|
else: |
|
logger.info(f"Best reg_net validation loss: {best_reg_validation_loss}") |
|
|
|
if len(validation_losses_seg) == 0: |
|
logger.info('Only one label is used for training, no need to do validation. Replace best validation loss with training loss !!!') |
|
logger.info(f'Best seg_net validation loss: {training_loss_seg}') |
|
else: |
|
logger.info(f"Best seg_net validation loss: {best_seg_validation_loss}") |
|
|
|
|
|
reg_training_pkl = [{'training_losses': training_losses_reg}, |
|
{'anatomy_loss': anatomy_loss_reg}, |
|
{'similarity_loss': similarity_loss_reg}, |
|
{'regularization_loss': regularization_loss_reg} |
|
] |
|
if len(validation_losses_reg) != 0: |
|
reg_training_pkl.append({'validation_losses': validation_losses_reg}) |
|
reg_training_pkl.append({'best_reg_validation_loss': best_reg_validation_loss}) |
|
else: |
|
reg_training_pkl.append({'best_reg_validation_loss': training_loss_reg}) |
|
|
|
|
|
seg_training_pkl = [{'training_losses': training_losses_seg}, |
|
{'anatomy_loss': anatomy_loss_seg}, |
|
{'supervised_loss': supervised_loss_seg} |
|
] |
|
if len(validation_losses_seg) != 0: |
|
seg_training_pkl.append({'validation_losses': validation_losses_seg}) |
|
seg_training_pkl.append({'best_seg_validation_loss': best_seg_validation_loss}) |
|
else: |
|
seg_training_pkl.append({'best_seg_validation_loss': training_loss_seg}) |
|
|
|
with open(os.path.join(result_reg_path, 'training_plot', 'reg_training_losses.pkl'), 'wb') as f: |
|
pickle.dump(reg_training_pkl, f) |
|
|
|
with open(os.path.join(result_seg_path, 'training_plot', 'seg_training_losses.pkl'), 'wb') as ff: |
|
pickle.dump(seg_training_pkl, ff) |