# Copyright 2022 The MT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for vocabularies.""" from absl.testing import absltest from mt3 import vocabularies import numpy as np import tensorflow.compat.v2 as tf tf.compat.v1.enable_eager_execution() class VocabulariesTest(absltest.TestCase): def test_velocity_quantization(self): self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=1)) self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=127)) self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=1)) self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=127)) self.assertEqual( 1, vocabularies.velocity_to_bin( vocabularies.bin_to_velocity(1, num_velocity_bins=1), num_velocity_bins=1)) for velocity_bin in range(1, 128): self.assertEqual( velocity_bin, vocabularies.velocity_to_bin( vocabularies.bin_to_velocity(velocity_bin, num_velocity_bins=127), num_velocity_bins=127)) def test_encode_decode(self): vocab = vocabularies.GenericTokenVocabulary(32) input_tokens = [1, 2, 3] expected_encoded = [4, 5, 6] # Encode self.assertSequenceEqual(vocab.encode(input_tokens), expected_encoded) np.testing.assert_array_equal( vocab.encode_tf(tf.convert_to_tensor(input_tokens)).numpy(), expected_encoded) # Decode self.assertSequenceEqual(vocab.decode(expected_encoded), input_tokens) np.testing.assert_array_equal( vocab.decode_tf(tf.convert_to_tensor(expected_encoded)).numpy(), input_tokens) def test_decode_invalid_ids(self): vocab = vocabularies.GenericTokenVocabulary(32, extra_ids=4) encoded = [0, 2, 3, 4, 34, 35] expected_decoded = [-2, -2, 0, 1, 31, -2] self.assertSequenceEqual(vocab.decode(encoded), expected_decoded) np.testing.assert_array_equal( vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(), expected_decoded) def test_decode_eos(self): vocab = vocabularies.GenericTokenVocabulary(32) encoded = [0, 2, 3, 4, 1, 0, 1, 0] # Python decode function truncates everything after first EOS. expected_decoded = [-2, -2, 0, 1, -1] self.assertSequenceEqual(vocab.decode(encoded), expected_decoded) # TF decode function preserves array length. expected_decoded_tf = [-2, -2, 0, 1, -1, -1, -1, -1] np.testing.assert_array_equal( vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(), expected_decoded_tf) def test_encode_invalid_id(self): vocab = vocabularies.GenericTokenVocabulary(32) inputs = [0, 15, 31] # No exception expected. vocab.encode(inputs) vocab.encode_tf(tf.convert_to_tensor(inputs)) inputs_too_low = [-1, 15, 31] with self.assertRaises(ValueError): vocab.encode(inputs_too_low) with self.assertRaises(tf.errors.InvalidArgumentError): vocab.encode_tf(tf.convert_to_tensor(inputs_too_low)) inputs_too_high = [0, 15, 32] with self.assertRaises(ValueError): vocab.encode(inputs_too_high) with self.assertRaises(tf.errors.InvalidArgumentError): vocab.encode_tf(tf.convert_to_tensor(inputs_too_high)) def test_encode_dtypes(self): vocab = vocabularies.GenericTokenVocabulary(32) inputs = [0, 15, 31] encoded32 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int32)) self.assertEqual(tf.int32, encoded32.dtype) encoded64 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int64)) self.assertEqual(tf.int64, encoded64.dtype) if __name__ == '__main__': absltest.main()