meow
init
d6d3a5b
raw
history blame
No virus
3.07 kB
import itertools
import numpy as np
import torch
def sort_dict(disordered):
sorted_dict = {k: disordered[k] for k in sorted(disordered)}
return sorted_dict
def prefix_dict(mydict, prefix):
out = {prefix + k: v for k, v in mydict.items()}
return out
def postfix_dict(mydict, postfix):
out = {k + postfix: v for k, v in mydict.items()}
return out
def unsort(L, sort_idx):
assert isinstance(sort_idx, list)
assert isinstance(L, list)
LL = zip(sort_idx, L)
LL = sorted(LL, key=lambda x: x[0])
_, L = zip(*LL)
return list(L)
def cat_dl(out_list, dim, verbose=True, squeeze=True):
out = {}
for key, val in out_list.items():
if isinstance(val[0], torch.Tensor):
out[key] = torch.cat(val, dim=dim)
if squeeze:
out[key] = out[key].squeeze()
elif isinstance(val[0], np.ndarray):
out[key] = np.concatenate(val, axis=dim)
if squeeze:
out[key] = np.squeeze(out[key])
elif isinstance(val[0], list):
out[key] = sum(val, [])
else:
if verbose:
print(f"Ignoring {key} undefined type {type(val[0])}")
return out
def stack_dl(out_list, dim, verbose=True, squeeze=True):
out = {}
for key, val in out_list.items():
if isinstance(val[0], torch.Tensor):
out[key] = torch.stack(val, dim=dim)
if squeeze:
out[key] = out[key].squeeze()
elif isinstance(val[0], np.ndarray):
out[key] = np.stack(val, axis=dim)
if squeeze:
out[key] = np.squeeze(out[key])
elif isinstance(val[0], list):
out[key] = sum(val, [])
else:
out[key] = val
if verbose:
print(f"Processing {key} undefined type {type(val[0])}")
return out
def add_prefix_postfix(mydict, prefix="", postfix=""):
assert isinstance(mydict, dict)
return dict((prefix + key + postfix, value) for (key, value) in mydict.items())
def ld2dl(LD):
assert isinstance(LD, list)
assert isinstance(LD[0], dict)
"""
A list of dict (same keys) to a dict of lists
"""
dict_list = {k: [dic[k] for dic in LD] for k in LD[0]}
return dict_list
class NameSpace(object):
def __init__(self, adict):
self.__dict__.update(adict)
def dict2ns(mydict):
"""
Convert dict objec to namespace
"""
return NameSpace(mydict)
def ld2dev(ld, dev):
"""
Convert tensors in a list or dict to a device recursively
"""
if isinstance(ld, torch.Tensor):
return ld.to(dev)
if isinstance(ld, dict):
for k, v in ld.items():
ld[k] = ld2dev(v, dev)
return ld
if isinstance(ld, list):
return [ld2dev(x, dev) for x in ld]
return ld
def all_comb_dict(hyper_dict):
assert isinstance(hyper_dict, dict)
keys, values = zip(*hyper_dict.items())
permute_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
return permute_dicts