""" image_proj_model.py This module defines the ImageProjModel class, which is responsible for projecting image embeddings into a different dimensional space. The model leverages a linear transformation followed by a layer normalization to reshape and normalize the input image embeddings for further processing in cross-attention mechanisms or other downstream tasks. Classes: ImageProjModel Dependencies: torch diffusers.ModelMixin """ import torch from diffusers import ModelMixin class ImageProjModel(ModelMixin): """ ImageProjModel is a class that projects image embeddings into a different dimensional space. It inherits from ModelMixin, providing additional functionalities specific to image projection. Attributes: cross_attention_dim (int): The dimension of the cross attention. clip_embeddings_dim (int): The dimension of the CLIP embeddings. clip_extra_context_tokens (int): The number of extra context tokens in CLIP. Methods: forward(image_embeds): Forward pass of the ImageProjModel, which takes in image embeddings and returns the projected tokens. """ def __init__( self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4, ): super().__init__() self.generator = None self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear( clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim ) self.norm = torch.nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds): """ Forward pass of the ImageProjModel, which takes in image embeddings and returns the projected tokens after reshaping and normalization. Args: image_embeds (torch.Tensor): The input image embeddings, with shape batch_size x num_image_tokens x clip_embeddings_dim. Returns: clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping and normalization, with shape batch_size x (clip_extra_context_tokens * cross_attention_dim). """ embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens