juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss functions."""
import enum
from typing import Tuple, Mapping, Optional, Union
from flax.training import common_utils
import jax
import jax.numpy as jnp
import numpy as np
@jax.custom_vjp
def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray,
z_loss: float) -> jnp.ndarray:
"""Computes cross entropy loss with stable custom gradient.
Computes a stabilized-gradient version of:
-jnp.sum(targets * nn.log_softmax(logits), axis=-1)
If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2
will be added to the cross entropy loss (z = softmax normalization constant).
The two uses of z_loss are:
1. To keep the logits from drifting too far from zero, which can cause
unacceptable roundoff errors in bfloat16.
2. To encourage the logits to be normalized log-probabilities.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical one-hot targets [batch, length, num_classes] float
array.
z_loss: coefficient for auxilliary z-loss loss term.
Returns:
tuple with the total loss and the z_loss, both
float arrays with shape [batch, length].
"""
logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
log_softmax = logits - logits_sum
loss = -jnp.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
log_z = jnp.squeeze(logits_sum, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
return loss, total_z_loss
def _cross_entropy_with_logits_fwd(
logits: jnp.ndarray,
targets: jnp.ndarray,
z_loss: float = 0.0
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp
.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
"""Forward-mode of `cross_entropy_with_logits`."""
max_logit = logits.max(axis=-1, keepdims=True)
shifted = logits - max_logit
exp_shifted = jnp.exp(shifted)
sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True)
log_softmax = shifted - jnp.log(sum_exp)
loss = -jnp.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
return (loss, total_z_loss), (logits, targets, z_loss, exp_shifted, sum_exp,
log_softmax, log_z)
def _cross_entropy_with_logits_bwd(
res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray,
jnp.ndarray, jnp.ndarray], g: Tuple[jnp.ndarray, jnp.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Backward-mode of `cross_entropy_with_logits`."""
g = g[0] # Ignore z_loss component as that is only used for logging.
logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res
# z-loss term adds the (2 * z_loss * log_z) factor.
deriv = (
jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp -
targets)
g_logits = jnp.expand_dims(g, axis=-1) * deriv
g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax
return (jnp.asarray(g_logits,
logits.dtype), jnp.asarray(g_targets, targets.dtype),
jnp.array(0.0)) # sets z-loss coeff gradient to 0
cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd,
_cross_entropy_with_logits_bwd)
def compute_weighted_cross_entropy(
logits: jnp.ndarray,
targets: jnp.ndarray,
weights: Optional[jnp.ndarray] = None,
label_smoothing: float = 0.0,
z_loss: float = 0.0,
loss_normalizing_factor: Optional[float] = None
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Compute weighted cross entropy and entropy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch, length].
label_smoothing: label smoothing constant, used to determine the on and off
values.
z_loss: coefficient for auxiliary z-loss loss term.
loss_normalizing_factor: Constant to divide loss by. If not specified, loss
will not be normalized. Intended for backward compatibility with T5-MTF
training. Should not normally be used.
Returns:
Tuple of scalar loss, z_loss, and weight sum.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
(str(logits.shape), str(targets.shape)))
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) +
(vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
soft_targets = common_utils.onehot(
targets, vocab_size, on_value=confidence, off_value=low_confidence)
total_loss, total_z_loss = cross_entropy_with_logits(
logits, soft_targets, z_loss=z_loss)
total_loss = total_loss - normalizing_constant
weight_sum = np.prod(targets.shape)
if weights is not None:
total_loss = total_loss * weights
total_z_loss = total_z_loss * weights
weight_sum = jnp.sum(weights)
# By default, we do not normalize loss based on anything.
# We don't normalize based on batch size because the optimizers we use are
# pretty much scale invariant, so this simplifies things.
# We don't normalize based on number of non-padding tokens in order to treat
# each token as equally important regardless of sequence length.
if loss_normalizing_factor is not None:
total_loss /= loss_normalizing_factor
total_z_loss /= loss_normalizing_factor
return jnp.sum(total_loss), jnp.sum(total_z_loss), weight_sum
@enum.unique
class SpecialLossNormalizingFactor(enum.Enum):
"""Specially calcualted loss_normalizing_factors, that are not a constant.
Attributes:
NUM_REAL_TARGET_TOKENS: Whether to divide the loss by the number of real
(non-padding) tokens in the current target batch. If
'decoder_loss_weights' are specified, it will be the sum of the weights.
Otherwise it will be the number of non-zero 'decoder_target_tokens'.
NUM_TOTAL_TARGET_TOKENS: Whether to divide the loss by the total number of
target tokens, i.e., batch_size * target_seq_length (including padding).
AVERAGE_PER_SEQUENCE: This will first compute the per-sequence loss
(averaged over the number of real target tokens in the sequence), and then
compute the average of that over the sequences. This can be preferable to
NUM_REAL_TARGET_TOKENS for finetuning, because it will weigh all examples
equally, regardless of sequence length (which can be especially important
for multi-task finetuning).
"""
NUM_REAL_TARGET_TOKENS = 1
NUM_TOTAL_TARGET_TOKENS = 2
AVERAGE_PER_SEQUENCE = 3
def convert_special_loss_normalizing_factor_to_enum(
x: str) -> SpecialLossNormalizingFactor:
"""Converts stringified version of LNF to an enum.
This is useful because gin dynamic registration does not (currently)
have support for enum.
Args:
x: stringified version of SpecialLossNormalizingFactor enum.
Returns:
SpecialLossNormalizingFactor enum instance.
"""
x = x.upper()
if x == 'NUM_REAL_TARGET_TOKENS':
return SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS
if x == 'NUM_TOTAL_TARGET_TOKENS':
return SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS
if x == 'AVERAGE_PER_SEQUENCE':
return SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE
raise ValueError(
'Could not convert string \"%s\" to SpecialLossNormalizingFactor' % x)
def get_loss_normalizing_factor_and_weights(
loss_normalizing_factor: Optional[Union[float, int, str,
SpecialLossNormalizingFactor]],
batch: Mapping[str, jnp.ndarray]):
"""Get the float loss_normalizing_factor and loss weights.
If loss_normalizing_factor is float or None, this will simply return the
input loss_normalizing_factor and batch.
If loss_normalizing_factor is a SpecialLossNormalizingFactor, it will
return a float loss_normalizing_factor and loss weights corresponding to
the special LNF. See SpecialLossNormalizingFactor for more details.
Args:
loss_normalizing_factor: The input LNF, which may be a float, None, or
SpecialLossNormalizingFactor (or a stringified SLNF).
batch: Input data batch.
Returns:
Tuple of (output_loss_normalizing_factor, loss_weights).
'output_loss_normalizing_factor' is a scalar float (Python float
or jnp float).
'loss_weights' is the per token loss weight JNP array.
"""
loss_weights = batch.get('decoder_loss_weights', None)
if (loss_normalizing_factor is None or
not isinstance(loss_normalizing_factor,
(str, SpecialLossNormalizingFactor))):
return (loss_normalizing_factor, loss_weights)
if isinstance(loss_normalizing_factor, str):
loss_normalizing_factor = convert_special_loss_normalizing_factor_to_enum(
loss_normalizing_factor)
# If `loss_weights` are not provided, we assume that the padding id is 0 and
# that non-padding tokens in the decoder all correspond to the positions
# where loss should be taken. If more fine-grained behavior (e.g., taking
# loss on subset of 'decoder_target_tokens') is desired, provide
# `loss_weights` that account for this.
if loss_weights is None:
loss_weights = jnp.asarray(batch['decoder_target_tokens'] > 0, jnp.float32)
output_normalizing_factor = None
if (loss_normalizing_factor ==
SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS):
output_normalizing_factor = jnp.sum(loss_weights)
elif (loss_normalizing_factor ==
SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS):
output_normalizing_factor = np.prod(batch['decoder_target_tokens'].shape)
elif (loss_normalizing_factor ==
SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE):
loss_weights /= jnp.sum(loss_weights, axis=-1, keepdims=True) + 1e-3
output_normalizing_factor = jnp.sum(loss_weights)
else:
raise ValueError('Unsupported value of loss_normalizing_factor: %s' %
str(loss_normalizing_factor))
return (output_normalizing_factor, loss_weights)