File size: 711 Bytes
2f56479
 
 
 
 
 
 
 
 
 
 
 
87b7a45
2f56479
 
 
87b7a45
2f56479
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import List, Set

import torch


def sorted_list(s: Set[str]) -> List[str]:
    return sorted(list(set(s)))


def device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def nested_to_device(s):
    # s is either a tensor or a dictionary
    if isinstance(s, torch.Tensor):
        return s.to(device())
    return {k: v.to(device()) for k, v in s.items()}

def nested_apply(h, s):
    # h is an unary function, s is one of N, tuple of N, list of N, or set of N
    if isinstance(s, str):
        return h(s)
    ret = [nested_apply(h, i) for i in s]
    if isinstance(s, tuple):
        return tuple(ret)
    if isinstance(s, set):
        return set(ret)
    return ret