File size: 2,301 Bytes
75b01a0
a8e4fc0
 
 
75b01a0
a8e4fc0
 
 
 
75b01a0
a8e4fc0
 
75b01a0
a8e4fc0
 
 
75b01a0
 
a8e4fc0
 
75b01a0
a8e4fc0
 
 
75b01a0
a8e4fc0
 
 
75b01a0
a8e4fc0
 
 
 
75b01a0
a8e4fc0
 
75b01a0
a8e4fc0
75b01a0
a8e4fc0
 
75b01a0
a8e4fc0
 
 
 
 
 
 
 
 
 
75b01a0
a8e4fc0
 
75b01a0
a8e4fc0
 
75b01a0
a8e4fc0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
'''
This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory.
Find them here- [https://github.com/google-research-datasets/conceptual-captions]
'''

import sys
import os
from datetime import datetime
import pandas as pd
import contexttimer
from urllib.request import urlopen
import requests
from PIL import Image
import torch
from torchvision.transforms import functional as TF
from multiprocessing import Pool
from tqdm import tqdm
import logging
import sys

# Setup
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)

if len(sys.argv) != 3:
    print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training")
    exit(1)

# Load data
print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
with contexttimer.Timer(prefix="Loading from tsv"):
    df = pd.read_csv(sys.argv[1], delimiter='\t', header=None)

url_to_idx_map = {url: index for index, caption, url in df.itertuples()}
print(f'Loaded {len(url_to_idx_map)} urls')

base_dir = os.path.join(os.getcwd(), sys.argv[2])

def process(item):
    url, image_id = item
    try:
        base_url = os.path.basename(url)  # extract base url
        stem, ext = os.path.splitext(base_url)  # split into stem and extension
        filename = f'{image_id:08d}---{stem}.jpg'  # create filename
        filepath = os.path.join(base_dir, filename)  # concat to get filepath
        if not os.path.isfile(filepath):
            req = requests.get(url, stream=True, timeout=1, verify=False).raw
            image = Image.open(req).convert('RGB')
            if min(image.size) > 512:
                image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
            image.save(filepath)  # save PIL image
    except Exception as e:
        logging.info(" ".join(repr(e).splitlines()))
        logging.error(url)

list_of_items = list(url_to_idx_map.items())
print(len(list_of_items))

with Pool(128) as p:
    r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
    print('DONE')