SigLIP_Tagger / app.py
p1atdev's picture
feat: add code
e212637
raw
history blame
No virus
2.98 kB
import numpy as np
import torch
from transformers import (
AutoProcessor,
)
from PIL import Image
import gradio as gr
from modeling_siglip import SiglipForImageClassification
MODEL_NAME = "p1atdev/siglip-tagger-test-3"
PROCESSOR_NAME = "google/siglip-so400m-patch14-384"
model = SiglipForImageClassification.from_pretrained(
MODEL_NAME,
)
# model = torch.compile(model)
processor = AutoProcessor.from_pretrained(PROCESSOR_NAME)
def compose_text(results: dict[str, float], threshold: float = 0.3):
return ", ".join(
[
key
for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True)
if value > threshold
]
)
@torch.no_grad()
def predict_tags(image: Image.Image, threshold: float):
inputs = processor(images=image, return_tensors="pt")
logits = model(**inputs.to(model.device, model.dtype)).logits.detach().cpu()
logits = np.clip(logits, 0.0, 1.0)
results = {}
for prediction in logits:
for i, prob in enumerate(prediction):
if prob.item() > 0:
results[model.config.id2label[i]] = prob.item()
return compose_text(results, threshold), results
css = """\
.sticky {
position: sticky;
top: 16px;
}
.gradio-container {
overflow: clip;
}
"""
def demo():
with gr.Blocks(css=css) as ui:
gr.Markdown(
"""\
## SigLIP Tagger Test 3
An experimental model for tagging danbooru tags of images using SigLIP.
Models:
- (soon)
Example images by NovelAI and niji・journey.
"""
)
with gr.Row():
with gr.Column():
with gr.Row(elem_classes="sticky"):
with gr.Column():
input_img = gr.Image(
label="Input image", type="pil", height=480
)
with gr.Group():
tag_threshold_slider = gr.Slider(
label="Tags threshold",
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.01,
)
start_btn = gr.Button(value="Start", variant="primary")
gr.Examples(
examples=[["./sample.jpg"], ["./sample2.webp"]],
inputs=[input_img],
cache_examples=False,
)
with gr.Column():
output_tags = gr.Text(label="Output text", interactive=False)
output_label = gr.Label(label="Output tags")
start_btn.click(
fn=predict_tags,
inputs=[input_img, tag_threshold_slider],
outputs=[output_tags, output_label],
)
ui.launch(
debug=True,
# share=True
)
if __name__ == "__main__":
demo()