# Copyright 2022 The T5X Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Loss functions.""" import enum from typing import Tuple, Mapping, Optional, Union from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np @jax.custom_vjp def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float) -> jnp.ndarray: """Computes cross entropy loss with stable custom gradient. Computes a stabilized-gradient version of: -jnp.sum(targets * nn.log_softmax(logits), axis=-1) If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 will be added to the cross entropy loss (z = softmax normalization constant). The two uses of z_loss are: 1. To keep the logits from drifting too far from zero, which can cause unacceptable roundoff errors in bfloat16. 2. To encourage the logits to be normalized log-probabilities. Args: logits: [batch, length, num_classes] float array. targets: categorical one-hot targets [batch, length, num_classes] float array. z_loss: coefficient for auxilliary z-loss loss term. Returns: tuple with the total loss and the z_loss, both float arrays with shape [batch, length]. """ logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) log_softmax = logits - logits_sum loss = -jnp.sum(targets * log_softmax, axis=-1) # Add auxilliary z-loss term. log_z = jnp.squeeze(logits_sum, axis=-1) total_z_loss = z_loss * jax.lax.square(log_z) loss += total_z_loss return loss, total_z_loss def _cross_entropy_with_logits_fwd( logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0 ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp .ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]]: """Forward-mode of `cross_entropy_with_logits`.""" max_logit = logits.max(axis=-1, keepdims=True) shifted = logits - max_logit exp_shifted = jnp.exp(shifted) sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True) log_softmax = shifted - jnp.log(sum_exp) loss = -jnp.sum(targets * log_softmax, axis=-1) # Add auxilliary z-loss term. log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) total_z_loss = z_loss * jax.lax.square(log_z) loss += total_z_loss return (loss, total_z_loss), (logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z) def _cross_entropy_with_logits_bwd( res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], g: Tuple[jnp.ndarray, jnp.ndarray] ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Backward-mode of `cross_entropy_with_logits`.""" g = g[0] # Ignore z_loss component as that is only used for logging. logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res # z-loss term adds the (2 * z_loss * log_z) factor. deriv = ( jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets) g_logits = jnp.expand_dims(g, axis=-1) * deriv g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax return (jnp.asarray(g_logits, logits.dtype), jnp.asarray(g_targets, targets.dtype), jnp.array(0.0)) # sets z-loss coeff gradient to 0 cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd) def compute_weighted_cross_entropy( logits: jnp.ndarray, targets: jnp.ndarray, weights: Optional[jnp.ndarray] = None, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Optional[float] = None ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. z_loss: coefficient for auxiliary z-loss loss term. loss_normalizing_factor: Constant to divide loss by. If not specified, loss will not be normalized. Intended for backward compatibility with T5-MTF training. Should not normally be used. Returns: Tuple of scalar loss, z_loss, and weight sum. """ if logits.ndim != targets.ndim + 1: raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence) total_loss, total_z_loss = cross_entropy_with_logits( logits, soft_targets, z_loss=z_loss) total_loss = total_loss - normalizing_constant weight_sum = np.prod(targets.shape) if weights is not None: total_loss = total_loss * weights total_z_loss = total_z_loss * weights weight_sum = jnp.sum(weights) # By default, we do not normalize loss based on anything. # We don't normalize based on batch size because the optimizers we use are # pretty much scale invariant, so this simplifies things. # We don't normalize based on number of non-padding tokens in order to treat # each token as equally important regardless of sequence length. if loss_normalizing_factor is not None: total_loss /= loss_normalizing_factor total_z_loss /= loss_normalizing_factor return jnp.sum(total_loss), jnp.sum(total_z_loss), weight_sum @enum.unique class SpecialLossNormalizingFactor(enum.Enum): """Specially calcualted loss_normalizing_factors, that are not a constant. Attributes: NUM_REAL_TARGET_TOKENS: Whether to divide the loss by the number of real (non-padding) tokens in the current target batch. If 'decoder_loss_weights' are specified, it will be the sum of the weights. Otherwise it will be the number of non-zero 'decoder_target_tokens'. NUM_TOTAL_TARGET_TOKENS: Whether to divide the loss by the total number of target tokens, i.e., batch_size * target_seq_length (including padding). AVERAGE_PER_SEQUENCE: This will first compute the per-sequence loss (averaged over the number of real target tokens in the sequence), and then compute the average of that over the sequences. This can be preferable to NUM_REAL_TARGET_TOKENS for finetuning, because it will weigh all examples equally, regardless of sequence length (which can be especially important for multi-task finetuning). """ NUM_REAL_TARGET_TOKENS = 1 NUM_TOTAL_TARGET_TOKENS = 2 AVERAGE_PER_SEQUENCE = 3 def convert_special_loss_normalizing_factor_to_enum( x: str) -> SpecialLossNormalizingFactor: """Converts stringified version of LNF to an enum. This is useful because gin dynamic registration does not (currently) have support for enum. Args: x: stringified version of SpecialLossNormalizingFactor enum. Returns: SpecialLossNormalizingFactor enum instance. """ x = x.upper() if x == 'NUM_REAL_TARGET_TOKENS': return SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS if x == 'NUM_TOTAL_TARGET_TOKENS': return SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS if x == 'AVERAGE_PER_SEQUENCE': return SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE raise ValueError( 'Could not convert string \"%s\" to SpecialLossNormalizingFactor' % x) def get_loss_normalizing_factor_and_weights( loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]], batch: Mapping[str, jnp.ndarray]): """Get the float loss_normalizing_factor and loss weights. If loss_normalizing_factor is float or None, this will simply return the input loss_normalizing_factor and batch. If loss_normalizing_factor is a SpecialLossNormalizingFactor, it will return a float loss_normalizing_factor and loss weights corresponding to the special LNF. See SpecialLossNormalizingFactor for more details. Args: loss_normalizing_factor: The input LNF, which may be a float, None, or SpecialLossNormalizingFactor (or a stringified SLNF). batch: Input data batch. Returns: Tuple of (output_loss_normalizing_factor, loss_weights). 'output_loss_normalizing_factor' is a scalar float (Python float or jnp float). 'loss_weights' is the per token loss weight JNP array. """ loss_weights = batch.get('decoder_loss_weights', None) if (loss_normalizing_factor is None or not isinstance(loss_normalizing_factor, (str, SpecialLossNormalizingFactor))): return (loss_normalizing_factor, loss_weights) if isinstance(loss_normalizing_factor, str): loss_normalizing_factor = convert_special_loss_normalizing_factor_to_enum( loss_normalizing_factor) # If `loss_weights` are not provided, we assume that the padding id is 0 and # that non-padding tokens in the decoder all correspond to the positions # where loss should be taken. If more fine-grained behavior (e.g., taking # loss on subset of 'decoder_target_tokens') is desired, provide # `loss_weights` that account for this. if loss_weights is None: loss_weights = jnp.asarray(batch['decoder_target_tokens'] > 0, jnp.float32) output_normalizing_factor = None if (loss_normalizing_factor == SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS): output_normalizing_factor = jnp.sum(loss_weights) elif (loss_normalizing_factor == SpecialLossNormalizingFactor.NUM_TOTAL_TARGET_TOKENS): output_normalizing_factor = np.prod(batch['decoder_target_tokens'].shape) elif (loss_normalizing_factor == SpecialLossNormalizingFactor.AVERAGE_PER_SEQUENCE): loss_weights /= jnp.sum(loss_weights, axis=-1, keepdims=True) + 1e-3 output_normalizing_factor = jnp.sum(loss_weights) else: raise ValueError('Unsupported value of loss_normalizing_factor: %s' % str(loss_normalizing_factor)) return (output_normalizing_factor, loss_weights)