File size: 1,947 Bytes
d6d3a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import torch

"""
This file stores functions for conversion between numpy and torch, torch, list, etc.
Also deal with general operations such as to(dev), detach, etc.
"""


def thing2list(thing):
    if isinstance(thing, torch.Tensor):
        return thing.tolist()
    if isinstance(thing, np.ndarray):
        return thing.tolist()
    if isinstance(thing, dict):
        return {k: thing2list(v) for k, v in md.items()}
    if isinstance(thing, list):
        return [thing2list(ten) for ten in thing]
    return thing


def thing2dev(thing, dev):
    if hasattr(thing, "to"):
        thing = thing.to(dev)
        return thing
    if isinstance(thing, list):
        return [thing2dev(ten, dev) for ten in thing]
    if isinstance(thing, tuple):
        return tuple(thing2dev(list(thing), dev))
    if isinstance(thing, dict):
        return {k: thing2dev(v, dev) for k, v in thing.items()}
    if isinstance(thing, torch.Tensor):
        return thing.to(dev)
    return thing


def thing2np(thing):
    if isinstance(thing, list):
        return np.array(thing)
    if isinstance(thing, torch.Tensor):
        return thing.cpu().detach().numpy()
    if isinstance(thing, dict):
        return {k: thing2np(v) for k, v in thing.items()}
    return thing


def thing2torch(thing):
    if isinstance(thing, list):
        return torch.tensor(np.array(thing))
    if isinstance(thing, np.ndarray):
        return torch.from_numpy(thing)
    if isinstance(thing, dict):
        return {k: thing2torch(v) for k, v in thing.items()}
    return thing


def detach_thing(thing):
    if isinstance(thing, torch.Tensor):
        return thing.cpu().detach()
    if isinstance(thing, list):
        return [detach_thing(ten) for ten in thing]
    if isinstance(thing, tuple):
        return tuple(detach_thing(list(thing)))
    if isinstance(thing, dict):
        return {k: detach_thing(v) for k, v in thing.items()}
    return thing