Spaces:
Runtime error
Runtime error
File size: 338 Bytes
8d6cd57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import numpy as np
import gdown
import torch
def download_file(file_id: str, output_path: str):
gdown.download(f'https://drive.google.com/uc?id={file_id}', output_path)
def sample_labels(labels: torch.Tensor, n: int) -> torch.Tensor:
high = labels.shape[0]
idx = np.random.randint(0, high, size=n)
return labels[idx]
|