WSCL / models /utils.py
yhzhai's picture
release code
482ab8a
import os
import sys
try:
from urllib import urlretrieve
except ImportError:
from urllib.request import urlretrieve
import torch
def load_url(url, model_dir="./pretrained", map_location=torch.device("cpu")):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
filename = url.split("/")[-1]
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
urlretrieve(url, cached_file)
return torch.load(cached_file, map_location=map_location)