akhaliq HF staff commited on
Commit
94837e8
1 Parent(s): cf37188

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -0
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_hub as hub
3
+
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import gradio as gr
11
+
12
+ #@title Helper functions for loading image (hidden)
13
+
14
+ original_image_cache = {}
15
+
16
+ def preprocess_image(image):
17
+ image = np.array(image)
18
+ # reshape into shape [batch_size, height, width, num_channels]
19
+ img_reshaped = tf.reshape(image, [1, image.shape[0], image.shape[1], image.shape[2]])
20
+ # Use `convert_image_dtype` to convert to floats in the [0,1] range.
21
+ image = tf.image.convert_image_dtype(img_reshaped, tf.float32)
22
+ return image
23
+
24
+ def load_image_from_url(img_url):
25
+ """Returns an image with shape [1, height, width, num_channels]."""
26
+ user_agent = {'User-agent': 'Colab Sample (https://tensorflow.org)'}
27
+ response = requests.get(img_url, headers=user_agent)
28
+ image = Image.open(BytesIO(response.content))
29
+ image = preprocess_image(image)
30
+ return image
31
+
32
+ def load_image(image_url, image_size=256, dynamic_size=False, max_dynamic_size=512):
33
+ """Loads and preprocesses images."""
34
+ # Cache image file locally.
35
+ if image_url in original_image_cache:
36
+ img = original_image_cache[image_url]
37
+ elif image_url.startswith('https://'):
38
+ img = load_image_from_url(image_url)
39
+ else:
40
+ fd = tf.io.gfile.GFile(image_url, 'rb')
41
+ img = preprocess_image(Image.open(fd))
42
+ original_image_cache[image_url] = img
43
+ # Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1].
44
+ img_raw = img
45
+ if tf.reduce_max(img) > 1.0:
46
+ img = img / 255.
47
+ if len(img.shape) == 3:
48
+ img = tf.stack([img, img, img], axis=-1)
49
+ if not dynamic_size:
50
+ img = tf.image.resize_with_pad(img, image_size, image_size)
51
+ elif img.shape[1] > max_dynamic_size or img.shape[2] > max_dynamic_size:
52
+ img = tf.image.resize_with_pad(img, max_dynamic_size, max_dynamic_size)
53
+ return img, img_raw
54
+
55
+
56
+
57
+ image_size = 224
58
+ dynamic_size = False
59
+
60
+ model_name = "efficientnet_b2"
61
+
62
+ model_handle_map = {
63
+ "efficientnetv2-s": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/classification/2",
64
+ "efficientnetv2-m": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_m/classification/2",
65
+ "efficientnetv2-l": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/classification/2",
66
+ "efficientnetv2-s-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_s/classification/2",
67
+ "efficientnetv2-m-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_m/classification/2",
68
+ "efficientnetv2-l-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_l/classification/2",
69
+ "efficientnetv2-xl-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_xl/classification/2",
70
+ "efficientnetv2-b0-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b0/classification/2",
71
+ "efficientnetv2-b1-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b1/classification/2",
72
+ "efficientnetv2-b2-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b2/classification/2",
73
+ "efficientnetv2-b3-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b3/classification/2",
74
+ "efficientnetv2-s-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/classification/2",
75
+ "efficientnetv2-m-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_m/classification/2",
76
+ "efficientnetv2-l-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_l/classification/2",
77
+ "efficientnetv2-xl-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_xl/classification/2",
78
+ "efficientnetv2-b0-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b0/classification/2",
79
+ "efficientnetv2-b1-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b1/classification/2",
80
+ "efficientnetv2-b2-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b2/classification/2",
81
+ "efficientnetv2-b3-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b3/classification/2",
82
+ "efficientnetv2-b0": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b0/classification/2",
83
+ "efficientnetv2-b1": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b1/classification/2",
84
+ "efficientnetv2-b2": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b2/classification/2",
85
+ "efficientnetv2-b3": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b3/classification/2",
86
+ "efficientnet_b0": "https://tfhub.dev/tensorflow/efficientnet/b0/classification/1",
87
+ "efficientnet_b1": "https://tfhub.dev/tensorflow/efficientnet/b1/classification/1",
88
+ "efficientnet_b2": "https://tfhub.dev/tensorflow/efficientnet/b2/classification/1",
89
+ "efficientnet_b3": "https://tfhub.dev/tensorflow/efficientnet/b3/classification/1",
90
+ "efficientnet_b4": "https://tfhub.dev/tensorflow/efficientnet/b4/classification/1",
91
+ "efficientnet_b5": "https://tfhub.dev/tensorflow/efficientnet/b5/classification/1",
92
+ "efficientnet_b6": "https://tfhub.dev/tensorflow/efficientnet/b6/classification/1",
93
+ "efficientnet_b7": "https://tfhub.dev/tensorflow/efficientnet/b7/classification/1",
94
+ "bit_s-r50x1": "https://tfhub.dev/google/bit/s-r50x1/ilsvrc2012_classification/1",
95
+ "inception_v3": "https://tfhub.dev/google/imagenet/inception_v3/classification/4",
96
+ "inception_resnet_v2": "https://tfhub.dev/google/imagenet/inception_resnet_v2/classification/4",
97
+ "resnet_v1_50": "https://tfhub.dev/google/imagenet/resnet_v1_50/classification/4",
98
+ "resnet_v1_101": "https://tfhub.dev/google/imagenet/resnet_v1_101/classification/4",
99
+ "resnet_v1_152": "https://tfhub.dev/google/imagenet/resnet_v1_152/classification/4",
100
+ "resnet_v2_50": "https://tfhub.dev/google/imagenet/resnet_v2_50/classification/4",
101
+ "resnet_v2_101": "https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4",
102
+ "resnet_v2_152": "https://tfhub.dev/google/imagenet/resnet_v2_152/classification/4",
103
+ "nasnet_large": "https://tfhub.dev/google/imagenet/nasnet_large/classification/4",
104
+ "nasnet_mobile": "https://tfhub.dev/google/imagenet/nasnet_mobile/classification/4",
105
+ "pnasnet_large": "https://tfhub.dev/google/imagenet/pnasnet_large/classification/4",
106
+ "mobilenet_v2_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/4",
107
+ "mobilenet_v2_130_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4",
108
+ "mobilenet_v2_140_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/4",
109
+ "mobilenet_v3_small_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_small_100_224/classification/5",
110
+ "mobilenet_v3_small_075_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_small_075_224/classification/5",
111
+ "mobilenet_v3_large_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/5",
112
+ "mobilenet_v3_large_075_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_large_075_224/classification/5",
113
+ }
114
+
115
+ model_image_size_map = {
116
+ "efficientnetv2-s": 384,
117
+ "efficientnetv2-m": 480,
118
+ "efficientnetv2-l": 480,
119
+ "efficientnetv2-b0": 224,
120
+ "efficientnetv2-b1": 240,
121
+ "efficientnetv2-b2": 260,
122
+ "efficientnetv2-b3": 300,
123
+ "efficientnetv2-s-21k": 384,
124
+ "efficientnetv2-m-21k": 480,
125
+ "efficientnetv2-l-21k": 480,
126
+ "efficientnetv2-xl-21k": 512,
127
+ "efficientnetv2-b0-21k": 224,
128
+ "efficientnetv2-b1-21k": 240,
129
+ "efficientnetv2-b2-21k": 260,
130
+ "efficientnetv2-b3-21k": 300,
131
+ "efficientnetv2-s-21k-ft1k": 384,
132
+ "efficientnetv2-m-21k-ft1k": 480,
133
+ "efficientnetv2-l-21k-ft1k": 480,
134
+ "efficientnetv2-xl-21k-ft1k": 512,
135
+ "efficientnetv2-b0-21k-ft1k": 224,
136
+ "efficientnetv2-b1-21k-ft1k": 240,
137
+ "efficientnetv2-b2-21k-ft1k": 260,
138
+ "efficientnetv2-b3-21k-ft1k": 300,
139
+ "efficientnet_b0": 224,
140
+ "efficientnet_b1": 240,
141
+ "efficientnet_b2": 260,
142
+ "efficientnet_b3": 300,
143
+ "efficientnet_b4": 380,
144
+ "efficientnet_b5": 456,
145
+ "efficientnet_b6": 528,
146
+ "efficientnet_b7": 600,
147
+ "inception_v3": 299,
148
+ "inception_resnet_v2": 299,
149
+ "mobilenet_v2_100_224": 224,
150
+ "mobilenet_v2_130_224": 224,
151
+ "mobilenet_v2_140_224": 224,
152
+ "nasnet_large": 331,
153
+ "nasnet_mobile": 224,
154
+ "pnasnet_large": 331,
155
+ "resnet_v1_50": 224,
156
+ "resnet_v1_101": 224,
157
+ "resnet_v1_152": 224,
158
+ "resnet_v2_50": 224,
159
+ "resnet_v2_101": 224,
160
+ "resnet_v2_152": 224,
161
+ "mobilenet_v3_small_100_224": 224,
162
+ "mobilenet_v3_small_075_224": 224,
163
+ "mobilenet_v3_large_100_224": 224,
164
+ "mobilenet_v3_large_075_224": 224,
165
+ }
166
+
167
+ model_handle = model_handle_map[model_name]
168
+
169
+
170
+ max_dynamic_size = 512
171
+ if model_name in model_image_size_map:
172
+ image_size = model_image_size_map[model_name]
173
+ dynamic_size = False
174
+ print(f"Images will be converted to {image_size}x{image_size}")
175
+ else:
176
+ dynamic_size = True
177
+ print(f"Images will be capped to a max size of {max_dynamic_size}x{max_dynamic_size}")
178
+
179
+ labels_file = "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"
180
+
181
+ #download labels and creates a maps
182
+ downloaded_file = tf.keras.utils.get_file("labels.txt", origin=labels_file)
183
+
184
+ classes = []
185
+
186
+ with open(downloaded_file) as f:
187
+ labels = f.readlines()
188
+ classes = [l.strip() for l in labels]
189
+
190
+
191
+ classifier = hub.load(model_handle)
192
+
193
+
194
+ def inference(img):
195
+ image, original_image = load_image(img, image_size, dynamic_size, max_dynamic_size)
196
+
197
+
198
+ input_shape = image.shape
199
+ warmup_input = tf.random.uniform(input_shape, 0, 1.0)
200
+ warmup_logits = classifier(warmup_input).numpy()
201
+
202
+ # Run model on image
203
+ probabilities = tf.nn.softmax(classifier(image)).numpy()
204
+
205
+ top_5 = tf.argsort(probabilities, axis=-1, direction="DESCENDING")[0][:5].numpy()
206
+ np_classes = np.array(classes)
207
+
208
+ # Some models include an additional 'background' class in the predictions, so
209
+ # we must account for this when reading the class labels.
210
+ includes_background_class = probabilities.shape[1] == 1001
211
+ result = {}
212
+ for i, item in enumerate(top_5):
213
+ class_index = item if includes_background_class else item + 1
214
+ line = f'({i+1}) {class_index:4} - {classes[class_index]}: {probabilities[0][top_5][i]}'
215
+ result[classes[class_index]] = probabilities[0][top_5][i].item()
216
+ return result
217
+
218
+ title="efficientnet_b2"
219
+ description="Gradio Demo for efficientnet_b2: Imagenet (ILSVRC-2012-CLS) classification with EfficientNet-B2. To use it, simply upload your image or click on one of the examples to load them. Read more at the links below"
220
+
221
+ article = "<p style='text-align: center'><a href='https://tfhub.dev/google/efficientnet/b2/classification/1' target='_blank'>Tensorflow Hub</a></p>"
222
+ examples=[['apple1.jpg']]
223
+
224
+ gr.Interface(inference,gr.inputs.Image(type="filepath"),"label",title=title,description=description,article=article,examples=examples).launch(enable_queue=True)