{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a3278dc9-0d83-4a37-aece-e46ac416988f", "metadata": {}, "outputs": [], "source": [ "#| default_exp app" ] }, { "cell_type": "code", "execution_count": 2, "id": "d6810835-d62a-4f94-a52e-0e0cd163fb98", "metadata": {}, "outputs": [], "source": [ "#| export\n", "from fastai.vision.all import *\n", "import gradio as gr\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "title = \"FastAI - Big Cats Classifier\"\n", "description = \"Classify big cats using all Resnet models available pre-trained in FastAI\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "6092ad61-d5cd-40f7-b2d2-20a77b0c8b0f", "metadata": {}, "outputs": [], "source": [ "#| export\n", "learners = {\n", " \"resnet-18\" : 'models/resnet18-model.pkl',\n", " \"resnet-34\" : 'models/resnet34-model.pkl',\n", " \"resnet-50\" : 'models/resnet50-model.pkl',\n", " \"resnet-101\": 'models/resnet101-model.pkl',\n", " \"resnet-152\": 'models/resnet152-model.pkl'\n", "}\n", "models = list(learners.keys())\n", "\n", "active_name = \"resnet-18\"\n", "active_model = learners[active_name]\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "632cbc1b-73b5-4992-8956-d4ae40f6b80b", "metadata": {}, "outputs": [], "source": [ "#| export\n", " \n", "def classify_image(img):\n", " learn = load_learner(active_model)\n", " pred,idx,probs = learn.predict(img)\n", " return dict(zip(learn.dls.vocab, map(float, probs)))\n", "\n", "def select_model(model_name):\n", " if model_name not in models:\n", " model_name = \"resnet-18\"\n", " active_name = model_name\n", " active_model = learners[active_name]\n", " return model_name.upper()\n", "\n", "def update_matrix():\n", " return \"models/\" + active_name.replace('-','',1) + \"-confusion-matrix.png\"\n", " \n", "def update_losses():\n", " return \"models/\" + active_name.replace('-','',1) + \"-top-losses.png\"\n", " " ] }, { "cell_type": "code", "execution_count": 5, "id": "9b5f1cc6-5173-475a-9365-0cab11db2d03", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 0.0005852991016581655, 'cheetah': 0.9993988275527954, 'clouded leopard': 1.7600793000838166e-07, 'cougar': 6.112059963925276e-06, 'jaguar': 7.491902579204179e-06, 'lion': 1.3097942428430542e-06, 'snow leopard': 6.794325599912554e-07, 'tiger': 1.22832446436405e-07}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 0.2962114214897156, 'cheetah': 2.706606210267637e-05, 'clouded leopard': 0.0008470952161587775, 'cougar': 1.0193979505856987e-05, 'jaguar': 0.701975405216217, 'lion': 1.3766093616141006e-05, 'snow leopard': 0.0008549779886379838, 'tiger': 6.007726915413514e-05}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 2.0210626061611947e-08, 'cheetah': 1.6748231246310752e-08, 'clouded leopard': 1.1174745395692298e-06, 'cougar': 2.63490710494807e-06, 'jaguar': 2.399448703727103e-06, 'lion': 6.196571433747522e-08, 'snow leopard': 2.4245096028607804e-06, 'tiger': 0.9999912977218628}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 9.39465026021935e-05, 'cheetah': 0.00021114452101755887, 'clouded leopard': 8.688175876159221e-05, 'cougar': 0.9761292934417725, 'jaguar': 7.082346655806759e-06, 'lion': 0.02333180606365204, 'snow leopard': 0.00011577722762012854, 'tiger': 2.4006889361771755e-05}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 1.3545766286426897e-08, 'cheetah': 2.635677674334147e-06, 'clouded leopard': 7.659965994832874e-09, 'cougar': 9.957815017003213e-09, 'jaguar': 1.497639772196635e-07, 'lion': 0.9999957084655762, 'snow leopard': 1.294516778216348e-07, 'tiger': 1.2779944427165901e-06}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 0.024091463536024094, 'cheetah': 0.0014163728337734938, 'clouded leopard': 0.008692733943462372, 'cougar': 0.0010448594111949205, 'jaguar': 0.7156786322593689, 'lion': 0.017859801650047302, 'snow leopard': 0.22819218039512634, 'tiger': 0.0030239589978009462}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 7.144178198359441e-06, 'cheetah': 3.725538704202336e-07, 'clouded leopard': 0.9994736313819885, 'cougar': 6.0378228226909414e-05, 'jaguar': 3.279747033957392e-05, 'lion': 1.1806019273308266e-07, 'snow leopard': 0.0003000575816258788, 'tiger': 0.0001255277602467686}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 2.8642458346439525e-05, 'cheetah': 0.00017579919949639589, 'clouded leopard': 0.08972200006246567, 'cougar': 7.897598698036745e-05, 'jaguar': 2.5307128453277983e-05, 'lion': 1.8576161892269738e-05, 'snow leopard': 0.9099361896514893, 'tiger': 1.4485961401078384e-05}\n" ] } ], "source": [ "example_images = [ 'cheetah.jpg', 'jaguar.jpg', 'tiger.jpg', 'cougar.jpg', 'lion.jpg', 'african leopard.jpg', 'clouded leopard.jpg', 'snow leopard.jpg' ]\n", "\n", "for c in example_images:\n", " im = PILImage.create(c)\n", " result = classify_image(im)\n", " print(result)" ] }, { "cell_type": "code", "execution_count": null, "id": "a48e7483-c04b-4048-a1ae-34a8c7986a57", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#| export\n", "example_images = [ 'cheetah.jpg', 'jaguar.jpg', 'tiger.jpg', 'cougar.jpg', 'lion.jpg', 'african leopard.jpg', 'clouded leopard.jpg', 'snow leopard.jpg', 'hidden.png', 'hidden2.png' ]\n", "\n", "model_matrix = [ 'models/resnet101-confusion-matrix.png', 'models/resnet18-confusion-matrix.png', 'models/resnet50-confusion-matrix.png',\n", "'models/resnet152-confusion-matrix.png', 'models/resnet34-confusion-matrix.png' ]\n", "\n", "model_losses = [ 'models/resnet101-top-losses.png', 'models/resnet18-top-losses.png', 'models/resnet50-top-losses.png',\n", "'models/resnet152-top-losses.png', 'models/resnet34-top-losses.png' ]\n", "\n", "demo = gr.Blocks()\n", "with demo:\n", " with gr.Column(variant=\"panel\"):\n", " image = gr.inputs.Image(label=\"Pick an image\")\n", " model = gr.inputs.Dropdown(label=\"Select a model\", choices=models)\n", " btnClassify = gr.Button(\"Classify\")\n", " with gr.Column(variant=\"panel\"):\n", " selected = gr.outputs.Textbox(label=\"Active Model\")\n", " with gr.Row(equal_height=True):\n", " matrix=gr.outputs.Image(type='filepath', label=\"Confusion Matrix\")\n", " losses=gr.outputs.Image(type='filepath', label=\"Top Losses\")\n", " result = gr.outputs.Label(label=\"Result\")\n", " \n", " model.change(fn=select_model, inputs=model, outputs=selected)\n", " btnClassify.click(fn=classify_image, inputs=image, outputs=result)\n", " img_gallery = gr.Examples(examples=example_images, inputs=image)\n", " matrix_gallery = gr.Examples(examples=model_matrix, label='Models Confusion Matrix', inputs=matrix)\n", " loss_gallery = gr.Examples(examples=model_losses, label='Models Top Losses', inputs=losses)\n", " result.change(fn=update_matrix, outputs=matrix)\n", " result.change(fn=update_losses, outputs=losses)\n", "\n", "demo.launch(debug=True, inline=False)\n", " # intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=example_images, title=title, description=description )\n", " # if __name__ == \"__main__\":\n", " # intf.launch(debug=True, inline=False)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "cab071f9-7c3b-4b35-a0d1-3687731ffce5", "metadata": {}, "outputs": [], "source": [ "import nbdev\n", "nbdev.export.nb_export('app.ipynb', './')\n", "print('Export successful')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 5 }