import torch def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0): all_matrix = [] for idx in range(num): random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id)) if keep_first: fg = random_matrix[1:][torch.randperm(dim - 1)] random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) else: random_matrix = random_matrix[torch.randperm(dim)] all_matrix.append(random_matrix) return torch.stack(all_matrix, dim=0) def truncated_normal_(tensor, mean=0, std=.02): size = tensor.shape tmp = tensor.new_empty(size + (4, )).normal_() valid = (tmp < 2) & (tmp > -2) ind = valid.max(-1, keepdim=True)[1] tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) tensor.data.mul_(std).add_(mean) return tensor