File size: 1,683 Bytes
ba7a003
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import pandas as pd
import os.path
import sys
import json
import logging
import contexttimer
import numpy as np

# Setup
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)

if len(sys.argv) != 4:
    print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json")
    exit(1)

annotation_file = sys.argv[1]
images_dir = sys.argv[2]
output_file = sys.argv[3]

logging.info("Processing cc3m dataset")

with contexttimer.Timer(prefix="Loading from tsv"):
    df = pd.read_csv(annotation_file, delimiter='\t')

lines = []

df = df[["caption", "url", "index_row"]]

df = df.replace('', np.nan)
df = df.dropna()

print(f"Loaded {len(df)} images.")

for index, caption_reference_description, image_url, index_row in df.itertuples():
    index_row+=1
    base_url = os.path.basename(image_url)  # extract base url
    stem, ext = os.path.splitext(base_url)  # split into stem and extension
    filename = f'{index_row:08d}---{stem}.jpg'

    full_image_path =  images_dir+"/"+filename

    if os.path.isfile(full_image_path):
        lines.append(json.dumps({"image_path": full_image_path, "captions": [caption_reference_description]}))
    else:
        #print(f"{full_image_path} doesn't exist")
        logging.error(full_image_path)

train_lines = lines[:-300_001]
valid_lines = lines[-300_001:]

with open(output_file+"_train.json", "w") as f:
    f.write("\n".join(train_lines))

with open(output_file+"_val.json", "w") as f:
    f.write("\n".join(valid_lines))

logging.info(f"Processing cc3m dataset done. {len(lines)} images processed.")