File size: 4,324 Bytes
20684dc
 
 
 
 
bfa0c67
 
20684dc
9372e1c
bfa0c67
20684dc
 
 
 
 
 
 
9372e1c
bfa0c67
9372e1c
 
 
bfa0c67
9372e1c
 
 
 
20684dc
 
bfa0c67
9372e1c
 
20684dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfa0c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20684dc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import random

import numpy as np
import plotly.express as px
import streamlit as st
import xlsxwriter
import pandas as pd

from .lib import initialise_storytelling, set_input
import io


def run_create_statistics(gen, container_guide, container_param, container_button):

    first_sentence, first_emotion, length = initialise_storytelling(
        gen, container_guide, container_param, container_button)
    # story_till_now = first_sentence

    num_generation = set_input(container_param,
                                                label='Number of generation', min_value=1, max_value=100, value=5, step=1,
                                                key_slider='num_generation_slider', key_input='num_generation_input',)

    num_tests = set_input(container_param,
                                      label='Number of tests', min_value=1, max_value=1000, value=3, step=1,
                                      key_slider='num_tests_slider', key_input='num_tests_input',)

    reaction_weight_mode = container_param.radio(
        "Reaction Weight w:", ["Random", "Fixed"])
    if reaction_weight_mode == "Fixed":
        reaction_weight = set_input(container_param,
                                                      label='Reaction Weight w', min_value=0.0, max_value=1.0, value=0.5, step=0.01,
                                                      key_slider='w_slider', key_input='w_input',)
    elif reaction_weight_mode == "Random":
        reaction_weight = -1
    if container_button.button('Analyse'):
        gen.get_stats(story_till_now=first_sentence,
                      num_generation=num_generation, length=length, reaction_weight=reaction_weight, num_tests=num_tests)
        # if len(gen.stories) > 0:
        #     for si, story in enumerate(gen.stories):
        #         st.markdown(f'### Story no. {si}:', unsafe_allow_html=False)
        #         st.markdown(story, unsafe_allow_html=False)
        #     data=gen.stats_df[gen.stats_df.sentence_no==3]
        #     fig = px.violin(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
        #     st.plotly_chart(fig, use_container_width=True)
        #     fig2 = px.box(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
        #     st.plotly_chart(fig2, use_container_width=True)
        if len(gen.data) > 0:
            for si, story in enumerate(gen.data):
                st.markdown(f'### Story {si}:', unsafe_allow_html=False)
                for i, sentence in enumerate(story):
                    col_turn, col_sentence, col_emo = st.columns([1, 8, 2])
                    col_turn.markdown(
                        sentence['turn'], unsafe_allow_html=False)
                    col_sentence.markdown(
                        sentence['sentence'], unsafe_allow_html=False)
                    col_emo.markdown(
                        f'{sentence["emotion"]} {np.round(sentence["confidence_score"], 3)}', unsafe_allow_html=False)
            st.table(data=gen.stats_df, )
            data = gen.stats_df[gen.stats_df.sentence_no == 3]
            fig = px.violin(data_frame=data, x="reaction_weight",
                            y="num_reactions", hover_data=data.columns)
            st.plotly_chart(fig, use_container_width=True)
            fig2 = px.box(data_frame=data, x="reaction_weight",
                          y="num_reactions", hover_data=data.columns)
            st.plotly_chart(fig2, use_container_width=True)
            # csv = gen.stats_df.to_csv().encode('utf-8')
                        
            buffer = io.BytesIO()
            with pd.ExcelWriter(buffer, engine='xlsxwriter') as writer:
                # Write each dataframe to a different worksheet.
                gen.stats_df.to_excel(writer, sheet_name='AllData')
                
                # Close the Pandas Excel writer and output the Excel file to the buffer
                writer.save()
                st.download_button(
                    label="Download data",
                    data=buffer,
                    file_name='data.xlsx',
                    mime='application/vnd.ms-excel',
                )
    else:
        container_guide.markdown(
            '### You selected statistics. Now set your parameters and click the `Analyse` button.')