Fizzarolli's picture
aas
1967c03
raw
history blame
1.1 kB
import gradio as gr
import spaces
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
from transformers import TextIteratorStreamer
from threading import Thread
TITLE = "E621 Tagger"
DESCRIPTION = "Tag images with E621 tags"
MODEL_ID = "estrogen/paligemma2-3b-e621-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID)
processor = AutoProcessor.from_pretrained(MODEL_ID)
@spaces.GPU
def tag_image(image):
inputs = processor(images=image, text="<image>tag en", return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer=processor.tokenizer, skip_prompt=True)
generation_kwargs = dict(inputs, streamer=streamer, return_full_text=False)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
text = ""
thread.start()
for new_text in streamer:
text += new_text
yield text
thread.join()
return text
gr.Interface(
fn=tag_image,
inputs=[gr.Image(type="pil")],
outputs=gr.Textbox(type="text"),
title=TITLE,
description=DESCRIPTION,
).launch()