multimodalart's picture
Upload folder using huggingface_hub
5a510e7 verified
"""
image_proj_model.py
This module defines the ImageProjModel class, which is responsible for
projecting image embeddings into a different dimensional space. The model
leverages a linear transformation followed by a layer normalization to
reshape and normalize the input image embeddings for further processing in
cross-attention mechanisms or other downstream tasks.
Classes:
ImageProjModel
Dependencies:
torch
diffusers.ModelMixin
"""
import torch
from diffusers import ModelMixin
class ImageProjModel(ModelMixin):
"""
ImageProjModel is a class that projects image embeddings into a different
dimensional space. It inherits from ModelMixin, providing additional functionalities
specific to image projection.
Attributes:
cross_attention_dim (int): The dimension of the cross attention.
clip_embeddings_dim (int): The dimension of the CLIP embeddings.
clip_extra_context_tokens (int): The number of extra context tokens in CLIP.
Methods:
forward(image_embeds): Forward pass of the ImageProjModel, which takes in image
embeddings and returns the projected tokens.
"""
def __init__(
self,
cross_attention_dim=1024,
clip_embeddings_dim=1024,
clip_extra_context_tokens=4,
):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
"""
Forward pass of the ImageProjModel, which takes in image embeddings and returns the
projected tokens after reshaping and normalization.
Args:
image_embeds (torch.Tensor): The input image embeddings, with shape
batch_size x num_image_tokens x clip_embeddings_dim.
Returns:
clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping
and normalization, with shape batch_size x (clip_extra_context_tokens *
cross_attention_dim).
"""
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens