Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import tempfile | |
from typing import Any | |
from loguru import logger | |
import torch | |
from torch import nn | |
from torch.cuda import amp | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
# fmt: off | |
from virtex.config import Config | |
from virtex.factories import ( | |
PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory, | |
LRSchedulerFactory, | |
) | |
from virtex.utils.checkpointing import CheckpointManager | |
from virtex.utils.common import common_parser, common_setup, cycle | |
import virtex.utils.distributed as dist | |
from virtex.utils.timer import Timer | |
parser = common_parser( | |
description="Train a VirTex model (CNN + Transformer) on COCO Captions." | |
) | |
group = parser.add_argument_group("Checkpointing and Logging") | |
group.add_argument( | |
"--resume-from", default=None, | |
help="Path to a checkpoint to resume training from (if provided)." | |
) | |
group.add_argument( | |
"--checkpoint-every", type=int, default=2000, | |
help="Serialize model to a checkpoint after every these many iterations.", | |
) | |
group.add_argument( | |
"--log-every", type=int, default=50, | |
help="""Log training curves to tensorboard after every these many iterations | |
only master process logs averaged loss values across processes.""", | |
) | |
# fmt: on | |
def main(_A: argparse.Namespace): | |
if _A.num_gpus_per_machine == 0: | |
# Set device as CPU if num_gpus_per_machine = 0. | |
device: Any = torch.device("cpu") | |
else: | |
# Get the current device as set for current distributed process. | |
# Check `launch` function in `virtex.utils.distributed` module. | |
device = torch.cuda.current_device() | |
# Create a config object (this will be immutable) and perform common setup | |
# such as logging and setting up serialization directory. | |
_C = Config(_A.config, _A.config_override) | |
common_setup(_C, _A) | |
# ------------------------------------------------------------------------- | |
# INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER | |
# ------------------------------------------------------------------------- | |
# fmt: off | |
train_dataset = PretrainingDatasetFactory.from_config(_C) | |
train_dataloader = DataLoader( | |
train_dataset, batch_size=None, shuffle=False, | |
num_workers=_A.cpu_workers, pin_memory=True, | |
) | |
# fmt: on | |
model = PretrainingModelFactory.from_config(_C).to(device) | |
optimizer = OptimizerFactory.from_config(_C, model.named_parameters()) | |
scheduler = LRSchedulerFactory.from_config(_C, optimizer) | |
# ------------------------------------------------------------------------- | |
# BEFORE TRAINING STARTS | |
# ------------------------------------------------------------------------- | |
# Create a gradient scaler for automatic mixed precision. | |
scaler = amp.GradScaler(enabled=_C.AMP) | |
# Load checkpoint to resume training if specified. | |
if _A.resume_from is not None: | |
start_iteration = CheckpointManager( | |
model=model, optimizer=optimizer, scheduler=scheduler, | |
).load(_A.resume_from) | |
else: | |
start_iteration = 0 | |
# Create an iterator from dataloader to sample batches perpetually. | |
train_dataloader_iter = cycle(train_dataloader, device, start_iteration) | |
# Wrap model in DDP if using more than one processes. | |
if dist.get_world_size() > 1: | |
dist.synchronize() | |
model = nn.parallel.DistributedDataParallel( | |
model, device_ids=[device], find_unused_parameters=True | |
) | |
# Keep track of time per iteration and ETA. | |
timer = Timer( | |
start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS | |
) | |
# Create tensorboard writer and checkpoint manager (only in master process). | |
if dist.is_master_process(): | |
tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) | |
tensorboard_writer.add_text("config", f"```\n{_C}\n```") | |
checkpoint_manager = CheckpointManager( | |
_A.serialization_dir, | |
model=model, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
scaler=scaler, | |
) | |
# ------------------------------------------------------------------------- | |
# TRAINING LOOP | |
# ------------------------------------------------------------------------- | |
for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1): | |
timer.tic() | |
optimizer.zero_grad() | |
batch = next(train_dataloader_iter) | |
with amp.autocast(enabled=_C.AMP): | |
output_dict = model(batch) | |
loss = output_dict["loss"] | |
scaler.scale(loss).backward() | |
# First clip norm of gradients, and then perform optimizer step. | |
scaler.unscale_(optimizer) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM) | |
scaler.step(optimizer) | |
scaler.update() | |
scheduler.step() | |
timer.toc() | |
# --------------------------------------------------------------------- | |
# LOGGING | |
# --------------------------------------------------------------------- | |
if iteration % _A.log_every == 0: | |
logger.info( | |
f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]" | |
) | |
if dist.is_master_process(): | |
tensorboard_writer.add_scalars( | |
"train", output_dict["loss_components"], iteration | |
) | |
if iteration % _A.checkpoint_every == 0 and dist.is_master_process(): | |
checkpoint_manager.step(iteration) | |
if __name__ == "__main__": | |
_A = parser.parse_args() | |
if _A.num_gpus_per_machine == 0: | |
main(_A) | |
else: | |
# This will launch `main` and set appropriate CUDA device (GPU ID) as | |
# per process (accessed in the beginning of `main`). | |
dist.launch( | |
main, | |
num_machines=_A.num_machines, | |
num_gpus_per_machine=_A.num_gpus_per_machine, | |
machine_rank=_A.machine_rank, | |
dist_url=_A.dist_url, | |
args=(_A, ), | |
) | |