File size: 2,727 Bytes
3133fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import unittest
from typing import Callable

import torch
from fvcore.common.benchmark import benchmark
from pytorchvideo.data.utils import thwc_to_cthw
from pytorchvideo.transforms.functional import short_side_scale
from utils import create_dummy_video_frames


class TestBenchmarkTransforms(unittest.TestCase):
    def setUp(self):
        super().setUp()
        torch.set_rng_state(torch.manual_seed(42).get_state())

    def test_benchmark_short_side_scale_pytorch(self, num_iters: int = 10) -> None:
        """
        Benchmark scale operation with pytorch backend.
        Args:
            num_iters (int): number of iterations to perform benchmarking.
        """
        kwargs_list = [
            {"temporal_size": 8, "ori_spatial_size": (128, 128), "dst_short_size": 112},
            {
                "temporal_size": 16,
                "ori_spatial_size": (128, 128),
                "dst_short_size": 112,
            },
            {
                "temporal_size": 32,
                "ori_spatial_size": (128, 128),
                "dst_short_size": 112,
            },
            {"temporal_size": 8, "ori_spatial_size": (256, 256), "dst_short_size": 224},
            {
                "temporal_size": 16,
                "ori_spatial_size": (256, 256),
                "dst_short_size": 224,
            },
            {
                "temporal_size": 32,
                "ori_spatial_size": (256, 256),
                "dst_short_size": 224,
            },
            {"temporal_size": 8, "ori_spatial_size": (320, 320), "dst_short_size": 224},
            {
                "temporal_size": 16,
                "ori_spatial_size": (320, 320),
                "dst_short_size": 224,
            },
            {
                "temporal_size": 32,
                "ori_spatial_size": (320, 320),
                "dst_short_size": 224,
            },
        ]

        def _init_benchmark_short_side_scale(**kwargs) -> Callable:
            x = thwc_to_cthw(
                create_dummy_video_frames(
                    kwargs["temporal_size"],
                    kwargs["ori_spatial_size"][0],
                    kwargs["ori_spatial_size"][1],
                )
            ).to(dtype=torch.float32)

            def func_to_benchmark() -> None:
                _ = short_side_scale(x, kwargs["dst_short_size"])
                return

            return func_to_benchmark

        benchmark(
            _init_benchmark_short_side_scale,
            "benchmark_short_side_scale_pytorch",
            kwargs_list,
            num_iters=num_iters,
            warmup_iters=2,
        )
        self.assertTrue(True)