Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -110,6 +110,10 @@ def download_data():
|
|
110 |
# downloads and loads MNIST test set
|
111 |
val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
|
112 |
val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
|
|
|
|
|
|
|
|
|
113 |
|
114 |
# uses GPU if available
|
115 |
if torch.cuda.is_available():
|
@@ -119,9 +123,6 @@ else:
|
|
119 |
|
120 |
device = torch.device(dev)
|
121 |
|
122 |
-
# gets mean and std of dataset
|
123 |
-
mean, std = get_mean_std(train_loader)
|
124 |
-
|
125 |
|
126 |
def run_model():
|
127 |
# defines parameters
|
|
|
110 |
# downloads and loads MNIST test set
|
111 |
val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
|
112 |
val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
|
113 |
+
|
114 |
+
# gets mean and std of dataset
|
115 |
+
mean, std = get_mean_std(train_loader)
|
116 |
+
|
117 |
|
118 |
# uses GPU if available
|
119 |
if torch.cuda.is_available():
|
|
|
123 |
|
124 |
device = torch.device(dev)
|
125 |
|
|
|
|
|
|
|
126 |
|
127 |
def run_model():
|
128 |
# defines parameters
|