JekyllAndHyde8999's picture
Added files
96c04f1
import math
import os
import sys
import tempfile
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
from PIL import Image
from tensorflow.keras import layers, models
# get default figure dimensions for matplotlib
fig_size = plt.rcParams['figure.figsize']
# set default color map to grayscale
plt.rcParams['image.cmap'] = 'gray'
# title for the app
st.title("Filters and Feature Maps Visualization")
# file uploader widget to upload h5 files
model = st.file_uploader(label="Upload model", type=["h5"])
if model:
# start a temporary directory
with tempfile.TemporaryDirectory() as tempdir:
with open(os.path.join(tempdir, "temp.h5"), mode='wb') as f:
# write the model to a temporary file in a temporary folder
f.write(model.getvalue())
# load the model into a `model` variable
model = models.load_model(os.path.join(tempdir, "temp.h5"))
# dropdown menu to visualize filters or maps
viz_option = st.selectbox("What would you like to visualize?",
options=["Filters", "Feature Maps"])
# get indices of conv layers in model
conv_indices = [i for i in range(len(model.layers)) if isinstance(
model.layers[i], layers.Conv2D)]
if viz_option.lower() == "filters":
# filter visualization
# dropdown menu to select a convolutional layer
layer_index = st.selectbox(
"Select a layer to see its filters", options=conv_indices)
weights = model.layers[layer_index].get_weights()[0]
num_filters = weights.shape[-1]
num_channels = weights.shape[-2]
st.write(
f"This layer has {num_filters} filters and {num_channels} channels per filter.")
channel_index = st.selectbox(
"Which channel would you like to view?", options=range(1, num_channels + 1))
# make subplots
nrows = math.ceil(math.sqrt(num_filters))
ncols = math.ceil(math.sqrt(num_filters))
fig, ax = plt.subplots(nrows, ncols, figsize=(
fig_size[0] * ncols, fig_size[1] * nrows))
# remove excess subplots
for i in range(num_filters - (nrows * ncols), 0):
ax.flatten()[i].remove()
# plot filters
for i in range(num_filters):
ax.flatten()[i].imshow(weights[:, :, channel_index - 1, i])
ax.flatten()[i].set(xticklabels=[],
yticklabels=[], title=f"Filter {i + 1}")
fig.tight_layout()
# show plot
st.pyplot(fig)
else:
# feature map visualization
# upload image
img = st.file_uploader(label="Upload image", type=['jpg', 'png'])
if img:
# read image
img = np.asarray(Image.open(img))
st.image(img)
# adjust shape of image
img = np.expand_dims(np.expand_dims(img, axis=-1), axis=0)
# choose layer
layer_index = st.selectbox(
"Feature Map at which layer?", options=conv_indices)
# create temp model
temp_model = models.Model(
inputs=model.inputs, outputs=model.layers[layer_index].output)
output = np.squeeze(temp_model.predict(img))
num_channels = output.shape[-1]
nrows = math.ceil(math.sqrt(num_channels))
ncols = math.ceil(math.sqrt(num_channels))
fig, ax = plt.subplots(nrows, ncols, figsize=(
fig_size[0] * ncols, fig_size[1] * nrows))
# remove excess subplots
for i in range(num_channels - (nrows * ncols), 0):
ax.flatten()[i].remove()
# plot filters
for i in range(num_channels):
ax.flatten()[i].imshow(output[:, :, i])
ax.flatten()[i].set(xticklabels=[],
yticklabels=[], title=f"Channel {i + 1}")
fig.tight_layout()
# show plot
st.pyplot(fig)