Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# The YiTrans End-to-End Speech Translation System for IWSLT 2022 Offline Shared Task (https://arxiv.org/abs/2206.05777) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/YiTrans | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Based on fairseq code bases | |
# https://github.com/facebookresearch/fairseq | |
# -------------------------------------------------------- | |
import math | |
import numpy as np | |
import torch | |
from fairseq.data import FairseqDataset, data_utils, DenoisingDataset | |
class DenoisingDatasetLang(DenoisingDataset): | |
""" | |
A wrapper around DenoisingDataset for BART dataset. | |
""" | |
def __init__( | |
self, | |
dataset, | |
sizes, | |
vocab, | |
mask_idx, | |
mask_whole_words, | |
shuffle, | |
seed, | |
args, | |
eos=None, | |
item_transform_func=None, | |
tgt_lang_idx=None, | |
): | |
super().__init__( | |
dataset, | |
sizes, | |
vocab, | |
mask_idx, | |
mask_whole_words, | |
shuffle, | |
seed, | |
args, | |
eos, | |
item_transform_func, | |
) | |
self.tgt_lang_idx=tgt_lang_idx | |
def __getitem__(self, index): | |
with data_utils.numpy_seed(self.seed, self.epoch, index): | |
tokens = self.dataset[index] | |
assert tokens[-1] == self.eos | |
source, target = tokens, tokens.clone() | |
if self.permute_sentence_ratio > 0.0: | |
source = self.permute_sentences(source, self.permute_sentence_ratio) | |
if self.mask_ratio > 0: | |
source = self.add_whole_word_mask(source, self.mask_ratio) | |
if self.insert_ratio > 0: | |
source = self.add_insertion_noise(source, self.insert_ratio) | |
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: | |
source = self.add_rolling_noise(source) | |
# there can additional changes to make: | |
if self.item_transform_func is not None: | |
source, target = self.item_transform_func(source, target) | |
assert (source >= 0).all() | |
assert (source[1:-1] >= 1).all() | |
assert (source <= len(self.vocab)).all() | |
assert source[0] == self.vocab.bos() | |
assert target[0] == self.vocab.bos() | |
assert source[-1] == self.eos | |
if self.tgt_lang_idx is not None: | |
tgt_lang_idx = torch.LongTensor([self.tgt_lang_idx]) | |
source = torch.cat([source[1:], tgt_lang_idx]) | |
target = torch.cat([target[1:], tgt_lang_idx]) | |
sample = { | |
"id": index, | |
"source": source, | |
"target": target, | |
} | |
return sample | |