""" coding=utf-8 Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal, Huggingface team :) Adapted From Facebook Inc, Detectron2 Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.import copy """ import copy import fnmatch import json import os import pickle as pkl import shutil import sys import tarfile import tempfile from collections import OrderedDict from contextlib import contextmanager from functools import partial from hashlib import sha256 from io import BytesIO from pathlib import Path from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile import cv2 import numpy as np import requests import wget from filelock import FileLock from PIL import Image from tqdm.auto import tqdm from yaml import Loader, dump, load try: import torch _torch_available = True except ImportError: _torch_available = False try: from torch.hub import _get_torch_home torch_cache_home = _get_torch_home() except ImportError: torch_cache_home = os.path.expanduser( os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) ) default_cache_path = os.path.join(torch_cache_home, "transformers") CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" PATH = "/".join(str(Path(__file__).resolve()).split("/")[:-1]) CONFIG = os.path.join(PATH, "config.yaml") ATTRIBUTES = os.path.join(PATH, "attributes.txt") OBJECTS = os.path.join(PATH, "objects.txt") PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) WEIGHTS_NAME = "pytorch_model.bin" CONFIG_NAME = "config.yaml" def load_labels(objs=OBJECTS, attrs=ATTRIBUTES): vg_classes = [] with open(objs) as f: for object in f.readlines(): vg_classes.append(object.split(",")[0].lower().strip()) vg_attrs = [] with open(attrs) as f: for object in f.readlines(): vg_attrs.append(object.split(",")[0].lower().strip()) return vg_classes, vg_attrs def load_checkpoint(ckp): r = OrderedDict() with open(ckp, "rb") as f: ckp = pkl.load(f)["model"] for k in copy.deepcopy(list(ckp.keys())): v = ckp.pop(k) if isinstance(v, np.ndarray): v = torch.tensor(v) else: assert isinstance(v, torch.tensor), type(v) r[k] = v return r class Config: _pointer = {} def __init__(self, dictionary: dict, name: str = "root", level=0): self._name = name self._level = level d = {} for k, v in dictionary.items(): if v is None: raise ValueError() k = copy.deepcopy(k) v = copy.deepcopy(v) if isinstance(v, dict): v = Config(v, name=k, level=level + 1) d[k] = v setattr(self, k, v) self._pointer = d def __repr__(self): return str(list((self._pointer.keys()))) def __setattr__(self, key, val): self.__dict__[key] = val self.__dict__[key.upper()] = val levels = key.split(".") last_level = len(levels) - 1 pointer = self._pointer if len(levels) > 1: for i, l in enumerate(levels): if hasattr(self, l) and isinstance(getattr(self, l), Config): setattr(getattr(self, l), ".".join(levels[i:]), val) if l == last_level: pointer[l] = val else: pointer = pointer[l] def to_dict(self): return self._pointer def dump_yaml(self, data, file_name): with open(f"{file_name}", "w") as stream: dump(data, stream) def dump_json(self, data, file_name): with open(f"{file_name}", "w") as stream: json.dump(data, stream) @staticmethod def load_yaml(config): with open(config) as stream: data = load(stream, Loader=Loader) return data def __str__(self): t = " " if self._name != "root": r = f"{t * (self._level-1)}{self._name}:\n" else: r = "" level = self._level for i, (k, v) in enumerate(self._pointer.items()): if isinstance(v, Config): r += f"{t * (self._level)}{v}\n" self._level += 1 else: r += f"{t * (self._level)}{k}: {v} ({type(v).__name__})\n" self._level = level return r[:-1] @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) return cls(config_dict) @classmethod def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs): cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path else: config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) try: # Load from URL or cache if already cached resolved_config_file = cached_path( config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) # Load config dict if resolved_config_file is None: raise EnvironmentError config_file = Config.load_yaml(resolved_config_file) except EnvironmentError: msg = "Can't load config for" raise EnvironmentError(msg) if resolved_config_file == config_file: print("loading configuration file from path") else: print("loading configuration file cache") return Config.load_yaml(resolved_config_file), kwargs # quick compare tensors def compare(in_tensor): out_tensor = torch.load("dump.pt", map_location=in_tensor.device) n1 = in_tensor.numpy() n2 = out_tensor.numpy()[0] print(n1.shape, n1[0, 0, :5]) print(n2.shape, n2[0, 0, :5]) assert np.allclose(n1, n2, rtol=0.01, atol=0.1), ( f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x is False])/len(n1.flatten())*100:.4f} %" " element-wise mismatch" ) raise Exception("tensors are all good") # Hugging face functions below def is_remote_url(url_or_filename): parsed = urlparse(url_or_filename) return parsed.scheme in ("http", "https") def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX legacy_format = "/" not in model_id if legacy_format: return f"{endpoint}/{model_id}-{filename}" else: return f"{endpoint}/{model_id}/{filename}" def http_get( url, temp_file, proxies=None, resume_size=0, user_agent=None, ): ua = "python/{}".format(sys.version.split()[0]) if _torch_available: ua += "; torch/{}".format(torch.__version__) if isinstance(user_agent, dict): ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent headers = {"user-agent": ua} if resume_size > 0: headers["Range"] = "bytes=%d-" % (resume_size,) response = requests.get(url, stream=True, proxies=proxies, headers=headers) if response.status_code == 416: # Range not satisfiable return content_length = response.headers.get("Content-Length") total = resume_size + int(content_length) if content_length is not None else None progress = tqdm( unit="B", unit_scale=True, total=total, initial=resume_size, desc="Downloading", ) for chunk in response.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) temp_file.write(chunk) progress.close() def get_from_cache( url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None, local_files_only=False, ): if cache_dir is None: cache_dir = TRANSFORMERS_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) os.makedirs(cache_dir, exist_ok=True) etag = None if not local_files_only: try: response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) if response.status_code == 200: etag = response.headers.get("ETag") except (EnvironmentError, requests.exceptions.Timeout): # etag is already None pass filename = url_to_filename(url, etag) # get cache path to put the file cache_path = os.path.join(cache_dir, filename) # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. # try to get the last downloaded one if etag is None: if os.path.exists(cache_path): return cache_path else: matching_files = [ file for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") if not file.endswith(".json") and not file.endswith(".lock") ] if len(matching_files) > 0: return os.path.join(cache_dir, matching_files[-1]) else: # If files cannot be found and local_files_only=True, # the models might've been found if local_files_only=False # Notify the user about that if local_files_only: raise ValueError( "Cannot find the requested files in the cached path and outgoing traffic has been" " disabled. To enable model look-ups and downloads online, set 'local_files_only'" " to False." ) return None # From now on, etag is not None. if os.path.exists(cache_path) and not force_download: return cache_path # Prevent parallel downloads of the same file with a lock. lock_path = cache_path + ".lock" with FileLock(lock_path): # If the download just completed while the lock was activated. if os.path.exists(cache_path) and not force_download: # Even if returning early like here, the lock will be released. return cache_path if resume_download: incomplete_path = cache_path + ".incomplete" @contextmanager def _resumable_file_manager(): with open(incomplete_path, "a+b") as f: yield f temp_file_manager = _resumable_file_manager if os.path.exists(incomplete_path): resume_size = os.stat(incomplete_path).st_size else: resume_size = 0 else: temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) resume_size = 0 # Download to temporary file, then copy to cache dir once finished. # Otherwise you get corrupt cache entries if the download gets interrupted. with temp_file_manager() as temp_file: print( "%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name, ) http_get( url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent, ) os.replace(temp_file.name, cache_path) meta = {"url": url, "etag": etag} meta_path = cache_path + ".json" with open(meta_path, "w") as meta_file: json.dump(meta, meta_file) return cache_path def url_to_filename(url, etag=None): url_bytes = url.encode("utf-8") url_hash = sha256(url_bytes) filename = url_hash.hexdigest() if etag: etag_bytes = etag.encode("utf-8") etag_hash = sha256(etag_bytes) filename += "." + etag_hash.hexdigest() if url.endswith(".h5"): filename += ".h5" return filename def cached_path( url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None, extract_compressed_file=False, force_extract=False, local_files_only=False, ): if cache_dir is None: cache_dir = TRANSFORMERS_CACHE if isinstance(url_or_filename, Path): url_or_filename = str(url_or_filename) if isinstance(cache_dir, Path): cache_dir = str(cache_dir) if is_remote_url(url_or_filename): # URL, so get it from the cache (downloading if necessary) output_path = get_from_cache( url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, user_agent=user_agent, local_files_only=local_files_only, ) elif os.path.exists(url_or_filename): # File, and it exists. output_path = url_or_filename elif urlparse(url_or_filename).scheme == "": # File, but it doesn't exist. raise EnvironmentError("file {} not found".format(url_or_filename)) else: # Something unknown raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) if extract_compressed_file: if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): return output_path # Path where we extract compressed archives # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/" output_dir, output_file = os.path.split(output_path) output_extract_dir_name = output_file.replace(".", "-") + "-extracted" output_path_extracted = os.path.join(output_dir, output_extract_dir_name) if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: return output_path_extracted # Prevent parallel extractions lock_path = output_path + ".lock" with FileLock(lock_path): shutil.rmtree(output_path_extracted, ignore_errors=True) os.makedirs(output_path_extracted) if is_zipfile(output_path): with ZipFile(output_path, "r") as zip_file: zip_file.extractall(output_path_extracted) zip_file.close() elif tarfile.is_tarfile(output_path): tar_file = tarfile.open(output_path) tar_file.extractall(output_path_extracted) tar_file.close() else: raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) return output_path_extracted return output_path def get_data(query, delim=","): assert isinstance(query, str) if os.path.isfile(query): with open(query) as f: data = eval(f.read()) else: req = requests.get(query) try: data = requests.json() except Exception: data = req.content.decode() assert data is not None, "could not connect" try: data = eval(data) except Exception: data = data.split("\n") req.close() return data def get_image_from_url(url): response = requests.get(url) img = np.array(Image.open(BytesIO(response.content))) return img # to load legacy frcnn checkpoint from detectron def load_frcnn_pkl_from_url(url): fn = url.split("/")[-1] if fn not in os.listdir(os.getcwd()): wget.download(url) with open(fn, "rb") as stream: weights = pkl.load(stream) model = weights.pop("model") new = {} for k, v in model.items(): new[k] = torch.from_numpy(v) if "running_var" in k: zero = torch.tensor([0]) k2 = k.replace("running_var", "num_batches_tracked") new[k2] = zero return new def get_demo_path(): print(f"{os.path.abspath(os.path.join(PATH, os.pardir))}/demo.ipynb") def img_tensorize(im, input_format="RGB"): assert isinstance(im, str) if os.path.isfile(im): img = cv2.imread(im) else: img = get_image_from_url(im) assert img is not None, f"could not connect to: {im}" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if input_format == "RGB": img = img[:, :, ::-1] return img def chunk(images, batch=1): return (images[i : i + batch] for i in range(0, len(images), batch))