vic / app.py
altndrr's picture
Wrap badges if space is not enough
8994e5d
import os
from glob import glob
from typing import Optional
import gradio as gr
import torch
from torchvision.transforms.functional import resize, to_pil_image
from transformers import AutoModel, CLIPProcessor
PAPER_TITLE = "Vocabulary-free Image Classification"
PAPER_URL = "https://arxiv.org/abs/2306.00917"
MARKDOWN_DESCRIPTION = """
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
<h1>Vocabulary-free Image Classification</h1>
</div>
<div style="display: flex;
flex-wrap: wrap;
align-items: center;
justify-content: center;
margin-bottom: 1rem;">
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem; margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
</a>
<a href="https://huggingface.co/spaces/altndrr/vic" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/demo-hf.altndrr%2Fvic-yellow.svg"/>
</a>
<a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/paper-arXiv.2306.00917-B31B1B.svg"/>
</a>
<a href="https://alessandroconti.me/papers/2306.00917" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
</a>
</div>
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE)
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
def save_original_image(image: gr.Image):
if image is None:
return None, None
size = PROCESSOR.image_processor.size["shortest_edge"]
size = min(size) if isinstance(size, tuple) else size
image = resize(image, size)
return image, image.copy()
def prepare_image(image: gr.Image):
if image is None:
return None, None
PROCESSOR.image_processor.do_normalize = False
image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True)
PROCESSOR.image_processor.do_normalize = True
image_tensor = image_tensor.pixel_values[0]
curr_image = to_pil_image(image_tensor)
return curr_image, image.copy()
def image_inference(image: gr.Image, alpha: Optional[float] = None):
if image is None:
return None
images = PROCESSOR(images=[image], return_tensors="pt", padding=True)
with torch.no_grad():
outputs = MODEL(images, alpha=alpha)
vocabulary = outputs["vocabularies"][0]
scores = outputs["scores"][0].tolist()
confidences = dict(zip(vocabulary, scores))
return confidences
with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
# LAYOUT
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
with gr.Column():
curr_image = gr.Image(label="input", type="pil")
_orig_image = gr.Image(
label="orig. image", type="pil", visible=False, interactive=False
)
alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha")
with gr.Row():
clear_button = gr.Button(value="Clear", variant="secondary")
run_button = gr.Button(value="Submit", variant="primary")
with gr.Column():
output_label = gr.Label(label="output", num_top_classes=5)
examples = gr.Examples(
examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")),
inputs=[_orig_image],
outputs=[output_label],
fn=image_inference,
cache_examples=True,
)
gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.")
# INTERACTIONS
# - change
_orig_image.change(prepare_image, [_orig_image], [curr_image, _orig_image])
# - upload
curr_image.upload(save_original_image, [curr_image], [curr_image, _orig_image])
curr_image.upload(lambda: None, [], [output_label])
# - clear
curr_image.clear(lambda: (None, None), [], [_orig_image, output_label])
# - click
clear_button.click(lambda: (None, None, None), [], [curr_image, _orig_image, output_label])
run_button.click(image_inference, [curr_image, alpha_slider], [output_label])
if __name__ == "__main__":
demo.queue()
demo.launch()