rzimmerdev commited on
Commit
1de9461
1 Parent(s): 49b098d

fix: Changed default dataset image type

Browse files
Files changed (1) hide show
  1. src/dataset.py +5 -4
src/dataset.py CHANGED
@@ -2,10 +2,11 @@
2
  # coding: utf-8
3
  import gzip
4
 
5
- from src.downloader import download_dataset
6
-
7
- import numpy as np
8
  from torch.utils.data import Dataset
 
 
 
9
 
10
 
11
  def load_mnist(download_dir):
@@ -37,7 +38,7 @@ class DatasetMNIST(Dataset):
37
  def __getitem__(self, n):
38
  if n > self.total:
39
  raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
40
- return self.data[n]
41
 
42
  def __len__(self):
43
  return len(self.data)
 
2
  # coding: utf-8
3
  import gzip
4
 
5
+ import torch
 
 
6
  from torch.utils.data import Dataset
7
+ import numpy as np
8
+
9
+ from src.downloader import download_dataset
10
 
11
 
12
  def load_mnist(download_dir):
 
38
  def __getitem__(self, n):
39
  if n > self.total:
40
  raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
41
+ return torch.tensor(self.data[n][0].reshape(1, 28, 28), dtype=torch.float32), torch.tensor(self.data[n][1])
42
 
43
  def __len__(self):
44
  return len(self.data)