mshukor
init
3eb682b
raw
history blame
8.73 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Helper functions for multigrid training."""
import numpy as np
import timesformer.utils.logging as logging
logger = logging.get_logger(__name__)
class MultigridSchedule(object):
"""
This class defines multigrid training schedule and update cfg accordingly.
"""
def init_multigrid(self, cfg):
"""
Update cfg based on multigrid settings.
Args:
cfg (configs): configs that contains training and multigrid specific
hyperparameters. Details can be seen in
slowfast/config/defaults.py.
Returns:
cfg (configs): the updated cfg.
"""
self.schedule = None
# We may modify cfg.TRAIN.BATCH_SIZE, cfg.DATA.NUM_FRAMES, and
# cfg.DATA.TRAIN_CROP_SIZE during training, so we store their original
# value in cfg and use them as global variables.
cfg.MULTIGRID.DEFAULT_B = cfg.TRAIN.BATCH_SIZE
cfg.MULTIGRID.DEFAULT_T = cfg.DATA.NUM_FRAMES
cfg.MULTIGRID.DEFAULT_S = cfg.DATA.TRAIN_CROP_SIZE
if cfg.MULTIGRID.LONG_CYCLE:
self.schedule = self.get_long_cycle_schedule(cfg)
cfg.SOLVER.STEPS = [0] + [s[-1] for s in self.schedule]
# Fine-tuning phase.
cfg.SOLVER.STEPS[-1] = (
cfg.SOLVER.STEPS[-2] + cfg.SOLVER.STEPS[-1]
) // 2
cfg.SOLVER.LRS = [
cfg.SOLVER.GAMMA ** s[0] * s[1][0] for s in self.schedule
]
# Fine-tuning phase.
cfg.SOLVER.LRS = cfg.SOLVER.LRS[:-1] + [
cfg.SOLVER.LRS[-2],
cfg.SOLVER.LRS[-1],
]
cfg.SOLVER.MAX_EPOCH = self.schedule[-1][-1]
elif cfg.MULTIGRID.SHORT_CYCLE:
cfg.SOLVER.STEPS = [
int(s * cfg.MULTIGRID.EPOCH_FACTOR) for s in cfg.SOLVER.STEPS
]
cfg.SOLVER.MAX_EPOCH = int(
cfg.SOLVER.MAX_EPOCH * cfg.MULTIGRID.EPOCH_FACTOR
)
return cfg
def update_long_cycle(self, cfg, cur_epoch):
"""
Before every epoch, check if long cycle shape should change. If it
should, update cfg accordingly.
Args:
cfg (configs): configs that contains training and multigrid specific
hyperparameters. Details can be seen in
slowfast/config/defaults.py.
cur_epoch (int): current epoch index.
Returns:
cfg (configs): the updated cfg.
changed (bool): do we change long cycle shape at this epoch?
"""
base_b, base_t, base_s = get_current_long_cycle_shape(
self.schedule, cur_epoch
)
if base_s != cfg.DATA.TRAIN_CROP_SIZE or base_t != cfg.DATA.NUM_FRAMES:
cfg.DATA.NUM_FRAMES = base_t
cfg.DATA.TRAIN_CROP_SIZE = base_s
cfg.TRAIN.BATCH_SIZE = base_b * cfg.MULTIGRID.DEFAULT_B
bs_factor = (
float(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
/ cfg.MULTIGRID.BN_BASE_SIZE
)
if bs_factor < 1:
cfg.BN.NORM_TYPE = "sync_batchnorm"
cfg.BN.NUM_SYNC_DEVICES = int(1.0 / bs_factor)
elif bs_factor > 1:
cfg.BN.NORM_TYPE = "sub_batchnorm"
cfg.BN.NUM_SPLITS = int(bs_factor)
else:
cfg.BN.NORM_TYPE = "batchnorm"
cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = cfg.DATA.SAMPLING_RATE * (
cfg.MULTIGRID.DEFAULT_T // cfg.DATA.NUM_FRAMES
)
logger.info("Long cycle updates:")
logger.info("\tBN.NORM_TYPE: {}".format(cfg.BN.NORM_TYPE))
if cfg.BN.NORM_TYPE == "sync_batchnorm":
logger.info(
"\tBN.NUM_SYNC_DEVICES: {}".format(cfg.BN.NUM_SYNC_DEVICES)
)
elif cfg.BN.NORM_TYPE == "sub_batchnorm":
logger.info("\tBN.NUM_SPLITS: {}".format(cfg.BN.NUM_SPLITS))
logger.info("\tTRAIN.BATCH_SIZE: {}".format(cfg.TRAIN.BATCH_SIZE))
logger.info(
"\tDATA.NUM_FRAMES x LONG_CYCLE_SAMPLING_RATE: {}x{}".format(
cfg.DATA.NUM_FRAMES, cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE
)
)
logger.info(
"\tDATA.TRAIN_CROP_SIZE: {}".format(cfg.DATA.TRAIN_CROP_SIZE)
)
return cfg, True
else:
return cfg, False
def get_long_cycle_schedule(self, cfg):
"""
Based on multigrid hyperparameters, define the schedule of a long cycle.
Args:
cfg (configs): configs that contains training and multigrid specific
hyperparameters. Details can be seen in
slowfast/config/defaults.py.
Returns:
schedule (list): Specifies a list long cycle base shapes and their
corresponding training epochs.
"""
steps = cfg.SOLVER.STEPS
default_size = float(
cfg.DATA.NUM_FRAMES * cfg.DATA.TRAIN_CROP_SIZE ** 2
)
default_iters = steps[-1]
# Get shapes and average batch size for each long cycle shape.
avg_bs = []
all_shapes = []
for t_factor, s_factor in cfg.MULTIGRID.LONG_CYCLE_FACTORS:
base_t = int(round(cfg.DATA.NUM_FRAMES * t_factor))
base_s = int(round(cfg.DATA.TRAIN_CROP_SIZE * s_factor))
if cfg.MULTIGRID.SHORT_CYCLE:
shapes = [
[
base_t,
cfg.MULTIGRID.DEFAULT_S
* cfg.MULTIGRID.SHORT_CYCLE_FACTORS[0],
],
[
base_t,
cfg.MULTIGRID.DEFAULT_S
* cfg.MULTIGRID.SHORT_CYCLE_FACTORS[1],
],
[base_t, base_s],
]
else:
shapes = [[base_t, base_s]]
# (T, S) -> (B, T, S)
shapes = [
[int(round(default_size / (s[0] * s[1] * s[1]))), s[0], s[1]]
for s in shapes
]
avg_bs.append(np.mean([s[0] for s in shapes]))
all_shapes.append(shapes)
# Get schedule regardless of cfg.MULTIGRID.EPOCH_FACTOR.
total_iters = 0
schedule = []
for step_index in range(len(steps) - 1):
step_epochs = steps[step_index + 1] - steps[step_index]
for long_cycle_index, shapes in enumerate(all_shapes):
cur_epochs = (
step_epochs * avg_bs[long_cycle_index] / sum(avg_bs)
)
cur_iters = cur_epochs / avg_bs[long_cycle_index]
total_iters += cur_iters
schedule.append((step_index, shapes[-1], cur_epochs))
iter_saving = default_iters / total_iters
final_step_epochs = cfg.SOLVER.MAX_EPOCH - steps[-1]
# We define the fine-tuning phase to have the same amount of iteration
# saving as the rest of the training.
ft_epochs = final_step_epochs / iter_saving * avg_bs[-1]
schedule.append((step_index + 1, all_shapes[-1][2], ft_epochs))
# Obtrain final schedule given desired cfg.MULTIGRID.EPOCH_FACTOR.
x = (
cfg.SOLVER.MAX_EPOCH
* cfg.MULTIGRID.EPOCH_FACTOR
/ sum(s[-1] for s in schedule)
)
final_schedule = []
total_epochs = 0
for s in schedule:
epochs = s[2] * x
total_epochs += epochs
final_schedule.append((s[0], s[1], int(round(total_epochs))))
print_schedule(final_schedule)
return final_schedule
def print_schedule(schedule):
"""
Log schedule.
"""
logger.info("Long cycle index\tBase shape\tEpochs")
for s in schedule:
logger.info("{}\t{}\t{}".format(s[0], s[1], s[2]))
def get_current_long_cycle_shape(schedule, epoch):
"""
Given a schedule and epoch index, return the long cycle base shape.
Args:
schedule (configs): configs that contains training and multigrid specific
hyperparameters. Details can be seen in
slowfast/config/defaults.py.
cur_epoch (int): current epoch index.
Returns:
shapes (list): A list describing the base shape in a long cycle:
[batch size relative to default,
number of frames, spatial dimension].
"""
for s in schedule:
if epoch < s[-1]:
return s[1]
return schedule[-1][1]