import torch import torch.nn as nn from torchtyping import TensorType from einops import rearrange from .image_encoders import get_image_encoder from .config import MultimodalConfig # ------------------------- Image prefix ---------------------------------- # for models that are fixed to a specific sequence lengths (i.e clip models with no pooling), the sequence lengths are below ENCODER_SEQ_LENS = { "clip_resnet": 49, "clip_resnet_large": 144, } ENCODER_OUT_DIMS = { "nfresnet50": 2048, "clip": 512, "clip_resnet": 2560, "clip_resnet_large": 3072, } class ImagePrefix(nn.Module): """ Takes in a batch of images and returns a batch of embeddings of the same dimensions as the LM's word embeddings. :param config: MultimodalConfig object :param out_dim: output dimension of the embedding :param device: device to run the model on """ def __init__( self, config: MultimodalConfig, out_dim: int = 2048, device=None, ): super().__init__() self.device = device or torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.config = config self.encoder_type = config.encoder_name # get image encoder backbone self.enc = get_image_encoder( config.encoder_name, pretrained=config.pretrained_img_encoder, ) self.encoder_out_dim = ENCODER_OUT_DIMS[ self.encoder_type ] # out dim for image encoder self.out_dim = out_dim # out dim for lm # set the out seq len to that specified in the config, or for some models, the hardcoded value self.out_seq_len = ( config.image_seq_len if config.encoder_name not in ENCODER_SEQ_LENS else ENCODER_SEQ_LENS[config.encoder_name] ) # get the output projection proj_out_dim = ( (self.out_dim * self.out_seq_len) if self.encoder_type not in ENCODER_SEQ_LENS else self.out_dim ) self.proj = nn.Linear(self.encoder_out_dim, proj_out_dim) self.dropout = nn.Dropout(config.image_embed_dropout_prob) self.use_layernorm = config.use_image_embed_layernorm if self.use_layernorm: self.ln = nn.LayerNorm(self.out_dim) def forward( self, x: TensorType["b", "c", "h", "w"] ) -> TensorType["b", "seq", "out_dim"]: # pass through image encoder logits = self.enc(x) # remove trailing dimensions of size 1 + pass through linear if logits.ndim == 4: logits = rearrange(logits, "b d 1 1 -> b d") elif logits.ndim == 3: assert self.encoder_type in ENCODER_SEQ_LENS else: assert logits.ndim == 2 logits = self.proj(logits) # reshape to desired output shape if ( self.encoder_type not in ENCODER_SEQ_LENS ): # don't need to reshape those with fixed seq lens / no pooling logits = rearrange( logits, "b (s d) -> b s d", d=self.out_dim, s=self.out_seq_len ) # pass through dropout and layer norm logits = self.dropout(logits) if self.use_layernorm: logits = self.ln(logits) return logits