ariG23498 HF staff commited on
Commit
0f130d4
1 Parent(s): 0469f00

chore: house cleaning

Browse files
Files changed (3) hide show
  1. app.py +39 -9
  2. utilities/model.py +0 -30
  3. utilities/visualization.py +0 -45
app.py CHANGED
@@ -1,8 +1,11 @@
1
  # import the necessary packages
2
  from utilities import config
3
  from utilities import model
4
- from utilities import visualization
5
  from tensorflow import keras
 
 
 
 
6
  import gradio as gr
7
 
8
  # load the models from disk
@@ -19,15 +22,42 @@ conv_attn = keras.models.load_model(
19
  compile=False
20
  )
21
 
22
- # create the patch conv net
23
- patch_conv_net = model.PatchConvNet(
24
- stem=conv_stem,
25
- trunk=conv_trunk,
26
- attention_pooling=conv_attn,
27
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # get the plot attention function
30
- plot_attention = visualization.PlotAttention(model=patch_conv_net)
31
  iface = gr.Interface(
32
  fn=plot_attention,
33
  inputs=[gr.inputs.Image(label="Input Image")],
 
1
  # import the necessary packages
2
  from utilities import config
3
  from utilities import model
 
4
  from tensorflow import keras
5
+ from tensorflow.keras import layers
6
+ import tensorflow as tf
7
+ import matplotlib.pyplot as plt
8
+ import math
9
  import gradio as gr
10
 
11
  # load the models from disk
 
22
  compile=False
23
  )
24
 
25
+ def plot_attention(image):
26
+ # resize the image to a 224, 224 dim
27
+ image = tf.image.convert_image_dtype(image, tf.float32)
28
+ image = tf.image.resize(image, (224, 224))
29
+ image = image[tf.newaxis, ...]
30
+
31
+ # pass through the stem
32
+ test_x = conv_stem(image)
33
+ # pass through the trunk
34
+ test_x = conv_trunk(test_x)
35
+ # pass through the attention pooling block
36
+ _, test_viz_weights = conv_attn(test_x)
37
+ test_viz_weights = test_viz_weights[tf.newaxis, ...]
38
+
39
+ # reshape the vizualization weights
40
+ num_patches = tf.shape(test_viz_weights)[-1]
41
+ height = width = int(math.sqrt(num_patches))
42
+ test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
43
+
44
+ index = 0
45
+ selected_image = image[index]
46
+ selected_weight = test_viz_weights[index]
47
+
48
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
49
+ ax[0].imshow(selected_image)
50
+ ax[0].set_title(f"Original")
51
+ ax[0].axis("off")
52
+
53
+ img = ax[1].imshow(selected_image)
54
+ ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
55
+ ax[1].set_title(f"Attended")
56
+ ax[1].axis("off")
57
+
58
+ plt.axis("off")
59
+ return plt
60
 
 
 
61
  iface = gr.Interface(
62
  fn=plot_attention,
63
  inputs=[gr.inputs.Image(label="Input Image")],
utilities/model.py DELETED
@@ -1,30 +0,0 @@
1
- # import the necessary packages
2
- from tensorflow import keras
3
- import tensorflow as tf
4
-
5
- # Patch conv
6
- class PatchConvNet(keras.Model):
7
- def __init__(
8
- self,
9
- stem,
10
- trunk,
11
- attention_pooling,
12
- **kwargs,
13
- ):
14
- super().__init__(**kwargs)
15
- self.stem = stem
16
- self.trunk = trunk
17
- self.attention_pooling = attention_pooling
18
-
19
- @tf.function(
20
- input_signature=[
21
- tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8)
22
- ])
23
- def call(self, images):
24
- # pass through the stem
25
- x = self.stem(images)
26
- # pass through the trunk
27
- x = self.trunk(x)
28
- # pass through the attention pooling block
29
- predictions, viz_weights = self.attention_pooling(x)
30
- return predictions, viz_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utilities/visualization.py DELETED
@@ -1,45 +0,0 @@
1
- # import the necessary packages
2
- from tensorflow.keras import layers
3
- import tensorflow as tf
4
- import matplotlib.pyplot as plt
5
- import math
6
-
7
- class PlotAttention(object):
8
- def __init__(self, model):
9
- self.model = model
10
-
11
- def __call__(self, image):
12
- # resize the image to a 224, 224 dim
13
- image = tf.image.convert_image_dtype(image, tf.float32)
14
- image = tf.image.resize(image, (224, 224))
15
- image = image[tf.newaxis, ...]
16
-
17
- # pass through the stem
18
- test_x = self.model.stem(image)
19
- # pass through the trunk
20
- test_x = self.model.trunk(test_x)
21
- # pass through the attention pooling block
22
- _, test_viz_weights = self.model.attention_pooling(test_x)
23
- test_viz_weights = test_viz_weights[tf.newaxis, ...]
24
-
25
- # reshape the vizualization weights
26
- num_patches = tf.shape(test_viz_weights)[-1]
27
- height = width = int(math.sqrt(num_patches))
28
- test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
29
-
30
- index = 0
31
- selected_image = image[index]
32
- selected_weight = test_viz_weights[index]
33
-
34
- fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
35
- ax[0].imshow(selected_image)
36
- ax[0].set_title(f"Original")
37
- ax[0].axis("off")
38
-
39
- img = ax[1].imshow(selected_image)
40
- ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
41
- ax[1].set_title(f"Attended")
42
- ax[1].axis("off")
43
-
44
- plt.axis("off")
45
- return plt