Spaces:
Sleeping
Sleeping
import pandas as pd | |
import streamlit as st | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import torch | |
import torch.nn.functional as F | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE | |
from sentence_transformers import SentenceTransformer | |
from transformers import BertTokenizer,BertForMaskedLM | |
import io | |
import time | |
def load_sentence_model(): | |
sentence_model = SentenceTransformer('paraphrase-distilroberta-base-v1') | |
return sentence_model | |
def load_model(model_name): | |
if model_name.startswith('bert'): | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
model = BertForMaskedLM.from_pretrained(model_name) | |
model.eval() | |
return tokenizer,model | |
def load_data(sentence_num): | |
df = pd.read_csv('tsne_out.csv') | |
df = df.loc[lambda d: (d['sentence_num']==sentence_num)&(d['iter_num']<1000)] | |
return df.reset_index() | |
#@st.cache(show_spinner=False) | |
def mask_prob(model,mask_id,sentences,position,temp=1): | |
masked_sentences = sentences.clone() | |
masked_sentences[:, position] = mask_id | |
with torch.no_grad(): | |
logits = model(masked_sentences)[0] | |
return F.log_softmax(logits[:, position] / temp, dim = -1) | |
#@st.cache(show_spinner=False) | |
def sample_words(probs,pos,sentences): | |
candidates = [[tokenizer.decode([candidate]),torch.exp(probs)[0,candidate].item()] | |
for candidate in torch.argsort(probs[0],descending=True)[:10]] | |
df = pd.DataFrame(data=candidates,columns=['word','prob']) | |
chosen_words = torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) | |
new_sentences = sentences.clone() | |
new_sentences[:, pos] = chosen_words | |
return new_sentences, df | |
def run_chains(tokenizer,model,mask_id,input_text,num_steps): | |
init_sent = tokenizer(input_text,return_tensors='pt')['input_ids'] | |
seq_len = init_sent.shape[1] | |
sentence = init_sent.clone() | |
data_list = [] | |
st.sidebar.write('Generating samples...') | |
st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences') | |
chain_progress = st.sidebar.progress(0) | |
for step_id in range(num_steps): | |
chain_progress.progress((step_id+1)/num_steps) | |
pos = torch.randint(seq_len-2,size=(1,)).item()+1 | |
#data_list.append([step_id,' '.join([tokenizer.decode([token]) for token in sentence[0]]),pos]) | |
data_list.append([step_id,tokenizer.decode([token for token in sentence[0]]),pos]) | |
probs = mask_prob(model,mask_id,sentence,pos) | |
sentence,_ = sample_words(probs,pos,sentence) | |
return pd.DataFrame(data=data_list,columns=['step','sentence','next_sample_loc']) | |
#@st.cache(show_spinner=True,allow_output_mutation=True) | |
def show_tsne_panel(df, step_id): | |
x_tsne, y_tsne = df.x_tsne, df.y_tsne | |
xscale_unit = (max(x_tsne)-min(x_tsne))/10 | |
yscale_unit = (max(y_tsne)-min(y_tsne))/10 | |
xlims = [(min(x_tsne)//xscale_unit-1)*xscale_unit,(max(x_tsne)//xscale_unit+1)*xscale_unit] | |
ylims = [(min(y_tsne)//yscale_unit-1)*yscale_unit,(max(y_tsne)//yscale_unit+1)*yscale_unit] | |
color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2)) | |
fig = plt.figure(figsize=(5,5),dpi=200) | |
ax = fig.add_subplot(1,1,1) | |
ax.plot(x_tsne[:step_id+1],y_tsne[:step_id+1],linewidth=0.2,color='gray',zorder=1) | |
ax.scatter(x_tsne[:step_id+1],y_tsne[:step_id+1],s=5,color=color_list[:step_id+1],zorder=2) | |
ax.scatter(x_tsne[step_id:step_id+1],y_tsne[step_id:step_id+1],s=50,marker='*',color='blue',zorder=3) | |
ax.set_xlim(*xlims) | |
ax.set_ylim(*ylims) | |
ax.axis('off') | |
return fig | |
def run_tsne(chain): | |
st.sidebar.write('Running t-SNE...') | |
st.sidebar.write('This takes ~1 min for 1000 steps with ~10 token sentences') | |
chain = chain.assign(cleaned_sentence=chain.sentence.str.replace(r'\[CLS\] ', '',regex=True).str.replace(r' \[SEP\]', '',regex=True)) | |
sentence_model = load_sentence_model() | |
sentence_embeddings = sentence_model.encode(chain.cleaned_sentence.to_list(), show_progress_bar=False) | |
tsne = TSNE(n_components = 2, n_iter=2000) | |
big_pca = PCA(n_components = 50) | |
tsne_vals = tsne.fit_transform(big_pca.fit_transform(sentence_embeddings)) | |
tsne = pd.concat([chain, pd.DataFrame(tsne_vals, columns = ['x_tsne', 'y_tsne'],index=chain.index)], axis = 1) | |
return tsne | |
def autoplay() : | |
for step_id in range(st.session_state.step_id, len(st.session_state.df), 1): | |
x = st.empty() | |
with x.container(): | |
st.markdown(show_changed_site(), unsafe_allow_html = True) | |
fig = show_tsne_panel(st.session_state.df, step_id) | |
st.session_state.prev_step_id = st.session_state.step_id | |
st.session_state.step_id = step_id | |
#plt.title(f'Step {step_id}')#: {show_changed_site()}') | |
cols = st.columns([1,2,1]) | |
with cols[1]: | |
st.pyplot(fig) | |
time.sleep(.25) | |
x.empty() | |
def initialize_buttons() : | |
buttons = st.sidebar.empty() | |
button_ids = [] | |
with buttons.container() : | |
row1_labels = ['+1','+10','+100','+500'] | |
row1 = st.columns([4,5,6,6]) | |
for col_id,col in enumerate(row1): | |
button_ids.append(col.button(row1_labels[col_id],key=row1_labels[col_id])) | |
row2_labels = ['-1','-10','-100','-500'] | |
row2 = st.columns([4,5,6,6]) | |
for col_id,col in enumerate(row2): | |
button_ids.append(col.button(row2_labels[col_id],key=row2_labels[col_id])) | |
show_candidates_checked = st.checkbox('Show candidates') | |
# Increment if any of them have been pressed | |
increments = np.array([1,10,100,500,-1,-10,-100,-500]) | |
if any(button_ids) : | |
increment_value = increments[np.array(button_ids)][0] | |
st.session_state.prev_step_id = st.session_state.step_id | |
new_step_id = st.session_state.step_id + increment_value | |
st.session_state.step_id = min(len(st.session_state.df) - 1, max(0, new_step_id)) | |
if show_candidates_checked: | |
st.write('Click any word to see each candidate with its probability') | |
show_candidates() | |
def show_candidates(): | |
if 'curr_table' in st.session_state: | |
st.session_state.curr_table.empty() | |
step_id = st.session_state.step_id | |
sentence = df.cleaned_sentence.loc[step_id] | |
input_sent = tokenizer(sentence,return_tensors='pt')['input_ids'] | |
decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]] | |
char_nums = [len(word)+2 for word in decoded_sent] | |
cols = st.columns(char_nums) | |
with cols[0]: | |
st.write(decoded_sent[0]) | |
with cols[-1]: | |
st.write(decoded_sent[-1]) | |
for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])): | |
with col: | |
if st.button(word,key=f'word_{word_id}'): | |
probs = mask_prob(model,mask_id,input_sent,word_id+1) | |
_, candidates_df = sample_words(probs, word_id+1, input_sent) | |
st.session_state.curr_table = st.table(candidates_df) | |
def show_changed_site(): | |
df = st.session_state.df | |
step_id = st.session_state.step_id | |
prev_step_id = st.session_state.prev_step_id | |
curr_sent = df.cleaned_sentence.loc[step_id].split(' ') | |
prev_sent = df.cleaned_sentence.loc[prev_step_id].split(' ') | |
locs = [df.next_sample_loc.to_list()[step_id-1]-1] if 'next_sample_loc' in df else ( | |
[i for i in range(len(curr_sent)) if curr_sent[i] not in prev_sent] | |
) | |
disp_style = '"font-family:san serif; color:Black; font-size: 20px"' | |
prefix = f'<p style={disp_style}>Step {st.session_state.step_id}: <span style="font-weight:bold">' | |
disp = ' '.join([f'<span style="color:Red">{word}</span>' if i in locs else f'{word}' | |
for (i, word) in enumerate(curr_sent)]) | |
suffix = '</span></p>' | |
return prefix + disp + suffix | |
def clear_df(): | |
if 'df' in st.session_state: | |
del st.session_state['df'] | |
if __name__=='__main__': | |
# Config | |
max_width = 1500 | |
padding_top = 0 | |
padding_right = 2 | |
padding_bottom = 0 | |
padding_left = 2 | |
define_margins = f""" | |
<style> | |
.appview-container .main .block-container{{ | |
max-width: {max_width}px; | |
padding-top: {padding_top}rem; | |
padding-right: {padding_right}rem; | |
padding-left: {padding_left}rem; | |
padding-bottom: {padding_bottom}rem; | |
}} | |
</style> | |
""" | |
hide_table_row_index = """ | |
<style> | |
tbody th {display:none} | |
.blank {display:none} | |
</style> | |
""" | |
st.markdown(define_margins, unsafe_allow_html=True) | |
st.markdown(hide_table_row_index, unsafe_allow_html=True) | |
input_type = st.sidebar.radio( | |
label='1. Choose the input type', | |
on_change=clear_df, | |
options=('Use one of the example sentences','Use your own initial sentence') | |
) | |
# Title | |
st.header("Demo: Probing BERT's priors with serial reproduction chains") | |
# Load BERT | |
tokenizer,model = load_model('bert-base-uncased') | |
mask_id = tokenizer.encode("[MASK]")[1:-1][0] | |
# First step: load the dataframe containing sentences | |
if input_type=='Use one of the example sentences': | |
sentence = st.sidebar.selectbox("Select the inital sentence", | |
('--- Please select one from below ---', | |
'About 170 campers attend the camps each week.', | |
"Ali marpet's mother is joy rose.", | |
'She grew up with three brothers and ten sisters.')) | |
if sentence!='--- Please select one from below ---': | |
if sentence=='About 170 campers attend the camps each week.': | |
sentence_num = 6 | |
elif sentence=='She grew up with three brothers and ten sisters.': | |
sentence_num = 8 | |
elif sentence=="Ali marpet's mother is joy rose." : | |
sentence_num = 2 | |
st.session_state.df = load_data(sentence_num) | |
st.session_state.finished_sampling = True | |
else: | |
sentence = st.sidebar.text_input('Type your own sentence here.',on_change=clear_df) | |
num_steps = st.sidebar.number_input(label='How many steps do you want to run?',value=500) | |
if st.sidebar.button('Run chains'): | |
chain = run_chains(tokenizer, model, mask_id, sentence, num_steps=num_steps) | |
st.session_state.df = run_tsne(chain) | |
st.session_state.finished_sampling = True | |
st.empty().markdown("\ | |
Let's explore sentences from BERT's prior! \ | |
Use the menu to the left to select a pre-generated chain, \ | |
or start a new chain using your own initial sentence.\ | |
" if not 'df' in st.session_state else "\ | |
Use the slider to select a step, or watch the autoplay.\ | |
Click 'Show candidates' to see the top proposals when each word is masked out.\ | |
") | |
if 'df' in st.session_state: | |
df = st.session_state.df | |
if 'step_id' not in st.session_state: | |
st.session_state.prev_step_id = 0 | |
st.session_state.step_id = 0 | |
explore_type = st.sidebar.radio( | |
'2. Choose how to explore the chain', | |
options=['Click through steps','Autoplay'] | |
) | |
if explore_type=='Autoplay': | |
st.empty() | |
st.sidebar.empty() | |
autoplay() | |
elif explore_type=='Click through steps': | |
initialize_buttons() | |
with st.container(): | |
st.markdown(show_changed_site(), unsafe_allow_html = True) | |
fig = show_tsne_panel(df, st.session_state.step_id) | |
cols = st.columns([1,2,1]) | |
with cols[1]: | |
st.pyplot(fig) | |