# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2020 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from typing import List, Dict import torch from torch import Tensor def lengths_to_mask(lengths: List[int], device: torch.device) -> Tensor: lengths = torch.tensor(lengths, device=device) max_len = max(lengths) mask = torch.arange(max_len, device=device).expand(len(lengths), max_len) < lengths.unsqueeze(1) return mask