Spaces:
Runtime error
Runtime error
import html | |
import logging | |
from pathlib import Path | |
import gradio as gr | |
from gradio.themes.utils import colors | |
from transformers import CLIPTokenizer | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
gr_logger = logging.getLogger("gradio") | |
gr_logger.setLevel(logging.INFO) | |
class ClipUtil: | |
def __init__(self): | |
logger.info("Loading ClipUtil") | |
self.theme = gr.themes.Base( | |
primary_hue=colors.violet, | |
secondary_hue=colors.indigo, | |
neutral_hue=colors.slate, | |
font=[gr.themes.GoogleFont("Fira Sans"), "ui-sans-serif", "system-ui", "sans-serif"], | |
font_mono=[gr.themes.GoogleFont("Fira Code"), "ui-monospace", "Consolas", "monospace"], | |
).set( | |
slider_color_dark="*primary_500", | |
) | |
try: | |
self.css = Path(__file__).with_suffix(".css").read_text() | |
except Exception: | |
logger.exception("Failed to load CSS file") | |
self.css = "" | |
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
self.vocab = {v: k for k, v in self.tokenizer.get_vocab().items()} | |
self.blocks = gr.Blocks( | |
title="ClipTokenizerUtil", analytics_enabled=False, theme=self.theme, css=self.css | |
) | |
def tokenize(self, text: str, input_ids: bool = False): | |
if input_ids: | |
tokens = [int(x.strip()) for x in text.split(",")] | |
else: | |
tokens = self.tokenizer(text, return_tensors="np").input_ids.squeeze().tolist() | |
code = "" | |
ids = [] | |
current_ids = [] | |
class_index = 0 | |
byte_decoder = self.tokenizer.byte_decoder | |
def dump(last=False): | |
nonlocal code, ids, current_ids | |
words = [self.vocab.get(x, "") for x in current_ids] | |
def wordscode(ids, word): | |
nonlocal class_index | |
word_title = html.escape(", ".join([str(x) for x in ids])) | |
res = f""" | |
<span class='tokenizer-token tokenizer-token-{class_index % 4}' title='{word_title}'> | |
{html.escape(word)} | |
</span> | |
""" | |
class_index += 1 | |
return res | |
try: | |
word = bytearray([byte_decoder[x] for x in "".join(words)]).decode("utf-8") | |
except UnicodeDecodeError: | |
if last: | |
word = "β" * len(current_ids) | |
elif len(current_ids) > 4: | |
id = current_ids[0] | |
ids += [id] | |
local_ids = current_ids[1:] | |
code += wordscode([id], "β") | |
current_ids = [] | |
for id in local_ids: | |
current_ids.append(id) | |
dump() | |
return | |
else: | |
return | |
# word = word.replace("</w>", " ") | |
code += wordscode(current_ids, word) | |
ids += current_ids | |
current_ids = [] | |
for token in tokens: | |
token = int(token) | |
current_ids.append(token) | |
dump() | |
dump(last=True) | |
ids_html = f""" | |
<p> | |
Token count: {len(ids)} | |
<br> | |
{", ".join([str(x) for x in ids])} | |
</p>""" | |
return code, ids_html | |
def tokenize_ids(self, text: str): | |
return self.tokenize(text, input_ids=True) | |
def create_components(self): | |
with self.blocks: | |
# title bar | |
with gr.Row().style(equal_height=True): | |
with gr.Column(scale=12, elem_id="header_col"): | |
self.header_title = gr.Markdown( | |
"## CLIP Tokenizer Util", | |
elem_id="header_title", | |
) | |
with gr.Column(scale=1, min_width=90, elem_id="button_col"): | |
with gr.Row(elem_id="button_row"): | |
self.reload_btn = gr.Button( | |
label="refresh", | |
elem_id="refresh_btn", | |
type="button", | |
value="π", | |
variant="primary", | |
) | |
with gr.Tabs() as in_tabs: | |
with gr.Tab(label="Text Input", id="text_input_tab"): | |
with gr.Row().style(equal_height=True): | |
with gr.Column(scale=12, elem_id="text_input_col"): | |
self.text_input = gr.Textbox( | |
label="Text Input", | |
elem_id="tokenizer_prompt", | |
show_label=False, | |
lines=8, | |
placeholder="Prompt for tokenization", | |
) | |
self.text_button = gr.Button( | |
label="Tokenize", | |
elem_id="go_button", | |
value="Go", | |
variant="primary", | |
) | |
with gr.Tab(label="Token Input", id="token_input_tab"): | |
with gr.Row().style(equal_height=True): | |
with gr.Column(scale=12, elem_id="text_input_col"): | |
self.token_input = gr.Textbox( | |
lines=5, | |
label="Text Input", | |
elem_id="text_input", | |
placeholder="Enter text here", | |
) | |
self.token_button = gr.Button( | |
label="Tokenize", | |
elem_id="go_button", | |
type="button", | |
value="Go", | |
variant="primary", | |
) | |
with gr.Tabs(): | |
with gr.Tab("Text"): | |
tokenized_text = gr.HTML(elem_id="tokenized_text") | |
with gr.Tab("Tokens"): | |
tokenized_ids = gr.HTML(elem_id="tokenized_ids") | |
self.text_button.click( | |
fn=self.tokenize, | |
inputs=[self.text_input], | |
outputs=[tokenized_text, tokenized_ids], | |
) | |
self.token_button.click( | |
fn=self.tokenize_ids, | |
inputs=[self.token_input], | |
outputs=[tokenized_text, tokenized_ids], | |
) | |
def launch(self, **kwargs) -> None: | |
return self.blocks.launch( | |
server_name="0.0.0.0", | |
show_error=True, | |
enable_queue=True, | |
**kwargs, | |
) | |
if __name__ == "__main__": | |
clip_util = ClipUtil() | |
clip_util.create_components() | |
clip_util.launch() | |