Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import torch | |
| import torch.nn.functional as F | |
| logger = logging.getLogger(__name__) | |
| def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): | |
| lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) | |
| return F.nll_loss( | |
| lprobs, | |
| target, | |
| ignore_index=ignore_index, | |
| reduction=reduction, | |
| ) | |
| try: | |
| import xentropy_cuda | |
| from apex.contrib import xentropy | |
| def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): | |
| if logits.device == torch.device("cpu"): | |
| return _cross_entropy_pytorch(logits, target, ignore_index, reduction) | |
| else: | |
| if not getattr(cross_entropy, "_has_logged_once", False): | |
| logger.info("using fused cross entropy") | |
| cross_entropy._has_logged_once = True | |
| half_to_float = logits.dtype == torch.half | |
| losses = xentropy.SoftmaxCrossEntropyLoss.apply( | |
| logits, | |
| target, | |
| 0.0, | |
| ignore_index, | |
| half_to_float, | |
| ) | |
| if reduction == "sum": | |
| return losses.sum() | |
| elif reduction == "mean": | |
| if ignore_index >= 0: | |
| return losses.sum() / target.ne(ignore_index).sum() | |
| else: | |
| return losses.mean() | |
| elif reduction == "none": | |
| return losses | |
| else: | |
| raise NotImplementedError | |
| except ImportError: | |
| def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): | |
| return _cross_entropy_pytorch(logits, target, ignore_index, reduction) | |