tokenizer-demo / app.py
taka-yamakoshi
fix a minor bug
6751661
raw history blame
No virus
6.43 kB
import pandas as pd
import streamlit as st
import numpy as np
import torch
import io
import time
@st.cache(show_spinner=True,allow_output_mutation=True)
def load_model(model_name):
if model_name.startswith('bert'):
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
elif model_name.startswith('gpt2'):
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
elif model_name.startswith('roberta'):
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained(model_name)
elif model_name.startswith('albert'):
from transformers import AlbertTokenizer
tokenizer = AlbertTokenizer.from_pretrained(model_name)
return tokenizer
def generate_markdown(text,color='black',font='Arial',size=20):
return f"<p style='text-align:center; color:{color}; font-family:{font}; font-size:{size}px;'>{text}</p>"
def TokenizeText(sentence,tokenizer_name):
if len(sentence)>0:
if tokenizer_name.startswith('gpt2'):
input_sent = tokenizer(sentence)['input_ids']
else:
input_sent = tokenizer(sentence)['input_ids'][1:-1]
encoded_sent = [str(token) for token in input_sent]
decoded_sent = [tokenizer.decode([token]) for token in input_sent]
num_tokens = len(decoded_sent)
#char_nums = [len(word)+2 for word in decoded_sent]
#word_cols = st.columns(char_nums)
#for word_col,word in zip(word_cols,decoded_sent):
#with word_col:
#st.write(word)
#st.write(' '.join(encoded_sent))
#st.write(' '.join(decoded_sent))
st.markdown(generate_markdown(' '.join(encoded_sent),size=16), unsafe_allow_html=True)
st.markdown(generate_markdown(' '.join(decoded_sent),size=16), unsafe_allow_html=True)
st.markdown(generate_markdown(f'{num_tokens} tokens'), unsafe_allow_html=True)
return num_tokens
def DeTokenizeText(input_str):
if len(input_str)>0:
input_sent = [int(element) for element in input_str.strip().split(' ')]
encoded_sent = [str(token) for token in input_sent]
decoded_sent = tokenizer.decode(input_sent)
num_tokens = len(input_sent)
#char_nums = [len(word)+2 for word in decoded_sent]
#word_cols = st.columns(char_nums)
#for word_col,word in zip(word_cols,decoded_sent):
#with word_col:
#st.write(word)
#st.write(' '.join(encoded_sent))
#st.write(' '.join(decoded_sent))
st.markdown(generate_markdown(decoded_sent), unsafe_allow_html=True)
return num_tokens
if __name__=='__main__':
# Config
max_width = 1500
padding_top = 0
padding_right = 2
padding_bottom = 0
padding_left = 2
define_margins = f"""
<style>
.appview-container .main .block-container{{
max-width: {max_width}px;
padding-top: {padding_top}rem;
padding-right: {padding_right}rem;
padding-left: {padding_left}rem;
padding-bottom: {padding_bottom}rem;
}}
</style>
"""
hide_table_row_index = """
<style>
tbody th {display:none}
.blank {display:none}
</style>
"""
st.markdown(define_margins, unsafe_allow_html=True)
st.markdown(hide_table_row_index, unsafe_allow_html=True)
# Title
st.markdown(generate_markdown('WordPiece Explorer',size=32), unsafe_allow_html=True)
st.markdown(generate_markdown('- quick and easy way to explore how tokenizers work -',size=24), unsafe_allow_html=True)
# Select and load the tokenizer
st.sidebar.write('1. Choose the tokenizer from below')
tokenizer_name = st.sidebar.selectbox('',
('bert-base-uncased','bert-large-cased',
'gpt2','gpt2-large',
'roberta-base','roberta-large',
'albert-base-v2','albert-xxlarge-v2'),index=7)
tokenizer = load_model(tokenizer_name)
st.sidebar.write('2. Optional settings')
comparison_mode = st.sidebar.checkbox('Compare two texts')
detokenize = st.sidebar.checkbox('de-tokenize')
st.sidebar.write(f'"Compare two texts" compares # tokens for two pieces of text '\
+f'and "de-tokenize" converts a list of tokenized indices back to strings.')
st.sidebar.write(f'For "de-tokenize", make sure to type in integers, separated by single spaces.')
if comparison_mode:
sent_cols = st.columns(2)
num_tokens = {}
sents = {}
for sent_id, sent_col in enumerate(sent_cols):
with sent_col:
if detokenize:
sentence = st.text_input(f'Tokenized IDs {sent_id+1}')
num_tokens[f'sent_{sent_id+1}'] = DeTokenizeText(sentence)
else:
sentence = st.text_input(f'Text {sent_id+1}')
num_tokens[f'sent_{sent_id+1}'] = TokenizeText(sentence,tokenizer_name)
sents[f'sent_{sent_id+1}'] = sentence
if len(sents['sent_1'])>0 and len(sents['sent_2'])>0:
st.markdown(generate_markdown('# Tokens&colon; ',size=16), unsafe_allow_html=True)
if num_tokens[f'sent_1']==num_tokens[f'sent_2']:
st.markdown(generate_markdown('Matched! ',color='MediumAquamarine'), unsafe_allow_html=True)
else:
st.markdown(generate_markdown('Not Matched... ',color='Salmon'), unsafe_allow_html=True)
else:
if detokenize:
if tokenizer_name.startswith('gpt2'):
default_tokens = tokenizer('Tokenizers decompose bigger words into smaller tokens')['input_ids']
else:
default_tokens = tokenizer('Tokenizers decompose bigger words into smaller tokens')['input_ids'][1:-1]
sentence = st.text_input(f'Tokenized IDs',value=' '.join([str(token) for token in default_tokens]))
num_tokens = DeTokenizeText(sentence)
else:
sentence = st.text_input(f'Text',value='Tokenizers decompose bigger words into smaller tokens')
num_tokens = TokenizeText(sentence,tokenizer_name)