|
|
|
import argparse |
|
import json |
|
import numpy as np |
|
import pickle |
|
import io |
|
import gzip |
|
import sys |
|
import time |
|
from nltk.corpus import wordnet |
|
from tqdm import tqdm |
|
import operator |
|
import torch |
|
|
|
sys.path.insert(0, 'third_party/CenterNet2/projects/CenterNet2/') |
|
sys.path.insert(0, 'third_party/Deformable-DETR') |
|
from detic.data.tar_dataset import DiskTarDataset, _TarDataset |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--imagenet_dir", default='datasets/imagenet/ImageNet-21k/') |
|
parser.add_argument("--tarfile_path", default='datasets/imagenet/metadata-22k/tar_files.npy') |
|
parser.add_argument("--tar_index_dir", default='datasets/imagenet/metadata-22k/tarindex_npy') |
|
parser.add_argument("--out_path", default='datasets/imagenet/annotations/imagenet-22k_image_info.json') |
|
parser.add_argument("--workers", default=16, type=int) |
|
args = parser.parse_args() |
|
|
|
|
|
start_time = time.time() |
|
print('Building dataset') |
|
dataset = DiskTarDataset(args.tarfile_path, args.tar_index_dir) |
|
end_time = time.time() |
|
print(f"Took {end_time-start_time} seconds to make the dataset.") |
|
print(f"Have {len(dataset)} samples.") |
|
print('dataset', dataset) |
|
|
|
|
|
tar_files = np.load(args.tarfile_path) |
|
categories = [] |
|
for i, tar_file in enumerate(tar_files): |
|
wnid = tar_file[-13:-4] |
|
synset = wordnet.synset_from_pos_and_offset('n', int(wnid[1:])) |
|
synonyms = [x.name() for x in synset.lemmas()] |
|
category = { |
|
'id': i + 1, |
|
'synset': synset.name(), |
|
'name': synonyms[0], |
|
'def': synset.definition(), |
|
'synonyms': synonyms, |
|
} |
|
categories.append(category) |
|
print('categories', len(categories)) |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, batch_size=1, shuffle=False, |
|
num_workers=args.workers, |
|
collate_fn=operator.itemgetter(0), |
|
) |
|
images = [] |
|
for img, label, index in tqdm(data_loader): |
|
if label == -1: |
|
continue |
|
image = { |
|
'id': int(index) + 1, |
|
'pos_category_ids': [int(label) + 1], |
|
'height': int(img.height), |
|
'width': int(img.width), |
|
'tar_index': int(index), |
|
} |
|
images.append(image) |
|
|
|
data = {'categories': categories, 'images': images, 'annotations': []} |
|
try: |
|
for k, v in data.items(): |
|
print(k, len(v)) |
|
print('Saving to ', args.out_path) |
|
json.dump(data, open(args.out_path, 'w')) |
|
except: |
|
pass |
|
import pdb; pdb.set_trace() |
|
|
|
|