# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import torch import torch.nn.functional as F logger = logging.getLogger(__name__) def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) return F.nll_loss( lprobs, target, ignore_index=ignore_index, reduction=reduction, ) try: import xentropy_cuda from apex.contrib import xentropy def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): if logits.device == torch.device("cpu"): return _cross_entropy_pytorch(logits, target, ignore_index, reduction) else: if not getattr(cross_entropy, "_has_logged_once", False): logger.info("using fused cross entropy") cross_entropy._has_logged_once = True half_to_float = logits.dtype == torch.half losses = xentropy.SoftmaxCrossEntropyLoss.apply( logits, target, 0.0, ignore_index, half_to_float, ) if reduction == "sum": return losses.sum() elif reduction == "mean": if ignore_index >= 0: return losses.sum() / target.ne(ignore_index).sum() else: return losses.mean() elif reduction == "none": return losses else: raise NotImplementedError except ImportError: def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): return _cross_entropy_pytorch(logits, target, ignore_index, reduction)