Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
#FIX | |
import config as CFG | |
from modules import TextEncoder, ProjectionHead, ImageEncoder | |
class PoemTextModel(nn.Module): | |
""" | |
Model predicting poem and text embeddings, and their similarities. | |
... | |
Attributes: | |
----------- | |
poem_encoder : TextEncoder | |
encoder used for extracting poem embeddings | |
text_encoder : TextEncoder | |
encoder used for extracting text embeddings | |
poem_projection: ProjectionHead | |
projection head used for poem embeddings (projects poem encoder output to shared embedding space) | |
text_projection: ProjectionHead | |
projection head used for text embeddings (projects text encoder output to shared embedding space) | |
temperature: float | |
used to scale the dot similarities | |
Methods: | |
-------- | |
forward(batch): | |
returns poem and text embeddings of batch | |
similarity_scores(batch): | |
computes dot similarities of a batch of text-poem pair | |
predict(batch): | |
predicts the most similar poem idx for each text (using previous methods) | |
calculate_loss(batch): | |
computes contrastive (cross entropy) loss for both poems and texts. | |
save_current(): | |
saves current model's encoders (if trainable) and projection heads. | |
""" | |
def __init__( | |
self, | |
poem_encoder_pretrained, | |
text_encoder_pretrained, | |
temperature=CFG.temperature, | |
poem_embedding=CFG.poem_embedding, | |
text_embedding=CFG.text_embedding, | |
): | |
""" | |
Initializes model's submodules | |
Parameters: | |
----------- | |
poem_encoder_pretrained: bool | |
whether or not to load a pretrained poem encoder. | |
text_encoder_pretrained: bool | |
whether or not to load a pretrained text encoder. | |
temperature: float, optional | |
used to scale the dot similarities | |
poem_embedding: int, optional | |
dim of poem encoder's encoding output before projection | |
text_embedding: int, optional | |
dim of text encoder's encoding output before projection | |
""" | |
super().__init__() | |
self.poem_encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=poem_encoder_pretrained, trainable= CFG.poem_encoder_trainable) | |
self.text_encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable= CFG.text_encoder_trainable) | |
self.poem_projection = ProjectionHead(embedding_dim=poem_embedding) | |
if CFG.poem_projection_load_path: # if provided, load projection weights from this path | |
self.poem_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device)) | |
self.text_projection = ProjectionHead(embedding_dim=text_embedding) | |
if CFG.text_projection_load_path: # if provided, load projection weights from this path | |
self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device)) | |
self.temperature = temperature | |
def forward(self, batch): | |
""" | |
returns poem and text embeddings of batch | |
Parameters: | |
----------- | |
batch: list of dict | |
input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') | |
Returns: | |
-------- | |
poem and text embeddings of batch (each of shape (batch_size, projection_dim)) | |
""" | |
beyts, texts = batch["beyt"], batch["text"] | |
# Getting Beyt and Text Features | |
poem_features = self.poem_encoder( | |
input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"] | |
) | |
text_features = self.text_encoder( | |
input_ids=texts["input_ids"], attention_mask=texts["attention_mask"] | |
) | |
# Getting Beyt and Text Embeddings (with same dimension) | |
poem_embeddings = self.poem_projection(poem_features) | |
text_embeddings = self.text_projection(text_features) | |
return poem_embeddings, text_embeddings | |
def similarity_scores(self, batch): | |
""" | |
computes dot similarities of a batch of text-poem pair | |
Parameters: | |
----------- | |
batch: list of dict | |
input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') | |
Returns: | |
-------- | |
dot similarity of poem and text embeddings of batch (of shape (batch_size, batch_size)) | |
""" | |
# Getting Beyt and Text Embeddings (with same dimension) | |
poem_embeddings, text_embeddings = self.forward(batch) | |
# Normalizing embeddings | |
poem_embeddings_n = F.normalize(poem_embeddings, p=2, dim=-1) | |
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
# Computing dot / cosine similarity of the normalized embeddings | |
dot_similarity = text_embeddings_n @ poem_embeddings_n.T | |
return dot_similarity # (batch_size, batch_size) first dim is texts, second dim is poems for each text | |
def predict(self, batch): | |
""" | |
predicts the most similar poem (idx) for each text (using previous methods) | |
Parameters: | |
----------- | |
batch: list of dict | |
input (containing poem-text pairs (encoded using the encoder's tokenizer) with keys 'beyt' and 'text') | |
Returns: | |
-------- | |
index of poem predicted for each text (of shape (batch_size)) | |
""" | |
dot_similarity = self.similarity_scores(batch) | |
# Getting argmax in first dimension of the dot-similarities to predict index of the most similar poem for each text | |
return torch.argmax(dot_similarity, dim=1) | |
def calculate_loss(self, poem_embeddings, text_embeddings): | |
""" | |
computes contrastive (cross entropy) loss for both poems and texts. | |
Parameters: | |
----------- | |
poem_embeddings: of shape (batch_size, projection_dim) | |
output embeddings of poem projection head | |
text_embeddings: of shape (batch_size, projection_dim) | |
output embeddings of text projection head | |
Returns: | |
-------- | |
average of the loss computed from inputs | |
""" | |
# dot similarity of the embeddings scaled by temperature (logits) | |
logits = (text_embeddings @ poem_embeddings.T) / self.temperature | |
# computing targets for the cross entropy loss to compare with logits. | |
# each embedding's similarity is computed with itself and then added, | |
# scaled by the temperature parameter, and normalized into a probability distribution via a softmax | |
poems_similarity = poem_embeddings @ poem_embeddings.T | |
texts_similarity = text_embeddings @ text_embeddings.T | |
targets = F.softmax( | |
(poems_similarity + texts_similarity) / 2 * self.temperature, dim=-1 | |
) | |
# taking cross entropy loss in both dimensions: once for texts and once for poems | |
texts_loss = cross_entropy(logits, targets, reduction='none') | |
poems_loss = cross_entropy(logits.T, targets.T, reduction='none') | |
loss = (poems_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size) | |
return loss.mean() | |
def save_current(self): | |
""" | |
saves current model's encoders (if trainable) and projection heads. | |
""" | |
if CFG.text_encoder_trainable: | |
self.text_encoder.model.save_pretrained(CFG.text_encoder_save_path) | |
if CFG.poem_encoder_trainable: | |
self.poem_encoder.model.save_pretrained(CFG.poem_encoder_save_path) | |
torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path) | |
torch.save(self.poem_projection.state_dict(), CFG.poem_projection_save_path) | |
class CLIPModel(nn.Module): | |
""" | |
Model predicting poem/text and image embeddings, and their similarities. | |
... | |
Attributes: | |
----------- | |
encoder : TextEncoder | |
encoder used for extracting poem/text embeddings | |
image_encoder : ImageEncoder | |
encoder used for extracting image embeddings | |
text_projection: ProjectionHead | |
projection head used for poem/text embeddings (projects text encoder output to shared embedding space) | |
image_projection: ProjectionHead | |
projection head used for image embeddings (projects image encoder output to shared embedding space) | |
temperature: float | |
used to scale the dot similarities | |
Methods: | |
-------- | |
forward(batch): | |
returns poem/text and image embeddings of batch | |
similarity_scores(batch): | |
computes dot similarities of a batch of text-image pair | |
predict(batch): | |
predicts the most similar poem/text idx for each image (using previous methods) | |
calculate_loss(batch): | |
computes contrastive (cross entropy) loss for both poems/texts and images. | |
save_current(): | |
saves current model's encoders (if trainable) and projection heads. | |
""" | |
def __init__( | |
self, | |
image_encoder_pretrained, | |
text_encoder_pretrained, | |
text_projection_trainable, | |
temperature=CFG.temperature, | |
image_embedding=CFG.image_embedding, | |
text_embedding=CFG.text_embedding, | |
is_image_poem_pair=True | |
): | |
""" | |
Initializes model's submodules | |
Parameters: | |
----------- | |
image_encoder_pretrained: bool | |
whether or not to load a pretrained image encoder. | |
text_encoder_pretrained: bool | |
whether or not to load a pretrained text encoder. | |
text_projection_trainable: bool | |
whether or not to train text projection | |
(since the text projection is frozen in our trainings unlike other projections of models) | |
temperature: float, optional | |
used to scale the dot similarities | |
image_embedding: int, optional | |
dim of image encoder's encoding output before projection | |
text_embedding: int, optional | |
dim of text encoder's encoding output before projection | |
is_image_poem_pair: bool, optional | |
if True, the text inputs to this model is poems and needs one of the poem encoders to predict embeddings with. | |
else it's a text that needs the encoders dedicated to text. | |
""" | |
super().__init__() | |
# Loading the encoders and their projections using configs | |
self.image_encoder = ImageEncoder(pretrained=image_encoder_pretrained, trainable=CFG.image_encoder_trainable) | |
if is_image_poem_pair: | |
self.encoder = TextEncoder(CFG.poem_encoder_model, CFG.poem_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.poem_encoder_trainable) | |
self.text_projection = ProjectionHead(embedding_dim=text_embedding) | |
if CFG.poem_projection_load_path: | |
self.text_projection.load_state_dict(torch.load(CFG.poem_projection_load_path, map_location=CFG.device)) | |
else: | |
self.encoder = TextEncoder(CFG.text_encoder_model, CFG.text_encoder_pretrained_name, pretrained=text_encoder_pretrained, trainable=CFG.text_encoder_trainable) | |
self.text_projection = ProjectionHead(embedding_dim=text_embedding) | |
if CFG.text_projection_load_path: | |
self.text_projection.load_state_dict(torch.load(CFG.text_projection_load_path, map_location=CFG.device)) | |
self.image_projection = ProjectionHead(embedding_dim=image_embedding) | |
if CFG.image_projection_load_path: | |
self.image_projection.load_state_dict(torch.load(CFG.image_projection_load_path, map_location=CFG.device)) | |
if not text_projection_trainable: | |
for p in self.text_projection.parameters(): | |
p.requires_grad = False | |
self.text_projection_trainable = text_projection_trainable | |
self.is_image_poem_pair = is_image_poem_pair | |
self.temperature = temperature | |
def forward(self, batch): | |
""" | |
returns image and text/poem embeddings of batch | |
Parameters: | |
----------- | |
batch: list of dict | |
input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) | |
with keys 'image' and 'text') | |
Returns: | |
-------- | |
poem/text and image embeddings of batch (each of shape (batch_size, projection_dim)) | |
""" | |
image, texts = batch["image"], batch["text"] | |
# Getting Image and Text Features | |
image_features = self.image_encoder(batch["image"]) | |
text_features = self.encoder( | |
input_ids=texts["input_ids"], attention_mask=texts["attention_mask"] | |
) | |
# Getting Image and Text Embeddings (with same dimension) | |
image_embeddings = self.image_projection(image_features) | |
text_embeddings = self.text_projection(text_features) | |
return image_embeddings, text_embeddings | |
def similarity_scores(self, batch): | |
""" | |
computes dot similarities of a batch of text/poem-image pair | |
Parameters: | |
----------- | |
batch: list of dict | |
input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) | |
with keys 'image' and 'text') | |
Returns: | |
-------- | |
dot similarity of poem/text and image embeddings of batch (of shape (batch_size, batch_size)) | |
""" | |
# Getting Image and Text Embeddings (with same dimension) | |
image_embeddings, text_embeddings = self.forward(batch) | |
# Normalizing embeddings | |
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
# Computing dot / cosine similarity of the normalized embeddings | |
dot_similarity = image_embeddings_n @ text_embeddings_n.T | |
return dot_similarity # (batch_size, batch_size) first dim is images, second dim is poems/texts for each image | |
def predict(self, batch): | |
""" | |
predicts the most similar poem/text (idx) for each image (using previous methods) | |
Parameters: | |
----------- | |
batch: list of dict | |
input (containing image-text/poem pairs (text/poem encoded using the encoder's tokenizer) | |
with keys 'image' and 'text') | |
Returns: | |
-------- | |
index of poem/text predicted for each image (of shape (batch_size)) | |
""" | |
dot_similarity = self.similarity_scores(batch) | |
# Getting argmax in first dimension of the dot-similarities | |
# to predict index of the most similar poem/text for each image | |
return torch.argmax(dot_similarity, dim=1) | |
def calculate_loss(self, image_embeddings, text_embeddings): | |
""" | |
computes contrastive (cross entropy) loss for both poems/texts and images. | |
Parameters: | |
----------- | |
image_embeddings: of shape (batch_size, projection_dim) | |
output embeddings of image projection head | |
text_embeddings: of shape (batch_size, projection_dim) | |
output embeddings of text projection head | |
Returns: | |
-------- | |
average of the loss computed from inputs | |
""" | |
# dot similarity of the embeddings scaled by temperature (logits) | |
logits = (text_embeddings @ image_embeddings.T) / self.temperature | |
# computing targets for the cross entropy loss to compare with logits. | |
# each embedding's similarity is computed with itself and then averaged, | |
# scaled by the temperature parameter, and normalized into a probability distribution via a softmax | |
images_similarity = image_embeddings @ image_embeddings.T | |
texts_similarity = text_embeddings @ text_embeddings.T | |
targets = F.softmax( | |
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 | |
) | |
# taking cross entropy loss in both dimensions: once for texts and once for images | |
texts_loss = cross_entropy(logits, targets, reduction='none') | |
images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |
loss = (images_loss + texts_loss) / 2.0 # average of losses. shape: (batch_size) | |
return loss.mean() | |
def save_current(self): | |
""" | |
saves current model's encoders and projection heads (if trainable). | |
""" | |
if self.is_image_poem_pair: | |
if CFG.poem_encoder_trainable: | |
self.encoder.model.save_pretrained(CFG.poem_encoder_save_path) | |
else: | |
if CFG.text_encoder_trainable: | |
self.encoder.model.save_pretrained(CFG.text_encoder_save_path) | |
if CFG.image_encoder_trainable: | |
torch.save(self.image_encoder.model.state_dict(), CFG.image_encoder_weights_save_path) | |
if self.text_projection_trainable: | |
torch.save(self.text_projection.state_dict(), CFG.text_projection_save_path) | |
torch.save(self.image_projection.state_dict(), CFG.image_projection_save_path) | |
def cross_entropy(preds, targets, reduction='none'): | |
""" | |
Computes cross_entropy of logits and targets using their last dimension | |
Parameters: | |
----------- | |
preds: tensor/numpy array | |
logits | |
targets: tensor/ numpy array | |
reduction: str, optional | |
if set to "mean", return loss mean across all dimensions. | |
if set to "none", return loss computed using last dim. | |
Returns: | |
-------- | |
loss or loss average | |
""" | |
log_softmax = nn.LogSoftmax(dim=-1) | |
loss = (-targets * log_softmax(preds)).sum(1) # cross entropy loss | |
if reduction == "none": | |
return loss | |
elif reduction == "mean": | |
return loss.mean() |