import numpy as np #import pandas as pd #import time import streamlit as st #import matplotlib.pyplot as plt #import seaborn as sns #import jax #import jax.numpy as jnp import torch import torch.nn.functional as F from transformers import AlbertTokenizer, AlbertForMaskedLM #from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM from skeleton_modeling_albert import SkeletonAlbertForMaskedLM def wide_setup(): max_width = 1500 padding_top = 0 padding_right = 2 padding_bottom = 0 padding_left = 2 define_margins = f""" """ hide_table_row_index = """ """ st.markdown(define_margins, unsafe_allow_html=True) st.markdown(hide_table_row_index, unsafe_allow_html=True) def load_css(file_name): with open(file_name) as f: st.markdown(f'', unsafe_allow_html=True) @st.cache(show_spinner=True,allow_output_mutation=True) def load_model(): tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True) model = AlbertForMaskedLM.from_pretrained('albert-base-v2') return tokenizer,model def clear_data(): for key in st.session_state: del st.session_state[key] def annotate_mask(sent_id,sent): st.write(f'Sentence {sent_id}') input_sent = tokenizer(sent).input_ids decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] st.session_state[f'decoded_sent_{sent_id}'] = decoded_sent char_nums = [len(word)+2 for word in decoded_sent] cols = st.columns(char_nums) if f'mask_locs_{sent_id}' not in st.session_state: st.session_state[f'mask_locs_{sent_id}'] = [] for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): with col: if st.button(word,key=f'word_mask_{sent_id}_{word_id}'): if word_id not in st.session_state[f'mask_locs_{sent_id}']: st.session_state[f'mask_locs_{sent_id}'].append(word_id) else: st.session_state[f'mask_locs_{sent_id}'].remove(word_id) show_annotated_sentence(decoded_sent, mask_locs=st.session_state[f'mask_locs_{sent_id}']) def annotate_options(sent_id,sent): st.write(f'Sentence {sent_id}') input_sent = tokenizer(sent).input_ids decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] char_nums = [len(word)+2 for word in decoded_sent] cols = st.columns(char_nums) if f'option_locs_{sent_id}' not in st.session_state: st.session_state[f'option_locs_{sent_id}'] = [] for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): with col: if st.button(word,key=f'word_option_{sent_id}_{word_id}'): if word_id not in st.session_state[f'option_locs_{sent_id}']: st.session_state[f'option_locs_{sent_id}'].append(word_id) else: st.session_state[f'option_locs_{sent_id}'].remove(word_id) show_annotated_sentence(decoded_sent, option_locs=st.session_state[f'option_locs_{sent_id}'], mask_locs=st.session_state[f'mask_locs_{sent_id}']) def show_annotated_sentence(sent,option_locs=[],mask_locs=[]): disp_style = '"font-family:san serif; color:Black; font-size: 20px"' prefix = f'
' style_list = [] for i, word in enumerate(sent): if i in mask_locs: style_list.append(f'{word}') elif i in option_locs: style_list.append(f'{word}') else: style_list.append(f'{word}') disp = ' '.join(style_list) suffix = '
' return st.markdown(prefix + disp + suffix, unsafe_allow_html = True) def show_instruction(sent,fontsize=20): disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"' prefix = f'' suffix = '
' return st.markdown(prefix + sent + suffix, unsafe_allow_html = True) def create_interventions(token_id,interv_types,num_heads): interventions = {} for rep in ['lay','qry','key','val']: if rep in interv_types: interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)] else: interventions[rep] = [] return interventions def separate_options(option_locs): assert np.sum(np.diff(option_locs)>1)==1 sep = list(np.diff(option_locs)>1).index(1)+1 option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:] if len(option_1_locs)>1: assert np.all(np.diff(option_1_locs)==1) if len(option_2_locs)>1: assert np.all(np.diff(option_2_locs)==1) return option_1_locs, option_2_locs def mask_out(input_ids,pron_locs,option_locs,mask_id): if len(pron_locs)>1: assert np.all(np.diff(pron_locs)==1) # note annotations are shifted by 1 because special tokens were omitted return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:] def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs): probs = [] for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]): input_ids = torch.tensor([ *[masked_ids['sent_1'] for _ in range(batch_size)], *[masked_ids['sent_2'] for _ in range(batch_size)] ]) outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions) logprobs = F.log_softmax(outputs['logits'], dim = -1) logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:] evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)] evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)] probs.append([np.exp(np.mean(evals_1,axis=0)),np.exp(np.mean(evals_2,axis=0))]) probs = np.array(probs) assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size return probs if __name__=='__main__': wide_setup() #load_css('style.css') #tokenizer,model = load_model() #num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads #st.write(num_layers,num_heads) #mask_id = tokenizer('[MASK]').input_ids[1:-1][0] main_area = st.empty() if 'page_status' not in st.session_state: st.session_state['page_status'] = 'type_in' if st.session_state['page_status']=='type_in': show_instruction('1. Type in the sentences and click "Tokenize"') sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.') sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.') if st.button('Tokenize'): st.session_state['page_status'] = 'annotate_mask' st.session_state['sent_1'] = sent_1 st.session_state['sent_2'] = sent_2 st.experimental_rerun() if st.session_state['page_status']=='annotate_mask': sent_1 = st.session_state['sent_1'] sent_2 = st.session_state['sent_2'] show_instruction('2. Select sites to mask out and click "Confirm"') annotate_mask(1,sent_1) annotate_mask(2,sent_2) if st.button('Confirm',key='mask'): st.session_state['page_status'] = 'annotate_options' st.experimental_rerun() if st.session_state['page_status'] == 'annotate_options': sent_1 = st.session_state['sent_1'] sent_2 = st.session_state['sent_2'] show_instruction('3. Select options and click "Confirm"') annotate_options(1,sent_1) annotate_options(2,sent_2) if st.button('Confirm',key='option'): st.session_state['page_status'] = 'analysis' st.experimental_rerun() if st.session_state['page_status']=='analysis': with main_area.container(): sent_1 = st.session_state['sent_1'] sent_2 = st.session_state['sent_2'] show_annotated_sentence(st.session_state['decoded_sent_1'], option_locs=st.session_state['option_locs_1'], mask_locs=st.session_state['mask_locs_1']) show_annotated_sentence(st.session_state['decoded_sent_2'], option_locs=st.session_state['option_locs_2'], mask_locs=st.session_state['mask_locs_2']) option_1_locs, option_2_locs = {}, {} pron_locs = {} input_ids_dict = {} masked_ids_option_1 = {} masked_ids_option_2 = {} for sent_id in [1,2]: option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}']) pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}'] input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'], pron_locs[f'sent_{sent_id}'], option_1_locs[f'sent_{sent_id}'],mask_id) masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'], pron_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'],mask_id) st.write(option_1_locs) st.write(option_2_locs) st.write(pron_locs) for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]: st.write(' '.join([tokenizer.decode([token]) for token in token_ids])) option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1] option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1] option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1] option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1] assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2) option_1_tokens = option_1_tokens_1 option_2_tokens = option_2_tokens_1 interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)] probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs) st.write(probs_original) if st.session_state['page_status'] == 'finish_debug': for layer_id in range(num_layers): interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)] probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)