pandagpt-vicuna-v0-7b / code /pytorchvideo /tests /benchmark_transforms.py
mvsoom's picture
Upload folder using huggingface_hub
3133fdb
raw
history blame contribute delete
2.73 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
from typing import Callable
import torch
from fvcore.common.benchmark import benchmark
from pytorchvideo.data.utils import thwc_to_cthw
from pytorchvideo.transforms.functional import short_side_scale
from utils import create_dummy_video_frames
class TestBenchmarkTransforms(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_benchmark_short_side_scale_pytorch(self, num_iters: int = 10) -> None:
"""
Benchmark scale operation with pytorch backend.
Args:
num_iters (int): number of iterations to perform benchmarking.
"""
kwargs_list = [
{"temporal_size": 8, "ori_spatial_size": (128, 128), "dst_short_size": 112},
{
"temporal_size": 16,
"ori_spatial_size": (128, 128),
"dst_short_size": 112,
},
{
"temporal_size": 32,
"ori_spatial_size": (128, 128),
"dst_short_size": 112,
},
{"temporal_size": 8, "ori_spatial_size": (256, 256), "dst_short_size": 224},
{
"temporal_size": 16,
"ori_spatial_size": (256, 256),
"dst_short_size": 224,
},
{
"temporal_size": 32,
"ori_spatial_size": (256, 256),
"dst_short_size": 224,
},
{"temporal_size": 8, "ori_spatial_size": (320, 320), "dst_short_size": 224},
{
"temporal_size": 16,
"ori_spatial_size": (320, 320),
"dst_short_size": 224,
},
{
"temporal_size": 32,
"ori_spatial_size": (320, 320),
"dst_short_size": 224,
},
]
def _init_benchmark_short_side_scale(**kwargs) -> Callable:
x = thwc_to_cthw(
create_dummy_video_frames(
kwargs["temporal_size"],
kwargs["ori_spatial_size"][0],
kwargs["ori_spatial_size"][1],
)
).to(dtype=torch.float32)
def func_to_benchmark() -> None:
_ = short_side_scale(x, kwargs["dst_short_size"])
return
return func_to_benchmark
benchmark(
_init_benchmark_short_side_scale,
"benchmark_short_side_scale_pytorch",
kwargs_list,
num_iters=num_iters,
warmup_iters=2,
)
self.assertTrue(True)