equ1 commited on
Commit
9b30bbe
1 Parent(s): 569aa5e

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -8
model.py CHANGED
@@ -101,14 +101,15 @@ class Net(nn.Module):
101
  return x
102
 
103
 
104
- # downloads and loads MNIST train set
105
- transform = transforms.Compose([transforms.ToTensor(), transforms.RandomAffine(degrees=10, translate=(0.1,0.1))])
106
- train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
107
- train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, pin_memory=True)
108
-
109
- # downloads and loads MNIST test set
110
- val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
111
- val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
 
112
 
113
  # uses GPU if available
114
  if torch.cuda.is_available():
 
101
  return x
102
 
103
 
104
+ def download_data():
105
+ # downloads and loads MNIST train set
106
+ transform = transforms.Compose([transforms.ToTensor(), transforms.RandomAffine(degrees=10, translate=(0.1,0.1))])
107
+ train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
108
+ train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, pin_memory=True)
109
+
110
+ # downloads and loads MNIST test set
111
+ val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
112
+ val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
113
 
114
  # uses GPU if available
115
  if torch.cuda.is_available():