File size: 6,507 Bytes
7ff77f3
cd16641
7ff77f3
 
 
 
 
 
 
cd16641
7ff77f3
 
 
 
 
 
 
cd16641
7ff77f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd16641
7ff77f3
cd16641
7ff77f3
 
cd16641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ff77f3
 
cd16641
7ff77f3
cd16641
7ff77f3
cd16641
 
 
 
 
 
 
 
 
 
 
 
7ff77f3
 
cd16641
7ff77f3
 
 
 
 
 
 
cd16641
 
7ff77f3
 
 
cd16641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
from typing import Callable, Optional

import numpy as np
import torch
from transformers import CLIPModel, CLIPProcessor
from transformers.modeling_utils import PreTrainedModel

from .configuration_cased import CaSEDConfig
from .retrieval_cased import RetrievalDatabase, download_retrieval_databases
from .transforms_cased import default_vocabulary_transforms


class CaSEDModel(PreTrainedModel):
    """Transformers module for Category Search from External Databases (CaSED).

    Reference:
        - Conti et al. Vocabulary-free Image Classification. NeurIPS 2023.

    Args:
        config (CaSEDConfig): Configuration class for CaSED.
    """

    config_class = CaSEDConfig

    def __init__(self, config: CaSEDConfig):
        super().__init__(config)

        # load CLIP
        model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.vision_encoder = model.vision_model
        self.vision_proj = model.visual_projection
        self.language_encoder = model.text_model
        self.language_proj = model.text_projection
        self.logit_scale = model.logit_scale.exp()
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

        # set hparams
        self.hparams = {}
        self.hparams["alpha"] = config.alpha
        self.hparams["index_name"] = config.index_name
        self.hparams["retrieval_num_results"] = config.retrieval_num_results
        self.hparams["cache_dir"] = config.cache_dir

        # create cache dir
        os.makedirs(self.hparams["cache_dir"], exist_ok=True)

        # download data
        download_retrieval_databases(cache_dir=self.hparams["cache_dir"])

        # setup vocabulary
        self.vocabulary = RetrievalDatabase("cc12m", self.hparams["cache_dir"])
        self._vocab_transform = default_vocabulary_transforms()

    @property
    def vocab_transform(self) -> Callable:
        """Get image preprocess transform.

        The getter wraps the transform in a map_reduce function and applies it to a list of images.
        If interested in the transform itself, use `self._vocab_transform`.
        """
        vocab_transform = self._vocab_transform

        def vocabs_transforms(texts: list[str]) -> list[torch.Tensor]:
            return [vocab_transform(text) for text in texts]

        return vocabs_transforms

    def get_vocabulary(self, images_z: Optional[torch.Tensor] = None) -> list[list[str]]:
        """Get the vocabulary for a batch of images.

        Args:
            images_z (torch.Tensor): Batch of image embeddings.
        """
        num_samples = self.hparams["retrieval_num_results"]

        assert images_z is not None

        images_z = images_z / images_z.norm(dim=-1, keepdim=True)
        images_z = images_z.cpu().detach().numpy().tolist()

        if isinstance(images_z[0], float):
            images_z = [images_z]

        query = np.matrix(images_z).astype("float32")
        results = self.vocabulary.query(query, modality="text", num_samples=num_samples)

        vocabularies = [[r["caption"] for r in result] for result in results]
        return vocabularies

    def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor:
        """Forward pass.

        Args:
            images (dict): Dictionary with the images. The expected keys are:
                - pixel_values (torch.Tensor): Pixel values of the images.
            alpha (Optional[float]): Alpha value for the interpolation.
        """
        alpha = alpha or self.hparams["alpha"]

        # forward the images
        images["pixel_values"] = images["pixel_values"].to(self.device)
        images_z = self.vision_proj(self.vision_encoder(**images)[1])
        images_z = images_z / images_z.norm(dim=-1, keepdim=True)
        vocabularies = self.get_vocabulary(images_z=images_z)

        # encode unfiltered words
        unfiltered_words = sum(vocabularies, [])
        texts_z = self.processor(unfiltered_words, return_tensors="pt", padding=True)
        texts_z["input_ids"] = texts_z["input_ids"][:, :77].to(self.device)
        texts_z["attention_mask"] = texts_z["attention_mask"][:, :77].to(self.device)
        texts_z = self.language_encoder(**texts_z)[1]
        texts_z = self.language_proj(texts_z)
        texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)

        # generate a text embedding for each image from their unfiltered words
        unfiltered_words_per_image = [len(vocab) for vocab in vocabularies]
        texts_z = torch.split(texts_z, unfiltered_words_per_image)
        texts_z = torch.stack([text_z.mean(dim=0) for text_z in texts_z])
        texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)

        # filter the words and embed them
        vocabularies = self.vocab_transform(vocabularies)
        vocabularies = [vocab or ["object"] for vocab in vocabularies]
        words = sum(vocabularies, [])
        words_z = self.processor(words, return_tensors="pt", padding=True)
        words_z = {k: v.to(self.device) for k, v in words_z.items()}
        words_z = self.language_encoder(**words_z)[1]
        words_z = self.language_proj(words_z)
        words_z = words_z / words_z.norm(dim=-1, keepdim=True)

        # create a one-hot relation mask between images and words
        words_per_image = [len(vocab) for vocab in vocabularies]
        col_indices = torch.arange(sum(words_per_image))
        row_indices = torch.arange(len(images_z)).repeat_interleave(torch.tensor(words_per_image))
        mask = torch.zeros(len(images_z), sum(words_per_image), device=self.device)
        mask[row_indices, col_indices] = 1

        # get the image and text similarities
        images_z = images_z / images_z.norm(dim=-1, keepdim=True)
        texts_z = texts_z / texts_z.norm(dim=-1, keepdim=True)
        words_z = words_z / words_z.norm(dim=-1, keepdim=True)
        images_sim = self.logit_scale * images_z @ words_z.T
        texts_sim = self.logit_scale * texts_z @ words_z.T

        # mask unrelated words
        images_sim = torch.masked_fill(images_sim, mask == 0, float("-inf"))
        texts_sim = torch.masked_fill(texts_sim, mask == 0, float("-inf"))

        # get the image and text predictions
        images_p = images_sim.softmax(dim=-1)
        texts_p = texts_sim.softmax(dim=-1)

        # average the image and text predictions
        samples_p = alpha * images_p + (1 - alpha) * texts_p

        return {"scores": samples_p, "words": words, "vocabularies": vocabularies}