artst-tts-demo / artst /models /modules /speaker_decoder_postnet.py
herwoww's picture
first upload
1547a56
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import torch.nn as nn
import math
import torch
import torch.nn.functional as F
class AngularMargin(nn.Module):
"""
An implementation of Angular Margin (AM) proposed in the following
paper: '''Margin Matters: Towards More Discriminative Deep Neural Network
Embeddings for Speaker Recognition''' (https://arxiv.org/abs/1906.07317)
Arguments
---------
margin : float
The margin for cosine similiarity
scale : float
The scale for cosine similiarity
Return
---------
predictions : torch.Tensor
Example
-------
>>> pred = AngularMargin()
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ])
>>> predictions = pred(outputs, targets)
>>> predictions[:,0] > predictions[:,1]
tensor([ True, False, True, False])
"""
def __init__(self, margin=0.0, scale=1.0):
super(AngularMargin, self).__init__()
self.margin = margin
self.scale = scale
def forward(self, outputs, targets):
"""Compute AM between two tensors
Arguments
---------
outputs : torch.Tensor
The outputs of shape [N, C], cosine similarity is required.
targets : torch.Tensor
The targets of shape [N, C], where the margin is applied for.
Return
---------
predictions : torch.Tensor
"""
outputs = outputs - self.margin * targets
return self.scale * outputs
class AdditiveAngularMargin(AngularMargin):
"""
An implementation of Additive Angular Margin (AAM) proposed
in the following paper: '''Margin Matters: Towards More Discriminative Deep
Neural Network Embeddings for Speaker Recognition'''
(https://arxiv.org/abs/1906.07317)
Arguments
---------
margin : float
The margin for cosine similiarity, usually 0.2.
scale: float
The scale for cosine similiarity, usually 30.
Returns
-------
predictions : torch.Tensor
Tensor.
Example
-------
>>> outputs = torch.tensor([ [1., -1.], [-1., 1.], [0.9, 0.1], [0.1, 0.9] ])
>>> targets = torch.tensor([ [1., 0.], [0., 1.], [ 1., 0.], [0., 1.] ])
>>> pred = AdditiveAngularMargin()
>>> predictions = pred(outputs, targets)
>>> predictions[:,0] > predictions[:,1]
tensor([ True, False, True, False])
"""
def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
super(AdditiveAngularMargin, self).__init__(margin, scale)
self.easy_margin = easy_margin
self.cos_m = math.cos(self.margin)
self.sin_m = math.sin(self.margin)
self.th = math.cos(math.pi - self.margin)
self.mm = math.sin(math.pi - self.margin) * self.margin
def forward(self, outputs, targets):
"""
Compute AAM between two tensors
Arguments
---------
outputs : torch.Tensor
The outputs of shape [N, C], cosine similarity is required.
targets : torch.Tensor
The targets of shape [N, C], where the margin is applied for.
Return
---------
predictions : torch.Tensor
"""
cosine = outputs.float()
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
outputs = (targets * phi) + ((1.0 - targets) * cosine)
return self.scale * outputs
class SpeakerDecoderPostnet(nn.Module):
"""Speaker Identification Postnet.
Arguments
---------
embed_dim : int
The size of embedding.
class_num: int
The number of classes.
args : Namespace
Return
---------
embed : torch.Tensor
output : torch.Tensor
"""
def __init__(self, embed_dim, class_num, args):
super(SpeakerDecoderPostnet, self).__init__()
self.embed_dim = embed_dim
self.class_num = class_num
self.no_pooling_bn = getattr(args, "sid_no_pooling_bn", False)
self.no_embed_postnet = getattr(args, "sid_no_embed_postnet", False)
self.normalize_postnet = getattr(args, "sid_normalize_postnet", False)
self.softmax_head = getattr(args, "sid_softmax_type", "softmax")
if not self.no_pooling_bn:
self.bn_pooling = nn.BatchNorm1d(args.decoder_output_dim)
else:
self.bn_pooling = None
if not self.no_embed_postnet:
self.output_embedding = nn.Linear(args.decoder_output_dim, embed_dim, bias=False)
self.bn_embedding = nn.BatchNorm1d(embed_dim)
else:
self.output_embedding = None
self.bn_embedding = None
self.embed_dim = args.decoder_output_dim
self.output_projection = nn.Linear(self.embed_dim, class_num, bias=False)
if self.softmax_head == "amsoftmax":
self.output_layer = AngularMargin(args.softmax_margin, args.softmax_scale)
elif self.softmax_head == "aamsoftmax":
self.output_layer = AdditiveAngularMargin(args.softmax_margin, args.softmax_scale, args.softmax_easy_margin)
else:
self.output_layer = None
if self.output_embedding is not None:
nn.init.normal_(self.output_embedding.weight, mean=0, std=embed_dim ** -0.5)
nn.init.normal_(self.output_projection.weight, mean=0, std=class_num ** -0.5)
def forward(self, x, target=None):
"""
Parameters
----------
x : torch.Tensor of shape [batch, channel] or [batch, time, channel]
target : torch.Tensor of shape [batch, channel]
"""
if self.bn_pooling is not None:
x = self.bn_pooling(x)
if self.output_embedding is not None and self.bn_embedding is not None:
embed = self.bn_embedding(self.output_embedding(x))
else:
embed = x
if self.output_layer is not None or self.normalize_postnet:
x_norm = F.normalize(embed, p=2, dim=1)
w_norm = F.normalize(self.output_projection.weight, p=2, dim=1) # [out_dim, in_dim]
output = F.linear(x_norm, w_norm)
if self.training and target is not None and self.output_layer is not None:
output = self.output_layer(output, target)
else:
output = self.output_projection(embed)
return output, embed