|
""" |
|
coding=utf-8 |
|
Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal, Huggingface team :) |
|
Adapted From Facebook Inc, Detectron2 |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License.import copy |
|
""" |
|
|
|
import copy |
|
import fnmatch |
|
import json |
|
import os |
|
import pickle as pkl |
|
import shutil |
|
import sys |
|
import tarfile |
|
import tempfile |
|
from collections import OrderedDict |
|
from contextlib import contextmanager |
|
from functools import partial |
|
from hashlib import sha256 |
|
from io import BytesIO |
|
from pathlib import Path |
|
from urllib.parse import urlparse |
|
from zipfile import ZipFile, is_zipfile |
|
|
|
import cv2 |
|
import numpy as np |
|
import requests |
|
import wget |
|
from filelock import FileLock |
|
from PIL import Image |
|
from tqdm.auto import tqdm |
|
from yaml import Loader, dump, load |
|
|
|
|
|
try: |
|
import torch |
|
|
|
_torch_available = True |
|
except ImportError: |
|
_torch_available = False |
|
|
|
|
|
try: |
|
from torch.hub import _get_torch_home |
|
|
|
torch_cache_home = _get_torch_home() |
|
except ImportError: |
|
torch_cache_home = os.path.expanduser( |
|
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) |
|
) |
|
|
|
default_cache_path = os.path.join(torch_cache_home, "transformers") |
|
|
|
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" |
|
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" |
|
PATH = "/".join(str(Path(__file__).resolve()).split("/")[:-1]) |
|
CONFIG = os.path.join(PATH, "config.yaml") |
|
ATTRIBUTES = os.path.join(PATH, "attributes.txt") |
|
OBJECTS = os.path.join(PATH, "objects.txt") |
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) |
|
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) |
|
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) |
|
WEIGHTS_NAME = "pytorch_model.bin" |
|
CONFIG_NAME = "config.yaml" |
|
|
|
|
|
def load_labels(objs=OBJECTS, attrs=ATTRIBUTES): |
|
vg_classes = [] |
|
with open(objs) as f: |
|
for object in f.readlines(): |
|
vg_classes.append(object.split(",")[0].lower().strip()) |
|
|
|
vg_attrs = [] |
|
with open(attrs) as f: |
|
for object in f.readlines(): |
|
vg_attrs.append(object.split(",")[0].lower().strip()) |
|
return vg_classes, vg_attrs |
|
|
|
|
|
def load_checkpoint(ckp): |
|
r = OrderedDict() |
|
with open(ckp, "rb") as f: |
|
ckp = pkl.load(f)["model"] |
|
for k in copy.deepcopy(list(ckp.keys())): |
|
v = ckp.pop(k) |
|
if isinstance(v, np.ndarray): |
|
v = torch.tensor(v) |
|
else: |
|
assert isinstance(v, torch.tensor), type(v) |
|
r[k] = v |
|
return r |
|
|
|
|
|
class Config: |
|
_pointer = {} |
|
|
|
def __init__(self, dictionary: dict, name: str = "root", level=0): |
|
self._name = name |
|
self._level = level |
|
d = {} |
|
for k, v in dictionary.items(): |
|
if v is None: |
|
raise ValueError() |
|
k = copy.deepcopy(k) |
|
v = copy.deepcopy(v) |
|
if isinstance(v, dict): |
|
v = Config(v, name=k, level=level + 1) |
|
d[k] = v |
|
setattr(self, k, v) |
|
|
|
self._pointer = d |
|
|
|
def __repr__(self): |
|
return str(list((self._pointer.keys()))) |
|
|
|
def __setattr__(self, key, val): |
|
self.__dict__[key] = val |
|
self.__dict__[key.upper()] = val |
|
levels = key.split(".") |
|
last_level = len(levels) - 1 |
|
pointer = self._pointer |
|
if len(levels) > 1: |
|
for i, l in enumerate(levels): |
|
if hasattr(self, l) and isinstance(getattr(self, l), Config): |
|
setattr(getattr(self, l), ".".join(levels[i:]), val) |
|
if l == last_level: |
|
pointer[l] = val |
|
else: |
|
pointer = pointer[l] |
|
|
|
def to_dict(self): |
|
return self._pointer |
|
|
|
def dump_yaml(self, data, file_name): |
|
with open(f"{file_name}", "w") as stream: |
|
dump(data, stream) |
|
|
|
def dump_json(self, data, file_name): |
|
with open(f"{file_name}", "w") as stream: |
|
json.dump(data, stream) |
|
|
|
@staticmethod |
|
def load_yaml(config): |
|
with open(config) as stream: |
|
data = load(stream, Loader=Loader) |
|
return data |
|
|
|
def __str__(self): |
|
t = " " |
|
if self._name != "root": |
|
r = f"{t * (self._level-1)}{self._name}:\n" |
|
else: |
|
r = "" |
|
level = self._level |
|
for i, (k, v) in enumerate(self._pointer.items()): |
|
if isinstance(v, Config): |
|
r += f"{t * (self._level)}{v}\n" |
|
self._level += 1 |
|
else: |
|
r += f"{t * (self._level)}{k}: {v} ({type(v).__name__})\n" |
|
self._level = level |
|
return r[:-1] |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
|
return cls(config_dict) |
|
|
|
@classmethod |
|
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs): |
|
cache_dir = kwargs.pop("cache_dir", None) |
|
force_download = kwargs.pop("force_download", False) |
|
resume_download = kwargs.pop("resume_download", False) |
|
proxies = kwargs.pop("proxies", None) |
|
local_files_only = kwargs.pop("local_files_only", False) |
|
|
|
if os.path.isdir(pretrained_model_name_or_path): |
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) |
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): |
|
config_file = pretrained_model_name_or_path |
|
else: |
|
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) |
|
|
|
try: |
|
|
|
resolved_config_file = cached_path( |
|
config_file, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
local_files_only=local_files_only, |
|
) |
|
|
|
if resolved_config_file is None: |
|
raise EnvironmentError |
|
|
|
config_file = Config.load_yaml(resolved_config_file) |
|
|
|
except EnvironmentError: |
|
msg = "Can't load config for" |
|
raise EnvironmentError(msg) |
|
|
|
if resolved_config_file == config_file: |
|
print("loading configuration file from path") |
|
else: |
|
print("loading configuration file cache") |
|
|
|
return Config.load_yaml(resolved_config_file), kwargs |
|
|
|
|
|
|
|
def compare(in_tensor): |
|
out_tensor = torch.load("dump.pt", map_location=in_tensor.device) |
|
n1 = in_tensor.numpy() |
|
n2 = out_tensor.numpy()[0] |
|
print(n1.shape, n1[0, 0, :5]) |
|
print(n2.shape, n2[0, 0, :5]) |
|
assert np.allclose(n1, n2, rtol=0.01, atol=0.1), ( |
|
f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x is False])/len(n1.flatten())*100:.4f} %" |
|
" element-wise mismatch" |
|
) |
|
raise Exception("tensors are all good") |
|
|
|
|
|
|
|
|
|
def is_remote_url(url_or_filename): |
|
parsed = urlparse(url_or_filename) |
|
return parsed.scheme in ("http", "https") |
|
|
|
|
|
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: |
|
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX |
|
legacy_format = "/" not in model_id |
|
if legacy_format: |
|
return f"{endpoint}/{model_id}-{filename}" |
|
else: |
|
return f"{endpoint}/{model_id}/{filename}" |
|
|
|
|
|
def http_get( |
|
url, |
|
temp_file, |
|
proxies=None, |
|
resume_size=0, |
|
user_agent=None, |
|
): |
|
ua = "python/{}".format(sys.version.split()[0]) |
|
if _torch_available: |
|
ua += "; torch/{}".format(torch.__version__) |
|
if isinstance(user_agent, dict): |
|
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) |
|
elif isinstance(user_agent, str): |
|
ua += "; " + user_agent |
|
headers = {"user-agent": ua} |
|
if resume_size > 0: |
|
headers["Range"] = "bytes=%d-" % (resume_size,) |
|
response = requests.get(url, stream=True, proxies=proxies, headers=headers) |
|
if response.status_code == 416: |
|
return |
|
content_length = response.headers.get("Content-Length") |
|
total = resume_size + int(content_length) if content_length is not None else None |
|
progress = tqdm( |
|
unit="B", |
|
unit_scale=True, |
|
total=total, |
|
initial=resume_size, |
|
desc="Downloading", |
|
) |
|
for chunk in response.iter_content(chunk_size=1024): |
|
if chunk: |
|
progress.update(len(chunk)) |
|
temp_file.write(chunk) |
|
progress.close() |
|
|
|
|
|
def get_from_cache( |
|
url, |
|
cache_dir=None, |
|
force_download=False, |
|
proxies=None, |
|
etag_timeout=10, |
|
resume_download=False, |
|
user_agent=None, |
|
local_files_only=False, |
|
): |
|
if cache_dir is None: |
|
cache_dir = TRANSFORMERS_CACHE |
|
if isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
etag = None |
|
if not local_files_only: |
|
try: |
|
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) |
|
if response.status_code == 200: |
|
etag = response.headers.get("ETag") |
|
except (EnvironmentError, requests.exceptions.Timeout): |
|
|
|
pass |
|
|
|
filename = url_to_filename(url, etag) |
|
|
|
|
|
cache_path = os.path.join(cache_dir, filename) |
|
|
|
|
|
|
|
if etag is None: |
|
if os.path.exists(cache_path): |
|
return cache_path |
|
else: |
|
matching_files = [ |
|
file |
|
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") |
|
if not file.endswith(".json") and not file.endswith(".lock") |
|
] |
|
if len(matching_files) > 0: |
|
return os.path.join(cache_dir, matching_files[-1]) |
|
else: |
|
|
|
|
|
|
|
if local_files_only: |
|
raise ValueError( |
|
"Cannot find the requested files in the cached path and outgoing traffic has been" |
|
" disabled. To enable model look-ups and downloads online, set 'local_files_only'" |
|
" to False." |
|
) |
|
return None |
|
|
|
|
|
if os.path.exists(cache_path) and not force_download: |
|
return cache_path |
|
|
|
|
|
lock_path = cache_path + ".lock" |
|
with FileLock(lock_path): |
|
|
|
if os.path.exists(cache_path) and not force_download: |
|
|
|
return cache_path |
|
|
|
if resume_download: |
|
incomplete_path = cache_path + ".incomplete" |
|
|
|
@contextmanager |
|
def _resumable_file_manager(): |
|
with open(incomplete_path, "a+b") as f: |
|
yield f |
|
|
|
temp_file_manager = _resumable_file_manager |
|
if os.path.exists(incomplete_path): |
|
resume_size = os.stat(incomplete_path).st_size |
|
else: |
|
resume_size = 0 |
|
else: |
|
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) |
|
resume_size = 0 |
|
|
|
|
|
|
|
with temp_file_manager() as temp_file: |
|
print( |
|
"%s not found in cache or force_download set to True, downloading to %s", |
|
url, |
|
temp_file.name, |
|
) |
|
|
|
http_get( |
|
url, |
|
temp_file, |
|
proxies=proxies, |
|
resume_size=resume_size, |
|
user_agent=user_agent, |
|
) |
|
|
|
os.replace(temp_file.name, cache_path) |
|
|
|
meta = {"url": url, "etag": etag} |
|
meta_path = cache_path + ".json" |
|
with open(meta_path, "w") as meta_file: |
|
json.dump(meta, meta_file) |
|
|
|
return cache_path |
|
|
|
|
|
def url_to_filename(url, etag=None): |
|
url_bytes = url.encode("utf-8") |
|
url_hash = sha256(url_bytes) |
|
filename = url_hash.hexdigest() |
|
|
|
if etag: |
|
etag_bytes = etag.encode("utf-8") |
|
etag_hash = sha256(etag_bytes) |
|
filename += "." + etag_hash.hexdigest() |
|
|
|
if url.endswith(".h5"): |
|
filename += ".h5" |
|
|
|
return filename |
|
|
|
|
|
def cached_path( |
|
url_or_filename, |
|
cache_dir=None, |
|
force_download=False, |
|
proxies=None, |
|
resume_download=False, |
|
user_agent=None, |
|
extract_compressed_file=False, |
|
force_extract=False, |
|
local_files_only=False, |
|
): |
|
if cache_dir is None: |
|
cache_dir = TRANSFORMERS_CACHE |
|
if isinstance(url_or_filename, Path): |
|
url_or_filename = str(url_or_filename) |
|
if isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
if is_remote_url(url_or_filename): |
|
|
|
output_path = get_from_cache( |
|
url_or_filename, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
user_agent=user_agent, |
|
local_files_only=local_files_only, |
|
) |
|
elif os.path.exists(url_or_filename): |
|
|
|
output_path = url_or_filename |
|
elif urlparse(url_or_filename).scheme == "": |
|
|
|
raise EnvironmentError("file {} not found".format(url_or_filename)) |
|
else: |
|
|
|
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) |
|
|
|
if extract_compressed_file: |
|
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): |
|
return output_path |
|
|
|
|
|
|
|
output_dir, output_file = os.path.split(output_path) |
|
output_extract_dir_name = output_file.replace(".", "-") + "-extracted" |
|
output_path_extracted = os.path.join(output_dir, output_extract_dir_name) |
|
|
|
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: |
|
return output_path_extracted |
|
|
|
|
|
lock_path = output_path + ".lock" |
|
with FileLock(lock_path): |
|
shutil.rmtree(output_path_extracted, ignore_errors=True) |
|
os.makedirs(output_path_extracted) |
|
if is_zipfile(output_path): |
|
with ZipFile(output_path, "r") as zip_file: |
|
zip_file.extractall(output_path_extracted) |
|
zip_file.close() |
|
elif tarfile.is_tarfile(output_path): |
|
tar_file = tarfile.open(output_path) |
|
tar_file.extractall(output_path_extracted) |
|
tar_file.close() |
|
else: |
|
raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) |
|
|
|
return output_path_extracted |
|
|
|
return output_path |
|
|
|
|
|
def get_data(query, delim=","): |
|
assert isinstance(query, str) |
|
if os.path.isfile(query): |
|
with open(query) as f: |
|
data = eval(f.read()) |
|
else: |
|
req = requests.get(query) |
|
try: |
|
data = requests.json() |
|
except Exception: |
|
data = req.content.decode() |
|
assert data is not None, "could not connect" |
|
try: |
|
data = eval(data) |
|
except Exception: |
|
data = data.split("\n") |
|
req.close() |
|
return data |
|
|
|
|
|
def get_image_from_url(url): |
|
response = requests.get(url) |
|
img = np.array(Image.open(BytesIO(response.content))) |
|
return img |
|
|
|
|
|
|
|
def load_frcnn_pkl_from_url(url): |
|
fn = url.split("/")[-1] |
|
if fn not in os.listdir(os.getcwd()): |
|
wget.download(url) |
|
with open(fn, "rb") as stream: |
|
weights = pkl.load(stream) |
|
model = weights.pop("model") |
|
new = {} |
|
for k, v in model.items(): |
|
new[k] = torch.from_numpy(v) |
|
if "running_var" in k: |
|
zero = torch.tensor([0]) |
|
k2 = k.replace("running_var", "num_batches_tracked") |
|
new[k2] = zero |
|
return new |
|
|
|
|
|
def get_demo_path(): |
|
print(f"{os.path.abspath(os.path.join(PATH, os.pardir))}/demo.ipynb") |
|
|
|
|
|
def img_tensorize(im, input_format="RGB"): |
|
assert isinstance(im, str) |
|
if os.path.isfile(im): |
|
img = cv2.imread(im) |
|
else: |
|
img = get_image_from_url(im) |
|
assert img is not None, f"could not connect to: {im}" |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
if input_format == "RGB": |
|
img = img[:, :, ::-1] |
|
return img |
|
|
|
|
|
def chunk(images, batch=1): |
|
return (images[i : i + batch] for i in range(0, len(images), batch)) |
|
|