{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "`Learn the Basics `_ ||\n", "**Quickstart** ||\n", "`Tensors `_ ||\n", "`Datasets & DataLoaders `_ ||\n", "`Transforms `_ ||\n", "`Build Model `_ ||\n", "`Autograd `_ ||\n", "`Optimization `_ ||\n", "`Save & Load Model `_\n", "\n", "Quickstart\n", "===================\n", "This section runs through the API for common tasks in machine learning. Refer to the links in each section to dive deeper.\n", "\n", "Working with data\n", "-----------------\n", "PyTorch has two `primitives to work with data `_:\n", "``torch.utils.data.DataLoader`` and ``torch.utils.data.Dataset``.\n", "``Dataset`` stores the samples and their corresponding labels, and ``DataLoader`` wraps an iterable around\n", "the ``Dataset``.\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor, Lambda, Compose\n", "import matplotlib.pyplot as plt\n", "from huggingface_hub import push_to_hub_keras" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PyTorch offers domain-specific libraries such as `TorchText `_,\n", "`TorchVision `_, and `TorchAudio `_,\n", "all of which include datasets. For this tutorial, we will be using a TorchVision dataset.\n", "\n", "The ``torchvision.datasets`` module contains ``Dataset`` objects for many real-world vision data like\n", "CIFAR, COCO (`full list here `_). In this tutorial, we\n", "use the FashionMNIST dataset. Every TorchVision ``Dataset`` includes two arguments: ``transform`` and\n", "``target_transform`` to modify the samples and labels respectively.\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Download training data from open datasets.\n", "training_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=True,\n", " download=True,\n", " transform=ToTensor(),\n", ")\n", "\n", "# Download test data from open datasets.\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=False,\n", " download=True,\n", " transform=ToTensor(),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We pass the ``Dataset`` as an argument to ``DataLoader``. This wraps an iterable over our dataset, and supports\n", "automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 64, i.e. each element\n", "in the dataloader iterable will return a batch of 64 features and labels.\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])\n", "Shape of y: torch.Size([64]) torch.int64\n" ] } ], "source": [ "batch_size = 64\n", "\n", "# Create data loaders.\n", "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n", "\n", "for X, y in test_dataloader:\n", " print(\"Shape of X [N, C, H, W]: \", X.shape)\n", " print(\"Shape of y: \", y.shape, y.dtype)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read more about `loading data in PyTorch `_.\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "--------------\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creating Models\n", "------------------\n", "To define a neural network in PyTorch, we create a class that inherits\n", "from `nn.Module `_. We define the layers of the network\n", "in the ``__init__`` function and specify how data will pass through the network in the ``forward`` function. To accelerate\n", "operations in the neural network, we move it to the GPU if available.\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cpu device\n", "NeuralNetwork(\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " (linear_relu_stack): Sequential(\n", " (0): Linear(in_features=784, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " (4): Linear(in_features=512, out_features=10, bias=True)\n", " )\n", ")\n" ] } ], "source": [ "# Get cpu or gpu device for training.\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(f\"Using {device} device\")\n", "\n", "# Define model\n", "class NeuralNetwork(nn.Module):\n", " def __init__(self):\n", " super(NeuralNetwork, self).__init__()\n", " self.flatten = nn.Flatten()\n", " self.linear_relu_stack = nn.Sequential(\n", " nn.Linear(28*28, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 10)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.flatten(x)\n", " logits = self.linear_relu_stack(x)\n", " return logits\n", "\n", "model = NeuralNetwork().to(device)\n", "print(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read more about `building neural networks in PyTorch `_.\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "--------------\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optimizing the Model Parameters\n", "----------------------------------------\n", "To train a model, we need a `loss function `_\n", "and an `optimizer `_.\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and\n", "backpropagates the prediction error to adjust the model's parameters.\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def train(dataloader, model, loss_fn, optimizer):\n", " size = len(dataloader.dataset)\n", " model.train()\n", " for batch, (X, y) in enumerate(dataloader):\n", " X, y = X.to(device), y.to(device)\n", "\n", " # Compute prediction error\n", " pred = model(X)\n", " loss = loss_fn(pred, y)\n", "\n", " # Backpropagation\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if batch % 100 == 0:\n", " loss, current = loss.item(), batch * len(X)\n", " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also check the model's performance against the test dataset to ensure it is learning.\n", "\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def test(dataloader, model, loss_fn):\n", " size = len(dataloader.dataset)\n", " num_batches = len(dataloader)\n", " model.eval()\n", " test_loss, correct = 0, 0\n", " with torch.no_grad():\n", " for X, y in dataloader:\n", " X, y = X.to(device), y.to(device)\n", " pred = model(X)\n", " test_loss += loss_fn(pred, y).item()\n", " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n", " test_loss /= num_batches\n", " correct /= size\n", " print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The training process is conducted over several iterations (*epochs*). During each epoch, the model learns\n", "parameters to make better predictions. We print the model's accuracy and loss at each epoch; we'd like to see the\n", "accuracy increase and the loss decrease with every epoch.\n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1\n", "-------------------------------\n", "loss: 2.293067 [ 0/60000]\n", "loss: 2.287422 [ 6400/60000]\n", "loss: 2.265790 [12800/60000]\n", "loss: 2.274793 [19200/60000]\n", "loss: 2.257332 [25600/60000]\n", "loss: 2.222204 [32000/60000]\n", "loss: 2.240200 [38400/60000]\n", "loss: 2.206084 [44800/60000]\n", "loss: 2.190236 [51200/60000]\n", "loss: 2.176934 [57600/60000]\n", "Test Error: \n", " Accuracy: 42.4%, Avg loss: 2.162450 \n", "\n", "Epoch 2\n", "-------------------------------\n", "loss: 2.161891 [ 0/60000]\n", "loss: 2.160867 [ 6400/60000]\n", "loss: 2.099223 [12800/60000]\n", "loss: 2.127940 [19200/60000]\n", "loss: 2.089684 [25600/60000]\n", "loss: 2.018054 [32000/60000]\n", "loss: 2.060461 [38400/60000]\n", "loss: 1.981958 [44800/60000]\n", "loss: 1.971331 [51200/60000]\n", "loss: 1.930486 [57600/60000]\n", "Test Error: \n", " Accuracy: 58.1%, Avg loss: 1.909495 \n", "\n", "Epoch 3\n", "-------------------------------\n", "loss: 1.930542 [ 0/60000]\n", "loss: 1.913976 [ 6400/60000]\n", "loss: 1.788895 [12800/60000]\n", "loss: 1.838503 [19200/60000]\n", "loss: 1.757226 [25600/60000]\n", "loss: 1.682464 [32000/60000]\n", "loss: 1.722755 [38400/60000]\n", "loss: 1.617113 [44800/60000]\n", "loss: 1.632282 [51200/60000]\n", "loss: 1.548769 [57600/60000]\n", "Test Error: \n", " Accuracy: 61.0%, Avg loss: 1.543196 \n", "\n", "Epoch 4\n", "-------------------------------\n", "loss: 1.601020 [ 0/60000]\n", "loss: 1.574128 [ 6400/60000]\n", "loss: 1.412696 [12800/60000]\n", "loss: 1.496537 [19200/60000]\n", "loss: 1.391789 [25600/60000]\n", "loss: 1.360881 [32000/60000]\n", "loss: 1.398112 [38400/60000]\n", "loss: 1.316551 [44800/60000]\n", "loss: 1.347136 [51200/60000]\n", "loss: 1.253991 [57600/60000]\n", "Test Error: \n", " Accuracy: 62.8%, Avg loss: 1.267020 \n", "\n", "Epoch 5\n", "-------------------------------\n", "loss: 1.336873 [ 0/60000]\n", "loss: 1.324502 [ 6400/60000]\n", "loss: 1.153551 [12800/60000]\n", "loss: 1.265215 [19200/60000]\n", "loss: 1.149221 [25600/60000]\n", "loss: 1.156962 [32000/60000]\n", "loss: 1.194912 [38400/60000]\n", "loss: 1.133846 [44800/60000]\n", "loss: 1.164861 [51200/60000]\n", "loss: 1.080542 [57600/60000]\n", "Test Error: \n", " Accuracy: 64.1%, Avg loss: 1.094896 \n", "\n", "Done!\n" ] } ], "source": [ "epochs = 5\n", "for t in range(epochs):\n", " print(f\"Epoch {t+1}\\n-------------------------------\")\n", " train(train_dataloader, model, loss_fn, optimizer)\n", " test(test_dataloader, model, loss_fn)\n", "print(\"Done!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read more about `Training your model `_.\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "--------------\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving Models\n", "-------------\n", "A common way to save a model is to serialize the internal state dictionary (containing the model parameters).\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved PyTorch Model State to model.pth\n" ] } ], "source": [ "torch.save(model.state_dict(), \"pytorch_model.bin\")\n", "print(\"Saved PyTorch Model State to pytorch_model.bin\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading Models\n", "----------------------------\n", "\n", "The process for loading a model includes re-creating the model structure and loading\n", "the state dictionary into it.\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = NeuralNetwork()\n", "model.load_state_dict(torch.load(\"model.pth\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This model can now be used to make predictions.\n", "\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Shirt\", Actual: \"Shirt\"\n" ] } ], "source": [ "classes = [\n", " \"T-shirt/top\",\n", " \"Trouser\",\n", " \"Pullover\",\n", " \"Dress\",\n", " \"Coat\",\n", " \"Sandal\",\n", " \"Shirt\",\n", " \"Sneaker\",\n", " \"Bag\",\n", " \"Ankle boot\",\n", "]\n", "\n", "model.eval()\n", "x, y = test_data[4][0], test_data[4][1]\n", "with torch.no_grad():\n", " pred = model(x)\n", " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0039, 0.0039, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.2235, 0.2627, 0.2863, 0.2980, 0.2980,\n", " 0.3255, 0.2431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.0000,\n", " 0.0510, 0.3098, 0.5020, 0.7882, 0.6353, 0.6314, 0.6784, 0.7529,\n", " 0.6745, 0.7098, 0.7216, 0.4235, 0.1176, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.4000,\n", " 0.5451, 0.5569, 0.4039, 0.4510, 0.6353, 0.6039, 0.6471, 0.6000,\n", " 0.5451, 0.5059, 0.5882, 0.5412, 0.6706, 0.6314, 0.1020, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.4157, 0.4863,\n", " 0.4235, 0.4039, 0.4157, 0.3647, 0.3922, 0.7059, 0.6118, 0.5765,\n", " 0.5412, 0.3333, 0.6157, 0.4471, 0.4863, 0.6039, 0.6157, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.1137, 0.5255, 0.3961,\n", " 0.4431, 0.4235, 0.3804, 0.4549, 0.3176, 0.5725, 0.7176, 0.6431,\n", " 0.4353, 0.5725, 0.5137, 0.4784, 0.5176, 0.5686, 0.6627, 0.3647,\n", " 0.0000, 0.0039, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2549, 0.5137, 0.4118,\n", " 0.3961, 0.4235, 0.3922, 0.4078, 0.3804, 0.2902, 0.8078, 0.6824,\n", " 0.4510, 0.5882, 0.4235, 0.4667, 0.5725, 0.5961, 0.6353, 0.5529,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4235, 0.4824, 0.4392,\n", " 0.4157, 0.3843, 0.3922, 0.3961, 0.4353, 0.2824, 0.5333, 0.5176,\n", " 0.4392, 0.4510, 0.4275, 0.5569, 0.5882, 0.6275, 0.6353, 0.7647,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5294, 0.4784, 0.4667,\n", " 0.4392, 0.3255, 0.3647, 0.3804, 0.4157, 0.4510, 0.3569, 0.4275,\n", " 0.3255, 0.4275, 0.4902, 0.6471, 0.5490, 0.7569, 0.6275, 0.6902,\n", " 0.0235, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0902, 0.5294, 0.5176, 0.5843,\n", " 0.4078, 0.3059, 0.3765, 0.3804, 0.4039, 0.4235, 0.4235, 0.4510,\n", " 0.3294, 0.4471, 0.5843, 0.6196, 0.5765, 0.8196, 0.6275, 0.6980,\n", " 0.2039, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.2235, 0.4863, 0.5137, 0.6275,\n", " 0.4039, 0.3765, 0.3961, 0.4275, 0.4275, 0.4353, 0.4235, 0.4471,\n", " 0.4157, 0.4431, 0.6118, 0.6392, 0.6118, 0.7686, 0.6549, 0.6824,\n", " 0.3333, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.3373, 0.4549, 0.4941, 0.6275,\n", " 0.5176, 0.4000, 0.3765, 0.4078, 0.4196, 0.3843, 0.3647, 0.4824,\n", " 0.4549, 0.4392, 0.5843, 0.6275, 0.7098, 0.7294, 0.6353, 0.6353,\n", " 0.4824, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.4392, 0.4471, 0.4392, 0.6549,\n", " 0.5725, 0.3922, 0.3922, 0.3961, 0.4196, 0.3765, 0.3922, 0.4941,\n", " 0.4039, 0.4706, 0.5529, 0.6196, 0.6549, 0.7333, 0.5765, 0.5804,\n", " 0.6667, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.4863, 0.4627, 0.3961, 0.7725,\n", " 0.3490, 0.3961, 0.3922, 0.3765, 0.4235, 0.4039, 0.4235, 0.4784,\n", " 0.4196, 0.4980, 0.5451, 0.5882, 0.4667, 0.7686, 0.5686, 0.5569,\n", " 0.7020, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.5137, 0.4510, 0.3804, 0.7765,\n", " 0.1843, 0.4235, 0.3765, 0.3765, 0.4157, 0.4667, 0.4000, 0.4706,\n", " 0.4039, 0.4824, 0.5490, 0.5882, 0.3176, 0.8078, 0.5725, 0.5294,\n", " 0.7608, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0157, 0.5333, 0.4627, 0.3843, 0.7569,\n", " 0.0824, 0.4275, 0.3765, 0.4157, 0.4000, 0.5059, 0.3922, 0.4667,\n", " 0.4000, 0.4627, 0.5529, 0.6000, 0.1765, 0.8471, 0.5804, 0.5451,\n", " 0.8039, 0.0471, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0941, 0.5373, 0.4588, 0.3961, 0.7333,\n", " 0.0980, 0.4431, 0.3608, 0.4392, 0.3686, 0.4706, 0.4118, 0.4980,\n", " 0.3804, 0.4510, 0.5569, 0.5882, 0.0745, 0.8353, 0.5804, 0.5137,\n", " 0.8000, 0.1412, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.1569, 0.5529, 0.4275, 0.4588, 0.6196,\n", " 0.0471, 0.4863, 0.3529, 0.4549, 0.3765, 0.4588, 0.4431, 0.5333,\n", " 0.3686, 0.4353, 0.5765, 0.6392, 0.1216, 0.7490, 0.5725, 0.5255,\n", " 0.8078, 0.2275, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.1529, 0.5059, 0.4000, 0.5765, 0.4667,\n", " 0.0000, 0.4706, 0.3529, 0.4667, 0.3961, 0.4549, 0.4157, 0.4980,\n", " 0.4000, 0.4471, 0.5725, 0.7059, 0.0784, 0.5725, 0.6235, 0.5059,\n", " 0.8000, 0.2745, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.2275, 0.4941, 0.4353, 0.6353, 0.3961,\n", " 0.0824, 0.5176, 0.3490, 0.4824, 0.4235, 0.4157, 0.4000, 0.4941,\n", " 0.4353, 0.4549, 0.5529, 0.6980, 0.1961, 0.4392, 0.6627, 0.5412,\n", " 0.6431, 0.3294, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.4235, 0.5255, 0.5255, 0.7255, 0.3294,\n", " 0.2863, 0.4824, 0.3412, 0.4784, 0.4353, 0.4000, 0.4157, 0.5020,\n", " 0.4471, 0.4275, 0.5255, 0.6824, 0.3804, 0.3843, 0.6275, 0.5765,\n", " 0.6863, 0.5294, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.3804, 0.5569, 0.6627, 0.7765, 0.1451,\n", " 0.3294, 0.4196, 0.3804, 0.4784, 0.4392, 0.4275, 0.4392, 0.4941,\n", " 0.4000, 0.3765, 0.5137, 0.6745, 0.5020, 0.2000, 0.9961, 0.6588,\n", " 0.6431, 0.4353, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0471, 0.1804, 0.0078,\n", " 0.4667, 0.4000, 0.4275, 0.4824, 0.3765, 0.4549, 0.4784, 0.5176,\n", " 0.4157, 0.4157, 0.5059, 0.5922, 0.7216, 0.1020, 0.0784, 0.0314,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0510,\n", " 0.5373, 0.3961, 0.4471, 0.3922, 0.4157, 0.5255, 0.5294, 0.5059,\n", " 0.4078, 0.4353, 0.4824, 0.5922, 0.7608, 0.2902, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118, 0.0000, 0.2863,\n", " 0.5176, 0.3961, 0.4078, 0.4000, 0.5490, 0.4235, 0.4235, 0.5137,\n", " 0.4157, 0.4667, 0.4431, 0.5569, 0.6549, 0.5294, 0.0000, 0.0039,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4392,\n", " 0.4627, 0.4196, 0.4078, 0.5451, 0.4275, 0.3804, 0.4824, 0.5412,\n", " 0.4196, 0.4980, 0.4706, 0.5333, 0.6314, 0.6235, 0.0000, 0.0000,\n", " 0.0039, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.5569,\n", " 0.5804, 0.4392, 0.4118, 0.3961, 0.3255, 0.4902, 0.4824, 0.5608,\n", " 0.4078, 0.4510, 0.3922, 0.4941, 0.6588, 0.6980, 0.0275, 0.0000,\n", " 0.0078, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0353,\n", " 0.4941, 0.7216, 0.7843, 0.6549, 0.6392, 0.6706, 0.5882, 0.6549,\n", " 0.6118, 0.6824, 0.7725, 0.7137, 0.6353, 0.2392, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.1176, 0.2824, 0.3725, 0.4275, 0.4353, 0.4353,\n", " 0.4157, 0.3961, 0.2784, 0.0471, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000]]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data[4][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read more about `Saving & Loading your model `_.\n", "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.1" } }, "nbformat": 4, "nbformat_minor": 0 }