Spaces:
Build error
Build error
# --- Visualization --- | |
import altair as alt | |
import streamlit as st | |
import plotly.graph_objects as go | |
from streamlit_vega_lite import altair_component | |
# --- Data --- | |
import pandas as pd | |
def base_chart(df, linked_vis=False, max_width=150, col_val=None,min_size=100,size_domain=[]): | |
''' Visualize the model's performance across susbets of the data''' | |
#Defining populations in the data | |
pop_domain = ["Overall Performance","Custom Slice","User Custom Sentence","US Protected Class"] | |
color_range = ["#5778a4", "#e49444", "#b8b0ac","#85b6b2"] | |
#being chart | |
base = alt.Chart(df) | |
if linked_vis: | |
selected = alt.selection_single( | |
on="click", empty="none", fields=["name", "source"] | |
) | |
base = base.add_selection(selected) | |
base = ( | |
base.mark_bar().encode( | |
alt.X("metric_value", | |
scale=alt.Scale(domain=(0, 1)), title="" | |
), | |
alt.Y("displayName", title=""), | |
alt.Column("metric_type", title=""), | |
alt.StrokeWidth("size:N", | |
scale=alt.Scale(domain=size_domain,range=[0,1.25]), | |
title="#sentences" | |
), | |
alt.StrokeOpacity("size:N", | |
scale=alt.Scale(domain=size_domain,range=[0,1]) | |
), | |
alt.Stroke("size:N", | |
scale=alt.Scale(domain=size_domain,range=["white","red"]), | |
), | |
alt.Fill("source", | |
scale = alt.Scale(domain = pop_domain, | |
range=color_range), | |
title = "Data Subpopulation"), | |
opacity=alt.condition(selected, alt.value(1), alt.value(0.5)), | |
tooltip=["name", "metric_type", "metric_value"] | |
).properties(width=125 | |
).configure_axis( | |
labelFontSize=14 | |
). | |
configure_legend( | |
labelFontSize=14 | |
) | |
) | |
else: | |
#This is now depracted and should never occur | |
base = ( | |
base.mark_bar() | |
.encode( | |
alt.X("metric_value", scale=alt.Scale(domain=(0, 1)), title=""), | |
alt.Y( | |
"metric_type", | |
title="", | |
sort=["Overall Performance", "Your Sentences"], | |
), | |
# alt.Row("metric_type",title=""), | |
color=alt.value(col_val), | |
tooltip=["name", "metric_type", "metric_value"], | |
) | |
.properties(width=max_width) | |
) | |
return base | |
def visualize_metrics(metrics, max_width=150, linked_vis=False, col_val="#1f77b4",min_size=1000): | |
""" | |
Visualize the metrics of the model. | |
""" | |
metric_df = pd.DataFrame() | |
for key in metrics.keys(): | |
metric_types = [] | |
metric_values = [] | |
tmp = metrics[key]["metrics"] | |
# get individual metrics | |
for mt in tmp.keys(): | |
metric_types = metric_types + [mt] | |
metric_values = metric_values + [tmp[mt]] | |
name = [key] * len(metric_types) | |
size = [metrics[key]["size"]] * len(metric_types) | |
source = [metrics[key]["source"]] * len(metric_types) | |
metric_df = metric_df.append( | |
pd.DataFrame( | |
{ | |
"name": name, | |
"metric_type": metric_types, | |
"metric_value": metric_values, | |
"source": source, | |
"size" : [ f">={min_size} sentences" if x >= min_size else f"<{min_size} sentences" for x in size] | |
} | |
) | |
) | |
#adding a human friendly display name (not RG's backend-name) | |
tmp = [i.split("->") for i in metric_df['name']] | |
metric_df['displayName']=[i.split("@")[0] for i in [j[0] if len(j)<=1 else j[1] for j in tmp ]] | |
#passing the size domain | |
size_domain = [f">={min_size} sentences", f"<{min_size} sentences"] | |
# generic metric chart | |
base = base_chart(metric_df, linked_vis, col_val=col_val,size_domain=size_domain) | |
# layered chart with line | |
""" | |
# vertical line | |
vertline = alt.Chart().mark_rule().encode(x="a:Q") | |
metric_chart = ( | |
alt.layer(base, vertline,data=metric_df) | |
.transform_calculate(a="0.5") | |
.facet( | |
alt.Column("metric_type", title="")) | |
.configure_header(labelFontSize=12 | |
) | |
) | |
""" | |
return base | |
#@st.cache(allow_output_mutation=True) | |
def data_comparison(df): | |
#set up a dropdown select bindinf | |
#input_dropdown = alt.binding_select(options=['Negative Sentiment','Positive Sentiment']) | |
selection = alt.selection_multi(fields=['name','sentiment']) | |
#pop_domain = ["Overall Performance","Custom Slice","User Custom Sentence","US Protected Class"] | |
#color_range = ["#5778a4", "#e49444", "#b8b0ac","#85b6b2",""] | |
#highlight colors on select | |
color = alt.condition(selection, | |
alt.Color('source:N', legend=None), | |
#scale = alt.Scale(domain = pop_domain,range=color_range)), | |
alt.value('lightgray')) | |
opacity = alt.condition(selection,alt.value(0.7),alt.value(0.25)) | |
#basic chart | |
scatter = alt.Chart(df).mark_point(size=100,filled=True).encode( | |
x=alt.X('x',axis=None), | |
y=alt.Y('y',axis=None), | |
color = color, | |
shape=alt.Shape('sentiment', scale=alt.Scale(range=['circle', 'diamond'])), | |
tooltip=['source','name','sentence','sentiment'], | |
opacity=opacity | |
).properties( | |
width= 600, | |
height = 700 | |
).interactive() | |
legend = alt.Chart(df).mark_point().encode( | |
y=alt.Y('name:N', axis=alt.Axis(orient='right'),title=""), | |
x=alt.X("sentiment"), | |
shape=alt.Shape('sentiment', scale=alt.Scale(range=['circle', 'diamond']),legend=None), | |
color=color | |
).add_selection( | |
selection | |
) | |
layered = scatter | legend | |
layered = layered.configure_axis( | |
grid=False | |
).configure_view( | |
strokeOpacity=0 | |
) | |
return layered | |
def vis_table(df, userInput=False): | |
""" DEPRECATED : Visualize table data more effectively """ | |
fig = go.Figure( | |
data=[ | |
go.Table( | |
header=dict( | |
values=list(df.columns), fill_color="paleturquoise", align="left" | |
), | |
columnwidth=[400, 50, 50], | |
cells=dict( | |
values=[df["sentence"], df["model label"], df["probability"]], | |
fill_color="lavender", | |
align="left", | |
), | |
) | |
] | |
) | |
return fig | |