Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
r""" | |
Basic training script for PyTorch | |
""" | |
# Set up custom environment before nearly anything else is imported | |
# NOTE: this should be the first import (no not reorder) | |
from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip | |
import argparse | |
import os | |
import torch | |
from maskrcnn_benchmark.config import cfg, try_to_find | |
from maskrcnn_benchmark.data import make_data_loader | |
from maskrcnn_benchmark.solver import make_lr_scheduler | |
from maskrcnn_benchmark.solver import make_optimizer | |
from maskrcnn_benchmark.engine.inference import inference | |
from maskrcnn_benchmark.engine.trainer import do_train | |
from maskrcnn_benchmark.modeling.detector import build_detection_model | |
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer | |
from maskrcnn_benchmark.utils.collect_env import collect_env_info | |
from maskrcnn_benchmark.utils.comm import get_world_size, all_gather, is_main_process, broadcast_data, get_rank, synchronize | |
from maskrcnn_benchmark.utils.imports import import_file | |
from maskrcnn_benchmark.utils.logger import setup_logger | |
from maskrcnn_benchmark.utils.metric_logger import MetricLogger, TensorboardLogger | |
from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config | |
import numpy as np | |
import random | |
import pdb, wandb | |
from maskrcnn_benchmark.utils.amp import autocast, GradScaler | |
def train(cfg, local_rank, distributed, use_tensorboard=False, use_wandb=False): | |
model = build_detection_model(cfg) | |
device = torch.device(cfg.MODEL.DEVICE) | |
model.to(device) | |
if cfg.MODEL.BACKBONE.RESET_BN: | |
for name, param in model.named_buffers(): | |
if "running_mean" in name: | |
torch.nn.init.constant_(param, 0) | |
if "running_var" in name: | |
torch.nn.init.constant_(param, 1) | |
if cfg.SOLVER.GRAD_CLIP > 0: | |
clip_value = cfg.SOLVER.GRAD_CLIP | |
for p in filter(lambda p: p.grad is not None, model.parameters()): | |
p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value)) | |
data_loader = make_data_loader( | |
cfg, | |
is_train=True, | |
is_distributed=distributed, | |
start_iter=0, # <TODO> Sample data from resume is disabled, due to the conflict with max_epoch | |
) | |
if cfg.TEST.DURING_TRAINING or cfg.SOLVER.USE_AUTOSTEP: | |
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) | |
data_loaders_val = data_loaders_val[0] | |
else: | |
data_loaders_val = None | |
if cfg.MODEL.BACKBONE.FREEZE: | |
for p in model.backbone.body.parameters(): | |
p.requires_grad = False | |
if cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: | |
print("LANGUAGE_BACKBONE FROZEN.") | |
for p in model.language_backbone.body.parameters(): | |
p.requires_grad = False | |
if cfg.MODEL.FPN.FREEZE: | |
for p in model.backbone.fpn.parameters(): | |
p.requires_grad = False | |
if cfg.MODEL.RPN.FREEZE: | |
for p in model.rpn.parameters(): | |
p.requires_grad = False | |
# if cfg.SOLVER.PROMPT_PROBING_LEVEL != -1: | |
# if cfg.SOLVER.PROMPT_PROBING_LEVEL == 1: | |
# for p in model.parameters(): | |
# p.requires_grad = False | |
# for p in model.language_backbone.body.parameters(): | |
# p.requires_grad = True | |
# for name, p in model.named_parameters(): | |
# if p.requires_grad: | |
# print(name, " : Not Frozen") | |
# else: | |
# print(name, " : Frozen") | |
# else: | |
# assert(0) | |
optimizer = make_optimizer(cfg, model) | |
scheduler = make_lr_scheduler(cfg, optimizer) | |
if distributed: | |
model = torch.nn.parallel.DistributedDataParallel( | |
model, | |
device_ids=[local_rank], | |
output_device=local_rank, | |
broadcast_buffers=cfg.MODEL.BACKBONE.USE_BN, | |
find_unused_parameters=cfg.SOLVER.FIND_UNUSED_PARAMETERS, | |
) | |
arguments = {} | |
arguments["iteration"] = 0 | |
output_dir = cfg.OUTPUT_DIR | |
save_to_disk = get_rank() == 0 | |
checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) | |
extra_checkpoint_data = checkpointer.load(try_to_find(cfg.MODEL.WEIGHT), skip_scheduler = cfg.SOLVER.RESUME_SKIP_SCHEDULE) | |
arguments.update(extra_checkpoint_data) | |
# For full model finetuning | |
# arguments["iteration"] = 0 | |
# optimizer = make_optimizer(cfg, model) | |
# scheduler = make_lr_scheduler(cfg, optimizer) | |
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD | |
if use_tensorboard: | |
meters = TensorboardLogger(log_dir=cfg.OUTPUT_DIR, start_iter=arguments["iteration"], delimiter=" ") | |
else: | |
meters = MetricLogger(delimiter=" ") | |
do_train( | |
cfg, | |
model, | |
data_loader, | |
optimizer, | |
scheduler, | |
checkpointer, | |
device, | |
checkpoint_period, | |
arguments, | |
data_loaders_val, | |
meters, | |
use_wandb = use_wandb | |
) | |
return model | |
def setup_for_distributed(is_master): | |
""" | |
This function disables printing when not in master process | |
""" | |
import builtins as __builtin__ | |
builtin_print = __builtin__.print | |
def print(*args, **kwargs): | |
force = kwargs.pop("force", False) | |
if is_master or force: | |
builtin_print(*args, **kwargs) | |
__builtin__.print = print | |
def main(): | |
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") | |
parser.add_argument( | |
"--config-file", | |
default="", | |
metavar="FILE", | |
help="path to config file", | |
type=str, | |
) | |
parser.add_argument("--local_rank", type=int, default=0) | |
parser.add_argument( | |
"--skip-test", | |
dest="skip_test", | |
help="Do not test the final model", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--use-tensorboard", | |
dest="use_tensorboard", | |
help="Use tensorboardX logger (Requires tensorboardX installed)", | |
action="store_true", | |
default=False, | |
) | |
parser.add_argument( | |
"opts", | |
help="Modify config options using the command-line", | |
default=None, | |
nargs=argparse.REMAINDER, | |
) | |
parser.add_argument("--save_original_config", action="store_true") | |
parser.add_argument("--disable_output_distributed", action="store_true") | |
parser.add_argument("--debug_nan_checkpoint", default=None) | |
parser.add_argument("--override_output_dir", default=None) | |
parser.add_argument("--wandb_name", default="__test__") | |
parser.add_argument("--use_wandb", action="store_true") | |
args = parser.parse_args() | |
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 | |
args.distributed = num_gpus > 1 | |
if args.distributed: | |
import datetime | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(0, 7200)) | |
if args.disable_output_distributed: | |
setup_for_distributed(args.local_rank <= 0) | |
cfg.local_rank = args.local_rank | |
cfg.num_gpus = num_gpus | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
# specify output dir for models | |
cfg.OUTPUT_DIR = "OUTPUTS/" + args.wandb_name | |
if is_main_process(): | |
mkdir(cfg.OUTPUT_DIR) | |
if args.wandb_name != "__test__" and args.use_wandb: | |
if is_main_process(): | |
run = wandb.init( | |
project = 'lang_det', | |
job_type = 'train_model', | |
name = args.wandb_name, | |
) | |
with open(os.path.join(cfg.OUTPUT_DIR, 'wandb_run_id.txt'), 'w') as f: | |
f.write(run.id) | |
if args.override_output_dir: | |
cfg.OUTPUT_DIR = args.override_output_dir | |
cfg.freeze() | |
seed = cfg.SOLVER.SEED + args.local_rank | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
output_dir = cfg.OUTPUT_DIR | |
if output_dir: | |
mkdir(output_dir) | |
logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank()) | |
logger.info(args) | |
logger.info("Using {} GPUs".format(num_gpus)) | |
# logger.info("Collecting env info (might take some time)") | |
# logger.info("\n" + collect_env_info()) | |
logger.info("Loaded configuration file {}".format(args.config_file)) | |
with open(args.config_file, "r") as cf: | |
config_str = "\n" + cf.read() | |
logger.info(config_str) | |
logger.info("Running with config:\n{}".format(cfg)) | |
output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml') | |
logger.info("Saving config into: {}".format(output_config_path)) | |
# save overloaded model config in the output directory | |
if args.save_original_config: | |
import shutil | |
shutil.copy(args.config_file, os.path.join(cfg.OUTPUT_DIR, "config_original.yml")) | |
save_config(cfg, output_config_path) | |
model = train( | |
cfg=cfg, | |
local_rank=args.local_rank, | |
distributed=args.distributed, | |
use_tensorboard=args.use_tensorboard, | |
use_wandb=args.wandb_name != "__test__" and args.use_wandb) | |
if __name__ == "__main__": | |
main() | |