Vincentqyw
fix: roma
8b973ee
raw
history blame
2.16 kB
"""
This file contains some useful functions for train / val.
"""
import os
import numpy as np
import torch
#################
## image utils ##
#################
def convert_image(input_tensor, axis):
"""Convert single channel images to 3-channel images."""
image_lst = [input_tensor for _ in range(3)]
outputs = np.concatenate(image_lst, axis)
return outputs
######################
## checkpoint utils ##
######################
def get_latest_checkpoint(
checkpoint_root, checkpoint_name, device=torch.device("cuda")
):
"""Get the latest checkpoint or by filename."""
# Load specific checkpoint
if checkpoint_name is not None:
checkpoint = torch.load(
os.path.join(checkpoint_root, checkpoint_name), map_location=device
)
# Load the latest checkpoint
else:
lastest_checkpoint = sorted(os.listdir(os.path.join(checkpoint_root, "*.tar")))[
-1
]
checkpoint = torch.load(
os.path.join(checkpoint_root, lastest_checkpoint), map_location=device
)
return checkpoint
def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
"""Remove the outdated checkpoints."""
# Get sorted list of checkpoints
checkpoint_list = sorted(
[_ for _ in os.listdir(os.path.join(checkpoint_root)) if _.endswith(".tar")]
)
# Get the checkpoints to be removed
if len(checkpoint_list) > max_ckpt:
remove_list = checkpoint_list[:-max_ckpt]
for _ in remove_list:
full_name = os.path.join(checkpoint_root, _)
os.remove(full_name)
print("[Debug] Remove outdated checkpoint %s" % (full_name))
def adapt_checkpoint(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
return new_state_dict
################
## HDF5 utils ##
################
def parse_h5_data(h5_data):
"""Parse h5 dataset."""
output_data = {}
for key in h5_data.keys():
output_data[key] = np.array(h5_data[key])
return output_data