{ "cells": [ { "cell_type": "code", "execution_count": 15, "id": "b60d911e-b970-4ac1-ac19-d8e7a3b8652e", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Load the model onto the CPU\n", "AlexNet_Model = torch.load('AlexNet_Model.pt', map_location=torch.device('cpu'))\n", "\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "2d154a7e-f3ba-4f52-8489-c45e9c736a9c", "metadata": {}, "outputs": [], "source": [ "import torchvision.transforms as transforms\n", "transform = transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", "])" ] }, { "cell_type": "code", "execution_count": 17, "id": "390a39ff-b95c-48bd-8556-2e6cc8fe8d40", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image: car.jpeg\n", "Class: Car, Probability: 1.0\n", "Class: Truck, Probability: 1.9733740683203216e-11\n", "Class: Ship, Probability: 1.0144151279612226e-16\n", "Class: Airplane, Probability: 8.654194920607349e-20\n", "Class: Frog, Probability: 5.849379056607012e-28\n", "Class: Deer, Probability: 2.1034692384348658e-29\n", "Class: Bird, Probability: 8.280908192503261e-30\n", "Class: Horse, Probability: 3.741535700249233e-30\n", "Class: Cat, Probability: 4.912921745335171e-31\n", "Class: Dog, Probability: 1.9247921025898383e-33\n" ] } ], "source": [ "\n", "import torch\n", "from PIL import Image\n", "\n", "classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')\n", "# Define a function to preprocess the image\n", "def preprocess_image(image_path):\n", " image = Image.open(image_path)\n", " image = transform(image).unsqueeze(0)\n", " return image\n", "\n", "# Define the image paths\n", "image_paths = [\n", " \"car.jpeg\"\n", "]\n", "\n", "# Preprocess the images\n", "preprocessed_images = [preprocess_image(image_path) for image_path in image_paths]\n", "\n", "# Make predictions\n", "with torch.no_grad():\n", " predictions = AlexNet_Model(torch.cat(preprocessed_images, dim=0))\n", "\n", "# Get the predicted probabilities for each class\n", "predicted_probabilities = torch.nn.functional.softmax(predictions, dim=1)\n", "\n", "# Print the probabilities for each class sorted by probability\n", "for image_path, probabilities in zip(image_paths, predicted_probabilities):\n", " print(f\"Image: {image_path}\")\n", " sorted_indices = torch.argsort(probabilities, descending=True)\n", " for idx in sorted_indices:\n", " class_name = classes[idx]\n", " prob = probabilities[idx].item()\n", " print(f\"Class: {class_name}, Probability: {prob}\")\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "e01098f6-0b9b-4991-a6d4-29f9fa146ca9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7868\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "name": "stdout", "output_type": "stream", "text": [ "[('Car', 1.0), ('Truck', 1.9733740683203216e-11), ('Ship', 1.0144151279612226e-16), ('Airplane', 8.654194920607349e-20), ('Frog', 5.849379056607012e-28), ('Deer', 2.1034692384348658e-29), ('Bird', 8.280908192503261e-30), ('Horse', 3.741535700249233e-30), ('Cat', 4.912921745335171e-31), ('Dog', 1.9247921025898383e-33)]\n", "[('Airplane', 0.5250793695449829), ('Truck', 0.3982219994068146), ('Car', 0.05421096831560135), ('Ship', 0.022192474454641342), ('Dog', 8.419439836870879e-05), ('Cat', 7.965308031998575e-05), ('Frog', 7.490476127713919e-05), ('Horse', 2.9290411475813016e-05), ('Bird', 2.588589813967701e-05), ('Deer', 1.2184905244794209e-06)]\n" ] } ], "source": [ "import torch\n", "from PIL import Image\n", "import torchvision.transforms as transforms\n", "import gradio as gr\n", "import numpy as np\n", "\n", "# Load the model onto the CPU\n", "AlexNet_Model = torch.load('AlexNet_Model.pt', map_location=torch.device('cpu'))\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", "])\n", "\n", "classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')\n", "\n", "# Define a function to preprocess the image\n", "def preprocess_image(image):\n", " if isinstance(image, np.ndarray):\n", " image = Image.fromarray(image)\n", " image = transform(image).unsqueeze(0)\n", " return image\n", "\n", "# Define the prediction function\n", "def predict(image):\n", " image = preprocess_image(image)\n", " with torch.no_grad():\n", " predictions = AlexNet_Model(image)\n", "\n", " # Get the predicted probabilities for each class\n", " predicted_probabilities = torch.nn.functional.softmax(predictions, dim=1)\n", " results = []\n", " for probabilities in (predicted_probabilities):\n", " sorted_indices = torch.argsort(probabilities, descending=True)\n", " for idx in sorted_indices:\n", " class_name = classes[idx]\n", " prob = probabilities[idx].item()\n", " #print(f\"Class: {class_name}, Probability: {prob}\")\n", " results.append((class_name, prob))\n", "\n", "\n", " return {class_name: prob for class_name, prob in results}\n", "\n", "# Create Gradio interface\n", "iface = gr.Interface(predict, \n", " inputs=\"image\", \n", " outputs=\"label\", \n", " title=\"AlexNet Image Classifier\",\n", " description=\"Classify images into one of 10 classes: Airplane, Car, Bird, Cat, Deer, Dog, Frog, Horse, Ship, or Truck.\")\n", "iface.launch()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7f6dac97-96d4-45f2-b630-9ecb6b2ca159", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 5 }