{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a3278dc9-0d83-4a37-aece-e46ac416988f", "metadata": {}, "outputs": [], "source": [ "#| default_exp app" ] }, { "cell_type": "code", "execution_count": 2, "id": "d6810835-d62a-4f94-a52e-0e0cd163fb98", "metadata": {}, "outputs": [], "source": [ "#| export\n", "from fastai.vision.all import *\n", "import gradio as gr\n", "\n", "interpretation='default'\n", "enable_queue=True\n", "\n", "title = \"FastAI - Big Cats Classifier\"\n", "description = \"Classify big cats using all Resnet models available pre-trained in FastAI\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "6092ad61-d5cd-40f7-b2d2-20a77b0c8b0f", "metadata": {}, "outputs": [], "source": [ "#| export\n", "learners = {\n", " \"resnet-18\" : 'models/resnet18-model.pkl',\n", " \"resnet-34\" : 'models/resnet34-model.pkl',\n", " \"resnet-50\" : 'models/resnet50-model.pkl',\n", " \"resnet-101\": 'models/resnet101-model.pkl',\n", " \"resnet-152\": 'models/resnet152-model.pkl'\n", "}\n", "models = list(learners.keys())\n", "\n", "active_name = \"resnet-18\"\n", "active_model = learners[active_name]\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "632cbc1b-73b5-4992-8956-d4ae40f6b80b", "metadata": {}, "outputs": [], "source": [ "#| export\n", " \n", "def classify_image(img):\n", " learn = load_learner(active_model)\n", " pred,idx,probs = learn.predict(img)\n", " return dict(zip(learn.dls.vocab, map(float, probs)))\n", "\n", "def select_model(model_name):\n", " if model_name not in models:\n", " model_name = \"resnet-18\"\n", " active_name = model_name\n", " active_model = learners[active_name]\n", " return model_name.upper()\n", "\n", "def update_matrix():\n", " return \"models/\" + active_name.replace('-','',1) + \"-confusion-matrix.png\"\n", " \n", "def update_losses():\n", " return \"models/\" + active_name.replace('-','',1) + \"-top-losses.png\"\n", " " ] }, { "cell_type": "code", "execution_count": 5, "id": "9b5f1cc6-5173-475a-9365-0cab11db2d03", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 0.00045245178625918925, 'cheetah': 0.9994743466377258, 'clouded leopard': 3.061432778395101e-07, 'cougar': 8.726581654627807e-06, 'jaguar': 4.878858817392029e-05, 'lion': 1.4129628652881365e-05, 'snow leopard': 1.2738197483486147e-06, 'tiger': 1.1983513736879559e-08}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 8.918660228118824e-07, 'cheetah': 3.004239079729132e-09, 'clouded leopard': 1.0275688282490592e-06, 'cougar': 1.8215871477877954e-08, 'jaguar': 0.9999979734420776, 'lion': 7.327587425720594e-10, 'snow leopard': 1.3988608316140017e-07, 'tiger': 4.418302523845341e-08}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 1.279351291572084e-08, 'cheetah': 3.040315732505405e-08, 'clouded leopard': 4.387358387702989e-08, 'cougar': 1.2642824458453106e-06, 'jaguar': 3.0061545430726255e-07, 'lion': 2.5054502472698914e-08, 'snow leopard': 4.821659516096588e-08, 'tiger': 0.9999983310699463}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 2.2317146886052797e-06, 'cheetah': 6.153353297122521e-06, 'clouded leopard': 3.5761433991865488e-06, 'cougar': 0.9940788745880127, 'jaguar': 7.271950153153739e-08, 'lion': 0.005906379781663418, 'snow leopard': 1.0360908220263809e-07, 'tiger': 2.569006483099656e-06}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 7.383512135028525e-10, 'cheetah': 1.6924343526625307e-06, 'clouded leopard': 3.8847122740826023e-10, 'cougar': 1.4941306858418102e-08, 'jaguar': 3.277633942033731e-09, 'lion': 0.9999983310699463, 'snow leopard': 4.2623696572263725e-08, 'tiger': 5.7686470711360016e-08}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 0.11080536246299744, 'cheetah': 0.00025237080990336835, 'clouded leopard': 0.0003655211767181754, 'cougar': 1.1126862773380708e-05, 'jaguar': 0.8603838086128235, 'lion': 8.311066630994901e-05, 'snow leopard': 0.028046416118741035, 'tiger': 5.234780110185966e-05}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 5.991949336703328e-08, 'cheetah': 1.2888077272066312e-08, 'clouded leopard': 0.9999984502792358, 'cougar': 7.355600928349304e-07, 'jaguar': 5.131531679580803e-07, 'lion': 5.543293823961903e-09, 'snow leopard': 3.404375448212704e-08, 'tiger': 2.0324510785485472e-07}\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'african leopard': 2.2017589799361303e-05, 'cheetah': 9.802879503695294e-05, 'clouded leopard': 0.0109814228489995, 'cougar': 1.8166520021623e-06, 'jaguar': 5.0095695769414306e-06, 'lion': 5.28784084963263e-06, 'snow leopard': 0.988881528377533, 'tiger': 4.889693173026899e-06}\n" ] } ], "source": [ "example_images = [ 'cheetah.jpg', 'jaguar.jpg', 'tiger.jpg', 'cougar.jpg', 'lion.jpg', 'african leopard.jpg', 'clouded leopard.jpg', 'snow leopard.jpg' ]\n", "\n", "for c in example_images:\n", " im = PILImage.create(c)\n", " result = classify_image(im)\n", " print(result)" ] }, { "cell_type": "code", "execution_count": null, "id": "a48e7483-c04b-4048-a1ae-34a8c7986a57", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#| export\n", "example_images = [ 'cheetah.jpg', 'jaguar.jpg', 'tiger.jpg', 'cougar.jpg', 'lion.jpg', 'african leopard.jpg', 'clouded leopard.jpg', 'snow leopard.jpg', 'hidden.png', 'hidden2.png' ]\n", "\n", "demo = gr.Blocks()\n", "with demo:\n", " with gr.Column(variant=\"panel\"):\n", " image = gr.inputs.Image(label=\"Pick an image\")\n", " model = gr.inputs.Dropdown(label=\"Select a model\", choices=models)\n", " with gr.Row(equal_height=True):\n", " btnClassify = gr.Button(\"Classify\")\n", " btnClear = gr.Button(\"Clear\")\n", " with gr.Column(variant=\"panel\"):\n", " selected = gr.outputs.Textbox(label=\"Active Model\")\n", " with gr.Row(equal_height=True):\n", " matrix=gr.outputs.Image(type='filepath', label=\"Confusion Matrix\")\n", " losses=gr.outputs.Image(type='filepath', label=\"Top Losses\")\n", " result = gr.outputs.Label(label=\"Result\")\n", " \n", " img_gallery = gr.Examples(examples=example_images, inputs=image)\n", "\n", " # Register all event listeners\n", " model.change(fn=select_model, inputs=model, outputs=selected)\n", " model.change(fn=update_matrix, outputs=matrix)\n", " model.change(fn=update_losses, outputs=losses)\n", " btnClassify.click(fn=classify_image, inputs=image, outputs=result)\n", " btnClear.click(fn=lambda: gr.Image.update(value=None), inputs=None, outputs=None)\n", "\n", "demo.launch(debug=True, inline=False)\n", " # intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=example_images, title=title, description=description )\n", " # if __name__ == \"__main__\":\n", " # intf.launch(debug=True, inline=False)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "cab071f9-7c3b-4b35-a0d1-3687731ffce5", "metadata": {}, "outputs": [], "source": [ "import nbdev\n", "nbdev.export.nb_export('app.ipynb', './')\n", "print('Export successful')" ] }, { "cell_type": "code", "execution_count": null, "id": "95f0e7ec-edd2-4afa-a68f-7da8b85b1f61", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 5 }