lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
2.11 kB
import os
import hashlib
from filelock import FileLock
import torch
import gdown
def _download(filename, url, refresh, agent):
dirpath = f'{torch.hub.get_dir()}/s3prl_cache'
os.makedirs(dirpath, exist_ok=True)
filepath = f'{dirpath}/{filename}'
with FileLock(filepath + ".lock"):
if not os.path.isfile(filepath) or refresh:
if agent == 'wget':
os.system(f'wget {url} -O {filepath}')
elif agent == 'gdown':
gdown.download(url, filepath, use_cookies=False)
else:
print('[Download] - Unknown download agent. Only \'wget\' and \'gdown\' are supported.')
raise NotImplementedError
else:
print(f'Using cache found in {filepath}\nfor {url}')
return filepath
def _urls_to_filepaths(*args, refresh=False, agent='wget'):
"""
Preprocess the URL specified in *args into local file paths after downloading
Args:
Any number of URLs (1 ~ any)
Return:
Same number of downloaded file paths
"""
def url_to_filename(url):
assert type(url) is str
m = hashlib.sha256()
m.update(str.encode(url))
return str(m.hexdigest())
def url_to_path(url, refresh):
if type(url) is str and len(url) > 0:
return _download(url_to_filename(url), url, refresh, agent=agent)
else:
return None
paths = [url_to_path(url, refresh) for url in args]
return paths if len(paths) > 1 else paths[0]
def _gdriveids_to_filepaths(*args, refresh=False):
"""
Preprocess the Google Drive id specified in *args into local file paths after downloading
Args:
Any number of Google Drive ids (1 ~ any)
Return:
Same number of downloaded file paths
"""
def gdriveid_to_url(gdriveid):
if type(gdriveid) is str and len(gdriveid) > 0:
return f'https://drive.google.com/uc?id={gdriveid}'
else:
return None
return _urls_to_filepaths(*[gdriveid_to_url(gid) for gid in args], refresh=refresh, agent='gdown')