TCC_PoliUSPPro / src /streamlit_app.py
marcossuzuki's picture
using shap_plot_text and shap_waterfall meethods
030fff5 verified
import streamlit as st
import streamlit.components.v1 as components
import shap
from datashap import DataSHAP as ds
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)
def select_fala():
if st.session_state.df_falas['selection']['rows']:
st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0]
num = st.session_state.num_fala
df=st.session_state[st.session_state.empresa][st.session_state.trimestre-1].df
rotulo = df.iloc[num]['tag']
option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',}
st.session_state.rotulo = option_map[rotulo]
@st.cache_resource
def get_dataSHAP(file, company, trim):
shap_value=ds.DataSHAP(file, company, trim)
shap_value.df['tag'] = shap_value.df['tag'].replace({'POSITIVE':'Positivo', 'NEGATIVE':'Negativo', 'NEUTRAL':'Neutro'})
return shap_value
def init_session(key, val):
if key not in st.session_state:
st.session_state[key] = []
for i in range(1,5):
arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save'
shap_value = get_dataSHAP(arquivo, empresa_dict[key], i)
st.session_state[key].append(shap_value)
st.set_page_config(page_title="TCCPoliUSPPro", )
pasta = {'vale':'VALE', 'petr':'Petrobras', 'bb':'BB'}
empresa_dict = {'petr':'Petrobras', 'vale':'Vale', 'bb':'Banco do Brasil'}
option_map = {'NEUTRAL':'Neutro', 'POSITIVE':'Positivo', 'NEGATIVE':'Negativo',}
shap_values = {}
title_score = ['positive_score', 'negative_score', 'neutral_score']
for key, val in pasta.items():
init_session(key, val)
shap_values[key] = st.session_state[key]
st.header("Sentimento da fala e Scores")
col1, col2, col3, col4 = st.columns([1.7,1.2,1.2,2], gap="small", vertical_alignment="bottom")
empresa = col1.selectbox(
"**Qual empresa quer analisar:**",
("vale", "bb", "petr"),
format_func=lambda option: empresa_dict[option],
key='empresa',
)
trim = col2.number_input("**Trimestre de 2024:**", 1, max_value = 4, key='trimestre')
text_num = col3.number_input(
"**Fala número:**",
0, max_value = len(shap_values[empresa][trim-1].shap_value)-1,
key='num_fala',)
df=shap_values[empresa][trim-1].df
total_tokens, h, m, s = shap_values[empresa][trim-1].get_performance()
col4.write(f"**Total tokens:** {total_tokens} \
\n**Compute time:** {h}h {m}m {s:.2}s")
tab1, tab2, tab3 = st.tabs(["**Data Frame**", "**Estatística Score**", '**Gráfico Estatística**'])
with tab1:
st.dataframe(df.style.highlight_max(axis = 1, color ='lightgreen',
subset = title_score),
selection_mode = 'single-row',
key='df_falas',
on_select=select_fala,
column_config={'speech':st.column_config.Column('Fala', width=100),
'qty_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'),
'positive_score':st.column_config.NumberColumn("Score Positivo",),
'negative_score':st.column_config.NumberColumn("Score Negativo",),
'neutral_score':st.column_config.NumberColumn("Score Neutro",),
'tag':"Rótulo",
},
height=200,)
with tab2:
st.dataframe(shap_values[empresa][trim-1].statistic, )
with tab3:
st.plotly_chart(shap_values[empresa][trim-1].plot)
score_positive, score_negative, score_neutral = df.loc[text_num, title_score]
rotulo = st.radio(
"**Rótulo**",
option_map.keys(),
horizontal=True,
format_func=lambda option: option_map[option],
captions = [f'{score_neutral:.4}', f'{score_positive:.4}', f'{score_negative:.4}'],
key='rotulo'
)
plot_text = shap_values[empresa][trim-1].shap_plot_text(text_num, rotulo)
components.html(plot_text, height = 180, scrolling = True)
st.header("Gráfico waterfall dos termos e Valores de Shapley")
with st.expander("Expand"):
max_display = st.slider(
"**Máximo de exibição:**",
1, max_value = int(df['qty_tokens'][text_num]),
value=int(int(df['qty_tokens'][text_num])/3)+1
)
plot_waterfall = shap_values[empresa][trim-1].shap_waterfall(text_num, rotulo, max_display)
st.pyplot(plot_waterfall)
st.header('Rank de termos do documento em Gráfico Barra')
with st.expander("Expand"):
plot_bar, ax, rank = shap_values[empresa][trim-1].get_plot_rank()
for key, val in option_map.items():
st.subheader(val)
st.pyplot(plot_bar[key])
st.dataframe(rank)