# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. r""" Basic training script for PyTorch """ # Set up custom environment before nearly anything else is imported # NOTE: this should be the first import (no not reorder) from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip import argparse import os import torch from maskrcnn_benchmark.config import cfg, try_to_find from maskrcnn_benchmark.data import make_data_loader from maskrcnn_benchmark.solver import make_lr_scheduler from maskrcnn_benchmark.solver import make_optimizer from maskrcnn_benchmark.engine.inference import inference from maskrcnn_benchmark.engine.trainer import do_train from maskrcnn_benchmark.modeling.detector import build_detection_model from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer from maskrcnn_benchmark.utils.collect_env import collect_env_info from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank, synchronize from maskrcnn_benchmark.utils.imports import import_file from maskrcnn_benchmark.utils.logger import setup_logger from maskrcnn_benchmark.utils.metric_logger import MetricLogger, TensorboardLogger from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config import numpy as np import random import pdb, wandb from maskrcnn_benchmark.utils.amp import autocast, GradScaler def train(cfg, local_rank, distributed, use_tensorboard=False, use_wandb=False): model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if cfg.MODEL.BACKBONE.RESET_BN: for name, param in model.named_buffers(): if "running_mean" in name: torch.nn.init.constant_(param, 0) if "running_var" in name: torch.nn.init.constant_(param, 1) if cfg.SOLVER.GRAD_CLIP > 0: clip_value = cfg.SOLVER.GRAD_CLIP for p in filter(lambda p: p.grad is not None, model.parameters()): p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value)) data_loader = make_data_loader( cfg, is_train=True, is_distributed=distributed, start_iter=0, # Sample data from resume is disabled, due to the conflict with max_epoch ) if cfg.TEST.DURING_TRAINING or cfg.SOLVER.USE_AUTOSTEP: data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) data_loaders_val = data_loaders_val[0] else: data_loaders_val = None if cfg.MODEL.BACKBONE.FREEZE: for p in model.backbone.body.parameters(): p.requires_grad = False if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: print("LANGUAGE_BACKBONE FROZEN.") for p in model.language_backbone.body.parameters(): p.requires_grad = False if cfg.MODEL.FPN.FREEZE: for p in model.backbone.fpn.parameters(): p.requires_grad = False if cfg.MODEL.RPN.FREEZE: for p in model.rpn.parameters(): p.requires_grad = False # if cfg.SOLVER.PROMPT_PROBING_LEVEL != -1: # if cfg.SOLVER.PROMPT_PROBING_LEVEL == 1: # for p in model.parameters(): # p.requires_grad = False # for p in model.language_backbone.body.parameters(): # p.requires_grad = True # for name, p in model.named_parameters(): # if p.requires_grad: # print(name, " : Not Frozen") # else: # print(name, " : Frozen") # else: # assert(0) optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=cfg.MODEL.BACKBONE.USE_BN, find_unused_parameters=cfg.SOLVER.FIND_UNUSED_PARAMETERS, ) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(try_to_find(cfg.MODEL.WEIGHT), skip_scheduler = cfg.SOLVER.RESUME_SKIP_SCHEDULE) arguments.update(extra_checkpoint_data) # For full model finetuning # arguments["iteration"] = 0 # optimizer = make_optimizer(cfg, model) # scheduler = make_lr_scheduler(cfg, optimizer) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD if use_tensorboard: meters = TensorboardLogger(log_dir=cfg.OUTPUT_DIR, start_iter=arguments["iteration"], delimiter=" ") else: meters = MetricLogger(delimiter=" ") do_train( cfg, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, data_loaders_val, meters, use_wandb = use_wandb ) return model def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") parser.add_argument( "--config-file", default="", metavar="FILE", help="path to config file", type=str, ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "--skip-test", dest="skip_test", help="Do not test the final model", action="store_true", ) parser.add_argument( "--use-tensorboard", dest="use_tensorboard", help="Use tensorboardX logger (Requires tensorboardX installed)", action="store_true", default=False, ) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) parser.add_argument("--save_original_config", action="store_true") parser.add_argument("--disable_output_distributed", action="store_true") parser.add_argument("--debug_nan_checkpoint", default=None) parser.add_argument("--override_output_dir", default=None) parser.add_argument("--wandb_name", default="__test__") parser.add_argument("--use_wandb", action="store_true") args = parser.parse_args() num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 args.distributed = num_gpus > 1 if args.distributed: import datetime torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(0, 7200)) if args.disable_output_distributed: setup_for_distributed(args.local_rank <= 0) cfg.local_rank = args.local_rank cfg.num_gpus = num_gpus cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) # specify output dir for models cfg.OUTPUT_DIR = "OUTPUTS/" + args.wandb_name if is_main_process(): mkdir(cfg.OUTPUT_DIR) if args.wandb_name != "__test__" and args.use_wandb: if is_main_process(): run = wandb.init( project = 'lang_det', job_type = 'train_model', name = args.wandb_name, ) with open(os.path.join(cfg.OUTPUT_DIR, 'wandb_run_id.txt'), 'w') as f: f.write(run.id) if args.override_output_dir: cfg.OUTPUT_DIR = args.override_output_dir cfg.freeze() seed = cfg.SOLVER.SEED + args.local_rank torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) output_dir = cfg.OUTPUT_DIR if output_dir: mkdir(output_dir) logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank()) logger.info(args) logger.info("Using {} GPUs".format(num_gpus)) # logger.info("Collecting env info (might take some time)") # logger.info("\n" + collect_env_info()) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml') logger.info("Saving config into: {}".format(output_config_path)) # save overloaded model config in the output directory if args.save_original_config: import shutil shutil.copy(args.config_file, os.path.join(cfg.OUTPUT_DIR, "config_original.yml")) save_config(cfg, output_config_path) model = train( cfg=cfg, local_rank=args.local_rank, distributed=args.distributed, use_tensorboard=args.use_tensorboard, use_wandb=args.wandb_name != "__test__" and args.use_wandb) if __name__ == "__main__": main()