Ludovica Schaerf
Duplicate from taquynhnga/CNNs-interpretation-visualization
fcc16aa
raw
history blame
11.2 kB
import streamlit as st
import pickle
import io
from typing import List, Optional
import markdown
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import streamlit as st
from plotly import express as px
from plotly.subplots import make_subplots
from tqdm import trange
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
@st.cache(allow_output_mutation=True)
# @st.cache_resource
def load_dataset(data_index):
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
dataset = pickle.load(file)
return dataset
@st.cache(allow_output_mutation=True)
# @st.cache_resource
def load_dataset_dict():
dataset_dict = {}
progress_empty = st.empty()
text_empty = st.empty()
text_empty.write("Loading datasets...")
progress_bar = progress_empty.progress(0.0)
for data_index in trange(5):
dataset_dict[data_index] = load_dataset(data_index)
progress_bar.progress((data_index+1)/5)
progress_empty.empty()
text_empty.empty()
return dataset_dict
# @st.cache_data
@st.cache(allow_output_mutation=True)
def load_image(image_id):
dataset = load_dataset(image_id//10000)
image = dataset[image_id%10000]
return image
# @st.cache_data
@st.cache(allow_output_mutation=True)
def load_images(image_ids):
images = []
for image_id in image_ids:
image = load_image(image_id)
images.append(image)
return images
@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
# @st.cache_resource
def load_model(model_name):
with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
if model_name == 'ResNet':
model_file_path = 'microsoft/resnet-50'
feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
model = AutoModelForImageClassification.from_pretrained(model_file_path)
model.eval()
elif model_name == 'ConvNeXt':
model_file_path = 'facebook/convnext-tiny-224'
feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
model = AutoModelForImageClassification.from_pretrained(model_file_path)
model.eval()
else:
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.eval()
feature_extractor = None
return model, feature_extractor
def make_grid(cols=None,rows=None):
grid = [0]*rows
for i in range(rows):
with st.container():
grid[i] = st.columns(cols)
return grid
def use_container_width_percentage(percentage_width:int = 75):
max_width_str = f"max-width: {percentage_width}%;"
st.markdown(f"""
<style>
.reportview-container .main .block-container{{{max_width_str}}}
</style>
""",
unsafe_allow_html=True,
)
matplotlib.use("Agg")
COLOR = "#31333f"
BACKGROUND_COLOR = "#ffffff"
def grid_demo():
"""Main function. Run this to run the app"""
st.sidebar.title("Layout and Style Experiments")
st.sidebar.header("Settings")
st.markdown(
"""
# Layout and Style Experiments
The basic question is: Can we create a multi-column dashboard with plots, numbers and text using
the [CSS Grid](https://gridbyexample.com/examples)?
Can we do it with a nice api?
Can have a dark theme?
"""
)
select_block_container_style()
add_resources_section()
# My preliminary idea of an API for generating a grid
with Grid("1 1 1", color=COLOR, background_color=BACKGROUND_COLOR) as grid:
grid.cell(
class_="a",
grid_column_start=2,
grid_column_end=3,
grid_row_start=1,
grid_row_end=2,
).markdown("# This is A Markdown Cell")
grid.cell("b", 2, 3, 2, 3).text("The cell to the left is a dataframe")
grid.cell("c", 3, 4, 2, 3).plotly_chart(get_plotly_fig())
grid.cell("d", 1, 2, 1, 3).dataframe(get_dataframe())
grid.cell("e", 3, 4, 1, 2).markdown(
"Try changing the **block container style** in the sidebar!"
)
grid.cell("f", 1, 3, 3, 4).text(
"The cell to the right is a matplotlib svg image"
)
grid.cell("g", 3, 4, 3, 4).pyplot(get_matplotlib_plt())
def add_resources_section():
"""Adds a resources section to the sidebar"""
st.sidebar.header("Add_resources_section")
st.sidebar.markdown(
"""
- [gridbyexample.com] (https://gridbyexample.com/examples/)
"""
)
class Cell:
"""A Cell can hold text, markdown, plots etc."""
def __init__(
self,
class_: str = None,
grid_column_start: Optional[int] = None,
grid_column_end: Optional[int] = None,
grid_row_start: Optional[int] = None,
grid_row_end: Optional[int] = None,
):
self.class_ = class_
self.grid_column_start = grid_column_start
self.grid_column_end = grid_column_end
self.grid_row_start = grid_row_start
self.grid_row_end = grid_row_end
self.inner_html = ""
def _to_style(self) -> str:
return f"""
.{self.class_} {{
grid-column-start: {self.grid_column_start};
grid-column-end: {self.grid_column_end};
grid-row-start: {self.grid_row_start};
grid-row-end: {self.grid_row_end};
}}
"""
def text(self, text: str = ""):
self.inner_html = text
def markdown(self, text):
self.inner_html = markdown.markdown(text)
def dataframe(self, dataframe: pd.DataFrame):
self.inner_html = dataframe.to_html()
def plotly_chart(self, fig):
self.inner_html = f"""
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<body>
<p>This should have been a plotly plot.
But since *script* tags are removed when inserting MarkDown/ HTML i cannot get it to workto work.
But I could potentially save to svg and insert that.</p>
<div id='divPlotly'></div>
<script>
var plotly_data = {fig.to_json()}
Plotly.react('divPlotly', plotly_data.data, plotly_data.layout);
</script>
</body>
"""
def pyplot(self, fig=None, **kwargs):
string_io = io.StringIO()
plt.savefig(string_io, format="svg", fig=(2, 2))
svg = string_io.getvalue()[215:]
plt.close(fig)
self.inner_html = '<div height="200px">' + svg + "</div>"
def _to_html(self):
return f"""<div class="box {self.class_}">{self.inner_html}</div>"""
class Grid:
"""A (CSS) Grid"""
def __init__(
self,
template_columns="1 1 1",
gap="10px",
background_color=COLOR,
color=BACKGROUND_COLOR,
):
self.template_columns = template_columns
self.gap = gap
self.background_color = background_color
self.color = color
self.cells: List[Cell] = []
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
st.markdown(self._get_grid_style(), unsafe_allow_html=True)
st.markdown(self._get_cells_style(), unsafe_allow_html=True)
st.markdown(self._get_cells_html(), unsafe_allow_html=True)
def _get_grid_style(self):
return f"""
<style>
.wrapper {{
display: grid;
grid-template-columns: {self.template_columns};
grid-gap: {self.gap};
background-color: {self.color};
color: {self.background_color};
}}
.box {{
background-color: {self.color};
color: {self.background_color};
border-radius: 0px;
padding: 0px;
font-size: 100%;
text-align: center;
}}
table {{
color: {self.color}
}}
</style>
"""
def _get_cells_style(self):
return (
"<style>"
+ "\n".join([cell._to_style() for cell in self.cells])
+ "</style>"
)
def _get_cells_html(self):
return (
'<div class="wrapper">'
+ "\n".join([cell._to_html() for cell in self.cells])
+ "</div>"
)
def cell(
self,
class_: str = None,
grid_column_start: Optional[int] = None,
grid_column_end: Optional[int] = None,
grid_row_start: Optional[int] = None,
grid_row_end: Optional[int] = None,
):
cell = Cell(
class_=class_,
grid_column_start=grid_column_start,
grid_column_end=grid_column_end,
grid_row_start=grid_row_start,
grid_row_end=grid_row_end,
)
self.cells.append(cell)
return cell
def select_block_container_style():
"""Add selection section for setting setting the max-width and padding
of the main block container"""
st.sidebar.header("Block Container Style")
max_width_100_percent = st.sidebar.checkbox("Max-width: 100%?", False)
if not max_width_100_percent:
max_width = st.sidebar.slider("Select max-width in px", 100, 2000, 1200, 100)
else:
max_width = 1200
dark_theme = st.sidebar.checkbox("Dark Theme?", False)
padding_top = st.sidebar.number_input("Select padding top in rem", 0, 200, 5, 1)
padding_right = st.sidebar.number_input("Select padding right in rem", 0, 200, 1, 1)
padding_left = st.sidebar.number_input("Select padding left in rem", 0, 200, 1, 1)
padding_bottom = st.sidebar.number_input(
"Select padding bottom in rem", 0, 200, 10, 1
)
if dark_theme:
global COLOR
global BACKGROUND_COLOR
BACKGROUND_COLOR = "rgb(17,17,17)"
COLOR = "#fff"
_set_block_container_style(
max_width,
max_width_100_percent,
padding_top,
padding_right,
padding_left,
padding_bottom,
)
def _set_block_container_style(
max_width: int = 1200,
max_width_100_percent: bool = False,
padding_top: int = 5,
padding_right: int = 1,
padding_left: int = 1,
padding_bottom: int = 10,
):
if max_width_100_percent:
max_width_str = f"max-width: 100%;"
else:
max_width_str = f"max-width: {max_width}px;"
st.markdown(
f"""
<style>
.reportview-container .main .block-container{{
{max_width_str}
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
.reportview-container .main {{
color: {COLOR};
background-color: {BACKGROUND_COLOR};
}}
</style>
""",
unsafe_allow_html=True,
)
# @st.cache
# def get_dataframe() -> pd.DataFrame():
# """Dummy DataFrame"""
# data = [
# {"quantity": 1, "price": 2},
# {"quantity": 3, "price": 5},
# {"quantity": 4, "price": 8},
# ]
# return pd.DataFrame(data)
# def get_plotly_fig():
# """Dummy Plotly Plot"""
# return px.line(data_frame=get_dataframe(), x="quantity", y="price")
# def get_matplotlib_plt():
# get_dataframe().plot(kind="line", x="quantity", y="price", figsize=(5, 3))