File size: 948 Bytes
3133fdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import torch
from pytorchvideo.models.byol import BYOL
from torch import nn
class TestBYOL(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_byol(self):
byol = BYOL(
backbone=nn.Linear(8, 4),
projector=nn.Linear(4, 4),
feature_dim=4,
norm=nn.BatchNorm1d,
)
for crop1, crop2 in TestBYOL._get_inputs():
byol(crop1, crop2)
@staticmethod
def _get_inputs() -> torch.tensor:
"""
Provide different tensors as test cases.
Yield:
(torch.tensor): tensor as test case input.
"""
# Prepare random inputs as test cases.
shapes = ((2, 8),)
for shape in shapes:
yield torch.rand(shape), torch.rand(shape)
|