File size: 338 Bytes
8d6cd57
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np
import gdown

import torch


def download_file(file_id: str, output_path: str):
    gdown.download(f'https://drive.google.com/uc?id={file_id}', output_path)


def sample_labels(labels: torch.Tensor, n: int) -> torch.Tensor:
    high = labels.shape[0]
    idx = np.random.randint(0, high, size=n)
    return labels[idx]