PolyFormer / data /create_pretraining_data.py
jiang
init commit
650c5f6
import json
import os
from tqdm import tqdm
import random
import pickle
# set up image paths
imgsfile = dict(
coco='mscoco/train2014',
vg='visual-genome',
saiaprtc12='saiaprtc12',
flickr='flickr30k'
)
# load annotation files
f = open("datasets/annotations/instances.json")
print("Loading annotation file")
data = json.load(f)
f.close()
# load the validation and test image list of refcoco, refcoco+, and refcocog
val_test_files = pickle.load(open("data/val_test_files.p", "rb"))
# create result folder
os.makedirs("datasets/pretrain", exist_ok=True)
# generate training tsv file
train_instances = data['train']
tsv_filename = "datasets/pretrain/train_shuffled.tsv"
writer = open(tsv_filename, 'w')
print("generating ", tsv_filename)
lines = []
for i, data_i in enumerate(tqdm(train_instances)):
data_source = data_i['data_source']
image_id = data_i['image_id']
bbox = data_i['bbox']
expressions = data_i['expressions']
height, width = data_i['height'], data_i['width']
x, y, w, h = bbox
box_string = f'{x},{y},{x + w},{y + h}'
img_name = "COCO_train2014_%012d.jpg" if "coco" in data_source else "%d.jpg"
img_name = img_name % image_id
filepath = os.path.join(imgsfile[data_source], img_name)
line = '\t'.join([str(i), expressions[0].replace('\n', ''), box_string, filepath]) + '\n'
lines.append(line)
# shuffle the training set
random.shuffle(lines)
# write training tsv file
writer.writelines(lines)
writer.close()
# generate validation tsv files
val_sets = ['val_refcoco_unc', 'val_refcocoplus_unc', 'val_refcocog_umd', 'val_flickr30k', 'val_referitgame_berkeley']
for val_set in val_sets:
val_instances = data[val_set]
tsv_filename = f"datasets/pretrain/{val_set}.tsv"
writer = open(tsv_filename, 'w')
print("generating ", tsv_filename)
lines = []
for i, data_i in enumerate(tqdm(val_instances)):
data_source = data_i['data_source']
image_id = data_i['image_id']
bbox = data_i['bbox']
expressions = data_i['expressions']
height, width = data_i['height'], data_i['width']
x, y, w, h = bbox
box_string = f'{x},{y},{x + w},{y + h}'
img_name = "COCO_train2014_%012d.jpg" if "coco" in data_source else "%d.jpg"
img_name = img_name % image_id
filepath = os.path.join(imgsfile[data_source], img_name)
line = '\t'.join([str(i), expressions[0].replace('\n', ''), box_string, filepath]) + '\n'
lines.append(line)
# write tsv file
writer.writelines(lines)
writer.close()