from collections import defaultdict import torch import torch.nn.functional as F def make_positions(tensor, padding_idx): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. """ # The series of casts and type-conversions here are carefully # balanced to both work with ONNX export and XLA. In particular XLA # prefers ints, cumsum defaults to output longs, and ONNX doesn't know # how to handle the dtype kwarg in cumsum. mask = tensor.ne(padding_idx).int() return ( torch.cumsum(mask, dim=1).type_as(mask) * mask ).long() + padding_idx def softmax(x, dim): return F.softmax(x, dim=dim, dtype=torch.float32) def sequence_mask(lengths, maxlen, dtype=torch.bool): if maxlen is None: maxlen = lengths.max() mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() mask.type(dtype) return mask INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) def _get_full_incremental_state_key(module_instance, key): module_name = module_instance.__class__.__name__ # assign a unique ID to each module instance, so that incremental state is # not shared across module instances if not hasattr(module_instance, '_instance_id'): INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) def get_incremental_state(module, incremental_state, key): """Helper for getting incremental state for an nn.Module.""" full_key = _get_full_incremental_state_key(module, key) if incremental_state is None or full_key not in incremental_state: return None return incremental_state[full_key] def set_incremental_state(module, incremental_state, key, value): """Helper for setting incremental state for an nn.Module.""" if incremental_state is not None: full_key = _get_full_incremental_state_key(module, key) incremental_state[full_key] = value def fill_with_neg_inf(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(float('-inf')).type_as(t) def fill_with_neg_inf2(t): """FP16-compatible function that fills a tensor with -inf.""" return t.float().fill_(-1e8).type_as(t) def get_focus_rate(attn, src_padding_mask=None, tgt_padding_mask=None): ''' attn: bs x L_t x L_s ''' if src_padding_mask is not None: attn = attn * (1 - src_padding_mask.float())[:, None, :] if tgt_padding_mask is not None: attn = attn * (1 - tgt_padding_mask.float())[:, :, None] focus_rate = attn.max(-1).values.sum(-1) focus_rate = focus_rate / attn.sum(-1).sum(-1) return focus_rate def get_phone_coverage_rate(attn, src_padding_mask=None, src_seg_mask=None, tgt_padding_mask=None): ''' attn: bs x L_t x L_s ''' src_mask = attn.new(attn.size(0), attn.size(-1)).bool().fill_(False) if src_padding_mask is not None: src_mask |= src_padding_mask if src_seg_mask is not None: src_mask |= src_seg_mask attn = attn * (1 - src_mask.float())[:, None, :] if tgt_padding_mask is not None: attn = attn * (1 - tgt_padding_mask.float())[:, :, None] phone_coverage_rate = attn.max(1).values.sum(-1) # phone_coverage_rate = phone_coverage_rate / attn.sum(-1).sum(-1) phone_coverage_rate = phone_coverage_rate / (1 - src_mask.float()).sum(-1) return phone_coverage_rate def get_diagonal_focus_rate(attn, attn_ks, target_len, src_padding_mask=None, tgt_padding_mask=None, band_mask_factor=5, band_width=50): ''' attn: bx x L_t x L_s attn_ks: shape: tensor with shape [batch_size], input_lens/output_lens diagonal: y=k*x (k=attn_ks, x:output, y:input) 1 0 0 0 1 0 0 0 1 y>=k*(x-width) and y<=k*(x+width):1 else:0 ''' # width = min(target_len/band_mask_factor, 50) width1 = target_len / band_mask_factor width2 = target_len.new(target_len.size()).fill_(band_width) width = torch.where(width1 < width2, width1, width2).float() base = torch.ones(attn.size()).to(attn.device) zero = torch.zeros(attn.size()).to(attn.device) x = torch.arange(0, attn.size(1)).to(attn.device)[None, :, None].float() * base y = torch.arange(0, attn.size(2)).to(attn.device)[None, None, :].float() * base cond = (y - attn_ks[:, None, None] * x) cond1 = cond + attn_ks[:, None, None] * width[:, None, None] cond2 = cond - attn_ks[:, None, None] * width[:, None, None] mask1 = torch.where(cond1 < 0, zero, base) mask2 = torch.where(cond2 > 0, zero, base) mask = mask1 * mask2 if src_padding_mask is not None: attn = attn * (1 - src_padding_mask.float())[:, None, :] if tgt_padding_mask is not None: attn = attn * (1 - tgt_padding_mask.float())[:, :, None] diagonal_attn = attn * mask diagonal_focus_rate = diagonal_attn.sum(-1).sum(-1) / attn.sum(-1).sum(-1) return diagonal_focus_rate, mask def select_attn(attn_logits, type='best'): """ :param attn_logits: [n_layers, B, n_head, T_sp, T_txt] :return: """ encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2) # [n_layers * n_head, B, T_sp, T_txt] encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1) if type == 'best': indices = encdec_attn.max(-1).values.sum(-1).argmax(0) encdec_attn = encdec_attn.gather( 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0] return encdec_attn elif type == 'mean': return encdec_attn.mean(0) def make_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of padded part. Args: lengths (LongTensor or List): Batch of lengths (B,). xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. length_dim (int, optional): Dimension indicator of the above tensor. See the example. Returns: Tensor: Mask tensor containing indices of padded part. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: With only lengths. >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[0, 0, 0, 0 ,0], [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] With the reference tensor. >>> xs = torch.zeros((3, 2, 4)) >>> make_pad_mask(lengths, xs) tensor([[[0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 1], [0, 0, 0, 1]], [[0, 0, 1, 1], [0, 0, 1, 1]]], dtype=torch.uint8) >>> xs = torch.zeros((3, 2, 6)) >>> make_pad_mask(lengths, xs) tensor([[[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]], [[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]], [[0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) With the reference tensor and dimension indicator. >>> xs = torch.zeros((3, 6, 6)) >>> make_pad_mask(lengths, xs, 1) tensor([[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]], [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) >>> make_pad_mask(lengths, xs, 2) tensor([[[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]], [[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]], [[0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) """ if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) if not isinstance(lengths, list): lengths = lengths.tolist() bs = int(len(lengths)) if xs is None: maxlen = int(max(lengths)) else: maxlen = xs.size(length_dim) seq_range = torch.arange(0, maxlen, dtype=torch.int64) seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) mask = seq_range_expand >= seq_length_expand if xs is not None: assert xs.size(0) == bs, (xs.size(0), bs) if length_dim < 0: length_dim = xs.dim() + length_dim # ind = (:, None, ..., None, :, , None, ..., None) ind = tuple( slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) ) mask = mask[ind].expand_as(xs).to(xs.device) return mask def make_non_pad_mask(lengths, xs=None, length_dim=-1): """Make mask tensor containing indices of non-padded part. Args: lengths (LongTensor or List): Batch of lengths (B,). xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. length_dim (int, optional): Dimension indicator of the above tensor. See the example. Returns: ByteTensor: mask tensor containing indices of padded part. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: With only lengths. >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) masks = [[1, 1, 1, 1 ,1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] With the reference tensor. >>> xs = torch.zeros((3, 2, 4)) >>> make_non_pad_mask(lengths, xs) tensor([[[1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 0], [1, 1, 1, 0]], [[1, 1, 0, 0], [1, 1, 0, 0]]], dtype=torch.uint8) >>> xs = torch.zeros((3, 2, 6)) >>> make_non_pad_mask(lengths, xs) tensor([[[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0]], [[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]], [[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) With the reference tensor and dimension indicator. >>> xs = torch.zeros((3, 6, 6)) >>> make_non_pad_mask(lengths, xs, 1) tensor([[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]], [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) >>> make_non_pad_mask(lengths, xs, 2) tensor([[[1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 0]], [[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]], [[1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) """ return ~make_pad_mask(lengths, xs, length_dim) def get_mask_from_lengths(lengths): max_len = torch.max(lengths).item() ids = torch.arange(0, max_len).to(lengths.device) mask = (ids < lengths.unsqueeze(1)).bool() return mask def group_hidden_by_segs(h, seg_ids, max_len): """ :param h: [B, T, H] :param seg_ids: [B, T] :return: h_ph: [B, T_ph, H] """ B, T, H = h.shape h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) all_ones = h.new_ones(h.shape[:2]) cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() h_gby_segs = h_gby_segs[:, 1:] cnt_gby_segs = cnt_gby_segs[:, 1:] h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) return h_gby_segs, cnt_gby_segs