Spaces:
Running
Running
import pandas as pd | |
import streamlit as st | |
import numpy as np | |
import torch | |
import io | |
import time | |
def load_model(model_name): | |
if model_name.startswith('bert'): | |
from transformers import BertTokenizer | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
elif model_name.startswith('gpt2'): | |
from transformers import GPT2Tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
elif model_name.startswith('roberta'): | |
from transformers import RobertaTokenizer | |
tokenizer = RobertaTokenizer.from_pretrained(model_name) | |
elif model_name.startswith('albert'): | |
from transformers import AlbertTokenizer | |
tokenizer = AlbertTokenizer.from_pretrained(model_name) | |
return tokenizer | |
def generate_markdown(text,color='black',font='Arial',size=20): | |
return f"<p style='text-align:center; color:{color}; font-family:{font}; font-size:{size}px;'>{text}</p>" | |
def TokenizeText(sentence,tokenizer_name): | |
if len(sentence)>0: | |
if tokenizer_name.startswith('gpt2'): | |
input_sent = tokenizer(sentence)['input_ids'] | |
else: | |
input_sent = tokenizer(sentence)['input_ids'][1:-1] | |
encoded_sent = [str(token) for token in input_sent] | |
decoded_sent = [tokenizer.decode([token]) for token in input_sent] | |
num_tokens = len(decoded_sent) | |
#char_nums = [len(word)+2 for word in decoded_sent] | |
#word_cols = st.columns(char_nums) | |
#for word_col,word in zip(word_cols,decoded_sent): | |
#with word_col: | |
#st.write(word) | |
#st.write(' '.join(encoded_sent)) | |
#st.write(' '.join(decoded_sent)) | |
st.markdown(generate_markdown(' '.join(encoded_sent),size=16), unsafe_allow_html=True) | |
st.markdown(generate_markdown(' '.join(decoded_sent),size=16), unsafe_allow_html=True) | |
st.markdown(generate_markdown(f'{num_tokens} tokens'), unsafe_allow_html=True) | |
return num_tokens | |
def DeTokenizeText(input_str): | |
if len(input_str)>0: | |
input_sent = [int(element) for element in input_str.strip().split(' ')] | |
encoded_sent = [str(token) for token in input_sent] | |
decoded_sent = [tokenizer.decode([token]) for token in input_sent] | |
num_tokens = len(decoded_sent) | |
#char_nums = [len(word)+2 for word in decoded_sent] | |
#word_cols = st.columns(char_nums) | |
#for word_col,word in zip(word_cols,decoded_sent): | |
#with word_col: | |
#st.write(word) | |
#st.write(' '.join(encoded_sent)) | |
#st.write(' '.join(decoded_sent)) | |
st.markdown(generate_markdown(' '.join(decoded_sent)), unsafe_allow_html=True) | |
return num_tokens | |
if __name__=='__main__': | |
# Config | |
max_width = 1500 | |
padding_top = 0 | |
padding_right = 2 | |
padding_bottom = 0 | |
padding_left = 2 | |
define_margins = f""" | |
<style> | |
.appview-container .main .block-container{{ | |
max-width: {max_width}px; | |
padding-top: {padding_top}rem; | |
padding-right: {padding_right}rem; | |
padding-left: {padding_left}rem; | |
padding-bottom: {padding_bottom}rem; | |
}} | |
</style> | |
""" | |
hide_table_row_index = """ | |
<style> | |
tbody th {display:none} | |
.blank {display:none} | |
</style> | |
""" | |
st.markdown(define_margins, unsafe_allow_html=True) | |
st.markdown(hide_table_row_index, unsafe_allow_html=True) | |
# Title | |
st.markdown(generate_markdown('Tokenizer Demo:',size=32), unsafe_allow_html=True) | |
st.markdown(generate_markdown('quick and easy way to explore how tokenizers work',size=24), unsafe_allow_html=True) | |
# Select and load the tokenizer | |
tokenizer_name = st.sidebar.selectbox('Choose the tokenizer from below', | |
('bert-base-uncased','bert-large-cased', | |
'gpt2','gpt2-large', | |
'roberta-base','roberta-large', | |
'albert-base-v2','albert-xxlarge-v2'),index=7) | |
tokenizer = load_model(tokenizer_name) | |
comparison_mode = st.sidebar.checkbox('Compare two texts') | |
detokenize = st.sidebar.checkbox('de-tokenize (make sure to type in integers separated by single spaces)') | |
if comparison_mode: | |
sent_cols = st.columns(2) | |
num_tokens = {} | |
sents = {} | |
for sent_id, sent_col in enumerate(sent_cols): | |
with sent_col: | |
if detokenize: | |
sentence = st.text_input(f'Tokenized IDs {sent_id+1}') | |
num_tokens[f'sent_{sent_id+1}'] = DeTokenizeText(sentence) | |
else: | |
sentence = st.text_input(f'Text {sent_id+1}') | |
num_tokens[f'sent_{sent_id+1}'] = TokenizeText(sentence,tokenizer_name) | |
sents[f'sent_{sent_id+1}'] = sentence | |
if len(sents['sent_1'])>0 and len(sents['sent_2'])>0: | |
st.markdown(generate_markdown('Result: ',size=16), unsafe_allow_html=True) | |
if num_tokens[f'sent_1']==num_tokens[f'sent_2']: | |
st.markdown(generate_markdown('Matched! ',color='MediumAquamarine'), unsafe_allow_html=True) | |
else: | |
st.markdown(generate_markdown('Not Matched... ',color='Salmon'), unsafe_allow_html=True) | |
else: | |
if detokenize: | |
sentence = st.text_input(f'Tokenized IDs') | |
num_tokens = DeTokenizeText(sentence) | |
else: | |
sentence = st.text_input(f'Text') | |
num_tokens = TokenizeText(sentence,tokenizer_name) | |