import streamlit as st import pickle import io from typing import List, Optional import markdown import matplotlib import matplotlib.pyplot as plt import pandas as pd import plotly.graph_objects as go import streamlit as st from plotly import express as px from plotly.subplots import make_subplots from tqdm import trange import torch from transformers import AutoFeatureExtractor, AutoModelForImageClassification @st.cache(allow_output_mutation=True) # @st.cache_resource def load_dataset(data_index): with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file: dataset = pickle.load(file) return dataset @st.cache(allow_output_mutation=True) # @st.cache_resource def load_dataset_dict(): dataset_dict = {} progress_empty = st.empty() text_empty = st.empty() text_empty.write("Loading datasets...") progress_bar = progress_empty.progress(0.0) for data_index in trange(5): dataset_dict[data_index] = load_dataset(data_index) progress_bar.progress((data_index+1)/5) progress_empty.empty() text_empty.empty() return dataset_dict # @st.cache_data @st.cache(allow_output_mutation=True) def load_image(image_id): dataset = load_dataset(image_id//10000) image = dataset[image_id%10000] return image # @st.cache_data @st.cache(allow_output_mutation=True) def load_images(image_ids): images = [] for image_id in image_ids: image = load_image(image_id) images.append(image) return images @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False) # @st.cache_resource def load_model(model_name): with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."): if model_name == 'ResNet': model_file_path = 'microsoft/resnet-50' feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0) model = AutoModelForImageClassification.from_pretrained(model_file_path) model.eval() elif model_name == 'ConvNeXt': model_file_path = 'facebook/convnext-tiny-224' feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0) model = AutoModelForImageClassification.from_pretrained(model_file_path) model.eval() else: model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True) model.eval() feature_extractor = None return model, feature_extractor def make_grid(cols=None,rows=None): grid = [0]*rows for i in range(rows): with st.container(): grid[i] = st.columns(cols) return grid def use_container_width_percentage(percentage_width:int = 75): max_width_str = f"max-width: {percentage_width}%;" st.markdown(f""" """, unsafe_allow_html=True, ) matplotlib.use("Agg") COLOR = "#31333f" BACKGROUND_COLOR = "#ffffff" def grid_demo(): """Main function. Run this to run the app""" st.sidebar.title("Layout and Style Experiments") st.sidebar.header("Settings") st.markdown( """ # Layout and Style Experiments The basic question is: Can we create a multi-column dashboard with plots, numbers and text using the [CSS Grid](https://gridbyexample.com/examples)? Can we do it with a nice api? Can have a dark theme? """ ) select_block_container_style() add_resources_section() # My preliminary idea of an API for generating a grid with Grid("1 1 1", color=COLOR, background_color=BACKGROUND_COLOR) as grid: grid.cell( class_="a", grid_column_start=2, grid_column_end=3, grid_row_start=1, grid_row_end=2, ).markdown("# This is A Markdown Cell") grid.cell("b", 2, 3, 2, 3).text("The cell to the left is a dataframe") grid.cell("c", 3, 4, 2, 3).plotly_chart(get_plotly_fig()) grid.cell("d", 1, 2, 1, 3).dataframe(get_dataframe()) grid.cell("e", 3, 4, 1, 2).markdown( "Try changing the **block container style** in the sidebar!" ) grid.cell("f", 1, 3, 3, 4).text( "The cell to the right is a matplotlib svg image" ) grid.cell("g", 3, 4, 3, 4).pyplot(get_matplotlib_plt()) def add_resources_section(): """Adds a resources section to the sidebar""" st.sidebar.header("Add_resources_section") st.sidebar.markdown( """ - [gridbyexample.com] (https://gridbyexample.com/examples/) """ ) class Cell: """A Cell can hold text, markdown, plots etc.""" def __init__( self, class_: str = None, grid_column_start: Optional[int] = None, grid_column_end: Optional[int] = None, grid_row_start: Optional[int] = None, grid_row_end: Optional[int] = None, ): self.class_ = class_ self.grid_column_start = grid_column_start self.grid_column_end = grid_column_end self.grid_row_start = grid_row_start self.grid_row_end = grid_row_end self.inner_html = "" def _to_style(self) -> str: return f""" .{self.class_} {{ grid-column-start: {self.grid_column_start}; grid-column-end: {self.grid_column_end}; grid-row-start: {self.grid_row_start}; grid-row-end: {self.grid_row_end}; }} """ def text(self, text: str = ""): self.inner_html = text def markdown(self, text): self.inner_html = markdown.markdown(text) def dataframe(self, dataframe: pd.DataFrame): self.inner_html = dataframe.to_html() def plotly_chart(self, fig): self.inner_html = f"""
This should have been a plotly plot. But since *script* tags are removed when inserting MarkDown/ HTML i cannot get it to workto work. But I could potentially save to svg and insert that.
""" def pyplot(self, fig=None, **kwargs): string_io = io.StringIO() plt.savefig(string_io, format="svg", fig=(2, 2)) svg = string_io.getvalue()[215:] plt.close(fig) self.inner_html = '