File size: 9,845 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
from typing import Iterable, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from shap_e.models.download import default_cache_dir

ImageType = Union[np.ndarray, torch.Tensor, Image.Image]


class ImageCLIP(nn.Module):
    """
    A wrapper around a pre-trained CLIP model that automatically handles
    batches of texts, images, and embeddings.
    """

    def __init__(
        self,
        device: torch.device,
        dtype: Optional[torch.dtype] = torch.float32,
        ensure_used_params: bool = True,
        clip_name: str = "ViT-L/14",
        cache_dir: Optional[str] = None,
    ):
        super().__init__()

        assert clip_name in ["ViT-L/14", "ViT-B/32"]

        self.device = device
        self.ensure_used_params = ensure_used_params

        # Lazy import because of torchvision.
        import clip

        self.clip_model, self.preprocess = clip.load(
            clip_name, device=device, download_root=cache_dir or default_cache_dir()
        )
        self.clip_name = clip_name

        if dtype is not None:
            self.clip_model.to(dtype)
        self._tokenize = clip.tokenize

    @property
    def feature_dim(self) -> int:
        if self.clip_name == "ViT-L/14":
            return 768
        else:
            return 512

    @property
    def grid_size(self) -> int:
        if self.clip_name == "ViT-L/14":
            return 16
        else:
            return 7

    @property
    def grid_feature_dim(self) -> int:
        if self.clip_name == "ViT-L/14":
            return 1024
        else:
            return 768

    def forward(
        self,
        batch_size: int,
        images: Optional[Iterable[Optional[ImageType]]] = None,
        texts: Optional[Iterable[Optional[str]]] = None,
        embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
    ) -> torch.Tensor:
        """
        Generate a batch of embeddings from a mixture of images, texts,
        precomputed embeddings, and possibly empty values.

        For each batch element, at most one of images, texts, and embeddings
        should have a non-None value. Embeddings from multiple modalities
        cannot be mixed for a single batch element. If no modality is provided,
        a zero embedding will be used for the batch element.
        """
        image_seq = [None] * batch_size if images is None else list(images)
        text_seq = [None] * batch_size if texts is None else list(texts)
        embedding_seq = [None] * batch_size if embeddings is None else list(embeddings)
        assert len(image_seq) == batch_size, "number of images should match batch size"
        assert len(text_seq) == batch_size, "number of texts should match batch size"
        assert len(embedding_seq) == batch_size, "number of embeddings should match batch size"

        if self.ensure_used_params:
            return self._static_multimodal_embed(
                images=image_seq, texts=text_seq, embeddings=embedding_seq
            )

        result = torch.zeros((batch_size, self.feature_dim), device=self.device)
        index_images = []
        index_texts = []
        for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)):
            assert (
                sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2
            ), "only one modality may be non-None per batch element"
            if image is not None:
                index_images.append((i, image))
            elif text is not None:
                index_texts.append((i, text))
            elif emb is not None:
                result[i] = emb.to(result)

        if len(index_images):
            embs = self.embed_images((img for _, img in index_images))
            for (i, _), emb in zip(index_images, embs):
                result[i] = emb.to(result)
        if len(index_texts):
            embs = self.embed_text((text for _, text in index_texts))
            for (i, _), emb in zip(index_texts, embs):
                result[i] = emb.to(result)

        return result

    def _static_multimodal_embed(
        self,
        images: List[Optional[ImageType]] = None,
        texts: List[Optional[str]] = None,
        embeddings: List[Optional[torch.Tensor]] = None,
    ) -> torch.Tensor:
        """
        Like forward(), but always runs all encoders to ensure that
        the forward graph looks the same on every rank.
        """
        image_emb = self.embed_images(images)
        text_emb = self.embed_text(t if t else "" for t in texts)
        joined_embs = torch.stack(
            [
                emb.to(device=self.device, dtype=torch.float32)
                if emb is not None
                else torch.zeros(self.feature_dim, device=self.device)
                for emb in embeddings
            ],
            dim=0,
        )

        image_flag = torch.tensor([x is not None for x in images], device=self.device)[
            :, None
        ].expand_as(image_emb)
        text_flag = torch.tensor([x is not None for x in texts], device=self.device)[
            :, None
        ].expand_as(image_emb)
        emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[
            :, None
        ].expand_as(image_emb)

        return (
            image_flag.float() * image_emb
            + text_flag.float() * text_emb
            + emb_flag.float() * joined_embs
            + self.clip_model.logit_scale * 0  # avoid unused parameters
        )

    def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        """
        :param xs: N images, stored as numpy arrays, tensors, or PIL images.
        :return: an [N x D] tensor of features.
        """
        clip_inputs = self.images_to_tensor(xs)
        results = self.clip_model.encode_image(clip_inputs).float()
        return results / torch.linalg.norm(results, dim=-1, keepdim=True)

    def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
        """
        Embed text prompts as an [N x D] tensor.
        """
        enc = self.clip_model.encode_text(
            self._tokenize(list(prompts), truncate=True).to(self.device)
        ).float()
        return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)

    def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        """
        Embed images into latent grids.

        :param xs: an iterable of images to embed.
        :return: a tensor of shape [N x C x L], where L = self.grid_size**2.
        """
        if self.ensure_used_params:
            extra_value = 0.0
            for p in self.parameters():
                extra_value = extra_value + p.mean() * 0.0
        else:
            extra_value = 0.0

        x = self.images_to_tensor(xs).to(self.clip_model.dtype)

        # https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225
        vt = self.clip_model.visual
        x = vt.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat(
            [
                vt.class_embedding.to(x.dtype)
                + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
                x,
            ],
            dim=1,
        )  # shape = [*, grid ** 2 + 1, width]
        x = x + vt.positional_embedding.to(x.dtype)
        x = vt.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = vt.transformer(x)
        x = x.permute(1, 2, 0)  # LND -> NDL

        return x[..., 1:].contiguous().float() + extra_value

    def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device)


class FrozenImageCLIP:
    def __init__(self, device: torch.device, **kwargs):
        self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs)
        for parameter in self.model.parameters():
            parameter.requires_grad_(False)

    @property
    def feature_dim(self) -> int:
        return self.model.feature_dim

    @property
    def grid_size(self) -> int:
        return self.model.grid_size

    @property
    def grid_feature_dim(self) -> int:
        return self.model.grid_feature_dim

    def __call__(
        self,
        batch_size: int,
        images: Optional[Iterable[Optional[ImageType]]] = None,
        texts: Optional[Iterable[Optional[str]]] = None,
        embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
    ) -> torch.Tensor:
        # We don't do a no_grad() here so that gradients could still
        # flow to the input embeddings argument.
        # This behavior is currently not used, but it could be.
        return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings)

    def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        with torch.no_grad():
            return self.model.embed_images(xs)

    def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
        with torch.no_grad():
            return self.model.embed_text(prompts)

    def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
        with torch.no_grad():
            return self.model.embed_images_grid(xs)


def _image_to_pil(obj: Optional[ImageType]) -> Image.Image:
    if obj is None:
        return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))
    if isinstance(obj, np.ndarray):
        return Image.fromarray(obj.astype(np.uint8))
    elif isinstance(obj, torch.Tensor):
        return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))
    else:
        return obj