Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import logging | |
import torch.nn as nn | |
import torch | |
logger = logging.getLogger(__name__) | |
class SpeechEncoderPostnet(nn.Module): | |
""" | |
Args: | |
in_channels (int): the number of input channels | |
mid_channels (int): the number of intermediate channels | |
out_channels (int): the number of output channels | |
kernel_sizes (List[int]): the kernel size for each convolutional layer | |
""" | |
def __init__(self, dictionaries, args): | |
super(SpeechEncoderPostnet, self).__init__() | |
# modules below are not needed during fine-tuning | |
self.target_glu = args.target_glu | |
self.skip_masked = args.skip_masked | |
self.skip_nomask = args.skip_nomask | |
self.logit_temp = args.logit_temp | |
final_dim = ( | |
args.final_dim if args.final_dim > 0 else args.encoder_embed_dim | |
) | |
if any([d is None for d in dictionaries]): | |
logger.info( | |
"cannot find dictionary. assume will be used for fine-tuning" | |
) | |
else: | |
self.num_classes = [len(d) for d in dictionaries] | |
self.label_embs_concat = nn.Parameter( | |
torch.FloatTensor(sum(self.num_classes), final_dim) | |
) | |
nn.init.uniform_(self.label_embs_concat) | |
self.untie_final_proj = args.untie_final_proj | |
if self.untie_final_proj: | |
self.final_proj = nn.Linear( | |
args.encoder_embed_dim, final_dim * len(dictionaries) | |
) | |
else: | |
self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) | |
def compute_nce(self, x, pos, negs): | |
neg_is_pos = (pos == negs).all(-1) | |
pos = pos.unsqueeze(0) | |
targets = torch.cat([pos, negs], dim=0) | |
logits = torch.cosine_similarity( | |
x.float(), targets.float(), dim=-1 | |
).type_as(x) | |
logits /= self.logit_temp | |
if neg_is_pos.any(): | |
logits[1:][neg_is_pos] = float("-inf") | |
logits = logits.transpose(0, 1) # (num_x, num_cls+1) | |
return logits | |
def forward(self, x, padding_mask, mask_indices, target_list): | |
def compute_pred(proj_x, target, label_embs): | |
# compute logits for the i-th label set | |
y = torch.index_select(label_embs, 0, target.long()) | |
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) | |
if self.target_glu: | |
y = self.target_glu(y) | |
negs = self.target_glu(negs) | |
# proj_x: (S, D) | |
# y: (S, D) | |
# negs: (Neg, S, D) | |
return self.compute_nce(proj_x, y, negs) | |
label_embs_list = self.label_embs_concat.split(self.num_classes, 0) | |
if not self.skip_masked: | |
masked_indices = torch.logical_and(~padding_mask, mask_indices) | |
proj_x_m = self.final_proj(x[masked_indices]) | |
if self.untie_final_proj: | |
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) | |
else: | |
proj_x_m_list = [proj_x_m for _ in range(len(target_list))] | |
logit_m_list = [ | |
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) | |
for i, (proj_x_m, t) in enumerate( | |
zip(proj_x_m_list, target_list) | |
) | |
] | |
else: | |
logit_m_list = [None for _ in target_list] | |
if not self.skip_nomask: | |
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) | |
proj_x_u = self.final_proj(x[nomask_indices]) | |
if self.untie_final_proj: | |
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) | |
else: | |
proj_x_u_list = [proj_x_u for _ in range(len(target_list))] | |
logit_u_list = [ | |
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) | |
for i, (proj_x_u, t) in enumerate( | |
zip(proj_x_u_list, target_list) | |
) | |
] | |
else: | |
logit_u_list = [None for _ in target_list] | |
result = { | |
"logit_m_list": logit_m_list, | |
"logit_u_list": logit_u_list, | |
"padding_mask": padding_mask, | |
} | |
return result | |