Detic / tools /preprocess_imagenet22k.py
AK391
files
159f437
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import numpy as np
import sys
sys.path.insert(0, 'third_party/CenterNet2/projects/CenterNet2/')
sys.path.insert(0, 'third_party/Deformable-DETR')
from detic.data.tar_dataset import _TarDataset, DiskTarDataset
import pickle
import io
import gzip
import time
class _RawTarDataset(object):
def __init__(self, filename, indexname, preload=False):
self.filename = filename
self.names = []
self.offsets = []
for l in open(indexname):
ll = l.split()
a, b, c = ll[:3]
offset = int(b[:-1])
if l.endswith('** Block of NULs **\n'):
self.offsets.append(offset)
break
else:
if c.endswith('JPEG'):
self.names.append(c)
self.offsets.append(offset)
else:
# ignore directories
pass
if preload:
self.data = np.memmap(filename, mode='r', dtype='uint8')
else:
self.data = None
def __len__(self):
return len(self.names)
def __getitem__(self, idx):
if self.data is None:
self.data = np.memmap(self.filename, mode='r', dtype='uint8')
ofs = self.offsets[idx] * 512
fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx])
data = self.data[ofs:ofs + fsize]
if data[:13].tostring() == '././@LongLink':
data = data[3 * 512:]
else:
data = data[512:]
# just to make it more fun a few JPEGs are GZIP compressed...
# catch this case
if tuple(data[:2]) == (0x1f, 0x8b):
s = io.StringIO(data.tostring())
g = gzip.GzipFile(None, 'r', 0, s)
sdata = g.read()
else:
sdata = data.tostring()
return sdata
def preprocess():
# Follow https://github.com/Alibaba-MIIL/ImageNet21K/blob/main/dataset_preprocessing/processing_script.sh
# Expect 12358684 samples with 11221 classes
# ImageNet folder has 21841 classes (synsets)
i22kdir = '/datasets01/imagenet-22k/062717/'
i22ktarlogs = '/checkpoint/imisra/datasets/imagenet-22k/tarindex'
class_names_file = '/checkpoint/imisra/datasets/imagenet-22k/words.txt'
output_dir = '/checkpoint/zhouxy/Datasets/ImageNet/metadata-22k/'
i22knpytarlogs = '/checkpoint/zhouxy/Datasets/ImageNet/metadata-22k/tarindex_npy'
print('Listing dir')
log_files = os.listdir(i22ktarlogs)
log_files = [x for x in log_files if x.endswith(".tarlog")]
log_files.sort()
chunk_datasets = []
dataset_lens = []
min_count = 0
create_npy_tarlogs = True
print('Creating folders')
if create_npy_tarlogs:
os.makedirs(i22knpytarlogs, exist_ok=True)
for log_file in log_files:
syn = log_file.replace(".tarlog", "")
dataset = _RawTarDataset(os.path.join(i22kdir, syn + ".tar"),
os.path.join(i22ktarlogs, syn + ".tarlog"),
preload=False)
names = np.array(dataset.names)
offsets = np.array(dataset.offsets, dtype=np.int64)
np.save(os.path.join(i22knpytarlogs, f"{syn}_names.npy"), names)
np.save(os.path.join(i22knpytarlogs, f"{syn}_offsets.npy"), offsets)
os.makedirs(output_dir, exist_ok=True)
start_time = time.time()
for log_file in log_files:
syn = log_file.replace(".tarlog", "")
dataset = _TarDataset(os.path.join(i22kdir, syn + ".tar"), i22knpytarlogs)
# dataset = _RawTarDataset(os.path.join(i22kdir, syn + ".tar"),
# os.path.join(i22ktarlogs, syn + ".tarlog"),
# preload=False)
dataset_lens.append(len(dataset))
end_time = time.time()
print(f"Time {end_time - start_time}")
dataset_lens = np.array(dataset_lens)
dataset_valid = dataset_lens > min_count
syn2class = {}
with open(class_names_file) as fh:
for line in fh:
line = line.strip().split("\t")
syn2class[line[0]] = line[1]
tarlog_files = []
class_names = []
tar_files = []
for k in range(len(dataset_valid)):
if not dataset_valid[k]:
continue
syn = log_files[k].replace(".tarlog", "")
tarlog_files.append(os.path.join(i22ktarlogs, syn + ".tarlog"))
tar_files.append(os.path.join(i22kdir, syn + ".tar"))
class_names.append(syn2class[syn])
tarlog_files = np.array(tarlog_files)
tar_files = np.array(tar_files)
class_names = np.array(class_names)
print(f"Have {len(class_names)} classes and {dataset_lens[dataset_valid].sum()} samples")
np.save(os.path.join(output_dir, "tarlog_files.npy"), tarlog_files)
np.save(os.path.join(output_dir, "tar_files.npy"), tar_files)
np.save(os.path.join(output_dir, "class_names.npy"), class_names)
np.save(os.path.join(output_dir, "tar_files.npy"), tar_files)
if __name__ == "__main__":
preprocess()