Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from PIL import Image | |
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
TITLE = "E621 Tagger" | |
DESCRIPTION = "Tag images with E621 tags" | |
MODEL_ID = "estrogen/paligemma2-3b-e621-224" | |
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID) | |
model.to("cuda") | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
def tag_image(image, max_new_tokens=128, temperature=1, top_p=1, min_p=0): | |
inputs = processor(images=image, text="<image>tag en", return_tensors="pt").to("cuda") | |
streamer = TextIteratorStreamer(tokenizer=processor.tokenizer, skip_prompt=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens, use_cache=True, cache_implementation="hybrid", do_sample=True, temperature=temperature, top_p=top_p, min_p=min_p) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
text = "" | |
thread.start() | |
for new_text in streamer: | |
text += new_text | |
yield text | |
return text | |
gr.Interface( | |
fn=tag_image, | |
inputs=[gr.Image(type="pil"), gr.Slider(label="Max new tokens", minimum=1, maximum=1024, value=128), gr.Slider(label="Temperature", minimum=0, maximum=2, value=1), gr.Slider(label="Top p", minimum=0, maximum=1, value=1), gr.Slider(label="Min p", minimum=0, maximum=1, value=0)], | |
outputs=gr.Textbox(type="text"), | |
title=TITLE, | |
description=DESCRIPTION, | |
).launch() |