|
|
|
|
|
import torch |
|
from torch import nn |
|
from einops import rearrange |
|
from transformers import CLIPModel |
|
|
|
from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule |
|
|
|
|
|
class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): |
|
|
|
def __init__(self, *, |
|
shape_model, |
|
clip_model_version: str = "openai/clip-vit-large-patch14"): |
|
|
|
super().__init__() |
|
|
|
self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) |
|
for params in self.clip_model.parameters(): |
|
params.requires_grad = False |
|
|
|
self.shape_model = shape_model |
|
self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.clip_model.projection_dim)) |
|
nn.init.normal_(self.shape_projection, std=self.clip_model.projection_dim ** -0.5) |
|
|
|
def set_shape_model_only(self): |
|
self.clip_model = None |
|
|
|
def encode_shape_embed(self, surface, return_latents: bool = False): |
|
""" |
|
|
|
Args: |
|
surface (torch.FloatTensor): [bs, n, 3 + c] |
|
return_latents (bool): |
|
|
|
Returns: |
|
x (torch.FloatTensor): [bs, projection_dim] |
|
shape_latents (torch.FloatTensor): [bs, m, d] |
|
""" |
|
|
|
pc = surface[..., 0:3] |
|
feats = surface[..., 3:] |
|
|
|
shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) |
|
x = shape_embed @ self.shape_projection |
|
|
|
if return_latents: |
|
return x, shape_latents |
|
else: |
|
return x |
|
|
|
def encode_image_embed(self, image): |
|
""" |
|
|
|
Args: |
|
image (torch.FloatTensor): [bs, 3, h, w] |
|
|
|
Returns: |
|
x (torch.FloatTensor): [bs, projection_dim] |
|
""" |
|
|
|
x = self.clip_model.get_image_features(image) |
|
|
|
return x |
|
|
|
def encode_text_embed(self, text): |
|
x = self.clip_model.get_text_features(text) |
|
return x |
|
|
|
def forward(self, surface, image, text): |
|
""" |
|
|
|
Args: |
|
surface (torch.FloatTensor): |
|
image (torch.FloatTensor): [bs, 3, 224, 224] |
|
text (torch.LongTensor): [bs, num_templates, 77] |
|
|
|
Returns: |
|
embed_outputs (dict): the embedding outputs, and it contains: |
|
- image_embed (torch.FloatTensor): |
|
- text_embed (torch.FloatTensor): |
|
- shape_embed (torch.FloatTensor): |
|
- logit_scale (float): |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b = text.shape[0] |
|
text_tokens = rearrange(text, "b t l -> (b t) l") |
|
text_embed = self.encode_text_embed(text_tokens) |
|
text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) |
|
text_embed = text_embed.mean(dim=1) |
|
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) |
|
|
|
|
|
image_embed = self.encode_image_embed(image) |
|
|
|
|
|
shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) |
|
|
|
embed_outputs = { |
|
"image_embed": image_embed, |
|
"text_embed": text_embed, |
|
"shape_embed": shape_embed, |
|
"logit_scale": self.clip_model.logit_scale.exp() |
|
} |
|
|
|
return embed_outputs, shape_latents |
|
|