# -*- coding: utf-8 -*- import torch from torch import nn from einops import rearrange from transformers import CLIPModel from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): def __init__(self, *, shape_model, clip_model_version: str = "openai/clip-vit-large-patch14"): super().__init__() self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) for params in self.clip_model.parameters(): params.requires_grad = False self.shape_model = shape_model self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.clip_model.projection_dim)) nn.init.normal_(self.shape_projection, std=self.clip_model.projection_dim ** -0.5) def set_shape_model_only(self): self.clip_model = None def encode_shape_embed(self, surface, return_latents: bool = False): """ Args: surface (torch.FloatTensor): [bs, n, 3 + c] return_latents (bool): Returns: x (torch.FloatTensor): [bs, projection_dim] shape_latents (torch.FloatTensor): [bs, m, d] """ pc = surface[..., 0:3] feats = surface[..., 3:] shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) x = shape_embed @ self.shape_projection if return_latents: return x, shape_latents else: return x def encode_image_embed(self, image): """ Args: image (torch.FloatTensor): [bs, 3, h, w] Returns: x (torch.FloatTensor): [bs, projection_dim] """ x = self.clip_model.get_image_features(image) return x def encode_text_embed(self, text): x = self.clip_model.get_text_features(text) return x def forward(self, surface, image, text): """ Args: surface (torch.FloatTensor): image (torch.FloatTensor): [bs, 3, 224, 224] text (torch.LongTensor): [bs, num_templates, 77] Returns: embed_outputs (dict): the embedding outputs, and it contains: - image_embed (torch.FloatTensor): - text_embed (torch.FloatTensor): - shape_embed (torch.FloatTensor): - logit_scale (float): """ # # text embedding # text_embed_all = [] # for i in range(text.shape[0]): # text_for_one_sample = text[i] # text_embed = self.encode_text_embed(text_for_one_sample) # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) # text_embed = text_embed.mean(dim=0) # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) # text_embed_all.append(text_embed) # text_embed_all = torch.stack(text_embed_all) b = text.shape[0] text_tokens = rearrange(text, "b t l -> (b t) l") text_embed = self.encode_text_embed(text_tokens) text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) text_embed = text_embed.mean(dim=1) text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) # image embedding image_embed = self.encode_image_embed(image) # shape embedding shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) embed_outputs = { "image_embed": image_embed, "text_embed": text_embed, "shape_embed": shape_embed, "logit_scale": self.clip_model.logit_scale.exp() } return embed_outputs, shape_latents