jitesh commited on
Commit
1061dba
1 Parent(s): 5d1e3ac

cache importing the models

Browse files
Files changed (2) hide show
  1. app.py +36 -35
  2. story_gen.py +3 -1
app.py CHANGED
@@ -4,16 +4,11 @@ import plotly.figure_factory as ff
4
  import plotly.express as px
5
  import random
6
  import numpy as np
7
- gen = StoryGenerator()
8
 
9
  st.set_page_config(page_title='Storytelling ' +
10
  u'\U0001F5BC', page_icon=u'\U0001F5BC', layout="wide")
11
- if 'count' not in st.session_state or st.session_state.count == 6:
12
- st.session_state.count = 0
13
- st.session_state.chat_history_ids = None
14
- st.session_state.old_response = ''
15
- else:
16
- st.session_state.count += 1
17
  container_mode = st.sidebar.container()
18
  container_guide = st.sidebar.container()
19
  container_param = st.sidebar.container()
@@ -22,7 +17,7 @@ mode = container_mode.radio(
22
  "Select your mode",
23
  ('Create Statistics', 'Play Storytelling'), index=0)
24
  story_till_now = st.text_input(
25
- label='First Sentence',
26
  value=random.choice([
27
  'Hello, I\'m a language model,',
28
  'So I suppose you want to ask me how I did it.',
@@ -30,15 +25,15 @@ story_till_now = st.text_input(
30
  'My first tutor was a dragon with a terrible sense of humor.',
31
  'Doctors told her she could never diet again.',
32
  'Memory is all around us, as well as within.',
33
-
34
- ]))
35
 
36
  num_generation = container_param.slider(
37
  label='Number of generation', min_value=1, max_value=100, value=5, step=1)
38
  length = container_param.slider(label='Length of the generated sentence',
39
  min_value=1, max_value=100, value=10, step=1)
40
  if mode == 'Create Statistics':
41
-
42
  num_tests = container_param.slider(
43
  label='Number of tests', min_value=1, max_value=1000, value=3, step=1)
44
  reaction_weight_mode = container_param.select_slider(
@@ -64,35 +59,41 @@ if mode == 'Create Statistics':
64
  for si, story in enumerate(gen.data):
65
  st.markdown(f'### Story no. {si}:', unsafe_allow_html=False)
66
  for i, sentence in enumerate(story):
67
- col_turn, col_sentence, col_emo = st.columns([1,8,2])
68
- col_turn.markdown(sentence['turn'], unsafe_allow_html=False)
69
- col_sentence.markdown(sentence['sentence'], unsafe_allow_html=False)
70
- col_emo.markdown(f'{sentence["emotion"]} {np.round(sentence["confidence_score"], 3)}', unsafe_allow_html=False)
 
 
 
71
  st.table(data=gen.stats_df, )
72
- data=gen.stats_df[gen.stats_df.sentence_no==3]
73
- fig = px.violin(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
 
74
  st.plotly_chart(fig, use_container_width=True)
75
- fig2 = px.box(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
 
76
  st.plotly_chart(fig2, use_container_width=True)
77
  else:
78
- container_guide.markdown('### You selected statistics. Now set your parameters and click the `Analyse` button.')
79
- elif mode == 'Play Storytelling':
 
80
 
81
- # # , placeholder="Start writing your story...")
82
- # story_till_now = st.text_input(
83
- # label='First Sentence', value='Hello, I\'m a language model,')
84
 
85
- # num_generation = st.sidebar.slider(
86
- # label='Number of generation', min_value=1, max_value=100, value=10, step=1)
87
- # length = st.sidebar.slider(label='Length of the generated sentence',
88
- # min_value=1, max_value=100, value=20, step=1)
89
- if container_button.button('Run'):
90
- story_till_now, emotion = gen.story(
91
- story_till_now, num_generation, length)
92
- st.markdown(f'### Story')
93
- st.text(story_till_now)
94
- st.markdown(f'The last sentence has the "{emotion["label"]}" **Emotion** with a confidence score of {emotion["score"]}.')
95
- else:
96
- container_guide.markdown('### Write the first sentence and then hit the `Run` button')
97
  # elif mode == 'Analyse Emotions':
98
  # container_mode.write('Let\'s play storytelling.')
 
4
  import plotly.express as px
5
  import random
6
  import numpy as np
 
7
 
8
  st.set_page_config(page_title='Storytelling ' +
9
  u'\U0001F5BC', page_icon=u'\U0001F5BC', layout="wide")
10
+ gen = StoryGenerator()
11
+
 
 
 
 
12
  container_mode = st.sidebar.container()
13
  container_guide = st.sidebar.container()
14
  container_param = st.sidebar.container()
 
17
  "Select your mode",
18
  ('Create Statistics', 'Play Storytelling'), index=0)
19
  story_till_now = st.text_input(
20
+ label='First Sentence',
21
  value=random.choice([
22
  'Hello, I\'m a language model,',
23
  'So I suppose you want to ask me how I did it.',
 
25
  'My first tutor was a dragon with a terrible sense of humor.',
26
  'Doctors told her she could never diet again.',
27
  'Memory is all around us, as well as within.',
28
+
29
+ ]))
30
 
31
  num_generation = container_param.slider(
32
  label='Number of generation', min_value=1, max_value=100, value=5, step=1)
33
  length = container_param.slider(label='Length of the generated sentence',
34
  min_value=1, max_value=100, value=10, step=1)
35
  if mode == 'Create Statistics':
36
+
37
  num_tests = container_param.slider(
38
  label='Number of tests', min_value=1, max_value=1000, value=3, step=1)
39
  reaction_weight_mode = container_param.select_slider(
 
59
  for si, story in enumerate(gen.data):
60
  st.markdown(f'### Story no. {si}:', unsafe_allow_html=False)
61
  for i, sentence in enumerate(story):
62
+ col_turn, col_sentence, col_emo = st.columns([1, 8, 2])
63
+ col_turn.markdown(
64
+ sentence['turn'], unsafe_allow_html=False)
65
+ col_sentence.markdown(
66
+ sentence['sentence'], unsafe_allow_html=False)
67
+ col_emo.markdown(
68
+ f'{sentence["emotion"]} {np.round(sentence["confidence_score"], 3)}', unsafe_allow_html=False)
69
  st.table(data=gen.stats_df, )
70
+ data = gen.stats_df[gen.stats_df.sentence_no == 3]
71
+ fig = px.violin(data_frame=data, x="reaction_weight",
72
+ y="num_reactions", hover_data=data.columns)
73
  st.plotly_chart(fig, use_container_width=True)
74
+ fig2 = px.box(data_frame=data, x="reaction_weight",
75
+ y="num_reactions", hover_data=data.columns)
76
  st.plotly_chart(fig2, use_container_width=True)
77
  else:
78
+ container_guide.markdown(
79
+ '### You selected statistics. Now set your parameters and click the `Analyse` button.')
80
+ # elif mode == 'Play Storytelling':
81
 
82
+ # # # , placeholder="Start writing your story...")
83
+ # # story_till_now = st.text_input(
84
+ # # label='First Sentence', value='Hello, I\'m a language model,')
85
 
86
+ # # num_generation = st.sidebar.slider(
87
+ # # label='Number of generation', min_value=1, max_value=100, value=10, step=1)
88
+ # # length = st.sidebar.slider(label='Length of the generated sentence',
89
+ # # min_value=1, max_value=100, value=20, step=1)
90
+ # if container_button.button('Run'):
91
+ # story_till_now, emotion = gen.story(
92
+ # story_till_now, num_generation, length)
93
+ # st.markdown(f'### Story')
94
+ # st.text(story_till_now)
95
+ # st.markdown(f'The last sentence has the "{emotion["label"]}" **Emotion** with a confidence score of {emotion["score"]}.')
96
+ # else:
97
+ # container_guide.markdown('### Write the first sentence and then hit the `Run` button')
98
  # elif mode == 'Analyse Emotions':
99
  # container_mode.write('Let\'s play storytelling.')
story_gen.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import pandas as pd
9
  # import nltk
10
  import re
 
11
 
12
 
13
  class StoryGenerator:
@@ -17,6 +18,7 @@ class StoryGenerator:
17
  self.stories = []
18
  self.data = []
19
 
 
20
  def initialise_models(self):
21
  start = time.time()
22
  self.generator = pipeline('text-generation', model='gpt2')
@@ -156,7 +158,7 @@ class StoryGenerator:
156
  stats_dict['reaction_weight'] = None
157
  stats_df = pd.concat(
158
  [stats_df, pd.DataFrame(stats_dict, index=[f'idx_{i}'])])
159
-
160
  return stats_df, story_till_now, story_data
161
 
162
  def get_stats(self,
 
8
  import pandas as pd
9
  # import nltk
10
  import re
11
+ import streamlit as st
12
 
13
 
14
  class StoryGenerator:
 
18
  self.stories = []
19
  self.data = []
20
 
21
+ @st.cache()
22
  def initialise_models(self):
23
  start = time.time()
24
  self.generator = pipeline('text-generation', model='gpt2')
 
158
  stats_dict['reaction_weight'] = None
159
  stats_df = pd.concat(
160
  [stats_df, pd.DataFrame(stats_dict, index=[f'idx_{i}'])])
161
+
162
  return stats_df, story_till_now, story_data
163
 
164
  def get_stats(self,