SuperFeatures / how /utils /io_helpers.py
YannisK's picture
temp state
32408ed
raw history blame
No virus
3.47 kB
"""Helper functions related to io"""
import os.path
import sys
import shutil
import urllib.request
from pathlib import Path
import yaml
import torch
def progress(iterable, *, size=None, print_freq=1, handle=sys.stdout):
"""Generator wrapping an iterable to print progress"""
for i, element in enumerate(iterable):
yield element
if i == 0 or (i+1) % print_freq == 0 or (i+1) == size:
if size:
handle.write(f'\r>>>> {i+1}/{size} done...')
else:
handle.write(f'\r>>>> {i+1} done...')
handle.write("\n")
# Params
def load_params(path):
"""Return loaded parameters from a yaml file"""
with open(path, "r") as handle:
content = yaml.safe_load(handle)
return load_nested_templates(content, os.path.dirname(path))
def save_params(path, params):
"""Save given parameters to a yaml file"""
with open(path, "w") as handle:
yaml.safe_dump(params, handle, default_flow_style=False)
def load_nested_templates(params, root_path):
"""Find keys '__template__' in nested dictionary and replace corresponding value with loaded
yaml file"""
if not isinstance(params, dict):
return params
if "__template__" in params:
template_path = os.path.expanduser(params.pop("__template__"))
path = os.path.join(root_path, template_path)
root_path = os.path.dirname(path)
# Treat template as defaults
params = dict_deep_overlay(load_params(path), params)
for key, value in params.items():
params[key] = load_nested_templates(value, root_path)
return params
def dict_deep_overlay(defaults, params):
"""If defaults and params are both dictionaries, perform deep overlay (use params value for
keys defined in params), otherwise use defaults value"""
if isinstance(defaults, dict) and isinstance(params, dict):
for key in params:
defaults[key] = dict_deep_overlay(defaults.get(key, None), params[key])
return defaults
return params
def dict_deep_set(dct, key, value):
"""Set key to value for a nested dictionary where the key is a sequence (e.g. list)"""
if len(key) == 1:
dct[key[0]] = value
return
if not isinstance(dct[key[0]], dict) or key[0] not in dct:
dct[key[0]] = {}
dict_deep_set(dct[key[0]], key[1:], value)
# Download
def download_files(names, root_path, base_url, logfunc=None):
"""Download file names from given url to given directory path. If logfunc given, use it to log
status."""
root_path = Path(root_path)
for name in names:
path = root_path / name
if path.exists():
continue
if logfunc:
logfunc(f"Downloading file '{name}'")
path.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(base_url + name, path)
# Checkpoints
def save_checkpoint(state, is_best, keep_epoch, directory):
"""Save state dictionary to the directory providing whether the corresponding epoch is the best
and whether to keep it anyway"""
filename = os.path.join(directory, 'model_epoch%d.pth' % state['epoch'])
filename_best = os.path.join(directory, 'model_best.pth')
if is_best and keep_epoch:
torch.save(state, filename)
shutil.copyfile(filename, filename_best)
elif is_best or keep_epoch:
torch.save(state, filename_best if is_best else filename)