from collections import abc as container_abc from inspect import getfullargspec import functools import numpy as np import torch from torch.nn import functional as F class PlaceHolder: def __init__(self, X, E, y=None): self.X = X self.E = E self.y = y def to(self, device=None, to_type=None): self.X = to_device(self.X, device, to_type) self.E = to_device(self.E, device, to_type) if self.y is not None: self.y = to_device(self.y, device, to_type) return self def type_as(self, x: torch.Tensor): """ Changes the device and dtype of X, E, y. """ self.X = self.X.type_as(x) self.E = self.E.type_as(x) if self.y is not None: self.y = self.y.type_as(x) return self def mask(self, node_mask, collapse=False): x_mask = node_mask.unsqueeze(-1) # bs, n, 1 e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 if collapse: self.X = torch.argmax(self.X, dim=-1) self.E = torch.argmax(self.E, dim=-1) self.X[node_mask == 0] = - 1 self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1 else: self.X = self.X * x_mask self.E = self.E * e_mask1 * e_mask2 assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) return self def clamp_to_prob(x, inplace=True): if not inplace: x = x.clone() x[x >= 1] = 1 x[x < 0] = 0 return x def sigmoid_with_clamp(x, inf=1e-4, sup=1 - 1e-4): x = x.clone() x = torch.clamp(x.sigmoid_(), min=inf, max=sup) return x def tensor2prob_map(in_tensor, trivial=False): cls_num = in_tensor.size()[1] prob_map = sigmoid_with_clamp(in_tensor) if cls_num == 1 else F.softmax(in_tensor, dim=1) if cls_num == 1 and trivial: prob_map = torch.cat([prob_map, 1 - prob_map], dim=1) return prob_map def feature_transform(feature, t_matrix, device=None, dtype='float'): if device is not None: feature, t_matrix = tuple(map(lambda x: to_device(x, device, dtype), [feature, t_matrix])) grid = F.affine_grid(t_matrix, feature.size(), align_corners=True) re = F.grid_sample(feature, grid, align_corners=True) return re def make_transform_matrix(mag): basic_matrix = np.array([[1., 0., 0.], [0., 1., 0.]]) shift, scale_list, shift_index_list, scale_index_list = \ [-mag, mag], [1 - mag, 1 / (1 - mag)], [(0, 2), (1, 2)], [(0, 0), (1, 1)] transform_matrix_list = [basic_matrix] for op, scale in zip(shift, scale_list): for shift_index, scale_index in zip(shift_index_list, scale_index_list): plus = np.zeros_like(basic_matrix) plus[shift_index] = op * 1 transform_matrix_list.append(basic_matrix + plus) scale_matrix = basic_matrix.copy() scale_matrix[scale_index] *= scale transform_matrix_list.append(scale_matrix) return transform_matrix_list def tensor2array(inp): if isinstance(inp, container_abc.Mapping): return {key: tensor2array(inp[key]) for key in inp} if isinstance(inp, list): return [tensor2array(item) for item in inp] if not isinstance(inp, torch.Tensor): return inp inp = inp.detach() if inp.device.type == 'cuda': inp = inp.cpu() return inp.numpy() def to_device(inp, device=None, to_type=None): if inp is None: return inp if isinstance(inp, container_abc.Mapping): return {key: to_device(inp[key], device, to_type) for key in inp} if isinstance(inp, container_abc.Sequence): return [to_device(item, device, to_type) for item in inp] if isinstance(inp, np.ndarray): inp = torch.as_tensor(inp) if to_type is None or to_type == 'identity': if device == 'identity' or device is None: return inp return inp.to(device) if device is None or device == 'identity': device = inp.device return inp.to(device).__getattribute__(to_type)() def cast_tensor_type(inputs, src_type, dst_type): """Recursively convert Tensor in inputs from src_type to dst_type. Args: inputs: Inputs that to be casted. src_type (torch.dtype): Source type.. dst_type (torch.dtype): Destination type. Returns: The same type with inputs, but all contained Tensors have been cast. """ if isinstance(inputs, torch.Tensor): return inputs.to(dst_type) elif isinstance(inputs, str): return inputs elif isinstance(inputs, np.ndarray): return inputs elif isinstance(inputs, container_abc.Mapping): return type(inputs)({ k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items() }) elif isinstance(inputs, container_abc.Iterable): return type(inputs)( cast_tensor_type(item, src_type, dst_type) for item in inputs) else: return inputs def force_fp32(apply_to=None, out_fp16=False): """Decorator to convert input arguments to fp32 in force. This decorator is useful when you write custom modules and want to support mixed precision training. If there are some inputs that must be processed in fp32 mode, then this decorator can handle it. If inputs arguments are fp16 tensors, they will be converted to fp32 automatically. Arguments other than fp16 tensors are ignored. Args: apply_to (Iterable, optional): The argument names to be converted. `None` indicates all arguments. out_fp16 (bool): Whether to convert the output back to fp16. Example: >>> import torch.nn as nn >>> class MyModule1(nn.Module): >>> >>> # Convert x and y to fp32 >>> @force_fp32() >>> def loss(self, x, y): >>> pass >>> import torch.nn as nn >>> class MyModule2(nn.Module): >>> >>> # convert pred to fp32 >>> @force_fp32(apply_to=('pred', )) >>> def post_process(self, pred, others): >>> pass """ def force_fp32_wrapper(old_func): @functools.wraps(old_func) def new_func(*args, **kwargs): # check if the module has set the attribute `fp16_enabled`, if not, # just fallback to the original method. # if not isinstance(args[0], torch.nn.Module): # raise TypeError('@force_fp32 can only be used to decorate the ' # 'method of nn.Module') # if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): # return old_func(*args, **kwargs) # get the arg spec of the decorated method args_info = getfullargspec(old_func) # get the argument names to be casted args_to_cast = args_info.args if apply_to is None else apply_to # convert the args that need to be processed new_args = [] if args: arg_names = args_info.args[:len(args)] for i, arg_name in enumerate(arg_names): if arg_name in args_to_cast: new_args.append( cast_tensor_type(args[i], torch.half, torch.float)) else: new_args.append(args[i]) # convert the kwargs that need to be processed new_kwargs = dict() if kwargs: for arg_name, arg_value in kwargs.items(): if arg_name in args_to_cast: new_kwargs[arg_name] = cast_tensor_type( arg_value, torch.half, torch.float) else: new_kwargs[arg_name] = arg_value # apply converted arguments to the decorated method output = old_func(*new_args, **new_kwargs) # cast the results back to fp32 if necessary if out_fp16: output = cast_tensor_type(output, torch.float, torch.half) return output return new_func return force_fp32_wrapper class TemporaryGrad(object): def __enter__(self): self.prev = torch.is_grad_enabled() torch.set_grad_enabled(True) def __exit__(self, exc_type, exc_value, traceback) -> None: torch.set_grad_enabled(self.prev) def make_one_hot(in_tensor, cls): if len(in_tensor) == 0: return in_tensor in_shape = tuple(in_tensor.size()) in_tensor = in_tensor.view(-1) res_one_hot = torch.zeros(in_tensor.size() + (cls,), device=in_tensor.device) select_index = (torch.arange(0, in_tensor.size()[-1]).long(), in_tensor.long()) res_one_hot[select_index] = 1 res = res_one_hot.reshape(in_shape[:-1] + (cls,)) return res def repeat_on_dim(in_tensor, repeat_num, dim=0): dims_num = list(in_tensor.size()) sq = in_tensor.unsqueeze(dim) repeat_v = np.ones(len(dims_num) + 1, dtype=np.int) repeat_v[dim + 1] = repeat_num rep = sq.repeat(*list(repeat_v)) dims_num[dim] *= repeat_num res = rep.reshape(dims_num) return res # ############# graphs def convert_node_matrix(x): atom_idx = torch.nonzero(x, as_tuple=True) atom = x[atom_idx] atom = torch.unsqueeze(atom, -1) return atom, atom_idx def convert_edge_matrix(edge): pass