import torch import torch.nn.functional as F def get_mask_from_lengths(lengths, max_len=None): lengths = lengths.to(torch.long) if max_len is None: max_len = torch.max(lengths).item() ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device) mask = ids < lengths.unsqueeze(1).expand(-1, max_len) return mask def linear_interpolation(features, seq_len): features = features.transpose(1, 2) output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') return output_features.transpose(1, 2) if __name__ == "__main__": import numpy as np mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4,6]))) import pdb; pdb.set_trace()