File size: 1,526 Bytes
b00b081 1967c03 b00b081 5c9b0ca b00b081 78edf01 fda67b7 1967c03 8a609d3 1967c03 b00b081 78edf01 b00b081 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import gradio as gr
import spaces
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
from transformers import TextIteratorStreamer
from threading import Thread
TITLE = "E621 Tagger"
DESCRIPTION = "Tag images with E621 tags"
MODEL_ID = "estrogen/paligemma2-3b-e621-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID)
model.to("cuda")
processor = AutoProcessor.from_pretrained(MODEL_ID)
@spaces.GPU
def tag_image(image, max_new_tokens=128, temperature=1, top_p=1, min_p=0):
inputs = processor(images=image, text="<image>tag en", return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer=processor.tokenizer, skip_prompt=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens, use_cache=True, cache_implementation="hybrid", do_sample=True, temperature=temperature, top_p=top_p, min_p=min_p)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
text = ""
thread.start()
for new_text in streamer:
text += new_text
yield text
return text
gr.Interface(
fn=tag_image,
inputs=[gr.Image(type="pil"), gr.Slider(label="Max new tokens", minimum=1, maximum=1024, value=128), gr.Slider(label="Temperature", minimum=0, maximum=2, value=1), gr.Slider(label="Top p", minimum=0, maximum=1, value=1), gr.Slider(label="Min p", minimum=0, maximum=1, value=0)],
outputs=gr.Textbox(type="text"),
title=TITLE,
description=DESCRIPTION,
).launch() |