KyanChen's picture
Upload 303 files
4d0eb62
raw
history blame
3.69 kB
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import torch
import torch.nn.functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
``loss_func(pred, target, **kwargs)``. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like ``loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)``.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
"""This function converts target class indices to one-hot vectors, given
the number of classes.
Args:
targets (Tensor): The ground truth label of the prediction
with shape (N, 1)
classes (int): the number of classes.
Returns:
Tensor: Processed loss values.
"""
assert (torch.max(targets).item() <
classes), 'Class Index must be less than number of classes'
one_hot_targets = F.one_hot(
targets.long().squeeze(-1), num_classes=classes)
return one_hot_targets