File size: 3,772 Bytes
b621857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-

import torch
from torch import nn
from einops import rearrange
from transformers import CLIPModel

from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule


class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule):

    def __init__(self, *,
                 shape_model,
                 clip_model_version: str = "openai/clip-vit-large-patch14"):

        super().__init__()

        self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version)
        for params in self.clip_model.parameters():
            params.requires_grad = False

        self.shape_model = shape_model
        self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.clip_model.projection_dim))
        nn.init.normal_(self.shape_projection, std=self.clip_model.projection_dim ** -0.5)

    def set_shape_model_only(self):
        self.clip_model = None

    def encode_shape_embed(self, surface, return_latents: bool = False):
        """

        Args:
            surface (torch.FloatTensor): [bs, n, 3 + c]
            return_latents (bool):

        Returns:
            x (torch.FloatTensor): [bs, projection_dim]
            shape_latents (torch.FloatTensor): [bs, m, d]
        """

        pc = surface[..., 0:3]
        feats = surface[..., 3:]

        shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats)
        x = shape_embed @ self.shape_projection

        if return_latents:
            return x, shape_latents
        else:
            return x

    def encode_image_embed(self, image):
        """

        Args:
            image (torch.FloatTensor): [bs, 3, h, w]

        Returns:
            x (torch.FloatTensor): [bs, projection_dim]
        """

        x = self.clip_model.get_image_features(image)

        return x

    def encode_text_embed(self, text):
        x = self.clip_model.get_text_features(text)
        return x

    def forward(self, surface, image, text):
        """

        Args:
            surface (torch.FloatTensor):
            image (torch.FloatTensor): [bs, 3, 224, 224]
            text (torch.LongTensor): [bs, num_templates, 77]

        Returns:
            embed_outputs (dict): the embedding outputs, and it contains:
                - image_embed (torch.FloatTensor):
                - text_embed (torch.FloatTensor):
                - shape_embed (torch.FloatTensor):
                - logit_scale (float):
        """

        # # text embedding
        # text_embed_all = []
        # for i in range(text.shape[0]):
        #     text_for_one_sample = text[i]
        #     text_embed = self.encode_text_embed(text_for_one_sample)
        #     text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
        #     text_embed = text_embed.mean(dim=0)
        #     text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
        #     text_embed_all.append(text_embed)
        # text_embed_all = torch.stack(text_embed_all)

        b = text.shape[0]
        text_tokens = rearrange(text, "b t l -> (b t) l")
        text_embed = self.encode_text_embed(text_tokens)
        text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
        text_embed = text_embed.mean(dim=1)
        text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)

        # image embedding
        image_embed = self.encode_image_embed(image)

        # shape embedding
        shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True)

        embed_outputs = {
            "image_embed": image_embed,
            "text_embed": text_embed,
            "shape_embed": shape_embed,
            "logit_scale": self.clip_model.logit_scale.exp()
        }

        return embed_outputs, shape_latents