File size: 2,196 Bytes
cc0b62b
 
 
1927812
cc0b62b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1927812
cc0b62b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import print_function
import os
from subprocess import call
from builtins import input

curr_folder = os.path.basename(os.path.normpath(os.getcwd()))

weights_filename = 'pytorch_model.bin'
weights_folder = 'model'
weights_path = '{}/{}'.format(weights_folder, weights_filename)
if curr_folder == 'scripts':
    weights_path = '../' + weights_path
weights_download_link = 'https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0#'


MB_FACTOR = float(1<<20)

def prompt():
    while True:
        valid = {
            'y': True,
            'ye': True,
            'yes': True,
            'n': False,
            'no': False,
        }
        choice = input().lower()
        if choice in valid:
            return valid[choice]
        else:
            print('Please respond with \'y\' or \'n\' (or \'yes\' or \'no\')')

download = True
if os.path.exists(weights_path):
    print('Weight file already exists at {}. Would you like to redownload it anyway? [y/n]'.format(weights_path))
    download = prompt()
    already_exists = True
else:
    already_exists = False

if download:
    print('About to download the pretrained weights file from {}'.format(weights_download_link))
    if already_exists == False:
        print('The size of the file is roughly 85MB. Continue? [y/n]')
    else:
        os.unlink(weights_path)

    if already_exists or prompt():
        print('Downloading...')

        #urllib.urlretrieve(weights_download_link, weights_path)
        #with open(weights_path,'wb') as f:
        #    f.write(requests.get(weights_download_link).content)

        # downloading using wget due to issues with urlretrieve and requests
        sys_call = 'wget {} -O {}'.format(weights_download_link, os.path.abspath(weights_path))
        print("Running system call: {}".format(sys_call))
        call(sys_call, shell=True)

        if os.path.getsize(weights_path) / MB_FACTOR < 80:
            raise ValueError("Download finished, but the resulting file is too small! " +
                             "It\'s only {} bytes.".format(os.path.getsize(weights_path)))
        print('Downloaded weights to {}'.format(weights_path))
else:
    print('Exiting.')