|
|
|
|
|
import unittest |
|
from typing import Callable |
|
|
|
import torch |
|
from fvcore.common.benchmark import benchmark |
|
from pytorchvideo.data.utils import thwc_to_cthw |
|
from pytorchvideo.transforms.functional import short_side_scale |
|
from utils import create_dummy_video_frames |
|
|
|
|
|
class TestBenchmarkTransforms(unittest.TestCase): |
|
def setUp(self): |
|
super().setUp() |
|
torch.set_rng_state(torch.manual_seed(42).get_state()) |
|
|
|
def test_benchmark_short_side_scale_pytorch(self, num_iters: int = 10) -> None: |
|
""" |
|
Benchmark scale operation with pytorch backend. |
|
Args: |
|
num_iters (int): number of iterations to perform benchmarking. |
|
""" |
|
kwargs_list = [ |
|
{"temporal_size": 8, "ori_spatial_size": (128, 128), "dst_short_size": 112}, |
|
{ |
|
"temporal_size": 16, |
|
"ori_spatial_size": (128, 128), |
|
"dst_short_size": 112, |
|
}, |
|
{ |
|
"temporal_size": 32, |
|
"ori_spatial_size": (128, 128), |
|
"dst_short_size": 112, |
|
}, |
|
{"temporal_size": 8, "ori_spatial_size": (256, 256), "dst_short_size": 224}, |
|
{ |
|
"temporal_size": 16, |
|
"ori_spatial_size": (256, 256), |
|
"dst_short_size": 224, |
|
}, |
|
{ |
|
"temporal_size": 32, |
|
"ori_spatial_size": (256, 256), |
|
"dst_short_size": 224, |
|
}, |
|
{"temporal_size": 8, "ori_spatial_size": (320, 320), "dst_short_size": 224}, |
|
{ |
|
"temporal_size": 16, |
|
"ori_spatial_size": (320, 320), |
|
"dst_short_size": 224, |
|
}, |
|
{ |
|
"temporal_size": 32, |
|
"ori_spatial_size": (320, 320), |
|
"dst_short_size": 224, |
|
}, |
|
] |
|
|
|
def _init_benchmark_short_side_scale(**kwargs) -> Callable: |
|
x = thwc_to_cthw( |
|
create_dummy_video_frames( |
|
kwargs["temporal_size"], |
|
kwargs["ori_spatial_size"][0], |
|
kwargs["ori_spatial_size"][1], |
|
) |
|
).to(dtype=torch.float32) |
|
|
|
def func_to_benchmark() -> None: |
|
_ = short_side_scale(x, kwargs["dst_short_size"]) |
|
return |
|
|
|
return func_to_benchmark |
|
|
|
benchmark( |
|
_init_benchmark_short_side_scale, |
|
"benchmark_short_side_scale_pytorch", |
|
kwargs_list, |
|
num_iters=num_iters, |
|
warmup_iters=2, |
|
) |
|
self.assertTrue(True) |
|
|