desco / tools /train_net.py
zdou0830's picture
desco
749745d
raw
history blame
9.58 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
r"""
Basic training script for PyTorch
"""
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip
import argparse
import os
import torch
from maskrcnn_benchmark.config import cfg, try_to_find
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.solver import make_lr_scheduler
from maskrcnn_benchmark.solver import make_optimizer
from maskrcnn_benchmark.engine.inference import inference
from maskrcnn_benchmark.engine.trainer import do_train
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank, synchronize
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.metric_logger import MetricLogger, TensorboardLogger
from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config
import numpy as np
import random
import pdb, wandb
from maskrcnn_benchmark.utils.amp import autocast, GradScaler
def train(cfg, local_rank, distributed, use_tensorboard=False, use_wandb=False):
model = build_detection_model(cfg)
device = torch.device(cfg.MODEL.DEVICE)
model.to(device)
if cfg.MODEL.BACKBONE.RESET_BN:
for name, param in model.named_buffers():
if "running_mean" in name:
torch.nn.init.constant_(param, 0)
if "running_var" in name:
torch.nn.init.constant_(param, 1)
if cfg.SOLVER.GRAD_CLIP > 0:
clip_value = cfg.SOLVER.GRAD_CLIP
for p in filter(lambda p: p.grad is not None, model.parameters()):
p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))
data_loader = make_data_loader(
cfg,
is_train=True,
is_distributed=distributed,
start_iter=0, # <TODO> Sample data from resume is disabled, due to the conflict with max_epoch
)
if cfg.TEST.DURING_TRAINING or cfg.SOLVER.USE_AUTOSTEP:
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
data_loaders_val = data_loaders_val[0]
else:
data_loaders_val = None
if cfg.MODEL.BACKBONE.FREEZE:
for p in model.backbone.body.parameters():
p.requires_grad = False
if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
print("LANGUAGE_BACKBONE FROZEN.")
for p in model.language_backbone.body.parameters():
p.requires_grad = False
if cfg.MODEL.FPN.FREEZE:
for p in model.backbone.fpn.parameters():
p.requires_grad = False
if cfg.MODEL.RPN.FREEZE:
for p in model.rpn.parameters():
p.requires_grad = False
# if cfg.SOLVER.PROMPT_PROBING_LEVEL != -1:
# if cfg.SOLVER.PROMPT_PROBING_LEVEL == 1:
# for p in model.parameters():
# p.requires_grad = False
# for p in model.language_backbone.body.parameters():
# p.requires_grad = True
# for name, p in model.named_parameters():
# if p.requires_grad:
# print(name, " : Not Frozen")
# else:
# print(name, " : Frozen")
# else:
# assert(0)
optimizer = make_optimizer(cfg, model)
scheduler = make_lr_scheduler(cfg, optimizer)
if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
broadcast_buffers=cfg.MODEL.BACKBONE.USE_BN,
find_unused_parameters=cfg.SOLVER.FIND_UNUSED_PARAMETERS,
)
arguments = {}
arguments["iteration"] = 0
output_dir = cfg.OUTPUT_DIR
save_to_disk = get_rank() == 0
checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk)
extra_checkpoint_data = checkpointer.load(try_to_find(cfg.MODEL.WEIGHT), skip_scheduler = cfg.SOLVER.RESUME_SKIP_SCHEDULE)
arguments.update(extra_checkpoint_data)
# For full model finetuning
# arguments["iteration"] = 0
# optimizer = make_optimizer(cfg, model)
# scheduler = make_lr_scheduler(cfg, optimizer)
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
if use_tensorboard:
meters = TensorboardLogger(log_dir=cfg.OUTPUT_DIR, start_iter=arguments["iteration"], delimiter=" ")
else:
meters = MetricLogger(delimiter=" ")
do_train(
cfg,
model,
data_loader,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
arguments,
data_loaders_val,
meters,
use_wandb = use_wandb
)
return model
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument(
"--skip-test",
dest="skip_test",
help="Do not test the final model",
action="store_true",
)
parser.add_argument(
"--use-tensorboard",
dest="use_tensorboard",
help="Use tensorboardX logger (Requires tensorboardX installed)",
action="store_true",
default=False,
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument("--save_original_config", action="store_true")
parser.add_argument("--disable_output_distributed", action="store_true")
parser.add_argument("--debug_nan_checkpoint", default=None)
parser.add_argument("--override_output_dir", default=None)
parser.add_argument("--wandb_name", default="__test__")
parser.add_argument("--use_wandb", action="store_true")
args = parser.parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1
if args.distributed:
import datetime
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(0, 7200))
if args.disable_output_distributed:
setup_for_distributed(args.local_rank <= 0)
cfg.local_rank = args.local_rank
cfg.num_gpus = num_gpus
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# specify output dir for models
cfg.OUTPUT_DIR = "OUTPUTS/" + args.wandb_name
if is_main_process():
mkdir(cfg.OUTPUT_DIR)
if args.wandb_name != "__test__" and args.use_wandb:
if is_main_process():
run = wandb.init(
project = 'lang_det',
job_type = 'train_model',
name = args.wandb_name,
)
with open(os.path.join(cfg.OUTPUT_DIR, 'wandb_run_id.txt'), 'w') as f:
f.write(run.id)
if args.override_output_dir:
cfg.OUTPUT_DIR = args.override_output_dir
cfg.freeze()
seed = cfg.SOLVER.SEED + args.local_rank
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
output_dir = cfg.OUTPUT_DIR
if output_dir:
mkdir(output_dir)
logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
logger.info(args)
logger.info("Using {} GPUs".format(num_gpus))
# logger.info("Collecting env info (might take some time)")
# logger.info("\n" + collect_env_info())
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml')
logger.info("Saving config into: {}".format(output_config_path))
# save overloaded model config in the output directory
if args.save_original_config:
import shutil
shutil.copy(args.config_file, os.path.join(cfg.OUTPUT_DIR, "config_original.yml"))
save_config(cfg, output_config_path)
model = train(
cfg=cfg,
local_rank=args.local_rank,
distributed=args.distributed,
use_tensorboard=args.use_tensorboard,
use_wandb=args.wandb_name != "__test__" and args.use_wandb)
if __name__ == "__main__":
main()