Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
""" | |
Grid features extraction script. | |
""" | |
import argparse | |
import os | |
import torch | |
import tqdm | |
from fvcore.common.file_io import PathManager | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.config import get_cfg | |
from detectron2.engine import default_setup | |
from detectron2.evaluation import inference_context | |
from detectron2.modeling import build_model | |
from grid_feats import ( | |
add_attribute_config, | |
build_detection_test_loader_with_attributes, | |
) | |
# A simple mapper from object detection dataset to VQA dataset names | |
dataset_to_folder_mapper = {} | |
dataset_to_folder_mapper['coco_2014_train'] = 'train2014' | |
dataset_to_folder_mapper['coco_2014_val'] = 'val2014' | |
# One may need to change the Detectron2 code to support coco_2015_test | |
# insert "coco_2015_test": ("coco/test2015", "coco/annotations/image_info_test2015.json"), | |
# at: https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/builtin.py#L36 | |
dataset_to_folder_mapper['coco_2015_test'] = 'test2015' | |
def extract_grid_feature_argument_parser(): | |
parser = argparse.ArgumentParser(description="Grid feature extraction") | |
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") | |
parser.add_argument("--dataset", help="name of the dataset", default="coco_2014_train", | |
choices=['coco_2014_train', 'coco_2014_val', 'coco_2015_test']) | |
parser.add_argument( | |
"opts", | |
help="Modify config options using the command-line", | |
default=None, | |
nargs=argparse.REMAINDER, | |
) | |
return parser | |
def extract_grid_feature_on_dataset(model, data_loader, dump_folder): | |
for idx, inputs in enumerate(tqdm.tqdm(data_loader)): | |
with torch.no_grad(): | |
image_id = inputs[0]['image_id'] | |
file_name = '%d.pth' % image_id | |
# compute features | |
images = model.preprocess_image(inputs) | |
features = model.backbone(images.tensor) | |
outputs = model.roi_heads.get_conv5_features(features) | |
with PathManager.open(os.path.join(dump_folder, file_name), "wb") as f: | |
# save as CPU tensors | |
torch.save(outputs.cpu(), f) | |
def do_feature_extraction(cfg, model, dataset_name): | |
with inference_context(model): | |
dump_folder = os.path.join(cfg.OUTPUT_DIR, "features", dataset_to_folder_mapper[dataset_name]) | |
PathManager.mkdirs(dump_folder) | |
data_loader = build_detection_test_loader_with_attributes(cfg, dataset_name) | |
extract_grid_feature_on_dataset(model, data_loader, dump_folder) | |
def setup(args): | |
""" | |
Create configs and perform basic setups. | |
""" | |
cfg = get_cfg() | |
add_attribute_config(cfg) | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
# force the final residual block to have dilations 1 | |
cfg.MODEL.RESNETS.RES5_DILATION = 1 | |
cfg.freeze() | |
default_setup(cfg, args) | |
return cfg | |
def main(args): | |
cfg = setup(args) | |
model = build_model(cfg) | |
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( | |
cfg.MODEL.WEIGHTS, resume=True | |
) | |
do_feature_extraction(cfg, model, args.dataset) | |
if __name__ == "__main__": | |
args = extract_grid_feature_argument_parser().parse_args() | |
print("Command Line Args:", args) | |
main(args) | |