Spaces:
Runtime error
Runtime error
import torch | |
def onehot(y, num_classes): | |
y_onehot = torch.zeros(y.size(0), num_classes).to(y.device) | |
if len(y.size()) == 1: | |
y_onehot = y_onehot.scatter_(1, y.unsqueeze(-1), 1) | |
elif len(y.size()) == 2: | |
y_onehot = y_onehot.scatter_(1, y, 1) | |
else: | |
raise ValueError("[onehot]: y should be in shape [B], or [B, C]") | |
return y_onehot | |
def sum(tensor, dim=None, keepdim=False): | |
if dim is None: | |
# sum up all dim | |
return torch.sum(tensor) | |
else: | |
if isinstance(dim, int): | |
dim = [dim] | |
dim = sorted(dim) | |
for d in dim: | |
tensor = tensor.sum(dim=d, keepdim=True) | |
if not keepdim: | |
for i, d in enumerate(dim): | |
tensor.squeeze_(d-i) | |
return tensor | |
def mean(tensor, dim=None, keepdim=False): | |
if dim is None: | |
# mean all dim | |
return torch.mean(tensor) | |
else: | |
if isinstance(dim, int): | |
dim = [dim] | |
dim = sorted(dim) | |
for d in dim: | |
tensor = tensor.mean(dim=d, keepdim=True) | |
if not keepdim: | |
for i, d in enumerate(dim): | |
tensor.squeeze_(d-i) | |
return tensor | |
def split_feature(tensor, type="split"): | |
""" | |
type = ["split", "cross"] | |
""" | |
C = tensor.size(1) | |
if type == "split": | |
return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] | |
elif type == "cross3": | |
return tensor[:, 0::3, ...], tensor[:, 1::3, ...], tensor[:, 2::3, ...] | |
elif type == "cross": | |
return tensor[:, 0::2, ...], tensor[:, 1::2, ...] | |
def cat_feature(tensor_a, tensor_b): | |
return torch.cat((tensor_a, tensor_b), dim=1) | |
def timesteps(tensor): | |
return int(tensor.size(2)) | |