ariG23498 commited on
Commit
310a06c
1 Parent(s): 9454924

fix: reformat code for imagenette

Browse files
.gitattributes CHANGED
@@ -25,3 +25,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+
29
+ # for macOS
30
+ .DS_Store
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import the necessary packages
2
+ from utilities import config
3
+ from utilities import model
4
+ from utilities import visualization
5
+ from tensorflow import keras
6
+ import gradio as gr
7
+
8
+ # load the models from disk
9
+ conv_stem = keras.models.load_model(
10
+ config.IMAGENETTE_STEM_PATH,
11
+ compile=False
12
+ )
13
+ conv_trunk = keras.models.load_model(
14
+ config.IMAGENETTE_TRUNK_PATH,
15
+ compile=False
16
+ )
17
+ conv_attn = keras.models.load_model(
18
+ config.IMAGENETTE_ATTN_PATH,
19
+ compile=False
20
+ )
21
+
22
+ # create the patch conv net
23
+ patch_conv_net = model.PatchConvNet(
24
+ stem=conv_stem,
25
+ trunk=conv_trunk,
26
+ attention_pooling=conv_attn,
27
+ )
28
+
29
+ # get the plot attention function
30
+ plot_attention = visualization.PlotAttention(model=patch_conv_net)
31
+ iface = gr.Interface(
32
+ fn=plot_attention,
33
+ inputs=[gr.inputs.Image(label="Input Image")],
34
+ outputs="image").launch()
model/att_pool/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fefb606f0aeb214dcfa9cf9786955f6b7ecb7bdd116e007c44264af75936bfca
3
+ size 15848
model/att_pool/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0057c7816b0d297c8cdb02958c1a0044220eea870f923e38ba2c1a7fb01c9448
3
+ size 324550
model/att_pool/variables/variables.data-00000-of-00001 ADDED
Binary file (1.61 MB). View file
 
model/att_pool/variables/variables.index ADDED
Binary file (1.38 kB). View file
 
model/stem/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e41ea234e1a8f95c77c1308242971d818b50ec3212d3b1f1f34d9042d77f1270
3
+ size 11998
model/stem/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee90a781360388d5b7e2035be35b349f4209a6e58881bbc02e2c01cea8130877
3
+ size 96815
model/stem/variables/variables.data-00000-of-00001 ADDED
Binary file (1.56 MB). View file
 
model/stem/variables/variables.index ADDED
Binary file (667 Bytes). View file
 
model/trunk/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c090252e37268009bf4e4fc8866b2984bfdfbe88341ddb54dea464c9eb365bc4
3
+ size 23883
model/trunk/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:255bf760cd4df017b1d4ff670f889ac9c5985d9b41d1d2d1d401c3030797164b
3
+ size 359160
model/trunk/variables/variables.data-00000-of-00001 ADDED
Binary file (2.96 MB). View file
 
model/trunk/variables/variables.index ADDED
Binary file (733 Bytes). View file
 
utilities/config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # import the necessary packages
2
+ import os
3
+
4
+ # define the path to the model
5
+ MODEL_PATH = "model"
6
+ IMAGENETTE_ATTN_PATH = os.path.join(MODEL_PATH, "att_pool")
7
+ IMAGENETTE_STEM_PATH = os.path.join(MODEL_PATH, "stem")
8
+ IMAGENETTE_TRUNK_PATH = os.path.join(MODEL_PATH, "trunk")
utilities/model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import the necessary packages
2
+ from tensorflow import keras
3
+ import tensorflow as tf
4
+
5
+ # Patch conv
6
+ class PatchConvNet(keras.Model):
7
+ def __init__(
8
+ self,
9
+ stem,
10
+ trunk,
11
+ attention_pooling,
12
+ **kwargs,
13
+ ):
14
+ super().__init__(**kwargs)
15
+ self.stem = stem
16
+ self.trunk = trunk
17
+ self.attention_pooling = attention_pooling
18
+
19
+ @tf.function(
20
+ input_signature=[
21
+ tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8)
22
+ ])
23
+ def call(self, images):
24
+ # pass through the stem
25
+ x = self.stem(images)
26
+ # pass through the trunk
27
+ x = self.trunk(x)
28
+ # pass through the attention pooling block
29
+ predictions, viz_weights = self.attention_pooling(x)
30
+ return predictions, viz_weights
utilities/visualization.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import the necessary packages
2
+ from tensorflow.keras import layers
3
+ import tensorflow as tf
4
+ import matplotlib.pyplot as plt
5
+ import math
6
+
7
+ class PlotAttention:
8
+ def __init__(self, model):
9
+ self.model = model
10
+
11
+ def __call__(self, image):
12
+ # resize the image to a 224, 224 dim
13
+ image = tf.image.convert_image_dtype(image, tf.float32)
14
+ image = tf.image.resize(image, (224, 224))
15
+ image = image[tf.newaxis, ...]
16
+
17
+ # pass through the stem
18
+ test_x = self.model.stem(image)
19
+ # pass through the trunk
20
+ test_x = self.model.trunk(test_x)
21
+ # pass through the attention pooling block
22
+ _, test_viz_weights = self.model.attention_pooling(test_x)
23
+ test_viz_weights = test_viz_weights[tf.newaxis, ...]
24
+
25
+ # reshape the vizualization weights
26
+ num_patches = tf.shape(test_viz_weights)[-1]
27
+ height = width = int(math.sqrt(num_patches))
28
+ test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
29
+
30
+ index = 0
31
+ selected_image = image[index]
32
+ selected_weight = test_viz_weights[index]
33
+
34
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
35
+ ax[0].imshow(selected_image)
36
+ ax[0].set_title(f"Original")
37
+ ax[0].axis("off")
38
+
39
+ img = ax[1].imshow(selected_image)
40
+ ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
41
+ ax[1].set_title(f"Attended")
42
+ ax[1].axis("off")
43
+
44
+ plt.axis("off")
45
+ return plt