# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import datetime import logging import time import random import torch import torch.distributed as dist from maskrcnn_benchmark.utils.comm import get_world_size, synchronize, broadcast_data from maskrcnn_benchmark.utils.metric_logger import MetricLogger from maskrcnn_benchmark.utils.ema import ModelEma def reduce_loss_dict(loss_dict): """ Reduce the loss dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as loss_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return loss_dict with torch.no_grad(): loss_names = [] all_losses = [] for k in sorted(loss_dict.keys()): loss_names.append(k) all_losses.append(loss_dict[k]) all_losses = torch.stack(all_losses, dim=0) dist.reduce(all_losses, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case all_losses /= world_size reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} return reduced_losses def do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, rngs=None ): logger = logging.getLogger("maskrcnn_benchmark.trainer") logger.info("Start training") meters = MetricLogger(delimiter=" ") max_iter = len(data_loader) start_iter = arguments["iteration"] model.train() model_ema = None if cfg.SOLVER.MODEL_EMA > 0: model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA) start_training_time = time.time() end = time.time() for iteration, (images, targets, _) in enumerate(data_loader, start_iter): if any(len(target) < 1 for target in targets): logger.error( "Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" ) continue data_time = time.time() - end iteration = iteration + 1 arguments["iteration"] = iteration images = images.to(device) targets = [target.to(device) for target in targets] # synchronize rngs if rngs is None: if isinstance(model, torch.nn.parallel.DistributedDataParallel): mix_nums = model.module.mix_nums else: mix_nums = model.mix_nums rngs = [random.randint(0, mix - 1) for mix in mix_nums] rngs = broadcast_data(rngs) for param in model.parameters(): param.requires_grad = False loss_dict = model(images, targets, rngs) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() losses.backward() optimizer.step() scheduler.step() if model_ema is not None: model_ema.update(model) arguments["model_ema"] = model_ema.state_dict() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iter: logger.info( meters.delimiter.join( [ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.0f}", ] ).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, ) ) if iteration % checkpoint_period == 0: checkpointer.save("model_{:07d}".format(iteration), **arguments) if iteration == max_iter: if model_ema is not None: model.load_state_dict(model_ema.state_dict()) checkpointer.save("model_final", **arguments) total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / (max_iter)))