import pandas as pd import streamlit as st import numpy as np import matplotlib.pyplot as plt import seaborn as sns import torch import torch.nn.functional as F from sklearn.decomposition import PCA from sklearn.manifold import TSNE from sentence_transformers import SentenceTransformer from transformers import BertTokenizer,BertForMaskedLM import cv2 import io def load_sentence_model(): sentence_model = SentenceTransformer('paraphrase-distilroberta-base-v1') return sentence_model @st.cache(show_spinner=False) def load_model(model_name): if model_name.startswith('bert'): tokenizer = BertTokenizer.from_pretrained(model_name) model = BertForMaskedLM.from_pretrained(model_name) model.eval() return tokenizer,model @st.cache(show_spinner=False) def load_data(sentence_num): df = pd.read_csv('tsne_out.csv') df = df.loc[lambda d: (d['sentence_num']==sentence_num)&(d['iter_num']<1000)] return df @st.cache(show_spinner=False) def mask_prob(model,mask_id,sentences,position,temp=1): masked_sentences = sentences.clone() masked_sentences[:, position] = mask_id with torch.no_grad(): logits = model(masked_sentences)[0] return F.log_softmax(logits[:, position] / temp, dim = -1) @st.cache(show_spinner=False) def sample_words(probs,pos,sentences): candidates = [[tokenizer.decode([candidate]),torch.exp(probs)[0,candidate].item()] for candidate in torch.argsort(probs[0],descending=True)[:10]] df = pd.DataFrame(data=candidates,columns=['word','prob']) chosen_words = torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) new_sentences = sentences.clone() new_sentences[:, pos] = chosen_words return new_sentences, df def run_chains(tokenizer,model,mask_id,input_text,num_steps): init_sent = tokenizer(input_text,return_tensors='pt')['input_ids'] seq_len = init_sent.shape[1] sentence = init_sent.clone() data_list = [] st.sidebar.write('Generating samples...') st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences') chain_progress = st.sidebar.progress(0) for step_id in range(num_steps): chain_progress.progress((step_id+1)/num_steps) pos = torch.randint(seq_len-2,size=(1,)).item()+1 data_list.append([step_id,' '.join([tokenizer.decode([token]) for token in sentence[0]]),pos]) probs = mask_prob(model,mask_id,sentence,pos) sentence,_ = sample_words(probs,pos,sentence) return pd.DataFrame(data=data_list,columns=['step','sentence','next_sample_loc']) @st.cache(suppress_st_warning=True,show_spinner=False) def run_tsne(chain): st.sidebar.write('Running t-SNE...') st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences') chain = chain.assign(cleaned_sentence=chain.sentence.str.replace(r'\[CLS\] ', '',regex=True).str.replace(r' \[SEP\]', '',regex=True)) sentence_model = load_sentence_model() sentence_embeddings = sentence_model.encode(chain.cleaned_sentence.to_list(), show_progress_bar=False) tsne = TSNE(n_components = 2, n_iter=2000) big_pca = PCA(n_components = 50) tsne_vals = tsne.fit_transform(big_pca.fit_transform(sentence_embeddings)) tsne = pd.concat([chain, pd.DataFrame(tsne_vals, columns = ['x_tsne', 'y_tsne'],index=chain.index)], axis = 1) return tsne def clear_df(): del st.session_state['df'] @st.cache(show_spinner=False) def plot_fig(df,sent_id,xlims,ylims,color_list): x_tsne, y_tsne = df.x_tsne, df.y_tsne fig = plt.figure(figsize=(5,5),dpi=200) ax = fig.add_subplot(1,1,1) ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1) ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2) ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3) ax.set_xlim(*xlims) ax.set_ylim(*ylims) ax.axis('off') ax.set_title(df.cleaned_sentence.to_list()[sent_id]) #fig.savefig(f'figures/{sent_id}.png') buf = io.BytesIO() fig.savefig(buf, format="png", dpi=200) buf.seek(0) img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) buf.close() img = cv2.imdecode(img_arr, 1) plt.clf() plt.close() return img def pre_render_images(df,input_sent_id): sent_id_options = [min(len(df)-1,max(0,input_sent_id+increment)) for increment in [-500,-100,-10,-1,0,1,10,100,500]] x_tsne, y_tsne = df.x_tsne, df.y_tsne xscale_unit = (max(x_tsne)-min(x_tsne))/10 yscale_unit = (max(y_tsne)-min(y_tsne))/10 xmax,xmin = (max(x_tsne)//xscale_unit+1)*xscale_unit,(min(x_tsne)//xscale_unit-1)*xscale_unit ymax,ymin = (max(y_tsne)//yscale_unit+1)*yscale_unit,(min(y_tsne)//yscale_unit-1)*yscale_unit color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2)) sent_list = [] fig_list = [] fig_production = st.progress(0) for fig_id,sent_id in enumerate(sent_id_options): fig_production.progress(fig_id+1) img = plot_fig(df,sent_id,[xmin,xmax],[ymin,ymax],color_list) sent_list.append(df.cleaned_sentence.to_list()[sent_id]) fig_list.append(img) return sent_list,fig_list if __name__=='__main__': # Config max_width = 1500 padding_top = 0 padding_right = 2 padding_bottom = 0 padding_left = 2 define_margins = f""" """ hide_table_row_index = """ """ st.markdown(define_margins, unsafe_allow_html=True) st.markdown(hide_table_row_index, unsafe_allow_html=True) # Title st.header("Demo: Probing BERT's priors with serial reproduction chains") st.text("Explore sentences in the serial reproduction chains generated by BERT!") st.text("Visit different positions in the chain using the widgets on the left.") st.text("Check 'Show candidates' to see what words are proposed when each word is masked out.") # Load BERT tokenizer,model = load_model('bert-base-uncased') mask_id = tokenizer.encode("[MASK]")[1:-1][0] # First step: load the dataframe containing sentences input_type = st.sidebar.radio(label='1. Choose the input type',options=('Use one of the example sentences','Use your own initial sentence')) if input_type=='Use one of the example sentences': sentence = st.sidebar.selectbox("Select the inital sentence", ('About 170 campers attend the camps each week.', 'She grew up with three brothers and ten sisters.')) if sentence=='About 170 campers attend the camps each week.': sentence_num = 6 else: sentence_num = 8 st.session_state.df = load_data(sentence_num) else: sentence = st.sidebar.text_input('Type down your own sentence here.',on_change=clear_df) num_steps = st.sidebar.number_input(label='How many steps do you want to run?',value=1000) if st.sidebar.button('Run chains'): chain = run_chains(tokenizer,model,mask_id,sentence,num_steps=num_steps) st.session_state.df = run_tsne(chain) st.session_state.finished_sampling = True if 'df' in st.session_state: df = st.session_state.df sent_id = st.sidebar.slider(label='2. Select a position in the chain to start exploring', min_value=0,max_value=len(df)-1,value=0) if input_type=='Use one of the example sentences': explore_type = st.sidebar.radio('3. Choose the way to explore',options=['In fixed increments','Click through each step','Autoplay']) else: explore_type = st.sidebar.radio('3. Choose the way to explore',options=['In fixed increments','Click through each step']) if explore_type=='Autoplay': #if st.button('Create the video (this may take a few minutes)'): #st.write('Creating the video...') #x_tsne, y_tsne = df.x_tsne, df.y_tsne #xscale_unit = (max(x_tsne)-min(x_tsne))/10 #yscale_unit = (max(y_tsne)-min(y_tsne))/10 #xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit] #ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit] #color_list = sns.color_palette('flare',n_colors=1200) #fig_production = st.progress(0) #img = plot_fig(df,0,xlims,ylims,color_list) #img = cv2.imread('figures/0.png') #height, width, layers = img.shape #size = (width,height) #out = cv2.VideoWriter('sampling_video.mp4',cv2.VideoWriter_fourcc(*'H264'), 3, size) #for sent_id in range(1000): # fig_production.progress((sent_id+1)/1000) # img = plot_fig(df,sent_id,xlims,ylims,color_list) #img = cv2.imread(f'figures/{sent_id}.png') # out.write(img) #out.release() cols = st.columns([1,2,1]) with cols[1]: with open(f'sampling_video_{sentence_num}.mp4', 'rb') as f: st.video(f) else: if explore_type=='In fixed increments': button_labels = ['-500','-100','-10','-1','0','+1','+10','+100','+500'] increment = st.sidebar.radio(label='select increment',options=button_labels,index=4) sent_id += int(increment.replace('+','')) sent_id = min(len(df)-1,max(0,sent_id)) elif explore_type=='Click through each step': sent_id = st.sidebar.number_input(label='step number',value=sent_id) x_tsne, y_tsne = df.x_tsne, df.y_tsne xscale_unit = (max(x_tsne)-min(x_tsne))/10 yscale_unit = (max(y_tsne)-min(y_tsne))/10 xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit] ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit] color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2)) fig = plt.figure(figsize=(5,5),dpi=200) ax = fig.add_subplot(1,1,1) ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1) ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2) ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3) ax.set_xlim(*xlims) ax.set_ylim(*ylims) ax.axis('off') sentence = df.cleaned_sentence.to_list()[sent_id] input_sent = tokenizer(sentence,return_tensors='pt')['input_ids'] decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]] show_candidates = st.checkbox('Show candidates') if show_candidates: st.write('Click any word to see each candidate with its probability') cols = st.columns(len(decoded_sent)) with cols[0]: st.write(decoded_sent[0]) with cols[-1]: st.write(decoded_sent[-1]) for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])): with col: if st.button(word): probs = mask_prob(model,mask_id,input_sent,word_id+1) _,candidates_df = sample_words(probs, word_id+1, input_sent) st.table(candidates_df) else: disp_style = '"font-family:san serif; color:Black; font-size: 25px; font-weight:bold"' if explore_type=='Click through each step' and input_type=='Use your own initial sentence' and sent_id>0 and 'finished_sampling' in st.session_state: sampled_loc = df.next_sample_loc.to_list()[sent_id-1] disp_sent_before = f'
'+' '.join(decoded_sent[1:sampled_loc]) new_word = f'{decoded_sent[sampled_loc]}' disp_sent_after = ' '.join(decoded_sent[sampled_loc+1:-1])+'
' st.markdown(disp_sent_before+' '+new_word+' '+disp_sent_after,unsafe_allow_html=True) else: st.markdown(f'{sentence}
',unsafe_allow_html=True) cols = st.columns([1,2,1]) with cols[1]: st.pyplot(fig)