|
import math |
|
import os |
|
import sys |
|
import tempfile |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import streamlit as st |
|
from PIL import Image |
|
from tensorflow.keras import layers, models |
|
|
|
|
|
fig_size = plt.rcParams['figure.figsize'] |
|
|
|
|
|
plt.rcParams['image.cmap'] = 'gray' |
|
|
|
|
|
st.title("Filters and Feature Maps Visualization") |
|
|
|
|
|
model = st.file_uploader(label="Upload model", type=["h5"]) |
|
|
|
if model: |
|
|
|
with tempfile.TemporaryDirectory() as tempdir: |
|
with open(os.path.join(tempdir, "temp.h5"), mode='wb') as f: |
|
|
|
f.write(model.getvalue()) |
|
|
|
|
|
model = models.load_model(os.path.join(tempdir, "temp.h5")) |
|
|
|
|
|
viz_option = st.selectbox("What would you like to visualize?", |
|
options=["Filters", "Feature Maps"]) |
|
|
|
|
|
conv_indices = [i for i in range(len(model.layers)) if isinstance( |
|
model.layers[i], layers.Conv2D)] |
|
|
|
if viz_option.lower() == "filters": |
|
|
|
|
|
|
|
layer_index = st.selectbox( |
|
"Select a layer to see its filters", options=conv_indices) |
|
|
|
weights = model.layers[layer_index].get_weights()[0] |
|
num_filters = weights.shape[-1] |
|
num_channels = weights.shape[-2] |
|
|
|
st.write( |
|
f"This layer has {num_filters} filters and {num_channels} channels per filter.") |
|
|
|
channel_index = st.selectbox( |
|
"Which channel would you like to view?", options=range(1, num_channels + 1)) |
|
|
|
|
|
nrows = math.ceil(math.sqrt(num_filters)) |
|
ncols = math.ceil(math.sqrt(num_filters)) |
|
fig, ax = plt.subplots(nrows, ncols, figsize=( |
|
fig_size[0] * ncols, fig_size[1] * nrows)) |
|
|
|
|
|
for i in range(num_filters - (nrows * ncols), 0): |
|
ax.flatten()[i].remove() |
|
|
|
|
|
for i in range(num_filters): |
|
ax.flatten()[i].imshow(weights[:, :, channel_index - 1, i]) |
|
ax.flatten()[i].set(xticklabels=[], |
|
yticklabels=[], title=f"Filter {i + 1}") |
|
|
|
fig.tight_layout() |
|
|
|
st.pyplot(fig) |
|
else: |
|
|
|
|
|
|
|
img = st.file_uploader(label="Upload image", type=['jpg', 'png']) |
|
|
|
if img: |
|
|
|
img = np.asarray(Image.open(img)) |
|
st.image(img) |
|
|
|
|
|
img = np.expand_dims(np.expand_dims(img, axis=-1), axis=0) |
|
|
|
|
|
layer_index = st.selectbox( |
|
"Feature Map at which layer?", options=conv_indices) |
|
|
|
|
|
temp_model = models.Model( |
|
inputs=model.inputs, outputs=model.layers[layer_index].output) |
|
output = np.squeeze(temp_model.predict(img)) |
|
|
|
num_channels = output.shape[-1] |
|
|
|
nrows = math.ceil(math.sqrt(num_channels)) |
|
ncols = math.ceil(math.sqrt(num_channels)) |
|
fig, ax = plt.subplots(nrows, ncols, figsize=( |
|
fig_size[0] * ncols, fig_size[1] * nrows)) |
|
|
|
|
|
for i in range(num_channels - (nrows * ncols), 0): |
|
ax.flatten()[i].remove() |
|
|
|
|
|
for i in range(num_channels): |
|
ax.flatten()[i].imshow(output[:, :, i]) |
|
ax.flatten()[i].set(xticklabels=[], |
|
yticklabels=[], title=f"Channel {i + 1}") |
|
|
|
fig.tight_layout() |
|
|
|
st.pyplot(fig) |
|
|