JustinLin610
update
8437114
raw history blame
No virus
1.87 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn.functional as F
logger = logging.getLogger(__name__)
def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"):
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
return F.nll_loss(
lprobs,
target,
ignore_index=ignore_index,
reduction=reduction,
)
try:
import xentropy_cuda
from apex.contrib import xentropy
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
if logits.device == torch.device("cpu"):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
else:
if not getattr(cross_entropy, "_has_logged_once", False):
logger.info("using fused cross entropy")
cross_entropy._has_logged_once = True
half_to_float = logits.dtype == torch.half
losses = xentropy.SoftmaxCrossEntropyLoss.apply(
logits,
target,
0.0,
ignore_index,
half_to_float,
)
if reduction == "sum":
return losses.sum()
elif reduction == "mean":
if ignore_index >= 0:
return losses.sum() / target.ne(ignore_index).sum()
else:
return losses.mean()
elif reduction == "none":
return losses
else:
raise NotImplementedError
except ImportError:
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction)