File size: 1,526 Bytes
b00b081
 
 
 
1967c03
 
b00b081
 
 
 
 
 
 
5c9b0ca
b00b081
 
 
78edf01
fda67b7
1967c03
8a609d3
1967c03
 
 
 
 
 
 
 
 
b00b081
 
 
78edf01
b00b081
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import gradio as gr
import spaces
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
from transformers import TextIteratorStreamer
from threading import Thread

TITLE = "E621 Tagger"
DESCRIPTION = "Tag images with E621 tags"

MODEL_ID = "estrogen/paligemma2-3b-e621-224"

model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID)
model.to("cuda")
processor = AutoProcessor.from_pretrained(MODEL_ID)

@spaces.GPU
def tag_image(image, max_new_tokens=128, temperature=1, top_p=1, min_p=0):
    inputs = processor(images=image, text="<image>tag en", return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(tokenizer=processor.tokenizer, skip_prompt=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens, use_cache=True, cache_implementation="hybrid", do_sample=True, temperature=temperature, top_p=top_p, min_p=min_p)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    
    text = ""
    thread.start()
    for new_text in streamer:
        text += new_text
        yield text

    return text

gr.Interface(
    fn=tag_image,
    inputs=[gr.Image(type="pil"), gr.Slider(label="Max new tokens", minimum=1, maximum=1024, value=128), gr.Slider(label="Temperature", minimum=0, maximum=2, value=1), gr.Slider(label="Top p", minimum=0, maximum=1, value=1), gr.Slider(label="Min p", minimum=0, maximum=1, value=0)],
    outputs=gr.Textbox(type="text"),
    title=TITLE,
    description=DESCRIPTION,
).launch()