zyliu's picture
release iChatApp
0f90f73
raw history blame
No virus
689 Bytes
from enum import Enum
import yaml
from easydict import EasyDict as edict
import torch.nn as nn
import torch
def load_yaml(path):
with open(path, 'r') as f:
return edict(yaml.safe_load(f))
def move_to_device(obj, device):
if isinstance(obj, nn.Module):
return obj.to(device)
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, (tuple, list)):
return [move_to_device(el, device) for el in obj]
if isinstance(obj, dict):
return {name: move_to_device(val, device) for name, val in obj.items()}
raise ValueError(f'Unexpected type {type(obj)}')
class SmallMode(Enum):
DROP = "drop"
UPSCALE = "upscale"