Spaces:
Running
Running
File size: 4,770 Bytes
b1582cb b05c932 acf89de b05c932 b20b652 b1582cb 4b035a2 439a73e 1ec92fe acf89de 1ec92fe 439a73e c42259d acf89de c42259d acf89de ca38aeb a46c542 4b035a2 439a73e ca38aeb 4b035a2 acf89de ca38aeb b05c932 a46c542 d9d5b62 439a73e 387bf78 439a73e 387bf78 1ec92fe 387bf78 1ec92fe 387bf78 439a73e acf89de 439a73e 387bf78 acf89de 4b035a2 acf89de 4b035a2 439a73e 8936371 4b035a2 3d690df c42259d acf89de c42259d acf89de 439a73e 10251b4 cd95107 4b035a2 acf89de fe700df f54b69a acf89de f54b69a 439a73e 3d690df b05c932 10251b4 4b035a2 b05c932 1ec92fe b05c932 030fff5 3d690df d9d5b62 387bf78 b05c932 10251b4 acf89de 10251b4 030fff5 10251b4 acf89de 3d690df b05c932 10251b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import streamlit as st
import streamlit.components.v1 as components
import shap
from datashap import DataSHAP as ds
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)
def select_fala():
if st.session_state.df_falas['selection']['rows']:
st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0]
num = st.session_state.num_fala
df=st.session_state[st.session_state.empresa][st.session_state.trimestre-1].df
rotulo = df.iloc[num]['tag']
option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',}
st.session_state.rotulo = option_map[rotulo]
@st.cache_resource
def get_dataSHAP(file, company, trim):
shap_value=ds.DataSHAP(file, company, trim)
shap_value.df['tag'] = shap_value.df['tag'].replace({'POSITIVE':'Positivo', 'NEGATIVE':'Negativo', 'NEUTRAL':'Neutro'})
return shap_value
def init_session(key, val):
if key not in st.session_state:
st.session_state[key] = []
for i in range(1,5):
arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save'
shap_value = get_dataSHAP(arquivo, empresa_dict[key], i)
st.session_state[key].append(shap_value)
st.set_page_config(page_title="TCCPoliUSPPro", )
pasta = {'vale':'VALE', 'petr':'Petrobras', 'bb':'BB'}
empresa_dict = {'petr':'Petrobras', 'vale':'Vale', 'bb':'Banco do Brasil'}
option_map = {'NEUTRAL':'Neutro', 'POSITIVE':'Positivo', 'NEGATIVE':'Negativo',}
shap_values = {}
title_score = ['positive_score', 'negative_score', 'neutral_score']
for key, val in pasta.items():
init_session(key, val)
shap_values[key] = st.session_state[key]
st.header("Sentimento da fala e Scores")
col1, col2, col3, col4 = st.columns([1.7,1.2,1.2,2], gap="small", vertical_alignment="bottom")
empresa = col1.selectbox(
"**Qual empresa quer analisar:**",
("vale", "bb", "petr"),
format_func=lambda option: empresa_dict[option],
key='empresa',
)
trim = col2.number_input("**Trimestre de 2024:**", 1, max_value = 4, key='trimestre')
text_num = col3.number_input(
"**Fala número:**",
0, max_value = len(shap_values[empresa][trim-1].shap_value)-1,
key='num_fala',)
df=shap_values[empresa][trim-1].df
total_tokens, h, m, s = shap_values[empresa][trim-1].get_performance()
col4.write(f"**Total tokens:** {total_tokens} \
\n**Compute time:** {h}h {m}m {s:.2}s")
tab1, tab2, tab3 = st.tabs(["**Data Frame**", "**Estatística Score**", '**Gráfico Estatística**'])
with tab1:
st.dataframe(df.style.highlight_max(axis = 1, color ='lightgreen',
subset = title_score),
selection_mode = 'single-row',
key='df_falas',
on_select=select_fala,
column_config={'speech':st.column_config.Column('Fala', width=100),
'qty_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'),
'positive_score':st.column_config.NumberColumn("Score Positivo",),
'negative_score':st.column_config.NumberColumn("Score Negativo",),
'neutral_score':st.column_config.NumberColumn("Score Neutro",),
'tag':"Rótulo",
},
height=200,)
with tab2:
st.dataframe(shap_values[empresa][trim-1].statistic, )
with tab3:
st.plotly_chart(shap_values[empresa][trim-1].plot)
score_positive, score_negative, score_neutral = df.loc[text_num, title_score]
rotulo = st.radio(
"**Rótulo**",
option_map.keys(),
horizontal=True,
format_func=lambda option: option_map[option],
captions = [f'{score_neutral:.4}', f'{score_positive:.4}', f'{score_negative:.4}'],
key='rotulo'
)
plot_text = shap_values[empresa][trim-1].shap_plot_text(text_num, rotulo)
components.html(plot_text, height = 180, scrolling = True)
st.header("Gráfico waterfall dos termos e Valores de Shapley")
with st.expander("Expand"):
max_display = st.slider(
"**Máximo de exibição:**",
1, max_value = int(df['qty_tokens'][text_num]),
value=int(int(df['qty_tokens'][text_num])/3)+1
)
plot_waterfall = shap_values[empresa][trim-1].shap_waterfall(text_num, rotulo, max_display)
st.pyplot(plot_waterfall)
st.header('Rank de termos do documento em Gráfico Barra')
with st.expander("Expand"):
plot_bar, ax, rank = shap_values[empresa][trim-1].get_plot_rank()
for key, val in option_map.items():
st.subheader(val)
st.pyplot(plot_bar[key])
st.dataframe(rank) |