File size: 4,658 Bytes
563a829
 
a3ee979
 
 
 
fa73381
50c0cb3
a3ee979
 
563a829
 
 
 
 
 
8994e5d
 
 
 
 
 
a3ee979
 
8994e5d
 
2eca6de
 
8994e5d
 
8c43e37
a3ee979
8994e5d
 
a3ee979
 
 
563a829
a3ee979
 
563a829
 
 
54a3362
a3ee979
fa73381
 
 
 
 
 
 
 
 
 
 
563a829
 
 
a3ee979
563a829
 
 
 
 
a3ee979
563a829
a3ee979
563a829
 
 
 
 
 
 
 
 
50c0cb3
f6c2567
a3ee979
 
 
 
 
563a829
0da80f6
563a829
 
 
c8d5da7
0da80f6
563a829
 
 
 
0da80f6
563a829
 
 
 
 
0da80f6
563a829
 
 
 
 
 
0da80f6
 
 
 
 
fa73381
0da80f6
 
 
 
 
 
 
563a829
 
 
 
 
0da80f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from glob import glob
from typing import Optional

import gradio as gr
import torch
from torchvision.transforms.functional import resize, to_pil_image
from transformers import AutoModel, CLIPProcessor

PAPER_TITLE = "Vocabulary-free Image Classification"
PAPER_URL = "https://arxiv.org/abs/2306.00917"
MARKDOWN_DESCRIPTION = """
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
    <h1>Vocabulary-free Image Classification</h1>
</div>

<div style="display: flex;
            flex-wrap: wrap;
            align-items: center;
            justify-content: center;
            margin-bottom: 1rem;">
    <a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem; margin-bottom: 0.5rem;">
        <img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
    </a>
    <a href="https://huggingface.co/spaces/altndrr/vic" style="margin-right: 0.5rem;
                                                               margin-bottom: 0.5rem;">
        <img src="https://img.shields.io/badge/demo-hf.altndrr%2Fvic-yellow.svg"/>
    </a>
    <a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;
                                                      margin-bottom: 0.5rem;">
        <img src="https://img.shields.io/badge/paper-arXiv.2306.00917-B31B1B.svg"/>
    </a>
    <a href="https://alessandroconti.me/papers/2306.00917" style="margin-right: 0.5rem;
                                                                  margin-bottom: 0.5rem;">
        <img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
    </a>
</div>
"""


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE)
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")


def save_original_image(image: gr.Image):
    if image is None:
        return None, None

    size = PROCESSOR.image_processor.size["shortest_edge"]
    size = min(size) if isinstance(size, tuple) else size
    image = resize(image, size)

    return image, image.copy()


def prepare_image(image: gr.Image):
    if image is None:
        return None, None

    PROCESSOR.image_processor.do_normalize = False
    image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True)
    PROCESSOR.image_processor.do_normalize = True
    image_tensor = image_tensor.pixel_values[0]
    curr_image = to_pil_image(image_tensor)

    return curr_image, image.copy()


def image_inference(image: gr.Image, alpha: Optional[float] = None):
    if image is None:
        return None

    images = PROCESSOR(images=[image], return_tensors="pt", padding=True)

    with torch.no_grad():
        outputs = MODEL(images, alpha=alpha)
    vocabulary = outputs["vocabularies"][0]
    scores = outputs["scores"][0].tolist()
    confidences = dict(zip(vocabulary, scores))

    return confidences


with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
    # LAYOUT
    gr.Markdown(MARKDOWN_DESCRIPTION)
    with gr.Row():
        with gr.Column():
            curr_image = gr.Image(label="input", type="pil")
            _orig_image = gr.Image(
                label="orig. image", type="pil", visible=False, interactive=False
            )
            alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha")
            with gr.Row():
                clear_button = gr.Button(value="Clear", variant="secondary")
                run_button = gr.Button(value="Submit", variant="primary")
        with gr.Column():
            output_label = gr.Label(label="output", num_top_classes=5)
    examples = gr.Examples(
        examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")),
        inputs=[_orig_image],
        outputs=[output_label],
        fn=image_inference,
        cache_examples=True,
    )
    gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.")

    # INTERACTIONS
    # - change
    _orig_image.change(prepare_image, [_orig_image], [curr_image, _orig_image])

    # - upload
    curr_image.upload(save_original_image, [curr_image], [curr_image, _orig_image])
    curr_image.upload(lambda: None, [], [output_label])

    # - clear
    curr_image.clear(lambda: (None, None), [], [_orig_image, output_label])

    # - click
    clear_button.click(lambda: (None, None, None), [], [curr_image, _orig_image, output_label])
    run_button.click(image_inference, [curr_image, alpha_slider], [output_label])


if __name__ == "__main__":
    demo.queue()
    demo.launch()