Spaces:
Runtime error
Runtime error
from typing import Optional | |
import gradio as gr | |
import torch | |
from src.nn import CaSED | |
PAPER_TITLE = "Vocabulary-free Image Classification" | |
PAPER_DESCRIPTION = """ | |
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;"> | |
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;"> | |
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/> | |
</a> | |
<a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;"> | |
<img src="https://img.shields.io/badge/paper-arXiv%3A2306.00917-B31B1B.svg"/> | |
</a> | |
<a href="https://altndrr.github.io/vic/" style="margin-right: 0.5rem;"> | |
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/> | |
</a> | |
</div> | |
Vocabulary-free Image Classification aims to assign a class to an image *without* prior knowledge | |
on the list of class names, thus operating on the semantic class space that contains all the | |
possible concepts. Our proposed method CaSED finds the best matching category within the | |
unconstrained semantic space by multimodal data from large vision-language databases. We first | |
retrieve the semantically most similar captions from a database, from which we extract a set of | |
candidate categories by applying text parsing and filtering techniques. We further score the | |
candidates using the multimodal aligned representation of the large pre-trained VLM, *i.e.* CLIP, | |
to obtain the best-matching category, using *alpha* as a hyperparameter to control the trade-off | |
between the visual and textual similarity. | |
""" | |
PAPER_URL = "https://arxiv.org/abs/2306.00917" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = CaSED().to(DEVICE).eval() | |
def vic(filename: str, alpha: Optional[float] = None): | |
# get the outputs of the model | |
vocabulary, scores = model(filename, alpha=alpha) | |
confidences = dict(zip(vocabulary, scores)) | |
return confidences | |
def resize_image(image, max_size: int = 256): | |
"""Resize image to max_size keeping the aspect ratio.""" | |
width, height = image.size | |
if width > height: | |
ratio = width / height | |
new_width = max_size * ratio | |
new_height = max_size | |
else: | |
ratio = height / width | |
new_width = max_size | |
new_height = max_size * ratio | |
return image.resize((int(new_width), int(new_height))) | |
demo = gr.Interface( | |
fn=vic, | |
inputs=[ | |
gr.Image(type="filepath", label="input"), | |
gr.Slider(0.0, 1.0, value=0.5, label="alpha"), | |
], | |
outputs=[gr.Label(num_top_classes=5, label="output")], | |
title=PAPER_TITLE, | |
description=PAPER_DESCRIPTION, | |
article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.", | |
examples="./artifacts/examples/", | |
allow_flagging='never', | |
theme=gr.themes.Soft() | |
) | |
demo.launch(share=False) | |