Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # The 1.6 release of PyTorch switched torch.save to use a new zipfile-based | |
| # file format. It will cause RuntimeError when a checkpoint was saved in | |
| # torch >= 1.6.0 but loaded in torch < 1.7.0. | |
| # More details at https://github.com/open-mmlab/mmpose/issues/904 | |
| from ..path import mkdir_or_exist | |
| from ..version_utils import digit_version | |
| from .parrots_wrapper import TORCH_VERSION | |
| if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( | |
| '1.7.0'): | |
| # Modified from https://github.com/pytorch/pytorch/blob/master/torch/hub.py | |
| import os | |
| import sys | |
| import warnings | |
| import zipfile | |
| from urllib.parse import urlparse | |
| import torch | |
| from torch.hub import HASH_REGEX, _get_torch_home, download_url_to_file | |
| # Hub used to support automatically extracts from zipfile manually | |
| # compressed by users. The legacy zip format expects only one file from | |
| # torch.save() < 1.6 in the zip. We should remove this support since | |
| # zipfile is now default zipfile format for torch.save(). | |
| def _is_legacy_zip_format(filename): | |
| if zipfile.is_zipfile(filename): | |
| infolist = zipfile.ZipFile(filename).infolist() | |
| return len(infolist) == 1 and not infolist[0].is_dir() | |
| return False | |
| def _legacy_zip_load(filename, model_dir, map_location): | |
| warnings.warn( | |
| 'Falling back to the old format < 1.6. This support will' | |
| ' be deprecated in favor of default zipfile format ' | |
| 'introduced in 1.6. Please redo torch.save() to save it ' | |
| 'in the new zipfile format.', DeprecationWarning) | |
| # Note: extractall() defaults to overwrite file if exists. No need to | |
| # clean up beforehand. We deliberately don't handle tarfile here | |
| # since our legacy serialization format was in tar. | |
| # E.g. resnet18-5c106cde.pth which is widely used. | |
| with zipfile.ZipFile(filename) as f: | |
| members = f.infolist() | |
| if len(members) != 1: | |
| raise RuntimeError( | |
| 'Only one file(not dir) is allowed in the zipfile') | |
| f.extractall(model_dir) | |
| extraced_name = members[0].filename | |
| extracted_file = os.path.join(model_dir, extraced_name) | |
| return torch.load(extracted_file, map_location=map_location) | |
| def load_url(url, | |
| model_dir=None, | |
| map_location=None, | |
| progress=True, | |
| check_hash=False, | |
| file_name=None): | |
| r"""Loads the Torch serialized object at the given URL. | |
| If downloaded file is a zip file, it will be automatically decompressed | |
| If the object is already present in `model_dir`, it's deserialized and | |
| returned. | |
| The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where | |
| ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. | |
| Args: | |
| url (str): URL of the object to download | |
| model_dir (str, optional): directory in which to save the object | |
| map_location (optional): a function or a dict specifying how to | |
| remap storage locations (see torch.load) | |
| progress (bool, optional): whether or not to display a progress bar | |
| to stderr. Defaults to True | |
| check_hash(bool, optional): If True, the filename part of the URL | |
| should follow the naming convention ``filename-<sha256>.ext`` | |
| where ``<sha256>`` is the first eight or more digits of the | |
| SHA256 hash of the contents of the file. The hash is used to | |
| ensure unique names and to verify the contents of the file. | |
| Defaults to False | |
| file_name (str, optional): name for the downloaded file. Filename | |
| from ``url`` will be used if not set. Defaults to None. | |
| Example: | |
| >>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106' | |
| ... 'cde.pth') | |
| >>> state_dict = torch.hub.load_state_dict_from_url(url) | |
| """ | |
| # Issue warning to move data if old env is set | |
| if os.getenv('TORCH_MODEL_ZOO'): | |
| warnings.warn( | |
| 'TORCH_MODEL_ZOO is deprecated, please use env ' | |
| 'TORCH_HOME instead', DeprecationWarning) | |
| if model_dir is None: | |
| torch_home = _get_torch_home() | |
| model_dir = os.path.join(torch_home, 'checkpoints') | |
| mkdir_or_exist(model_dir) | |
| parts = urlparse(url) | |
| filename = os.path.basename(parts.path) | |
| if file_name is not None: | |
| filename = file_name | |
| cached_file = os.path.join(model_dir, filename) | |
| if not os.path.exists(cached_file): | |
| sys.stderr.write('Downloading: "{}" to {}\n'.format( | |
| url, cached_file)) | |
| hash_prefix = None | |
| if check_hash: | |
| r = HASH_REGEX.search(filename) # r is Optional[Match[str]] | |
| hash_prefix = r.group(1) if r else None | |
| download_url_to_file( | |
| url, cached_file, hash_prefix, progress=progress) | |
| if _is_legacy_zip_format(cached_file): | |
| return _legacy_zip_load(cached_file, model_dir, map_location) | |
| try: | |
| return torch.load(cached_file, map_location=map_location) | |
| except RuntimeError as error: | |
| if digit_version(TORCH_VERSION) < digit_version('1.5.0'): | |
| warnings.warn( | |
| f'If the error is the same as "{cached_file} is a zip ' | |
| 'archive (did you mean to use torch.jit.load()?)", you can' | |
| ' upgrade your torch to 1.5.0 or higher (current torch ' | |
| f'version is {TORCH_VERSION}). The error was raised ' | |
| ' because the checkpoint was saved in torch>=1.6.0 but ' | |
| 'loaded in torch<1.5.') | |
| raise error | |
| else: | |
| from torch.utils.model_zoo import load_url # type: ignore # noqa: F401 | |