import numpy as np import pandas as pd import time import streamlit as st import matplotlib.pyplot as plt import seaborn as sns import jax import jax.numpy as jnp import torch import torch.nn.functional as F from transformers import AlbertTokenizer, AlbertForMaskedLM #from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM from skeleton_modeling_albert import SkeletonAlbertForMaskedLM def wide_setup(): max_width = 1500 padding_top = 0 padding_right = 2 padding_bottom = 0 padding_left = 2 define_margins = f""" """ hide_table_row_index = """ """ st.markdown(define_margins, unsafe_allow_html=True) st.markdown(hide_table_row_index, unsafe_allow_html=True) def load_css(file_name): with open(file_name) as f: st.markdown(f'', unsafe_allow_html=True) @st.cache(show_spinner=True,allow_output_mutation=True) def load_model(): tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2') #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True) model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2') return tokenizer,model def clear_data(): for key in st.session_state: del st.session_state[key] if __name__=='__main__': wide_setup() load_css('style.css') tokenizer,model = load_model() mask_id = tokenizer('[MASK]').input_ids[1:-1][0] main_area = st.empty() if 'page_status' not in st.session_state: st.session_state['page_status'] = 'type_in' if st.session_state['page_status']=='type_in': with main_area.container(): st.write('1. Type in the sentences and click "Tokenize"') sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.') sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.') if st.button('Tokenize'): st.session_state['page_status'] = 'tokenized' st.session_state['sent_1'] = sent_1 st.session_state['sent_2'] = sent_2 main_area.empty() if st.session_state['page_status']=='tokenized': with main_area.container(): sent_1 = st.session_state['sent_1'] sent_2 = st.session_state['sent_2'] if 'masked_pos_1' not in st.session_state: st.session_state['masked_pos_1'] = [] if 'masked_pos_2' not in st.session_state: st.session_state['masked_pos_2'] = [] st.write('2. Select sites to mask out and click "Confirm"') input_sent = tokenizer(sent_1).input_ids decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] char_nums = [len(word)+3 for word in decoded_sent] st.write(char_nums) cols = st.columns(char_nums) for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): with col: if st.button(word,key=f'word_{word_id}'): if word_id not in st.session_state['masked_pos_1']: st.session_state['masked_pos_1'].append(word_id) else: st.session_state['masked_pos_1'].remove(word_id) st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}') if st.session_state['page_status']=='analysis': sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data) sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data) input_ids_1 = tokenizer(sent_1).input_ids input_ids_2 = tokenizer(sent_2).input_ids input_ids = torch.tensor([input_ids_1,input_ids_2]) outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}}) logprobs = F.log_softmax(outputs['logits'], dim = -1) preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]] st.write([tokenizer.decode([token]) for token in preds])