import torch from torch import nn import torch.nn.functional as F #FIX import config as CFG from modules import TextEncoder, ProjectionHead, ImageEncoder class PoemTextModel(nn.Module): """ Model predicting poem and text embeddings, and their similarities. ... Attributes: ----------- poem_encoder : TextEncoder encoder used for extracting poem embeddings text_encoder : TextEncoder encoder used for extracting text embeddings poem_projection: ProjectionHead projection head used for poem embeddings (projects poem encoder output to shared embedding space) text_projection: ProjectionHead projection head used for text embeddings (projects text encoder output to shared embedding space) temperature: float used to scale the dot similarities Methods: -------- forward(batch): returns poem and text embeddings of batch similarity_scores(batch): computes dot similarities of a batch of text-poem pair predict(batch): predicts the most similar poem idx for each text (using previous methods) calculate_loss(batch): computes contrastive (cross entropy) loss for both poems and texts. save_current(): saves current model's encoders (if trainable) and projection heads. """ def __init__( self, poem_encoder_pretrained, text_encoder_pretrained, temperature=CFG.temperature, poem_embedding=CFG.poem_embedding, text_embedding=CFG.text_embedding, ): """ Initializes model's submodules Parameters: ----------- poem_encoder_pretrained: bool whether or not to load a pretrained poem encoder. text_encoder_pretrained: bool whether or not to load a pretrained text encoder. temperature: float, optional used to scale the dot similarities poem_embedding: int, optional dim of poem encoder's encoding output before projection text_embedding: int, optional dim of text encoder's encoding output before projection """ super().__init__() self.poem_encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=poem_encoder_pretrained, trainable= CFG.poem_encoder_trainable) self.text_encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable= CFG.text_encoder_trainable) self.poem_projection = ProjectionHead(embedding_dim=poem_embedding) if CFG.poem_projection_load_path: # if provided, load projection weights from this path self.poem_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device)) self.text_projection = ProjectionHead(embedding_dim=text_embedding) if CFG.text_projection_load_path: # if provided, load projection weights from this path self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device)) self.temperature = temperature def forward(self, batch): """ returns poem and text embeddings of batch Parameters: ----------- batch: list of dict input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') Returns: -------- poem and text embeddings of batch (each of shape (batch_size, projection_dim)) """ beyts, texts = batch["beyt"], batch["text"] # Getting Beyt and Text Features poem_features = self.poem_encoder( input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"] ) text_features = self.text_encoder( input_ids=texts["input_ids"], attention_mask=texts["attention_mask"] ) # Getting Beyt and Text Embeddings (with same dimension) poem_embeddings = self.poem_projection(poem_features) text_embeddings = self.text_projection(text_features) return poem_embeddings, text_embeddings def similarity_scores(self, batch): """ computes dot similarities of a batch of text-poem pair Parameters: ----------- batch: list of dict input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') Returns: -------- dot similarity of poem and text embeddings of batch (of shape (batch_size, batch_size)) """ # Getting Beyt and Text Embeddings (with same dimension) poem_embeddings, text_embeddings = self.forward(batch) # Normalizing embeddings poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) # Computing dot / cosine similarity of the normalized embeddings dot_similarity = text_embeddings_n @ poem_embeddings_n.T return dot_similarity # (batch_size, batch_size) first dim is texts, second dim is poems for each text def predict(self, batch): """ predicts the most similar poem (idx) for each text (using previous methods) Parameters: ----------- batch: list of dict input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') Returns: -------- index of poem predicted for each text (of shape (batch_size)) """ dot_similarity = self.similarity_scores(batch) # Getting argmax in first dimension of the dot-similarities to predict index of the most similar poem for each text return torch.argmax(dot_similarity, dim=1) def calculate_loss(self, poem_embeddings, text_embeddings): """ computes contrastive (cross entropy) loss for both poems and texts. Parameters: ----------- poem_embeddings: of shape (batch_size, projection_dim) output embeddings of poem projection head text_embeddings: of shape (batch_size, projection_dim) output embeddings of text projection head Returns: -------- average of the loss computed from inputs """ # dot similarity of the embeddings scaled by temperature (logits) logits = (text_embeddings @ poem_embeddings.T) / self.temperature # computing targets for the cross entropy loss to compare with logits. # each embedding's similarity is computed with itself and then added, # scaled by the temperature parameter, and normalized into a probability distribution via a softmax poems_similarity = poem_embeddings @ poem_embeddings.T texts_similarity = text_embeddings @ text_embeddings.T targets = F.softmax( (poems_similarity + texts_similarity) / 2 * self.temperature, dim=-1 ) # taking cross entropy loss in both dimensions: once for texts and once for poems texts_loss = cross_entropy(logits, targets, reduction='none') poems_loss = cross_entropy(logits.T, targets.T, reduction='none') loss = (poems_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size) return loss.mean() def save_current(self): """ saves current model's encoders (if trainable) and projection heads. """ if CFG.text_encoder_trainable: self.text_encoder.model.save_pretrained(CFG.text_encoder_save_path) if CFG.poem_encoder_trainable: self.poem_encoder.model.save_pretrained(CFG.poem_encoder_save_path) torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path) torch.save(self.poem_projection.state_dict(), CFG.poem_projection_save_path) class CLIPModel(nn.Module): """ Model predicting poem/text and image embeddings, and their similarities. ... Attributes: ----------- encoder : TextEncoder encoder used for extracting poem/text embeddings image_encoder : ImageEncoder encoder used for extracting image embeddings text_projection: ProjectionHead projection head used for poem/text embeddings (projects text encoder output to shared embedding space) image_projection: ProjectionHead projection head used for image embeddings (projects image encoder output to shared embedding space) temperature: float used to scale the dot similarities Methods: -------- forward(batch): returns poem/text and image embeddings of batch similarity_scores(batch): computes dot similarities of a batch of text-image pair predict(batch): predicts the most similar poem/text idx for each image (using previous methods) calculate_loss(batch): computes contrastive (cross entropy) loss for both poems/texts and images. save_current(): saves current model's encoders (if trainable) and projection heads. """ def __init__( self, image_encoder_pretrained, text_encoder_pretrained, text_projection_trainable, temperature=CFG.temperature, image_embedding=CFG.image_embedding, text_embedding=CFG.text_embedding, is_image_poem_pair=True ): """ Initializes model's submodules Parameters: ----------- image_encoder_pretrained: bool whether or not to load a pretrained image encoder. text_encoder_pretrained: bool whether or not to load a pretrained text encoder. text_projection_trainable: bool whether or not to train text projection (since the text projection is frozen in our trainings unlike other projections of models) temperature: float, optional used to scale the dot similarities image_embedding: int, optional dim of image encoder's encoding output before projection text_embedding: int, optional dim of text encoder's encoding output before projection is_image_poem_pair: bool, optional if True, the text inputs to this model is poems and needs one of the poem encoders to predict embeddings with. else it's a text that needs the encoders dedicated to text. """ super().__init__() # Loading the encoders and their projections using configs self.image_encoder = ImageEncoder(pretrained=image_encoder_pretrained, trainable=CFG.image_encoder_trainable) if is_image_poem_pair: self.encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.poem_encoder_trainable) self.text_projection = ProjectionHead(embedding_dim=text_embedding) if CFG.poem_projection_load_path: self.text_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device)) else: self.encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.text_encoder_trainable) self.text_projection = ProjectionHead(embedding_dim=text_embedding) if CFG.text_projection_load_path: self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device)) self.image_projection = ProjectionHead(embedding_dim=image_embedding) if CFG.image_projection_load_path: self.image_projection.load_state_dict(torch.load(CFG.image_projection_load_path, map_location=CFG.device)) if not text_projection_trainable: for p in self.text_projection.parameters(): p.requires_grad = False self.text_projection_trainable = text_projection_trainable self.is_image_poem_pair = is_image_poem_pair self.temperature = temperature def forward(self, batch): """ returns image and text/poem embeddings of batch Parameters: ----------- batch: list of dict input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) with keys 'image' and 'text') Returns: -------- poem/text and image embeddings of batch (each of shape (batch_size, projection_dim)) """ image, texts = batch["image"], batch["text"] # Getting Image and Text Features image_features = self.image_encoder(batch["image"]) text_features = self.encoder( input_ids=texts["input_ids"], attention_mask=texts["attention_mask"] ) # Getting Image and Text Embeddings (with same dimension) image_embeddings = self.image_projection(image_features) text_embeddings = self.text_projection(text_features) return image_embeddings, text_embeddings def similarity_scores(self, batch): """ computes dot similarities of a batch of text/poem-image pair Parameters: ----------- batch: list of dict input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) with keys 'image' and 'text') Returns: -------- dot similarity of poem/text and image embeddings of batch (of shape (batch_size, batch_size)) """ # Getting Image and Text Embeddings (with same dimension) image_embeddings, text_embeddings = self.forward(batch) # Normalizing embeddings image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) # Computing dot / cosine similarity of the normalized embeddings dot_similarity = image_embeddings_n @ text_embeddings_n.T return dot_similarity # (batch_size, batch_size) first dim is images, second dim is poems/texts for each image def predict(self, batch): """ predicts the most similar poem/text (idx) for each image (using previous methods) Parameters: ----------- batch: list of dict input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) with keys 'image' and 'text') Returns: -------- index of poem/text predicted for each image (of shape (batch_size)) """ dot_similarity = self.similarity_scores(batch) # Getting argmax in first dimension of the dot-similarities # to predict index of the most similar poem/text for each image return torch.argmax(dot_similarity, dim=1) def calculate_loss(self, image_embeddings, text_embeddings): """ computes contrastive (cross entropy) loss for both poems/texts and images. Parameters: ----------- image_embeddings: of shape (batch_size, projection_dim) output embeddings of image projection head text_embeddings: of shape (batch_size, projection_dim) output embeddings of text projection head Returns: -------- average of the loss computed from inputs """ # dot similarity of the embeddings scaled by temperature (logits) logits = (text_embeddings @ image_embeddings.T) / self.temperature # computing targets for the cross entropy loss to compare with logits. # each embedding's similarity is computed with itself and then averaged, # scaled by the temperature parameter, and normalized into a probability distribution via a softmax images_similarity = image_embeddings @ image_embeddings.T texts_similarity = text_embeddings @ text_embeddings.T targets = F.softmax( (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 ) # taking cross entropy loss in both dimensions: once for texts and once for images texts_loss = cross_entropy(logits, targets, reduction='none') images_loss = cross_entropy(logits.T, targets.T, reduction='none') loss = (images_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size) return loss.mean() def save_current(self): """ saves current model's encoders and projection heads (if trainable). """ if self.is_image_poem_pair: if CFG.poem_encoder_trainable: self.encoder.model.save_pretrained(CFG.poem_encoder_save_path) else: if CFG.text_encoder_trainable: self.encoder.model.save_pretrained(CFG.text_encoder_save_path) if CFG.image_encoder_trainable: torch.save(self.image_encoder.model.state_dict(), CFG.image_encoder_weights_save_path) if self.text_projection_trainable: torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path) torch.save(self.image_projection.state_dict(), CFG.image_projection_save_path) def cross_entropy(preds, targets, reduction='none'): """ Computes cross_entropy of logits and targets using their last dimension Parameters: ----------- preds: tensor/numpy array logits targets: tensor/ numpy array reduction: str, optional if set to "mean", return loss mean across all dimensions. if set to "none", return loss computed using last dim. Returns: -------- loss or loss average """ log_softmax = nn.LogSoftmax(dim=-1) loss = (-targets * log_softmax(preds)).sum(1) # cross entropy loss if reduction == "none": return loss elif reduction == "mean": return loss.mean()