artst-demo-asr / SpeechT5 /SpeechLM /speechlm /data /language_trible_dataset.py
amupd's picture
SpeechT5 upload
62e9ca6
raw
history blame
25.9 kB
# ----------------------------------------------------------------------------
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import logging
import numpy as np
import torch
import os
import itertools
from fairseq.data import FairseqDataset, data_utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
PrependTokenDataset,
data_utils,
indexed_dataset,
)
logger = logging.getLogger(__name__)
def load_langtriple_dataset(
data_path,
split,
src,
src_dict,
ref,
ref_dict,
tgt,
tgt_dict,
combine,
dataset_impl,
upsample_primary,
left_pad_source,
left_pad_target,
max_source_positions,
max_target_positions,
prepend_bos=False,
load_alignments=False,
truncate_source=False,
append_source_id=False,
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
prepend_bos_src=None,
lang_format="[{}]",
):
assert not truncate_source
def split_exists(split, src, ref, tgt, lang, data_path):
filename = os.path.join(data_path, "{}.{}-{}-{}.{}".format(split, src, ref, tgt, lang))
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
ref_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
# infer langcode
if split_exists(split_k, src, ref, tgt, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, src, ref, tgt))
elif split_exists(split_k, tgt, ref, src, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, tgt, ref, src))
else:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
src_dataset = data_utils.load_indexed_dataset(
prefix + src, src_dict, dataset_impl
)
src_datasets.append(src_dataset)
ref_dataset = data_utils.load_indexed_dataset(
prefix + ref, ref_dict, dataset_impl
)
ref_datasets.append(ref_dataset)
tgt_dataset = data_utils.load_indexed_dataset(
prefix + tgt, tgt_dict, dataset_impl
)
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
logger.info(
"{} {} {}-{}-{} {} examples".format(
data_path, split_k, src, ref, tgt, len(src_datasets[-1])
)
)
if not combine:
break
assert len(src_datasets) == len(ref_datasets)
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
if len(src_datasets) == 1:
src_dataset = src_datasets[0]
ref_dataset = ref_datasets[0]
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
else:
sample_ratios = [1] * len(src_datasets)
sample_ratios[0] = upsample_primary
src_dataset = ConcatDataset(src_datasets, sample_ratios)
ref_dataset = ConcatDataset(ref_datasets, sample_ratios)
if len(tgt_datasets) > 0:
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
else:
tgt_dataset = None
if prepend_bos:
assert hasattr(src_dict, "bos_index") and hasattr(ref_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
ref_dataset = PrependTokenDataset(ref_dataset, ref_dict.bos())
if tgt_dataset is not None:
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
elif prepend_bos_src is not None:
logger.info(f"prepending src bos: {prepend_bos_src}")
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
ref_dataset = PrependTokenDataset(ref_dataset, prepend_bos_src)
eos = None
if append_source_id:
src_dataset = AppendTokenDataset(
src_dataset, src_dict.index(lang_format.format(src))
)
ref_dataset = AppendTokenDataset(
ref_dataset, ref_dict.index(lang_format.format(ref))
)
if tgt_dataset is not None:
tgt_dataset = AppendTokenDataset(
tgt_dataset, tgt_dict.index(lang_format.format(tgt))
)
eos = tgt_dict.index(lang_format.format(tgt))
align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(
align_path, None, dataset_impl
)
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
return LanguageTripleDataset(
src_dataset,
src_dataset.sizes,
src_dict,
ref_dataset,
ref_dataset.sizes,
ref_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
align_dataset=align_dataset,
eos=eos,
num_buckets=num_buckets,
shuffle=shuffle,
pad_to_multiple=pad_to_multiple,
)
def collate(
samples,
pad_idx,
eos_idx,
left_pad_source=True,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
pad_to_multiple=1,
):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
None,
left_pad,
move_eos_to_beginning,
pad_to_length=pad_to_length,
pad_to_multiple=pad_to_multiple,
)
def check_alignment(alignment, src_len, tgt_len):
if alignment is None or len(alignment) == 0:
return False
if (
alignment[:, 0].max().item() >= src_len - 1
or alignment[:, 1].max().item() >= tgt_len - 1
):
logger.warning("alignment size mismatch found, skipping alignment!")
return False
return True
def compute_alignment_weights(alignments):
"""
Given a tensor of shape [:, 2] containing the source-target indices
corresponding to the alignments, a weight vector containing the
inverse frequency of each target index is computed.
For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
index 3 is repeated twice)
"""
align_tgt = alignments[:, 1]
_, align_tgt_i, align_tgt_c = torch.unique(
align_tgt, return_inverse=True, return_counts=True
)
align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
return 1.0 / align_weights.float()
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
ref_tokens = merge(
"reference",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor(
[s["source"].ne(pad_idx).long().sum() for s in samples]
)
ref_lengths = torch.LongTensor(
[s["reference"].ne(pad_idx).long().sum() for s in samples]
)
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
ref_lengths = ref_lengths.index_select(0, sort_order)
ref_tokens = ref_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
).index_select(0, sort_order)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
elif input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
else:
ntokens = src_lengths.sum().item()
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
"ref_tokens": ref_tokens,
"ref_lengths": ref_lengths,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(
0, sort_order
)
if samples[0].get("alignment", None) is not None:
bsz, tgt_sz = batch["target"].shape
src_sz = batch["net_input"]["src_tokens"].shape[1]
offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz
if left_pad_source:
offsets[:, 0] += src_sz - src_lengths
if left_pad_target:
offsets[:, 1] += tgt_sz - tgt_lengths
alignments = [
alignment + offset
for align_idx, offset, src_len, tgt_len in zip(
sort_order, offsets, src_lengths, tgt_lengths
)
for alignment in [samples[align_idx]["alignment"].view(-1, 2)]
if check_alignment(alignment, src_len, tgt_len)
]
if len(alignments) > 0:
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)
batch["alignments"] = alignments
batch["align_weights"] = align_weights
if samples[0].get("constraints", None) is not None:
# Collate the packed constraints across the samples, padding to
# the length of the longest sample.
lens = [sample.get("constraints").size(0) for sample in samples]
max_len = max(lens)
constraints = torch.zeros((len(samples), max(lens))).long()
for i, sample in enumerate(samples):
constraints[i, 0 : lens[i]] = samples[i].get("constraints")
batch["constraints"] = constraints.index_select(0, sort_order)
return batch
class LanguageTripleDataset(FairseqDataset):
"""
A pair of torch.utils.data.Datasets.
Args:
src (torch.utils.data.Dataset): source dataset to wrap
src_sizes (List[int]): source sentence lengths
src_dict (~fairseq.data.Dictionary): source vocabulary
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side
(default: True).
left_pad_target (bool, optional): pad target tensors on the left side
(default: False).
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for teacher forcing (default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end
of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent (default: False).
align_dataset (torch.utils.data.Dataset, optional): dataset
containing alignments.
constraints (Tensor, optional): 2d tensor with a concatenated, zero-
delimited list of constraints for each sentence.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
num_buckets (int, optional): if set to a value greater than 0, then
batches will be bucketed into the given number of batch shapes.
src_lang_id (int, optional): source language ID, if set, the collated batch
will contain a field 'src_lang_id' in 'net_input' which indicates the
source language of the samples.
tgt_lang_id (int, optional): target language ID, if set, the collated batch
will contain a field 'tgt_lang_id' which indicates the target language
of the samples.
"""
def __init__(
self,
src,
src_sizes,
src_dict,
ref,
ref_sizes,
ref_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
left_pad_source=True,
left_pad_target=False,
shuffle=True,
input_feeding=True,
remove_eos_from_source=False,
append_eos_to_target=False,
align_dataset=None,
constraints=None,
append_bos=False,
eos=None,
num_buckets=0,
src_lang_id=None,
tgt_lang_id=None,
pad_to_multiple=1,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
if tgt is not None:
assert len(src) == len(
tgt
), "Source and target must contain the same number of examples"
assert len(src) == len(
ref
), "Source and reference must contain the same number of examples"
self.src = src
self.ref = ref
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.ref_sizes = np.array(ref_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.sizes = (
np.vstack((self.src_sizes, self.tgt_sizes)).T
if self.tgt_sizes is not None
else self.src_sizes
)
self.src_dict = src_dict
self.ref_dict = ref_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.shuffle = shuffle
self.input_feeding = input_feeding
self.remove_eos_from_source = remove_eos_from_source
self.append_eos_to_target = append_eos_to_target
self.align_dataset = align_dataset
if self.align_dataset is not None:
assert (
self.tgt_sizes is not None
), "Both source and target needed when alignments are provided"
self.constraints = constraints
self.append_bos = append_bos
self.eos = eos if eos is not None else src_dict.eos()
self.src_lang_id = src_lang_id
self.tgt_lang_id = tgt_lang_id
if num_buckets > 0:
from fairseq.data import BucketPadLengthDataset
self.src = BucketPadLengthDataset(
self.src,
sizes=self.src_sizes,
num_buckets=num_buckets,
pad_idx=self.src_dict.pad(),
left_pad=self.left_pad_source,
)
self.src_sizes = self.src.sizes
logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
self.ref = BucketPadLengthDataset(
self.ref,
sizes=self.ref_sizes,
num_buckets=num_buckets,
pad_idx=self.ref_dict.pad(),
left_pad=self.left_pad_source,
)
self.ref_sizes = self.ref.sizes
logger.info("bucketing reference lengths: {}".format(list(self.src.buckets)))
if self.tgt is not None:
self.tgt = BucketPadLengthDataset(
self.tgt,
sizes=self.tgt_sizes,
num_buckets=num_buckets,
pad_idx=self.tgt_dict.pad(),
left_pad=self.left_pad_target,
)
self.tgt_sizes = self.tgt.sizes
logger.info(
"bucketing target lengths: {}".format(list(self.tgt.buckets))
)
# determine bucket sizes using self.num_tokens, which will return
# the padded lengths (thanks to BucketPadLengthDataset)
num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
self.buckets = [
(None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
]
else:
self.buckets = None
self.pad_to_multiple = pad_to_multiple
def get_batch_shapes(self):
return self.buckets
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
src_item = self.src[index]
ref_item = self.ref[index]
# Append EOS to end of tgt sentence if it does not have an EOS and remove
# EOS from end of src sentence if it exists. This is useful when we use
# use existing datasets for opposite directions i.e., when we want to
# use tgt_dataset as src_dataset and vice versa
if self.append_eos_to_target:
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.append_bos:
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
if self.tgt and self.tgt[index][0] != bos:
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
bos = self.src_dict.bos()
if self.src[index][0] != bos:
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
if self.ref[index][0] != bos:
ref_item = torch.cat([torch.LongTensor([bos]), self.ref[index]])
if self.remove_eos_from_source:
eos = self.src_dict.eos()
if self.src[index][-1] == eos:
src_item = self.src[index][:-1]
if self.ref[index][-1] == eos:
ref_item = self.ref[index][:-1]
example = {
"id": index,
"source": src_item,
"reference": ref_item,
"target": tgt_item,
}
if self.align_dataset is not None:
example["alignment"] = self.align_dataset[index]
if self.constraints is not None:
example["constraints"] = self.constraints[index]
return example
def __len__(self):
return len(self.src)
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
pad_to_length (dict, optional): a dictionary of
{'source': source_pad_to_length, 'target': target_pad_to_length}
to indicate the max length to pad to in source and target respectively.
Returns:
dict: a mini-batch with the following keys:
- `id` (LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will
appear on the left if *left_pad_source* is ``True``.
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
lengths of each source sentence of shape `(bsz)`
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one
position for teacher forcing, of shape `(bsz, tgt_len)`.
This key will not be present if *input_feeding* is
``False``. Padding will appear on the left if
*left_pad_target* is ``True``.
- `src_lang_id` (LongTensor): a long Tensor which contains source
language IDs of each sample in the batch
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the left if *left_pad_target* is ``True``.
- `tgt_lang_id` (LongTensor): a long Tensor which contains target language
IDs of each sample in the batch
"""
res = collate(
samples,
pad_idx=self.src_dict.pad(),
eos_idx=self.eos,
left_pad_source=self.left_pad_source,
left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
pad_to_length=pad_to_length,
pad_to_multiple=self.pad_to_multiple,
)
if self.src_lang_id is not None or self.tgt_lang_id is not None:
src_tokens = res["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
if self.src_lang_id is not None:
res["net_input"]["src_lang_id"] = (
torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
)
if self.tgt_lang_id is not None:
res["tgt_lang_id"] = (
torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
)
return res
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
sizes = self.src_sizes[indices]
if self.tgt_sizes is not None:
sizes = np.maximum(sizes, self.tgt_sizes[indices])
return sizes
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.src_sizes[index],
self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self)).astype(np.int64)
else:
indices = np.arange(len(self), dtype=np.int64)
if self.buckets is None:
# sort by target length, then source length
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
else:
# sort by bucketed_num_tokens, which is:
# max(padded_src_len, padded_tgt_len)
return indices[
np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
]
@property
def supports_prefetch(self):
return getattr(self.src, "supports_prefetch", False) and (
getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
)
def prefetch(self, indices):
self.src.prefetch(indices)
if self.tgt is not None:
self.tgt.prefetch(indices)
if self.align_dataset is not None:
self.align_dataset.prefetch(indices)
def filter_indices_by_size(self, indices, max_sizes):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
return data_utils.filter_paired_dataset_indices_by_size(
self.src_sizes,
self.tgt_sizes,
indices,
max_sizes,
)