import torch def onehot(y, num_classes): y_onehot = torch.zeros(y.size(0), num_classes).to(y.device) if len(y.size()) == 1: y_onehot = y_onehot.scatter_(1, y.unsqueeze(-1), 1) elif len(y.size()) == 2: y_onehot = y_onehot.scatter_(1, y, 1) else: raise ValueError("[onehot]: y should be in shape [B], or [B, C]") return y_onehot def sum(tensor, dim=None, keepdim=False): if dim is None: # sum up all dim return torch.sum(tensor) else: if isinstance(dim, int): dim = [dim] dim = sorted(dim) for d in dim: tensor = tensor.sum(dim=d, keepdim=True) if not keepdim: for i, d in enumerate(dim): tensor.squeeze_(d-i) return tensor def mean(tensor, dim=None, keepdim=False): if dim is None: # mean all dim return torch.mean(tensor) else: if isinstance(dim, int): dim = [dim] dim = sorted(dim) for d in dim: tensor = tensor.mean(dim=d, keepdim=True) if not keepdim: for i, d in enumerate(dim): tensor.squeeze_(d-i) return tensor def split_feature(tensor, type="split"): """ type = ["split", "cross"] """ C = tensor.size(1) if type == "split": return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] elif type == "cross3": return tensor[:, 0::3, ...], tensor[:, 1::3, ...], tensor[:, 2::3, ...] elif type == "cross": return tensor[:, 0::2, ...], tensor[:, 1::2, ...] def cat_feature(tensor_a, tensor_b): return torch.cat((tensor_a, tensor_b), dim=1) def timesteps(tensor): return int(tensor.size(2))