File size: 5,985 Bytes
db6a3b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import numpy as np
import cv2
import torch
# Dictionary utils
def _dict_merge(dicta, dictb, prefix=''):
"""
Merge two dictionaries.
"""
assert isinstance(dicta, dict), 'input must be a dictionary'
assert isinstance(dictb, dict), 'input must be a dictionary'
dict_ = {}
all_keys = set(dicta.keys()).union(set(dictb.keys()))
for key in all_keys:
if key in dicta.keys() and key in dictb.keys():
if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
else:
raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
elif key in dicta.keys():
dict_[key] = dicta[key]
else:
dict_[key] = dictb[key]
return dict_
def dict_merge(dicta, dictb):
"""
Merge two dictionaries.
"""
return _dict_merge(dicta, dictb, prefix='')
def dict_foreach(dic, func, special_func={}):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
dic[key] = dict_foreach(dic[key], func)
else:
if key in special_func.keys():
dic[key] = special_func[key](dic[key])
else:
dic[key] = func(dic[key])
return dic
def dict_reduce(dicts, func, special_func={}):
"""
Reduce a list of dictionaries. Leaf values must be scalars.
"""
assert isinstance(dicts, list), 'input must be a list of dictionaries'
assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
all_keys = set([key for dict_ in dicts for key in dict_.keys()])
reduced_dict = {}
for key in all_keys:
vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
if isinstance(vlist[0], dict):
reduced_dict[key] = dict_reduce(vlist, func, special_func)
else:
if key in special_func.keys():
reduced_dict[key] = special_func[key](vlist)
else:
reduced_dict[key] = func(vlist)
return reduced_dict
def dict_any(dic, func):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
if dict_any(dic[key], func):
return True
else:
if func(dic[key]):
return True
return False
def dict_all(dic, func):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
if not dict_all(dic[key], func):
return False
else:
if not func(dic[key]):
return False
return True
def dict_flatten(dic, sep='.'):
"""
Flatten a nested dictionary into a dictionary with no nested dictionaries.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
flat_dict = {}
for key in dic.keys():
if isinstance(dic[key], dict):
sub_dict = dict_flatten(dic[key], sep=sep)
for sub_key in sub_dict.keys():
flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
else:
flat_dict[key] = dic[key]
return flat_dict
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
num_images = len(images)
if nrow is None and ncol is None:
if aspect_ratio is not None:
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
else:
nrow = int(np.sqrt(num_images))
ncol = (num_images + nrow - 1) // nrow
elif nrow is None and ncol is not None:
nrow = (num_images + ncol - 1) // ncol
elif nrow is not None and ncol is None:
ncol = (num_images + nrow - 1) // nrow
else:
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
for i, img in enumerate(images):
row = i // ncol
col = i % ncol
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
return grid
def notes_on_image(img, notes=None):
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if notes is not None:
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def save_image_with_notes(img, path, notes=None):
"""
Save an image with notes.
"""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy().transpose(1, 2, 0)
if img.dtype == np.float32 or img.dtype == np.float64:
img = np.clip(img * 255, 0, 255).astype(np.uint8)
img = notes_on_image(img, notes)
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
# debug utils
def atol(x, y):
"""
Absolute tolerance.
"""
return torch.abs(x - y)
def rtol(x, y):
"""
Relative tolerance.
"""
return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
# print utils
def indent(s, n=4):
"""
Indent a string.
"""
lines = s.split('\n')
for i in range(1, len(lines)):
lines[i] = ' ' * n + lines[i]
return '\n'.join(lines)
|