import os import hashlib from filelock import FileLock import torch import gdown def _download(filename, url, refresh, agent): dirpath = f'{torch.hub.get_dir()}/s3prl_cache' os.makedirs(dirpath, exist_ok=True) filepath = f'{dirpath}/{filename}' with FileLock(filepath + ".lock"): if not os.path.isfile(filepath) or refresh: if agent == 'wget': os.system(f'wget {url} -O {filepath}') elif agent == 'gdown': gdown.download(url, filepath, use_cookies=False) else: print('[Download] - Unknown download agent. Only \'wget\' and \'gdown\' are supported.') raise NotImplementedError else: print(f'Using cache found in {filepath}\nfor {url}') return filepath def _urls_to_filepaths(*args, refresh=False, agent='wget'): """ Preprocess the URL specified in *args into local file paths after downloading Args: Any number of URLs (1 ~ any) Return: Same number of downloaded file paths """ def url_to_filename(url): assert type(url) is str m = hashlib.sha256() m.update(str.encode(url)) return str(m.hexdigest()) def url_to_path(url, refresh): if type(url) is str and len(url) > 0: return _download(url_to_filename(url), url, refresh, agent=agent) else: return None paths = [url_to_path(url, refresh) for url in args] return paths if len(paths) > 1 else paths[0] def _gdriveids_to_filepaths(*args, refresh=False): """ Preprocess the Google Drive id specified in *args into local file paths after downloading Args: Any number of Google Drive ids (1 ~ any) Return: Same number of downloaded file paths """ def gdriveid_to_url(gdriveid): if type(gdriveid) is str and len(gdriveid) > 0: return f'https://drive.google.com/uc?id={gdriveid}' else: return None return _urls_to_filepaths(*[gdriveid_to_url(gid) for gid in args], refresh=refresh, agent='gdown')