altndrr commited on
Commit
2690b62
1 Parent(s): 54f2384

Rewrite interface

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +63 -43
  3. examples/log.csv +0 -11
  4. pyproject.toml +1 -1
  5. requirements.txt +5 -4
.gitignore CHANGED
@@ -149,3 +149,4 @@ dmypy.json
149
 
150
  # Template
151
  /artifacts/models/databases/*/
 
 
149
 
150
  # Template
151
  /artifacts/models/databases/*/
152
+ /gradio_cached_examples/*
app.py CHANGED
@@ -1,12 +1,19 @@
 
 
1
  from typing import Optional
2
 
3
  import gradio as gr
4
  import torch
5
- from PIL import Image
6
  from transformers import AutoModel, CLIPProcessor
7
 
8
  PAPER_TITLE = "Vocabulary-free Image Classification"
9
- PAPER_DESCRIPTION = """
 
 
 
 
 
10
  <div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
11
  <a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;">
12
  <img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
@@ -21,31 +28,35 @@ PAPER_DESCRIPTION = """
21
  <img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
22
  </a>
23
  </div>
 
24
 
25
 
26
- Vocabulary-free Image Classification aims to assign a class to an image *without* prior knowledge
27
- on the list of class names, thus operating on the semantic class space that contains all the
28
- possible concepts. Our proposed method CaSED finds the best matching category within the
29
- unconstrained semantic space by multimodal data from large vision-language databases.
30
 
31
- To assign a label to an image, we:
32
- 1. extract the image features using a pre-trained Vision-Language Model (VLM);
33
- 2. retrieve the semantically most similar captions from a textual database;
34
- 3. extract from the captions a set of candidate categories by applying text parsing and filtering;
35
- 4. score the candidates using the multimodal aligned representation of the pre-trained VLM to
36
- obtain the best-matching category.
37
- """
38
- PAPER_URL = "https://arxiv.org/abs/2306.00917"
39
 
 
 
 
40
 
41
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- model = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(device)
43
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
 
 
44
 
 
45
 
46
- def vic(filename: str, alpha: Optional[float] = None):
47
- images = processor(images=[Image.open(filename)], return_tensors="pt", padding=True)
48
- outputs = model(images, alpha=alpha)
 
 
 
 
 
 
49
  vocabulary = outputs["vocabularies"][0]
50
  scores = outputs["scores"][0].tolist()
51
  confidences = dict(zip(vocabulary, scores))
@@ -53,26 +64,35 @@ def vic(filename: str, alpha: Optional[float] = None):
53
  return confidences
54
 
55
 
56
- demo = gr.Interface(
57
- fn=vic,
58
- inputs=[
59
- gr.Image(type="filepath", label="input"),
60
- gr.Slider(
61
- 0.0,
62
- 1.0,
63
- value=0.5,
64
- label="alpha",
65
- info="trade-off between the text (left) and image (right) modality",
66
- ),
67
- ],
68
- outputs=[gr.Label(num_top_classes=5, label="output")],
69
- title=PAPER_TITLE,
70
- description=PAPER_DESCRIPTION,
71
- article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.",
72
- examples="./examples/",
73
- allow_flagging="never",
74
- theme=gr.themes.Soft(),
75
- thumbnail="https://altndrr.github.io/vic/assets/images/method.png",
76
- )
77
-
78
- demo.launch(share=False)
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
  from typing import Optional
4
 
5
  import gradio as gr
6
  import torch
7
+ from torchvision.transforms.functional import to_pil_image
8
  from transformers import AutoModel, CLIPProcessor
9
 
10
  PAPER_TITLE = "Vocabulary-free Image Classification"
11
+ PAPER_URL = "https://arxiv.org/abs/2306.00917"
12
+ MARKDOWN_DESCRIPTION = """
13
+ <div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
14
+ <h1>Vocabulary-free Image Classification</h1>
15
+ </div>
16
+
17
  <div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
18
  <a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;">
19
  <img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
 
28
  <img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
29
  </a>
30
  </div>
31
+ """
32
 
33
 
34
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE)
36
+ PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
 
37
 
 
 
 
 
 
 
 
 
38
 
39
+ def prepare_image(image: gr.Image):
40
+ if image is None:
41
+ return None, None
42
 
43
+ PROCESSOR.image_processor.do_normalize = False
44
+ image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True)
45
+ PROCESSOR.image_processor.do_normalize = True
46
+ image_tensor = image_tensor.pixel_values[0]
47
+ curr_image = to_pil_image(image_tensor)
48
 
49
+ return curr_image, image.copy()
50
 
51
+
52
+ def image_inference(image: gr.Image, alpha: Optional[float] = None):
53
+ if image is None:
54
+ return None
55
+
56
+ images = PROCESSOR(images=[image], return_tensors="pt", padding=True)
57
+
58
+ with torch.no_grad():
59
+ outputs = MODEL(images, alpha=alpha)
60
  vocabulary = outputs["vocabularies"][0]
61
  scores = outputs["scores"][0].tolist()
62
  confidences = dict(zip(vocabulary, scores))
 
64
  return confidences
65
 
66
 
67
+ with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
68
+ gr.Markdown(MARKDOWN_DESCRIPTION)
69
+ with gr.Row():
70
+ with gr.Column():
71
+ curr_image = gr.Image(label="input", type="pil", height=300)
72
+ orig_image = gr.Image(
73
+ label="orig. image", type="pil", visible=False, interactive=False
74
+ )
75
+ alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha")
76
+ with gr.Row():
77
+ clear_button = gr.ClearButton([curr_image, orig_image])
78
+ run_button = gr.Button(value="Submit", variant="primary")
79
+ with gr.Column():
80
+ output_label = gr.Label(label="output", num_top_classes=5)
81
+ examples = gr.Examples(
82
+ examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")),
83
+ inputs=[orig_image],
84
+ outputs=[output_label],
85
+ fn=image_inference,
86
+ cache_examples=True,
87
+ )
88
+ gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.")
89
+
90
+ curr_image.upload(prepare_image, [curr_image], [curr_image, orig_image])
91
+ curr_image.clear(lambda: None, [], [orig_image])
92
+ orig_image.change(prepare_image, [orig_image], [curr_image, orig_image])
93
+ run_button.click(image_inference, [curr_image, alpha_slider], [output_label])
94
+
95
+
96
+ if __name__ == "__main__":
97
+ demo.queue()
98
+ demo.launch(server_name="0.0.0.0", server_port=7860)
examples/log.csv DELETED
@@ -1,11 +0,0 @@
1
- image_fp
2
- basketball.jpg
3
- cassowary.jpg
4
- colosseum.jpg
5
- desk.jpg
6
- kitchen.jpg
7
- monkey.jpg
8
- park.jpg
9
- ramen.jpg
10
- sagrada.jpg
11
- venice.jpg
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml CHANGED
@@ -15,7 +15,7 @@ line_length = 99
15
  count = true
16
  ignore = ["E402"]
17
  per-file-ignores = ["__init__.py:F401"]
18
- exclude = ["data/*","logs/*"]
19
  max-line-length = 99
20
 
21
  [tool.isort]
 
15
  count = true
16
  ignore = ["E402"]
17
  per-file-ignores = ["__init__.py:F401"]
18
+ exclude = []
19
  max-line-length = 99
20
 
21
  [tool.isort]
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
  torch==2.0.1
 
2
  faiss-cpu==1.7.4
3
- flair==0.12.2
4
- gradio==3.33.1
5
- inflect==6.0.4
6
  nltk==3.8.1
7
- transformers==4.29.2
 
1
  torch==2.0.1
2
+ torchvision==0.15.2
3
  faiss-cpu==1.7.4
4
+ flair==0.13.0
5
+ gradio==4.7.1
6
+ inflect==7.0.0
7
  nltk==3.8.1
8
+ transformers==4.35.1