ymzhang319's picture
init
7f2690b
raw
history blame
5.1 kB
import hashlib
import os
import requests
from tqdm import tqdm
URL_MAP = {
'vggishish_lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt',
'vggishish_mean_std_melspec_10s_22050hz': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt',
'melception': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt',
}
CKPT_MAP = {
'vggishish_lpaps': 'vggishish16.pt',
'vggishish_mean_std_melspec_10s_22050hz': 'train_means_stds_melspec_10s_22050hz.txt',
'melception': 'melception-21-05-10T09-28-40.pt',
}
MD5_MAP = {
'vggishish_lpaps': '197040c524a07ccacf7715d7080a80bd',
'vggishish_mean_std_melspec_10s_22050hz': 'f449c6fd0e248936c16f6d22492bb625',
'melception': 'a71a41041e945b457c7d3d814bbcf72d',
}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class KeyNotFoundError(Exception):
def __init__(self, cause, keys=None, visited=None):
self.cause = cause
self.keys = keys
self.visited = visited
messages = list()
if keys is not None:
messages.append("Key not found: {}".format(keys))
if visited is not None:
messages.append("Visited: {}".format(visited))
messages.append("Cause:\n{}".format(cause))
message = "\n".join(messages)
super().__init__(message)
def retrieve(
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
"""Given a nested list or dict return the desired value at key expanding
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
is done in-place.
Parameters
----------
list_or_dict : list or dict
Possibly nested list or dictionary.
key : str
key/to/value, path like string describing all keys necessary to
consider to get to the desired value. List indices can also be
passed here.
splitval : str
String that defines the delimiter between keys of the
different depth levels in `key`.
default : obj
Value returned if :attr:`key` is not found.
expand : bool
Whether to expand callable nodes on the path or not.
Returns
-------
The desired value or if :attr:`default` is not ``None`` and the
:attr:`key` is not found returns ``default``.
Raises
------
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
``None``.
"""
keys = key.split(splitval)
success = True
try:
visited = []
parent = None
last_key = None
for key in keys:
if callable(list_or_dict):
if not expand:
raise KeyNotFoundError(
ValueError(
"Trying to get past callable node with expand=False."
),
keys=keys,
visited=visited,
)
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
last_key = key
parent = list_or_dict
try:
if isinstance(list_or_dict, dict):
list_or_dict = list_or_dict[key]
else:
list_or_dict = list_or_dict[int(key)]
except (KeyError, IndexError, ValueError) as e:
raise KeyNotFoundError(e, keys=keys, visited=visited)
visited += [key]
# final expansion of retrieved value
if expand and callable(list_or_dict):
list_or_dict = list_or_dict()
parent[last_key] = list_or_dict
except KeyNotFoundError as e:
if default is None:
raise e
else:
list_or_dict = default
success = False
if not pass_success:
return list_or_dict
else:
return list_or_dict, success