meow
init
d6d3a5b
raw
history blame
No virus
1.95 kB
import numpy as np
import torch
"""
This file stores functions for conversion between numpy and torch, torch, list, etc.
Also deal with general operations such as to(dev), detach, etc.
"""
def thing2list(thing):
if isinstance(thing, torch.Tensor):
return thing.tolist()
if isinstance(thing, np.ndarray):
return thing.tolist()
if isinstance(thing, dict):
return {k: thing2list(v) for k, v in md.items()}
if isinstance(thing, list):
return [thing2list(ten) for ten in thing]
return thing
def thing2dev(thing, dev):
if hasattr(thing, "to"):
thing = thing.to(dev)
return thing
if isinstance(thing, list):
return [thing2dev(ten, dev) for ten in thing]
if isinstance(thing, tuple):
return tuple(thing2dev(list(thing), dev))
if isinstance(thing, dict):
return {k: thing2dev(v, dev) for k, v in thing.items()}
if isinstance(thing, torch.Tensor):
return thing.to(dev)
return thing
def thing2np(thing):
if isinstance(thing, list):
return np.array(thing)
if isinstance(thing, torch.Tensor):
return thing.cpu().detach().numpy()
if isinstance(thing, dict):
return {k: thing2np(v) for k, v in thing.items()}
return thing
def thing2torch(thing):
if isinstance(thing, list):
return torch.tensor(np.array(thing))
if isinstance(thing, np.ndarray):
return torch.from_numpy(thing)
if isinstance(thing, dict):
return {k: thing2torch(v) for k, v in thing.items()}
return thing
def detach_thing(thing):
if isinstance(thing, torch.Tensor):
return thing.cpu().detach()
if isinstance(thing, list):
return [detach_thing(ten) for ten in thing]
if isinstance(thing, tuple):
return tuple(detach_thing(list(thing)))
if isinstance(thing, dict):
return {k: detach_thing(v) for k, v in thing.items()}
return thing