Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import gc | |
import numpy as np | |
import pandas as pd | |
import itertools | |
from tqdm.autonotebook import tqdm | |
import albumentations as A | |
import matplotlib.pyplot as plt | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import timm | |
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer | |
class CFG: | |
debug = False | |
image_path = "./Images" | |
captions_path = "." | |
batch_size = 32 | |
num_workers = 2 | |
head_lr = 1e-3 | |
image_encoder_lr = 1e-4 | |
text_encoder_lr = 1e-5 | |
weight_decay = 1e-3 | |
patience = 1 | |
factor = 0.8 | |
epochs = 4 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_name = 'resnet50' | |
image_embedding = 2048 | |
text_encoder_model = "distilbert-base-uncased" | |
text_embedding = 768 | |
text_tokenizer = "distilbert-base-uncased" | |
max_length = 200 | |
pretrained = True # for both image encoder and text encoder | |
trainable = True # for both image encoder and text encoder | |
temperature = 1.0 | |
# image size | |
size = 224 | |
# for projection head; used for both image and text encoders | |
num_projection_layers = 1 | |
projection_dim = 256 | |
dropout = 0.1 | |
class AvgMeter: | |
def __init__(self, name="Metric"): | |
self.name = name | |
self.reset() | |
def reset(self): | |
self.avg, self.sum, self.count = [0] * 3 | |
def update(self, val, count=1): | |
self.count += count | |
self.sum += val * count | |
self.avg = self.sum / self.count | |
def __repr__(self): | |
text = f"{self.name}: {self.avg:.4f}" | |
return text | |
def get_lr(optimizer): | |
for param_group in optimizer.param_groups: | |
return param_group["lr"] | |
class CLIPDataset(torch.utils.data.Dataset): | |
def __init__(self, image_filenames, captions, tokenizer, transforms): | |
""" | |
image_filenames and cpations must have the same length; so, if there are | |
multiple captions for each image, the image_filenames must have repetitive | |
file names | |
""" | |
self.image_filenames = image_filenames | |
self.captions = list(captions) | |
self.encoded_captions = tokenizer( | |
list(captions), padding=True, truncation=True, max_length=CFG.max_length | |
) | |
self.transforms = transforms | |
def __getitem__(self, idx): | |
item = { | |
key: torch.tensor(values[idx]) | |
for key, values in self.encoded_captions.items() | |
} | |
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}") | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = self.transforms(image=image)['image'] | |
item['image'] = torch.tensor(image).permute(2, 0, 1).float() | |
item['caption'] = self.captions[idx] | |
return item | |
def __len__(self): | |
return len(self.captions) | |
def get_transforms(mode="train"): | |
if mode == "train": | |
return A.Compose( | |
[ | |
A.Resize(CFG.size, CFG.size, always_apply=True), | |
A.Normalize(max_pixel_value=255.0, always_apply=True), | |
] | |
) | |
else: | |
return A.Compose( | |
[ | |
A.Resize(CFG.size, CFG.size, always_apply=True), | |
A.Normalize(max_pixel_value=255.0, always_apply=True), | |
] | |
) | |
class ImageEncoder(nn.Module): | |
""" | |
Encode images to a fixed size vector | |
""" | |
def __init__( | |
self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable | |
): | |
super().__init__() | |
self.model = timm.create_model( | |
model_name, pretrained, num_classes=0, global_pool="avg" | |
) | |
for p in self.model.parameters(): | |
p.requires_grad = trainable | |
def forward(self, x): | |
return self.model(x) | |
class TextEncoder(nn.Module): | |
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable): | |
super().__init__() | |
if pretrained: | |
self.model = DistilBertModel.from_pretrained(model_name) | |
else: | |
self.model = DistilBertModel(config=DistilBertConfig()) | |
for p in self.model.parameters(): | |
p.requires_grad = trainable | |
# we are using the CLS token hidden representation as the sentence's embedding | |
self.target_token_idx = 0 | |
def forward(self, input_ids, attention_mask): | |
output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
last_hidden_state = output.last_hidden_state | |
return last_hidden_state[:, self.target_token_idx, :] | |
class ProjectionHead(nn.Module): | |
def __init__( | |
self, | |
embedding_dim, | |
projection_dim=CFG.projection_dim, | |
dropout=CFG.dropout | |
): | |
super().__init__() | |
self.projection = nn.Linear(embedding_dim, projection_dim) | |
self.gelu = nn.GELU() | |
self.fc = nn.Linear(projection_dim, projection_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(projection_dim) | |
def forward(self, x): | |
projected = self.projection(x) | |
x = self.gelu(projected) | |
x = self.fc(x) | |
x = self.dropout(x) | |
x = x + projected | |
x = self.layer_norm(x) | |
return x | |
class CLIPModel(nn.Module): | |
def __init__( | |
self, | |
temperature=CFG.temperature, | |
image_embedding=CFG.image_embedding, | |
text_embedding=CFG.text_embedding, | |
): | |
super().__init__() | |
self.image_encoder = ImageEncoder() | |
self.text_encoder = TextEncoder() | |
self.image_projection = ProjectionHead(embedding_dim=image_embedding) | |
self.text_projection = ProjectionHead(embedding_dim=text_embedding) | |
self.temperature = temperature | |
def forward(self, batch): | |
# Getting Image and Text Features | |
image_features = self.image_encoder(batch["image"]) | |
text_features = self.text_encoder( | |
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] | |
) | |
# Getting Image and Text Embeddings (with same dimension) | |
image_embeddings = self.image_projection(image_features) | |
text_embeddings = self.text_projection(text_features) | |
# Calculating the Loss | |
logits = (text_embeddings @ image_embeddings.T) / self.temperature | |
images_similarity = image_embeddings @ image_embeddings.T | |
texts_similarity = text_embeddings @ text_embeddings.T | |
targets = F.softmax( | |
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 | |
) | |
texts_loss = cross_entropy(logits, targets, reduction='none') | |
images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size) | |
return loss.mean() | |
def cross_entropy(preds, targets, reduction='none'): | |
log_softmax = nn.LogSoftmax(dim=-1) | |
loss = (-targets * log_softmax(preds)).sum(1) | |
if reduction == "none": | |
return loss | |
elif reduction == "mean": | |
return loss.mean() | |
def make_train_valid_dfs(): | |
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv") | |
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100 | |
image_ids = np.arange(0, max_id) | |
np.random.seed(42) | |
valid_ids = np.random.choice( | |
image_ids, size=int(0.2 * len(image_ids)), replace=False | |
) | |
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids] | |
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True) | |
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True) | |
return train_dataframe, valid_dataframe | |
def build_loaders(dataframe, tokenizer, mode): | |
transforms = get_transforms(mode=mode) | |
dataset = CLIPDataset( | |
dataframe["image"].values, | |
dataframe["caption"].values, | |
tokenizer=tokenizer, | |
transforms=transforms, | |
) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=CFG.batch_size, | |
num_workers=CFG.num_workers, | |
shuffle=True if mode == "train" else False, | |
) | |
return dataloader | |
def get_image_embeddings(valid_df, model_path): | |
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) | |
valid_loader = build_loaders(valid_df, tokenizer, mode="valid") | |
model = CLIPModel().to(CFG.device) | |
model.load_state_dict(torch.load(model_path, map_location=CFG.device)) | |
model.eval() | |
valid_image_embeddings = [] | |
with torch.no_grad(): | |
for batch in tqdm(valid_loader): | |
image_features = model.image_encoder(batch["image"].to(CFG.device)) | |
image_embeddings = model.image_projection(image_features) | |
valid_image_embeddings.append(image_embeddings) | |
return model, torch.cat(valid_image_embeddings) | |
def find_matches(model, image_embeddings, query, image_filenames, n=9): | |
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) | |
encoded_query = tokenizer([query]) | |
batch = { | |
key: torch.tensor(values).to(CFG.device) | |
for key, values in encoded_query.items() | |
} | |
with torch.no_grad(): | |
text_features = model.text_encoder( | |
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] | |
) | |
text_embeddings = model.text_projection(text_features) | |
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
dot_similarity = text_embeddings_n @ image_embeddings_n.T | |
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5) | |
matches = [image_filenames[idx] for idx in indices[::5]] | |
_, axes = plt.subplots(4, 4, figsize=(10, 10)) | |
results = [] | |
for match, ax in zip(matches, axes.flatten()): | |
image = cv2.imread(f"{CFG.image_path}/{match}") | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# ax.imshow(image) | |
# ax.axis("off") | |
results.append(image) | |
return results | |
def clip_image_search(model,image_embeddings, | |
query, | |
image_filenames, | |
n=16): | |
_, valid_df = make_train_valid_dfs() | |
model, image_embeddings = get_image_embeddings(valid_df, "best.pt") | |
return find_matches(model, | |
image_embeddings, | |
query, | |
image_filenames = valid_df['image'].values, | |
n=16) |