山越貴耀 commited on
Commit
c510ebb
1 Parent(s): ccfb64d

added autoplay

Browse files
Files changed (1) hide show
  1. app.py +51 -78
app.py CHANGED
@@ -84,44 +84,7 @@ def clear_df():
84
 
85
  #@st.cache(show_spinner=False)
86
  def plot_fig(df,sent_id,xlims,ylims,color_list):
87
- x_tsne, y_tsne = df.x_tsne, df.y_tsne
88
- fig = plt.figure(figsize=(5,5),dpi=200)
89
- ax = fig.add_subplot(1,1,1)
90
- ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
91
- ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
92
- ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
93
- ax.set_xlim(*xlims)
94
- ax.set_ylim(*ylims)
95
- ax.axis('off')
96
- ax.set_title(df.cleaned_sentence.to_list()[sent_id])
97
- #fig.savefig(f'figures/{sent_id}.png')
98
- buf = io.BytesIO()
99
- fig.savefig(buf, format="png", dpi=200)
100
- buf.seek(0)
101
- img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
102
- buf.close()
103
- img = cv2.imdecode(img_arr, 1)
104
- plt.clf()
105
- plt.close()
106
- return img
107
 
108
- def pre_render_images(df,input_sent_id):
109
- sent_id_options = [min(len(df)-1,max(0,input_sent_id+increment)) for increment in [-500,-100,-10,-1,0,1,10,100,500]]
110
- x_tsne, y_tsne = df.x_tsne, df.y_tsne
111
- xscale_unit = (max(x_tsne)-min(x_tsne))/10
112
- yscale_unit = (max(y_tsne)-min(y_tsne))/10
113
- xmax,xmin = (max(x_tsne)//xscale_unit+1)*xscale_unit,(min(x_tsne)//xscale_unit-1)*xscale_unit
114
- ymax,ymin = (max(y_tsne)//yscale_unit+1)*yscale_unit,(min(y_tsne)//yscale_unit-1)*yscale_unit
115
- color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
116
- sent_list = []
117
- fig_list = []
118
- fig_production = st.progress(0)
119
- for fig_id,sent_id in enumerate(sent_id_options):
120
- fig_production.progress(fig_id+1)
121
- img = plot_fig(df,sent_id,[xmin,xmax],[ymin,ymax],color_list)
122
- sent_list.append(df.cleaned_sentence.to_list()[sent_id])
123
- fig_list.append(img)
124
- return sent_list,fig_list
125
 
126
  def update_sent_id(increment_value=0):
127
  sent_id = st.session_state.sent_id
@@ -195,6 +158,13 @@ if __name__=='__main__':
195
 
196
  if 'df' in st.session_state:
197
  df = st.session_state.df
 
 
 
 
 
 
 
198
  st.sidebar.slider(label='2. Select a position in the chain to start exploring',
199
  min_value=0,max_value=len(df)-1,value=0,key='sent_id_from_slider',on_change=initialize_sent_id)
200
  if 'sent_id' not in st.session_state:
@@ -205,31 +175,40 @@ if __name__=='__main__':
205
  else:
206
  explore_type = st.sidebar.radio('3. Choose the way to explore',options=['In fixed increments','Click through each step'])
207
  if explore_type=='Autoplay':
208
- #if st.button('Create the video (this may take a few minutes)'):
209
- #st.write('Creating the video...')
210
- #x_tsne, y_tsne = df.x_tsne, df.y_tsne
211
- #xscale_unit = (max(x_tsne)-min(x_tsne))/10
212
- #yscale_unit = (max(y_tsne)-min(y_tsne))/10
213
- #xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit]
214
- #ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit]
215
- #color_list = sns.color_palette('flare',n_colors=1200)
216
- #fig_production = st.progress(0)
217
-
218
- #img = plot_fig(df,0,xlims,ylims,color_list)
219
- #img = cv2.imread('figures/0.png')
220
- #height, width, layers = img.shape
221
- #size = (width,height)
222
- #out = cv2.VideoWriter('sampling_video.mp4',cv2.VideoWriter_fourcc(*'H264'), 3, size)
223
- #for sent_id in range(1000):
224
- # fig_production.progress((sent_id+1)/1000)
225
- # img = plot_fig(df,sent_id,xlims,ylims,color_list)
226
- #img = cv2.imread(f'figures/{sent_id}.png')
227
- # out.write(img)
228
- #out.release()
229
- cols = st.columns([1,2,1])
230
  with cols[1]:
231
- with open(f'sampling_video_{sentence_num}.mp4', 'rb') as f:
232
- st.video(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  else:
234
  if explore_type=='In fixed increments':
235
  button_labels = ['+1','+10','+100','+500']
@@ -247,30 +226,15 @@ if __name__=='__main__':
247
  elif explore_type=='Click through each step':
248
  st.session_state.sent_id = st.sidebar.number_input(label='step number',value=st.session_state.sent_id_from_slider)
249
 
250
- x_tsne, y_tsne = df.x_tsne, df.y_tsne
251
- xscale_unit = (max(x_tsne)-min(x_tsne))/10
252
- yscale_unit = (max(y_tsne)-min(y_tsne))/10
253
- xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit]
254
- ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit]
255
- color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
256
-
257
  sent_id = st.session_state.sent_id
258
- fig = plt.figure(figsize=(5,5),dpi=200)
259
- ax = fig.add_subplot(1,1,1)
260
- ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
261
- ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
262
- ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
263
- ax.set_xlim(*xlims)
264
- ax.set_ylim(*ylims)
265
- ax.axis('off')
266
-
267
  sentence = df.cleaned_sentence.to_list()[sent_id]
268
  input_sent = tokenizer(sentence,return_tensors='pt')['input_ids']
269
  decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]]
 
270
  show_candidates = st.checkbox('Show candidates')
271
  if show_candidates:
272
  st.write('Click any word to see each candidate with its probability')
273
- cols = st.columns(len(decoded_sent))
274
  with cols[0]:
275
  st.write(decoded_sent[0])
276
  with cols[-1]:
@@ -293,6 +257,15 @@ if __name__=='__main__':
293
  else:
294
  disp_step = f'<p style={disp_style}>Step {st.session_state.sent_id}&colon;&nbsp;'
295
  st.markdown(f'{disp_step}<span style="font-weight:bold">{sentence}</span></p>',unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
296
  cols = st.columns([1,2,1])
297
  with cols[1]:
298
  st.pyplot(fig)
 
84
 
85
  #@st.cache(show_spinner=False)
86
  def plot_fig(df,sent_id,xlims,ylims,color_list):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def update_sent_id(increment_value=0):
90
  sent_id = st.session_state.sent_id
 
158
 
159
  if 'df' in st.session_state:
160
  df = st.session_state.df
161
+ x_tsne, y_tsne = df.x_tsne, df.y_tsne
162
+ xscale_unit = (max(x_tsne)-min(x_tsne))/10
163
+ yscale_unit = (max(y_tsne)-min(y_tsne))/10
164
+ xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit]
165
+ ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit]
166
+ color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
167
+
168
  st.sidebar.slider(label='2. Select a position in the chain to start exploring',
169
  min_value=0,max_value=len(df)-1,value=0,key='sent_id_from_slider',on_change=initialize_sent_id)
170
  if 'sent_id' not in st.session_state:
 
175
  else:
176
  explore_type = st.sidebar.radio('3. Choose the way to explore',options=['In fixed increments','Click through each step'])
177
  if explore_type=='Autoplay':
178
+ cols = st.columns(2)
179
+ with cols[0]:
180
+ container_0 = st.container()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  with cols[1]:
182
+ container_1 = st.container()
183
+ if container_0.button('Play',key='play'):
184
+ while not container_1.button('Stop',key='stop'):
185
+ for sent_id in range(st.session_state.sent_id_from_slider,len(st.session_state)):
186
+ sentence = df.cleaned_sentence.to_list()[sent_id]
187
+ input_sent = tokenizer(sentence,return_tensors='pt')['input_ids']
188
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]]
189
+ disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
190
+ if explore_type=='Click through each step' and input_type=='Use your own initial sentence' and sent_id>0 and 'finished_sampling' in st.session_state:
191
+ sampled_loc = df.next_sample_loc.to_list()[sent_id-1]
192
+ disp_step = f'<p style={disp_style}>Step {st.session_state.sent_id}&colon;&nbsp;'
193
+ disp_sent_before = f'{disp_step}<span style="font-weight:bold">'+' '.join(decoded_sent[1:sampled_loc])
194
+ new_word = f'<span style="color:Red">{decoded_sent[sampled_loc]}</span>'
195
+ disp_sent_after = ' '.join(decoded_sent[sampled_loc+1:-1])+'</span></p>'
196
+ st.markdown(disp_sent_before+' '+new_word+' '+disp_sent_after,unsafe_allow_html=True)
197
+ else:
198
+ disp_step = f'<p style={disp_style}>Step {st.session_state.sent_id}&colon;&nbsp;'
199
+ st.markdown(f'{disp_step}<span style="font-weight:bold">{sentence}</span></p>',unsafe_allow_html=True)
200
+
201
+ fig = plt.figure(figsize=(5,5),dpi=200)
202
+ ax = fig.add_subplot(1,1,1)
203
+ ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
204
+ ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
205
+ ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
206
+ ax.set_xlim(*xlims)
207
+ ax.set_ylim(*ylims)
208
+ ax.axis('off')
209
+ cols = st.columns([1,2,1])
210
+ with cols[1]:
211
+ st.pyplot(fig)
212
  else:
213
  if explore_type=='In fixed increments':
214
  button_labels = ['+1','+10','+100','+500']
 
226
  elif explore_type=='Click through each step':
227
  st.session_state.sent_id = st.sidebar.number_input(label='step number',value=st.session_state.sent_id_from_slider)
228
 
 
 
 
 
 
 
 
229
  sent_id = st.session_state.sent_id
 
 
 
 
 
 
 
 
 
230
  sentence = df.cleaned_sentence.to_list()[sent_id]
231
  input_sent = tokenizer(sentence,return_tensors='pt')['input_ids']
232
  decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]]
233
+ char_nums = [len(word)+2 word for word in decoded_sent]
234
  show_candidates = st.checkbox('Show candidates')
235
  if show_candidates:
236
  st.write('Click any word to see each candidate with its probability')
237
+ cols = st.columns(char_nums)
238
  with cols[0]:
239
  st.write(decoded_sent[0])
240
  with cols[-1]:
 
257
  else:
258
  disp_step = f'<p style={disp_style}>Step {st.session_state.sent_id}&colon;&nbsp;'
259
  st.markdown(f'{disp_step}<span style="font-weight:bold">{sentence}</span></p>',unsafe_allow_html=True)
260
+
261
+ fig = plt.figure(figsize=(5,5),dpi=200)
262
+ ax = fig.add_subplot(1,1,1)
263
+ ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
264
+ ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
265
+ ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
266
+ ax.set_xlim(*xlims)
267
+ ax.set_ylim(*ylims)
268
+ ax.axis('off')
269
  cols = st.columns([1,2,1])
270
  with cols[1]:
271
  st.pyplot(fig)