import math import os import sys import tempfile import matplotlib.pyplot as plt import numpy as np import streamlit as st from PIL import Image from tensorflow.keras import layers, models # get default figure dimensions for matplotlib fig_size = plt.rcParams['figure.figsize'] # set default color map to grayscale plt.rcParams['image.cmap'] = 'gray' # title for the app st.title("Filters and Feature Maps Visualization") # file uploader widget to upload h5 files model = st.file_uploader(label="Upload model", type=["h5"]) if model: # start a temporary directory with tempfile.TemporaryDirectory() as tempdir: with open(os.path.join(tempdir, "temp.h5"), mode='wb') as f: # write the model to a temporary file in a temporary folder f.write(model.getvalue()) # load the model into a `model` variable model = models.load_model(os.path.join(tempdir, "temp.h5")) # dropdown menu to visualize filters or maps viz_option = st.selectbox("What would you like to visualize?", options=["Filters", "Feature Maps"]) # get indices of conv layers in model conv_indices = [i for i in range(len(model.layers)) if isinstance( model.layers[i], layers.Conv2D)] if viz_option.lower() == "filters": # filter visualization # dropdown menu to select a convolutional layer layer_index = st.selectbox( "Select a layer to see its filters", options=conv_indices) weights = model.layers[layer_index].get_weights()[0] num_filters = weights.shape[-1] num_channels = weights.shape[-2] st.write( f"This layer has {num_filters} filters and {num_channels} channels per filter.") channel_index = st.selectbox( "Which channel would you like to view?", options=range(1, num_channels + 1)) # make subplots nrows = math.ceil(math.sqrt(num_filters)) ncols = math.ceil(math.sqrt(num_filters)) fig, ax = plt.subplots(nrows, ncols, figsize=( fig_size[0] * ncols, fig_size[1] * nrows)) # remove excess subplots for i in range(num_filters - (nrows * ncols), 0): ax.flatten()[i].remove() # plot filters for i in range(num_filters): ax.flatten()[i].imshow(weights[:, :, channel_index - 1, i]) ax.flatten()[i].set(xticklabels=[], yticklabels=[], title=f"Filter {i + 1}") fig.tight_layout() # show plot st.pyplot(fig) else: # feature map visualization # upload image img = st.file_uploader(label="Upload image", type=['jpg', 'png']) if img: # read image img = np.asarray(Image.open(img)) st.image(img) # adjust shape of image img = np.expand_dims(np.expand_dims(img, axis=-1), axis=0) # choose layer layer_index = st.selectbox( "Feature Map at which layer?", options=conv_indices) # create temp model temp_model = models.Model( inputs=model.inputs, outputs=model.layers[layer_index].output) output = np.squeeze(temp_model.predict(img)) num_channels = output.shape[-1] nrows = math.ceil(math.sqrt(num_channels)) ncols = math.ceil(math.sqrt(num_channels)) fig, ax = plt.subplots(nrows, ncols, figsize=( fig_size[0] * ncols, fig_size[1] * nrows)) # remove excess subplots for i in range(num_channels - (nrows * ncols), 0): ax.flatten()[i].remove() # plot filters for i in range(num_channels): ax.flatten()[i].imshow(output[:, :, i]) ax.flatten()[i].set(xticklabels=[], yticklabels=[], title=f"Channel {i + 1}") fig.tight_layout() # show plot st.pyplot(fig)