Spaces:
Runtime error
Runtime error
Rewrite interface (#1)
Browse files- Rewrite interface (2690b62e772a0b91fc15d70d26955a37557376ef)
- .gitignore +1 -0
- app.py +63 -43
- examples/log.csv +0 -11
- pyproject.toml +1 -1
- requirements.txt +5 -4
.gitignore
CHANGED
@@ -149,3 +149,4 @@ dmypy.json
|
|
149 |
|
150 |
# Template
|
151 |
/artifacts/models/databases/*/
|
|
|
|
149 |
|
150 |
# Template
|
151 |
/artifacts/models/databases/*/
|
152 |
+
/gradio_cached_examples/*
|
app.py
CHANGED
@@ -1,12 +1,19 @@
|
|
|
|
|
|
1 |
from typing import Optional
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
-
from
|
6 |
from transformers import AutoModel, CLIPProcessor
|
7 |
|
8 |
PAPER_TITLE = "Vocabulary-free Image Classification"
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
|
11 |
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;">
|
12 |
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
|
@@ -21,31 +28,35 @@ PAPER_DESCRIPTION = """
|
|
21 |
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
|
22 |
</a>
|
23 |
</div>
|
|
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
unconstrained semantic space by multimodal data from large vision-language databases.
|
30 |
|
31 |
-
To assign a label to an image, we:
|
32 |
-
1. extract the image features using a pre-trained Vision-Language Model (VLM);
|
33 |
-
2. retrieve the semantically most similar captions from a textual database;
|
34 |
-
3. extract from the captions a set of candidate categories by applying text parsing and filtering;
|
35 |
-
4. score the candidates using the multimodal aligned representation of the pre-trained VLM to
|
36 |
-
obtain the best-matching category.
|
37 |
-
"""
|
38 |
-
PAPER_URL = "https://arxiv.org/abs/2306.00917"
|
39 |
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
vocabulary = outputs["vocabularies"][0]
|
50 |
scores = outputs["scores"][0].tolist()
|
51 |
confidences = dict(zip(vocabulary, scores))
|
@@ -53,26 +64,35 @@ def vic(filename: str, alpha: Optional[float] = None):
|
|
53 |
return confidences
|
54 |
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
gr.
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
label="alpha"
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
)
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
from typing import Optional
|
4 |
|
5 |
import gradio as gr
|
6 |
import torch
|
7 |
+
from torchvision.transforms.functional import to_pil_image
|
8 |
from transformers import AutoModel, CLIPProcessor
|
9 |
|
10 |
PAPER_TITLE = "Vocabulary-free Image Classification"
|
11 |
+
PAPER_URL = "https://arxiv.org/abs/2306.00917"
|
12 |
+
MARKDOWN_DESCRIPTION = """
|
13 |
+
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
|
14 |
+
<h1>Vocabulary-free Image Classification</h1>
|
15 |
+
</div>
|
16 |
+
|
17 |
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
|
18 |
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;">
|
19 |
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
|
|
|
28 |
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
|
29 |
</a>
|
30 |
</div>
|
31 |
+
"""
|
32 |
|
33 |
|
34 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE)
|
36 |
+
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
def prepare_image(image: gr.Image):
|
40 |
+
if image is None:
|
41 |
+
return None, None
|
42 |
|
43 |
+
PROCESSOR.image_processor.do_normalize = False
|
44 |
+
image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True)
|
45 |
+
PROCESSOR.image_processor.do_normalize = True
|
46 |
+
image_tensor = image_tensor.pixel_values[0]
|
47 |
+
curr_image = to_pil_image(image_tensor)
|
48 |
|
49 |
+
return curr_image, image.copy()
|
50 |
|
51 |
+
|
52 |
+
def image_inference(image: gr.Image, alpha: Optional[float] = None):
|
53 |
+
if image is None:
|
54 |
+
return None
|
55 |
+
|
56 |
+
images = PROCESSOR(images=[image], return_tensors="pt", padding=True)
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
outputs = MODEL(images, alpha=alpha)
|
60 |
vocabulary = outputs["vocabularies"][0]
|
61 |
scores = outputs["scores"][0].tolist()
|
62 |
confidences = dict(zip(vocabulary, scores))
|
|
|
64 |
return confidences
|
65 |
|
66 |
|
67 |
+
with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
|
68 |
+
gr.Markdown(MARKDOWN_DESCRIPTION)
|
69 |
+
with gr.Row():
|
70 |
+
with gr.Column():
|
71 |
+
curr_image = gr.Image(label="input", type="pil", height=300)
|
72 |
+
orig_image = gr.Image(
|
73 |
+
label="orig. image", type="pil", visible=False, interactive=False
|
74 |
+
)
|
75 |
+
alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha")
|
76 |
+
with gr.Row():
|
77 |
+
clear_button = gr.ClearButton([curr_image, orig_image])
|
78 |
+
run_button = gr.Button(value="Submit", variant="primary")
|
79 |
+
with gr.Column():
|
80 |
+
output_label = gr.Label(label="output", num_top_classes=5)
|
81 |
+
examples = gr.Examples(
|
82 |
+
examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")),
|
83 |
+
inputs=[orig_image],
|
84 |
+
outputs=[output_label],
|
85 |
+
fn=image_inference,
|
86 |
+
cache_examples=True,
|
87 |
+
)
|
88 |
+
gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.")
|
89 |
+
|
90 |
+
curr_image.upload(prepare_image, [curr_image], [curr_image, orig_image])
|
91 |
+
curr_image.clear(lambda: None, [], [orig_image])
|
92 |
+
orig_image.change(prepare_image, [orig_image], [curr_image, orig_image])
|
93 |
+
run_button.click(image_inference, [curr_image, alpha_slider], [output_label])
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
demo.queue()
|
98 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
examples/log.csv
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
image_fp
|
2 |
-
basketball.jpg
|
3 |
-
cassowary.jpg
|
4 |
-
colosseum.jpg
|
5 |
-
desk.jpg
|
6 |
-
kitchen.jpg
|
7 |
-
monkey.jpg
|
8 |
-
park.jpg
|
9 |
-
ramen.jpg
|
10 |
-
sagrada.jpg
|
11 |
-
venice.jpg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
@@ -15,7 +15,7 @@ line_length = 99
|
|
15 |
count = true
|
16 |
ignore = ["E402"]
|
17 |
per-file-ignores = ["__init__.py:F401"]
|
18 |
-
exclude = [
|
19 |
max-line-length = 99
|
20 |
|
21 |
[tool.isort]
|
|
|
15 |
count = true
|
16 |
ignore = ["E402"]
|
17 |
per-file-ignores = ["__init__.py:F401"]
|
18 |
+
exclude = []
|
19 |
max-line-length = 99
|
20 |
|
21 |
[tool.isort]
|
requirements.txt
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
torch==2.0.1
|
|
|
2 |
faiss-cpu==1.7.4
|
3 |
-
flair==0.
|
4 |
-
gradio==
|
5 |
-
inflect==
|
6 |
nltk==3.8.1
|
7 |
-
transformers==4.
|
|
|
1 |
torch==2.0.1
|
2 |
+
torchvision==0.15.2
|
3 |
faiss-cpu==1.7.4
|
4 |
+
flair==0.13.0
|
5 |
+
gradio==4.7.1
|
6 |
+
inflect==7.0.0
|
7 |
nltk==3.8.1
|
8 |
+
transformers==4.35.1
|