|
|
|
|
|
import unittest |
|
|
|
import hydra |
|
from hydra.experimental import compose, initialize_config_module |
|
from pytorchvideo_trainer.module.byol import BYOLModule |
|
from pytorchvideo_trainer.module.moco_v2 import MOCOV2Module |
|
from pytorchvideo_trainer.module.simclr import SimCLRModule |
|
from pytorchvideo_trainer.module.video_classification import VideoClassificationModule |
|
|
|
|
|
class TestVideoClassificationModuleConf(unittest.TestCase): |
|
def test_init_with_hydra(self) -> None: |
|
with initialize_config_module(config_module="pytorchvideo_trainer.conf"): |
|
test_conf = compose( |
|
config_name="video_classification_train_app_conf", |
|
overrides=["module/model=slow_r50"], |
|
) |
|
test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) |
|
self.assertIsInstance(test_module, VideoClassificationModule) |
|
self.assertIsNotNone(test_module.model) |
|
|
|
|
|
class TestVideoSimCLRModuleConf(unittest.TestCase): |
|
def test_init_with_hydra(self) -> None: |
|
with initialize_config_module(config_module="pytorchvideo_trainer.conf"): |
|
test_conf = compose( |
|
config_name="simclr_train_app_conf", |
|
) |
|
test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) |
|
self.assertIsInstance(test_module, SimCLRModule) |
|
self.assertIsNotNone(test_module.model) |
|
|
|
|
|
class TestVideoBYOLModuleConf(unittest.TestCase): |
|
def test_init_with_hydra(self) -> None: |
|
with initialize_config_module(config_module="pytorchvideo_trainer.conf"): |
|
test_conf = compose( |
|
config_name="byol_train_app_conf", |
|
) |
|
test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) |
|
self.assertIsInstance(test_module, BYOLModule) |
|
self.assertIsNotNone(test_module.model) |
|
|
|
|
|
class TestVideoMOCOV2ModuleConf(unittest.TestCase): |
|
def test_init_with_hydra(self) -> None: |
|
with initialize_config_module(config_module="pytorchvideo_trainer.conf"): |
|
test_conf = compose( |
|
config_name="moco_v2_train_app_conf", |
|
|
|
) |
|
test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) |
|
self.assertIsInstance(test_module, MOCOV2Module) |
|
self.assertIsNotNone(test_module.model) |
|
|