#!/usr/bin/env python # Copyright (c) Facebook, Inc. and its affiliates. import argparse import os from itertools import chain import cv2 import tqdm from detectron2.config import get_cfg from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_train_loader from detectron2.data import detection_utils as utils from detectron2.data.build import filter_images_with_few_keypoints from detectron2.utils.logger import setup_logger from detectron2.utils.visualizer import Visualizer def setup(args): cfg = get_cfg() if args.config_file: cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.DATALOADER.NUM_WORKERS = 0 cfg.freeze() return cfg def parse_args(in_args=None): parser = argparse.ArgumentParser(description="Visualize ground-truth data") parser.add_argument( "--source", choices=["annotation", "dataloader"], required=True, help="visualize the annotations or the data loader (with pre-processing)", ) parser.add_argument("--config-file", metavar="FILE", help="path to config file") parser.add_argument("--output-dir", default="./", help="path to output directory") parser.add_argument("--show", action="store_true", help="show output in a window") parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) return parser.parse_args(in_args) if __name__ == "__main__": args = parse_args() logger = setup_logger() logger.info("Arguments: " + str(args)) cfg = setup(args) dirname = args.output_dir os.makedirs(dirname, exist_ok=True) metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) def output(vis, fname): if args.show: print(fname) cv2.imshow("window", vis.get_image()[:, :, ::-1]) cv2.waitKey() else: filepath = os.path.join(dirname, fname) print("Saving to {} ...".format(filepath)) vis.save(filepath) scale = 1.0 if args.source == "dataloader": train_data_loader = build_detection_train_loader(cfg) for batch in train_data_loader: for per_image in batch: # Pytorch tensor is in (C, H, W) format img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy() img = utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT) visualizer = Visualizer(img, metadata=metadata, scale=scale) target_fields = per_image["instances"].get_fields() labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]] vis = visualizer.overlay_instances( labels=labels, boxes=target_fields.get("gt_boxes", None), masks=target_fields.get("gt_masks", None), keypoints=target_fields.get("gt_keypoints", None), ) output(vis, str(per_image["image_id"]) + ".jpg") else: dicts = list(chain.from_iterable([DatasetCatalog.get(k) for k in cfg.DATASETS.TRAIN])) if cfg.MODEL.KEYPOINT_ON: dicts = filter_images_with_few_keypoints(dicts, 1) for dic in tqdm.tqdm(dicts): img = utils.read_image(dic["file_name"], "RGB") visualizer = Visualizer(img, metadata=metadata, scale=scale) vis = visualizer.draw_dataset_dict(dic) output(vis, os.path.basename(dic["file_name"]))