Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| def pad_x_to_y(x, y, axis: int = -1): | |
| if axis != -1: | |
| raise NotImplementedError | |
| inp_len = y.shape[axis] | |
| output_len = x.shape[axis] | |
| return nn.functional.pad(x, [0, inp_len - output_len]) | |
| def shape_reconstructed(reconstructed, size): | |
| if len(size) == 1: | |
| return reconstructed.squeeze(0) | |
| return reconstructed | |
| def tensors_to_device(tensors, device): | |
| """Transfer tensor, dict or list of tensors to device. | |
| Args: | |
| tensors (:class:`torch.Tensor`): May be a single, a list or a | |
| dictionary of tensors. | |
| device (:class: `torch.device`): the device where to place the tensors. | |
| Returns: | |
| Union [:class:`torch.Tensor`, list, tuple, dict]: | |
| Same as input but transferred to device. | |
| Goes through lists and dicts and transfers the torch.Tensor to | |
| device. Leaves the rest untouched. | |
| """ | |
| if isinstance(tensors, torch.Tensor): | |
| return tensors.to(device) | |
| elif isinstance(tensors, (list, tuple)): | |
| return [tensors_to_device(tens, device) for tens in tensors] | |
| elif isinstance(tensors, dict): | |
| for key in tensors.keys(): | |
| tensors[key] = tensors_to_device(tensors[key], device) | |
| return tensors | |
| else: | |
| return tensors | |