File size: 2,196 Bytes
cc0b62b 1927812 cc0b62b 1927812 cc0b62b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from __future__ import print_function
import os
from subprocess import call
from builtins import input
curr_folder = os.path.basename(os.path.normpath(os.getcwd()))
weights_filename = 'pytorch_model.bin'
weights_folder = 'model'
weights_path = '{}/{}'.format(weights_folder, weights_filename)
if curr_folder == 'scripts':
weights_path = '../' + weights_path
weights_download_link = 'https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0#'
MB_FACTOR = float(1<<20)
def prompt():
while True:
valid = {
'y': True,
'ye': True,
'yes': True,
'n': False,
'no': False,
}
choice = input().lower()
if choice in valid:
return valid[choice]
else:
print('Please respond with \'y\' or \'n\' (or \'yes\' or \'no\')')
download = True
if os.path.exists(weights_path):
print('Weight file already exists at {}. Would you like to redownload it anyway? [y/n]'.format(weights_path))
download = prompt()
already_exists = True
else:
already_exists = False
if download:
print('About to download the pretrained weights file from {}'.format(weights_download_link))
if already_exists == False:
print('The size of the file is roughly 85MB. Continue? [y/n]')
else:
os.unlink(weights_path)
if already_exists or prompt():
print('Downloading...')
#urllib.urlretrieve(weights_download_link, weights_path)
#with open(weights_path,'wb') as f:
# f.write(requests.get(weights_download_link).content)
# downloading using wget due to issues with urlretrieve and requests
sys_call = 'wget {} -O {}'.format(weights_download_link, os.path.abspath(weights_path))
print("Running system call: {}".format(sys_call))
call(sys_call, shell=True)
if os.path.getsize(weights_path) / MB_FACTOR < 80:
raise ValueError("Download finished, but the resulting file is too small! " +
"It\'s only {} bytes.".format(os.path.getsize(weights_path)))
print('Downloaded weights to {}'.format(weights_path))
else:
print('Exiting.')
|