File size: 3,474 Bytes
32408ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Helper functions related to io"""

import os.path
import sys
import shutil
import urllib.request
from pathlib import Path
import yaml
import torch


def progress(iterable, *, size=None, print_freq=1, handle=sys.stdout):
    """Generator wrapping an iterable to print progress"""
    for i, element in enumerate(iterable):
        yield element

        if i == 0 or (i+1) % print_freq == 0 or (i+1) == size:
            if size:
                handle.write(f'\r>>>> {i+1}/{size} done...')
            else:
                handle.write(f'\r>>>> {i+1} done...')

    handle.write("\n")


# Params

def load_params(path):
    """Return loaded parameters from a yaml file"""
    with open(path, "r") as handle:
        content = yaml.safe_load(handle)
    return load_nested_templates(content, os.path.dirname(path))

def save_params(path, params):
    """Save given parameters to a yaml file"""
    with open(path, "w") as handle:
        yaml.safe_dump(params, handle, default_flow_style=False)

def load_nested_templates(params, root_path):
    """Find keys '__template__' in nested dictionary and replace corresponding value with loaded
        yaml file"""
    if not isinstance(params, dict):
        return params

    if "__template__" in params:
        template_path = os.path.expanduser(params.pop("__template__"))
        path = os.path.join(root_path, template_path)
        root_path = os.path.dirname(path)
        # Treat template as defaults
        params = dict_deep_overlay(load_params(path), params)

    for key, value in params.items():
        params[key] = load_nested_templates(value, root_path)

    return params

def dict_deep_overlay(defaults, params):
    """If defaults and params are both dictionaries, perform deep overlay (use params value for
        keys defined in params), otherwise use defaults value"""
    if isinstance(defaults, dict) and isinstance(params, dict):
        for key in params:
            defaults[key] = dict_deep_overlay(defaults.get(key, None), params[key])
        return defaults

    return params

def dict_deep_set(dct, key, value):
    """Set key to value for a nested dictionary where the key is a sequence (e.g. list)"""
    if len(key) == 1:
        dct[key[0]] = value
        return

    if not isinstance(dct[key[0]], dict) or key[0] not in dct:
        dct[key[0]] = {}
    dict_deep_set(dct[key[0]], key[1:], value)


# Download

def download_files(names, root_path, base_url, logfunc=None):
    """Download file names from given url to given directory path. If logfunc given, use it to log
        status."""
    root_path = Path(root_path)
    for name in names:
        path = root_path / name
        if path.exists():
            continue
        if logfunc:
            logfunc(f"Downloading file '{name}'")
        path.parent.mkdir(parents=True, exist_ok=True)
        urllib.request.urlretrieve(base_url + name, path)


# Checkpoints

def save_checkpoint(state, is_best, keep_epoch, directory):
    """Save state dictionary to the directory providing whether the corresponding epoch is the best
        and whether to keep it anyway"""
    filename = os.path.join(directory, 'model_epoch%d.pth' % state['epoch'])
    filename_best = os.path.join(directory, 'model_best.pth')
    if is_best and keep_epoch:
        torch.save(state, filename)
        shutil.copyfile(filename, filename_best)
    elif is_best or keep_epoch:
        torch.save(state, filename_best if is_best else filename)