{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"import plotly.express as px\n",
"from plotly.subplots import make_subplots\n",
"import numpy as np\n",
"\n",
"from dataset import DatasetMNIST, load_mnist\n",
"from trainer import LitTrainer\n",
"from models import CNN"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [],
"source": [
"net = torch.load(\"../checkpoints/pytorch/version_0.pt\")\n",
"net.eval()\n",
"\n",
"pl_net = LitTrainer.load_from_checkpoint(\"../checkpoints/lightning_logs/version_26/checkpoints/epoch=9-step=1000.ckpt\", model=CNN(1, 10))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"mnist = load_mnist(\"../downloads/mnist/\")\n",
"dataset, test_data = DatasetMNIST(*mnist[\"train\"]), DatasetMNIST(*mnist[\"test\"])"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [],
"source": [
"def show_sequence(model):\n",
" fig = make_subplots(rows=2, cols=5)\n",
"\n",
" i, j = 0, np.random.randint(0, 30000)\n",
"\n",
" while i < 10:\n",
" x, y = dataset[j]\n",
" y_pred = model(x.to(\"cuda\")).detach().cpu()\n",
" p = torch.max(nn.functional.softmax(y_pred, dim=0))\n",
" y_pred = int(np.argmax(y_pred))\n",
" if y_pred == i and p > 0.95:\n",
" img = np.flip(np.array(x.reshape(28, 28)), 0)\n",
" fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i%5+1)\n",
" i += 1\n",
" j += 1\n",
" fig.show()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch Lightning Network\n"
]
},
{
"data": {
"text/html": " \n "
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.plotly.v1+json": {
"data": [
{
"coloraxis": "coloraxis",
"hovertemplate": "x: %{x}
y: %{y}
color: %{z}
y: %{y}
color: %{z}
y: %{y}
color: %{z}
y: %{y}
color: %{z}
y: %{y}
color: %{z}
y: %{y}
color: %{z}
y: %{y}
color: %{z}
y: %{y}
color: %{z}
y: %{y}
color: %{z}
y: %{y}
color: %{z}