Guess-What-Moves / utils /convert.py
subhc's picture
Code Commit
5e88f62
import torch
import itertools
def cast_like(maybe_tensor, example_tensor):
if not torch.is_tensor(maybe_tensor):
maybe_tensor = torch.tensor(maybe_tensor)
maybe_tensor = maybe_tensor.to(example_tensor.device).to(example_tensor.dtype)
shape = [*maybe_tensor.shape, *[1] * len(example_tensor.shape)]
if not shape:
shape = [1]
return maybe_tensor.view(*shape)
def lofd_2_dofl(list_of_dicts, make_tensor=True):
keys = set(itertools.chain.from_iterable(list_of_dicts))
out_dict = {}
for k in keys:
out_dict[k] = [d[k] for d in list_of_dicts if k in d]
if make_tensor:
example_tensor = next((v for v in out_dict[k] if torch.is_tensor(v)), None)
if example_tensor is None:
out_dict[k] = torch.tensor(out_dict[k])
else:
out_dict[k] = torch.cat([cast_like(t, example_tensor) for t in out_dict[k]], 0)
return out_dict