""" This file contains some useful functions for train / val. """ import os import numpy as np import torch ################# ## image utils ## ################# def convert_image(input_tensor, axis): """Convert single channel images to 3-channel images.""" image_lst = [input_tensor for _ in range(3)] outputs = np.concatenate(image_lst, axis) return outputs ###################### ## checkpoint utils ## ###################### def get_latest_checkpoint( checkpoint_root, checkpoint_name, device=torch.device("cuda") ): """Get the latest checkpoint or by filename.""" # Load specific checkpoint if checkpoint_name is not None: checkpoint = torch.load( os.path.join(checkpoint_root, checkpoint_name), map_location=device ) # Load the latest checkpoint else: lastest_checkpoint = sorted(os.listdir(os.path.join(checkpoint_root, "*.tar")))[ -1 ] checkpoint = torch.load( os.path.join(checkpoint_root, lastest_checkpoint), map_location=device ) return checkpoint def remove_old_checkpoints(checkpoint_root, max_ckpt=15): """Remove the outdated checkpoints.""" # Get sorted list of checkpoints checkpoint_list = sorted( [_ for _ in os.listdir(os.path.join(checkpoint_root)) if _.endswith(".tar")] ) # Get the checkpoints to be removed if len(checkpoint_list) > max_ckpt: remove_list = checkpoint_list[:-max_ckpt] for _ in remove_list: full_name = os.path.join(checkpoint_root, _) os.remove(full_name) print("[Debug] Remove outdated checkpoint %s" % (full_name)) def adapt_checkpoint(state_dict): new_state_dict = {} for k, v in state_dict.items(): if k.startswith("module."): new_state_dict[k[7:]] = v else: new_state_dict[k] = v return new_state_dict ################ ## HDF5 utils ## ################ def parse_h5_data(h5_data): """Parse h5 dataset.""" output_data = {} for key in h5_data.keys(): output_data[key] = np.array(h5_data[key]) return output_data