Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from diffusers import DiffusionPipeline | |
from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast | |
import pandas as pd | |
def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]: | |
config = DiffusionPipeline.load_config(model_id) | |
num_tokenizers = sum("tokenizer" in key for key in config.keys()) | |
if not 1 <= num_tokenizers <= 3: | |
raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}") | |
tokenizers = [ | |
AutoTokenizer.from_pretrained( | |
model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}' | |
) | |
for i in range(num_tokenizers) | |
] | |
# Pad the list with None if there are fewer than 3 tokenizers | |
tokenizers.extend([None] * (3 - num_tokenizers)) | |
return tokenizers | |
def inference(model_id: str, text: str): | |
tokenizers = load_tokenizers(model_id) | |
text_pairs_components = [] | |
special_tokens_components = [] | |
tokenizer_details_components = [] | |
for i, tokenizer in enumerate(tokenizers): | |
if tokenizer: | |
label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}" | |
# テキストとトークンIDのペアを作成 | |
input_ids = tokenizer( | |
text=text, | |
truncation=False, | |
return_length=False, | |
return_overflowing_tokens=False, | |
).input_ids | |
decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids] | |
token_pairs = [ | |
(str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids) | |
] | |
output_text_pair_component = gr.HighlightedText( | |
label=label_text, | |
value=token_pairs, | |
visible=True, | |
) | |
# スペシャルトークンを追加 | |
special_tokens = [] | |
for k, v in tokenizer.special_tokens_map.items(): | |
if k == "additional_special_tokens": | |
continue | |
special_token_map = (str(k), str(v)) | |
special_tokens.append(special_token_map) | |
output_special_tokens_component = gr.HighlightedText( | |
label=label_text, | |
value=special_tokens, | |
visible=True, | |
) | |
# トークナイザーの詳細情報を追加 | |
tokenizer_details = pd.DataFrame([ | |
("Type", tokenizer.__class__.__name__), | |
("Vocab Size", tokenizer.vocab_size), | |
("Model Max Length", tokenizer.model_max_length), | |
("Padding Side", tokenizer.padding_side), | |
("Truncation Side", tokenizer.truncation_side), | |
], columns=["Attribute", "Value"]) | |
output_tokenizer_details = gr.Dataframe( | |
headers=["Attribute", "Value"], | |
value=tokenizer_details, | |
label=label_text, | |
visible=True, | |
) | |
else: | |
output_text_pair_component = gr.HighlightedText(visible=False) | |
output_special_tokens_component = gr.HighlightedText(visible=False) | |
output_tokenizer_details = gr.Dataframe(visible=False) | |
text_pairs_components.append(output_text_pair_component) | |
special_tokens_components.append(output_special_tokens_component) | |
tokenizer_details_components.append(output_tokenizer_details) | |
return text_pairs_components + special_tokens_components + tokenizer_details_components | |
if __name__ == "__main__": | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.colors.emerald, | |
secondary_hue=gr.themes.colors.emerald, | |
) | |
with gr.Blocks(theme=theme) as demo: | |
with gr.Column(): | |
input_model_id = gr.Dropdown( | |
label="Model ID", | |
choices=[ | |
"black-forest-labs/FLUX.1-dev", | |
"black-forest-labs/FLUX.1-schnell", | |
"stabilityai/stable-diffusion-3-medium-diffusers", | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
"stable-diffusion-v1-5/stable-diffusion-v1-5", | |
"stabilityai/japanese-stable-diffusion-xl", | |
"rinna/japanese-stable-diffusion", | |
], | |
value="black-forest-labs/FLUX.1-dev", | |
) | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter text here", | |
) | |
with gr.Tab(label="Tokenization Outputs"): | |
with gr.Column(): | |
output_highlighted_text_1 = gr.HighlightedText() | |
output_highlighted_text_2 = gr.HighlightedText() | |
output_highlighted_text_3 = gr.HighlightedText() | |
with gr.Tab(label="Special Tokens"): | |
with gr.Column(): | |
output_special_tokens_1 = gr.HighlightedText() | |
output_special_tokens_2 = gr.HighlightedText() | |
output_special_tokens_3 = gr.HighlightedText() | |
with gr.Tab(label="Tokenizer Details"): | |
with gr.Column(): | |
output_tokenizer_details_1 = gr.Dataframe(headers=["Attribute", "Value"]) | |
output_tokenizer_details_2 = gr.Dataframe(headers=["Attribute", "Value"]) | |
output_tokenizer_details_3 = gr.Dataframe(headers=["Attribute", "Value"]) | |
with gr.Row(): | |
clear_button = gr.ClearButton(components=[input_text]) | |
submit_button = gr.Button("Run", variant="primary") | |
all_inputs = [input_model_id, input_text] | |
all_output = [ | |
output_highlighted_text_1, | |
output_highlighted_text_2, | |
output_highlighted_text_3, | |
output_special_tokens_1, | |
output_special_tokens_2, | |
output_special_tokens_3, | |
output_tokenizer_details_1, | |
output_tokenizer_details_2, | |
output_tokenizer_details_3, | |
] | |
submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output) | |
examples = gr.Examples( | |
fn=inference, | |
inputs=all_inputs, | |
outputs=all_output, | |
examples=[ | |
["black-forest-labs/FLUX.1-dev", "a photo of cat"], | |
[ | |
"stabilityai/stable-diffusion-3-medium-diffusers", | |
'cat holding sign saying "I am a cat"', | |
], | |
["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"], | |
], | |
cache_examples=True, | |
) | |
demo.queue().launch() |