vic / app.py
altndrr's picture
Wrap badges if space is not enough
8994e5d
raw
history blame
No virus
4.66 kB
import os
from glob import glob
from typing import Optional
import gradio as gr
import torch
from torchvision.transforms.functional import resize, to_pil_image
from transformers import AutoModel, CLIPProcessor
PAPER_TITLE = "Vocabulary-free Image Classification"
PAPER_URL = "https://arxiv.org/abs/2306.00917"
MARKDOWN_DESCRIPTION = """
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
<h1>Vocabulary-free Image Classification</h1>
</div>
<div style="display: flex;
flex-wrap: wrap;
align-items: center;
justify-content: center;
margin-bottom: 1rem;">
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem; margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
</a>
<a href="https://huggingface.co/spaces/altndrr/vic" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/demo-hf.altndrr%2Fvic-yellow.svg"/>
</a>
<a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/paper-arXiv.2306.00917-B31B1B.svg"/>
</a>
<a href="https://alessandroconti.me/papers/2306.00917" style="margin-right: 0.5rem;
margin-bottom: 0.5rem;">
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
</a>
</div>
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE)
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
def save_original_image(image: gr.Image):
if image is None:
return None, None
size = PROCESSOR.image_processor.size["shortest_edge"]
size = min(size) if isinstance(size, tuple) else size
image = resize(image, size)
return image, image.copy()
def prepare_image(image: gr.Image):
if image is None:
return None, None
PROCESSOR.image_processor.do_normalize = False
image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True)
PROCESSOR.image_processor.do_normalize = True
image_tensor = image_tensor.pixel_values[0]
curr_image = to_pil_image(image_tensor)
return curr_image, image.copy()
def image_inference(image: gr.Image, alpha: Optional[float] = None):
if image is None:
return None
images = PROCESSOR(images=[image], return_tensors="pt", padding=True)
with torch.no_grad():
outputs = MODEL(images, alpha=alpha)
vocabulary = outputs["vocabularies"][0]
scores = outputs["scores"][0].tolist()
confidences = dict(zip(vocabulary, scores))
return confidences
with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
# LAYOUT
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
with gr.Column():
curr_image = gr.Image(label="input", type="pil")
_orig_image = gr.Image(
label="orig. image", type="pil", visible=False, interactive=False
)
alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha")
with gr.Row():
clear_button = gr.Button(value="Clear", variant="secondary")
run_button = gr.Button(value="Submit", variant="primary")
with gr.Column():
output_label = gr.Label(label="output", num_top_classes=5)
examples = gr.Examples(
examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")),
inputs=[_orig_image],
outputs=[output_label],
fn=image_inference,
cache_examples=True,
)
gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.")
# INTERACTIONS
# - change
_orig_image.change(prepare_image, [_orig_image], [curr_image, _orig_image])
# - upload
curr_image.upload(save_original_image, [curr_image], [curr_image, _orig_image])
curr_image.upload(lambda: None, [], [output_label])
# - clear
curr_image.clear(lambda: (None, None), [], [_orig_image, output_label])
# - click
clear_button.click(lambda: (None, None, None), [], [curr_image, _orig_image, output_label])
run_button.click(image_inference, [curr_image, alpha_slider], [output_label])
if __name__ == "__main__":
demo.queue()
demo.launch()