Spaces:
Runtime error
Runtime error
chore: house cleaning
Browse files- app.py +39 -9
- utilities/model.py +0 -30
- utilities/visualization.py +0 -45
app.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
# import the necessary packages
|
2 |
from utilities import config
|
3 |
from utilities import model
|
4 |
-
from utilities import visualization
|
5 |
from tensorflow import keras
|
|
|
|
|
|
|
|
|
6 |
import gradio as gr
|
7 |
|
8 |
# load the models from disk
|
@@ -19,15 +22,42 @@ conv_attn = keras.models.load_model(
|
|
19 |
compile=False
|
20 |
)
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
# get the plot attention function
|
30 |
-
plot_attention = visualization.PlotAttention(model=patch_conv_net)
|
31 |
iface = gr.Interface(
|
32 |
fn=plot_attention,
|
33 |
inputs=[gr.inputs.Image(label="Input Image")],
|
|
|
1 |
# import the necessary packages
|
2 |
from utilities import config
|
3 |
from utilities import model
|
|
|
4 |
from tensorflow import keras
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
import tensorflow as tf
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import math
|
9 |
import gradio as gr
|
10 |
|
11 |
# load the models from disk
|
|
|
22 |
compile=False
|
23 |
)
|
24 |
|
25 |
+
def plot_attention(image):
|
26 |
+
# resize the image to a 224, 224 dim
|
27 |
+
image = tf.image.convert_image_dtype(image, tf.float32)
|
28 |
+
image = tf.image.resize(image, (224, 224))
|
29 |
+
image = image[tf.newaxis, ...]
|
30 |
+
|
31 |
+
# pass through the stem
|
32 |
+
test_x = conv_stem(image)
|
33 |
+
# pass through the trunk
|
34 |
+
test_x = conv_trunk(test_x)
|
35 |
+
# pass through the attention pooling block
|
36 |
+
_, test_viz_weights = conv_attn(test_x)
|
37 |
+
test_viz_weights = test_viz_weights[tf.newaxis, ...]
|
38 |
+
|
39 |
+
# reshape the vizualization weights
|
40 |
+
num_patches = tf.shape(test_viz_weights)[-1]
|
41 |
+
height = width = int(math.sqrt(num_patches))
|
42 |
+
test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
|
43 |
+
|
44 |
+
index = 0
|
45 |
+
selected_image = image[index]
|
46 |
+
selected_weight = test_viz_weights[index]
|
47 |
+
|
48 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
|
49 |
+
ax[0].imshow(selected_image)
|
50 |
+
ax[0].set_title(f"Original")
|
51 |
+
ax[0].axis("off")
|
52 |
+
|
53 |
+
img = ax[1].imshow(selected_image)
|
54 |
+
ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
|
55 |
+
ax[1].set_title(f"Attended")
|
56 |
+
ax[1].axis("off")
|
57 |
+
|
58 |
+
plt.axis("off")
|
59 |
+
return plt
|
60 |
|
|
|
|
|
61 |
iface = gr.Interface(
|
62 |
fn=plot_attention,
|
63 |
inputs=[gr.inputs.Image(label="Input Image")],
|
utilities/model.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
# import the necessary packages
|
2 |
-
from tensorflow import keras
|
3 |
-
import tensorflow as tf
|
4 |
-
|
5 |
-
# Patch conv
|
6 |
-
class PatchConvNet(keras.Model):
|
7 |
-
def __init__(
|
8 |
-
self,
|
9 |
-
stem,
|
10 |
-
trunk,
|
11 |
-
attention_pooling,
|
12 |
-
**kwargs,
|
13 |
-
):
|
14 |
-
super().__init__(**kwargs)
|
15 |
-
self.stem = stem
|
16 |
-
self.trunk = trunk
|
17 |
-
self.attention_pooling = attention_pooling
|
18 |
-
|
19 |
-
@tf.function(
|
20 |
-
input_signature=[
|
21 |
-
tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8)
|
22 |
-
])
|
23 |
-
def call(self, images):
|
24 |
-
# pass through the stem
|
25 |
-
x = self.stem(images)
|
26 |
-
# pass through the trunk
|
27 |
-
x = self.trunk(x)
|
28 |
-
# pass through the attention pooling block
|
29 |
-
predictions, viz_weights = self.attention_pooling(x)
|
30 |
-
return predictions, viz_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utilities/visualization.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
# import the necessary packages
|
2 |
-
from tensorflow.keras import layers
|
3 |
-
import tensorflow as tf
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
import math
|
6 |
-
|
7 |
-
class PlotAttention(object):
|
8 |
-
def __init__(self, model):
|
9 |
-
self.model = model
|
10 |
-
|
11 |
-
def __call__(self, image):
|
12 |
-
# resize the image to a 224, 224 dim
|
13 |
-
image = tf.image.convert_image_dtype(image, tf.float32)
|
14 |
-
image = tf.image.resize(image, (224, 224))
|
15 |
-
image = image[tf.newaxis, ...]
|
16 |
-
|
17 |
-
# pass through the stem
|
18 |
-
test_x = self.model.stem(image)
|
19 |
-
# pass through the trunk
|
20 |
-
test_x = self.model.trunk(test_x)
|
21 |
-
# pass through the attention pooling block
|
22 |
-
_, test_viz_weights = self.model.attention_pooling(test_x)
|
23 |
-
test_viz_weights = test_viz_weights[tf.newaxis, ...]
|
24 |
-
|
25 |
-
# reshape the vizualization weights
|
26 |
-
num_patches = tf.shape(test_viz_weights)[-1]
|
27 |
-
height = width = int(math.sqrt(num_patches))
|
28 |
-
test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
|
29 |
-
|
30 |
-
index = 0
|
31 |
-
selected_image = image[index]
|
32 |
-
selected_weight = test_viz_weights[index]
|
33 |
-
|
34 |
-
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
|
35 |
-
ax[0].imshow(selected_image)
|
36 |
-
ax[0].set_title(f"Original")
|
37 |
-
ax[0].axis("off")
|
38 |
-
|
39 |
-
img = ax[1].imshow(selected_image)
|
40 |
-
ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
|
41 |
-
ax[1].set_title(f"Attended")
|
42 |
-
ax[1].axis("off")
|
43 |
-
|
44 |
-
plt.axis("off")
|
45 |
-
return plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|