virtex-redcaps / virtex /scripts /eval_detectron2.py
zamborg's picture
added datasets and virtex
a5f8a35
raw history blame
No virus
9.45 kB
"""
Finetune a pre-trained model on a downstream task, one of those available in
Detectron2.
Supported downstream:
- LVIS Instance Segmentation
- COCO Instance Segmentation
- Pascal VOC 2007+12 Object Detection
Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py
Thanks to the developers of Detectron2!
"""
import argparse
import os
import re
from typing import Any, Dict, Union
import torch
from torch.utils.tensorboard import SummaryWriter
import detectron2 as d2
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultTrainer, default_setup
from detectron2.evaluation import (
LVISEvaluator,
PascalVOCDetectionEvaluator,
COCOEvaluator,
)
from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
from virtex.config import Config
from virtex.factories import PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser
import virtex.utils.distributed as dist
# fmt: off
parser = common_parser(
description="Train object detectors from pretrained visual backbone."
)
parser.add_argument(
"--d2-config", required=True,
help="Path to a detectron2 config for downstream task finetuning."
)
parser.add_argument(
"--d2-config-override", nargs="*", default=[],
help="""Key-value pairs from Detectron2 config to override from file.
Some keys will be ignored because they are set from other args:
[DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD,
TEST.EVAL_PERIOD, OUTPUT_DIR]""",
)
parser.add_argument_group("Checkpointing and Logging")
parser.add_argument(
"--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
default="virtex", help="""How to initialize weights:
1. 'random' initializes all weights randomly
2. 'imagenet' initializes backbone weights from torchvision model zoo
3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
- with 'torchvision', state dict would be from PyTorch's training
script.
- with 'virtex' it should be for our full pretrained model."""
)
parser.add_argument(
"--checkpoint-path",
help="Path to load checkpoint and run downstream task evaluation."
)
parser.add_argument(
"--resume", action="store_true", help="""Specify this flag when resuming
training from a checkpoint saved by Detectron2."""
)
parser.add_argument(
"--eval-only", action="store_true",
help="Skip training and evaluate checkpoint provided at --checkpoint-path.",
)
parser.add_argument(
"--checkpoint-every", type=int, default=5000,
help="Serialize model to a checkpoint after every these many iterations.",
)
# fmt: on
@ROI_HEADS_REGISTRY.register()
class Res5ROIHeadsExtraNorm(Res5ROIHeads):
r"""
ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN
C4/DC5 backbones for VOC detection.
"""
def _build_res5_block(self, cfg):
seq, out_channels = super()._build_res5_block(cfg)
norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels)
seq.add_module("norm", norm)
return seq, out_channels
def build_detectron2_config(_C: Config, _A: argparse.Namespace):
r"""Build detectron2 config based on our pre-training config and args."""
_D2C = d2.config.get_cfg()
# Override some default values based on our config file.
_D2C.merge_from_file(_A.d2_config)
_D2C.merge_from_list(_A.d2_config_override)
# Set some config parameters from args.
_D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers
_D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every
_D2C.OUTPUT_DIR = _A.serialization_dir
# Set ResNet depth to override in Detectron2's config.
_D2C.MODEL.RESNETS.DEPTH = int(
re.search(r"resnet(\d+)", _C.MODEL.VISUAL.NAME).group(1)
if "torchvision" in _C.MODEL.VISUAL.NAME
else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NAME).group(1)
if "detectron2" in _C.MODEL.VISUAL.NAME
else 0
)
return _D2C
class DownstreamTrainer(DefaultTrainer):
r"""
Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks.
Parameters
----------
cfg: detectron2.config.CfgNode
Detectron2 config object containing all config params.
weights: Union[str, Dict[str, Any]]
Weights to load in the initialized model. If ``str``, then we assume path
to a checkpoint, or if a ``dict``, we assume a state dict. This will be
an ``str`` only if we resume training from a Detectron2 checkpoint.
"""
def __init__(self, cfg, weights: Union[str, Dict[str, Any]]):
super().__init__(cfg)
# Load pre-trained weights before wrapping to DDP because `ApexDDP` has
# some weird issue with `DetectionCheckpointer`.
# fmt: off
if isinstance(weights, str):
# weights are ``str`` means ImageNet init or resume training.
self.start_iter = (
DetectionCheckpointer(
self._trainer.model,
optimizer=self._trainer.optimizer,
scheduler=self.scheduler
).resume_or_load(weights, resume=True).get("iteration", -1) + 1
)
elif isinstance(weights, dict):
# weights are a state dict means our pretrain init.
DetectionCheckpointer(self._trainer.model)._load_model(weights)
# fmt: on
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type == "pascal_voc":
return PascalVOCDetectionEvaluator(dataset_name)
elif evaluator_type == "coco":
return COCOEvaluator(dataset_name, cfg, True, output_folder)
elif evaluator_type == "lvis":
return LVISEvaluator(dataset_name, cfg, True, output_folder)
def test(self, cfg=None, model=None, evaluators=None):
r"""Evaluate the model and log results to stdout and tensorboard."""
cfg = cfg or self.cfg
model = model or self.model
tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
results = super().test(cfg, model)
flat_results = d2.evaluation.testing.flatten_results_dict(results)
for k, v in flat_results.items():
tensorboard_writer.add_scalar(k, v, self.start_iter)
def main(_A: argparse.Namespace):
# Get the current device as set for current distributed process.
# Check `launch` function in `virtex.utils.distributed` module.
device = torch.cuda.current_device()
# Local process group is needed for detectron2.
pg = list(range(dist.get_world_size()))
d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg)
# Create a config object (this will be immutable) and perform common setup
# such as logging and setting up serialization directory.
if _A.weight_init == "imagenet":
_A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])
_C = Config(_A.config, _A.config_override)
# We use `default_setup` from detectron2 to do some common setup, such as
# logging, setting up serialization etc. For more info, look into source.
_D2C = build_detectron2_config(_C, _A)
default_setup(_D2C, _A)
# Prepare weights to pass in instantiation call of trainer.
if _A.weight_init in {"virtex", "torchvision"}:
if _A.resume:
# If resuming training, let detectron2 load weights by providing path.
model = None
weights = _A.checkpoint_path
else:
# Load backbone weights from VirTex pretrained checkpoint.
model = PretrainingModelFactory.from_config(_C)
if _A.weight_init == "virtex":
CheckpointManager(model=model).load(_A.checkpoint_path)
else:
model.visual.cnn.load_state_dict(
torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
strict=False,
)
weights = model.visual.detectron2_backbone_state_dict()
else:
# If random or imagenet init, just load weights after initializing model.
model = PretrainingModelFactory.from_config(_C)
weights = model.visual.detectron2_backbone_state_dict()
# Back up pretrain config and model checkpoint (if provided).
_C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))
if _A.weight_init == "virtex" and not _A.resume:
torch.save(
model.state_dict(),
os.path.join(_A.serialization_dir, "pretrain_model.pth"),
)
del model
trainer = DownstreamTrainer(_D2C, weights)
trainer.test() if _A.eval_only else trainer.train()
if __name__ == "__main__":
_A = parser.parse_args()
# This will launch `main` and set appropriate CUDA device (GPU ID) as
# per process (accessed in the beginning of `main`).
dist.launch(
main,
num_machines=_A.num_machines,
num_gpus_per_machine=_A.num_gpus_per_machine,
machine_rank=_A.machine_rank,
dist_url=_A.dist_url,
args=(_A, ),
)