# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import datetime import logging import sys import os import math import time import torch import torch.distributed as dist from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank from maskrcnn_benchmark.utils.metric_logger import MetricLogger from maskrcnn_benchmark.utils.ema import ModelEma from maskrcnn_benchmark.utils.amp import autocast, GradScaler from maskrcnn_benchmark.data.datasets.evaluation import evaluate from .inference import inference from .tsv_saver import TSVResultWriter import wandb import pdb, random def reduce_loss_dict(loss_dict): """ Reduce the loss dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as loss_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return loss_dict with torch.no_grad(): loss_names = [] all_losses = [] for k in sorted(loss_dict.keys()): loss_names.append(k) all_losses.append(loss_dict[k]) all_losses = torch.stack(all_losses, dim=0) dist.reduce(all_losses, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case all_losses /= world_size reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} return reduced_losses def do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, val_data_loader=None, meters=None, use_wandb=False ): logger = logging.getLogger("maskrcnn_benchmark.trainer") logger.info("Start training") # meters = MetricLogger(delimiter=" ") max_iter = len(data_loader) start_iter = arguments["iteration"] model.train() model_ema = None if cfg.SOLVER.MODEL_EMA > 0: model_ema = ModelEma(model, decay=cfg.SOLVER.MODEL_EMA) start_training_time = time.time() end = time.time() if cfg.SOLVER.USE_AMP: scaler = GradScaler() global_rank = get_rank() if cfg.SOLVER.CHECKPOINT_PER_EPOCH != -1 and cfg.SOLVER.MAX_EPOCH >= 1: checkpoint_period = len(data_loader) * cfg.SOLVER.CHECKPOINT_PER_EPOCH // cfg.SOLVER.MAX_EPOCH if global_rank <= 0 and cfg.SOLVER.MAX_EPOCH >= 1: print("Iter per epoch ", len(data_loader) // cfg.SOLVER.MAX_EPOCH) if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: patience_counter = 0 previous_best = 0.0 # Adapt the weight decay if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, "milestones"): milestone_target = 0 for i, milstone in enumerate(list(scheduler.milestones)): if scheduler.last_epoch >= milstone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: milestone_target = i + 1 # try to visualize the training data ### get the tokenizer if hasattr(model, "module"): tokenizer = model.module.tokenizer else: tokenizer = model.tokenizer tsv_visualizer = TSVResultWriter( tokenizer=tokenizer, max_visualize_num=1000, file_name=cfg.OUTPUT_DIR + "/train_visualize/train.tsv", write_freq=100) for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate( data_loader, start_iter ): nnegative = sum(len(target) < 1 for target in targets) nsample = len(targets) if nnegative > nsample * cfg.SOLVER.MAX_NEG_PER_BATCH: logger.info( "[WARNING] Sampled {} negative in {} in a batch, greater the allowed ratio {}, skip".format( nnegative, nsample, cfg.SOLVER.MAX_NEG_PER_BATCH ) ) continue data_time = time.time() - end iteration = iteration + 1 arguments["iteration"] = iteration images = images.to(device) captions = None try: targets = [target.to(device) for target in targets] captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] except: pass # Freeze language backbone if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: if hasattr(model, "module"): if hasattr(model.module, "fusion_backbone"): model.module.fusion_backbone.language_backbone.eval() else: model.module.language_backbone.eval() else: if hasattr(model, "fusion_backbone"): model.fusion_backbone.language_backbone.eval() else: model.language_backbone.eval() if is_main_process(): # only visualize for the main process tsv_visualizer.update_train_data(images, targets) if cfg.SOLVER.USE_AMP: with autocast(): if len(captions) > 0: loss_dict = model(images, targets, captions, positive_map, greenlight_map=greenlight_map) else: loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # save checkpoints for further debug if nan happens loss_value = losses.item() if torch.isnan(losses) or torch.isinf(losses): logging.error("NaN encountered, ignoring") losses[losses != losses] = 0 # if loss is too large, ignore it if loss_value > 10 and iteration > 10000: losses[losses == losses] = 0 # this is a bad example print("Loss is too large, ignore it, loss: ", loss_value) optimizer.zero_grad() scaler.scale(losses).backward() scaler.step(optimizer) scaler.update() scheduler.step() else: if len(captions) > 0: loss_dict = model(images, targets, captions, positive_map) else: loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # save checkpoints for further debug if nan happens loss_value = losses.item() if not math.isfinite(loss_value): logging.error(f"=> loss is {loss_value}, stopping training") time_str = time.strftime("%Y-%m-%d-%H-%M") fname = os.path.join(checkpointer.save_dir, f"{time_str}_states.pth") logging.info(f"=> save error state to {fname}") dict_to_save = { "x": images, "y": targets, "loss": losses, "states": model.module.state_dict() if hasattr(model, "module") else model.state_dict(), } if len(captions) > 0: dict_to_save["captions"] = captions dict_to_save["positive_map"] = positive_map torch.save(dict_to_save, fname) sys.exit(-1) if torch.isnan(losses) or torch.isinf(losses): losses[losses != losses] = 0 # if loss is too large, ignore it # if loss_value > 10 and iteration > 10000: # losses[losses == losses] = 0 # this is a bad example # print("Loss is too large, ignore it, loss: ", loss_value) optimizer.zero_grad() losses.backward() optimizer.step() scheduler.step() # Adapt the weight decay: only support multiStepLR if cfg.SOLVER.WEIGHT_DECAY_SCHEDULE and hasattr(scheduler, "milestones"): if milestone_target < len(scheduler.milestones): next_milestone = list(scheduler.milestones)[milestone_target] else: next_milestone = float("inf") if scheduler.last_epoch >= next_milestone * cfg.SOLVER.WEIGHT_DECAY_SCHEDULE_RATIO: gamma = scheduler.gamma logger.info("Drop the weight decay by {}!".format(gamma)) for param in optimizer.param_groups: if "weight_decay" in param: param["weight_decay"] *= gamma # move the target forward milestone_target += 1 # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) if model_ema is not None: model_ema.update(model) arguments["model_ema"] = model_ema.state_dict() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iter: # if iteration % 1 == 0 or iteration == max_iter: # logger.info( if global_rank <= 0: print( meters.delimiter.join( [ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "wd: {wd:.6f}", "max mem: {memory:.0f}", ] ).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], wd=optimizer.param_groups[0]["weight_decay"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, ) ) if use_wandb and is_main_process(): wandb.log({"train_loss": losses_reduced, "lr": optimizer.param_groups[0]["lr"], "wd": optimizer.param_groups[0]["weight_decay"], **loss_dict_reduced}) if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter): if is_main_process(): print("Evaluating") eval_result = 0.0 model.eval() if cfg.SOLVER.TEST_WITH_INFERENCE: with torch.no_grad(): try: _model = model.module except: _model = model _result = inference( model=_model, data_loader=val_data_loader, dataset_name="val", device=device, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, output_folder=None, cfg=cfg, verbose=False, ) if is_main_process(): try: eval_result = _result[0].results["bbox"]["AP"] except: pass else: results_dict = {} cpu_device = torch.device("cpu") for i, batch in enumerate(val_data_loader): images, targets, image_ids, positive_map, *_ = batch with torch.no_grad(): images = images.to(device) if positive_map is None: output = model(images) else: captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] output = model(images, captions, positive_map) output = [o.to(cpu_device) for o in output] results_dict.update({img_id: result for img_id, result in zip(image_ids, output)}) all_predictions = all_gather(results_dict) if is_main_process(): predictions = {} for p in all_predictions: predictions.update(p) predictions = [predictions[i] for i in list(sorted(predictions.keys()))] eval_result, _ = evaluate( val_data_loader.dataset, predictions, output_folder=None, box_only=cfg.DATASETS.CLASS_AGNOSTIC ) if cfg.DATASETS.CLASS_AGNOSTIC: eval_result = eval_result.results["box_proposal"]["AR@100"] else: eval_result = eval_result.results["bbox"]["AP"] model.train() if model_ema is not None and cfg.SOLVER.USE_EMA_FOR_MONITOR: model_ema.ema.eval() results_dict = {} cpu_device = torch.device("cpu") for i, batch in enumerate(val_data_loader): images, targets, image_ids, positive_map, positive_map_eval = batch with torch.no_grad(): images = images.to(device) if positive_map is None: output = model_ema.ema(images) else: captions = [t.get_field("caption") for t in targets if "caption" in t.fields()] output = model_ema.ema(images, captions, positive_map) output = [o.to(cpu_device) for o in output] results_dict.update({img_id: result for img_id, result in zip(image_ids, output)}) all_predictions = all_gather(results_dict) if is_main_process(): predictions = {} for p in all_predictions: predictions.update(p) predictions = [predictions[i] for i in list(sorted(predictions.keys()))] eval_result, _ = evaluate( val_data_loader.dataset, predictions, output_folder=None, box_only=cfg.DATASETS.CLASS_AGNOSTIC ) if cfg.DATASETS.CLASS_AGNOSTIC: eval_result = eval_result.results["box_proposal"]["AR@100"] else: eval_result = eval_result.results["bbox"]["AP"] if eval_result is not None: arguments.update(eval_result=eval_result) if cfg.SOLVER.USE_AUTOSTEP: assert eval_result is not None eval_result = all_gather(eval_result)[0] # broadcast_data([eval_result])[0] # print("Rank {} eval result gathered".format(cfg.local_rank), eval_result) scheduler.step(eval_result) if cfg.SOLVER.AUTO_TERMINATE_PATIENCE != -1: if eval_result < previous_best: patience_counter += 1 else: patience_counter = 0 previous_best = eval_result checkpointer.save("model_best", **arguments) print("Previous Best", previous_best, "Patience Counter", patience_counter, "Eval Result", eval_result) if patience_counter >= cfg.SOLVER.AUTO_TERMINATE_PATIENCE: if is_main_process(): print("\n\n\n\nAuto Termination at {}, current best {}\n\n\n".format(iteration, previous_best)) break if iteration % checkpoint_period == 0 and checkpoint_period > 100: checkpointer.save("model_{:07d}".format(iteration), **arguments) if iteration == max_iter: checkpointer.save("model_final", **arguments) break total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / (max_iter)))