Spaces:
Runtime error
Runtime error
File size: 7,049 Bytes
1547a56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import torch.nn as nn
import math
import torch
import torch.nn.functional as F
class AngularMargin(nn.Module):
"""
An implementation of Angular Margin (AM) proposed in the following
paper: '''Margin Matters: Towards More Discriminative Deep Neural Network
Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)
Arguments
---------
margin : float
The margin for cosine similiarity
scale : float
The scale for cosine similiarity
Return
---------
predictions : torch.Tensor
Example
-------
>>> pred = AngularMargin()
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ])
>>> predictions = pred(outputs, targets)
>>> predictions[:,0] > predictions[:,1]
tensor([ True, False, True, False])
"""
def __init__(self, margin=0.0, scale=1.0):
super(AngularMargin, self).__init__()
self.margin = margin
self.scale = scale
def forward(self, outputs, targets):
"""Compute AM between two tensors
Arguments
---------
outputs : torch.Tensor
The outputs of shape [N, C], cosine similarity is required.
targets : torch.Tensor
The targets of shape [N, C], where the margin is applied for.
Return
---------
predictions : torch.Tensor
"""
outputs = outputs - self.margin * targets
return self.scale * outputs
class AdditiveAngularMargin(AngularMargin):
"""
An implementation of Additive Angular Margin (AAM) proposed
in the following paper: '''Margin Matters: Towards More Discriminative Deep
Neural Network Embeddings for Speaker Recognition'''
(https://arxiv.org/abs/1906.07317)
Arguments
---------
margin : float
The margin for cosine similiarity, usually 0.2.
scale: float
The scale for cosine similiarity, usually 30.
Returns
-------
predictions : torch.Tensor
Tensor.
Example
-------
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ])
>>> pred = AdditiveAngularMargin()
>>> predictions = pred(outputs, targets)
>>> predictions[:,0] > predictions[:,1]
tensor([ True, False, True, False])
"""
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
super(AdditiveAngularMargin, self).__init__(margin, scale)
self.easy_margin = easy_margin
self.cos_m = math.cos(self.margin)
self.sin_m = math.sin(self.margin)
self.th = math.cos(math.pi - self.margin)
self.mm = math.sin(math.pi - self.margin) * self.margin
def forward(self, outputs, targets):
"""
Compute AAM between two tensors
Arguments
---------
outputs : torch.Tensor
The outputs of shape [N, C], cosine similarity is required.
targets : torch.Tensor
The targets of shape [N, C], where the margin is applied for.
Return
---------
predictions : torch.Tensor
"""
cosine = outputs.float()
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
outputs = (targets * phi) + ((1.0 - targets) * cosine)
return self.scale * outputs
class SpeakerDecoderPostnet(nn.Module):
"""Speaker Identification Postnet.
Arguments
---------
embed_dim : int
The size of embedding.
class_num: int
The number of classes.
args : Namespace
Return
---------
embed : torch.Tensor
output : torch.Tensor
"""
def __init__(self, embed_dim, class_num, args):
super(SpeakerDecoderPostnet, self).__init__()
self.embed_dim = embed_dim
self.class_num = class_num
self.no_pooling_bn = getattr(args, "sid_no_pooling_bn", False)
self.no_embed_postnet = getattr(args, "sid_no_embed_postnet", False)
self.normalize_postnet = getattr(args, "sid_normalize_postnet", False)
self.softmax_head = getattr(args, "sid_softmax_type", "softmax")
if not self.no_pooling_bn:
self.bn_pooling = nn.BatchNorm1d(args.decoder_output_dim)
else:
self.bn_pooling = None
if not self.no_embed_postnet:
self.output_embedding = nn.Linear(args.decoder_output_dim, embed_dim, bias=False)
self.bn_embedding = nn.BatchNorm1d(embed_dim)
else:
self.output_embedding = None
self.bn_embedding = None
self.embed_dim = args.decoder_output_dim
self.output_projection = nn.Linear(self.embed_dim, class_num, bias=False)
if self.softmax_head == "amsoftmax":
self.output_layer = AngularMargin(args.softmax_margin, args.softmax_scale)
elif self.softmax_head == "aamsoftmax":
self.output_layer = AdditiveAngularMargin(args.softmax_margin, args.softmax_scale, args.softmax_easy_margin)
else:
self.output_layer = None
if self.output_embedding is not None:
nn.init.normal_(self.output_embedding.weight, mean=0, std=embed_dim ** -0.5)
nn.init.normal_(self.output_projection.weight, mean=0, std=class_num ** -0.5)
def forward(self, x, target=None):
"""
Parameters
----------
x : torch.Tensor of shape [batch, channel] or [batch, time, channel]
target : torch.Tensor of shape [batch, channel]
"""
if self.bn_pooling is not None:
x = self.bn_pooling(x)
if self.output_embedding is not None and self.bn_embedding is not None:
embed = self.bn_embedding(self.output_embedding(x))
else:
embed = x
if self.output_layer is not None or self.normalize_postnet:
x_norm = F.normalize(embed, p=2, dim=1)
w_norm = F.normalize(self.output_projection.weight, p=2, dim=1) # [out_dim, in_dim]
output = F.linear(x_norm, w_norm)
if self.training and target is not None and self.output_layer is not None:
output = self.output_layer(output, target)
else:
output = self.output_projection(embed)
return output, embed
|