juancopi81's picture
Add t5x and mt3 models
b100e1c
raw history blame
No virus
5.12 kB
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for models."""
from unittest import mock
from absl.testing import absltest
from clu import metrics as clu_metrics_lib
from flax import core as flax_core
import jax.numpy as jnp
import numpy as np
from t5x import metrics as metrics_lib
from t5x.contrib.moe import models
Accuracy = clu_metrics_lib.Accuracy
AveragePerStep = metrics_lib.AveragePerStep
ExpertMetrics = models.ExpertMetrics
FrozenDict = flax_core.frozen_dict.FrozenDict
class ModelsTest(absltest.TestCase):
def test_expert_losses(self):
diversity_metrics = [
ExpertMetrics(
auxiliary_loss=1.,
router_z_loss=0.,
fraction_tokens_left_behind=0.5,
expert_usage=0.5,
router_confidence=0.5),
ExpertMetrics(
auxiliary_loss=2.,
router_z_loss=1.,
fraction_tokens_left_behind=0.5,
expert_usage=0.5,
router_confidence=0.5)
]
aux_loss, router_z_loss = models._expert_losses(
diversity_metrics, auxiliary_loss_factor=0.1, router_z_loss_factor=10)
self.assertEqual(aux_loss, 0.15)
self.assertEqual(router_z_loss, 5.)
def test_expert_metrics(self):
diversity_metrics = [
ExpertMetrics(
auxiliary_loss=1.,
router_z_loss=0.,
fraction_tokens_left_behind=1.,
expert_usage=0.7,
router_confidence=0.5),
ExpertMetrics(
auxiliary_loss=2.,
router_z_loss=1.,
fraction_tokens_left_behind=0.5,
expert_usage=0.5,
router_confidence=0.5)
]
actual_metrics = models._expert_metrics(
diversity_metrics,
total_loss=100.,
z_loss=1.,
auxiliary_loss=3.,
router_z_loss=7.,
num_tokens=2)
actual_metrics = metrics_lib.set_step_metrics_num_steps(actual_metrics, 1)
actual_computed_metrics = {
k: v.compute() for k, v in actual_metrics.items()
}
expected_metrics = {
'cross_ent_loss': 89.0,
'cross_ent_loss_per_all_target_tokens': 44.5,
'experts/auxiliary_loss': 3.,
'experts/expert_usage': 0.6,
'experts/fraction_tokens_left_behind': 0.75,
'experts/router_confidence': 0.5,
'experts/router_z_loss': 7.
}
self.assertEqual(actual_computed_metrics, expected_metrics)
def test_extract_from_non_expert_model(self):
empty_state = FrozenDict({'intermediates': {}})
with self.assertRaisesRegex(ValueError,
'Unable to find any expert diversity metrics.'):
models._extract_diversity_metrics(empty_state)
def test_model(self):
encoder_input_tokens = jnp.ones((2, 3))
decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]])
decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]])
decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]])
logits = jnp.arange(0, 24).reshape((2, 4, 3))
params = {'foo': jnp.zeros(3)}
mock_transformer = mock.Mock()
mock_transformer.apply.return_value = logits
mock_transformer.dtype = jnp.float32
batch = {
'encoder_input_tokens': encoder_input_tokens,
'decoder_input_tokens': decoder_input_tokens,
'decoder_target_tokens': decoder_target_tokens,
'decoder_loss_weights': decoder_loss_weights
}
def mock_init(self):
self.module = mock_transformer
with mock.patch.object(
models.MoeEncoderDecoderModel, '__init__', new=mock_init):
model = models.MoeEncoderDecoderModel()
result = model.score_batch(params, batch)
mock_transformer.apply.assert_called_with({'params': params},
encoder_input_tokens,
decoder_input_tokens,
decoder_target_tokens,
encoder_segment_ids=None,
decoder_segment_ids=None,
encoder_positions=None,
decoder_positions=None,
decode=False,
enable_dropout=False,
rngs=None,
mutable=False)
np.testing.assert_allclose(result, [-3.2228181, -1.8152122], rtol=1e-5)
if __name__ == '__main__':
absltest.main()