sooks commited on
Commit
45f6aa8
1 Parent(s): 2a093da

Create download.py

Browse files
Files changed (1) hide show
  1. detector/download.py +49 -0
detector/download.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import requests
4
+ import torch.distributed as dist
5
+ from tqdm import tqdm
6
+
7
+ from .utils import distributed
8
+
9
+ ALL_DATASETS = [
10
+ 'webtext',
11
+ 'small-117M', 'small-117M-k40', 'small-117M-nucleus',
12
+ 'medium-345M', 'medium-345M-k40', 'medium-345M-nucleus',
13
+ 'large-762M', 'large-762M-k40', 'large-762M-nucleus',
14
+ 'xl-1542M', 'xl-1542M-k40', 'xl-1542M-nucleus'
15
+ ]
16
+
17
+
18
+ def download(*datasets, data_dir='data'):
19
+ os.makedirs(data_dir, exist_ok=True)
20
+
21
+ if distributed() and dist.get_rank() > 0:
22
+ dist.barrier()
23
+
24
+ for ds in datasets:
25
+ assert ds in ALL_DATASETS, f'Unknown dataset {ds}'
26
+
27
+ for split in ['train', 'valid', 'test']:
28
+ filename = ds + "." + split + '.jsonl'
29
+ output_file = os.path.join(data_dir, filename)
30
+ if os.path.isfile(output_file):
31
+ continue
32
+
33
+ r = requests.get("https://storage.googleapis.com/gpt-2/output-dataset/v1/" + filename, stream=True)
34
+
35
+ with open(output_file, 'wb') as f:
36
+ file_size = int(r.headers["content-length"])
37
+ chunk_size = 1000
38
+ with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar:
39
+ # 1k for chunk_size, since Ethernet packet size is around 1500 bytes
40
+ for chunk in r.iter_content(chunk_size=chunk_size):
41
+ f.write(chunk)
42
+ pbar.update(chunk_size)
43
+
44
+ if distributed() and dist.get_rank() == 0:
45
+ dist.barrier()
46
+
47
+
48
+ if __name__ == '__main__':
49
+ download(*ALL_DATASETS)