|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import numpy as np |
|
import torch |
|
from examples.speech_recognition.data.collaters import Seq2SeqCollater |
|
|
|
|
|
class TestSeq2SeqCollator(unittest.TestCase): |
|
def test_collate(self): |
|
|
|
eos_idx = 1 |
|
pad_idx = 0 |
|
collater = Seq2SeqCollater( |
|
feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx |
|
) |
|
|
|
|
|
frames1 = np.array([[7, 8], [9, 10]]) |
|
frames2 = np.array([[1, 2], [3, 4], [5, 6]]) |
|
target1 = np.array([4, 2, 3, eos_idx]) |
|
target2 = np.array([3, 2, eos_idx]) |
|
sample1 = {"id": 0, "data": [frames1, target1]} |
|
sample2 = {"id": 1, "data": [frames2, target2]} |
|
batch = collater.collate([sample1, sample2]) |
|
|
|
|
|
self.assertTensorEqual(batch["id"], torch.tensor([1, 0])) |
|
self.assertEqual(batch["ntokens"], 7) |
|
self.assertTensorEqual( |
|
batch["net_input"]["src_tokens"], |
|
torch.tensor( |
|
[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]] |
|
), |
|
) |
|
self.assertTensorEqual( |
|
batch["net_input"]["prev_output_tokens"], |
|
torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]), |
|
) |
|
self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2])) |
|
self.assertTensorEqual( |
|
batch["target"], |
|
torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]), |
|
) |
|
self.assertEqual(batch["nsentences"], 2) |
|
|
|
def assertTensorEqual(self, t1, t2): |
|
self.assertEqual(t1.size(), t2.size(), "size mismatch") |
|
self.assertEqual(t1.ne(t2).long().sum(), 0) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|