File size: 2,600 Bytes
3133fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

# pyre-strict
from torchrecipes.core.base_train_app import BaseTrainApp
from util import (
    BaseTrainAppTestCase,
    create_small_kinetics_dataset,
    run_locally,
    tempdir,
)


class TestBYOLTrainApp(BaseTrainAppTestCase):
    def get_train_app(
        self,
        root_dir: str,
        fast_dev_run: bool = True,
        logger: bool = False,
    ) -> BaseTrainApp:
        create_small_kinetics_dataset(root_dir)
        overrides = [
            f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv",
            f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv",
            f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv",
            f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}",
            f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}",
            f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}",
            "datamodule.dataloader.train.num_workers=0",
            "datamodule.dataloader.val.num_workers=0",
            "datamodule.dataloader.test.num_workers=0",
            "module.knn_memory.length=50",
            "module.knn_memory.knn_k=2",
            "datamodule.dataloader.train.batch_size=2",
            "datamodule.dataloader.val.batch_size=2",
            "datamodule.dataloader.test.batch_size=2",
            "trainer.logger=false",
        ]
        app = self.create_app_from_hydra(
            config_module="pytorchvideo_trainer.conf",
            config_name="byol_train_app_conf",
            overrides=overrides,
        )
        trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger}
        self.mock_trainer_params(app, trainer_overrides)
        return app

    @run_locally
    @tempdir
    def test_byol_app_train_test_30_views(self, root_dir: str) -> None:
        train_app = self.get_train_app(
            root_dir=root_dir, fast_dev_run=False, logger=False
        )
        output = train_app.train()
        self.assertIsNotNone(output)
        output = train_app.test()
        self.assertIsNotNone(output)

        video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None)
        num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10)
        num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3)
        self.assertIsNotNone(video_clips_cnts)
        for _, sample_cnts in video_clips_cnts.items():
            self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts)