ariG23498 HF staff commited on
Commit
9ac731e
1 Parent(s): b812a94

add: plot and bar

Browse files
Files changed (2) hide show
  1. app.py +32 -18
  2. utilities/load_model.py +19 -0
app.py CHANGED
@@ -1,27 +1,34 @@
1
  # import the necessary packages
2
  from utilities import config
 
3
  from tensorflow.keras import layers
4
- from tensorflow import keras
5
  import tensorflow as tf
6
  import matplotlib.pyplot as plt
7
  import math
8
  import gradio as gr
9
 
10
  # load the models from disk
11
- conv_stem = keras.models.load_model(
12
- config.IMAGENETTE_STEM_PATH,
13
- compile=False
14
- )
15
- conv_trunk = keras.models.load_model(
16
- config.IMAGENETTE_TRUNK_PATH,
17
- compile=False
18
- )
19
- conv_attn = keras.models.load_model(
20
- config.IMAGENETTE_ATTN_PATH,
21
- compile=False
22
  )
23
 
24
- def plot_attention(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # resize the image to a 224, 224 dim
26
  image = tf.image.convert_image_dtype(image, tf.float32)
27
  image = tf.image.resize(image, (224, 224))
@@ -32,7 +39,7 @@ def plot_attention(image):
32
  # pass through the trunk
33
  test_x = conv_trunk(test_x)
34
  # pass through the attention pooling block
35
- _, test_viz_weights = conv_attn(test_x)
36
  test_viz_weights = test_viz_weights[tf.newaxis, ...]
37
 
38
  # reshape the vizualization weights
@@ -52,9 +59,16 @@ def plot_attention(image):
52
  extent=img.get_extent()
53
  )
54
  plt.axis("off")
55
- return plt
 
 
 
56
 
57
  iface = gr.Interface(
58
- fn=plot_attention,
59
- inputs=[gr.inputs.Image(label="Input Image")],
60
- outputs="image").launch()
 
 
 
 
 
1
  # import the necessary packages
2
  from utilities import config
3
+ from utilities import load_model
4
  from tensorflow.keras import layers
 
5
  import tensorflow as tf
6
  import matplotlib.pyplot as plt
7
  import math
8
  import gradio as gr
9
 
10
  # load the models from disk
11
+ (conv_stem, conv_trunk, conv_attn) = load_model.loader(
12
+ stem=config.IMAGENETTE_STEM_PATH,
13
+ trunk=config.IMAGENETTE_TRUNK_PATH,
14
+ attn=config.IMAGENETTE_ATTN_PATH,
 
 
 
 
 
 
 
15
  )
16
 
17
+ # load labels
18
+ labels = [
19
+ 'tench',
20
+ 'english springer',
21
+ 'cassette player',
22
+ 'chain saw',
23
+ 'church',
24
+ 'french horn',
25
+ 'garbage truck',
26
+ 'gas pump',
27
+ 'golf ball',
28
+ 'parachute'
29
+ ]
30
+
31
+ def get_results(image):
32
  # resize the image to a 224, 224 dim
33
  image = tf.image.convert_image_dtype(image, tf.float32)
34
  image = tf.image.resize(image, (224, 224))
 
39
  # pass through the trunk
40
  test_x = conv_trunk(test_x)
41
  # pass through the attention pooling block
42
+ logits, test_viz_weights = conv_attn(test_x)
43
  test_viz_weights = test_viz_weights[tf.newaxis, ...]
44
 
45
  # reshape the vizualization weights
 
59
  extent=img.get_extent()
60
  )
61
  plt.axis("off")
62
+
63
+ prediction = tf.nn.softmax(logits, axis=-1)
64
+
65
+ return plt, {labels[i]: float(prediction[i]) for i in range(10)}
66
 
67
  iface = gr.Interface(
68
+ fn=get_results,
69
+ inputs=gr.inputs.Image(label="Input Image"),
70
+ outputs=[
71
+ gr.outputs.Image(label="Attention Map"),
72
+ gr.outputs.Label(num_top_classes=10)
73
+ ]
74
+ ).launch()
utilities/load_model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import the necessary packages
2
+ from tensorflow import keras
3
+
4
+ def loader(stem, trunk, attn):
5
+ # load the models from disk
6
+ conv_stem = keras.models.load_model(
7
+ stem,
8
+ compile=False
9
+ )
10
+ conv_trunk = keras.models.load_model(
11
+ trunk,
12
+ compile=False
13
+ )
14
+ conv_attn = keras.models.load_model(
15
+ attn,
16
+ compile=False
17
+ )
18
+
19
+ return (conv_stem, conv_trunk, conv_attn)