Ananda Bollu commited on
Commit
8cd6dcb
1 Parent(s): c62696f

rename to pytorch_model.bin

Browse files

The model Ab0/foo-model does not seem to have model files. Please check that it contains either `pytorch_model.bin` or `tf_model.h5`.

model.pth → pytorch_model.bin RENAMED
File without changes
quickstart_tutorial.ipynb ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "collapsed": false
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "%matplotlib inline"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "\n",
19
+ "`Learn the Basics <intro.html>`_ ||\n",
20
+ "**Quickstart** ||\n",
21
+ "`Tensors <tensorqs_tutorial.html>`_ ||\n",
22
+ "`Datasets & DataLoaders <data_tutorial.html>`_ ||\n",
23
+ "`Transforms <transforms_tutorial.html>`_ ||\n",
24
+ "`Build Model <buildmodel_tutorial.html>`_ ||\n",
25
+ "`Autograd <autogradqs_tutorial.html>`_ ||\n",
26
+ "`Optimization <optimization_tutorial.html>`_ ||\n",
27
+ "`Save & Load Model <saveloadrun_tutorial.html>`_\n",
28
+ "\n",
29
+ "Quickstart\n",
30
+ "===================\n",
31
+ "This section runs through the API for common tasks in machine learning. Refer to the links in each section to dive deeper.\n",
32
+ "\n",
33
+ "Working with data\n",
34
+ "-----------------\n",
35
+ "PyTorch has two `primitives to work with data <https://pytorch.org/docs/stable/data.html>`_:\n",
36
+ "``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset``.\n",
37
+ "``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around\n",
38
+ "the ``Dataset``.\n",
39
+ "\n",
40
+ "\n"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 2,
46
+ "metadata": {
47
+ "collapsed": false
48
+ },
49
+ "outputs": [],
50
+ "source": [
51
+ "import torch\n",
52
+ "from torch import nn\n",
53
+ "from torch.utils.data import DataLoader\n",
54
+ "from torchvision import datasets\n",
55
+ "from torchvision.transforms import ToTensor, Lambda, Compose\n",
56
+ "import matplotlib.pyplot as plt\n",
57
+ "from huggingface_hub import push_to_hub_keras"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "markdown",
62
+ "metadata": {},
63
+ "source": [
64
+ "PyTorch offers domain-specific libraries such as `TorchText <https://pytorch.org/text/stable/index.html>`_,\n",
65
+ "`TorchVision <https://pytorch.org/vision/stable/index.html>`_, and `TorchAudio <https://pytorch.org/audio/stable/index.html>`_,\n",
66
+ "all of which include datasets. For this tutorial, we will be using a TorchVision dataset.\n",
67
+ "\n",
68
+ "The ``torchvision.datasets`` module contains ``Dataset`` objects for many real-world vision data like\n",
69
+ "CIFAR, COCO (`full list here <https://pytorch.org/vision/stable/datasets.html>`_). In this tutorial, we\n",
70
+ "use the FashionMNIST dataset. Every TorchVision ``Dataset`` includes two arguments: ``transform`` and\n",
71
+ "``target_transform`` to modify the samples and labels respectively.\n",
72
+ "\n"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 3,
78
+ "metadata": {
79
+ "collapsed": false
80
+ },
81
+ "outputs": [],
82
+ "source": [
83
+ "# Download training data from open datasets.\n",
84
+ "training_data = datasets.FashionMNIST(\n",
85
+ " root=\"data\",\n",
86
+ " train=True,\n",
87
+ " download=True,\n",
88
+ " transform=ToTensor(),\n",
89
+ ")\n",
90
+ "\n",
91
+ "# Download test data from open datasets.\n",
92
+ "test_data = datasets.FashionMNIST(\n",
93
+ " root=\"data\",\n",
94
+ " train=False,\n",
95
+ " download=True,\n",
96
+ " transform=ToTensor(),\n",
97
+ ")"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "We pass the ``Dataset`` as an argument to ``DataLoader``. This wraps an iterable over our dataset, and supports\n",
105
+ "automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 64, i.e. each element\n",
106
+ "in the dataloader iterable will return a batch of 64 features and labels.\n",
107
+ "\n"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 4,
113
+ "metadata": {
114
+ "collapsed": false
115
+ },
116
+ "outputs": [
117
+ {
118
+ "name": "stdout",
119
+ "output_type": "stream",
120
+ "text": [
121
+ "Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])\n",
122
+ "Shape of y: torch.Size([64]) torch.int64\n"
123
+ ]
124
+ }
125
+ ],
126
+ "source": [
127
+ "batch_size = 64\n",
128
+ "\n",
129
+ "# Create data loaders.\n",
130
+ "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
131
+ "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n",
132
+ "\n",
133
+ "for X, y in test_dataloader:\n",
134
+ " print(\"Shape of X [N, C, H, W]: \", X.shape)\n",
135
+ " print(\"Shape of y: \", y.shape, y.dtype)\n",
136
+ " break"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "Read more about `loading data in PyTorch <data_tutorial.html>`_.\n",
144
+ "\n",
145
+ "\n"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {},
151
+ "source": [
152
+ "--------------\n",
153
+ "\n",
154
+ "\n"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "markdown",
159
+ "metadata": {},
160
+ "source": [
161
+ "Creating Models\n",
162
+ "------------------\n",
163
+ "To define a neural network in PyTorch, we create a class that inherits\n",
164
+ "from `nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_. We define the layers of the network\n",
165
+ "in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate\n",
166
+ "operations in the neural network, we move it to the GPU if available.\n",
167
+ "\n"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 5,
173
+ "metadata": {
174
+ "collapsed": false
175
+ },
176
+ "outputs": [
177
+ {
178
+ "name": "stdout",
179
+ "output_type": "stream",
180
+ "text": [
181
+ "Using cpu device\n",
182
+ "NeuralNetwork(\n",
183
+ " (flatten): Flatten(start_dim=1, end_dim=-1)\n",
184
+ " (linear_relu_stack): Sequential(\n",
185
+ " (0): Linear(in_features=784, out_features=512, bias=True)\n",
186
+ " (1): ReLU()\n",
187
+ " (2): Linear(in_features=512, out_features=512, bias=True)\n",
188
+ " (3): ReLU()\n",
189
+ " (4): Linear(in_features=512, out_features=10, bias=True)\n",
190
+ " )\n",
191
+ ")\n"
192
+ ]
193
+ }
194
+ ],
195
+ "source": [
196
+ "# Get cpu or gpu device for training.\n",
197
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
198
+ "print(f\"Using {device} device\")\n",
199
+ "\n",
200
+ "# Define model\n",
201
+ "class NeuralNetwork(nn.Module):\n",
202
+ " def __init__(self):\n",
203
+ " super(NeuralNetwork, self).__init__()\n",
204
+ " self.flatten = nn.Flatten()\n",
205
+ " self.linear_relu_stack = nn.Sequential(\n",
206
+ " nn.Linear(28*28, 512),\n",
207
+ " nn.ReLU(),\n",
208
+ " nn.Linear(512, 512),\n",
209
+ " nn.ReLU(),\n",
210
+ " nn.Linear(512, 10)\n",
211
+ " )\n",
212
+ "\n",
213
+ " def forward(self, x):\n",
214
+ " x = self.flatten(x)\n",
215
+ " logits = self.linear_relu_stack(x)\n",
216
+ " return logits\n",
217
+ "\n",
218
+ "model = NeuralNetwork().to(device)\n",
219
+ "print(model)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "metadata": {},
225
+ "source": [
226
+ "Read more about `building neural networks in PyTorch <buildmodel_tutorial.html>`_.\n",
227
+ "\n",
228
+ "\n"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "markdown",
233
+ "metadata": {},
234
+ "source": [
235
+ "--------------\n",
236
+ "\n",
237
+ "\n"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "metadata": {},
243
+ "source": [
244
+ "Optimizing the Model Parameters\n",
245
+ "----------------------------------------\n",
246
+ "To train a model, we need a `loss function <https://pytorch.org/docs/stable/nn.html#loss-functions>`_\n",
247
+ "and an `optimizer <https://pytorch.org/docs/stable/optim.html>`_.\n",
248
+ "\n"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": 6,
254
+ "metadata": {
255
+ "collapsed": false
256
+ },
257
+ "outputs": [],
258
+ "source": [
259
+ "loss_fn = nn.CrossEntropyLoss()\n",
260
+ "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "metadata": {},
266
+ "source": [
267
+ "In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and\n",
268
+ "backpropagates the prediction error to adjust the model's parameters.\n",
269
+ "\n"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": 7,
275
+ "metadata": {
276
+ "collapsed": false
277
+ },
278
+ "outputs": [],
279
+ "source": [
280
+ "def train(dataloader, model, loss_fn, optimizer):\n",
281
+ " size = len(dataloader.dataset)\n",
282
+ " model.train()\n",
283
+ " for batch, (X, y) in enumerate(dataloader):\n",
284
+ " X, y = X.to(device), y.to(device)\n",
285
+ "\n",
286
+ " # Compute prediction error\n",
287
+ " pred = model(X)\n",
288
+ " loss = loss_fn(pred, y)\n",
289
+ "\n",
290
+ " # Backpropagation\n",
291
+ " optimizer.zero_grad()\n",
292
+ " loss.backward()\n",
293
+ " optimizer.step()\n",
294
+ "\n",
295
+ " if batch % 100 == 0:\n",
296
+ " loss, current = loss.item(), batch * len(X)\n",
297
+ " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "markdown",
302
+ "metadata": {},
303
+ "source": [
304
+ "We also check the model's performance against the test dataset to ensure it is learning.\n",
305
+ "\n"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": 8,
311
+ "metadata": {
312
+ "collapsed": false
313
+ },
314
+ "outputs": [],
315
+ "source": [
316
+ "def test(dataloader, model, loss_fn):\n",
317
+ " size = len(dataloader.dataset)\n",
318
+ " num_batches = len(dataloader)\n",
319
+ " model.eval()\n",
320
+ " test_loss, correct = 0, 0\n",
321
+ " with torch.no_grad():\n",
322
+ " for X, y in dataloader:\n",
323
+ " X, y = X.to(device), y.to(device)\n",
324
+ " pred = model(X)\n",
325
+ " test_loss += loss_fn(pred, y).item()\n",
326
+ " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
327
+ " test_loss /= num_batches\n",
328
+ " correct /= size\n",
329
+ " print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "metadata": {},
335
+ "source": [
336
+ "The training process is conducted over several iterations (*epochs*). During each epoch, the model learns\n",
337
+ "parameters to make better predictions. We print the model's accuracy and loss at each epoch; we'd like to see the\n",
338
+ "accuracy increase and the loss decrease with every epoch.\n",
339
+ "\n"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": 9,
345
+ "metadata": {
346
+ "collapsed": false
347
+ },
348
+ "outputs": [
349
+ {
350
+ "name": "stdout",
351
+ "output_type": "stream",
352
+ "text": [
353
+ "Epoch 1\n",
354
+ "-------------------------------\n",
355
+ "loss: 2.293067 [ 0/60000]\n",
356
+ "loss: 2.287422 [ 6400/60000]\n",
357
+ "loss: 2.265790 [12800/60000]\n",
358
+ "loss: 2.274793 [19200/60000]\n",
359
+ "loss: 2.257332 [25600/60000]\n",
360
+ "loss: 2.222204 [32000/60000]\n",
361
+ "loss: 2.240200 [38400/60000]\n",
362
+ "loss: 2.206084 [44800/60000]\n",
363
+ "loss: 2.190236 [51200/60000]\n",
364
+ "loss: 2.176934 [57600/60000]\n",
365
+ "Test Error: \n",
366
+ " Accuracy: 42.4%, Avg loss: 2.162450 \n",
367
+ "\n",
368
+ "Epoch 2\n",
369
+ "-------------------------------\n",
370
+ "loss: 2.161891 [ 0/60000]\n",
371
+ "loss: 2.160867 [ 6400/60000]\n",
372
+ "loss: 2.099223 [12800/60000]\n",
373
+ "loss: 2.127940 [19200/60000]\n",
374
+ "loss: 2.089684 [25600/60000]\n",
375
+ "loss: 2.018054 [32000/60000]\n",
376
+ "loss: 2.060461 [38400/60000]\n",
377
+ "loss: 1.981958 [44800/60000]\n",
378
+ "loss: 1.971331 [51200/60000]\n",
379
+ "loss: 1.930486 [57600/60000]\n",
380
+ "Test Error: \n",
381
+ " Accuracy: 58.1%, Avg loss: 1.909495 \n",
382
+ "\n",
383
+ "Epoch 3\n",
384
+ "-------------------------------\n",
385
+ "loss: 1.930542 [ 0/60000]\n",
386
+ "loss: 1.913976 [ 6400/60000]\n",
387
+ "loss: 1.788895 [12800/60000]\n",
388
+ "loss: 1.838503 [19200/60000]\n",
389
+ "loss: 1.757226 [25600/60000]\n",
390
+ "loss: 1.682464 [32000/60000]\n",
391
+ "loss: 1.722755 [38400/60000]\n",
392
+ "loss: 1.617113 [44800/60000]\n",
393
+ "loss: 1.632282 [51200/60000]\n",
394
+ "loss: 1.548769 [57600/60000]\n",
395
+ "Test Error: \n",
396
+ " Accuracy: 61.0%, Avg loss: 1.543196 \n",
397
+ "\n",
398
+ "Epoch 4\n",
399
+ "-------------------------------\n",
400
+ "loss: 1.601020 [ 0/60000]\n",
401
+ "loss: 1.574128 [ 6400/60000]\n",
402
+ "loss: 1.412696 [12800/60000]\n",
403
+ "loss: 1.496537 [19200/60000]\n",
404
+ "loss: 1.391789 [25600/60000]\n",
405
+ "loss: 1.360881 [32000/60000]\n",
406
+ "loss: 1.398112 [38400/60000]\n",
407
+ "loss: 1.316551 [44800/60000]\n",
408
+ "loss: 1.347136 [51200/60000]\n",
409
+ "loss: 1.253991 [57600/60000]\n",
410
+ "Test Error: \n",
411
+ " Accuracy: 62.8%, Avg loss: 1.267020 \n",
412
+ "\n",
413
+ "Epoch 5\n",
414
+ "-------------------------------\n",
415
+ "loss: 1.336873 [ 0/60000]\n",
416
+ "loss: 1.324502 [ 6400/60000]\n",
417
+ "loss: 1.153551 [12800/60000]\n",
418
+ "loss: 1.265215 [19200/60000]\n",
419
+ "loss: 1.149221 [25600/60000]\n",
420
+ "loss: 1.156962 [32000/60000]\n",
421
+ "loss: 1.194912 [38400/60000]\n",
422
+ "loss: 1.133846 [44800/60000]\n",
423
+ "loss: 1.164861 [51200/60000]\n",
424
+ "loss: 1.080542 [57600/60000]\n",
425
+ "Test Error: \n",
426
+ " Accuracy: 64.1%, Avg loss: 1.094896 \n",
427
+ "\n",
428
+ "Done!\n"
429
+ ]
430
+ }
431
+ ],
432
+ "source": [
433
+ "epochs = 5\n",
434
+ "for t in range(epochs):\n",
435
+ " print(f\"Epoch {t+1}\\n-------------------------------\")\n",
436
+ " train(train_dataloader, model, loss_fn, optimizer)\n",
437
+ " test(test_dataloader, model, loss_fn)\n",
438
+ "print(\"Done!\")"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "markdown",
443
+ "metadata": {},
444
+ "source": [
445
+ "Read more about `Training your model <optimization_tutorial.html>`_.\n",
446
+ "\n",
447
+ "\n"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "markdown",
452
+ "metadata": {},
453
+ "source": [
454
+ "--------------\n",
455
+ "\n",
456
+ "\n"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "markdown",
461
+ "metadata": {},
462
+ "source": [
463
+ "Saving Models\n",
464
+ "-------------\n",
465
+ "A common way to save a model is to serialize the internal state dictionary (containing the model parameters).\n",
466
+ "\n"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "execution_count": 10,
472
+ "metadata": {
473
+ "collapsed": false
474
+ },
475
+ "outputs": [
476
+ {
477
+ "name": "stdout",
478
+ "output_type": "stream",
479
+ "text": [
480
+ "Saved PyTorch Model State to model.pth\n"
481
+ ]
482
+ }
483
+ ],
484
+ "source": [
485
+ "torch.save(model.state_dict(), \"pytorch_model.bin\")\n",
486
+ "print(\"Saved PyTorch Model State to pytorch_model.bin\")"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "markdown",
491
+ "metadata": {},
492
+ "source": [
493
+ "Loading Models\n",
494
+ "----------------------------\n",
495
+ "\n",
496
+ "The process for loading a model includes re-creating the model structure and loading\n",
497
+ "the state dictionary into it.\n",
498
+ "\n"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "code",
503
+ "execution_count": 13,
504
+ "metadata": {
505
+ "collapsed": false
506
+ },
507
+ "outputs": [
508
+ {
509
+ "data": {
510
+ "text/plain": [
511
+ "<All keys matched successfully>"
512
+ ]
513
+ },
514
+ "execution_count": 13,
515
+ "metadata": {},
516
+ "output_type": "execute_result"
517
+ }
518
+ ],
519
+ "source": [
520
+ "model = NeuralNetwork()\n",
521
+ "model.load_state_dict(torch.load(\"model.pth\"))"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "markdown",
526
+ "metadata": {},
527
+ "source": [
528
+ "This model can now be used to make predictions.\n",
529
+ "\n"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": 15,
535
+ "metadata": {
536
+ "collapsed": false
537
+ },
538
+ "outputs": [
539
+ {
540
+ "name": "stdout",
541
+ "output_type": "stream",
542
+ "text": [
543
+ "Predicted: \"Shirt\", Actual: \"Shirt\"\n"
544
+ ]
545
+ }
546
+ ],
547
+ "source": [
548
+ "classes = [\n",
549
+ " \"T-shirt/top\",\n",
550
+ " \"Trouser\",\n",
551
+ " \"Pullover\",\n",
552
+ " \"Dress\",\n",
553
+ " \"Coat\",\n",
554
+ " \"Sandal\",\n",
555
+ " \"Shirt\",\n",
556
+ " \"Sneaker\",\n",
557
+ " \"Bag\",\n",
558
+ " \"Ankle boot\",\n",
559
+ "]\n",
560
+ "\n",
561
+ "model.eval()\n",
562
+ "x, y = test_data[4][0], test_data[4][1]\n",
563
+ "with torch.no_grad():\n",
564
+ " pred = model(x)\n",
565
+ " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
566
+ " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": 16,
572
+ "metadata": {},
573
+ "outputs": [
574
+ {
575
+ "data": {
576
+ "text/plain": [
577
+ "tensor([[[0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0039, 0.0039, 0.0000,\n",
578
+ " 0.0000, 0.0000, 0.0000, 0.2235, 0.2627, 0.2863, 0.2980, 0.2980,\n",
579
+ " 0.3255, 0.2431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
580
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
581
+ " [0.0000, 0.0000, 0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.0000,\n",
582
+ " 0.0510, 0.3098, 0.5020, 0.7882, 0.6353, 0.6314, 0.6784, 0.7529,\n",
583
+ " 0.6745, 0.7098, 0.7216, 0.4235, 0.1176, 0.0000, 0.0000, 0.0000,\n",
584
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
585
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.4000,\n",
586
+ " 0.5451, 0.5569, 0.4039, 0.4510, 0.6353, 0.6039, 0.6471, 0.6000,\n",
587
+ " 0.5451, 0.5059, 0.5882, 0.5412, 0.6706, 0.6314, 0.1020, 0.0000,\n",
588
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
589
+ " [0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.4157, 0.4863,\n",
590
+ " 0.4235, 0.4039, 0.4157, 0.3647, 0.3922, 0.7059, 0.6118, 0.5765,\n",
591
+ " 0.5412, 0.3333, 0.6157, 0.4471, 0.4863, 0.6039, 0.6157, 0.0000,\n",
592
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
593
+ " [0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.1137, 0.5255, 0.3961,\n",
594
+ " 0.4431, 0.4235, 0.3804, 0.4549, 0.3176, 0.5725, 0.7176, 0.6431,\n",
595
+ " 0.4353, 0.5725, 0.5137, 0.4784, 0.5176, 0.5686, 0.6627, 0.3647,\n",
596
+ " 0.0000, 0.0039, 0.0000, 0.0000],\n",
597
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2549, 0.5137, 0.4118,\n",
598
+ " 0.3961, 0.4235, 0.3922, 0.4078, 0.3804, 0.2902, 0.8078, 0.6824,\n",
599
+ " 0.4510, 0.5882, 0.4235, 0.4667, 0.5725, 0.5961, 0.6353, 0.5529,\n",
600
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
601
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4235, 0.4824, 0.4392,\n",
602
+ " 0.4157, 0.3843, 0.3922, 0.3961, 0.4353, 0.2824, 0.5333, 0.5176,\n",
603
+ " 0.4392, 0.4510, 0.4275, 0.5569, 0.5882, 0.6275, 0.6353, 0.7647,\n",
604
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
605
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5294, 0.4784, 0.4667,\n",
606
+ " 0.4392, 0.3255, 0.3647, 0.3804, 0.4157, 0.4510, 0.3569, 0.4275,\n",
607
+ " 0.3255, 0.4275, 0.4902, 0.6471, 0.5490, 0.7569, 0.6275, 0.6902,\n",
608
+ " 0.0235, 0.0000, 0.0000, 0.0000],\n",
609
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0902, 0.5294, 0.5176, 0.5843,\n",
610
+ " 0.4078, 0.3059, 0.3765, 0.3804, 0.4039, 0.4235, 0.4235, 0.4510,\n",
611
+ " 0.3294, 0.4471, 0.5843, 0.6196, 0.5765, 0.8196, 0.6275, 0.6980,\n",
612
+ " 0.2039, 0.0000, 0.0000, 0.0000],\n",
613
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.2235, 0.4863, 0.5137, 0.6275,\n",
614
+ " 0.4039, 0.3765, 0.3961, 0.4275, 0.4275, 0.4353, 0.4235, 0.4471,\n",
615
+ " 0.4157, 0.4431, 0.6118, 0.6392, 0.6118, 0.7686, 0.6549, 0.6824,\n",
616
+ " 0.3333, 0.0000, 0.0000, 0.0000],\n",
617
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.3373, 0.4549, 0.4941, 0.6275,\n",
618
+ " 0.5176, 0.4000, 0.3765, 0.4078, 0.4196, 0.3843, 0.3647, 0.4824,\n",
619
+ " 0.4549, 0.4392, 0.5843, 0.6275, 0.7098, 0.7294, 0.6353, 0.6353,\n",
620
+ " 0.4824, 0.0000, 0.0000, 0.0000],\n",
621
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.4392, 0.4471, 0.4392, 0.6549,\n",
622
+ " 0.5725, 0.3922, 0.3922, 0.3961, 0.4196, 0.3765, 0.3922, 0.4941,\n",
623
+ " 0.4039, 0.4706, 0.5529, 0.6196, 0.6549, 0.7333, 0.5765, 0.5804,\n",
624
+ " 0.6667, 0.0000, 0.0000, 0.0000],\n",
625
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.4863, 0.4627, 0.3961, 0.7725,\n",
626
+ " 0.3490, 0.3961, 0.3922, 0.3765, 0.4235, 0.4039, 0.4235, 0.4784,\n",
627
+ " 0.4196, 0.4980, 0.5451, 0.5882, 0.4667, 0.7686, 0.5686, 0.5569,\n",
628
+ " 0.7020, 0.0000, 0.0000, 0.0000],\n",
629
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.5137, 0.4510, 0.3804, 0.7765,\n",
630
+ " 0.1843, 0.4235, 0.3765, 0.3765, 0.4157, 0.4667, 0.4000, 0.4706,\n",
631
+ " 0.4039, 0.4824, 0.5490, 0.5882, 0.3176, 0.8078, 0.5725, 0.5294,\n",
632
+ " 0.7608, 0.0000, 0.0000, 0.0000],\n",
633
+ " [0.0000, 0.0000, 0.0000, 0.0157, 0.5333, 0.4627, 0.3843, 0.7569,\n",
634
+ " 0.0824, 0.4275, 0.3765, 0.4157, 0.4000, 0.5059, 0.3922, 0.4667,\n",
635
+ " 0.4000, 0.4627, 0.5529, 0.6000, 0.1765, 0.8471, 0.5804, 0.5451,\n",
636
+ " 0.8039, 0.0471, 0.0000, 0.0000],\n",
637
+ " [0.0000, 0.0000, 0.0000, 0.0941, 0.5373, 0.4588, 0.3961, 0.7333,\n",
638
+ " 0.0980, 0.4431, 0.3608, 0.4392, 0.3686, 0.4706, 0.4118, 0.4980,\n",
639
+ " 0.3804, 0.4510, 0.5569, 0.5882, 0.0745, 0.8353, 0.5804, 0.5137,\n",
640
+ " 0.8000, 0.1412, 0.0000, 0.0000],\n",
641
+ " [0.0000, 0.0000, 0.0000, 0.1569, 0.5529, 0.4275, 0.4588, 0.6196,\n",
642
+ " 0.0471, 0.4863, 0.3529, 0.4549, 0.3765, 0.4588, 0.4431, 0.5333,\n",
643
+ " 0.3686, 0.4353, 0.5765, 0.6392, 0.1216, 0.7490, 0.5725, 0.5255,\n",
644
+ " 0.8078, 0.2275, 0.0000, 0.0000],\n",
645
+ " [0.0000, 0.0000, 0.0000, 0.1529, 0.5059, 0.4000, 0.5765, 0.4667,\n",
646
+ " 0.0000, 0.4706, 0.3529, 0.4667, 0.3961, 0.4549, 0.4157, 0.4980,\n",
647
+ " 0.4000, 0.4471, 0.5725, 0.7059, 0.0784, 0.5725, 0.6235, 0.5059,\n",
648
+ " 0.8000, 0.2745, 0.0000, 0.0000],\n",
649
+ " [0.0000, 0.0000, 0.0000, 0.2275, 0.4941, 0.4353, 0.6353, 0.3961,\n",
650
+ " 0.0824, 0.5176, 0.3490, 0.4824, 0.4235, 0.4157, 0.4000, 0.4941,\n",
651
+ " 0.4353, 0.4549, 0.5529, 0.6980, 0.1961, 0.4392, 0.6627, 0.5412,\n",
652
+ " 0.6431, 0.3294, 0.0000, 0.0000],\n",
653
+ " [0.0000, 0.0000, 0.0000, 0.4235, 0.5255, 0.5255, 0.7255, 0.3294,\n",
654
+ " 0.2863, 0.4824, 0.3412, 0.4784, 0.4353, 0.4000, 0.4157, 0.5020,\n",
655
+ " 0.4471, 0.4275, 0.5255, 0.6824, 0.3804, 0.3843, 0.6275, 0.5765,\n",
656
+ " 0.6863, 0.5294, 0.0000, 0.0000],\n",
657
+ " [0.0000, 0.0000, 0.0000, 0.3804, 0.5569, 0.6627, 0.7765, 0.1451,\n",
658
+ " 0.3294, 0.4196, 0.3804, 0.4784, 0.4392, 0.4275, 0.4392, 0.4941,\n",
659
+ " 0.4000, 0.3765, 0.5137, 0.6745, 0.5020, 0.2000, 0.9961, 0.6588,\n",
660
+ " 0.6431, 0.4353, 0.0000, 0.0000],\n",
661
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0471, 0.1804, 0.0078,\n",
662
+ " 0.4667, 0.4000, 0.4275, 0.4824, 0.3765, 0.4549, 0.4784, 0.5176,\n",
663
+ " 0.4157, 0.4157, 0.5059, 0.5922, 0.7216, 0.1020, 0.0784, 0.0314,\n",
664
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
665
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0510,\n",
666
+ " 0.5373, 0.3961, 0.4471, 0.3922, 0.4157, 0.5255, 0.5294, 0.5059,\n",
667
+ " 0.4078, 0.4353, 0.4824, 0.5922, 0.7608, 0.2902, 0.0000, 0.0000,\n",
668
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
669
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118, 0.0000, 0.2863,\n",
670
+ " 0.5176, 0.3961, 0.4078, 0.4000, 0.5490, 0.4235, 0.4235, 0.5137,\n",
671
+ " 0.4157, 0.4667, 0.4431, 0.5569, 0.6549, 0.5294, 0.0000, 0.0039,\n",
672
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
673
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4392,\n",
674
+ " 0.4627, 0.4196, 0.4078, 0.5451, 0.4275, 0.3804, 0.4824, 0.5412,\n",
675
+ " 0.4196, 0.4980, 0.4706, 0.5333, 0.6314, 0.6235, 0.0000, 0.0000,\n",
676
+ " 0.0039, 0.0000, 0.0000, 0.0000],\n",
677
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.5569,\n",
678
+ " 0.5804, 0.4392, 0.4118, 0.3961, 0.3255, 0.4902, 0.4824, 0.5608,\n",
679
+ " 0.4078, 0.4510, 0.3922, 0.4941, 0.6588, 0.6980, 0.0275, 0.0000,\n",
680
+ " 0.0078, 0.0000, 0.0000, 0.0000],\n",
681
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0353,\n",
682
+ " 0.4941, 0.7216, 0.7843, 0.6549, 0.6392, 0.6706, 0.5882, 0.6549,\n",
683
+ " 0.6118, 0.6824, 0.7725, 0.7137, 0.6353, 0.2392, 0.0000, 0.0000,\n",
684
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
685
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
686
+ " 0.0000, 0.0000, 0.1176, 0.2824, 0.3725, 0.4275, 0.4353, 0.4353,\n",
687
+ " 0.4157, 0.3961, 0.2784, 0.0471, 0.0000, 0.0000, 0.0000, 0.0000,\n",
688
+ " 0.0000, 0.0000, 0.0000, 0.0000]]])"
689
+ ]
690
+ },
691
+ "execution_count": 16,
692
+ "metadata": {},
693
+ "output_type": "execute_result"
694
+ }
695
+ ],
696
+ "source": [
697
+ "test_data[4][0]"
698
+ ]
699
+ },
700
+ {
701
+ "cell_type": "markdown",
702
+ "metadata": {},
703
+ "source": [
704
+ "Read more about `Saving & Loading your model <saveloadrun_tutorial.html>`_.\n",
705
+ "\n",
706
+ "\n"
707
+ ]
708
+ }
709
+ ],
710
+ "metadata": {
711
+ "kernelspec": {
712
+ "display_name": "Python 3",
713
+ "language": "python",
714
+ "name": "python3"
715
+ },
716
+ "language_info": {
717
+ "codemirror_mode": {
718
+ "name": "ipython",
719
+ "version": 3
720
+ },
721
+ "file_extension": ".py",
722
+ "mimetype": "text/x-python",
723
+ "name": "python",
724
+ "nbconvert_exporter": "python",
725
+ "pygments_lexer": "ipython3",
726
+ "version": "3.8.1"
727
+ }
728
+ },
729
+ "nbformat": 4,
730
+ "nbformat_minor": 0
731
+ }