Spaces:
Build error
Build error
"""Helper functions related to io""" | |
import os.path | |
import sys | |
import shutil | |
import urllib.request | |
from pathlib import Path | |
import yaml | |
import torch | |
def progress(iterable, *, size=None, print_freq=1, handle=sys.stdout): | |
"""Generator wrapping an iterable to print progress""" | |
for i, element in enumerate(iterable): | |
yield element | |
if i == 0 or (i+1) % print_freq == 0 or (i+1) == size: | |
if size: | |
handle.write(f'\r>>>> {i+1}/{size} done...') | |
else: | |
handle.write(f'\r>>>> {i+1} done...') | |
handle.write("\n") | |
# Params | |
def load_params(path): | |
"""Return loaded parameters from a yaml file""" | |
with open(path, "r") as handle: | |
content = yaml.safe_load(handle) | |
return load_nested_templates(content, os.path.dirname(path)) | |
def save_params(path, params): | |
"""Save given parameters to a yaml file""" | |
with open(path, "w") as handle: | |
yaml.safe_dump(params, handle, default_flow_style=False) | |
def load_nested_templates(params, root_path): | |
"""Find keys '__template__' in nested dictionary and replace corresponding value with loaded | |
yaml file""" | |
if not isinstance(params, dict): | |
return params | |
if "__template__" in params: | |
template_path = os.path.expanduser(params.pop("__template__")) | |
path = os.path.join(root_path, template_path) | |
root_path = os.path.dirname(path) | |
# Treat template as defaults | |
params = dict_deep_overlay(load_params(path), params) | |
for key, value in params.items(): | |
params[key] = load_nested_templates(value, root_path) | |
return params | |
def dict_deep_overlay(defaults, params): | |
"""If defaults and params are both dictionaries, perform deep overlay (use params value for | |
keys defined in params), otherwise use defaults value""" | |
if isinstance(defaults, dict) and isinstance(params, dict): | |
for key in params: | |
defaults[key] = dict_deep_overlay(defaults.get(key, None), params[key]) | |
return defaults | |
return params | |
def dict_deep_set(dct, key, value): | |
"""Set key to value for a nested dictionary where the key is a sequence (e.g. list)""" | |
if len(key) == 1: | |
dct[key[0]] = value | |
return | |
if not isinstance(dct[key[0]], dict) or key[0] not in dct: | |
dct[key[0]] = {} | |
dict_deep_set(dct[key[0]], key[1:], value) | |
# Download | |
def download_files(names, root_path, base_url, logfunc=None): | |
"""Download file names from given url to given directory path. If logfunc given, use it to log | |
status.""" | |
root_path = Path(root_path) | |
for name in names: | |
path = root_path / name | |
if path.exists(): | |
continue | |
if logfunc: | |
logfunc(f"Downloading file '{name}'") | |
path.parent.mkdir(parents=True, exist_ok=True) | |
urllib.request.urlretrieve(base_url + name, path) | |
# Checkpoints | |
def save_checkpoint(state, is_best, keep_epoch, directory): | |
"""Save state dictionary to the directory providing whether the corresponding epoch is the best | |
and whether to keep it anyway""" | |
filename = os.path.join(directory, 'model_epoch%d.pth' % state['epoch']) | |
filename_best = os.path.join(directory, 'model_best.pth') | |
if is_best and keep_epoch: | |
torch.save(state, filename) | |
shutil.copyfile(filename, filename_best) | |
elif is_best or keep_epoch: | |
torch.save(state, filename_best if is_best else filename) | |