Spaces:
Running
Running
import gradio as gr | |
from transformers import T5TokenizerFast, CLIPTokenizer | |
def count_tokens(text): | |
# Load the common tokenizers | |
t5_tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl", legacy=False) | |
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
# Get tokens and their IDs | |
t5_tokens = t5_tokenizer.encode(text, return_tensors="pt")[0].tolist() | |
clip_tokens = clip_tokenizer.encode(text) | |
# Decode individual tokens for display, replacing whitespace with visible characters | |
t5_decoded = [] | |
for token in t5_tokens: | |
decoded = t5_tokenizer.decode([token]) | |
# Replace whitespace with visible characters and empty strings with special markers | |
if decoded.isspace(): | |
decoded = "β£" # visible space marker | |
elif decoded == "": | |
decoded = "β " # empty token marker | |
t5_decoded.append(decoded) | |
clip_decoded = [] | |
for token in clip_tokens: | |
decoded = clip_tokenizer.decode([token]) | |
if decoded.isspace(): | |
decoded = "β£" | |
elif decoded == "": | |
decoded = "β " | |
clip_decoded.append(decoded) | |
# Create highlighted text tuples (text, label) | |
t5_highlights = [(token, f"Token {i}") for i, token in enumerate(t5_decoded)] | |
clip_highlights = [(token, f"Token {i}") for i, token in enumerate(clip_decoded)] | |
return ( | |
# T5 outputs | |
len(t5_tokens), | |
t5_highlights, | |
str(t5_tokens), | |
# CLIP outputs | |
len(clip_tokens), | |
clip_highlights, | |
str(clip_tokens) | |
) | |
# Create a Gradio interface with custom layout | |
with gr.Blocks(title="Common Diffusion Model Token Counter") as iface: | |
gr.Markdown("# Common Diffusion Model Token Counter") | |
gr.Markdown("Enter text to count tokens using T5 and CLIP tokenizers, commonly used in diffusion models.") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Diffusion Prompt", placeholder="Enter your prompt here...") | |
with gr.Row(): | |
# T5 Column | |
with gr.Column(): | |
gr.Markdown("### T5 Tokenizer Results") | |
t5_count = gr.Number(label="T5 Token Count") | |
t5_highlights = gr.HighlightedText(label="T5 Tokens", show_legend=True) | |
t5_ids = gr.Textbox(label="T5 Token IDs", lines=2) | |
# CLIP Column | |
with gr.Column(): | |
gr.Markdown("### CLIP Tokenizer Results") | |
clip_count = gr.Number(label="CLIP Token Count") | |
clip_highlights = gr.HighlightedText(label="CLIP Tokens", show_legend=True) | |
clip_ids = gr.Textbox(label="CLIP Token IDs", lines=2) | |
text_input.change( | |
fn=count_tokens, | |
inputs=[text_input], | |
outputs=[t5_count, t5_highlights, t5_ids, clip_count, clip_highlights, clip_ids] | |
) | |
# Launch the app | |
iface.launch(show_error=True) |