{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "87345732-d868-473b-b1a1-5c25839ce25b", "metadata": {}, "outputs": [], "source": [ "from fastai.vision.all import *" ] }, { "cell_type": "code", "execution_count": 2, "id": "79b9fbad-7b99-40fd-8768-b0a091bf85cb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/conda/envs/py310-cuda116/lib/python3.10/site-packages/paramiko/transport.py:236: CryptographyDeprecationWarning: Blowfish has been deprecated\n", " \"class\": algorithms.Blowfish,\n" ] } ], "source": [ "import gradio" ] }, { "cell_type": "code", "execution_count": 3, "id": "5409c6a7-5cae-42bb-8335-587a04471f22", "metadata": {}, "outputs": [], "source": [ "MODELS_PATH = Path(\"./models\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "4e836799-6858-438a-8d70-d95f98cf54f7", "metadata": {}, "outputs": [], "source": [ "EXAMPLES_PATH = Path('./examples')" ] }, { "cell_type": "code", "execution_count": 5, "id": "9ed20c60-9f23-4795-bb4b-79b00af0f6d1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#2) [Path('models/food-101-resnet34.pkl'),Path('models/food-101-resnet50.pkl')]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "MODELS_PATH.ls()" ] }, { "cell_type": "code", "execution_count": 6, "id": "0969ba8e-b0df-4550-a900-5d5a30fb0187", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#9) [Path('examples/pad_thai.jpeg'),Path('examples/takoyaki.jpeg'),Path('examples/momo.jpeg'),Path('examples/falafel.jpeg'),Path('examples/paella.jpeg'),Path('examples/ravioli.jpeg'),Path('examples/huevos_rancheros.jpeg'),Path('examples/edamame.jpeg'),Path('examples/sushi.jpeg')]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "EXAMPLES_PATH.ls()" ] }, { "cell_type": "code", "execution_count": 7, "id": "e9143742-c6bc-44f6-8ecd-3826502c84ac", "metadata": {}, "outputs": [], "source": [ "def label_func(filepath):\n", " return filepath.parent.name" ] }, { "cell_type": "code", "execution_count": 8, "id": "c6ad64e8-f163-4472-b2f0-c0aa50ead4d8", "metadata": {}, "outputs": [], "source": [ "learn = load_learner(MODELS_PATH/'food-101-resnet50.pkl')" ] }, { "cell_type": "code", "execution_count": 9, "id": "d1370d20-fd51-4512-bd28-5f170d216c7b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels = learn.dls.vocab\n", "labels" ] }, { "cell_type": "code", "execution_count": null, "id": "8f666b42-9fdd-45ca-81ca-7e98dd191369", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "id": "a360dd6b-75a9-43e5-b91d-c6963ea462ea", "metadata": {}, "outputs": [], "source": [ "def predict(img):\n", " img = PILImage.create(img)\n", " _pred, _pred_w_idx, probs = learn.predict(img)\n", " labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}\n", " return labels_probs" ] }, { "cell_type": "code", "execution_count": 12, "id": "febc7266-8587-4530-811b-f2fa9117dcd5", "metadata": {}, "outputs": [], "source": [ "with open('gradio_article.md') as f:\n", " article = f.read()" ] }, { "cell_type": "code", "execution_count": 13, "id": "8fd4ffb4-11ca-4b25-999c-cde2a4e236b4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/conda/envs/py310-cuda116/lib/python3.10/site-packages/gradio/interface.py:419: UserWarning: The `enable_queue` parameter in the `Interface`will be deprecated and may not work properly. Please use the `enable_queue` parameter in `launch()` instead\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://localhost:9999/\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(,\n", " 'http://localhost:9999/',\n", " None)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "interface_options = {\n", " \"title\": \"Food-101 Classifier\",\n", " \"description\": \"A food image classifier trained on the Food-101 (https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/) dataset with fastai with a ResNet50 CNN model.\",\n", " \"article\": article,\n", " \"examples\" : [f'{EXAMPLES_PATH}/{f.name}' for f in EXAMPLES_PATH.iterdir()],\n", " \"interpretation\": \"default\",\n", " \"layout\": \"horizontal\",\n", " \"allow_flagging\": \"never\",\n", " \"enable_queue\": True \n", "}\n", "\n", "demo = gradio.Interface(fn=predict,\n", " inputs=gradio.inputs.Image(shape=(512, 512)),\n", " outputs=gradio.outputs.Label(num_top_classes=5),\n", " **interface_options)\n", "\n", "demo_options = {\n", " \"inline\": True,\n", " \"inbrowser\": False,\n", " \"share\": False,\n", " \"show_error\": True,\n", " \"server_name\": \"0.0.0.0\",\n", " \"server_port\": 9999,\n", " \"enable_queue\": True,\n", "}\n", "\n", "demo.launch(**demo_options)" ] }, { "cell_type": "code", "execution_count": null, "id": "570f8a3c-367e-4a7f-808d-8fa2e925a444", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }