|
import pandas as pd |
|
import os.path |
|
import sys |
|
import json |
|
import logging |
|
import contexttimer |
|
import numpy as np |
|
|
|
|
|
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO) |
|
|
|
if len(sys.argv) != 4: |
|
print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json") |
|
exit(1) |
|
|
|
annotation_file = sys.argv[1] |
|
images_dir = sys.argv[2] |
|
output_file = sys.argv[3] |
|
|
|
logging.info("Processing subcaption dataset") |
|
|
|
with contexttimer.Timer(prefix="Loading from tsv"): |
|
df = pd.read_csv(annotation_file, delimiter='\t') |
|
|
|
lines = [] |
|
|
|
df = df[["caption", "url"]] |
|
|
|
df = df.replace('', np.nan) |
|
df = df.dropna() |
|
|
|
print(f"Loaded {len(df)} images.") |
|
|
|
for index, caption_reference_description, image_url in df.itertuples(): |
|
|
|
base_url = os.path.basename(image_url) |
|
stem, ext = os.path.splitext(base_url) |
|
filename = f'{index:08d}---{stem}.jpg' |
|
|
|
full_image_path = images_dir+"/"+filename |
|
|
|
if os.path.isfile(full_image_path): |
|
lines.append(json.dumps({"image_path": full_image_path, "captions": [caption_reference_description]})) |
|
else: |
|
|
|
logging.error(full_image_path) |
|
|
|
|
|
train_lines = lines[:-100_001] |
|
valid_lines = lines[-100_001:] |
|
|
|
with open(output_file+"_train.json", "w") as f: |
|
f.write("\n".join(train_lines)) |
|
|
|
with open(output_file+"_val.json", "w") as f: |
|
f.write("\n".join(valid_lines)) |
|
|
|
logging.info(f"Processing subcaption dataset done. {len(lines)} images processed.") |
|
|
|
|