from enum import Enum import yaml from easydict import EasyDict as edict import torch.nn as nn import torch def load_yaml(path): with open(path, 'r') as f: return edict(yaml.safe_load(f)) def move_to_device(obj, device): if isinstance(obj, nn.Module): return obj.to(device) if torch.is_tensor(obj): return obj.to(device) if isinstance(obj, (tuple, list)): return [move_to_device(el, device) for el in obj] if isinstance(obj, dict): return {name: move_to_device(val, device) for name, val in obj.items()} raise ValueError(f'Unexpected type {type(obj)}') class SmallMode(Enum): DROP = "drop" UPSCALE = "upscale"