vic / app.py
altndrr's picture
Add first version
a3ee979
raw
history blame
2.89 kB
from typing import Optional
import gradio as gr
import torch
from src.nn import CaSED
PAPER_TITLE = "Vocabulary-free Image Classification"
PAPER_DESCRIPTION = """
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;">
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
</a>
<a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;">
<img src="https://img.shields.io/badge/paper-arXiv%3A2306.00917-B31B1B.svg"/>
</a>
<a href="https://altndrr.github.io/vic/" style="margin-right: 0.5rem;">
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
</a>
</div>
Vocabulary-free Image Classification aims to assign a class to an image *without* prior knowledge
on the list of class names, thus operating on the semantic class space that contains all the
possible concepts. Our proposed method CaSED finds the best matching category within the
unconstrained semantic space by multimodal data from large vision-language databases. We first
retrieve the semantically most similar captions from a database, from which we extract a set of
candidate categories by applying text parsing and filtering techniques. We further score the
candidates using the multimodal aligned representation of the large pre-trained VLM, *i.e.* CLIP,
to obtain the best-matching category, using *alpha* as a hyperparameter to control the trade-off
between the visual and textual similarity.
"""
PAPER_URL = "https://arxiv.org/abs/2306.00917"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CaSED().to(DEVICE).eval()
def vic(filename: str, alpha: Optional[float] = None):
# get the outputs of the model
vocabulary, scores = model(filename, alpha=alpha)
confidences = dict(zip(vocabulary, scores))
return confidences
def resize_image(image, max_size: int = 256):
"""Resize image to max_size keeping the aspect ratio."""
width, height = image.size
if width > height:
ratio = width / height
new_width = max_size * ratio
new_height = max_size
else:
ratio = height / width
new_width = max_size
new_height = max_size * ratio
return image.resize((int(new_width), int(new_height)))
demo = gr.Interface(
fn=vic,
inputs=[
gr.Image(type="filepath", label="input"),
gr.Slider(0.0, 1.0, value=0.5, label="alpha"),
],
outputs=[gr.Label(num_top_classes=5, label="output")],
title=PAPER_TITLE,
description=PAPER_DESCRIPTION,
article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.",
examples="./artifacts/examples/",
allow_flagging='never',
theme=gr.themes.Soft()
)
demo.launch(share=False)