Matthew
initial commit
0392181
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Grid features extraction script.
"""
import argparse
import os
import torch
import tqdm
from fvcore.common.file_io import PathManager
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.engine import default_setup
from detectron2.evaluation import inference_context
from detectron2.modeling import build_model
from grid_feats import (
add_attribute_config,
build_detection_test_loader_with_attributes,
)
# A simple mapper from object detection dataset to VQA dataset names
dataset_to_folder_mapper = {}
dataset_to_folder_mapper['coco_2014_train'] = 'train2014'
dataset_to_folder_mapper['coco_2014_val'] = 'val2014'
# One may need to change the Detectron2 code to support coco_2015_test
# insert "coco_2015_test": ("coco/test2015", "coco/annotations/image_info_test2015.json"),
# at: https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/builtin.py#L36
dataset_to_folder_mapper['coco_2015_test'] = 'test2015'
def extract_grid_feature_argument_parser():
parser = argparse.ArgumentParser(description="Grid feature extraction")
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument("--dataset", help="name of the dataset", default="coco_2014_train",
choices=['coco_2014_train', 'coco_2014_val', 'coco_2015_test'])
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
def extract_grid_feature_on_dataset(model, data_loader, dump_folder):
for idx, inputs in enumerate(tqdm.tqdm(data_loader)):
with torch.no_grad():
image_id = inputs[0]['image_id']
file_name = '%d.pth' % image_id
# compute features
images = model.preprocess_image(inputs)
features = model.backbone(images.tensor)
outputs = model.roi_heads.get_conv5_features(features)
with PathManager.open(os.path.join(dump_folder, file_name), "wb") as f:
# save as CPU tensors
torch.save(outputs.cpu(), f)
def do_feature_extraction(cfg, model, dataset_name):
with inference_context(model):
dump_folder = os.path.join(cfg.OUTPUT_DIR, "features", dataset_to_folder_mapper[dataset_name])
PathManager.mkdirs(dump_folder)
data_loader = build_detection_test_loader_with_attributes(cfg, dataset_name)
extract_grid_feature_on_dataset(model, data_loader, dump_folder)
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
add_attribute_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# force the final residual block to have dilations 1
cfg.MODEL.RESNETS.RES5_DILATION = 1
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
model = build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=True
)
do_feature_extraction(cfg, model, args.dataset)
if __name__ == "__main__":
args = extract_grid_feature_argument_parser().parse_args()
print("Command Line Args:", args)
main(args)