# -------------------------------------------------------- # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) # Github source: https://github.com/mbzuai-nlp/ArTST # Based on speecht5, fairseq and espnet code bases # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet # -------------------------------------------------------- import torch.nn as nn import math import torch import torch.nn.functional as F class AngularMargin(nn.Module): """ An implementation of Angular Margin (AM) proposed in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317) Arguments --------- margin : float The margin for cosine similiarity scale : float The scale for cosine similiarity Return --------- predictions : torch.Tensor Example ------- >>> pred = AngularMargin() >>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ]) >>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ]) >>> predictions = pred(outputs, targets) >>> predictions[:,0] > predictions[:,1] tensor([ True, False, True, False]) """ def __init__(self, margin=0.0, scale=1.0): super(AngularMargin, self).__init__() self.margin = margin self.scale = scale def forward(self, outputs, targets): """Compute AM between two tensors Arguments --------- outputs : torch.Tensor The outputs of shape [N, C], cosine similarity is required. targets : torch.Tensor The targets of shape [N, C], where the margin is applied for. Return --------- predictions : torch.Tensor """ outputs = outputs - self.margin * targets return self.scale * outputs class AdditiveAngularMargin(AngularMargin): """ An implementation of Additive Angular Margin (AAM) proposed in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317) Arguments --------- margin : float The margin for cosine similiarity, usually 0.2. scale: float The scale for cosine similiarity, usually 30. Returns ------- predictions : torch.Tensor Tensor. Example ------- >>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ]) >>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ]) >>> pred = AdditiveAngularMargin() >>> predictions = pred(outputs, targets) >>> predictions[:,0] > predictions[:,1] tensor([ True, False, True, False]) """ def __init__(self, margin=0.0, scale=1.0, easy_margin=False): super(AdditiveAngularMargin, self).__init__(margin, scale) self.easy_margin = easy_margin self.cos_m = math.cos(self.margin) self.sin_m = math.sin(self.margin) self.th = math.cos(math.pi - self.margin) self.mm = math.sin(math.pi - self.margin) * self.margin def forward(self, outputs, targets): """ Compute AAM between two tensors Arguments --------- outputs : torch.Tensor The outputs of shape [N, C], cosine similarity is required. targets : torch.Tensor The targets of shape [N, C], where the margin is applied for. Return --------- predictions : torch.Tensor """ cosine = outputs.float() sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) if self.easy_margin: phi = torch.where(cosine > 0, phi, cosine) else: phi = torch.where(cosine > self.th, phi, cosine - self.mm) outputs = (targets * phi) + ((1.0 - targets) * cosine) return self.scale * outputs class SpeakerDecoderPostnet(nn.Module): """Speaker Identification Postnet. Arguments --------- embed_dim : int The size of embedding. class_num: int The number of classes. args : Namespace Return --------- embed : torch.Tensor output : torch.Tensor """ def __init__(self, embed_dim, class_num, args): super(SpeakerDecoderPostnet, self).__init__() self.embed_dim = embed_dim self.class_num = class_num self.no_pooling_bn = getattr(args, "sid_no_pooling_bn", False) self.no_embed_postnet = getattr(args, "sid_no_embed_postnet", False) self.normalize_postnet = getattr(args, "sid_normalize_postnet", False) self.softmax_head = getattr(args, "sid_softmax_type", "softmax") if not self.no_pooling_bn: self.bn_pooling = nn.BatchNorm1d(args.decoder_output_dim) else: self.bn_pooling = None if not self.no_embed_postnet: self.output_embedding = nn.Linear(args.decoder_output_dim, embed_dim, bias=False) self.bn_embedding = nn.BatchNorm1d(embed_dim) else: self.output_embedding = None self.bn_embedding = None self.embed_dim = args.decoder_output_dim self.output_projection = nn.Linear(self.embed_dim, class_num, bias=False) if self.softmax_head == "amsoftmax": self.output_layer = AngularMargin(args.softmax_margin, args.softmax_scale) elif self.softmax_head == "aamsoftmax": self.output_layer = AdditiveAngularMargin(args.softmax_margin, args.softmax_scale, args.softmax_easy_margin) else: self.output_layer = None if self.output_embedding is not None: nn.init.normal_(self.output_embedding.weight, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.output_projection.weight, mean=0, std=class_num ** -0.5) def forward(self, x, target=None): """ Parameters ---------- x : torch.Tensor of shape [batch, channel] or [batch, time, channel] target : torch.Tensor of shape [batch, channel] """ if self.bn_pooling is not None: x = self.bn_pooling(x) if self.output_embedding is not None and self.bn_embedding is not None: embed = self.bn_embedding(self.output_embedding(x)) else: embed = x if self.output_layer is not None or self.normalize_postnet: x_norm = F.normalize(embed, p=2, dim=1) w_norm = F.normalize(self.output_projection.weight, p=2, dim=1) # [out_dim, in_dim] output = F.linear(x_norm, w_norm) if self.training and target is not None and self.output_layer is not None: output = self.output_layer(output, target) else: output = self.output_projection(embed) return output, embed