nazneen's picture
interactive model card streamlit app
90f4ec6
# --- Visualization ---
import altair as alt
import streamlit as st
import plotly.graph_objects as go
from streamlit_vega_lite import altair_component
# --- Data ---
import pandas as pd
def base_chart(df, linked_vis=False, max_width=150, col_val=None,min_size=100,size_domain=[]):
''' Visualize the model's performance across susbets of the data'''
#Defining populations in the data
pop_domain = ["Overall Performance","Custom Slice","User Custom Sentence","US Protected Class"]
color_range = ["#5778a4", "#e49444", "#b8b0ac","#85b6b2"]
#being chart
base = alt.Chart(df)
if linked_vis:
selected = alt.selection_single(
on="click", empty="none", fields=["name", "source"]
)
base = base.add_selection(selected)
base = (
base.mark_bar().encode(
alt.X("metric_value",
scale=alt.Scale(domain=(0, 1)), title=""
),
alt.Y("displayName", title=""),
alt.Column("metric_type", title=""),
alt.StrokeWidth("size:N",
scale=alt.Scale(domain=size_domain,range=[0,1.25]),
title="#sentences"
),
alt.StrokeOpacity("size:N",
scale=alt.Scale(domain=size_domain,range=[0,1])
),
alt.Stroke("size:N",
scale=alt.Scale(domain=size_domain,range=["white","red"]),
),
alt.Fill("source",
scale = alt.Scale(domain = pop_domain,
range=color_range),
title = "Data Subpopulation"),
opacity=alt.condition(selected, alt.value(1), alt.value(0.5)),
tooltip=["name", "metric_type", "metric_value"]
).properties(width=125
).configure_axis(
labelFontSize=14
).
configure_legend(
labelFontSize=14
)
)
else:
#This is now depracted and should never occur
base = (
base.mark_bar()
.encode(
alt.X("metric_value", scale=alt.Scale(domain=(0, 1)), title=""),
alt.Y(
"metric_type",
title="",
sort=["Overall Performance", "Your Sentences"],
),
# alt.Row("metric_type",title=""),
color=alt.value(col_val),
tooltip=["name", "metric_type", "metric_value"],
)
.properties(width=max_width)
)
return base
@st.cache(allow_output_mutation=True)
def visualize_metrics(metrics, max_width=150, linked_vis=False, col_val="#1f77b4",min_size=1000):
"""
Visualize the metrics of the model.
"""
metric_df = pd.DataFrame()
for key in metrics.keys():
metric_types = []
metric_values = []
tmp = metrics[key]["metrics"]
# get individual metrics
for mt in tmp.keys():
metric_types = metric_types + [mt]
metric_values = metric_values + [tmp[mt]]
name = [key] * len(metric_types)
size = [metrics[key]["size"]] * len(metric_types)
source = [metrics[key]["source"]] * len(metric_types)
metric_df = metric_df.append(
pd.DataFrame(
{
"name": name,
"metric_type": metric_types,
"metric_value": metric_values,
"source": source,
"size" : [ f">={min_size} sentences" if x >= min_size else f"<{min_size} sentences" for x in size]
}
)
)
#adding a human friendly display name (not RG's backend-name)
tmp = [i.split("->") for i in metric_df['name']]
metric_df['displayName']=[i.split("@")[0] for i in [j[0] if len(j)<=1 else j[1] for j in tmp ]]
#passing the size domain
size_domain = [f">={min_size} sentences", f"<{min_size} sentences"]
# generic metric chart
base = base_chart(metric_df, linked_vis, col_val=col_val,size_domain=size_domain)
# layered chart with line
"""
# vertical line
vertline = alt.Chart().mark_rule().encode(x="a:Q")
metric_chart = (
alt.layer(base, vertline,data=metric_df)
.transform_calculate(a="0.5")
.facet(
alt.Column("metric_type", title=""))
.configure_header(labelFontSize=12
)
)
"""
return base
#@st.cache(allow_output_mutation=True)
def data_comparison(df):
#set up a dropdown select bindinf
#input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment'])
selection = alt.selection_multi(fields=['name','sentiment'])
#pop_domain = ["Overall Performance","Custom Slice","User Custom Sentence","US Protected Class"]
#color_range = ["#5778a4", "#e49444", "#b8b0ac","#85b6b2",""]
#highlight colors on select
color = alt.condition(selection,
alt.Color('source:N', legend=None),
#scale = alt.Scale(domain = pop_domain,range=color_range)),
alt.value('lightgray'))
opacity = alt.condition(selection,alt.value(0.7),alt.value(0.25))
#basic chart
scatter = alt.Chart(df).mark_point(size=100,filled=True).encode(
x=alt.X('x',axis=None),
y=alt.Y('y',axis=None),
color = color,
shape=alt.Shape('sentiment', scale=alt.Scale(range=['circle', 'diamond'])),
tooltip=['source','name','sentence','sentiment'],
opacity=opacity
).properties(
width= 600,
height = 700
).interactive()
legend = alt.Chart(df).mark_point().encode(
y=alt.Y('name:N', axis=alt.Axis(orient='right'),title=""),
x=alt.X("sentiment"),
shape=alt.Shape('sentiment', scale=alt.Scale(range=['circle', 'diamond']),legend=None),
color=color
).add_selection(
selection
)
layered = scatter | legend
layered = layered.configure_axis(
grid=False
).configure_view(
strokeOpacity=0
)
return layered
def vis_table(df, userInput=False):
""" DEPRECATED : Visualize table data more effectively """
fig = go.Figure(
data=[
go.Table(
header=dict(
values=list(df.columns), fill_color="paleturquoise", align="left"
),
columnwidth=[400, 50, 50],
cells=dict(
values=[df["sentence"], df["model label"], df["probability"]],
fill_color="lavender",
align="left",
),
)
]
)
return fig