Spaces:
Sleeping
Sleeping
山越貴耀
commited on
Commit
•
4c1fd66
1
Parent(s):
22a211b
added app
Browse files- app.py +253 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import streamlit as st
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from sklearn.decomposition import PCA
|
9 |
+
from sklearn.manifold import TSNE
|
10 |
+
from sentence_transformers import SentenceTransformer
|
11 |
+
from transformers import BertTokenizer,BertForMaskedLM
|
12 |
+
import cv2
|
13 |
+
|
14 |
+
def load_sentence_model():
|
15 |
+
sentence_model = SentenceTransformer('paraphrase-distilroberta-base-v1')
|
16 |
+
return sentence_model
|
17 |
+
|
18 |
+
@st.cache(show_spinner=False)
|
19 |
+
def load_model(model_name):
|
20 |
+
if model_name.startswith('bert'):
|
21 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
22 |
+
model = BertForMaskedLM.from_pretrained(model_name)
|
23 |
+
model.eval()
|
24 |
+
return tokenizer,model
|
25 |
+
|
26 |
+
@st.cache
|
27 |
+
def load_data(sentence_num):
|
28 |
+
df = pd.read_csv('tsne_out.csv')
|
29 |
+
df = df.loc[lambda d: (d['sentence_num']==sentence_num)&(d['iter_num']<1000)]
|
30 |
+
return df
|
31 |
+
|
32 |
+
@st.cache
|
33 |
+
def mask_prob(model,mask_id,sentences,position,temp=1):
|
34 |
+
masked_sentences = sentences.clone()
|
35 |
+
masked_sentences[:, position] = mask_id
|
36 |
+
with torch.no_grad():
|
37 |
+
logits = model(masked_sentences)[0]
|
38 |
+
return F.log_softmax(logits[:, position] / temp, dim = -1)
|
39 |
+
|
40 |
+
@st.cache
|
41 |
+
def sample_words(probs,pos,sentences):
|
42 |
+
candidates = [[tokenizer.decode([candidate]),torch.exp(probs)[0,candidate].item()]
|
43 |
+
for candidate in torch.argsort(probs[0],descending=True)[:10]]
|
44 |
+
df = pd.DataFrame(data=candidates,columns=['word','prob'])
|
45 |
+
chosen_words = torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1)
|
46 |
+
new_sentences = sentences.clone()
|
47 |
+
new_sentences[:, pos] = chosen_words
|
48 |
+
return new_sentences, df
|
49 |
+
|
50 |
+
def run_chains(tokenizer,model,mask_id,input_text,num_steps):
|
51 |
+
init_sent = tokenizer(input_text,return_tensors='pt')['input_ids']
|
52 |
+
seq_len = init_sent.shape[1]
|
53 |
+
sentence = init_sent.clone()
|
54 |
+
data_list = []
|
55 |
+
st.sidebar.write('Generating samples...')
|
56 |
+
st.sidebar.write('This takes ~30 seconds for 1000 steps with ~10 token sentences')
|
57 |
+
chain_progress = st.sidebar.progress(0)
|
58 |
+
for step_id in range(num_steps):
|
59 |
+
chain_progress.progress((step_id+1)/num_steps)
|
60 |
+
pos = torch.randint(seq_len-2,size=(1,)).item()+1
|
61 |
+
data_list.append([step_id,' '.join([tokenizer.decode([token]) for token in sentence[0]]),pos])
|
62 |
+
probs = mask_prob(model,mask_id,sentence,pos)
|
63 |
+
sentence,_ = sample_words(probs,pos,sentence)
|
64 |
+
return pd.DataFrame(data=data_list,columns=['step','sentence','next_sample_loc'])
|
65 |
+
|
66 |
+
@st.cache(suppress_st_warning=True,show_spinner=False)
|
67 |
+
def run_tsne(chain):
|
68 |
+
st.sidebar.write('Running t-SNE...')
|
69 |
+
chain = chain.assign(cleaned_sentence=chain.sentence.str.replace(r'\[CLS\] ', '',regex=True).str.replace(r' \[SEP\]', '',regex=True))
|
70 |
+
sentence_model = load_sentence_model()
|
71 |
+
sentence_embeddings = sentence_model.encode(chain.cleaned_sentence.to_list(), show_progress_bar=False)
|
72 |
+
|
73 |
+
tsne = TSNE(n_components = 2, n_iter=2000)
|
74 |
+
big_pca = PCA(n_components = 50)
|
75 |
+
tsne_vals = tsne.fit_transform(big_pca.fit_transform(sentence_embeddings))
|
76 |
+
tsne = pd.concat([chain, pd.DataFrame(tsne_vals, columns = ['x_tsne', 'y_tsne'],index=chain.index)], axis = 1)
|
77 |
+
return tsne
|
78 |
+
|
79 |
+
def clear_df():
|
80 |
+
del st.session_state['df']
|
81 |
+
|
82 |
+
@st.cache(show_spinner=False)
|
83 |
+
def plot_fig(df,sent_id,xlims,ylims,color_list):
|
84 |
+
x_tsne, y_tsne = df.x_tsne, df.y_tsne
|
85 |
+
fig = plt.figure(figsize=(5,5),dpi=200)
|
86 |
+
ax = fig.add_subplot(1,1,1)
|
87 |
+
ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
|
88 |
+
ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
|
89 |
+
ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
|
90 |
+
ax.set_xlim(*xlims)
|
91 |
+
ax.set_ylim(*ylims)
|
92 |
+
ax.axis('off')
|
93 |
+
ax.set_title(df.cleaned_sentence.to_list()[sent_id])
|
94 |
+
fig.savefig(f'figures/{sent_id}.png')
|
95 |
+
plt.clf()
|
96 |
+
plt.close()
|
97 |
+
|
98 |
+
def pre_render_images(df,input_sent_id):
|
99 |
+
sent_id_options = [min(len(df)-1,max(0,input_sent_id+increment)) for increment in [-500,-100,-10,-1,0,1,10,100,500]]
|
100 |
+
x_tsne, y_tsne = df.x_tsne, df.y_tsne
|
101 |
+
xmax,xmin = (max(x_tsne)//30+1)*30,(min(x_tsne)//30-1)*30
|
102 |
+
ymax,ymin = (max(y_tsne)//30+1)*30,(min(y_tsne)//30-1)*30
|
103 |
+
color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
|
104 |
+
sent_list = []
|
105 |
+
fig_production = st.progress(0)
|
106 |
+
for fig_id,sent_id in enumerate(sent_id_options):
|
107 |
+
fig_production.progress(fig_id+1)
|
108 |
+
plot_fig(fig_id,x_tsne,y_tsne,sent_id,[xmin,xmax],[ymin,ymax],color_list)
|
109 |
+
sent_list.append(df.cleaned_sentence.to_list()[sent_id])
|
110 |
+
return sent_list
|
111 |
+
|
112 |
+
|
113 |
+
if __name__=='__main__':
|
114 |
+
# Config
|
115 |
+
max_width = 1500
|
116 |
+
padding_top = 2
|
117 |
+
padding_right = 5
|
118 |
+
padding_bottom = 0
|
119 |
+
padding_left = 5
|
120 |
+
|
121 |
+
define_margins = f"""
|
122 |
+
<style>
|
123 |
+
.appview-container .main .block-container{{
|
124 |
+
max-width: {max_width}px;
|
125 |
+
padding-top: {padding_top}rem;
|
126 |
+
padding-right: {padding_right}rem;
|
127 |
+
padding-left: {padding_left}rem;
|
128 |
+
padding-bottom: {padding_bottom}rem;
|
129 |
+
}}
|
130 |
+
</style>
|
131 |
+
"""
|
132 |
+
hide_table_row_index = """
|
133 |
+
<style>
|
134 |
+
tbody th {display:none}
|
135 |
+
.blank {display:none}
|
136 |
+
</style>
|
137 |
+
"""
|
138 |
+
st.markdown(define_margins, unsafe_allow_html=True)
|
139 |
+
st.markdown(hide_table_row_index, unsafe_allow_html=True)
|
140 |
+
|
141 |
+
# Title
|
142 |
+
st.header("Demo: Probing BERT's priors with serial reproduction chains")
|
143 |
+
|
144 |
+
# Load BERT
|
145 |
+
tokenizer,model = load_model('bert-base-uncased')
|
146 |
+
mask_id = tokenizer.encode("[MASK]")[1:-1][0]
|
147 |
+
|
148 |
+
# First step: load the dataframe containing sentences
|
149 |
+
input_type = st.sidebar.radio(label='1. Choose the input type',options=('Use one of our example sentences','Use your own initial sentence'))
|
150 |
+
|
151 |
+
if input_type=='Use one of our example sentences':
|
152 |
+
sentence = st.sidebar.selectbox("Select the inital sentence",
|
153 |
+
('About 170 campers attend the camps each week.',
|
154 |
+
'She grew up with three brothers and ten sisters.'))
|
155 |
+
if sentence=='About 170 campers attend the camps each week.':
|
156 |
+
sentence_num = 6
|
157 |
+
else:
|
158 |
+
sentence_num = 8
|
159 |
+
|
160 |
+
st.session_state.df = load_data(sentence_num)
|
161 |
+
|
162 |
+
else:
|
163 |
+
sentence = st.sidebar.text_input('Type down your own sentence here',on_change=clear_df)
|
164 |
+
num_steps = st.sidebar.number_input(label='How many steps do you want to run?',value=1000)
|
165 |
+
if st.sidebar.button('Run chains'):
|
166 |
+
chain = run_chains(tokenizer,model,mask_id,sentence,num_steps=num_steps)
|
167 |
+
st.session_state.df = run_tsne(chain)
|
168 |
+
st.session_state.finished_sampling = True
|
169 |
+
|
170 |
+
if 'df' in st.session_state:
|
171 |
+
df = st.session_state.df
|
172 |
+
sent_id = st.sidebar.slider(label='2. Select the position in a chain to start exploring',
|
173 |
+
min_value=0,max_value=len(df)-1,value=0)
|
174 |
+
|
175 |
+
explore_type = st.sidebar.radio('3. Choose the way to explore',options=['In fixed increments','Click through each step','Autoplay'])
|
176 |
+
if explore_type=='Autoplay':
|
177 |
+
if st.button('Create the video (this may take a few minutes)'):
|
178 |
+
st.write('Creating the video...')
|
179 |
+
x_tsne, y_tsne = df.x_tsne, df.y_tsne
|
180 |
+
xmax,xmin = (max(x_tsne)//30+1)*30,(min(x_tsne)//30-1)*30
|
181 |
+
ymax,ymin = (max(y_tsne)//30+1)*30,(min(y_tsne)//30-1)*30
|
182 |
+
color_list = sns.color_palette('flare',n_colors=1200)
|
183 |
+
fig_production = st.progress(0)
|
184 |
+
|
185 |
+
plot_fig(df,0,[xmin,xmax],[ymin,ymax],color_list)
|
186 |
+
img = cv2.imread('figures/0.png')
|
187 |
+
height, width, layers = img.shape
|
188 |
+
size = (width,height)
|
189 |
+
out = cv2.VideoWriter('sampling_video.mp4',cv2.VideoWriter_fourcc(*'H264'), 3, size)
|
190 |
+
for sent_id in range(1000):
|
191 |
+
fig_production.progress((sent_id+1)/1000)
|
192 |
+
plot_fig(df,sent_id,[xmin,xmax],[ymin,ymax],color_list)
|
193 |
+
img = cv2.imread(f'figures/{sent_id}.png')
|
194 |
+
out.write(img)
|
195 |
+
out.release()
|
196 |
+
|
197 |
+
cols = st.columns([1,2,1])
|
198 |
+
with cols[1]:
|
199 |
+
with open('sampling_video.mp4', 'rb') as f:
|
200 |
+
st.video(f)
|
201 |
+
else:
|
202 |
+
if explore_type=='In fixed increments':
|
203 |
+
button_labels = ['-500','-100','-10','-1','0','+1','+10','+100','+500']
|
204 |
+
increment = st.sidebar.radio(label='select increment',options=button_labels,index=4)
|
205 |
+
sent_id += int(increment.replace('+',''))
|
206 |
+
sent_id = min(len(df)-1,max(0,sent_id))
|
207 |
+
elif explore_type=='Click through each step':
|
208 |
+
sent_id = st.sidebar.number_input(label='step number',value=sent_id)
|
209 |
+
|
210 |
+
x_tsne, y_tsne = df.x_tsne, df.y_tsne
|
211 |
+
xlims = [(min(x_tsne)//30-1)*30,(max(x_tsne)//30+1)*30]
|
212 |
+
ylims = [(min(y_tsne)//30-1)*30,(max(y_tsne)//30+1)*30]
|
213 |
+
color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
|
214 |
+
|
215 |
+
fig = plt.figure(figsize=(5,5),dpi=200)
|
216 |
+
ax = fig.add_subplot(1,1,1)
|
217 |
+
ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
|
218 |
+
ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
|
219 |
+
ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
|
220 |
+
ax.set_xlim(*xlims)
|
221 |
+
ax.set_ylim(*ylims)
|
222 |
+
ax.axis('off')
|
223 |
+
|
224 |
+
sentence = df.cleaned_sentence.to_list()[sent_id]
|
225 |
+
input_sent = tokenizer(sentence,return_tensors='pt')['input_ids']
|
226 |
+
decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]]
|
227 |
+
show_candidates = st.checkbox('Show candidates')
|
228 |
+
if show_candidates:
|
229 |
+
st.write('Click any word to see each candidate with its probability')
|
230 |
+
cols = st.columns(len(decoded_sent))
|
231 |
+
with cols[0]:
|
232 |
+
st.write(decoded_sent[0])
|
233 |
+
with cols[-1]:
|
234 |
+
st.write(decoded_sent[-1])
|
235 |
+
for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
|
236 |
+
with col:
|
237 |
+
if st.button(word):
|
238 |
+
probs = mask_prob(model,mask_id,input_sent,word_id+1)
|
239 |
+
_,candidates_df = sample_words(probs, word_id+1, input_sent)
|
240 |
+
st.table(candidates_df)
|
241 |
+
else:
|
242 |
+
disp_style = '"font-family:san serif; color:Black; font-size: 25px; font-weight:bold"'
|
243 |
+
if explore_type=='Click through each step' and input_type=='Use your own initial sentence' and sent_id>0 and 'finished_sampling' in st.session_state:
|
244 |
+
sampled_loc = df.next_sample_loc.to_list()[sent_id-1]
|
245 |
+
disp_sent_before = f'<p style={disp_style}>'+' '.join(decoded_sent[1:sampled_loc])
|
246 |
+
new_word = f'<span style="color:Red">{decoded_sent[sampled_loc]}</span>'
|
247 |
+
disp_sent_after = ' '.join(decoded_sent[sampled_loc+1:-1])+'</p>'
|
248 |
+
st.markdown(disp_sent_before+' '+new_word+' '+disp_sent_after,unsafe_allow_html=True)
|
249 |
+
else:
|
250 |
+
st.markdown(f'<p style={disp_style}>{sentence}</p>',unsafe_allow_html=True)
|
251 |
+
cols = st.columns([1,2,1])
|
252 |
+
with cols[1]:
|
253 |
+
st.pyplot(fig)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
sentence_transformers
|
4 |
+
cv2
|
5 |
+
seaborn
|
6 |
+
sklearn
|