# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import unittest | |
import torch | |
from pytorchvideo.layers import DropPath | |
class TestDropPath(unittest.TestCase): | |
def setUp(self): | |
super().setUp() | |
torch.set_rng_state(torch.manual_seed(42).get_state()) | |
def test_dropPath(self): | |
# Input should be same if drop_prob = 0. | |
net_drop_path = DropPath(drop_prob=0.0) | |
fake_input = torch.rand(64, 10, 20) | |
output = net_drop_path(fake_input) | |
self.assertTrue(output.equal(fake_input)) | |
# Test when drop_prob > 0. | |
net_drop_path = DropPath(drop_prob=0.5) | |
fake_input = torch.rand(64, 10, 20) | |
output = net_drop_path(fake_input) | |
self.assertTrue(output.shape, fake_input.shape) | |