Spaces:
Runtime error
Runtime error
Lint code
Browse files- README.md +6 -6
- app.py +1 -14
- artifacts/models/retrieval/indices.json +1 -1
- src/nn.py +3 -3
- src/retrieval.py +7 -19
README.md
CHANGED
@@ -20,9 +20,9 @@ Recent advances in large vision-language models have revolutionized the image cl
|
|
20 |
|
21 |
<div align="center">
|
22 |
|
23 |
-
| <img src="https://altndrr.github.io/vic/assets/images/task_left.png">
|
24 |
-
|
|
25 |
-
|
|
26 |
|
27 |
</div>
|
28 |
|
@@ -30,7 +30,7 @@ In this work, we first empirically verify that representing this semantic space
|
|
30 |
|
31 |
<div align="center">
|
32 |
|
33 |
-
|
|
34 |
| :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
35 |
| Overview of CaSED. Given an input image, CaSED retrieves the most relevant captions from an external database filtering them to extract candidate categories. We classify image-to-text and text-to-text, using the retrieved captions centroid as the textual counterpart of the input image. |
|
36 |
|
@@ -42,11 +42,11 @@ If you find this work useful, please consider citing:
|
|
42 |
|
43 |
```latex
|
44 |
@misc{conti2023vocabularyfree,
|
45 |
-
title={Vocabulary-free Image Classification},
|
46 |
author={Alessandro Conti and Enrico Fini and Massimiliano Mancini and Paolo Rota and Yiming Wang and Elisa Ricci},
|
47 |
year={2023},
|
48 |
eprint={2306.00917},
|
49 |
archivePrefix={arXiv},
|
50 |
primaryClass={cs.CV}
|
51 |
}
|
52 |
-
```
|
|
|
20 |
|
21 |
<div align="center">
|
22 |
|
23 |
+
| <img src="https://altndrr.github.io/vic/assets/images/task_left.png"> | <img src="https://altndrr.github.io/vic/assets/images/task_right.png"> |
|
24 |
+
| :-------------------------------------------------------------------: | :--------------------------------------------------------------------: |
|
25 |
+
| Vision Language Model (VLM)-based classification | Vocabulary-free Image Classification |
|
26 |
|
27 |
</div>
|
28 |
|
|
|
30 |
|
31 |
<div align="center">
|
32 |
|
33 |
+
| <img src="https://altndrr.github.io/vic/assets/images/method.png"> |
|
34 |
| :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
35 |
| Overview of CaSED. Given an input image, CaSED retrieves the most relevant captions from an external database filtering them to extract candidate categories. We classify image-to-text and text-to-text, using the retrieved captions centroid as the textual counterpart of the input image. |
|
36 |
|
|
|
42 |
|
43 |
```latex
|
44 |
@misc{conti2023vocabularyfree,
|
45 |
+
title={Vocabulary-free Image Classification},
|
46 |
author={Alessandro Conti and Enrico Fini and Massimiliano Mancini and Paolo Rota and Yiming Wang and Elisa Ricci},
|
47 |
year={2023},
|
48 |
eprint={2306.00917},
|
49 |
archivePrefix={arXiv},
|
50 |
primaryClass={cs.CV}
|
51 |
}
|
52 |
+
```
|
app.py
CHANGED
@@ -49,19 +49,6 @@ def vic(filename: str, alpha: Optional[float] = None):
|
|
49 |
|
50 |
return confidences
|
51 |
|
52 |
-
def resize_image(image, max_size: int = 256):
|
53 |
-
"""Resize image to max_size keeping the aspect ratio."""
|
54 |
-
width, height = image.size
|
55 |
-
if width > height:
|
56 |
-
ratio = width / height
|
57 |
-
new_width = max_size * ratio
|
58 |
-
new_height = max_size
|
59 |
-
else:
|
60 |
-
ratio = height / width
|
61 |
-
new_width = max_size
|
62 |
-
new_height = max_size * ratio
|
63 |
-
return image.resize((int(new_width), int(new_height)))
|
64 |
-
|
65 |
|
66 |
demo = gr.Interface(
|
67 |
fn=vic,
|
@@ -80,7 +67,7 @@ demo = gr.Interface(
|
|
80 |
description=PAPER_DESCRIPTION,
|
81 |
article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.",
|
82 |
examples="./artifacts/examples/",
|
83 |
-
allow_flagging=
|
84 |
theme=gr.themes.Soft(),
|
85 |
thumbnail="https://altndrr.github.io/vic/assets/images/method.png",
|
86 |
)
|
|
|
49 |
|
50 |
return confidences
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
demo = gr.Interface(
|
54 |
fn=vic,
|
|
|
67 |
description=PAPER_DESCRIPTION,
|
68 |
article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.",
|
69 |
examples="./artifacts/examples/",
|
70 |
+
allow_flagging="never",
|
71 |
theme=gr.themes.Soft(),
|
72 |
thumbnail="https://altndrr.github.io/vic/assets/images/method.png",
|
73 |
)
|
artifacts/models/retrieval/indices.json
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
{
|
2 |
"ViT-L-14_CC12M": "./artifacts/models/databases/cc12m/vit-l-14/"
|
3 |
-
}
|
|
|
1 |
{
|
2 |
"ViT-L-14_CC12M": "./artifacts/models/databases/cc12m/vit-l-14/"
|
3 |
+
}
|
src/nn.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
from open_clip.transformer import Transformer
|
12 |
from PIL import Image
|
13 |
|
14 |
-
from src.retrieval import ArrowMetadataProvider
|
15 |
from src.transforms import TextCompose, default_vocabulary_transforms
|
16 |
|
17 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -92,7 +92,7 @@ class CaSED(torch.nn.Module):
|
|
92 |
# load faiss indices
|
93 |
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
|
94 |
indices_fp = indices_list_dir / "indices.json"
|
95 |
-
self.indices = json.load(open(indices_fp
|
96 |
|
97 |
# load faiss indices and metadata providers
|
98 |
self.resources = {}
|
@@ -165,7 +165,7 @@ class CaSED(torch.nn.Module):
|
|
165 |
output = {}
|
166 |
meta = None if key + 1 > len(metadata) else metadata[key]
|
167 |
if meta is not None:
|
168 |
-
output.update(
|
169 |
output["id"] = i.item()
|
170 |
output["similarity"] = d.item()
|
171 |
results.append(output)
|
|
|
11 |
from open_clip.transformer import Transformer
|
12 |
from PIL import Image
|
13 |
|
14 |
+
from src.retrieval import ArrowMetadataProvider
|
15 |
from src.transforms import TextCompose, default_vocabulary_transforms
|
16 |
|
17 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
92 |
# load faiss indices
|
93 |
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
|
94 |
indices_fp = indices_list_dir / "indices.json"
|
95 |
+
self.indices = json.load(open(indices_fp))
|
96 |
|
97 |
# load faiss indices and metadata providers
|
98 |
self.resources = {}
|
|
|
165 |
output = {}
|
166 |
meta = None if key + 1 > len(metadata) else metadata[key]
|
167 |
if meta is not None:
|
168 |
+
output.update(meta)
|
169 |
output["id"] = i.item()
|
170 |
output["similarity"] = d.item()
|
171 |
results.append(output)
|
src/retrieval.py
CHANGED
@@ -1,17 +1,17 @@
|
|
1 |
from pathlib import Path
|
|
|
2 |
|
3 |
-
import pyarrow as pa
|
4 |
import numpy as np
|
|
|
5 |
|
6 |
|
7 |
class ArrowMetadataProvider:
|
8 |
"""The arrow metadata provider provides metadata from contiguous ids using arrow.
|
9 |
|
10 |
-
Code taken from:
|
11 |
-
https://github.dev/rom1504/clip-retrieval
|
12 |
"""
|
13 |
|
14 |
-
def __init__(self, arrow_folder):
|
15 |
arrow_files = [str(a) for a in sorted(Path(arrow_folder).glob("**/*")) if a.is_file()]
|
16 |
self.table = pa.concat_tables(
|
17 |
[
|
@@ -20,23 +20,11 @@ class ArrowMetadataProvider:
|
|
20 |
]
|
21 |
)
|
22 |
|
23 |
-
def get(self, ids, cols=None):
|
24 |
-
"""
|
25 |
if cols is None:
|
26 |
cols = self.table.schema.names
|
27 |
else:
|
28 |
cols = list(set(self.table.schema.names) & set(cols))
|
29 |
-
t = pa.concat_tables([self.table[i:
|
30 |
return t.select(cols).to_pandas().to_dict("records")
|
31 |
-
|
32 |
-
|
33 |
-
def meta_to_dict(meta):
|
34 |
-
"""Convert a metadata list to a dictionary."""
|
35 |
-
output = {}
|
36 |
-
for k, v in meta.items():
|
37 |
-
if isinstance(v, bytes):
|
38 |
-
v = v.decode()
|
39 |
-
elif type(v).__module__ == np.__name__:
|
40 |
-
v = v.item()
|
41 |
-
output[k] = v
|
42 |
-
return output
|
|
|
1 |
from pathlib import Path
|
2 |
+
from typing import Optional
|
3 |
|
|
|
4 |
import numpy as np
|
5 |
+
import pyarrow as pa
|
6 |
|
7 |
|
8 |
class ArrowMetadataProvider:
|
9 |
"""The arrow metadata provider provides metadata from contiguous ids using arrow.
|
10 |
|
11 |
+
Code taken from: https://github.dev/rom1504/clip-retrieval
|
|
|
12 |
"""
|
13 |
|
14 |
+
def __init__(self, arrow_folder: str):
|
15 |
arrow_files = [str(a) for a in sorted(Path(arrow_folder).glob("**/*")) if a.is_file()]
|
16 |
self.table = pa.concat_tables(
|
17 |
[
|
|
|
20 |
]
|
21 |
)
|
22 |
|
23 |
+
def get(self, ids: np.ndarray, cols: Optional[list] = None):
|
24 |
+
"""Implement the get method from the arrow metadata provide, get metadata from ids."""
|
25 |
if cols is None:
|
26 |
cols = self.table.schema.names
|
27 |
else:
|
28 |
cols = list(set(self.table.schema.names) & set(cols))
|
29 |
+
t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)])
|
30 |
return t.select(cols).to_pandas().to_dict("records")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|