Fizzarolli's picture
blehhhhh
fda67b7
raw
history blame
1.23 kB
import gradio as gr
import spaces
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
from transformers import TextIteratorStreamer
from threading import Thread
TITLE = "E621 Tagger"
DESCRIPTION = "Tag images with E621 tags"
MODEL_ID = "estrogen/paligemma2-3b-e621-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID)
model.to("cuda")
processor = AutoProcessor.from_pretrained(MODEL_ID)
@spaces.GPU
def tag_image(image, max_new_tokens=128):
inputs = processor(images=image, text="<image>tag en", return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer=processor.tokenizer, skip_prompt=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens, use_cache=True, cache_implementation="hybrid")
thread = Thread(target=model.generate, kwargs=generation_kwargs)
text = ""
thread.start()
for new_text in streamer:
text += new_text
yield text
return text
gr.Interface(
fn=tag_image,
inputs=[gr.Image(type="pil"), gr.Slider(minimum=1, maximum=1024, value=128)],
outputs=gr.Textbox(type="text"),
title=TITLE,
description=DESCRIPTION,
).launch()