{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using the produced models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# import sys\n", "# sys.path.append('..')\n", "from pytorch_utils import *\n", "from lightning_utils import *\n", "from pytorch_vision_utils import *\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data_path = r'E:\\Data_and_Models\\Kaggle_Cards'\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reconstructing the models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Timings (on GTX 970) and last-epoch (therefore not necessarily best) loss and F1:\n", "\n", "| **Model** | **Retrained Portion** | **Epochs** | **Time** | **Epoch Time** | **train_loss** | **train_metric** | **test_loss** | **test_metric** |\n", "| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |\n", "| **EfficientNet B0** | Classification | 10 | 7:43 | 0:25 | 1.4660 | 0.5907 | 1.8822 | 0.4228 |\n", "| **EfficientNet B2** | Classification | 10 | 21:50 | 2:11 | 1.4443 | 0.6009 | 1.8505 | 0.4279 |\n", "| **RexNet 1.0** | Classification | 10 | 7:21 | 0:44 | 1.5509 | 0.5605 | 2.0055 | 0.4186 |\n", "| **RexNet 1.5** | Classification | 10 | 8:00 | 0:48 | 0.8666 | 0.7560 | 1.5839 | 0.4884 |\n", "| **RexNet 1.0** | Full Retrain | 10 | 30:49 | 3:04 | 0.0903 | 0.9732 | 0.2218 | 0.9390 |\n", "| **RexNet 1.0** | Full Retrain | val_loss early stop at 5 (equalling 2)| 16:32 | 2:40 | 0.2707 | 0.8750 | 0.0996 | 1 |\n", "| **RexNet 1.0** | Classification | val_loss early stop but all 10 (7 selected) | 6:55 | 0:38 | 2.564 | 0.1250 | 2.9857 | 0.2222 |\n", "| **RexNet 1.0 features -> LightGBM** | No retraining; feature extraction -> GB | 50 bagging gbdt iterations (OpenCL, not CUDA) | 1:08 | 1:53 to extract train and validation features | | | | 0.4188 |\n", "| **RexNet 1.0 features -> LightGBM** | No retraining; feature extraction -> GB | 100 bagging gbdt iterations (OpenCL, not CUDA) | 2:40 | 1:53 to extract train and validation features | | | | 0.4679 |\n", "| **RexNet 1.5 features -> LightGBM** | No retraining; feature extraction -> GB | 100 bagging gbdt iterations (OpenCL, not CUDA) | 3:37 | 1:50 to extract train and validation features | | | | 0.5433 |" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Recreate the model and load the saved parameters\n", "\n", "classes = [os.path.basename(p) for p in Path(fr'{data_path}\\test').glob('*')]\n", "\n", "experiment_name = ''\n", "\n", "# # EfficientNet B0\n", "# model_name, extra = 'EffNetB0', '0_First_Adam001_10_epochs'\n", "# model = tv.models.efficientnet_b0(weights = (weights := tv.models.EfficientNet_B0_Weights.DEFAULT)).to(device)\n", "# transforms = weights.transforms()\n", "# for param in model.features.parameters(): param.requires_grad = False\n", "# model.classifier = torch.nn.Sequential(\n", "# torch.nn.Dropout(p = 0.2, inplace = True),\n", "# torch.nn.Linear(in_features = 1280, out_features = len(classes), bias = True)\n", "# ).to(device)\n", "\n", "# # EfficientNet B2\n", "# model_name, extra = 'EffNetB2', '0_First_Adam001_10_epochs'\n", "# model = tv.models.efficientnet_b0(weights = (weights := tv.models.EfficientNet_B2_Weights.DEFAULT)).to(device)\n", "# transforms = weights.transforms()\n", "# for param in model.features.parameters(): param.requires_grad = False\n", "# model.classifier = torch.nn.Sequential(\n", "# torch.nn.Dropout(p = 0.3, inplace = True),\n", "# torch.nn.Linear(in_features = 1408, out_features = len(classes), bias = True)\n", "# ).to(device)\n", "\n", "# RexNet 1.0\n", "# model_name, extra = 'RexNet10', '0_First_Adam001_10_epochs'\n", "experiment_name, model_name, extra = 'FullRetrain_EarlyStop', 'RexNet10', 'Adam001_max10_epochs'\n", "# experiment_name, model_name, extra = 'ClassRetrain_EarlyStop', 'RexNet10', 'Adam001_max10_epochs'\n", "model = timm.create_model('rexnet_100.nav_in1k', pretrained = True, num_classes = 53).eval().to(device)\n", "transforms = timm.data.create_transform(**timm.data.resolve_model_data_config(model), is_training = False)\n", "for param in model.features.parameters(): param.requires_grad = False\n", "for param in model.stem.parameters(): param.requires_grad = False\n", "\n", "# # RexNet 1.5\n", "# model_name, extra = 'RexNet15', '0_First_Adam001_10_epochs'\n", "# model = timm.create_model('rexnet_150.nav_in1k', pretrained = True, num_classes = 53).eval().to(device)\n", "# transforms = timm.data.create_transform(**timm.data.resolve_model_data_config(model), is_training = False)\n", "# for param in model.features.parameters(): param.requires_grad = False\n", "# for param in model.stem.parameters(): param.requires_grad = False\n", "\n", "# model.classifier\n", "model.load_state_dict(torch.load(fr'{data_path}\\Models\\{experiment_name}_{model_name}_{extra}.pth'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PyTorch Lightning Models (can just import as non-lightning above though)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "experiment_name, model_name, extra = 'FullRetrain_EarlyStop', 'RexNet10', 'Adam001_max10_epochs'\n", "# experiment_name, model_name, extra = 'ClassRetrain_EarlyStop', 'RexNet10', 'Adam001_max10_epochs'\n", "\n", "model = timm.create_model('rexnet_100.nav_in1k', pretrained = True, num_classes = 53).eval()\n", "# for param in model.features.parameters(): param.requires_grad = False\n", "# for param in model.stem.parameters(): param.requires_grad = False\n", "\n", "transforms = timm.data.create_transform(**timm.data.resolve_model_data_config(model), is_training = False)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "# Define an extra metric beside the loss\n", "f1_fn = torchmetrics.F1Score(task = 'multiclass', num_classes = 53)\n", "accuracy_fn = torchmetrics.Accuracy(task = 'multiclass', num_classes = 53)\n", "\n", "# Individual prediction input pipeline (though better through automated batches)\n", "# with torch.inference_mode(): pred_logit = model(transforms(img).unsqueeze(dim = 0).to(device)) # Prepend \"batch\" dimension (-> [batch_size, color_channels, height, width])\n", "\n", "def prediction_fn(logits): return torch.argmax(torch.softmax(logits, dim = 1), dim = 1)\n", "\n", "\n", "## The two below imports are for .pth vs .ckpt saved files;\n", "## in both cases the parameters are only required for the definition and are not actually used\n", "## Again, though, there is no real need to import a pth into the Lightning-wrapped (i.e. Strike) version of the model,\n", "## as .load_state dict on the wrapped model type works just fine\n", "\n", "# # .pth\n", "# model = Strike(model, loss_fn = loss_fn, metric_name_and_fn = ('F1', f1_fn),\n", "# optimiser_factory = lambda m: torch.optim.Adam(m.parameters(), lr = m.learning_rate),\n", "# prediction_fn = prediction_fn, learning_rate = 0.001, log_at_every_step = False)\n", "# state_dict = torch.load(fr'{data_path}\\Models\\{experiment_name}_{model_name}_{extra}.pth')\n", "# model.load_state_dict(state_dict)\n", "\n", "# .ckpt\n", "# Need to give arguments again since many are non-pickleable with .save_hyperparameters\n", "# Could solve by moving those ones to a function producing the class\n", "model = Strike.load_from_checkpoint(fr'{data_path}\\Models\\{experiment_name}_{model_name}_{extra}.ckpt',\n", " model = model, loss_fn = loss_fn, metric_name_and_fn = ('F1', f1_fn),\n", " optimiser_factory = lambda m: torch.optim.Adam(m.parameters(), lr = m.learning_rate),\n", " prediction_fn = prediction_fn, learning_rate = 0.001, log_at_every_step = False)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "=============================================================================================================================\n", "Layer (type (var_name)) Input Shape Output Shape Param # Trainable\n", "=============================================================================================================================\n", "RexNet (RexNet) [32, 3, 224, 224] [32, 53] -- Partial\n", "├─ConvNormAct (stem) [32, 3, 224, 224] [32, 32, 112, 112] -- False\n", "│ └─Conv2d (conv) [32, 3, 224, 224] [32, 32, 112, 112] (864) False\n", "│ └─BatchNormAct2d (bn) [32, 32, 112, 112] [32, 32, 112, 112] 64 False\n", "│ │ └─Identity (drop) [32, 32, 112, 112] [32, 32, 112, 112] -- --\n", "│ │ └─SiLU (act) [32, 32, 112, 112] [32, 32, 112, 112] -- --\n", "├─Sequential (features) [32, 32, 112, 112] [32, 1280, 7, 7] -- False\n", "│ └─LinearBottleneck (0) [32, 32, 112, 112] [32, 16, 112, 112] -- False\n", "│ │ └─ConvNormAct (conv_dw) [32, 32, 112, 112] [32, 32, 112, 112] (352) False\n", "│ │ └─ReLU6 (act_dw) [32, 32, 112, 112] [32, 32, 112, 112] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 32, 112, 112] [32, 16, 112, 112] (544) False\n", "│ └─LinearBottleneck (1) [32, 16, 112, 112] [32, 27, 56, 56] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 16, 112, 112] [32, 96, 112, 112] (1,728) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 96, 112, 112] [32, 96, 56, 56] (1,056) False\n", "│ │ └─ReLU6 (act_dw) [32, 96, 56, 56] [32, 96, 56, 56] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 96, 56, 56] [32, 27, 56, 56] (2,646) False\n", "│ └─LinearBottleneck (2) [32, 27, 56, 56] [32, 38, 56, 56] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 27, 56, 56] [32, 162, 56, 56] (4,698) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 162, 56, 56] [32, 162, 56, 56] (1,782) False\n", "│ │ └─ReLU6 (act_dw) [32, 162, 56, 56] [32, 162, 56, 56] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 162, 56, 56] [32, 38, 56, 56] (6,232) False\n", "│ └─LinearBottleneck (3) [32, 38, 56, 56] [32, 50, 28, 28] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 38, 56, 56] [32, 228, 56, 56] (9,120) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 228, 56, 56] [32, 228, 28, 28] (2,508) False\n", "│ │ └─SEModule (se) [32, 228, 28, 28] [32, 228, 28, 28] (8,949) False\n", "│ │ └─ReLU6 (act_dw) [32, 228, 28, 28] [32, 228, 28, 28] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 228, 28, 28] [32, 50, 28, 28] (11,500) False\n", "│ └─LinearBottleneck (4) [32, 50, 28, 28] [32, 61, 28, 28] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 50, 28, 28] [32, 300, 28, 28] (15,600) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 300, 28, 28] [32, 300, 28, 28] (3,300) False\n", "│ │ └─SEModule (se) [32, 300, 28, 28] [32, 300, 28, 28] (15,375) False\n", "│ │ └─ReLU6 (act_dw) [32, 300, 28, 28] [32, 300, 28, 28] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 300, 28, 28] [32, 61, 28, 28] (18,422) False\n", "│ └─LinearBottleneck (5) [32, 61, 28, 28] [32, 72, 14, 14] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 61, 28, 28] [32, 366, 28, 28] (23,058) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 366, 28, 28] [32, 366, 14, 14] (4,026) False\n", "│ │ └─SEModule (se) [32, 366, 14, 14] [32, 366, 14, 14] (22,416) False\n", "│ │ └─ReLU6 (act_dw) [32, 366, 14, 14] [32, 366, 14, 14] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 366, 14, 14] [32, 72, 14, 14] (26,496) False\n", "│ └─LinearBottleneck (6) [32, 72, 14, 14] [32, 84, 14, 14] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 72, 14, 14] [32, 432, 14, 14] (31,968) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 432, 14, 14] [32, 432, 14, 14] (4,752) False\n", "│ │ └─SEModule (se) [32, 432, 14, 14] [32, 432, 14, 14] (31,644) False\n", "│ │ └─ReLU6 (act_dw) [32, 432, 14, 14] [32, 432, 14, 14] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 432, 14, 14] [32, 84, 14, 14] (36,456) False\n", "│ └─LinearBottleneck (7) [32, 84, 14, 14] [32, 95, 14, 14] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 84, 14, 14] [32, 504, 14, 14] (43,344) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 504, 14, 14] [32, 504, 14, 14] (5,544) False\n", "│ │ └─SEModule (se) [32, 504, 14, 14] [32, 504, 14, 14] (42,966) False\n", "│ │ └─ReLU6 (act_dw) [32, 504, 14, 14] [32, 504, 14, 14] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 504, 14, 14] [32, 95, 14, 14] (48,070) False\n", "│ └─LinearBottleneck (8) [32, 95, 14, 14] [32, 106, 14, 14] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 95, 14, 14] [32, 570, 14, 14] (55,290) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 570, 14, 14] [32, 570, 14, 14] (6,270) False\n", "│ │ └─SEModule (se) [32, 570, 14, 14] [32, 570, 14, 14] (54,291) False\n", "│ │ └─ReLU6 (act_dw) [32, 570, 14, 14] [32, 570, 14, 14] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 570, 14, 14] [32, 106, 14, 14] (60,632) False\n", "│ └─LinearBottleneck (9) [32, 106, 14, 14] [32, 117, 14, 14] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 106, 14, 14] [32, 636, 14, 14] (68,688) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 636, 14, 14] [32, 636, 14, 14] (6,996) False\n", "│ │ └─SEModule (se) [32, 636, 14, 14] [32, 636, 14, 14] (68,211) False\n", "│ │ └─ReLU6 (act_dw) [32, 636, 14, 14] [32, 636, 14, 14] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 636, 14, 14] [32, 117, 14, 14] (74,646) False\n", "│ └─LinearBottleneck (10) [32, 117, 14, 14] [32, 128, 14, 14] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 117, 14, 14] [32, 702, 14, 14] (83,538) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 702, 14, 14] [32, 702, 14, 14] (7,722) False\n", "│ │ └─SEModule (se) [32, 702, 14, 14] [32, 702, 14, 14] (82,308) False\n", "│ │ └─ReLU6 (act_dw) [32, 702, 14, 14] [32, 702, 14, 14] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 702, 14, 14] [32, 128, 14, 14] (90,112) False\n", "│ └─LinearBottleneck (11) [32, 128, 14, 14] [32, 140, 7, 7] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 128, 14, 14] [32, 768, 14, 14] (99,840) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 768, 14, 14] [32, 768, 7, 7] (8,448) False\n", "│ │ └─SEModule (se) [32, 768, 7, 7] [32, 768, 7, 7] (99,264) False\n", "│ │ └─ReLU6 (act_dw) [32, 768, 7, 7] [32, 768, 7, 7] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 768, 7, 7] [32, 140, 7, 7] (107,800) False\n", "│ └─LinearBottleneck (12) [32, 140, 7, 7] [32, 151, 7, 7] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 140, 7, 7] [32, 840, 7, 7] (119,280) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 840, 7, 7] [32, 840, 7, 7] (9,240) False\n", "│ │ └─SEModule (se) [32, 840, 7, 7] [32, 840, 7, 7] (118,650) False\n", "│ │ └─ReLU6 (act_dw) [32, 840, 7, 7] [32, 840, 7, 7] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 840, 7, 7] [32, 151, 7, 7] (127,142) False\n", "│ └─LinearBottleneck (13) [32, 151, 7, 7] [32, 162, 7, 7] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 151, 7, 7] [32, 906, 7, 7] (138,618) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 906, 7, 7] [32, 906, 7, 7] (9,966) False\n", "│ │ └─SEModule (se) [32, 906, 7, 7] [32, 906, 7, 7] (137,031) False\n", "│ │ └─ReLU6 (act_dw) [32, 906, 7, 7] [32, 906, 7, 7] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 906, 7, 7] [32, 162, 7, 7] (147,096) False\n", "│ └─LinearBottleneck (14) [32, 162, 7, 7] [32, 174, 7, 7] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 162, 7, 7] [32, 972, 7, 7] (159,408) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 972, 7, 7] [32, 972, 7, 7] (10,692) False\n", "│ │ └─SEModule (se) [32, 972, 7, 7] [32, 972, 7, 7] (158,679) False\n", "│ │ └─ReLU6 (act_dw) [32, 972, 7, 7] [32, 972, 7, 7] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 972, 7, 7] [32, 174, 7, 7] (169,476) False\n", "│ └─LinearBottleneck (15) [32, 174, 7, 7] [32, 185, 7, 7] -- False\n", "│ │ └─ConvNormAct (conv_exp) [32, 174, 7, 7] [32, 1044, 7, 7] (183,744) False\n", "│ │ └─ConvNormAct (conv_dw) [32, 1044, 7, 7] [32, 1044, 7, 7] (11,484) False\n", "│ │ └─SEModule (se) [32, 1044, 7, 7] [32, 1044, 7, 7] (182,961) False\n", "│ │ └─ReLU6 (act_dw) [32, 1044, 7, 7] [32, 1044, 7, 7] -- --\n", "│ │ └─ConvNormAct (conv_pwl) [32, 1044, 7, 7] [32, 185, 7, 7] (193,510) False\n", "│ └─ConvNormAct (16) [32, 185, 7, 7] [32, 1280, 7, 7] -- False\n", "│ │ └─Conv2d (conv) [32, 185, 7, 7] [32, 1280, 7, 7] (236,800) False\n", "│ │ └─BatchNormAct2d (bn) [32, 1280, 7, 7] [32, 1280, 7, 7] (2,560) False\n", "├─ClassifierHead (head) [32, 1280, 7, 7] [32, 53] -- True\n", "│ └─SelectAdaptivePool2d (global_pool) [32, 1280, 7, 7] [32, 1280] -- --\n", "│ │ └─AdaptiveAvgPool2d (pool) [32, 1280, 7, 7] [32, 1280, 1, 1] -- --\n", "│ │ └─Flatten (flatten) [32, 1280, 1, 1] [32, 1280] -- --\n", "│ └─Dropout (drop) [32, 1280] [32, 1280] -- --\n", "│ └─Linear (fc) [32, 1280] [32, 53] 67,893 True\n", "│ └─Identity (flatten) [32, 53] [32, 53] -- --\n", "=============================================================================================================================\n", "Total params: 3,583,766\n", "Trainable params: 67,893\n", "Non-trainable params: 3,515,873\n", "Total mult-adds (Units.GIGABYTES): 12.71\n", "=============================================================================================================================\n", "Input size (MB): 19.27\n", "Forward/backward pass size (MB): 1904.75\n", "Params size (MB): 14.18\n", "Estimated Total Size (MB): 1938.20\n", "=============================================================================================================================" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "summ(model, input_size = (32, 3, 224, 224))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predictions" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "972b30b0cbce4f5781fcebc69190e9ab", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/265 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
pathtrue_classpred_classpred_probcorrect
189E:\\Data_and_Models\\Kaggle_Cards\\test\\six of cl...six of clubsthree of clubs0.852508False
112E:\\Data_and_Models\\Kaggle_Cards\\test\\king of d...king of diamondsqueen of diamonds0.711170False
167E:\\Data_and_Models\\Kaggle_Cards\\test\\seven of ...seven of clubsfive of clubs0.691422False
153E:\\Data_and_Models\\Kaggle_Cards\\test\\queen of ...queen of diamondssix of diamonds0.667886False
2E:\\Data_and_Models\\Kaggle_Cards\\test\\ace of cl...ace of clubstwo of clubs0.582781False
..................
194E:\\Data_and_Models\\Kaggle_Cards\\test\\six of di...six of diamondssix of diamonds0.562002True
195E:\\Data_and_Models\\Kaggle_Cards\\test\\six of he...six of heartssix of hearts0.408652True
97E:\\Data_and_Models\\Kaggle_Cards\\test\\jack of s...jack of spadesjack of spades0.376405True
148E:\\Data_and_Models\\Kaggle_Cards\\test\\queen of ...queen of clubsqueen of clubs0.319073True
104E:\\Data_and_Models\\Kaggle_Cards\\test\\joker\\5.jpgjokerjoker0.276672True
\n", "

265 rows × 5 columns

\n", "" ], "text/plain": [ " path true_class \\\n", "189 E:\\Data_and_Models\\Kaggle_Cards\\test\\six of cl... six of clubs \n", "112 E:\\Data_and_Models\\Kaggle_Cards\\test\\king of d... king of diamonds \n", "167 E:\\Data_and_Models\\Kaggle_Cards\\test\\seven of ... seven of clubs \n", "153 E:\\Data_and_Models\\Kaggle_Cards\\test\\queen of ... queen of diamonds \n", "2 E:\\Data_and_Models\\Kaggle_Cards\\test\\ace of cl... ace of clubs \n", ".. ... ... \n", "194 E:\\Data_and_Models\\Kaggle_Cards\\test\\six of di... six of diamonds \n", "195 E:\\Data_and_Models\\Kaggle_Cards\\test\\six of he... six of hearts \n", "97 E:\\Data_and_Models\\Kaggle_Cards\\test\\jack of s... jack of spades \n", "148 E:\\Data_and_Models\\Kaggle_Cards\\test\\queen of ... queen of clubs \n", "104 E:\\Data_and_Models\\Kaggle_Cards\\test\\joker\\5.jpg joker \n", "\n", " pred_class pred_prob correct \n", "189 three of clubs 0.852508 False \n", "112 queen of diamonds 0.711170 False \n", "167 five of clubs 0.691422 False \n", "153 six of diamonds 0.667886 False \n", "2 two of clubs 0.582781 False \n", ".. ... ... ... \n", "194 six of diamonds 0.562002 True \n", "195 six of hearts 0.408652 True \n", "97 jack of spades 0.376405 True \n", "148 queen of clubs 0.319073 True \n", "104 joker 0.276672 True \n", "\n", "[265 rows x 5 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image_paths = list(Path(fr'{data_path}\\test').glob('*/*.jpg'))\n", "\n", "image_df = record_image_preds(image_paths = image_paths, model = model, transform = transforms, class_names = classes)\n", "image_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Random sample" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "" ], "text/plain": [ "alt.VConcatChart(...)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image_pred_grid(image_df.sample(12).copy(), ncols = 6, img_width = 200, img_height = 200, allow_1_col_reduction = True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Most-wrong predictions" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", "" ], "text/plain": [ "alt.VConcatChart(...)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image_pred_grid(image_df.head(12).copy(), ncols = 6, img_width = 200, img_height = 200, allow_1_col_reduction = True)" ] } ], "metadata": { "kernelspec": { "display_name": "ML11", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 2 }