Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import common.thing as thing | |
def _print_stat(key, thing): | |
""" | |
Helper function for printing statistics about a key-value pair in an xdict. | |
""" | |
mytype = type(thing) | |
if isinstance(thing, (list, tuple)): | |
print("{:<20}: {:<30}\t{:}".format(key, len(thing), mytype)) | |
elif isinstance(thing, (torch.Tensor)): | |
dev = thing.device | |
shape = str(thing.shape).replace(" ", "") | |
print("{:<20}: {:<30}\t{:}\t{}".format(key, shape, mytype, dev)) | |
elif isinstance(thing, (np.ndarray)): | |
dev = "" | |
shape = str(thing.shape).replace(" ", "") | |
print("{:<20}: {:<30}\t{:}".format(key, shape, mytype)) | |
else: | |
print("{:<20}: {:}".format(key, mytype)) | |
class xdict(dict): | |
""" | |
A subclass of Python's built-in dict class, which provides additional methods for manipulating and operating on dictionaries. | |
""" | |
def __init__(self, mydict=None): | |
""" | |
Constructor for the xdict class. Creates a new xdict object and optionally initializes it with key-value pairs from the provided dictionary mydict. If mydict is not provided, an empty xdict is created. | |
""" | |
if mydict is None: | |
return | |
for k, v in mydict.items(): | |
super().__setitem__(k, v) | |
def subset(self, keys): | |
""" | |
Returns a new xdict object containing only the key-value pairs with keys in the provided list 'keys'. | |
""" | |
out_dict = {} | |
for k in keys: | |
out_dict[k] = self[k] | |
return xdict(out_dict) | |
def __setitem__(self, key, val): | |
""" | |
Overrides the dict.__setitem__ method to raise an assertion error if a key already exists. | |
""" | |
assert key not in self.keys(), f"Key already exists {key}" | |
super().__setitem__(key, val) | |
def search(self, keyword, replace_to=None): | |
""" | |
Returns a new xdict object containing only the key-value pairs with keys that contain the provided keyword. | |
""" | |
out_dict = {} | |
for k in self.keys(): | |
if keyword in k: | |
if replace_to is None: | |
out_dict[k] = self[k] | |
else: | |
out_dict[k.replace(keyword, replace_to)] = self[k] | |
return xdict(out_dict) | |
def rm(self, keyword, keep_list=[], verbose=False): | |
""" | |
Returns a new xdict object with keys that contain keyword removed. Keys in keep_list are excluded from the removal. | |
""" | |
out_dict = {} | |
for k in self.keys(): | |
if keyword not in k or k in keep_list: | |
out_dict[k] = self[k] | |
else: | |
if verbose: | |
print(f"Removing: {k}") | |
return xdict(out_dict) | |
def overwrite(self, k, v): | |
""" | |
The original assignment operation of Python dict | |
""" | |
super().__setitem__(k, v) | |
def merge(self, dict2): | |
""" | |
Same as dict.update(), but raises an assertion error if there are duplicate keys between the two dictionaries. | |
Args: | |
dict2 (dict or xdict): The dictionary or xdict instance to merge with. | |
Raises: | |
AssertionError: If dict2 is not a dictionary or xdict instance. | |
AssertionError: If there are duplicate keys between the two instances. | |
""" | |
assert isinstance(dict2, (dict, xdict)) | |
mykeys = set(self.keys()) | |
intersect = mykeys.intersection(set(dict2.keys())) | |
assert len(intersect) == 0, f"Merge failed: duplicate keys ({intersect})" | |
self.update(dict2) | |
def mul(self, scalar): | |
""" | |
Multiplies each value (could be tensor, np.array, list) in the xdict instance by the provided scalar. | |
Args: | |
scalar (float): The scalar to multiply the values by. | |
Raises: | |
AssertionError: If scalar is not a float. | |
""" | |
if isinstance(scalar, int): | |
scalar = 1.0 * scalar | |
assert isinstance(scalar, float) | |
out_dict = {} | |
for k in self.keys(): | |
if isinstance(self[k], list): | |
out_dict[k] = [v * scalar for v in self[k]] | |
else: | |
out_dict[k] = self[k] * scalar | |
return xdict(out_dict) | |
def prefix(self, text): | |
""" | |
Adds a prefix to each key in the xdict instance. | |
Args: | |
text (str): The prefix to add. | |
Returns: | |
xdict: The xdict instance with the added prefix. | |
""" | |
out_dict = {} | |
for k in self.keys(): | |
out_dict[text + k] = self[k] | |
return xdict(out_dict) | |
def replace_keys(self, str_src, str_tar): | |
""" | |
Replaces a substring in all keys of the xdict instance. | |
Args: | |
str_src (str): The substring to replace. | |
str_tar (str): The replacement string. | |
Returns: | |
xdict: The xdict instance with the replaced keys. | |
""" | |
out_dict = {} | |
for k in self.keys(): | |
old_key = k | |
new_key = old_key.replace(str_src, str_tar) | |
out_dict[new_key] = self[k] | |
return xdict(out_dict) | |
def postfix(self, text): | |
""" | |
Adds a postfix to each key in the xdict instance. | |
Args: | |
text (str): The postfix to add. | |
Returns: | |
xdict: The xdict instance with the added postfix. | |
""" | |
out_dict = {} | |
for k in self.keys(): | |
out_dict[k + text] = self[k] | |
return xdict(out_dict) | |
def sorted_keys(self): | |
""" | |
Returns a sorted list of the keys in the xdict instance. | |
Returns: | |
list: A sorted list of keys in the xdict instance. | |
""" | |
return sorted(list(self.keys())) | |
def to(self, dev): | |
""" | |
Moves the xdict instance to a specific device. | |
Args: | |
dev (torch.device): The device to move the instance to. | |
Returns: | |
xdict: The xdict instance moved to the specified device. | |
""" | |
if dev is None: | |
return self | |
raw_dict = dict(self) | |
return xdict(thing.thing2dev(raw_dict, dev)) | |
def to_torch(self): | |
""" | |
Converts elements in the xdict to Torch tensors and returns a new xdict. | |
Returns: | |
xdict: A new xdict with Torch tensors as values. | |
""" | |
return xdict(thing.thing2torch(self)) | |
def to_np(self): | |
""" | |
Converts elements in the xdict to numpy arrays and returns a new xdict. | |
Returns: | |
xdict: A new xdict with numpy arrays as values. | |
""" | |
return xdict(thing.thing2np(self)) | |
def tolist(self): | |
""" | |
Converts elements in the xdict to Python lists and returns a new xdict. | |
Returns: | |
xdict: A new xdict with Python lists as values. | |
""" | |
return xdict(thing.thing2list(self)) | |
def print_stat(self): | |
""" | |
Prints statistics for each item in the xdict. | |
""" | |
for k, v in self.items(): | |
_print_stat(k, v) | |
def detach(self): | |
""" | |
Detaches all Torch tensors in the xdict from the computational graph and moves them to the CPU. | |
Non-tensor objects are ignored. | |
Returns: | |
xdict: A new xdict with detached Torch tensors as values. | |
""" | |
return xdict(thing.detach_thing(self)) | |
def has_invalid(self): | |
""" | |
Checks if any of the Torch tensors in the xdict contain NaN or Inf values. | |
Returns: | |
bool: True if at least one tensor contains NaN or Inf values, False otherwise. | |
""" | |
for k, v in self.items(): | |
if isinstance(v, torch.Tensor): | |
if torch.isnan(v).any(): | |
print(f"{k} contains nan values") | |
return True | |
if torch.isinf(v).any(): | |
print(f"{k} contains inf values") | |
return True | |
return False | |
def apply(self, operation, criterion=None): | |
""" | |
Applies an operation to the values in the xdict, based on an optional criterion. | |
Args: | |
operation (callable): A callable object that takes a single argument and returns a value. | |
criterion (callable, optional): A callable object that takes two arguments (key and value) and returns a boolean. | |
Returns: | |
xdict: A new xdict with the same keys as the original, but with the values modified by the operation. | |
""" | |
out = {} | |
for k, v in self.items(): | |
if criterion is None or criterion(k, v): | |
out[k] = operation(v) | |
return xdict(out) | |
def save(self, path, dev=None, verbose=True): | |
""" | |
Saves the xdict to disk as a Torch tensor. | |
Args: | |
path (str): The path to save the xdict. | |
dev (torch.device, optional): The device to use for saving the tensor (default is CPU). | |
verbose (bool, optional): Whether to print a message indicating that the xdict has been saved (default is True). | |
""" | |
if verbose: | |
print(f"Saving to {path}") | |
torch.save(self.to(dev), path) | |