Spaces:
Running
Running
import streamlit as st | |
import streamlit.components.v1 as components | |
import shap | |
from datashap import DataSHAP as ds | |
import matplotlib.pyplot as plt | |
import warnings | |
warnings.simplefilter("ignore", category=DeprecationWarning) | |
def select_fala(): | |
if st.session_state.df_falas['selection']['rows']: | |
st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0] | |
num = st.session_state.num_fala | |
df=st.session_state[st.session_state.empresa][st.session_state.trimestre-1].df | |
rotulo = df.iloc[num]['tag'] | |
option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',} | |
st.session_state.rotulo = option_map[rotulo] | |
def get_dataSHAP(file, company, trim): | |
shap_value=ds.DataSHAP(file, company, trim) | |
shap_value.df['tag'] = shap_value.df['tag'].replace({'POSITIVE':'Positivo', 'NEGATIVE':'Negativo', 'NEUTRAL':'Neutro'}) | |
return shap_value | |
def init_session(key, val): | |
if key not in st.session_state: | |
st.session_state[key] = [] | |
for i in range(1,5): | |
arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save' | |
shap_value = get_dataSHAP(arquivo, empresa_dict[key], i) | |
st.session_state[key].append(shap_value) | |
st.set_page_config(page_title="TCCPoliUSPPro", ) | |
pasta = {'vale':'VALE', 'petr':'Petrobras', 'bb':'BB'} | |
empresa_dict = {'petr':'Petrobras', 'vale':'Vale', 'bb':'Banco do Brasil'} | |
option_map = {'NEUTRAL':'Neutro', 'POSITIVE':'Positivo', 'NEGATIVE':'Negativo',} | |
shap_values = {} | |
title_score = ['positive_score', 'negative_score', 'neutral_score'] | |
for key, val in pasta.items(): | |
init_session(key, val) | |
shap_values[key] = st.session_state[key] | |
st.header("Sentimento da fala e Scores") | |
col1, col2, col3, col4 = st.columns([1.7,1.2,1.2,2], gap="small", vertical_alignment="bottom") | |
empresa = col1.selectbox( | |
"**Qual empresa quer analisar:**", | |
("vale", "bb", "petr"), | |
format_func=lambda option: empresa_dict[option], | |
key='empresa', | |
) | |
trim = col2.number_input("**Trimestre de 2024:**", 1, max_value = 4, key='trimestre') | |
text_num = col3.number_input( | |
"**Fala número:**", | |
0, max_value = len(shap_values[empresa][trim-1].shap_value)-1, | |
key='num_fala',) | |
df=shap_values[empresa][trim-1].df | |
total_tokens, h, m, s = shap_values[empresa][trim-1].get_performance() | |
col4.write(f"**Total tokens:** {total_tokens} \ | |
\n**Compute time:** {h}h {m}m {s:.2}s") | |
tab1, tab2, tab3 = st.tabs(["**Data Frame**", "**Estatística Score**", '**Gráfico Estatística**']) | |
with tab1: | |
st.dataframe(df.style.highlight_max(axis = 1, color ='lightgreen', | |
subset = title_score), | |
selection_mode = 'single-row', | |
key='df_falas', | |
on_select=select_fala, | |
column_config={'speech':st.column_config.Column('Fala', width=100), | |
'qty_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'), | |
'positive_score':st.column_config.NumberColumn("Score Positivo",), | |
'negative_score':st.column_config.NumberColumn("Score Negativo",), | |
'neutral_score':st.column_config.NumberColumn("Score Neutro",), | |
'tag':"Rótulo", | |
}, | |
height=200,) | |
with tab2: | |
st.dataframe(shap_values[empresa][trim-1].statistic, ) | |
with tab3: | |
st.plotly_chart(shap_values[empresa][trim-1].plot) | |
score_positive, score_negative, score_neutral = df.loc[text_num, title_score] | |
rotulo = st.radio( | |
"**Rótulo**", | |
option_map.keys(), | |
horizontal=True, | |
format_func=lambda option: option_map[option], | |
captions = [f'{score_neutral:.4}', f'{score_positive:.4}', f'{score_negative:.4}'], | |
key='rotulo' | |
) | |
plot_text = shap_values[empresa][trim-1].shap_plot_text(text_num, rotulo) | |
components.html(plot_text, height = 180, scrolling = True) | |
st.header("Gráfico waterfall dos termos e Valores de Shapley") | |
with st.expander("Expand"): | |
max_display = st.slider( | |
"**Máximo de exibição:**", | |
1, max_value = int(df['qty_tokens'][text_num]), | |
value=int(int(df['qty_tokens'][text_num])/3)+1 | |
) | |
plot_waterfall = shap_values[empresa][trim-1].shap_waterfall(text_num, rotulo, max_display) | |
st.pyplot(plot_waterfall) | |
st.header('Rank de termos do documento em Gráfico Barra') | |
with st.expander("Expand"): | |
plot_bar, ax, rank = shap_values[empresa][trim-1].get_plot_rank() | |
for key, val in option_map.items(): | |
st.subheader(val) | |
st.pyplot(plot_bar[key]) | |
st.dataframe(rank) |