Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
This module contains collection of classes which implement | |
collate functionalities for various tasks. | |
Collaters should know what data to expect for each sample | |
and they should pack / collate them into batches | |
""" | |
from __future__ import absolute_import, division, print_function, unicode_literals | |
import numpy as np | |
import torch | |
from fairseq.data import data_utils as fairseq_data_utils | |
class Seq2SeqCollater(object): | |
""" | |
Implements collate function mainly for seq2seq tasks | |
This expects each sample to contain feature (src_tokens) and | |
targets. | |
This collator is also used for aligned training task. | |
""" | |
def __init__( | |
self, | |
feature_index=0, | |
label_index=1, | |
pad_index=1, | |
eos_index=2, | |
move_eos_to_beginning=True, | |
): | |
self.feature_index = feature_index | |
self.label_index = label_index | |
self.pad_index = pad_index | |
self.eos_index = eos_index | |
self.move_eos_to_beginning = move_eos_to_beginning | |
def _collate_frames(self, frames): | |
"""Convert a list of 2d frames into a padded 3d tensor | |
Args: | |
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is | |
length of i-th frame and f_dim is static dimension of features | |
Returns: | |
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] | |
""" | |
len_max = max(frame.size(0) for frame in frames) | |
f_dim = frames[0].size(1) | |
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0) | |
for i, v in enumerate(frames): | |
res[i, : v.size(0)] = v | |
return res | |
def collate(self, samples): | |
""" | |
utility function to collate samples into batch for speech recognition. | |
""" | |
if len(samples) == 0: | |
return {} | |
# parse samples into torch tensors | |
parsed_samples = [] | |
for s in samples: | |
# skip invalid samples | |
if s["data"][self.feature_index] is None: | |
continue | |
source = s["data"][self.feature_index] | |
if isinstance(source, (np.ndarray, np.generic)): | |
source = torch.from_numpy(source) | |
target = s["data"][self.label_index] | |
if isinstance(target, (np.ndarray, np.generic)): | |
target = torch.from_numpy(target).long() | |
elif isinstance(target, list): | |
target = torch.LongTensor(target) | |
parsed_sample = {"id": s["id"], "source": source, "target": target} | |
parsed_samples.append(parsed_sample) | |
samples = parsed_samples | |
id = torch.LongTensor([s["id"] for s in samples]) | |
frames = self._collate_frames([s["source"] for s in samples]) | |
# sort samples by descending number of frames | |
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples]) | |
frames_lengths, sort_order = frames_lengths.sort(descending=True) | |
id = id.index_select(0, sort_order) | |
frames = frames.index_select(0, sort_order) | |
target = None | |
target_lengths = None | |
prev_output_tokens = None | |
if samples[0].get("target", None) is not None: | |
ntokens = sum(len(s["target"]) for s in samples) | |
target = fairseq_data_utils.collate_tokens( | |
[s["target"] for s in samples], | |
self.pad_index, | |
self.eos_index, | |
left_pad=False, | |
move_eos_to_beginning=False, | |
) | |
target = target.index_select(0, sort_order) | |
target_lengths = torch.LongTensor( | |
[s["target"].size(0) for s in samples] | |
).index_select(0, sort_order) | |
prev_output_tokens = fairseq_data_utils.collate_tokens( | |
[s["target"] for s in samples], | |
self.pad_index, | |
self.eos_index, | |
left_pad=False, | |
move_eos_to_beginning=self.move_eos_to_beginning, | |
) | |
prev_output_tokens = prev_output_tokens.index_select(0, sort_order) | |
else: | |
ntokens = sum(len(s["source"]) for s in samples) | |
batch = { | |
"id": id, | |
"ntokens": ntokens, | |
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths}, | |
"target": target, | |
"target_lengths": target_lengths, | |
"nsentences": len(samples), | |
} | |
if prev_output_tokens is not None: | |
batch["net_input"]["prev_output_tokens"] = prev_output_tokens | |
return batch | |