Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from __future__ import print_function | |
import os | |
from subprocess import call | |
from builtins import input | |
curr_folder = os.path.basename(os.path.normpath(os.getcwd())) | |
weights_filename = 'pytorch_model.bin' | |
weights_folder = 'model' | |
weights_path = '{}/{}'.format(weights_folder, weights_filename) | |
if curr_folder == 'scripts': | |
weights_path = '../' + weights_path | |
weights_download_link = 'https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0#' | |
MB_FACTOR = float(1<<20) | |
def prompt(): | |
while True: | |
valid = { | |
'y': True, | |
'ye': True, | |
'yes': True, | |
'n': False, | |
'no': False, | |
} | |
choice = input().lower() | |
if choice in valid: | |
return valid[choice] | |
else: | |
print('Please respond with \'y\' or \'n\' (or \'yes\' or \'no\')') | |
download = True | |
if os.path.exists(weights_path): | |
print('Weight file already exists at {}. Would you like to redownload it anyway? [y/n]'.format(weights_path)) | |
download = prompt() | |
already_exists = True | |
else: | |
already_exists = False | |
if download: | |
print('About to download the pretrained weights file from {}'.format(weights_download_link)) | |
if already_exists == False: | |
print('The size of the file is roughly 85MB. Continue? [y/n]') | |
else: | |
os.unlink(weights_path) | |
if already_exists or prompt(): | |
print('Downloading...') | |
#urllib.urlretrieve(weights_download_link, weights_path) | |
#with open(weights_path,'wb') as f: | |
# f.write(requests.get(weights_download_link).content) | |
# downloading using wget due to issues with urlretrieve and requests | |
sys_call = 'wget {} -O {}'.format(weights_download_link, os.path.abspath(weights_path)) | |
print("Running system call: {}".format(sys_call)) | |
call(sys_call, shell=True) | |
if os.path.getsize(weights_path) / MB_FACTOR < 80: | |
raise ValueError("Download finished, but the resulting file is too small! " + | |
"It\'s only {} bytes.".format(os.path.getsize(weights_path))) | |
print('Downloaded weights to {}'.format(weights_path)) | |
else: | |
print('Exiting.') | |