#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import unittest import numpy as np import torch from examples.speech_recognition.data.collaters import Seq2SeqCollater class TestSeq2SeqCollator(unittest.TestCase): def test_collate(self): eos_idx = 1 pad_idx = 0 collater = Seq2SeqCollater( feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx ) # 2 frames in the first sample and 3 frames in the second one frames1 = np.array([[7, 8], [9, 10]]) frames2 = np.array([[1, 2], [3, 4], [5, 6]]) target1 = np.array([4, 2, 3, eos_idx]) target2 = np.array([3, 2, eos_idx]) sample1 = {"id": 0, "data": [frames1, target1]} sample2 = {"id": 1, "data": [frames2, target2]} batch = collater.collate([sample1, sample2]) # collate sort inputs by frame's length before creating the batch self.assertTensorEqual(batch["id"], torch.tensor([1, 0])) self.assertEqual(batch["ntokens"], 7) self.assertTensorEqual( batch["net_input"]["src_tokens"], torch.tensor( [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]] ), ) self.assertTensorEqual( batch["net_input"]["prev_output_tokens"], torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]), ) self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2])) self.assertTensorEqual( batch["target"], torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]), ) self.assertEqual(batch["nsentences"], 2) def assertTensorEqual(self, t1, t2): self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertEqual(t1.ne(t2).long().sum(), 0) if __name__ == "__main__": unittest.main()