Spaces:
Runtime error
Runtime error
fix: reformat code for imagenette
Browse files- .gitattributes +3 -0
- app.py +34 -0
- model/att_pool/keras_metadata.pb +3 -0
- model/att_pool/saved_model.pb +3 -0
- model/att_pool/variables/variables.data-00000-of-00001 +0 -0
- model/att_pool/variables/variables.index +0 -0
- model/stem/keras_metadata.pb +3 -0
- model/stem/saved_model.pb +3 -0
- model/stem/variables/variables.data-00000-of-00001 +0 -0
- model/stem/variables/variables.index +0 -0
- model/trunk/keras_metadata.pb +3 -0
- model/trunk/saved_model.pb +3 -0
- model/trunk/variables/variables.data-00000-of-00001 +0 -0
- model/trunk/variables/variables.index +0 -0
- utilities/config.py +8 -0
- utilities/model.py +30 -0
- utilities/visualization.py +45 -0
.gitattributes
CHANGED
@@ -25,3 +25,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
|
29 |
+
# for macOS
|
30 |
+
.DS_Store
|
app.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import the necessary packages
|
2 |
+
from utilities import config
|
3 |
+
from utilities import model
|
4 |
+
from utilities import visualization
|
5 |
+
from tensorflow import keras
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# load the models from disk
|
9 |
+
conv_stem = keras.models.load_model(
|
10 |
+
config.IMAGENETTE_STEM_PATH,
|
11 |
+
compile=False
|
12 |
+
)
|
13 |
+
conv_trunk = keras.models.load_model(
|
14 |
+
config.IMAGENETTE_TRUNK_PATH,
|
15 |
+
compile=False
|
16 |
+
)
|
17 |
+
conv_attn = keras.models.load_model(
|
18 |
+
config.IMAGENETTE_ATTN_PATH,
|
19 |
+
compile=False
|
20 |
+
)
|
21 |
+
|
22 |
+
# create the patch conv net
|
23 |
+
patch_conv_net = model.PatchConvNet(
|
24 |
+
stem=conv_stem,
|
25 |
+
trunk=conv_trunk,
|
26 |
+
attention_pooling=conv_attn,
|
27 |
+
)
|
28 |
+
|
29 |
+
# get the plot attention function
|
30 |
+
plot_attention = visualization.PlotAttention(model=patch_conv_net)
|
31 |
+
iface = gr.Interface(
|
32 |
+
fn=plot_attention,
|
33 |
+
inputs=[gr.inputs.Image(label="Input Image")],
|
34 |
+
outputs="image").launch()
|
model/att_pool/keras_metadata.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fefb606f0aeb214dcfa9cf9786955f6b7ecb7bdd116e007c44264af75936bfca
|
3 |
+
size 15848
|
model/att_pool/saved_model.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0057c7816b0d297c8cdb02958c1a0044220eea870f923e38ba2c1a7fb01c9448
|
3 |
+
size 324550
|
model/att_pool/variables/variables.data-00000-of-00001
ADDED
Binary file (1.61 MB). View file
|
|
model/att_pool/variables/variables.index
ADDED
Binary file (1.38 kB). View file
|
|
model/stem/keras_metadata.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e41ea234e1a8f95c77c1308242971d818b50ec3212d3b1f1f34d9042d77f1270
|
3 |
+
size 11998
|
model/stem/saved_model.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee90a781360388d5b7e2035be35b349f4209a6e58881bbc02e2c01cea8130877
|
3 |
+
size 96815
|
model/stem/variables/variables.data-00000-of-00001
ADDED
Binary file (1.56 MB). View file
|
|
model/stem/variables/variables.index
ADDED
Binary file (667 Bytes). View file
|
|
model/trunk/keras_metadata.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c090252e37268009bf4e4fc8866b2984bfdfbe88341ddb54dea464c9eb365bc4
|
3 |
+
size 23883
|
model/trunk/saved_model.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:255bf760cd4df017b1d4ff670f889ac9c5985d9b41d1d2d1d401c3030797164b
|
3 |
+
size 359160
|
model/trunk/variables/variables.data-00000-of-00001
ADDED
Binary file (2.96 MB). View file
|
|
model/trunk/variables/variables.index
ADDED
Binary file (733 Bytes). View file
|
|
utilities/config.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import the necessary packages
|
2 |
+
import os
|
3 |
+
|
4 |
+
# define the path to the model
|
5 |
+
MODEL_PATH = "model"
|
6 |
+
IMAGENETTE_ATTN_PATH = os.path.join(MODEL_PATH, "att_pool")
|
7 |
+
IMAGENETTE_STEM_PATH = os.path.join(MODEL_PATH, "stem")
|
8 |
+
IMAGENETTE_TRUNK_PATH = os.path.join(MODEL_PATH, "trunk")
|
utilities/model.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import the necessary packages
|
2 |
+
from tensorflow import keras
|
3 |
+
import tensorflow as tf
|
4 |
+
|
5 |
+
# Patch conv
|
6 |
+
class PatchConvNet(keras.Model):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
stem,
|
10 |
+
trunk,
|
11 |
+
attention_pooling,
|
12 |
+
**kwargs,
|
13 |
+
):
|
14 |
+
super().__init__(**kwargs)
|
15 |
+
self.stem = stem
|
16 |
+
self.trunk = trunk
|
17 |
+
self.attention_pooling = attention_pooling
|
18 |
+
|
19 |
+
@tf.function(
|
20 |
+
input_signature=[
|
21 |
+
tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8)
|
22 |
+
])
|
23 |
+
def call(self, images):
|
24 |
+
# pass through the stem
|
25 |
+
x = self.stem(images)
|
26 |
+
# pass through the trunk
|
27 |
+
x = self.trunk(x)
|
28 |
+
# pass through the attention pooling block
|
29 |
+
predictions, viz_weights = self.attention_pooling(x)
|
30 |
+
return predictions, viz_weights
|
utilities/visualization.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import the necessary packages
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
import tensorflow as tf
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import math
|
6 |
+
|
7 |
+
class PlotAttention:
|
8 |
+
def __init__(self, model):
|
9 |
+
self.model = model
|
10 |
+
|
11 |
+
def __call__(self, image):
|
12 |
+
# resize the image to a 224, 224 dim
|
13 |
+
image = tf.image.convert_image_dtype(image, tf.float32)
|
14 |
+
image = tf.image.resize(image, (224, 224))
|
15 |
+
image = image[tf.newaxis, ...]
|
16 |
+
|
17 |
+
# pass through the stem
|
18 |
+
test_x = self.model.stem(image)
|
19 |
+
# pass through the trunk
|
20 |
+
test_x = self.model.trunk(test_x)
|
21 |
+
# pass through the attention pooling block
|
22 |
+
_, test_viz_weights = self.model.attention_pooling(test_x)
|
23 |
+
test_viz_weights = test_viz_weights[tf.newaxis, ...]
|
24 |
+
|
25 |
+
# reshape the vizualization weights
|
26 |
+
num_patches = tf.shape(test_viz_weights)[-1]
|
27 |
+
height = width = int(math.sqrt(num_patches))
|
28 |
+
test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
|
29 |
+
|
30 |
+
index = 0
|
31 |
+
selected_image = image[index]
|
32 |
+
selected_weight = test_viz_weights[index]
|
33 |
+
|
34 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
|
35 |
+
ax[0].imshow(selected_image)
|
36 |
+
ax[0].set_title(f"Original")
|
37 |
+
ax[0].axis("off")
|
38 |
+
|
39 |
+
img = ax[1].imshow(selected_image)
|
40 |
+
ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
|
41 |
+
ax[1].set_title(f"Attended")
|
42 |
+
ax[1].axis("off")
|
43 |
+
|
44 |
+
plt.axis("off")
|
45 |
+
return plt
|