Spaces:
Running
Running
File size: 1,770 Bytes
6ec3bf6 1de9461 6ec3bf6 1de9461 6ec3bf6 1de9461 6ec3bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
#!/usr/bin/env python
# coding: utf-8
import gzip
import torch
from torch.utils.data import Dataset
import numpy as np
from src.downloader import download_dataset
def load_mnist(download_dir):
download_dataset("mnist", download_dir)
return {"train": (download_dir + "train_images", download_dir + "train_labels"),
"test": (download_dir + "test_images", download_dir + "test_labels")}
class DatasetMNIST(Dataset):
def __init__(self, images, labels):
with gzip.open(images, 'r') as f:
f.read(4)
self.total = int.from_bytes(f.read(4), 'big')
rows = int.from_bytes(f.read(4), 'big')
columns = int.from_bytes(f.read(4), 'big')
image_data = f.read()
images = np.frombuffer(image_data, dtype=np.uint8).reshape((self.total, rows, columns))
self.images = images
with gzip.open(labels, 'r') as f:
f.read(8)
label_data = f.read()
labels = np.frombuffer(label_data, dtype=np.uint8)
self.labels = labels
self.data = list(zip(self.images, self.labels))
def __getitem__(self, n):
if n > self.total:
raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
return torch.tensor(self.data[n][0].reshape(1, 28, 28), dtype=torch.float32), torch.tensor(self.data[n][1])
def __len__(self):
return len(self.data)
if __name__ == "__main__":
download_dir = "../downloads/mnist/"
mnist = load_mnist(download_dir)
dataset = DatasetMNIST(*mnist["train"])
import matplotlib.pyplot as plt
X, y = dataset[4]
plt.imshow(X, cmap="gray")
plt.title(label="Annotated label: " + str(y))
plt.show()
|