File size: 4,164 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright 2022 The MT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for vocabularies."""

from absl.testing import absltest
from mt3 import vocabularies

import numpy as np
import tensorflow.compat.v2 as tf

tf.compat.v1.enable_eager_execution()


class VocabulariesTest(absltest.TestCase):

  def test_velocity_quantization(self):
    self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=1))
    self.assertEqual(0, vocabularies.velocity_to_bin(0, num_velocity_bins=127))
    self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=1))
    self.assertEqual(0, vocabularies.bin_to_velocity(0, num_velocity_bins=127))

    self.assertEqual(
        1,
        vocabularies.velocity_to_bin(
            vocabularies.bin_to_velocity(1, num_velocity_bins=1),
            num_velocity_bins=1))

    for velocity_bin in range(1, 128):
      self.assertEqual(
          velocity_bin,
          vocabularies.velocity_to_bin(
              vocabularies.bin_to_velocity(velocity_bin, num_velocity_bins=127),
              num_velocity_bins=127))

  def test_encode_decode(self):
    vocab = vocabularies.GenericTokenVocabulary(32)
    input_tokens = [1, 2, 3]
    expected_encoded = [4, 5, 6]

    # Encode
    self.assertSequenceEqual(vocab.encode(input_tokens), expected_encoded)
    np.testing.assert_array_equal(
        vocab.encode_tf(tf.convert_to_tensor(input_tokens)).numpy(),
        expected_encoded)

    # Decode
    self.assertSequenceEqual(vocab.decode(expected_encoded), input_tokens)
    np.testing.assert_array_equal(
        vocab.decode_tf(tf.convert_to_tensor(expected_encoded)).numpy(),
        input_tokens)

  def test_decode_invalid_ids(self):
    vocab = vocabularies.GenericTokenVocabulary(32, extra_ids=4)
    encoded = [0, 2, 3, 4, 34, 35]
    expected_decoded = [-2, -2, 0, 1, 31, -2]
    self.assertSequenceEqual(vocab.decode(encoded), expected_decoded)
    np.testing.assert_array_equal(
        vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(),
        expected_decoded)

  def test_decode_eos(self):
    vocab = vocabularies.GenericTokenVocabulary(32)
    encoded = [0, 2, 3, 4, 1, 0, 1, 0]
    # Python decode function truncates everything after first EOS.
    expected_decoded = [-2, -2, 0, 1, -1]
    self.assertSequenceEqual(vocab.decode(encoded), expected_decoded)
    # TF decode function preserves array length.
    expected_decoded_tf = [-2, -2, 0, 1, -1, -1, -1, -1]
    np.testing.assert_array_equal(
        vocab.decode_tf(tf.convert_to_tensor(encoded)).numpy(),
        expected_decoded_tf)

  def test_encode_invalid_id(self):
    vocab = vocabularies.GenericTokenVocabulary(32)
    inputs = [0, 15, 31]
    # No exception expected.
    vocab.encode(inputs)
    vocab.encode_tf(tf.convert_to_tensor(inputs))

    inputs_too_low = [-1, 15, 31]
    with self.assertRaises(ValueError):
      vocab.encode(inputs_too_low)
    with self.assertRaises(tf.errors.InvalidArgumentError):
      vocab.encode_tf(tf.convert_to_tensor(inputs_too_low))

    inputs_too_high = [0, 15, 32]
    with self.assertRaises(ValueError):
      vocab.encode(inputs_too_high)
    with self.assertRaises(tf.errors.InvalidArgumentError):
      vocab.encode_tf(tf.convert_to_tensor(inputs_too_high))

  def test_encode_dtypes(self):
    vocab = vocabularies.GenericTokenVocabulary(32)
    inputs = [0, 15, 31]
    encoded32 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int32))
    self.assertEqual(tf.int32, encoded32.dtype)
    encoded64 = vocab.encode_tf(tf.convert_to_tensor(inputs, tf.int64))
    self.assertEqual(tf.int64, encoded64.dtype)


if __name__ == '__main__':
  absltest.main()