Schrodingers's picture
Upload folder using huggingface_hub
ffbe0b4
raw
history blame
829 Bytes
import torch
def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0):
all_matrix = []
for idx in range(num):
random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id))
if keep_first:
fg = random_matrix[1:][torch.randperm(dim - 1)]
random_matrix = torch.cat([random_matrix[0:1], fg], dim=0)
else:
random_matrix = random_matrix[torch.randperm(dim)]
all_matrix.append(random_matrix)
return torch.stack(all_matrix, dim=0)
def truncated_normal_(tensor, mean=0, std=.02):
size = tensor.shape
tmp = tensor.new_empty(size + (4, )).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor