membrfab commited on
Commit
c4dfd3e
1 Parent(s): b779da0

Upload app.ipynb

Browse files
Files changed (1) hide show
  1. app.ipynb +157 -0
app.ipynb ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import gradio as gr\n",
10
+ "import tensorflow as tf\n",
11
+ "import numpy as np\n",
12
+ "from PIL import Image"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 3,
18
+ "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "name": "stdout",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "Running on local URL: http://127.0.0.1:7861\n",
25
+ "\n",
26
+ "To create a public link, set `share=True` in `launch()`.\n"
27
+ ]
28
+ },
29
+ {
30
+ "data": {
31
+ "text/html": [
32
+ "<div><iframe src=\"http://127.0.0.1:7861/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
33
+ ],
34
+ "text/plain": [
35
+ "<IPython.core.display.HTML object>"
36
+ ]
37
+ },
38
+ "metadata": {},
39
+ "output_type": "display_data"
40
+ },
41
+ {
42
+ "name": "stdout",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "Running on local URL: http://127.0.0.1:7862\n",
46
+ "\n",
47
+ "To create a public link, set `share=True` in `launch()`.\n"
48
+ ]
49
+ },
50
+ {
51
+ "data": {
52
+ "text/html": [
53
+ "<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
54
+ ],
55
+ "text/plain": [
56
+ "<IPython.core.display.HTML object>"
57
+ ]
58
+ },
59
+ "metadata": {},
60
+ "output_type": "display_data"
61
+ },
62
+ {
63
+ "data": {
64
+ "text/plain": []
65
+ },
66
+ "execution_count": 3,
67
+ "metadata": {},
68
+ "output_type": "execute_result"
69
+ }
70
+ ],
71
+ "source": [
72
+ "# Load your custom classification models\n",
73
+ "scratch_model = tf.keras.models.load_model('animal_classifier_model_scratch.keras')\n",
74
+ "transfer_learning_model = tf.keras.models.load_model('animal_classifier_model_transfer.keras')\n",
75
+ "\n",
76
+ "# Class names, should match your dataset\n",
77
+ "class_names = ['butterfly', 'cat', 'elephant', 'horse', 'squirrel']\n",
78
+ "\n",
79
+ "def classify_image(image, model):\n",
80
+ " # Convert the Gradio input image to a PIL image\n",
81
+ " if isinstance(image, np.ndarray):\n",
82
+ " image = Image.fromarray(image.astype('uint8'), 'RGB')\n",
83
+ " \n",
84
+ " # Resize the image using np.resize\n",
85
+ " image = np.resize(image, (300, 300, 3)) # Add the channel dimension\n",
86
+ " \n",
87
+ " image = image / 255.0 # Normalize the image\n",
88
+ " image = np.expand_dims(image, axis=0) # Add batch dimension\n",
89
+ " \n",
90
+ " # Predict the class of the image\n",
91
+ " predictions = model.predict(image)\n",
92
+ " \n",
93
+ " # Get the indices of the top 3 predictions\n",
94
+ " top_indices = np.argsort(predictions[0])[::-1][:3]\n",
95
+ " \n",
96
+ " # Get the corresponding class names and confidences\n",
97
+ " top_classes = [class_names[i] for i in top_indices]\n",
98
+ " confidences = [predictions[0][i] for i in top_indices]\n",
99
+ " \n",
100
+ " return {class_name: float(confidence) for class_name, confidence in zip(top_classes, confidences)}\n",
101
+ "\n",
102
+ "image_input = gr.Image()\n",
103
+ "label = gr.Label(num_top_classes=3)\n",
104
+ "\n",
105
+ "scratch_interface = gr.Interface(\n",
106
+ " fn=lambda image: classify_image(image, scratch_model), \n",
107
+ " inputs=image_input, \n",
108
+ " outputs=label,\n",
109
+ " title='Animal Classifier (Scratch Model)',\n",
110
+ " description='Upload an image of an animal, and the classifier will tell you which animal it is, along with the confidence level of the prediction.'\n",
111
+ ")\n",
112
+ "\n",
113
+ "transfer_learning_interface = gr.Interface(\n",
114
+ " fn=lambda image: classify_image(image, transfer_learning_model), \n",
115
+ " inputs=image_input, \n",
116
+ " outputs=label,\n",
117
+ " title='Animal Classifier (Transfer Learning Model)',\n",
118
+ " description='Upload an image of an animal, and the classifier will tell you which animal it is, along with the confidence level of the prediction.'\n",
119
+ ")\n",
120
+ "\n",
121
+ "scratch_interface.launch()\n",
122
+ "transfer_learning_interface.launch()\n"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {},
128
+ "source": []
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "metadata": {},
133
+ "source": []
134
+ }
135
+ ],
136
+ "metadata": {
137
+ "kernelspec": {
138
+ "display_name": "venv_new",
139
+ "language": "python",
140
+ "name": "python3"
141
+ },
142
+ "language_info": {
143
+ "codemirror_mode": {
144
+ "name": "ipython",
145
+ "version": 3
146
+ },
147
+ "file_extension": ".py",
148
+ "mimetype": "text/x-python",
149
+ "name": "python",
150
+ "nbconvert_exporter": "python",
151
+ "pygments_lexer": "ipython3",
152
+ "version": "3.11.5"
153
+ }
154
+ },
155
+ "nbformat": 4,
156
+ "nbformat_minor": 2
157
+ }