Spaces:
Running
Running
""" | |
Author: Paul-Edouard Sarlin (skydes) | |
""" | |
import collections.abc as collections | |
import numpy as np | |
import torch | |
string_classes = (str, bytes) | |
def map_tensor(input_, func): | |
if isinstance(input_, string_classes): | |
return input_ | |
elif isinstance(input_, collections.Mapping): | |
return {k: map_tensor(sample, func) for k, sample in input_.items()} | |
elif isinstance(input_, collections.Sequence): | |
return [map_tensor(sample, func) for sample in input_] | |
elif input_ is None: | |
return None | |
else: | |
return func(input_) | |
def batch_to_numpy(batch): | |
return map_tensor(batch, lambda tensor: tensor.cpu().numpy()) | |
def batch_to_device(batch, device, non_blocking=True): | |
def _func(tensor): | |
return tensor.to(device=device, non_blocking=non_blocking) | |
return map_tensor(batch, _func) | |
def rbd(data: dict) -> dict: | |
"""Remove batch dimension from elements in data""" | |
return { | |
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v | |
for k, v in data.items() | |
} | |
def index_batch(tensor_dict): | |
batch_size = len(next(iter(tensor_dict.values()))) | |
for i in range(batch_size): | |
yield map_tensor(tensor_dict, lambda t: t[i]) | |