mshukor
init
3eb682b
raw
history blame
542 Bytes
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Loss functions."""
import torch.nn as nn
_LOSSES = {
"cross_entropy": nn.CrossEntropyLoss,
"bce": nn.BCELoss,
"bce_logit": nn.BCEWithLogitsLoss,
}
def get_loss_func(loss_name):
"""
Retrieve the loss given the loss name.
Args (int):
loss_name: the name of the loss to use.
"""
if loss_name not in _LOSSES.keys():
raise NotImplementedError("Loss {} is not supported".format(loss_name))
return _LOSSES[loss_name]