Spaces:
Runtime error
Runtime error
""" | |
Utilities for working with the local dataset cache. | |
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp | |
Copyright by the AllenNLP authors. | |
""" | |
import fnmatch | |
import json | |
import logging | |
import os | |
import sys | |
import tempfile | |
from contextlib import contextmanager | |
from functools import partial, wraps | |
from hashlib import sha256 | |
from typing import Optional | |
from urllib.parse import urlparse | |
import boto3 | |
import requests | |
from botocore.config import Config | |
from botocore.exceptions import ClientError | |
from filelock import FileLock | |
from tqdm.auto import tqdm | |
from . import __version__ | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
try: | |
USE_TF = os.environ.get("USE_TF", "AUTO").upper() | |
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | |
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): | |
import torch | |
_torch_available = True # pylint: disable=invalid-name | |
logger.info("PyTorch version {} available.".format(torch.__version__)) | |
else: | |
logger.info("Disabling PyTorch because USE_TF is set") | |
_torch_available = False | |
except ImportError: | |
_torch_available = False # pylint: disable=invalid-name | |
try: | |
USE_TF = os.environ.get("USE_TF", "AUTO").upper() | |
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | |
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): | |
import tensorflow as tf | |
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 | |
_tf_available = True # pylint: disable=invalid-name | |
logger.info("TensorFlow version {} available.".format(tf.__version__)) | |
else: | |
logger.info("Disabling Tensorflow because USE_TORCH is set") | |
_tf_available = False | |
except (ImportError, AssertionError): | |
_tf_available = False # pylint: disable=invalid-name | |
try: | |
from torch.hub import _get_torch_home | |
torch_cache_home = _get_torch_home() | |
except ImportError: | |
torch_cache_home = os.path.expanduser( | |
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) | |
) | |
default_cache_path = os.path.join(torch_cache_home, "transformers") | |
try: | |
from pathlib import Path | |
PYTORCH_PRETRAINED_BERT_CACHE = Path( | |
os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)) | |
) | |
except (AttributeError, ImportError): | |
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( | |
"PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) | |
) | |
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility | |
TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility | |
WEIGHTS_NAME = "pytorch_model.bin" | |
TF2_WEIGHTS_NAME = "tf_model.h5" | |
TF_WEIGHTS_NAME = "model.ckpt" | |
CONFIG_NAME = "config.json" | |
MODEL_CARD_NAME = "modelcard.json" | |
MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] | |
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] | |
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] | |
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" | |
CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net" | |
def is_torch_available(): | |
return _torch_available | |
def is_tf_available(): | |
return _tf_available | |
def add_start_docstrings(*docstr): | |
def docstring_decorator(fn): | |
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") | |
return fn | |
return docstring_decorator | |
def add_start_docstrings_to_callable(*docstr): | |
def docstring_decorator(fn): | |
class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) | |
intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name) | |
note = r""" | |
.. note:: | |
Although the recipe for forward pass needs to be defined within | |
this function, one should call the :class:`Module` instance afterwards | |
instead of this since the former takes care of running the | |
pre and post processing steps while the latter silently ignores them. | |
""" | |
fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") | |
return fn | |
return docstring_decorator | |
def add_end_docstrings(*docstr): | |
def docstring_decorator(fn): | |
fn.__doc__ = fn.__doc__ + "".join(docstr) | |
return fn | |
return docstring_decorator | |
def is_remote_url(url_or_filename): | |
parsed = urlparse(url_or_filename) | |
return parsed.scheme in ("http", "https", "s3") | |
def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: | |
endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX | |
if postfix is None: | |
return "/".join((endpoint, identifier)) | |
else: | |
return "/".join((endpoint, identifier, postfix)) | |
def url_to_filename(url, etag=None): | |
""" | |
Convert `url` into a hashed filename in a repeatable way. | |
If `etag` is specified, append its hash to the url's, delimited | |
by a period. | |
If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name | |
so that TF 2.0 can identify it as a HDF5 file | |
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) | |
""" | |
url_bytes = url.encode("utf-8") | |
url_hash = sha256(url_bytes) | |
filename = url_hash.hexdigest() | |
if etag: | |
etag_bytes = etag.encode("utf-8") | |
etag_hash = sha256(etag_bytes) | |
filename += "." + etag_hash.hexdigest() | |
if url.endswith(".h5"): | |
filename += ".h5" | |
return filename | |
def filename_to_url(filename, cache_dir=None): | |
""" | |
Return the url and etag (which may be ``None``) stored for `filename`. | |
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. | |
""" | |
if cache_dir is None: | |
cache_dir = TRANSFORMERS_CACHE | |
if isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
cache_path = os.path.join(cache_dir, filename) | |
if not os.path.exists(cache_path): | |
raise EnvironmentError("file {} not found".format(cache_path)) | |
meta_path = cache_path + ".json" | |
if not os.path.exists(meta_path): | |
raise EnvironmentError("file {} not found".format(meta_path)) | |
with open(meta_path, encoding="utf-8") as meta_file: | |
metadata = json.load(meta_file) | |
url = metadata["url"] | |
etag = metadata["etag"] | |
return url, etag | |
def cached_path( | |
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None | |
) -> Optional[str]: | |
""" | |
Given something that might be a URL (or might be a local path), | |
determine which. If it's a URL, download the file and cache it, and | |
return the path to the cached file. If it's already a local path, | |
make sure the file exists and then return the path. | |
Args: | |
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). | |
force_download: if True, re-dowload the file even if it's already cached in the cache dir. | |
resume_download: if True, resume the download if incompletly recieved file is found. | |
user_agent: Optional string or dict that will be appended to the user-agent on remote requests. | |
Return: | |
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). | |
Local path (string) otherwise | |
""" | |
if cache_dir is None: | |
cache_dir = TRANSFORMERS_CACHE | |
if isinstance(url_or_filename, Path): | |
url_or_filename = str(url_or_filename) | |
if isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
if is_remote_url(url_or_filename): | |
# URL, so get it from the cache (downloading if necessary) | |
return get_from_cache( | |
url_or_filename, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
user_agent=user_agent, | |
) | |
elif os.path.exists(url_or_filename): | |
# File, and it exists. | |
return url_or_filename | |
elif urlparse(url_or_filename).scheme == "": | |
# File, but it doesn't exist. | |
raise EnvironmentError("file {} not found".format(url_or_filename)) | |
else: | |
# Something unknown | |
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) | |
def split_s3_path(url): | |
"""Split a full s3 path into the bucket name and path.""" | |
parsed = urlparse(url) | |
if not parsed.netloc or not parsed.path: | |
raise ValueError("bad s3 path {}".format(url)) | |
bucket_name = parsed.netloc | |
s3_path = parsed.path | |
# Remove '/' at beginning of path. | |
if s3_path.startswith("/"): | |
s3_path = s3_path[1:] | |
return bucket_name, s3_path | |
def s3_request(func): | |
""" | |
Wrapper function for s3 requests in order to create more helpful error | |
messages. | |
""" | |
def wrapper(url, *args, **kwargs): | |
try: | |
return func(url, *args, **kwargs) | |
except ClientError as exc: | |
if int(exc.response["Error"]["Code"]) == 404: | |
raise EnvironmentError("file {} not found".format(url)) | |
else: | |
raise | |
return wrapper | |
def s3_etag(url, proxies=None): | |
"""Check ETag on S3 object.""" | |
s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) | |
bucket_name, s3_path = split_s3_path(url) | |
s3_object = s3_resource.Object(bucket_name, s3_path) | |
return s3_object.e_tag | |
def s3_get(url, temp_file, proxies=None): | |
"""Pull a file directly from S3.""" | |
s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) | |
bucket_name, s3_path = split_s3_path(url) | |
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) | |
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): | |
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) | |
if is_torch_available(): | |
ua += "; torch/{}".format(torch.__version__) | |
if is_tf_available(): | |
ua += "; tensorflow/{}".format(tf.__version__) | |
if isinstance(user_agent, dict): | |
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) | |
elif isinstance(user_agent, str): | |
ua += "; " + user_agent | |
headers = {"user-agent": ua} | |
if resume_size > 0: | |
headers["Range"] = "bytes=%d-" % (resume_size,) | |
response = requests.get(url, stream=True, proxies=proxies, headers=headers) | |
if response.status_code == 416: # Range not satisfiable | |
return | |
content_length = response.headers.get("Content-Length") | |
total = resume_size + int(content_length) if content_length is not None else None | |
progress = tqdm( | |
unit="B", | |
unit_scale=True, | |
total=total, | |
initial=resume_size, | |
desc="Downloading", | |
disable=bool(logger.getEffectiveLevel() == logging.NOTSET), | |
) | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: # filter out keep-alive new chunks | |
progress.update(len(chunk)) | |
temp_file.write(chunk) | |
progress.close() | |
def get_from_cache( | |
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None | |
) -> Optional[str]: | |
""" | |
Given a URL, look for the corresponding file in the local cache. | |
If it's not there, download it. Then return the path to the cached file. | |
Return: | |
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). | |
Local path (string) otherwise | |
""" | |
if cache_dir is None: | |
cache_dir = TRANSFORMERS_CACHE | |
if isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
os.makedirs(cache_dir, exist_ok=True) | |
# Get eTag to add to filename, if it exists. | |
if url.startswith("s3://"): | |
etag = s3_etag(url, proxies=proxies) | |
else: | |
try: | |
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) | |
if response.status_code != 200: | |
etag = None | |
else: | |
etag = response.headers.get("ETag") | |
except (EnvironmentError, requests.exceptions.Timeout): | |
etag = None | |
filename = url_to_filename(url, etag) | |
# get cache path to put the file | |
cache_path = os.path.join(cache_dir, filename) | |
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. | |
# try to get the last downloaded one | |
if etag is None: | |
if os.path.exists(cache_path): | |
return cache_path | |
else: | |
matching_files = [ | |
file | |
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") | |
if not file.endswith(".json") and not file.endswith(".lock") | |
] | |
if len(matching_files) > 0: | |
return os.path.join(cache_dir, matching_files[-1]) | |
else: | |
return None | |
# From now on, etag is not None. | |
if os.path.exists(cache_path) and not force_download: | |
return cache_path | |
# Prevent parallel downloads of the same file with a lock. | |
lock_path = cache_path + ".lock" | |
with FileLock(lock_path): | |
if resume_download: | |
incomplete_path = cache_path + ".incomplete" | |
def _resumable_file_manager(): | |
with open(incomplete_path, "a+b") as f: | |
yield f | |
temp_file_manager = _resumable_file_manager | |
if os.path.exists(incomplete_path): | |
resume_size = os.stat(incomplete_path).st_size | |
else: | |
resume_size = 0 | |
else: | |
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) | |
resume_size = 0 | |
# Download to temporary file, then copy to cache dir once finished. | |
# Otherwise you get corrupt cache entries if the download gets interrupted. | |
with temp_file_manager() as temp_file: | |
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) | |
# GET file object | |
if url.startswith("s3://"): | |
if resume_download: | |
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') | |
s3_get(url, temp_file, proxies=proxies) | |
else: | |
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) | |
logger.info("storing %s in cache at %s", url, cache_path) | |
os.rename(temp_file.name, cache_path) | |
logger.info("creating metadata file for %s", cache_path) | |
meta = {"url": url, "etag": etag} | |
meta_path = cache_path + ".json" | |
with open(meta_path, "w") as meta_file: | |
json.dump(meta, meta_file) | |
return cache_path | |