import json from torch.utils import data from torchvision.datasets import ImageFolder import torch import os from PIL import Image import numpy as np import argparse from tqdm import tqdm from munkres import Munkres import multiprocessing from multiprocessing import Process, Manager import collections import torchvision.transforms as transforms import torchvision.transforms.functional as TF import random import torchvision import cv2 from label_str_to_imagenet_classes import label_str_to_imagenet_classes torch.manual_seed(0) normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) class ObjectNetDataset(ImageFolder): def __init__(self, imagenet_path): self._imagenet_path = imagenet_path self._all_images = [] o_dataset = ImageFolder(self._imagenet_path) # get mappings folder mappings_folder = os.path.abspath( os.path.join(self._imagenet_path, "../mappings") ) # get ObjectNet label to ImageNet label mapping with open( os.path.join(mappings_folder, "objectnet_to_imagenet_1k.json") ) as file_handle: o_label_to_all_i_labels = json.load(file_handle) # now remove double i labels to avoid confusion o_label_to_i_labels = { o_label: all_i_label.split("; ") for o_label, all_i_label in o_label_to_all_i_labels.items() } # some in-between mappings ... o_folder_to_o_idx = o_dataset.class_to_idx with open( os.path.join(mappings_folder, "folder_to_objectnet_label.json") ) as file_handle: o_folder_o_label = json.load(file_handle) # now get mapping from o_label to o_idx o_label_to_o_idx = { o_label: o_folder_to_o_idx[o_folder] for o_folder, o_label in o_folder_o_label.items() } # some in-between mappings ... with open( os.path.join(mappings_folder, "pytorch_to_imagenet_2012_id.json") ) as file_handle: i_idx_to_i_line = json.load(file_handle) with open( os.path.join(mappings_folder, "imagenet_to_label_2012_v2") ) as file_handle: i_line_to_i_label = file_handle.readlines() i_line_to_i_label = { i_line: i_label[:-1] for i_line, i_label in enumerate(i_line_to_i_label) } # now get mapping from i_label to i_idx i_label_to_i_idx = { i_line_to_i_label[i_line]: int(i_idx) for i_idx, i_line in i_idx_to_i_line.items() } # now get the final mapping of interest!!! o_idx_to_i_idxs = { o_label_to_o_idx[o_label]: [ i_label_to_i_idx[i_label] for i_label in i_labels ] for o_label, i_labels in o_label_to_i_labels.items() } self._tag_list = [] # now get a list of files of interest for filepath, o_idx in o_dataset.samples: if o_idx not in o_idx_to_i_idxs: continue rel_file = os.path.relpath(filepath, self._imagenet_path) if o_idx_to_i_idxs[o_idx][0] not in self._tag_list: self._tag_list.append(o_idx_to_i_idxs[o_idx][0]) self._all_images.append((rel_file, o_idx_to_i_idxs[o_idx][0])) def __getitem__(self, item): image_path, classification = self._all_images[item] image_path = os.path.join(self._imagenet_path, image_path) image = Image.open(image_path) image = image.convert('RGB') image = transform(image) return image, classification def __len__(self): return len(self._all_images)