Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for models.""" | |
from unittest import mock | |
from absl.testing import absltest | |
from clu import metrics as clu_metrics_lib | |
from flax import core as flax_core | |
import jax.numpy as jnp | |
import numpy as np | |
from t5x import metrics as metrics_lib | |
from t5x.contrib.moe import models | |
Accuracy = clu_metrics_lib.Accuracy | |
AveragePerStep = metrics_lib.AveragePerStep | |
ExpertMetrics = models.ExpertMetrics | |
FrozenDict = flax_core.frozen_dict.FrozenDict | |
class ModelsTest(absltest.TestCase): | |
def test_expert_losses(self): | |
diversity_metrics = [ | |
ExpertMetrics( | |
auxiliary_loss=1., | |
router_z_loss=0., | |
fraction_tokens_left_behind=0.5, | |
expert_usage=0.5, | |
router_confidence=0.5), | |
ExpertMetrics( | |
auxiliary_loss=2., | |
router_z_loss=1., | |
fraction_tokens_left_behind=0.5, | |
expert_usage=0.5, | |
router_confidence=0.5) | |
] | |
aux_loss, router_z_loss = models._expert_losses( | |
diversity_metrics, auxiliary_loss_factor=0.1, router_z_loss_factor=10) | |
self.assertEqual(aux_loss, 0.15) | |
self.assertEqual(router_z_loss, 5.) | |
def test_expert_metrics(self): | |
diversity_metrics = [ | |
ExpertMetrics( | |
auxiliary_loss=1., | |
router_z_loss=0., | |
fraction_tokens_left_behind=1., | |
expert_usage=0.7, | |
router_confidence=0.5), | |
ExpertMetrics( | |
auxiliary_loss=2., | |
router_z_loss=1., | |
fraction_tokens_left_behind=0.5, | |
expert_usage=0.5, | |
router_confidence=0.5) | |
] | |
actual_metrics = models._expert_metrics( | |
diversity_metrics, | |
total_loss=100., | |
z_loss=1., | |
auxiliary_loss=3., | |
router_z_loss=7., | |
num_tokens=2) | |
actual_metrics = metrics_lib.set_step_metrics_num_steps(actual_metrics, 1) | |
actual_computed_metrics = { | |
k: v.compute() for k, v in actual_metrics.items() | |
} | |
expected_metrics = { | |
'cross_ent_loss': 89.0, | |
'cross_ent_loss_per_all_target_tokens': 44.5, | |
'experts/auxiliary_loss': 3., | |
'experts/expert_usage': 0.6, | |
'experts/fraction_tokens_left_behind': 0.75, | |
'experts/router_confidence': 0.5, | |
'experts/router_z_loss': 7. | |
} | |
self.assertEqual(actual_computed_metrics, expected_metrics) | |
def test_extract_from_non_expert_model(self): | |
empty_state = FrozenDict({'intermediates': {}}) | |
with self.assertRaisesRegex(ValueError, | |
'Unable to find any expert diversity metrics.'): | |
models._extract_diversity_metrics(empty_state) | |
def test_model(self): | |
encoder_input_tokens = jnp.ones((2, 3)) | |
decoder_input_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) | |
decoder_target_tokens = jnp.array([[1, 2, 1, 0], [0, 1, 0, 2]]) | |
decoder_loss_weights = jnp.array([[1, 1, 1, 0], [0, 1, 0, 1]]) | |
logits = jnp.arange(0, 24).reshape((2, 4, 3)) | |
params = {'foo': jnp.zeros(3)} | |
mock_transformer = mock.Mock() | |
mock_transformer.apply.return_value = logits | |
mock_transformer.dtype = jnp.float32 | |
batch = { | |
'encoder_input_tokens': encoder_input_tokens, | |
'decoder_input_tokens': decoder_input_tokens, | |
'decoder_target_tokens': decoder_target_tokens, | |
'decoder_loss_weights': decoder_loss_weights | |
} | |
def mock_init(self): | |
self.module = mock_transformer | |
with mock.patch.object( | |
models.MoeEncoderDecoderModel, '__init__', new=mock_init): | |
model = models.MoeEncoderDecoderModel() | |
result = model.score_batch(params, batch) | |
mock_transformer.apply.assert_called_with({'params': params}, | |
encoder_input_tokens, | |
decoder_input_tokens, | |
decoder_target_tokens, | |
encoder_segment_ids=None, | |
decoder_segment_ids=None, | |
encoder_positions=None, | |
decoder_positions=None, | |
decode=False, | |
enable_dropout=False, | |
rngs=None, | |
mutable=False) | |
np.testing.assert_allclose(result, [-3.2228181, -1.8152122], rtol=1e-5) | |
if __name__ == '__main__': | |
absltest.main() | |