File size: 4,770 Bytes
b1582cb
b05c932
 
acf89de
b05c932
b20b652
 
b1582cb
4b035a2
439a73e
 
 
1ec92fe
acf89de
 
1ec92fe
 
439a73e
c42259d
acf89de
 
 
c42259d
 
acf89de
 
 
 
 
 
 
ca38aeb
a46c542
 
4b035a2
439a73e
 
ca38aeb
 
4b035a2
 
acf89de
ca38aeb
b05c932
a46c542
d9d5b62
439a73e
387bf78
 
439a73e
387bf78
 
1ec92fe
387bf78
 
1ec92fe
387bf78
 
439a73e
acf89de
439a73e
387bf78
acf89de
4b035a2
acf89de
4b035a2
439a73e
 
 
8936371
4b035a2
3d690df
 
c42259d
 
 
acf89de
 
c42259d
 
 
acf89de
439a73e
10251b4
cd95107
4b035a2
acf89de
fe700df
f54b69a
acf89de
f54b69a
439a73e
3d690df
b05c932
10251b4
 
4b035a2
 
b05c932
1ec92fe
 
b05c932
 
030fff5
3d690df
d9d5b62
387bf78
b05c932
10251b4
 
 
acf89de
 
10251b4
 
030fff5
10251b4
 
 
 
 
acf89de
3d690df
 
 
b05c932
10251b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import streamlit.components.v1 as components
import shap
from datashap import DataSHAP as ds
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)


def select_fala():
    if st.session_state.df_falas['selection']['rows']:
        st.session_state.num_fala = st.session_state.df_falas['selection']['rows'][0]
        num = st.session_state.num_fala
        df=st.session_state[st.session_state.empresa][st.session_state.trimestre-1].df
        rotulo = df.iloc[num]['tag']
        option_map = {'Neutro':'NEUTRAL', 'Positivo':'POSITIVE', 'Negativo':'NEGATIVE',}
        st.session_state.rotulo = option_map[rotulo]

@st.cache_resource
def get_dataSHAP(file, company, trim):
    shap_value=ds.DataSHAP(file, company, trim)
    shap_value.df['tag'] = shap_value.df['tag'].replace({'POSITIVE':'Positivo', 'NEGATIVE':'Negativo', 'NEUTRAL':'Neutro'})
    return shap_value

def init_session(key, val):
    if key not in st.session_state:
        st.session_state[key] = []
        for i in range(1,5):
            arquivo=f'spaces/marcossuzuki/TCC_PoliUSPPro/transcrição audio RI/{val}/valores_shap-{key}{i}t24.save'
            shap_value = get_dataSHAP(arquivo, empresa_dict[key], i)
            st.session_state[key].append(shap_value)

st.set_page_config(page_title="TCCPoliUSPPro", )

pasta = {'vale':'VALE', 'petr':'Petrobras', 'bb':'BB'}
empresa_dict = {'petr':'Petrobras', 'vale':'Vale', 'bb':'Banco do Brasil'}
option_map = {'NEUTRAL':'Neutro', 'POSITIVE':'Positivo', 'NEGATIVE':'Negativo',}
shap_values = {}
title_score = ['positive_score', 'negative_score', 'neutral_score']

for key, val in pasta.items():
    init_session(key, val)
    shap_values[key] = st.session_state[key]

st.header("Sentimento da fala e Scores")

col1, col2, col3, col4 = st.columns([1.7,1.2,1.2,2],  gap="small", vertical_alignment="bottom")

empresa = col1.selectbox(
    "**Qual empresa quer analisar:**",
    ("vale", "bb", "petr"),
    format_func=lambda option: empresa_dict[option],
    key='empresa', 
)

trim = col2.number_input("**Trimestre de 2024:**", 1, max_value = 4, key='trimestre')

text_num = col3.number_input(
    "**Fala número:**",
    0, max_value = len(shap_values[empresa][trim-1].shap_value)-1,
    key='num_fala',)

df=shap_values[empresa][trim-1].df

total_tokens, h, m, s = shap_values[empresa][trim-1].get_performance()

col4.write(f"**Total tokens:** {total_tokens} \
        \n**Compute time:** {h}h {m}m {s:.2}s")

tab1, tab2, tab3 = st.tabs(["**Data Frame**", "**Estatística Score**", '**Gráfico Estatística**'])
with tab1:
    st.dataframe(df.style.highlight_max(axis = 1, color ='lightgreen',
                                        subset = title_score),
                selection_mode = 'single-row',
                key='df_falas',
                on_select=select_fala,
                column_config={'speech':st.column_config.Column('Fala', width=100),
                                    'qty_tokens':st.column_config.NumberColumn("Qtde. Tokens", format='%d'),
                                    'positive_score':st.column_config.NumberColumn("Score Positivo",),
                                    'negative_score':st.column_config.NumberColumn("Score Negativo",),
                                    'neutral_score':st.column_config.NumberColumn("Score Neutro",),
                                    'tag':"Rótulo",
                                    },
                height=200,)

with tab2:
    st.dataframe(shap_values[empresa][trim-1].statistic, )

with tab3:
    st.plotly_chart(shap_values[empresa][trim-1].plot)


score_positive, score_negative, score_neutral = df.loc[text_num, title_score]

rotulo = st.radio(
    "**Rótulo**",
    option_map.keys(),
    horizontal=True,
    format_func=lambda option: option_map[option],
    captions = [f'{score_neutral:.4}', f'{score_positive:.4}', f'{score_negative:.4}'],
    key='rotulo'
)

plot_text = shap_values[empresa][trim-1].shap_plot_text(text_num, rotulo)
components.html(plot_text, height = 180, scrolling = True)

st.header("Gráfico waterfall dos termos e Valores de Shapley")

with st.expander("Expand"):
    max_display = st.slider(
        "**Máximo de exibição:**",
        1, max_value = int(df['qty_tokens'][text_num]),
        value=int(int(df['qty_tokens'][text_num])/3)+1
    )

    plot_waterfall = shap_values[empresa][trim-1].shap_waterfall(text_num, rotulo, max_display)
    st.pyplot(plot_waterfall)

st.header('Rank de termos do documento em Gráfico Barra')

with st.expander("Expand"):
    plot_bar, ax, rank = shap_values[empresa][trim-1].get_plot_rank()
    for key, val in option_map.items():
        st.subheader(val)
        st.pyplot(plot_bar[key])

st.dataframe(rank)