boris commited on
Commit
1055c3d
2 Parent(s): 8b9d1f5 3df3a47

Merge pull request #7 from khalidsaifullaah/main

Browse files
Files changed (2) hide show
  1. data/CC12M_downloader.py +91 -0
  2. data/CC3M_downloader.py +46 -140
data/CC12M_downloader.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
2
+
3
+ #%%
4
+ import sys
5
+ import os
6
+ from datetime import datetime
7
+ import pandas as pd
8
+ import contexttimer
9
+ from urllib.request import urlopen
10
+ import requests
11
+ from PIL import Image
12
+ import torch
13
+ from torchvision.transforms import functional as TF
14
+ from multiprocessing import Pool
15
+ from tqdm import tqdm
16
+ import logging
17
+
18
+ # Setup
19
+ logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
20
+ requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
21
+
22
+
23
+ # # For downloading SVG images (I can't get this to work)
24
+ # from io import BytesIO
25
+ # import cairosvg
26
+
27
+ #%%
28
+ # Load data
29
+ print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
30
+ with contexttimer.Timer(prefix="Loading from tsv"):
31
+ df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None)
32
+
33
+ url_to_idx_map = {url: index for index, url, caption in df.itertuples()}
34
+ print(f'Loaded {len(url_to_idx_map)} urls')
35
+
36
+ #%%
37
+ df.head()
38
+
39
+ #%%
40
+
41
+ # Note: it seems that there are no SVG images
42
+ df.sample(10000)[1].str.contains('.svg').sum()
43
+
44
+ #%%
45
+ # Resize function
46
+ def resize(img):
47
+ max_size_of_short_side = 512
48
+ if min(img.size) > max_size_of_short_side:
49
+ img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS)
50
+ return img
51
+
52
+ base_dir = os.path.join(os.getcwd(), 'images')
53
+
54
+ def process(item):
55
+ url, image_id = item
56
+ try:
57
+ base_url = os.path.basename(url) # extract base url
58
+ stem, ext = os.path.splitext(base_url) # split into stem and extension
59
+ filename = f'{image_id:08d}---{stem}.jpg' # create filename
60
+ filepath = os.path.join(base_dir, filename) # concat to get filepath
61
+ if not os.path.isfile(filepath):
62
+ # if filepath.endswith('.svg'):
63
+ # raise NotImplementedError()
64
+ # image_bytes = BytesIO() # create a bytestream
65
+ # cairosvg.svg2png(url=url, write_to=image_bytes) # convert svg into image
66
+ # else:
67
+ req = requests.get(url, stream=True, timeout=1, verify=False).raw
68
+ image = Image.open(req).convert('RGB')
69
+ if min(image.size) > 512:
70
+ image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
71
+ # image = resize(image) # resize PIL image
72
+ image.save(filepath) # save PIL image
73
+ except Exception as e:
74
+ logging.info(" ".join(repr(e).splitlines()))
75
+ logging.error(url)
76
+
77
+ #%%
78
+ #for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))):
79
+ # process(item)
80
+ # if i > 100:
81
+ # break
82
+
83
+ # Use multiprocessing for speed
84
+ list_of_items = list(url_to_idx_map.items())
85
+ print(len(list_of_items))
86
+ list_of_items = list_of_items[10_000_000:]
87
+ print(len(list_of_items))
88
+ with Pool(128) as p:
89
+ r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
90
+ print('DONE')
91
+
data/CC3M_downloader.py CHANGED
@@ -1,156 +1,62 @@
1
- # It expects you to have the train and validation `.tsv` file downloaded in the current directory
2
- # Head around to this link to download the `.tsv` files
3
- # https://ai.google.com/research/ConceptualCaptions/download
4
-
5
  '''
6
- This script was adapted from https://github.com/igorbrigadir/DownloadConceptualCaptions
7
- Few changes were made post that (excluding the post processing of data). We'll have
8
- only csv file with image url and captions written in different languages but not images
9
- as we do not own any of the images in the dataset and hence cannot legally provide them to you.
10
  '''
 
 
 
 
11
  import pandas as pd
12
- import numpy as np
 
13
  import requests
14
- import zlib
15
- import os
16
- import shelve
17
- import magic
18
  from multiprocessing import Pool
19
  from tqdm import tqdm
 
 
20
 
21
- headers = {
22
- 'User-Agent':'Googlebot-Image/1.0', # Pretend to be googlebot
23
- 'X-Forwarded-For': '64.18.15.200'
24
- }
25
-
26
- def _df_split_apply(tup_arg):
27
- split_ind, subset, func = tup_arg
28
- r = subset.apply(func, axis=1)
29
- return (split_ind, r)
30
 
31
- def df_multiprocess(df, processes, chunk_size, func, dataset_name):
32
- print("Generating parts...")
33
- with shelve.open('%s_%s_%s_results.tmp' % (dataset_name, func.__name__, chunk_size)) as results:
34
 
35
- pbar = tqdm(total=len(df), position=0)
36
- # Resume:
37
- finished_chunks = set([int(k) for k in results.keys()])
38
- pbar.desc = "Resuming"
39
- for k in results.keys():
40
- pbar.update(len(results[str(k)][1]))
41
 
42
- pool_data = ((index, df[i:i + chunk_size], func) for index, i in enumerate(range(0, len(df), chunk_size)) if index not in finished_chunks)
43
- print(int(len(df) / chunk_size), "parts.", chunk_size, "per part.", "Using", processes, "processes")
44
 
45
- pbar.desc = "Downloading"
46
- with Pool(processes) as pool:
47
- for i, result in enumerate(pool.imap_unordered(_df_split_apply, pool_data, 2)):
48
- results[str(result[0])] = result
49
- pbar.update(len(result[1]))
50
- pbar.close()
51
 
52
- print("Finished Downloading.")
53
- return
54
-
55
- # Unique name based on url
56
- def _file_name(row):
57
- return "%s/%s_%s" % (row['folder'], row.name, (zlib.crc32(row['url'].encode('utf-8')) & 0xffffffff))
58
-
59
- # For checking mimetypes separately without download
60
- def check_mimetype(row):
61
- if os.path.isfile(str(row['file'])):
62
- row['mimetype'] = magic.from_file(row['file'], mime=True)
63
- row['size'] = os.stat(row['file']).st_size
64
- return row
65
-
66
- # Don't download image, just check with a HEAD request, can't resume.
67
- # Can use this instead of download_image to get HTTP status codes.
68
- def check_download(row):
69
- fname = _file_name(row)
70
  try:
71
- # not all sites will support HEAD
72
- response = requests.head(row['url'], stream=False, timeout=5, allow_redirects=True, headers=headers)
73
- row['status'] = response.status_code
74
- row['headers'] = dict(response.headers)
75
- except:
76
- # log errors later, set error as 408 timeout
77
- row['status'] = 408
78
- return row
79
- if response.ok:
80
- row['file'] = fname
81
- return row
82
-
83
- def download_image(row):
84
- fname = _file_name(row)
85
- # Skip Already downloaded, retry others later
86
- if os.path.isfile(fname):
87
- row['status'] = 200
88
- row['file'] = fname
89
- row['mimetype'] = magic.from_file(row['file'], mime=True)
90
- row['size'] = os.stat(row['file']).st_size
91
- return row
92
-
93
- try:
94
- # use smaller timeout to skip errors, but can result in failed downloads
95
- response = requests.get(row['url'], stream=False, timeout=10, allow_redirects=True, headers=headers)
96
- row['status'] = response.status_code
97
- #row['headers'] = dict(response.headers)
98
  except Exception as e:
99
- # log errors later, set error as 408 timeout
100
- row['status'] = 408
101
- return row
102
-
103
- if response.ok:
104
- try:
105
- with open(fname, 'wb') as out_file:
106
- # some sites respond with gzip transport encoding
107
- response.raw.decode_content = True
108
- out_file.write(response.content)
109
- row['mimetype'] = magic.from_file(fname, mime=True)
110
- row['size'] = os.stat(fname).st_size
111
- except:
112
- # This is if it times out during a download or decode
113
- row['status'] = 408
114
- return row
115
- row['file'] = fname
116
- return row
117
-
118
- def open_tsv(fname, folder):
119
- print("Opening %s Data File..." % fname)
120
- df = pd.read_csv(fname, sep='\t', names=["caption","url"], usecols=range(1,2))
121
- df['folder'] = folder
122
- print("Processing", len(df), " Images:")
123
- return df
124
 
125
- def df_from_shelve(chunk_size, func, dataset_name):
126
- print("Generating Dataframe from results...")
127
- with shelve.open('%s_%s_%s_results.tmp' % (dataset_name, func.__name__, chunk_size)) as results:
128
- keylist = sorted([int(k) for k in results.keys()])
129
- df = pd.concat([results[str(k)][1] for k in keylist], sort=True)
130
- return df
131
-
132
- # number of processes in the pool can be larger than cores
133
- num_processes = 256
134
- # chunk_size is how many images per chunk per process - changing this resets progress when restarting.
135
- images_per_part = 200
136
-
137
- '''
138
- A bunch of them will fail to download, and return web pages instead. These will
139
- need to be cleaned up later. See downloaded_validation_report.tsv after it downloads
140
- for HTTP errors. Around 10-11% of images are gone, based on validation set results. Setting
141
- the user agent could fix some errors too maybe - not sure if any requests are rejected by
142
- sites based on this.
143
- '''
144
- data_name = "validation"
145
- df = open_tsv("Validation_GCC-1.1.0-Validation.tsv", data_name)
146
- df_multiprocess(df=df, processes=num_processes, chunk_size=images_per_part, func=download_image, dataset_name=data_name)
147
- df = df_from_shelve(chunk_size=images_per_part, func=download_image, dataset_name=data_name)
148
- df.to_csv("downloaded_%s_report.tsv.gz" % data_name, compression='gzip', sep='\t', header=False, index=False)
149
- print("Saved.")
150
 
151
- data_name = "training"
152
- df = open_tsv("Train-GCC-training.tsv",data_name)
153
- df_multiprocess(df=df, processes=num_processes, chunk_size=images_per_part, func=download_image, dataset_name=data_name)
154
- df = df_from_shelve(chunk_size=images_per_part, func=download_image, dataset_name=data_name)
155
- df.to_csv("downloaded_%s_report.tsv.gz" % data_name, compression='gzip', sep='\t', header=False, index=False)
156
- print("Saved.")
 
 
 
 
 
1
  '''
2
+ This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
3
+ Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory.
4
+ Find them here- [https://github.com/google-research-datasets/conceptual-captions]
 
5
  '''
6
+
7
+ import sys
8
+ import os
9
+ from datetime import datetime
10
  import pandas as pd
11
+ import contexttimer
12
+ from urllib.request import urlopen
13
  import requests
14
+ from PIL import Image
15
+ import torch
16
+ from torchvision.transforms import functional as TF
 
17
  from multiprocessing import Pool
18
  from tqdm import tqdm
19
+ import logging
20
+ import sys
21
 
22
+ # Setup
23
+ logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
24
+ requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
 
 
 
 
 
 
25
 
26
+ if len(sys.argv) != 3:
27
+ print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training")
28
+ exit(1)
29
 
30
+ # Load data
31
+ print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
32
+ with contexttimer.Timer(prefix="Loading from tsv"):
33
+ df = pd.read_csv(sys.argv[1], delimiter='\t', header=None)
 
 
34
 
35
+ url_to_idx_map = {url: index for index, caption, url in df.itertuples()}
36
+ print(f'Loaded {len(url_to_idx_map)} urls')
37
 
38
+ base_dir = os.path.join(os.getcwd(), sys.argv[2])
 
 
 
 
 
39
 
40
+ def process(item):
41
+ url, image_id = item
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  try:
43
+ base_url = os.path.basename(url) # extract base url
44
+ stem, ext = os.path.splitext(base_url) # split into stem and extension
45
+ filename = f'{image_id:08d}---{stem}.jpg' # create filename
46
+ filepath = os.path.join(base_dir, filename) # concat to get filepath
47
+ if not os.path.isfile(filepath):
48
+ req = requests.get(url, stream=True, timeout=1, verify=False).raw
49
+ image = Image.open(req).convert('RGB')
50
+ if min(image.size) > 512:
51
+ image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
52
+ image.save(filepath) # save PIL image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
+ logging.info(" ".join(repr(e).splitlines()))
55
+ logging.error(url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ list_of_items = list(url_to_idx_map.items())
58
+ print(len(list_of_items))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ with Pool(128) as p:
61
+ r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
62
+ print('DONE')