# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Loss functions.""" import torch.nn as nn _LOSSES = { "cross_entropy": nn.CrossEntropyLoss, "bce": nn.BCELoss, "bce_logit": nn.BCEWithLogitsLoss, } def get_loss_func(loss_name): """ Retrieve the loss given the loss name. Args (int): loss_name: the name of the loss to use. """ if loss_name not in _LOSSES.keys(): raise NotImplementedError("Loss {} is not supported".format(loss_name)) return _LOSSES[loss_name]