| """ | |
| Thread-safe file downloading and cacheing | |
| Authors | |
| * Leo 2022 | |
| * Cheng Liang 2022 | |
| """ | |
| import hashlib | |
| import logging | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from urllib.request import Request, urlopen | |
| import requests | |
| from filelock import FileLock | |
| from tqdm import tqdm | |
| logger = logging.getLogger(__name__) | |
| _download_dir = Path.home() / ".cache" / "s3prl" / "download" | |
| __all__ = [ | |
| "get_dir", | |
| "set_dir", | |
| "download", | |
| "urls_to_filepaths", | |
| ] | |
| def get_dir(): | |
| _download_dir.mkdir(exist_ok=True, parents=True) | |
| return _download_dir | |
| def set_dir(d): | |
| global _download_dir | |
| _download_dir = Path(d) | |
| def _download_url_to_file(url, dst, hash_prefix=None, progress=True): | |
| """ | |
| This function is not thread-safe. Please ensure only a single | |
| thread or process can enter this block at the same time | |
| """ | |
| file_size = None | |
| req = Request(url, headers={"User-Agent": "torch.hub"}) | |
| u = urlopen(req) | |
| meta = u.info() | |
| if hasattr(meta, "getheaders"): | |
| content_length = meta.getheaders("Content-Length") | |
| else: | |
| content_length = meta.get_all("Content-Length") | |
| if content_length is not None and len(content_length) > 0: | |
| file_size = int(content_length[0]) | |
| dst = os.path.expanduser(dst) | |
| dst_dir = os.path.dirname(dst) | |
| f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | |
| try: | |
| if hash_prefix is not None: | |
| sha256 = hashlib.sha256() | |
| tqdm.write(f"Downloading: {url}", file=sys.stderr) | |
| tqdm.write(f"Destination: {dst}", file=sys.stderr) | |
| with tqdm( | |
| total=file_size, | |
| disable=not progress, | |
| unit="B", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as pbar: | |
| while True: | |
| buffer = u.read(8192) | |
| if len(buffer) == 0: | |
| break | |
| f.write(buffer) | |
| if hash_prefix is not None: | |
| sha256.update(buffer) | |
| pbar.update(len(buffer)) | |
| f.close() | |
| if hash_prefix is not None: | |
| digest = sha256.hexdigest() | |
| if digest[: len(hash_prefix)] != hash_prefix: | |
| raise RuntimeError( | |
| 'invalid hash value (expected "{}", got "{}")'.format( | |
| hash_prefix, digest | |
| ) | |
| ) | |
| shutil.move(f.name, dst) | |
| finally: | |
| f.close() | |
| if os.path.exists(f.name): | |
| os.remove(f.name) | |
| def _download_url_to_file_requests(url, dst, hash_prefix=None, progress=True): | |
| """ | |
| Alternative download when urllib.Request fails. | |
| """ | |
| req = requests.get(url, stream=True, headers={"User-Agent": "torch.hub"}) | |
| file_size = int(req.headers["Content-Length"]) | |
| dst = os.path.expanduser(dst) | |
| dst_dir = os.path.dirname(dst) | |
| f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | |
| try: | |
| if hash_prefix is not None: | |
| sha256 = hashlib.sha256() | |
| tqdm.write( | |
| f"urllib.Request method failed. Trying using another method...", | |
| file=sys.stderr, | |
| ) | |
| tqdm.write(f"Downloading: {url}", file=sys.stderr) | |
| tqdm.write(f"Destination: {dst}", file=sys.stderr) | |
| with tqdm( | |
| total=file_size, | |
| disable=not progress, | |
| unit="B", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as pbar: | |
| for chunk in req.iter_content(chunk_size=1024 * 1024 * 10): | |
| if chunk: | |
| f.write(chunk) | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| if hash_prefix is not None: | |
| sha256.update(chunk) | |
| pbar.update(len(chunk)) | |
| f.close() | |
| if hash_prefix is not None: | |
| digest = sha256.hexdigest() | |
| if digest[: len(hash_prefix)] != hash_prefix: | |
| raise RuntimeError( | |
| 'invalid hash value (expected "{}", got "{}")'.format( | |
| hash_prefix, digest | |
| ) | |
| ) | |
| shutil.move(f.name, dst) | |
| finally: | |
| f.close() | |
| if os.path.exists(f.name): | |
| os.remove(f.name) | |
| def _download(filepath: Path, url, refresh: bool, new_enough_secs: float = 2.0): | |
| """ | |
| If refresh is True, check the latest modfieid time of the filepath. | |
| If the file is new enough (no older than `new_enough_secs`), than directly use it. | |
| If the file is older than `new_enough_secs`, than re-download the file. | |
| This function is useful when multi-processes are all downloading the same large file | |
| """ | |
| Path(filepath).parent.mkdir(exist_ok=True, parents=True) | |
| lock_file = Path(str(filepath) + ".lock") | |
| logger.info(f"Requesting URL: {url}") | |
| with FileLock(str(lock_file)): | |
| if not filepath.is_file() or ( | |
| refresh and (time.time() - os.path.getmtime(filepath)) > new_enough_secs | |
| ): | |
| try: | |
| _download_url_to_file(url, filepath) | |
| except: | |
| _download_url_to_file_requests(url, filepath) | |
| logger.info(f"Using URL's local file: {filepath}") | |
| def _urls_to_filepaths(*args, refresh=False, download: bool = True): | |
| """ | |
| Preprocess the URL specified in *args into local file paths after downloading | |
| Args: | |
| Any number of URLs (1 ~ any) | |
| Return: | |
| Same number of downloaded file paths | |
| """ | |
| def _url_to_filepath(url): | |
| assert isinstance(url, str) | |
| m = hashlib.sha256() | |
| m.update(str.encode(url)) | |
| filepath = get_dir() / f"{str(m.hexdigest())}.{Path(url).name}" | |
| if download: | |
| _download(filepath, url, refresh=refresh) | |
| return str(filepath.resolve()) | |
| paths = [_url_to_filepath(url) for url in args] | |
| return paths if len(paths) > 1 else paths[0] | |
| download = _download | |
| urls_to_filepaths = _urls_to_filepaths | |