Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
from mmcv.utils import digit_version | |
from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss, DiceLoss, | |
DistributionFocalLoss, FocalLoss, | |
GaussianFocalLoss, | |
KnowledgeDistillationKLDivLoss, L1Loss, | |
MSELoss, QualityFocalLoss, SeesawLoss, | |
SmoothL1Loss, VarifocalLoss) | |
from mmdet.models.losses.ghm_loss import GHMC, GHMR | |
from mmdet.models.losses.iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, | |
GIoULoss, IoULoss) | |
def test_iou_type_loss_zeros_weight(loss_class): | |
pred = torch.rand((10, 4)) | |
target = torch.rand((10, 4)) | |
weight = torch.zeros(10) | |
loss = loss_class()(pred, target, weight) | |
assert loss == 0. | |
def test_loss_with_reduction_override(loss_class): | |
pred = torch.rand((10, 4)) | |
target = torch.rand((10, 4)), | |
weight = None | |
with pytest.raises(AssertionError): | |
# only reduction_override from [None, 'none', 'mean', 'sum'] | |
# is not allowed | |
reduction_override = True | |
loss_class()( | |
pred, target, weight, reduction_override=reduction_override) | |
def test_regression_losses(loss_class, input_shape): | |
pred = torch.rand(input_shape) | |
target = torch.rand(input_shape) | |
weight = torch.rand(input_shape) | |
# Test loss forward | |
loss = loss_class()(pred, target) | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with weight | |
loss = loss_class()(pred, target, weight) | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with reduction_override | |
loss = loss_class()(pred, target, reduction_override='mean') | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with avg_factor | |
loss = loss_class()(pred, target, avg_factor=10) | |
assert isinstance(loss, torch.Tensor) | |
with pytest.raises(ValueError): | |
# loss can evaluate with avg_factor only if | |
# reduction is None, 'none' or 'mean'. | |
reduction_override = 'sum' | |
loss_class()( | |
pred, target, avg_factor=10, reduction_override=reduction_override) | |
# Test loss forward with avg_factor and reduction | |
for reduction_override in [None, 'none', 'mean']: | |
loss_class()( | |
pred, target, avg_factor=10, reduction_override=reduction_override) | |
assert isinstance(loss, torch.Tensor) | |
def test_classification_losses(loss_class, input_shape): | |
if input_shape[0] == 0 and digit_version( | |
torch.__version__) < digit_version('1.5.0'): | |
pytest.skip( | |
f'CELoss in PyTorch {torch.__version__} does not support empty' | |
f'tensor.') | |
pred = torch.rand(input_shape) | |
target = torch.randint(0, 5, (input_shape[0], )) | |
# Test loss forward | |
loss = loss_class()(pred, target) | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with reduction_override | |
loss = loss_class()(pred, target, reduction_override='mean') | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with avg_factor | |
loss = loss_class()(pred, target, avg_factor=10) | |
assert isinstance(loss, torch.Tensor) | |
with pytest.raises(ValueError): | |
# loss can evaluate with avg_factor only if | |
# reduction is None, 'none' or 'mean'. | |
reduction_override = 'sum' | |
loss_class()( | |
pred, target, avg_factor=10, reduction_override=reduction_override) | |
# Test loss forward with avg_factor and reduction | |
for reduction_override in [None, 'none', 'mean']: | |
loss_class()( | |
pred, target, avg_factor=10, reduction_override=reduction_override) | |
assert isinstance(loss, torch.Tensor) | |
def test_GHMR_loss(loss_class, input_shape): | |
pred = torch.rand(input_shape) | |
target = torch.rand(input_shape) | |
weight = torch.rand(input_shape) | |
# Test loss forward | |
loss = loss_class()(pred, target, weight) | |
assert isinstance(loss, torch.Tensor) | |
def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore): | |
# Test cross_entropy loss | |
loss_class = CrossEntropyLoss( | |
use_sigmoid=use_sigmoid, | |
use_mask=False, | |
ignore_index=255, | |
avg_non_ignore=avg_non_ignore) | |
pred = torch.rand((10, 5)) | |
target = torch.randint(0, 5, (10, )) | |
ignored_indices = torch.randint(0, 10, (2, ), dtype=torch.long) | |
target[ignored_indices] = 255 | |
# Test loss forward with default ignore | |
loss_with_ignore = loss_class(pred, target, reduction_override=reduction) | |
assert isinstance(loss_with_ignore, torch.Tensor) | |
# Test loss forward with forward ignore | |
target[ignored_indices] = 255 | |
loss_with_forward_ignore = loss_class( | |
pred, target, ignore_index=255, reduction_override=reduction) | |
assert isinstance(loss_with_forward_ignore, torch.Tensor) | |
# Verify correctness | |
if avg_non_ignore: | |
# manually remove the ignored elements | |
not_ignored_indices = (target != 255) | |
pred = pred[not_ignored_indices] | |
target = target[not_ignored_indices] | |
loss = loss_class(pred, target, reduction_override=reduction) | |
assert torch.allclose(loss, loss_with_ignore) | |
assert torch.allclose(loss, loss_with_forward_ignore) | |
# test ignore all target | |
pred = torch.rand((10, 5)) | |
target = torch.ones((10, ), dtype=torch.long) * 255 | |
loss = loss_class(pred, target, reduction_override=reduction) | |
assert loss == 0 | |
def test_dice_loss(naive_dice): | |
loss_class = DiceLoss | |
pred = torch.rand((10, 4, 4)) | |
target = torch.rand((10, 4, 4)) | |
weight = torch.rand((10)) | |
# Test loss forward | |
loss = loss_class(naive_dice=naive_dice)(pred, target) | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with weight | |
loss = loss_class(naive_dice=naive_dice)(pred, target, weight) | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with reduction_override | |
loss = loss_class(naive_dice=naive_dice)( | |
pred, target, reduction_override='mean') | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with avg_factor | |
loss = loss_class(naive_dice=naive_dice)(pred, target, avg_factor=10) | |
assert isinstance(loss, torch.Tensor) | |
with pytest.raises(ValueError): | |
# loss can evaluate with avg_factor only if | |
# reduction is None, 'none' or 'mean'. | |
reduction_override = 'sum' | |
loss_class(naive_dice=naive_dice)( | |
pred, target, avg_factor=10, reduction_override=reduction_override) | |
# Test loss forward with avg_factor and reduction | |
for reduction_override in [None, 'none', 'mean']: | |
loss_class(naive_dice=naive_dice)( | |
pred, target, avg_factor=10, reduction_override=reduction_override) | |
assert isinstance(loss, torch.Tensor) | |
# Test loss forward with has_acted=False and use_sigmoid=False | |
with pytest.raises(NotImplementedError): | |
loss_class( | |
use_sigmoid=False, activate=True, naive_dice=naive_dice)(pred, | |
target) | |
# Test loss forward with weight.ndim != loss.ndim | |
with pytest.raises(AssertionError): | |
weight = torch.rand((2, 8)) | |
loss_class(naive_dice=naive_dice)(pred, target, weight) | |
# Test loss forward with len(weight) != len(pred) | |
with pytest.raises(AssertionError): | |
weight = torch.rand((8)) | |
loss_class(naive_dice=naive_dice)(pred, target, weight) | |