vdprabhu commited on
Commit
d6466d7
1 Parent(s): 38f02a0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Imports
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+ import tensorflow as tf
6
+ from tensorflow import keras
7
+
8
+ import streamlit as st
9
+
10
+ from app_utils import *
11
+
12
+ # The functions (except main) are taken straight from Keras Example
13
+ def compute_loss(feature_extractor, input_image, filter_index):
14
+ activation = feature_extractor(input_image)
15
+ # We avoid border artifacts by only involving non-border pixels in the loss.
16
+ filter_activation = activation[:, 2:-2, 2:-2, filter_index]
17
+ return tf.reduce_mean(filter_activation)
18
+
19
+
20
+ @tf.function
21
+ def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate):
22
+ with tf.GradientTape() as tape:
23
+ tape.watch(img)
24
+ loss = compute_loss(feature_extractor, img, filter_index)
25
+ # Compute gradients.
26
+ grads = tape.gradient(loss, img)
27
+ # Normalize gradients.
28
+ grads = tf.math.l2_normalize(grads)
29
+ img += learning_rate * grads
30
+ return loss, img
31
+
32
+
33
+ def initialize_image():
34
+ # We start from a gray image with some random noise
35
+ img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3))
36
+ # ResNet50V2 expects inputs in the range [-1, +1].
37
+ # Here we scale our random inputs to [-0.125, +0.125]
38
+ return (img - 0.5) * 0.25
39
+
40
+
41
+ def visualize_filter(feature_extractor, filter_index):
42
+ # We run gradient ascent for 20 steps
43
+ img = initialize_image()
44
+ for _ in range(ITERATIONS):
45
+ loss, img = gradient_ascent_step(
46
+ feature_extractor, img, filter_index, LEARNING_RATE
47
+ )
48
+
49
+ # Decode the resulting input image
50
+ img = deprocess_image(img[0].numpy())
51
+ return loss, img
52
+
53
+
54
+ def deprocess_image(img):
55
+ # Normalize array: center on 0., ensure variance is 0.15
56
+ img -= img.mean()
57
+ img /= img.std() + 1e-5
58
+ img *= 0.15
59
+
60
+ # Center crop
61
+ img = img[25:-25, 25:-25, :]
62
+
63
+ # Clip to [0, 1]
64
+ img += 0.5
65
+ img = np.clip(img, 0, 1)
66
+
67
+ # Convert to RGB array
68
+ img *= 255
69
+ img = np.clip(img, 0, 255).astype("uint8")
70
+ return img
71
+
72
+
73
+ # The visualization function
74
+ def main():
75
+ # Model selector
76
+ mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS)
77
+
78
+ # Check to not load the model for ever layer change
79
+ if mn_option != st.session_state.model_name:
80
+ model = getattr(keras.applications, mn_option)(
81
+ weights="imagenet", include_top=False
82
+ )
83
+ st.session_state.layer_list = ["<select layer>"] + [
84
+ layer.name for layer in model.layers
85
+ ]
86
+ st.session_state.model = model
87
+ st.session_state.model_name = mn_option
88
+
89
+ # Layer selector, saves the feature selector in case 64 filters are to be seen
90
+ if st.session_state.model_name:
91
+ ln_option = st.selectbox(
92
+ "Select the target layer (best to pick somewhere in the middle of the model) -",
93
+ st.session_state.layer_list,
94
+ )
95
+ if ln_option != "<select layer>":
96
+ if ln_option != st.session_state.layer_name:
97
+ layer = st.session_state.model.get_layer(name=ln_option)
98
+ st.session_state.feat_extract = keras.Model(
99
+ inputs=st.session_state.model.inputs, outputs=layer.output
100
+ )
101
+ st.session_state.layer_name = ln_option
102
+
103
+ # Filter index selector
104
+ if st.session_state.layer_name:
105
+ filter_select = st.selectbox("Visualize -", VIS_OPTION.keys())
106
+
107
+ if VIS_OPTION[filter_select] == 0:
108
+ loss, img = visualize_filter(st.session_state.feat_extract, 0)
109
+ st.image(img)
110
+ else:
111
+ st.warning(":exclamation: Calculating the gradients can take a while..")
112
+ prog_bar = st.progress(0)
113
+ fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14))
114
+ for filter_index, ax in enumerate(axis.ravel()):
115
+ prog_bar.progress((filter_index + 1) / 64)
116
+ if filter_index < 65:
117
+ loss, img = visualize_filter(
118
+ st.session_state.feat_extract, filter_index
119
+ )
120
+ ax.imshow(img)
121
+ ax.set_title(filter_index + 1)
122
+ ax.set_axis_off()
123
+ else:
124
+ ax.set_axis_off()
125
+
126
+ st.write(fig)
127
+
128
+
129
+ if __name__ == "__main__":
130
+
131
+ with open("model_names.txt", "r") as op:
132
+ AVAILABLE_MODELS = [i.strip() for i in op.readlines()]
133
+
134
+ st.set_page_config(layout="wide")
135
+
136
+ st.title(title)
137
+ st.write(info_text)
138
+ st.info(f"{credits}\n\n{replicate}\n\n{vit_info}")
139
+
140
+ main()