import streamlit as st import pandas as pd import plotly.graph_objects as go import plotly.express as px covariate_columns = { 'content_domain': 'Content Domain', 'language': 'Language', 'rater_group': 'Rater Group', } id_vars = [ 'mean_z', 'text', 'content_domain', 'language', 'rater_group', 'study', 'instrument' ] if 'df' not in st.session_state: st.session_state.df = ( pd .read_feather(path='data.feather') .query('partition == "test" | partition == "dev"') .melt( value_vars=['sentiment_model', 'desirability_model'], var_name='x_group', value_name='x', id_vars=id_vars ) .replace( to_replace={ 'en': 'English', 'de': 'German', 'other': 'Other', 'personality': 'Personality', 'laypeople': 'Laypeople', 'students': 'Students', 'sentiment_model': 'Sentiment Model', 'desirability_model': 'Desirability Model' } ) .rename(columns=covariate_columns) .rename( columns={ 'mean_z': 'Human-ratings', 'x': 'Machine-ratings', } ) ) def scatter_plot(df, group_var): colors = ['#36def1', '#4361ee'] if group_var else ['#4361ee'] plot = px.scatter( df, x='Machine-ratings', y='Human-ratings', color=group_var, facet_col='x_group', facet_col_wrap=2, trendline='ols', trendline_scope='trace', hover_data={ 'Text': df.text, 'Language': False, 'x_group': False, 'Human-ratings': ':.2f', 'Machine-ratings': ':.2f', 'Study': df.study, 'Instrument': df.instrument, }, width=400, height=400, color_discrete_sequence=colors ) plot.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1])) plot.update_layout( legend={ 'orientation':'h', 'yanchor': 'bottom', 'y': -.30 }) plot.update_xaxes(title_standoff = 0) return plot def show(): st.markdown(""" ## Explore the data Figures show the accuarcy in precitions of human-rated item desirability by the sentiment model (left) and the desirability model (right), using `test`-partition data only. """) show_covariates = st.checkbox('Show covariates', value=True) if show_covariates: option = st.selectbox('Group by', options=list(covariate_columns.values())) else: option = None if 'df' in st.session_state: plot = scatter_plot(st.session_state.df, option) st.plotly_chart(plot, theme=None, use_container_width=True)