|
from __future__ import print_function |
|
import os |
|
from subprocess import call |
|
from builtins import input |
|
|
|
curr_folder = os.path.basename(os.path.normpath(os.getcwd())) |
|
|
|
weights_filename = 'pytorch_model.bin' |
|
weights_folder = 'model' |
|
weights_path = '{}/{}'.format(weights_folder, weights_filename) |
|
if curr_folder == 'scripts': |
|
weights_path = '../' + weights_path |
|
weights_download_link = 'https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0#' |
|
|
|
|
|
MB_FACTOR = float(1<<20) |
|
|
|
def prompt(): |
|
while True: |
|
valid = { |
|
'y': True, |
|
'ye': True, |
|
'yes': True, |
|
'n': False, |
|
'no': False, |
|
} |
|
choice = input().lower() |
|
if choice in valid: |
|
return valid[choice] |
|
else: |
|
print('Please respond with \'y\' or \'n\' (or \'yes\' or \'no\')') |
|
|
|
download = True |
|
if os.path.exists(weights_path): |
|
print('Weight file already exists at {}. Would you like to redownload it anyway? [y/n]'.format(weights_path)) |
|
download = prompt() |
|
already_exists = True |
|
else: |
|
already_exists = False |
|
|
|
if download: |
|
print('About to download the pretrained weights file from {}'.format(weights_download_link)) |
|
if already_exists == False: |
|
print('The size of the file is roughly 85MB. Continue? [y/n]') |
|
else: |
|
os.unlink(weights_path) |
|
|
|
if already_exists or prompt(): |
|
print('Downloading...') |
|
|
|
|
|
|
|
|
|
|
|
|
|
sys_call = 'wget {} -O {}'.format(weights_download_link, os.path.abspath(weights_path)) |
|
print("Running system call: {}".format(sys_call)) |
|
call(sys_call, shell=True) |
|
|
|
if os.path.getsize(weights_path) / MB_FACTOR < 80: |
|
raise ValueError("Download finished, but the resulting file is too small! " + |
|
"It\'s only {} bytes.".format(os.path.getsize(weights_path))) |
|
print('Downloaded weights to {}'.format(weights_path)) |
|
else: |
|
print('Exiting.') |
|
|