mshukor
init
3eb682b
raw
history blame
No virus
3.3 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Functions for benchmarks.
"""
import numpy as np
import pprint
import torch
import tqdm
from fvcore.common.timer import Timer
import timesformer.utils.logging as logging
import timesformer.utils.misc as misc
from timesformer.datasets import loader
from timesformer.utils.env import setup_environment
logger = logging.get_logger(__name__)
def benchmark_data_loading(cfg):
"""
Benchmark the speed of data loading in PySlowFast.
Args:
cfg (CfgNode): configs. Details can be found in
lib/config/defaults.py
"""
# Set up environment.
setup_environment()
# Set random seed from configs.
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
# Setup logging format.
logging.setup_logging(cfg.OUTPUT_DIR)
# Print config.
logger.info("Benchmark data loading with config:")
logger.info(pprint.pformat(cfg))
timer = Timer()
dataloader = loader.construct_loader(cfg, "train")
logger.info(
"Initialize loader using {:.2f} seconds.".format(timer.seconds())
)
# Total batch size across different machines.
batch_size = cfg.TRAIN.BATCH_SIZE * cfg.NUM_SHARDS
log_period = cfg.BENCHMARK.LOG_PERIOD
epoch_times = []
# Test for a few epochs.
for cur_epoch in range(cfg.BENCHMARK.NUM_EPOCHS):
timer = Timer()
timer_epoch = Timer()
iter_times = []
if cfg.BENCHMARK.SHUFFLE:
loader.shuffle_dataset(dataloader, cur_epoch)
for cur_iter, _ in enumerate(tqdm.tqdm(dataloader)):
if cur_iter > 0 and cur_iter % log_period == 0:
iter_times.append(timer.seconds())
ram_usage, ram_total = misc.cpu_mem_usage()
logger.info(
"Epoch {}: {} iters ({} videos) in {:.2f} seconds. "
"RAM Usage: {:.2f}/{:.2f} GB.".format(
cur_epoch,
log_period,
log_period * batch_size,
iter_times[-1],
ram_usage,
ram_total,
)
)
timer.reset()
epoch_times.append(timer_epoch.seconds())
ram_usage, ram_total = misc.cpu_mem_usage()
logger.info(
"Epoch {}: in total {} iters ({} videos) in {:.2f} seconds. "
"RAM Usage: {:.2f}/{:.2f} GB.".format(
cur_epoch,
len(dataloader),
len(dataloader) * batch_size,
epoch_times[-1],
ram_usage,
ram_total,
)
)
logger.info(
"Epoch {}: on average every {} iters ({} videos) take {:.2f}/{:.2f} "
"(avg/std) seconds.".format(
cur_epoch,
log_period,
log_period * batch_size,
np.mean(iter_times),
np.std(iter_times),
)
)
logger.info(
"On average every epoch ({} videos) takes {:.2f}/{:.2f} "
"(avg/std) seconds.".format(
len(dataloader) * batch_size,
np.mean(epoch_times),
np.std(epoch_times),
)
)