taka-yamakoshi
fix mask token
633647b
raw
history blame
No virus
17.9 kB
import numpy as np
import pandas as pd
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
#import jax
#import jax.numpy as jnp
import torch
import torch.nn.functional as F
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
def wide_setup():
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
<style>
.appview-container .main .block-container{{
max-width: {max_width}px;
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
</style>
"""
hide_table_row_index = """
<style>
tbody th {display:none}
.blank {display:none}
</style>
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
def load_css(file_name):
with open(file_name) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model(model_name):
if model_name.startswith('albert'):
from transformers import AlbertTokenizer, AlbertForMaskedLM
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
tokenizer = AlbertTokenizer.from_pretrained(model_name)
model = AlbertForMaskedLM.from_pretrained(model_name)
skeleton_model = SkeletonAlbertForMaskedLM
elif model_name.startswith('bert'):
from transformers import BertTokenizer, BertForMaskedLM
from skeleton_modeling_bert import SkeletonBertForMaskedLM
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)
skeleton_model = SkeletonBertForMaskedLM
elif model_name.startswith('roberta'):
from transformers import RobertaTokenizer, RobertaForMaskedLM
from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForMaskedLM.from_pretrained(model_name)
skeleton_model = SkeletonRobertaForMaskedLM
return tokenizer,model,skeleton_model
def clear_data():
for key in st.session_state:
del st.session_state[key]
def annotate_mask(sent_id,sent):
show_instruction(f'Sentence {sent_id}',fontsize=16)
input_sent = tokenizer(sent).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
st.session_state[f'decoded_sent_{sent_id}'] = decoded_sent
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
if f'mask_locs_{sent_id}' not in st.session_state:
st.session_state[f'mask_locs_{sent_id}'] = []
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_mask_{sent_id}_{word_id}'):
if word_id not in st.session_state[f'mask_locs_{sent_id}']:
st.session_state[f'mask_locs_{sent_id}'].append(word_id)
else:
st.session_state[f'mask_locs_{sent_id}'].remove(word_id)
show_annotated_sentence(decoded_sent,
mask_locs=st.session_state[f'mask_locs_{sent_id}'])
def annotate_options(sent_id,sent):
show_instruction(f'Sentence {sent_id}',fontsize=16)
input_sent = tokenizer(sent).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
if f'option_locs_{sent_id}' not in st.session_state:
st.session_state[f'option_locs_{sent_id}'] = []
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_option_{sent_id}_{word_id}'):
if word_id not in st.session_state[f'option_locs_{sent_id}']:
st.session_state[f'option_locs_{sent_id}'].append(word_id)
else:
st.session_state[f'option_locs_{sent_id}'].remove(word_id)
show_annotated_sentence(decoded_sent,
option_locs=st.session_state[f'option_locs_{sent_id}'],
mask_locs=st.session_state[f'mask_locs_{sent_id}'])
st.session_state[f'option_locs_{sent_id}'] = list(np.sort(st.session_state[f'option_locs_{sent_id}']))
st.session_state[f'mask_locs_{sent_id}'] = list(np.sort(st.session_state[f'mask_locs_{sent_id}']))
def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
style_list = []
for i, word in enumerate(sent):
if i in mask_locs:
style_list.append(f'<span style="color:Red">{word}</span>')
elif i in option_locs:
style_list.append(f'<span style="color:Blue">{word}</span>')
else:
style_list.append(f'{word}')
disp = ' '.join(style_list)
suffix = '</span></p>'
return st.markdown(prefix + disp + suffix, unsafe_allow_html = True)
def show_instruction(sent,fontsize=20):
disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"'
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
suffix = '</span></p>'
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
def create_interventions(token_id,interv_types,num_heads,multihead=False,heads=[]):
interventions = {}
for rep in ['lay','qry','key','val']:
if rep in interv_types:
if multihead:
interventions[rep] = [(head_id,token_id,[0,1]) for head_id in range(num_heads)]
else:
interventions[rep] = [(head_id,token_id,[i,i+len(heads)]) for i,head_id in enumerate(heads)]
else:
interventions[rep] = []
return interventions
def separate_options(option_locs):
assert np.sum(np.diff(option_locs)>1)==1
sep = list(np.diff(option_locs)>1).index(1)+1
option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:]
if len(option_1_locs)>1:
assert np.all(np.diff(option_1_locs)==1)
if len(option_2_locs)>1:
assert np.all(np.diff(option_2_locs)==1)
return option_1_locs, option_2_locs
def mask_out(input_ids,pron_locs,option_locs,mask_id):
if len(pron_locs)>1:
assert np.all(np.diff(pron_locs)==1)
# note annotations are shifted by 1 because special tokens were omitted
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
def run_intervention(interventions,batch_size,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
probs = []
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
input_ids = torch.tensor([
*[masked_ids['sent_1'] for _ in range(batch_size)],
*[masked_ids['sent_2'] for _ in range(batch_size)]
])
outputs = skeleton_model(model,input_ids,interventions=interventions)
logprobs = F.log_softmax(outputs['logits'], dim = -1)
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
probs.append([np.exp(np.mean(evals_1,axis=0)),np.exp(np.mean(evals_2,axis=0))])
probs = np.array(probs)
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
return probs
def show_results(effect_array,masked_sent,token_id_list,num_layers):
cols = st.columns(len(masked_sent)-2)
for col_id,col in enumerate(cols):
with col:
st.write(tokenizer.decode([masked_sent[col_id+1]]))
if col_id in token_id_list:
interv_id = token_id_list.index(col_id)
fig,ax = plt.subplots()
ax.set_box_aspect(num_layers)
ax.imshow(effect_array[:,interv_id:interv_id+1],cmap=sns.color_palette("light:r", as_cmap=True),
vmin=effect_array.min(),vmax=effect_array.max())
ax.set_xticks([])
ax.set_xticklabels([])
ax.set_yticks([])
ax.set_yticklabels([])
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
st.pyplot(fig)
if __name__=='__main__':
wide_setup()
load_css('style.css')
if 'page_status' not in st.session_state:
st.session_state['page_status'] = 'model_selection'
if st.session_state['page_status']=='model_selection':
show_instruction('0. Select the model and click "Confirm"',fontsize=16)
model_name = st.selectbox('Please select the model from below.',
('bert-base-uncased','bert-large-cased',
'roberta-base','roberta-large',
'albert-base-v2','albert-large-v2','albert-xlarge-v2','albert-xxlarge-v2'),
index=3,label_visibility='collapsed')
st.session_state['model_name'] = model_name
if st.button('Confirm',key='confirm_models'):
st.session_state['page_status'] = 'type_in'
st.experimental_rerun()
if st.session_state['page_status']!='model_selection':
tokenizer,model,skeleton_model = load_model(st.session_state['model_name'])
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
mask_id = tokenizer(tokenizer.mask_token).input_ids[1:-1][0]
if st.session_state['page_status']=='type_in':
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
sent_1 = st.text_input('Sentence 1',value="Paul tried to call George on the phone, but he wasn't successful.")
sent_2 = st.text_input('Sentence 2',value="Paul tried to call George on the phone, but he wasn't available.")
if st.button('Tokenize'):
st.session_state['page_status'] = 'annotate_mask'
st.session_state['sent_1'] = sent_1
st.session_state['sent_2'] = sent_2
st.experimental_rerun()
if st.session_state['page_status']=='annotate_mask':
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
show_instruction('2. Select sites to mask out and click "Confirm"',fontsize=16)
#show_instruction('------------------------------',fontsize=32)
annotate_mask(1,sent_1)
show_instruction('------------------------------',fontsize=24)
annotate_mask(2,sent_2)
if st.button('Confirm',key='confirm_mask'):
st.session_state['page_status'] = 'annotate_options'
st.experimental_rerun()
if st.session_state['page_status'] == 'annotate_options':
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
show_instruction('3. Select options and click "Confirm"',fontsize=16)
#show_instruction('------------------------------',fontsize=32)
annotate_options(1,sent_1)
show_instruction('------------------------------',fontsize=24)
annotate_options(2,sent_2)
if st.button('Confirm',key='confirm_option'):
st.session_state['page_status'] = 'analysis'
st.experimental_rerun()
if st.session_state['page_status']=='analysis':
interv_reps = st.multiselect('Select the types of representations to intervene.',['layer','query','key','value'])
rep_dict = {'layer':'lay','query':'qry','key':'key','value':'val'}
multihead = not st.checkbox('Perform individual head analysis (takes time)')
if not multihead:
heads = st.multiselect('Select heads to intervene.',list(np.arange(1,num_heads+1)))
else:
heads = []
if st.button('Run',key='run'):
st.session_state['reps'] = [rep_dict[rep] for rep in interv_reps]
st.session_state['multihead'] = multihead
st.session_state['heads'] = heads
st.session_state['page_status'] = 'results'
st.experimental_rerun()
if st.session_state['page_status']=='results':
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
multihead = st.session_state['multihead']
heads = st.session_state['heads']
reps = st.session_state['reps']
option_1_locs, option_2_locs = {}, {}
pron_locs = {}
input_ids_dict = {}
masked_ids_option_1 = {}
masked_ids_option_2 = {}
for sent_id in [1,2]:
option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}']
input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids
masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
pron_locs[f'sent_{sent_id}'],
option_1_locs[f'sent_{sent_id}'],mask_id)
masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
pron_locs[f'sent_{sent_id}'],
option_2_locs[f'sent_{sent_id}'],mask_id)
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1]
assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
option_1_tokens = option_1_tokens_1
option_2_tokens = option_2_tokens_1
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
probs_original = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
[probs_original[0,1][0],probs_original[1,1][0]]],
columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
index=['Sentence 1','Sentence 2'])
cols = st.columns(3)
with cols[1]:
show_instruction('Probability of predicting each option in each sentence',fontsize=12)
st.dataframe(df.style.highlight_max(axis=1),use_container_width=True)
compare_1 = np.array(masked_ids_option_1['sent_1'])!=np.array(masked_ids_option_1['sent_2'])
compare_2 = np.array(masked_ids_option_2['sent_1'])!=np.array(masked_ids_option_2['sent_2'])
assert np.all(compare_1.astype(int)==compare_2.astype(int))
context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation
assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs
effect_array = []
for token_id in token_id_list:
token_id += 1
effect_list = []
for layer_id in range(num_layers):
interventions = [create_interventions(token_id,reps,num_heads,multihead,[head_id-1 for head_id in heads])
if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
if multihead:
probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
else:
probs = run_intervention(interventions,len(heads),skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
effect_list.append(effect)
effect_array.append(effect_list)
effect_array = np.transpose(np.array(effect_array),(1,0,2))
if multihead:
show_results(effect_array[:,:,0],masked_ids_option_1['sent_1'],token_id_list,num_layers)
else:
tabs = st.tabs([str(head_id) for head_id in heads])
for i,tab in enumerate(tabs):
with tab:
show_results(effect_array[:,:,i],masked_ids_option_1['sent_1'],token_id_list,num_layers)