taka-yamakoshi
bring back some
7c56f41
raw history blame
No virus
12.4 kB
import numpy as np
#import pandas as pd
#import time
import streamlit as st
#import matplotlib.pyplot as plt
#import seaborn as sns
#import jax
#import jax.numpy as jnp
import torch
import torch.nn.functional as F
from transformers import AlbertTokenizer, AlbertForMaskedLM
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
def wide_setup():
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
<style>
.appview-container .main .block-container{{
max-width: {max_width}px;
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
</style>
"""
hide_table_row_index = """
<style>
tbody th {display:none}
.blank {display:none}
</style>
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
def load_css(file_name):
with open(file_name) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model():
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
model = AlbertForMaskedLM.from_pretrained('albert-base-v2')
return tokenizer,model
def clear_data():
for key in st.session_state:
del st.session_state[key]
def annotate_mask(sent_id,sent):
st.write(f'Sentence {sent_id}')
input_sent = tokenizer(sent).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
st.session_state[f'decoded_sent_{sent_id}'] = decoded_sent
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
if f'mask_locs_{sent_id}' not in st.session_state:
st.session_state[f'mask_locs_{sent_id}'] = []
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_mask_{sent_id}_{word_id}'):
if word_id not in st.session_state[f'mask_locs_{sent_id}']:
st.session_state[f'mask_locs_{sent_id}'].append(word_id)
else:
st.session_state[f'mask_locs_{sent_id}'].remove(word_id)
show_annotated_sentence(decoded_sent,
mask_locs=st.session_state[f'mask_locs_{sent_id}'])
def annotate_options(sent_id,sent):
st.write(f'Sentence {sent_id}')
input_sent = tokenizer(sent).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
if f'option_locs_{sent_id}' not in st.session_state:
st.session_state[f'option_locs_{sent_id}'] = []
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_option_{sent_id}_{word_id}'):
if word_id not in st.session_state[f'option_locs_{sent_id}']:
st.session_state[f'option_locs_{sent_id}'].append(word_id)
else:
st.session_state[f'option_locs_{sent_id}'].remove(word_id)
show_annotated_sentence(decoded_sent,
option_locs=st.session_state[f'option_locs_{sent_id}'],
mask_locs=st.session_state[f'mask_locs_{sent_id}'])
def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
style_list = []
for i, word in enumerate(sent):
if i in mask_locs:
style_list.append(f'<span style="color:Red">{word}</span>')
elif i in option_locs:
style_list.append(f'<span style="color:Blue">{word}</span>')
else:
style_list.append(f'{word}')
disp = ' '.join(style_list)
suffix = '</span></p>'
return st.markdown(prefix + disp + suffix, unsafe_allow_html = True)
def show_instruction(sent,fontsize=20):
disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"'
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
suffix = '</span></p>'
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
def create_interventions(token_id,interv_types,num_heads):
interventions = {}
for rep in ['lay','qry','key','val']:
if rep in interv_types:
interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
else:
interventions[rep] = []
return interventions
def separate_options(option_locs):
assert np.sum(np.diff(option_locs)>1)==1
sep = list(np.diff(option_locs)>1).index(1)+1
option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:]
if len(option_1_locs)>1:
assert np.all(np.diff(option_1_locs)==1)
if len(option_2_locs)>1:
assert np.all(np.diff(option_2_locs)==1)
return option_1_locs, option_2_locs
def mask_out(input_ids,pron_locs,option_locs,mask_id):
if len(pron_locs)>1:
assert np.all(np.diff(pron_locs)==1)
# note annotations are shifted by 1 because special tokens were omitted
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
'''
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
probs = []
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
input_ids = torch.tensor([
*[masked_ids['sent_1'] for _ in range(batch_size)],
*[masked_ids['sent_2'] for _ in range(batch_size)]
])
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
logprobs = F.log_softmax(outputs['logits'], dim = -1)
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
probs.append([np.exp(np.mean(evals_1,axis=0)),np.exp(np.mean(evals_2,axis=0))])
probs = np.array(probs)
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
return probs
'''
if __name__=='__main__':
wide_setup()
load_css('style.css')
tokenizer,model = load_model()
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
st.write(num_layers,num_heads)
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
main_area = st.empty()
if 'page_status' not in st.session_state:
st.session_state['page_status'] = 'type_in'
if st.session_state['page_status']=='type_in':
show_instruction('1. Type in the sentences and click "Tokenize"')
sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
if st.button('Tokenize'):
st.session_state['page_status'] = 'annotate_mask'
st.session_state['sent_1'] = sent_1
st.session_state['sent_2'] = sent_2
st.experimental_rerun()
if st.session_state['page_status']=='annotate_mask':
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
show_instruction('2. Select sites to mask out and click "Confirm"')
annotate_mask(1,sent_1)
annotate_mask(2,sent_2)
if st.button('Confirm',key='mask'):
st.session_state['page_status'] = 'annotate_options'
st.experimental_rerun()
if st.session_state['page_status'] == 'annotate_options':
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
show_instruction('3. Select options and click "Confirm"')
annotate_options(1,sent_1)
annotate_options(2,sent_2)
if st.button('Confirm',key='option'):
st.session_state['page_status'] = 'analysis'
st.experimental_rerun()
if st.session_state['page_status']=='analysis':
with main_area.container():
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
show_annotated_sentence(st.session_state['decoded_sent_1'],
option_locs=st.session_state['option_locs_1'],
mask_locs=st.session_state['mask_locs_1'])
show_annotated_sentence(st.session_state['decoded_sent_2'],
option_locs=st.session_state['option_locs_2'],
mask_locs=st.session_state['mask_locs_2'])
option_1_locs, option_2_locs = {}, {}
pron_locs = {}
input_ids_dict = {}
masked_ids_option_1 = {}
masked_ids_option_2 = {}
for sent_id in [1,2]:
option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}']
input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids
masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
pron_locs[f'sent_{sent_id}'],
option_1_locs[f'sent_{sent_id}'],mask_id)
masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
pron_locs[f'sent_{sent_id}'],
option_2_locs[f'sent_{sent_id}'],mask_id)
st.write(option_1_locs)
st.write(option_2_locs)
st.write(pron_locs)
for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1]
assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
option_1_tokens = option_1_tokens_1
option_2_tokens = option_2_tokens_1
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
st.write(probs_original)
if st.session_state['page_status'] == 'finish_debug':
for layer_id in range(num_layers):
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)