File size: 2,727 Bytes
3133fdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
from typing import Callable
import torch
from fvcore.common.benchmark import benchmark
from pytorchvideo.data.utils import thwc_to_cthw
from pytorchvideo.transforms.functional import short_side_scale
from utils import create_dummy_video_frames
class TestBenchmarkTransforms(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_benchmark_short_side_scale_pytorch(self, num_iters: int = 10) -> None:
"""
Benchmark scale operation with pytorch backend.
Args:
num_iters (int): number of iterations to perform benchmarking.
"""
kwargs_list = [
{"temporal_size": 8, "ori_spatial_size": (128, 128), "dst_short_size": 112},
{
"temporal_size": 16,
"ori_spatial_size": (128, 128),
"dst_short_size": 112,
},
{
"temporal_size": 32,
"ori_spatial_size": (128, 128),
"dst_short_size": 112,
},
{"temporal_size": 8, "ori_spatial_size": (256, 256), "dst_short_size": 224},
{
"temporal_size": 16,
"ori_spatial_size": (256, 256),
"dst_short_size": 224,
},
{
"temporal_size": 32,
"ori_spatial_size": (256, 256),
"dst_short_size": 224,
},
{"temporal_size": 8, "ori_spatial_size": (320, 320), "dst_short_size": 224},
{
"temporal_size": 16,
"ori_spatial_size": (320, 320),
"dst_short_size": 224,
},
{
"temporal_size": 32,
"ori_spatial_size": (320, 320),
"dst_short_size": 224,
},
]
def _init_benchmark_short_side_scale(**kwargs) -> Callable:
x = thwc_to_cthw(
create_dummy_video_frames(
kwargs["temporal_size"],
kwargs["ori_spatial_size"][0],
kwargs["ori_spatial_size"][1],
)
).to(dtype=torch.float32)
def func_to_benchmark() -> None:
_ = short_side_scale(x, kwargs["dst_short_size"])
return
return func_to_benchmark
benchmark(
_init_benchmark_short_side_scale,
"benchmark_short_side_scale_pytorch",
kwargs_list,
num_iters=num_iters,
warmup_iters=2,
)
self.assertTrue(True)
|