equ1 commited on
Commit
e4b9ec4
1 Parent(s): a232c52

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -12
model.py CHANGED
@@ -101,19 +101,14 @@ class Net(nn.Module):
101
  return x
102
 
103
 
104
- def download_data():
105
- # downloads and loads MNIST train set
106
- transform = transforms.Compose([transforms.ToTensor(), transforms.RandomAffine(degrees=10, translate=(0.1,0.1))])
107
- train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
108
- train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, pin_memory=True)
109
-
110
- # downloads and loads MNIST test set
111
- val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
112
- val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
113
-
114
- # gets mean and std of dataset
115
- mean, std = get_mean_std(train_loader)
116
 
 
 
 
117
 
118
  # uses GPU if available
119
  if torch.cuda.is_available():
@@ -123,6 +118,9 @@ else:
123
 
124
  device = torch.device(dev)
125
 
 
 
 
126
 
127
  def run_model():
128
  # defines parameters
 
101
  return x
102
 
103
 
104
+ # downloads and loads MNIST train set
105
+ transform = transforms.Compose([transforms.ToTensor(), transforms.RandomAffine(degrees=10, translate=(0.1,0.1))])
106
+ train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
107
+ train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, pin_memory=True)
 
 
 
 
 
 
 
 
108
 
109
+ # downloads and loads MNIST test set
110
+ val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
111
+ val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
112
 
113
  # uses GPU if available
114
  if torch.cuda.is_available():
 
118
 
119
  device = torch.device(dev)
120
 
121
+ # gets mean and std of dataset
122
+ mean, std = get_mean_std(train_loader)
123
+
124
 
125
  def run_model():
126
  # defines parameters