artst-tts-demo / artst /models /modules /speech_encoder_postnet.py
herwoww's picture
first upload
1547a56
raw
history blame
4.68 kB
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import logging
import torch.nn as nn
import torch
logger = logging.getLogger(__name__)
class SpeechEncoderPostnet(nn.Module):
"""
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(self, dictionaries, args):
super(SpeechEncoderPostnet, self).__init__()
# modules below are not needed during fine-tuning
self.target_glu = args.target_glu
self.skip_masked = args.skip_masked
self.skip_nomask = args.skip_nomask
self.logit_temp = args.logit_temp
final_dim = (
args.final_dim if args.final_dim > 0 else args.encoder_embed_dim
)
if any([d is None for d in dictionaries]):
logger.info(
"cannot find dictionary. assume will be used for fine-tuning"
)
else:
self.num_classes = [len(d) for d in dictionaries]
self.label_embs_concat = nn.Parameter(
torch.FloatTensor(sum(self.num_classes), final_dim)
)
nn.init.uniform_(self.label_embs_concat)
self.untie_final_proj = args.untie_final_proj
if self.untie_final_proj:
self.final_proj = nn.Linear(
args.encoder_embed_dim, final_dim * len(dictionaries)
)
else:
self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim)
def compute_nce(self, x, pos, negs):
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity(
x.float(), targets.float(), dim=-1
).type_as(x)
logits /= self.logit_temp
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
return logits
def forward(self, x, padding_mask, mask_indices, target_list):
def compute_pred(proj_x, target, label_embs):
# compute logits for the i-th label set
y = torch.index_select(label_embs, 0, target.long())
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)
# proj_x: (S, D)
# y: (S, D)
# negs: (Neg, S, D)
return self.compute_nce(proj_x, y, negs)
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
if not self.skip_masked:
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
if self.untie_final_proj:
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
else:
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
logit_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t) in enumerate(
zip(proj_x_m_list, target_list)
)
]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
if self.untie_final_proj:
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
else:
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
logit_u_list = [
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
for i, (proj_x_u, t) in enumerate(
zip(proj_x_u_list, target_list)
)
]
else:
logit_u_list = [None for _ in target_list]
result = {
"logit_m_list": logit_m_list,
"logit_u_list": logit_u_list,
"padding_mask": padding_mask,
}
return result