import torch | |
def lengths_to_mask(lengths: list[int], | |
device: torch.device, | |
max_len: int = None) -> torch.Tensor: | |
lengths = torch.tensor(lengths, device=device) | |
max_len = max_len if max_len else max(lengths) | |
mask = torch.arange(max_len, device=device).expand( | |
len(lengths), max_len) < lengths.unsqueeze(1) | |
return mask | |
def remove_padding(tensors: torch.Tensor, lengths: list[int]) -> list: | |
return [ | |
tensor[:tensor_length] | |
for tensor, tensor_length in zip(tensors, lengths) | |
] | |