# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import logging import torch.nn as nn import torch logger = logging.getLogger(__name__) class SpeechEncoderPostnet(nn.Module): """ Args: in_channels (int): the number of input channels mid_channels (int): the number of intermediate channels out_channels (int): the number of output channels kernel_sizes (List[int]): the kernel size for each convolutional layer """ def __init__(self, dictionaries, args): super(SpeechEncoderPostnet, self).__init__() # modules below are not needed during fine-tuning self.target_glu = args.target_glu self.skip_masked = args.skip_masked self.skip_nomask = args.skip_nomask self.logit_temp = args.logit_temp final_dim = ( args.final_dim if args.final_dim > 0 else args.encoder_embed_dim ) if any([d is None for d in dictionaries]): logger.info( "cannot find dictionary. assume will be used for fine-tuning" ) else: self.num_classes = [len(d) for d in dictionaries] self.label_embs_concat = nn.Parameter( torch.FloatTensor(sum(self.num_classes), final_dim) ) nn.init.uniform_(self.label_embs_concat) self.untie_final_proj = args.untie_final_proj if self.untie_final_proj: self.final_proj = nn.Linear( args.encoder_embed_dim, final_dim * len(dictionaries) ) else: self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) def compute_nce(self, x, pos, negs): neg_is_pos = (pos == negs).all(-1) pos = pos.unsqueeze(0) targets = torch.cat([pos, negs], dim=0) logits = torch.cosine_similarity( x.float(), targets.float(), dim=-1 ).type_as(x) logits /= self.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") logits = logits.transpose(0, 1) # (num_x, num_cls+1) return logits def forward(self, x, padding_mask, mask_indices, target_list): def compute_pred(proj_x, target, label_embs): # compute logits for the i-th label set y = torch.index_select(label_embs, 0, target.long()) negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) # proj_x: (S, D) # y: (S, D) # negs: (Neg, S, D) return self.compute_nce(proj_x, y, negs) label_embs_list = self.label_embs_concat.split(self.num_classes, 0) if not self.skip_masked: masked_indices = torch.logical_and(~padding_mask, mask_indices) proj_x_m = self.final_proj(x[masked_indices]) if self.untie_final_proj: proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) else: proj_x_m_list = [proj_x_m for _ in range(len(target_list))] logit_m_list = [ compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) for i, (proj_x_m, t) in enumerate( zip(proj_x_m_list, target_list) ) ] else: logit_m_list = [None for _ in target_list] if not self.skip_nomask: nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) proj_x_u = self.final_proj(x[nomask_indices]) if self.untie_final_proj: proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) else: proj_x_u_list = [proj_x_u for _ in range(len(target_list))] logit_u_list = [ compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) for i, (proj_x_u, t) in enumerate( zip(proj_x_u_list, target_list) ) ] else: logit_u_list = [None for _ in target_list] result = { "logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, } return result