altndrr commited on
Commit
07a2d78
1 Parent(s): ef912f3
Files changed (5) hide show
  1. README.md +6 -6
  2. app.py +1 -14
  3. artifacts/models/retrieval/indices.json +1 -1
  4. src/nn.py +3 -3
  5. src/retrieval.py +7 -19
README.md CHANGED
@@ -20,9 +20,9 @@ Recent advances in large vision-language models have revolutionized the image cl
20
 
21
  <div align="center">
22
 
23
- | <img src="https://altndrr.github.io/vic/assets/images/task_left.png"> | <img src="https://altndrr.github.io/vic/assets/images/task_right.png"> |
24
- | :----------------------------------------------: | :----------------------------------------------: |
25
- | Vision Language Model (VLM)-based classification | Vocabulary-free Image Classification |
26
 
27
  </div>
28
 
@@ -30,7 +30,7 @@ In this work, we first empirically verify that representing this semantic space
30
 
31
  <div align="center">
32
 
33
- | <img src="https://altndrr.github.io/vic/assets/images/method.png"> |
34
  | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
35
  | Overview of CaSED. Given an input image, CaSED retrieves the most relevant captions from an external database filtering them to extract candidate categories. We classify image-to-text and text-to-text, using the retrieved captions centroid as the textual counterpart of the input image. |
36
 
@@ -42,11 +42,11 @@ If you find this work useful, please consider citing:
42
 
43
  ```latex
44
  @misc{conti2023vocabularyfree,
45
- title={Vocabulary-free Image Classification},
46
  author={Alessandro Conti and Enrico Fini and Massimiliano Mancini and Paolo Rota and Yiming Wang and Elisa Ricci},
47
  year={2023},
48
  eprint={2306.00917},
49
  archivePrefix={arXiv},
50
  primaryClass={cs.CV}
51
  }
52
- ```
 
20
 
21
  <div align="center">
22
 
23
+ | <img src="https://altndrr.github.io/vic/assets/images/task_left.png"> | <img src="https://altndrr.github.io/vic/assets/images/task_right.png"> |
24
+ | :-------------------------------------------------------------------: | :--------------------------------------------------------------------: |
25
+ | Vision Language Model (VLM)-based classification | Vocabulary-free Image Classification |
26
 
27
  </div>
28
 
 
30
 
31
  <div align="center">
32
 
33
+ | <img src="https://altndrr.github.io/vic/assets/images/method.png"> |
34
  | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
35
  | Overview of CaSED. Given an input image, CaSED retrieves the most relevant captions from an external database filtering them to extract candidate categories. We classify image-to-text and text-to-text, using the retrieved captions centroid as the textual counterpart of the input image. |
36
 
 
42
 
43
  ```latex
44
  @misc{conti2023vocabularyfree,
45
+ title={Vocabulary-free Image Classification},
46
  author={Alessandro Conti and Enrico Fini and Massimiliano Mancini and Paolo Rota and Yiming Wang and Elisa Ricci},
47
  year={2023},
48
  eprint={2306.00917},
49
  archivePrefix={arXiv},
50
  primaryClass={cs.CV}
51
  }
52
+ ```
app.py CHANGED
@@ -49,19 +49,6 @@ def vic(filename: str, alpha: Optional[float] = None):
49
 
50
  return confidences
51
 
52
- def resize_image(image, max_size: int = 256):
53
- """Resize image to max_size keeping the aspect ratio."""
54
- width, height = image.size
55
- if width > height:
56
- ratio = width / height
57
- new_width = max_size * ratio
58
- new_height = max_size
59
- else:
60
- ratio = height / width
61
- new_width = max_size
62
- new_height = max_size * ratio
63
- return image.resize((int(new_width), int(new_height)))
64
-
65
 
66
  demo = gr.Interface(
67
  fn=vic,
@@ -80,7 +67,7 @@ demo = gr.Interface(
80
  description=PAPER_DESCRIPTION,
81
  article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.",
82
  examples="./artifacts/examples/",
83
- allow_flagging='never',
84
  theme=gr.themes.Soft(),
85
  thumbnail="https://altndrr.github.io/vic/assets/images/method.png",
86
  )
 
49
 
50
  return confidences
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  demo = gr.Interface(
54
  fn=vic,
 
67
  description=PAPER_DESCRIPTION,
68
  article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.",
69
  examples="./artifacts/examples/",
70
+ allow_flagging="never",
71
  theme=gr.themes.Soft(),
72
  thumbnail="https://altndrr.github.io/vic/assets/images/method.png",
73
  )
artifacts/models/retrieval/indices.json CHANGED
@@ -1,3 +1,3 @@
1
  {
2
  "ViT-L-14_CC12M": "./artifacts/models/databases/cc12m/vit-l-14/"
3
- }
 
1
  {
2
  "ViT-L-14_CC12M": "./artifacts/models/databases/cc12m/vit-l-14/"
3
+ }
src/nn.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  from open_clip.transformer import Transformer
12
  from PIL import Image
13
 
14
- from src.retrieval import ArrowMetadataProvider, meta_to_dict
15
  from src.transforms import TextCompose, default_vocabulary_transforms
16
 
17
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -92,7 +92,7 @@ class CaSED(torch.nn.Module):
92
  # load faiss indices
93
  indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
94
  indices_fp = indices_list_dir / "indices.json"
95
- self.indices = json.load(open(indices_fp, "r"))
96
 
97
  # load faiss indices and metadata providers
98
  self.resources = {}
@@ -165,7 +165,7 @@ class CaSED(torch.nn.Module):
165
  output = {}
166
  meta = None if key + 1 > len(metadata) else metadata[key]
167
  if meta is not None:
168
- output.update(meta_to_dict(meta))
169
  output["id"] = i.item()
170
  output["similarity"] = d.item()
171
  results.append(output)
 
11
  from open_clip.transformer import Transformer
12
  from PIL import Image
13
 
14
+ from src.retrieval import ArrowMetadataProvider
15
  from src.transforms import TextCompose, default_vocabulary_transforms
16
 
17
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
92
  # load faiss indices
93
  indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
94
  indices_fp = indices_list_dir / "indices.json"
95
+ self.indices = json.load(open(indices_fp))
96
 
97
  # load faiss indices and metadata providers
98
  self.resources = {}
 
165
  output = {}
166
  meta = None if key + 1 > len(metadata) else metadata[key]
167
  if meta is not None:
168
+ output.update(meta)
169
  output["id"] = i.item()
170
  output["similarity"] = d.item()
171
  results.append(output)
src/retrieval.py CHANGED
@@ -1,17 +1,17 @@
1
  from pathlib import Path
 
2
 
3
- import pyarrow as pa
4
  import numpy as np
 
5
 
6
 
7
  class ArrowMetadataProvider:
8
  """The arrow metadata provider provides metadata from contiguous ids using arrow.
9
 
10
- Code taken from:
11
- https://github.dev/rom1504/clip-retrieval
12
  """
13
 
14
- def __init__(self, arrow_folder):
15
  arrow_files = [str(a) for a in sorted(Path(arrow_folder).glob("**/*")) if a.is_file()]
16
  self.table = pa.concat_tables(
17
  [
@@ -20,23 +20,11 @@ class ArrowMetadataProvider:
20
  ]
21
  )
22
 
23
- def get(self, ids, cols=None):
24
- """implement the get method from the arrow metadata provide, get metadata from ids"""
25
  if cols is None:
26
  cols = self.table.schema.names
27
  else:
28
  cols = list(set(self.table.schema.names) & set(cols))
29
- t = pa.concat_tables([self.table[i:(i + 1)] for i in ids])
30
  return t.select(cols).to_pandas().to_dict("records")
31
-
32
-
33
- def meta_to_dict(meta):
34
- """Convert a metadata list to a dictionary."""
35
- output = {}
36
- for k, v in meta.items():
37
- if isinstance(v, bytes):
38
- v = v.decode()
39
- elif type(v).__module__ == np.__name__:
40
- v = v.item()
41
- output[k] = v
42
- return output
 
1
  from pathlib import Path
2
+ from typing import Optional
3
 
 
4
  import numpy as np
5
+ import pyarrow as pa
6
 
7
 
8
  class ArrowMetadataProvider:
9
  """The arrow metadata provider provides metadata from contiguous ids using arrow.
10
 
11
+ Code taken from: https://github.dev/rom1504/clip-retrieval
 
12
  """
13
 
14
+ def __init__(self, arrow_folder: str):
15
  arrow_files = [str(a) for a in sorted(Path(arrow_folder).glob("**/*")) if a.is_file()]
16
  self.table = pa.concat_tables(
17
  [
 
20
  ]
21
  )
22
 
23
+ def get(self, ids: np.ndarray, cols: Optional[list] = None):
24
+ """Implement the get method from the arrow metadata provide, get metadata from ids."""
25
  if cols is None:
26
  cols = self.table.schema.names
27
  else:
28
  cols = list(set(self.table.schema.names) & set(cols))
29
+ t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)])
30
  return t.select(cols).to_pandas().to_dict("records")