File size: 947 Bytes
5e88f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import itertools

def cast_like(maybe_tensor, example_tensor):
    if not torch.is_tensor(maybe_tensor):
        maybe_tensor = torch.tensor(maybe_tensor)
    maybe_tensor = maybe_tensor.to(example_tensor.device).to(example_tensor.dtype)
    shape = [*maybe_tensor.shape, *[1] * len(example_tensor.shape)]
    if not shape:
        shape = [1]
    return maybe_tensor.view(*shape)


def lofd_2_dofl(list_of_dicts, make_tensor=True):
    keys = set(itertools.chain.from_iterable(list_of_dicts))
    out_dict = {}
    for k in keys:
        out_dict[k] = [d[k] for d in list_of_dicts if k in d]
        if make_tensor:
            example_tensor = next((v for v in out_dict[k] if torch.is_tensor(v)), None)
            if example_tensor is None:
                out_dict[k] = torch.tensor(out_dict[k])
            else:
                out_dict[k] = torch.cat([cast_like(t, example_tensor) for t in out_dict[k]], 0)
    return out_dict