|
import numpy as np |
|
import pandas as pd |
|
import time |
|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from transformers import AlbertTokenizer, AlbertForMaskedLM |
|
|
|
|
|
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM |
|
|
|
def wide_setup(): |
|
max_width = 1500 |
|
padding_top = 0 |
|
padding_right = 2 |
|
padding_bottom = 0 |
|
padding_left = 2 |
|
|
|
define_margins = f""" |
|
<style> |
|
.appview-container .main .block-container{{ |
|
max-width: {max_width}px; |
|
padding-top: {padding_top}rem; |
|
padding-right: {padding_right}rem; |
|
padding-left: {padding_left}rem; |
|
padding-bottom: {padding_bottom}rem; |
|
}} |
|
</style> |
|
""" |
|
hide_table_row_index = """ |
|
<style> |
|
tbody th {display:none} |
|
.blank {display:none} |
|
</style> |
|
""" |
|
st.markdown(define_margins, unsafe_allow_html=True) |
|
st.markdown(hide_table_row_index, unsafe_allow_html=True) |
|
|
|
@st.cache(show_spinner=True,allow_output_mutation=True) |
|
def load_model(): |
|
tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2') |
|
|
|
model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2') |
|
return tokenizer,model |
|
|
|
def clear_data(): |
|
for key in st.session_state: |
|
del st.session_state[key] |
|
|
|
if __name__=='__main__': |
|
wide_setup() |
|
|
|
if 'page_status' not in st.session_state: |
|
st.session_state['page_status'] = 'type_in' |
|
|
|
if st.session_state['page_status']=='type_in': |
|
tokenizer,model = load_model() |
|
mask_id = tokenizer('[MASK]').input_ids[1:-1][0] |
|
|
|
st.write('1. Type in the sentences and click "Tokenize"') |
|
sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.') |
|
sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.') |
|
if st.button('Tokenize'): |
|
st.session_state['page_status'] = 'tokenized' |
|
st.session_state['sent_1'] = sent_1 |
|
st.session_state['sent_2'] = sent_2 |
|
st.experimental_rerun() |
|
|
|
if st.session_state['page_status']=='tokenized': |
|
tokenizer,model = load_model() |
|
mask_id = tokenizer('[MASK]').input_ids[1:-1][0] |
|
sent_1 = st.session_state['sent_1'] |
|
sent_2 = st.session_state['sent_2'] |
|
if 'masked_pos_1' not in st.session_state: |
|
st.session_state['masked_pos_1'] = [] |
|
if 'masked_pos_2' not in st.session_state: |
|
st.session_state['masked_pos_2'] = [] |
|
|
|
st.write('2. Select sites to mask out and click "Confirm"') |
|
input_sent = tokenizer(sent_1).input_ids |
|
decoded_sent = [tokenizer.decode([token]) for token in input_sent] |
|
char_nums = [len(word)+2 for word in decoded_sent] |
|
cols = st.columns(char_nums) |
|
with cols[0]: |
|
st.write(decoded_sent[0]) |
|
with cols[-1]: |
|
st.write(decoded_sent[-1]) |
|
for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])): |
|
with col: |
|
if st.button(word,key=f'word_{word_id}'): |
|
if word_in not in st.session_state['masked_pos_1']: |
|
st.session_state['masked_pos_1'].append(word_id) |
|
st.write(f'Masked words: {", ".join([decoded_sent[word_id+1] for word_id in np.sort(st.session_state["masked_pos_1"])])}') |
|
|
|
''' |
|
sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data) |
|
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data) |
|
input_ids_1 = tokenizer(sent_1).input_ids |
|
input_ids_2 = tokenizer(sent_2).input_ids |
|
input_ids = torch.tensor([input_ids_1,input_ids_2]) |
|
|
|
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}}) |
|
logprobs = F.log_softmax(outputs['logits'], dim = -1) |
|
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]] |
|
st.write([tokenizer.decode([token]) for token in preds]) |
|
''' |
|
|