import torch from torch import nn import timm import config as CFG class TextEncoder(nn.Module): """ Text/Poem encoder used in PoemTextModel and CLIPModel ... Attributes: ----------- model : a torch.nn.Module model The image encoder model Methods: -------- forward(x) returns model embeddings of x (batch of texts/poems) (of the CLS token) __init__() creates the encoder model using huggingface transformers, also freezes the model if it's not trainable. """ def __init__(self, encoder_model, encoder_pretrained_name, pretrained, trainable): """ creates the poem or text encoder model using transformers and loads weights from pretrained model if needed. Also freezes the model if it's not trainable. Parameters: ----------- pretrained: bool if pretrained=True, get pretrained model's weights. else create a fresh untrained model. trainable: bool if trainable=False, the model's weights will be frozen. encoder_model: str image encoder model name used as input to get the right model from configs. encoder_pretrained_name: str image encoder model to get weights from. (not used when pretrained=False) """ super().__init__() if pretrained: self.model = CFG.encoders[encoder_model].from_pretrained(encoder_pretrained_name) else: self.model = CFG.encoders[encoder_model](config=CFG.configs[encoder_model]()) for p in self.model.parameters(): p.requires_grad = trainable # Using the CLS token hidden representation as the sentence's embedding self.target_token_idx = 0 def forward(self, input_ids, attention_mask): """ forwards and calculates embeddings of the input using attention mask. Parameters: ----------- input_ids: input ids (output of tokenizer) attention masks: input masks (for example for padding, pad tokens will be masked) Returns: -------- the embedding of the CLS (or target) token of the encoder's last hidden state """ output = self.model(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = output.last_hidden_state return last_hidden_state[:, self.target_token_idx, :] class ProjectionHead(nn.Module): """ Projection head used to project embeddings from each encoder to a shared embedding space ... Attributes: ----------- projection : torch.nn.Linear The main Dense projection (from encoder's embedding dim to shared embedding projection dim) gelu: torch.nn.GELU activation function fc: torch.nn.Linear a dense layer after projection (projection_dim to projection_dim) dropout: torch.nn.Dropout dropout after fc layer_norm: torch.nn.LayerNorm layer norm after dropout Methods: -------- forward(x) returns projection embeddings from x (encoder output embeddings) __init__() creates the projection head """ def __init__( self, embedding_dim, projection_dim=CFG.projection_dim, dropout=CFG.dropout ): """ Creates the projection head used after an encoder. Parameters: ----------- embedding_dim: int dimension of the output embeddings of the encoder. projection_dim: int, optional dimension to project embeddings to. dropout: float fraction of the output of fc layer to be zeroed. """ super().__init__() self.projection = nn.Linear(embedding_dim, projection_dim) self.gelu = nn.GELU() self.fc = nn.Linear(projection_dim, projection_dim) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(projection_dim) def forward(self, x): """ Forwards and calculates projected embeddings from encoder embeddings. Parameters: ----------- x: input (of shape (batch_size, embedding_dim)) the output embedding of this projection head's encoder Returns: -------- the embeddings in a shared embedding space (of shape (batch_size, projection_dim)) """ projected = self.projection(x) #main projection layer x = self.gelu(projected) x = self.fc(x) x = self.dropout(x) # the projected outputs are added to x as a residual connection x = x + projected x = self.layer_norm(x) return x class ImageEncoder(nn.Module): """ Image encoder used in CLIPModel ... Attributes: ----------- model : a torch.nn.Module model from timm (pytorch-image-models) The image encoder model Methods: -------- forward(x) returns model embeddings of x (batch of images) __init__() creates the encoder model using timm and loads fine-tuned model's state dict if needed. also freezes the model if it's not trainable. """ def __init__( self, pretrained, trainable, model_name=CFG.image_encoder_model ): """ creates the encoder model using timm and loads fine-tuned model's state dict if needed. Also freezes the model if it's not trainable. Parameters: ----------- pretrained: bool if pretrained=True, get SOTA weights (or weights saved in image_encoder_weights_load_path). else create a fresh untrained model. trainable: bool if trainable=False, the model's weights will be frozen. model_name: str image encoder model name used as input to timm.create_model. """ super().__init__() self.model = timm.create_model( model_name, pretrained, num_classes=0, global_pool="avg" ) if pretrained and CFG.image_encoder_weights_load_path: self.model.load_state_dict(torch.load(CFG.image_encoder_weights_load_path, map_location=CFG.device)) for p in self.model.parameters(): p.requires_grad = trainable def forward(self, x): """ forwards and calculates embeddings of the input. Parameters: ----------- x: input (batch of transformed images) Returns: -------- embeddings of the model for the input (of shape (batch_size, image_embedding)) """ return self.model(x)