File size: 783 Bytes
3133fdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import torch
from pytorchvideo.layers import DropPath
class TestDropPath(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_dropPath(self):
# Input should be same if drop_prob = 0.
net_drop_path = DropPath(drop_prob=0.0)
fake_input = torch.rand(64, 10, 20)
output = net_drop_path(fake_input)
self.assertTrue(output.equal(fake_input))
# Test when drop_prob > 0.
net_drop_path = DropPath(drop_prob=0.5)
fake_input = torch.rand(64, 10, 20)
output = net_drop_path(fake_input)
self.assertTrue(output.shape, fake_input.shape)
|