import torch from torch.autograd import Variable import numpy as np import collections __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] def as_variable(obj): if isinstance(obj, Variable): return obj if isinstance(obj, collections.Sequence): return [as_variable(v) for v in obj] elif isinstance(obj, collections.Mapping): return {k: as_variable(v) for k, v in obj.items()} else: return Variable(obj) def as_numpy(obj): if isinstance(obj, collections.Sequence): return [as_numpy(v) for v in obj] elif isinstance(obj, collections.Mapping): return {k: as_numpy(v) for k, v in obj.items()} elif isinstance(obj, Variable): return obj.data.cpu().numpy() elif torch.is_tensor(obj): return obj.cpu().numpy() else: return np.array(obj) def mark_volatile(obj): if torch.is_tensor(obj): obj = Variable(obj) if isinstance(obj, Variable): obj.no_grad = True return obj elif isinstance(obj, collections.Mapping): return {k: mark_volatile(o) for k, o in obj.items()} elif isinstance(obj, collections.Sequence): return [mark_volatile(o) for o in obj] else: return obj