Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
#import jax | |
#import jax.numpy as jnp | |
import torch | |
import torch.nn.functional as F | |
from transformers import AlbertTokenizer, AlbertForMaskedLM | |
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM | |
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM | |
def wide_setup(): | |
max_width = 1500 | |
padding_top = 0 | |
padding_right = 2 | |
padding_bottom = 0 | |
padding_left = 2 | |
define_margins = f""" | |
<style> | |
.appview-container .main .block-container{{ | |
max-width: {max_width}px; | |
padding-top: {padding_top}rem; | |
padding-right: {padding_right}rem; | |
padding-left: {padding_left}rem; | |
padding-bottom: {padding_bottom}rem; | |
}} | |
</style> | |
""" | |
hide_table_row_index = """ | |
<style> | |
tbody th {display:none} | |
.blank {display:none} | |
</style> | |
""" | |
st.markdown(define_margins, unsafe_allow_html=True) | |
st.markdown(hide_table_row_index, unsafe_allow_html=True) | |
def load_css(file_name): | |
with open(file_name) as f: | |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
def load_model(): | |
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2') | |
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True) | |
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2') | |
return tokenizer,model | |
def clear_data(): | |
for key in st.session_state: | |
del st.session_state[key] | |
def annotate_mask(sent_id,sent): | |
show_instruction(f'Sentence {sent_id}',fontsize=16) | |
input_sent = tokenizer(sent).input_ids | |
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] | |
st.session_state[f'decoded_sent_{sent_id}'] = decoded_sent | |
char_nums = [len(word)+2 for word in decoded_sent] | |
cols = st.columns(char_nums) | |
if f'mask_locs_{sent_id}' not in st.session_state: | |
st.session_state[f'mask_locs_{sent_id}'] = [] | |
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): | |
with col: | |
if st.button(word,key=f'word_mask_{sent_id}_{word_id}'): | |
if word_id not in st.session_state[f'mask_locs_{sent_id}']: | |
st.session_state[f'mask_locs_{sent_id}'].append(word_id) | |
else: | |
st.session_state[f'mask_locs_{sent_id}'].remove(word_id) | |
show_annotated_sentence(decoded_sent, | |
mask_locs=st.session_state[f'mask_locs_{sent_id}']) | |
def annotate_options(sent_id,sent): | |
show_instruction(f'Sentence {sent_id}',fontsize=16) | |
input_sent = tokenizer(sent).input_ids | |
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] | |
char_nums = [len(word)+2 for word in decoded_sent] | |
cols = st.columns(char_nums) | |
if f'option_locs_{sent_id}' not in st.session_state: | |
st.session_state[f'option_locs_{sent_id}'] = [] | |
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): | |
with col: | |
if st.button(word,key=f'word_option_{sent_id}_{word_id}'): | |
if word_id not in st.session_state[f'option_locs_{sent_id}']: | |
st.session_state[f'option_locs_{sent_id}'].append(word_id) | |
else: | |
st.session_state[f'option_locs_{sent_id}'].remove(word_id) | |
show_annotated_sentence(decoded_sent, | |
option_locs=st.session_state[f'option_locs_{sent_id}'], | |
mask_locs=st.session_state[f'mask_locs_{sent_id}']) | |
st.session_state[f'option_locs_{sent_id}'] = list(np.sort(st.session_state[f'option_locs_{sent_id}'])) | |
st.session_state[f'mask_locs_{sent_id}'] = list(np.sort(st.session_state[f'mask_locs_{sent_id}'])) | |
def show_annotated_sentence(sent,option_locs=[],mask_locs=[]): | |
disp_style = '"font-family:san serif; color:Black; font-size: 20px"' | |
prefix = f'<p style={disp_style}><span style="font-weight:bold">' | |
style_list = [] | |
for i, word in enumerate(sent): | |
if i in mask_locs: | |
style_list.append(f'<span style="color:Red">{word}</span>') | |
elif i in option_locs: | |
style_list.append(f'<span style="color:Blue">{word}</span>') | |
else: | |
style_list.append(f'{word}') | |
disp = ' '.join(style_list) | |
suffix = '</span></p>' | |
return st.markdown(prefix + disp + suffix, unsafe_allow_html = True) | |
def show_instruction(sent,fontsize=20): | |
disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"' | |
prefix = f'<p style={disp_style}><span style="font-weight:bold">' | |
suffix = '</span></p>' | |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True) | |
def create_interventions(token_id,interv_types,num_heads,multihead=False): | |
interventions = {} | |
for rep in ['lay','qry','key','val']: | |
if rep in interv_types: | |
if multihead: | |
interventions[rep] = [(head_id,token_id,[0,1]) for head_id in range(num_heads)] | |
else: | |
interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)] | |
else: | |
interventions[rep] = [] | |
return interventions | |
def separate_options(option_locs): | |
assert np.sum(np.diff(option_locs)>1)==1 | |
sep = list(np.diff(option_locs)>1).index(1)+1 | |
option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:] | |
if len(option_1_locs)>1: | |
assert np.all(np.diff(option_1_locs)==1) | |
if len(option_2_locs)>1: | |
assert np.all(np.diff(option_2_locs)==1) | |
return option_1_locs, option_2_locs | |
def mask_out(input_ids,pron_locs,option_locs,mask_id): | |
if len(pron_locs)>1: | |
assert np.all(np.diff(pron_locs)==1) | |
# note annotations are shifted by 1 because special tokens were omitted | |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:] | |
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs): | |
probs = [] | |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]): | |
input_ids = torch.tensor([ | |
*[masked_ids['sent_1'] for _ in range(batch_size)], | |
*[masked_ids['sent_2'] for _ in range(batch_size)] | |
]) | |
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions) | |
logprobs = F.log_softmax(outputs['logits'], dim = -1) | |
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:] | |
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)] | |
evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)] | |
probs.append([np.exp(np.mean(evals_1,axis=0)),np.exp(np.mean(evals_2,axis=0))]) | |
probs = np.array(probs) | |
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size | |
return probs | |
if __name__=='__main__': | |
wide_setup() | |
load_css('style.css') | |
tokenizer,model = load_model() | |
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads | |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0] | |
main_area = st.empty() | |
if 'page_status' not in st.session_state: | |
st.session_state['page_status'] = 'type_in' | |
if st.session_state['page_status']=='type_in': | |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16) | |
sent_1 = st.text_input('Sentence 1',value="Paul tried to call George on the phone, but he wasn't successful.") | |
sent_2 = st.text_input('Sentence 2',value="Paul tried to call George on the phone, but he wasn't available.") | |
if st.button('Tokenize'): | |
st.session_state['page_status'] = 'annotate_mask' | |
st.session_state['sent_1'] = sent_1 | |
st.session_state['sent_2'] = sent_2 | |
st.experimental_rerun() | |
if st.session_state['page_status']=='annotate_mask': | |
sent_1 = st.session_state['sent_1'] | |
sent_2 = st.session_state['sent_2'] | |
show_instruction('2. Select sites to mask out and click "Confirm"',fontsize=16) | |
show_instruction('------------------------------',fontsize=32) | |
annotate_mask(1,sent_1) | |
show_instruction('------------------------------',fontsize=32) | |
annotate_mask(2,sent_2) | |
if st.button('Confirm',key='mask'): | |
st.session_state['page_status'] = 'annotate_options' | |
st.experimental_rerun() | |
if st.session_state['page_status'] == 'annotate_options': | |
sent_1 = st.session_state['sent_1'] | |
sent_2 = st.session_state['sent_2'] | |
show_instruction('3. Select options and click "Confirm"',fontsize=16) | |
show_instruction('------------------------------',fontsize=32) | |
annotate_options(1,sent_1) | |
show_instruction('------------------------------',fontsize=32) | |
annotate_options(2,sent_2) | |
if st.button('Confirm',key='option'): | |
st.session_state['page_status'] = 'analysis' | |
st.experimental_rerun() | |
if st.session_state['page_status']=='analysis': | |
with main_area.container(): | |
sent_1 = st.session_state['sent_1'] | |
sent_2 = st.session_state['sent_2'] | |
#show_annotated_sentence(st.session_state['decoded_sent_1'], | |
# option_locs=st.session_state['option_locs_1'], | |
# mask_locs=st.session_state['mask_locs_1']) | |
#show_annotated_sentence(st.session_state['decoded_sent_2'], | |
# option_locs=st.session_state['option_locs_2'], | |
# mask_locs=st.session_state['mask_locs_2']) | |
option_1_locs, option_2_locs = {}, {} | |
pron_locs = {} | |
input_ids_dict = {} | |
masked_ids_option_1 = {} | |
masked_ids_option_2 = {} | |
for sent_id in [1,2]: | |
option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}']) | |
pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}'] | |
input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids | |
masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'], | |
pron_locs[f'sent_{sent_id}'], | |
option_1_locs[f'sent_{sent_id}'],mask_id) | |
masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'], | |
pron_locs[f'sent_{sent_id}'], | |
option_2_locs[f'sent_{sent_id}'],mask_id) | |
#st.write(option_1_locs) | |
#st.write(option_2_locs) | |
#st.write(pron_locs) | |
#for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]: | |
# st.write(' '.join([tokenizer.decode([token]) for token in token_ids])) | |
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1] | |
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1] | |
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1] | |
option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1] | |
assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2) | |
option_1_tokens = option_1_tokens_1 | |
option_2_tokens = option_2_tokens_1 | |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)] | |
probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs) | |
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]], | |
[probs_original[0,1][0],probs_original[1,1][0]]], | |
columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)], | |
index=['Sentence 1','Sentence 2']) | |
cols = st.columns(3) | |
with cols[1]: | |
show_instruction('Probability of predicting each option in each sentence',fontsize=12) | |
st.dataframe(df.style.highlight_max(axis=1),use_container_width=True) | |
compare_1 = np.array(masked_ids_option_1['sent_1'])!=np.array(masked_ids_option_1['sent_2']) | |
compare_2 = np.array(masked_ids_option_2['sent_1'])!=np.array(masked_ids_option_2['sent_2']) | |
assert np.all(compare_1.astype(int)==compare_2.astype(int)) | |
context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation | |
multihead = True | |
assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2'])) | |
assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2'])) | |
assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2'])) | |
token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs | |
#st.write(token_id_list) | |
effect_array = [] | |
for token_id in token_id_list: | |
token_id += 1 | |
effect_list = [] | |
for layer_id in range(num_layers): | |
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)] | |
if multihead: | |
probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs) | |
else: | |
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs) | |
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4 | |
effect_list.append(effect) | |
effect_array.append(effect_list) | |
effect_array = np.transpose(np.array(effect_array),(1,0,2)) | |
cols = st.columns(len(masked_ids_option_1['sent_1'])-2) | |
token_id = 0 | |
for col_id,col in enumerate(cols): | |
with col: | |
st.write(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]])) | |
if col_id in token_id_list: | |
interv_id = token_id_list.index(col_id) | |
fig,ax = plt.subplots() | |
ax.set_box_aspect(num_layers) | |
ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True), | |
vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max()) | |
ax.set_xticks([]) | |
ax.set_xticklabels([]) | |
ax.set_yticks([]) | |
ax.set_yticklabels([]) | |
ax.spines['top'].set_visible(False) | |
ax.spines['bottom'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
ax.spines['left'].set_visible(False) | |
st.pyplot(fig) | |