Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
import os | |
import json | |
import argparse | |
from PIL import Image | |
import numpy as np | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--ann', default='datasets/cc3m/Train_GCC-training.tsv') | |
parser.add_argument('--save_image_path', default='datasets/cc3m/training/') | |
parser.add_argument('--cat_info', default='datasets/lvis/lvis_v1_val.json') | |
parser.add_argument('--out_path', default='datasets/cc3m/train_image_info.json') | |
parser.add_argument('--not_download_image', action='store_true') | |
args = parser.parse_args() | |
categories = json.load(open(args.cat_info, 'r'))['categories'] | |
images = [] | |
if not os.path.exists(args.save_image_path): | |
os.makedirs(args.save_image_path) | |
f = open(args.ann) | |
for i, line in enumerate(f): | |
cap, path = line[:-1].split('\t') | |
print(i, cap, path) | |
if not args.not_download_image: | |
os.system( | |
'wget {} -O {}/{}.jpg'.format( | |
path, args.save_image_path, i + 1)) | |
try: | |
img = Image.open( | |
open('{}/{}.jpg'.format(args.save_image_path, i + 1), "rb")) | |
img = np.asarray(img.convert("RGB")) | |
h, w = img.shape[:2] | |
except: | |
continue | |
image_info = { | |
'id': i + 1, | |
'file_name': '{}.jpg'.format(i + 1), | |
'height': h, | |
'width': w, | |
'captions': [cap], | |
} | |
images.append(image_info) | |
data = {'categories': categories, 'images': images, 'annotations': []} | |
for k, v in data.items(): | |
print(k, len(v)) | |
print('Saving to', args.out_path) | |
json.dump(data, open(args.out_path, 'w')) | |