taka-yamakoshi
check masking
ce466e4
raw
history blame
10.7 kB
import numpy as np
import pandas as pd
import time
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
from transformers import AlbertTokenizer, AlbertForMaskedLM
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
def wide_setup():
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
<style>
.appview-container .main .block-container{{
max-width: {max_width}px;
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
</style>
"""
hide_table_row_index = """
<style>
tbody th {display:none}
.blank {display:none}
</style>
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
def load_css(file_name):
with open(file_name) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model():
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2')
return tokenizer,model
def clear_data():
for key in st.session_state:
del st.session_state[key]
def annotate_mask(sent_id,sent):
st.write(f'Sentence {sent_id}')
input_sent = tokenizer(sent).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
st.session_state[f'decoded_sent_{sent_id}'] = decoded_sent
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
if f'mask_locs_{sent_id}' not in st.session_state:
st.session_state[f'mask_locs_{sent_id}'] = []
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_mask_{sent_id}_{word_id}'):
if word_id not in st.session_state[f'mask_locs_{sent_id}']:
st.session_state[f'mask_locs_{sent_id}'].append(word_id)
else:
st.session_state[f'mask_locs_{sent_id}'].remove(word_id)
show_annotated_sentence(decoded_sent,
mask_locs=st.session_state[f'mask_locs_{sent_id}'])
def annotate_options(sent_id,sent):
st.write(f'Sentence {sent_id}')
input_sent = tokenizer(sent).input_ids
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
char_nums = [len(word)+2 for word in decoded_sent]
cols = st.columns(char_nums)
if f'option_locs_{sent_id}' not in st.session_state:
st.session_state[f'option_locs_{sent_id}'] = []
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
with col:
if st.button(word,key=f'word_option_{sent_id}_{word_id}'):
if word_id not in st.session_state[f'option_locs_{sent_id}']:
st.session_state[f'option_locs_{sent_id}'].append(word_id)
else:
st.session_state[f'option_locs_{sent_id}'].remove(word_id)
show_annotated_sentence(decoded_sent,
option_locs=st.session_state[f'option_locs_{sent_id}'],
mask_locs=st.session_state[f'mask_locs_{sent_id}'])
def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
disp_style = '"font-family:san serif; color:Black; font-size: 20px"'
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
style_list = []
for i, word in enumerate(sent):
if i in mask_locs:
style_list.append(f'<span style="color:Red">{word}</span>')
elif i in option_locs:
style_list.append(f'<span style="color:Blue">{word}</span>')
else:
style_list.append(f'{word}')
disp = ' '.join(style_list)
suffix = '</span></p>'
return st.markdown(prefix + disp + suffix, unsafe_allow_html = True)
def show_instruction(sent,fontsize=20):
disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"'
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
suffix = '</span></p>'
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
def create_interventions(token_id,interv_type,num_layers,num_heads):
interventions = {}
for layer_id in range(num_layers):
interventions[layer_id] = {}
if interv_type == 'all':
for rep in ['lay','qry','key','val']:
interventions[layer_id][rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
else:
interventions[layer_id][interv_type] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
return interventions
def separate_options(option_locs):
assert np.sum(np.diff(option_locs)>1)==1
sep = list(np.diff(option_locs)>1).index(1)+1
option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:]
assert np.all(np.diff(option_1_locs)==1) and np.all(np.diff(option_2_loc)==1)
return option_1_locs, option_2_locs
def mask_out(input_ids,pron_locs,option_locs,mask_id):
assert np.all(np.diff(pron_locs)==1)
return input_ids[:pron_locs[0]] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+1:]
if __name__=='__main__':
wide_setup()
load_css('style.css')
tokenizer,model = load_model()
num_layers, num_heads = 12, 64
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
main_area = st.empty()
if 'page_status' not in st.session_state:
st.session_state['page_status'] = 'type_in'
if st.session_state['page_status']=='type_in':
show_instruction('1. Type in the sentences and click "Tokenize"')
sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.')
sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.')
if st.button('Tokenize'):
st.session_state['page_status'] = 'annotate_mask'
st.session_state['sent_1'] = sent_1
st.session_state['sent_2'] = sent_2
st.experimental_rerun()
if st.session_state['page_status']=='annotate_mask':
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
show_instruction('2. Select sites to mask out and click "Confirm"')
annotate_mask(1,sent_1)
annotate_mask(2,sent_2)
if st.button('Confirm',key='mask'):
st.session_state['page_status'] = 'annotate_options'
st.experimental_rerun()
if st.session_state['page_status'] == 'annotate_options':
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
show_instruction('3. Select options and click "Confirm"')
annotate_options(1,sent_1)
annotate_options(2,sent_2)
if st.button('Confirm',key='option'):
st.session_state['page_status'] = 'analysis'
st.experimental_rerun()
if st.session_state['page_status']=='analysis':
with main_area.container():
sent_1 = st.session_state['sent_1']
sent_2 = st.session_state['sent_2']
show_annotated_sentence(st.session_state['decoded_sent_1'],
option_locs=st.session_state['option_locs_1'],
mask_locs=st.session_state['mask_locs_1'])
show_annotated_sentence(st.session_state['decoded_sent_2'],
option_locs=st.session_state['option_locs_2'],
mask_locs=st.session_state['mask_locs_2'])
option_1_locs, option_2_locs = {}, {}
pron_id = {}
input_ids_dict = {}
masked_ids_option_1 = {}
masked_ids_option_2 = {}
for sent_id in range(2):
option_1_locs[f'sent_{sent_id+1}'], option_2_locs[f'sent_{sent_id+1}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
pron_locs[f'sent_{sent_id+1}'] = st.session_state[f'mask_locs_{sent_id+1}']
input_ids_dict[f'sent_{sent_id+1}'] = tokenizer(st.session_state[f'sent_{sent_id+1}']).input_ids
masked_ids_option_1[f'sent_{sent_id+1}'] = mask_out(input_ids_dict[f'sent_{sent_id+1}'],
pron_locs[f'sent_{sent_id+1}'],
option_1_locs[f'sent_{sent_id+1}'],mask_id)
masked_ids_option_2[f'sent_{sent_id+1}'] = mask_out(input_ids_dict[f'sent_{sent_id+1}'],
pron_locs[f'sent_{sent_id+1}'],
option_2_locs[f'sent_{sent_id+1}'],mask_id)
for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
if st.session_state['page_status'] == 'finish_debug':
try:
assert len(input_ids_1) == len(input_ids_2)
except AssertionError:
show_instruction('Please make sure the number of tokens match between Sentence 1 and Sentence 2', fontsize=12)
input_ids = torch.tensor([*[input_ids_1 for _ in range(num_heads)],*[input_ids_2 for _ in range(num_heads)]])
interventions = create_interventions(16,'all',num_layers=num_layers,num_heads=num_heads)
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
logprobs = F.log_softmax(outputs['logits'], dim = -1)
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
st.write([tokenizer.decode([token]) for token in preds_0])
st.write([tokenizer.decode([token]) for token in preds_1])