File size: 1,372 Bytes
6c9ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock

import torch

from mmdet.engine.hooks import CheckInvalidLossHook


class TestCheckInvalidLossHook(TestCase):

    def test_after_train_iter(self):
        n = 50
        hook = CheckInvalidLossHook(n)
        runner = Mock()
        runner.logger = Mock()
        runner.logger.info = Mock()

        # Test `after_train_iter` function within the n iteration.
        runner.iter = 10
        outputs = dict(loss=torch.LongTensor([2]))
        hook.after_train_iter(runner, 10, outputs=outputs)
        outputs = dict(loss=torch.tensor(float('nan')))
        hook.after_train_iter(runner, 10, outputs=outputs)
        outputs = dict(loss=torch.tensor(float('inf')))
        hook.after_train_iter(runner, 10, outputs=outputs)

        # Test `after_train_iter` at the n iteration.
        runner.iter = n - 1
        outputs = dict(loss=torch.LongTensor([2]))
        hook.after_train_iter(runner, n - 1, outputs=outputs)
        outputs = dict(loss=torch.tensor(float('nan')))
        with self.assertRaises(AssertionError):
            hook.after_train_iter(runner, n - 1, outputs=outputs)
        outputs = dict(loss=torch.tensor(float('inf')))
        with self.assertRaises(AssertionError):
            hook.after_train_iter(runner, n - 1, outputs=outputs)