|
import os, hashlib |
|
import requests |
|
from tqdm import tqdm |
|
import importlib |
|
|
|
URL_MAP = { |
|
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" |
|
} |
|
|
|
CKPT_MAP = { |
|
"vgg_lpips": "vgg.pth" |
|
} |
|
|
|
MD5_MAP = { |
|
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a" |
|
} |
|
|
|
|
|
def get_obj_from_str(string, reload=False): |
|
module, cls = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
|
def instantiate_from_config(config): |
|
if not "target" in config: |
|
raise KeyError("Expected key `target` to instantiate.") |
|
return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
|
def download(url, local_path, chunk_size=1024): |
|
os.makedirs(os.path.split(local_path)[0], exist_ok=True) |
|
with requests.get(url, stream=True) as r: |
|
total_size = int(r.headers.get("content-length", 0)) |
|
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: |
|
with open(local_path, "wb") as f: |
|
for data in r.iter_content(chunk_size=chunk_size): |
|
if data: |
|
f.write(data) |
|
pbar.update(chunk_size) |
|
|
|
|
|
def md5_hash(path): |
|
with open(path, "rb") as f: |
|
content = f.read() |
|
return hashlib.md5(content).hexdigest() |
|
|
|
|
|
def get_ckpt_path(name, root, check=False): |
|
assert name in URL_MAP |
|
path = os.path.join(root, CKPT_MAP[name]) |
|
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): |
|
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) |
|
download(URL_MAP[name], path) |
|
md5 = md5_hash(path) |
|
assert md5 == MD5_MAP[name], md5 |
|
return path |
|
|
|
|
|
class KeyNotFoundError(Exception): |
|
def __init__(self, cause, keys=None, visited=None): |
|
self.cause = cause |
|
self.keys = keys |
|
self.visited = visited |
|
messages = list() |
|
if keys is not None: |
|
messages.append("Key not found: {}".format(keys)) |
|
if visited is not None: |
|
messages.append("Visited: {}".format(visited)) |
|
messages.append("Cause:\n{}".format(cause)) |
|
message = "\n".join(messages) |
|
super().__init__(message) |
|
|
|
|
|
def retrieve( |
|
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False |
|
): |
|
"""Given a nested list or dict return the desired value at key expanding |
|
callable nodes if necessary and :attr:`expand` is ``True``. The expansion |
|
is done in-place. |
|
|
|
Parameters |
|
---------- |
|
list_or_dict : list or dict |
|
Possibly nested list or dictionary. |
|
key : str |
|
key/to/value, path like string describing all keys necessary to |
|
consider to get to the desired value. List indices can also be |
|
passed here. |
|
splitval : str |
|
String that defines the delimiter between keys of the |
|
different depth levels in `key`. |
|
default : obj |
|
Value returned if :attr:`key` is not found. |
|
expand : bool |
|
Whether to expand callable nodes on the path or not. |
|
|
|
Returns |
|
------- |
|
The desired value or if :attr:`default` is not ``None`` and the |
|
:attr:`key` is not found returns ``default``. |
|
|
|
Raises |
|
------ |
|
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is |
|
``None``. |
|
""" |
|
|
|
keys = key.split(splitval) |
|
|
|
success = True |
|
try: |
|
visited = [] |
|
parent = None |
|
last_key = None |
|
for key in keys: |
|
if callable(list_or_dict): |
|
if not expand: |
|
raise KeyNotFoundError( |
|
ValueError( |
|
"Trying to get past callable node with expand=False." |
|
), |
|
keys=keys, |
|
visited=visited, |
|
) |
|
list_or_dict = list_or_dict() |
|
parent[last_key] = list_or_dict |
|
|
|
last_key = key |
|
parent = list_or_dict |
|
|
|
try: |
|
if isinstance(list_or_dict, dict): |
|
list_or_dict = list_or_dict[key] |
|
else: |
|
list_or_dict = list_or_dict[int(key)] |
|
except (KeyError, IndexError, ValueError) as e: |
|
raise KeyNotFoundError(e, keys=keys, visited=visited) |
|
|
|
visited += [key] |
|
|
|
if expand and callable(list_or_dict): |
|
list_or_dict = list_or_dict() |
|
parent[last_key] = list_or_dict |
|
except KeyNotFoundError as e: |
|
if default is None: |
|
raise e |
|
else: |
|
list_or_dict = default |
|
success = False |
|
|
|
if not pass_success: |
|
return list_or_dict |
|
else: |
|
return list_or_dict, success |
|
|
|
|
|
if __name__ == "__main__": |
|
config = {"keya": "a", |
|
"keyb": "b", |
|
"keyc": |
|
{"cc1": 1, |
|
"cc2": 2, |
|
} |
|
} |
|
from omegaconf import OmegaConf |
|
|
|
config = OmegaConf.create(config) |
|
print(config) |
|
retrieve(config, "keya") |
|
|