File size: 1,372 Bytes
6c9ac8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import torch
from mmdet.engine.hooks import CheckInvalidLossHook
class TestCheckInvalidLossHook(TestCase):
def test_after_train_iter(self):
n = 50
hook = CheckInvalidLossHook(n)
runner = Mock()
runner.logger = Mock()
runner.logger.info = Mock()
# Test `after_train_iter` function within the n iteration.
runner.iter = 10
outputs = dict(loss=torch.LongTensor([2]))
hook.after_train_iter(runner, 10, outputs=outputs)
outputs = dict(loss=torch.tensor(float('nan')))
hook.after_train_iter(runner, 10, outputs=outputs)
outputs = dict(loss=torch.tensor(float('inf')))
hook.after_train_iter(runner, 10, outputs=outputs)
# Test `after_train_iter` at the n iteration.
runner.iter = n - 1
outputs = dict(loss=torch.LongTensor([2]))
hook.after_train_iter(runner, n - 1, outputs=outputs)
outputs = dict(loss=torch.tensor(float('nan')))
with self.assertRaises(AssertionError):
hook.after_train_iter(runner, n - 1, outputs=outputs)
outputs = dict(loss=torch.tensor(float('inf')))
with self.assertRaises(AssertionError):
hook.after_train_iter(runner, n - 1, outputs=outputs)
|