File size: 783 Bytes
3133fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import unittest

import torch
from pytorchvideo.layers import DropPath


class TestDropPath(unittest.TestCase):
    def setUp(self):
        super().setUp()
        torch.set_rng_state(torch.manual_seed(42).get_state())

    def test_dropPath(self):
        # Input should be same if drop_prob = 0.
        net_drop_path = DropPath(drop_prob=0.0)
        fake_input = torch.rand(64, 10, 20)
        output = net_drop_path(fake_input)
        self.assertTrue(output.equal(fake_input))
        # Test when drop_prob > 0.
        net_drop_path = DropPath(drop_prob=0.5)
        fake_input = torch.rand(64, 10, 20)
        output = net_drop_path(fake_input)
        self.assertTrue(output.shape, fake_input.shape)