Spaces:
Running
on
A10G
Running
on
A10G
""" | |
image_proj_model.py | |
This module defines the ImageProjModel class, which is responsible for | |
projecting image embeddings into a different dimensional space. The model | |
leverages a linear transformation followed by a layer normalization to | |
reshape and normalize the input image embeddings for further processing in | |
cross-attention mechanisms or other downstream tasks. | |
Classes: | |
ImageProjModel | |
Dependencies: | |
torch | |
diffusers.ModelMixin | |
""" | |
import torch | |
from diffusers import ModelMixin | |
class ImageProjModel(ModelMixin): | |
""" | |
ImageProjModel is a class that projects image embeddings into a different | |
dimensional space. It inherits from ModelMixin, providing additional functionalities | |
specific to image projection. | |
Attributes: | |
cross_attention_dim (int): The dimension of the cross attention. | |
clip_embeddings_dim (int): The dimension of the CLIP embeddings. | |
clip_extra_context_tokens (int): The number of extra context tokens in CLIP. | |
Methods: | |
forward(image_embeds): Forward pass of the ImageProjModel, which takes in image | |
embeddings and returns the projected tokens. | |
""" | |
def __init__( | |
self, | |
cross_attention_dim=1024, | |
clip_embeddings_dim=1024, | |
clip_extra_context_tokens=4, | |
): | |
super().__init__() | |
self.generator = None | |
self.cross_attention_dim = cross_attention_dim | |
self.clip_extra_context_tokens = clip_extra_context_tokens | |
self.proj = torch.nn.Linear( | |
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim | |
) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds): | |
""" | |
Forward pass of the ImageProjModel, which takes in image embeddings and returns the | |
projected tokens after reshaping and normalization. | |
Args: | |
image_embeds (torch.Tensor): The input image embeddings, with shape | |
batch_size x num_image_tokens x clip_embeddings_dim. | |
Returns: | |
clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping | |
and normalization, with shape batch_size x (clip_extra_context_tokens * | |
cross_attention_dim). | |
""" | |
embeds = image_embeds | |
clip_extra_context_tokens = self.proj(embeds).reshape( | |
-1, self.clip_extra_context_tokens, self.cross_attention_dim | |
) | |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
return clip_extra_context_tokens | |