from torch import nn from tqdm.autonotebook import tqdm from transformers import AutoTokenizer, AutoModel from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer import albumentations as A import cv2 import timm import torch import torch.nn.functional as F device = torch.device("cpu") class CFG: debug = False image_path = '/content/content/new_images_v5' captions_path = '/content/content/all_data/new_caption.csv' batch_size = 12 num_workers = 2 head_lr = 1e-3 image_encoder_lr = 1e-4 text_encoder_lr = 1e-5 weight_decay = 1e-3 patience = 1 factor = 0.8 epochs = 2 saved_model_clinical = '/content/content/new_weights.pt' trained_model = 'clinical_bert_weights.pt' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = 'resnet50' image_embedding = 2048 text_encoder_model = "distilbert-base-uncased" clinical_encoder_model = "emilyalsentzer/Bio_ClinicalBERT" text_embedding = 768 text_tokenizer = "distilbert-base-uncased" max_length = 200 pretrained = True # for both image encoder and text encoder trainable = True # for both image encoder and text encoder temperature = 1.0 # image size size = 224 # for projection head; used for both image and text encoders num_projection_layers = 1 projection_dim = 256 dropout = 0.1 def build_loaders(dataframe, tokenizer, mode): transforms = get_transforms(mode=mode) dataset = CLIPDataset( dataframe["image"].values, dataframe["caption"].values, tokenizer=tokenizer, transforms=transforms, ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=CFG.batch_size, num_workers=CFG.num_workers, shuffle=True if mode == "train" else False, ) return dataloader class AvgMeter: def __init__(self, name="Metric"): self.name = name self.reset() def reset(self): self.avg, self.sum, self.count = [0] * 3 def update(self, val, count=1): self.count += count self.sum += val * count self.avg = self.sum / self.count def __repr__(self): text = f"{self.name}: {self.avg:.4f}" return text def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group["lr"] # Custom dataset object. Will tokenize text and apply transforms to images before yielding them. class CLIPDataset(torch.utils.data.Dataset): def __init__(self, image_filenames, captions, tokenizer, transforms): """ image_filenames and cpations must have the same length; so, if there are multiple captions for each image, the image_filenames must have repetitive file names """ self.image_filenames = image_filenames self.captions = list(captions) self.skippedImgCount = 0 self.encoded_captions = tokenizer( list(captions), padding=True, truncation=True, max_length=CFG.max_length ) self.transforms = transforms def __getitem__(self, idx): item = { key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items() } image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}") if image is None: # Skip the current example and move to the next one self.skippedImgCount += 1 return self.__getitem__((idx + 1) % len(self)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = self.transforms(image=image)['image'] item['image'] = torch.tensor(image).permute(2, 0, 1).float() item['caption'] = self.captions[idx] return item def __len__(self): return len(self.captions) def get_transforms(mode="train"): if mode == "train": return A.Compose( [ A.Resize(CFG.size, CFG.size, always_apply=True), A.Normalize(max_pixel_value=255.0, always_apply=True), ] ) else: return A.Compose( [ A.Resize(CFG.size, CFG.size, always_apply=True), A.Normalize(max_pixel_value=255.0, always_apply=True), ] ) class ImageEncoder(nn.Module): """ Encode images to a fixed size vector """ def __init__( self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable ): super().__init__() self.model = timm.create_model( model_name, pretrained, num_classes=0, global_pool="avg" ) for p in self.model.parameters(): p.requires_grad = trainable def forward(self, x): return self.model(x) class TextEncoder(nn.Module): def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable): super().__init__() if pretrained: # self.model = DistilBertModel.from_pretrained(model_name) # Use Bio-ClinicalBERT self.model = AutoModel.from_pretrained(CFG.clinical_encoder_model) else: self.model = DistilBertModel(config=DistilBertConfig()) for p in self.model.parameters(): p.requires_grad = trainable # we are using the CLS token hidden representation as the sentence's embedding self.target_token_idx = 0 def forward(self, input_ids, attention_mask): output = self.model(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = output.last_hidden_state return last_hidden_state[:, self.target_token_idx, :] # Get both image and text encodings into a same size matrix class ProjectionHead(nn.Module): def __init__( self, embedding_dim, projection_dim=CFG.projection_dim, dropout=CFG.dropout ): super().__init__() self.projection = nn.Linear(embedding_dim, projection_dim) self.gelu = nn.GELU() self.fc = nn.Linear(projection_dim, projection_dim) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(projection_dim) def forward(self, x): projected = self.projection(x) x = self.gelu(projected) x = self.fc(x) x = self.dropout(x) x = x + projected x = self.layer_norm(x) return x class CLIPModel(nn.Module): def __init__( self, temperature=CFG.temperature, image_embedding=CFG.image_embedding, text_embedding=CFG.text_embedding, ): super().__init__() self.image_encoder = ImageEncoder() self.text_encoder = TextEncoder() self.image_projection = ProjectionHead(embedding_dim=image_embedding) self.text_projection = ProjectionHead(embedding_dim=text_embedding) self.temperature = temperature def forward(self, batch): # Getting Image and Text Features image_features = self.image_encoder(batch["image"]) text_features = self.text_encoder( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] ) # Getting Image and Text Embeddings (with same dimension) image_embeddings = self.image_projection(image_features) text_embeddings = self.text_projection(text_features) # Calculating the Loss logits = (text_embeddings @ image_embeddings.T) / self.temperature images_similarity = image_embeddings @ image_embeddings.T texts_similarity = text_embeddings @ text_embeddings.T targets = F.softmax( (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 ) texts_loss = cross_entropy(logits, targets, reduction='none') images_loss = cross_entropy(logits.T, targets.T, reduction='none') loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size) return loss.mean() def cross_entropy(preds, targets, reduction='none'): log_softmax = nn.LogSoftmax(dim=-1) loss = (-targets * log_softmax(preds)).sum(1) if reduction == "none": return loss elif reduction == "mean": return loss.mean() # INFERENCE CODE def get_image_embeddings(image): # preprocess the image if image is None: print("Image not found!") return None image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = get_transforms("valid")(image=image)['image'] image = image.reshape(3, 224, 224) model = CLIPModel().to(device) model.load_state_dict(torch.load('weights.pt', map_location=device)) model.eval() with torch.no_grad(): image_tensor = torch.from_numpy(image) image_features = model.image_encoder(image_tensor.unsqueeze(0).to(device)) image_embeddings = model.image_projection(image_features) image_embeddings = F.normalize(image_embeddings, p=2, dim=-1) return image_embeddings def predict_caption(image, model, text_embeddings, captions, n=2): # get the image embeddings image_embeddings = get_image_embeddings(image) if image_embeddings is None: return None # normalize the embeddings image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) # calculate the dot product of image and text embeddings dot_similarity = image_embeddings_n @ text_embeddings_n.T # get the top n matches values, indices = torch.topk(dot_similarity.squeeze(0), n) indices = indices.cpu().numpy().tolist() matches = [captions[idx] for idx in indices] return matches def get_text_embeddings(valid_df): tokenizer = AutoTokenizer.from_pretrained(CFG.clinical_encoder_model) valid_loader = build_loaders(valid_df, tokenizer, mode="valid") model = CLIPModel().to(device) model.load_state_dict(torch.load("weights.pt", map_location=device)) model.eval() valid_text_embeddings = [] with torch.no_grad(): for batch in tqdm(valid_loader): text_features = model.text_encoder( input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device) ) text_embeddings = model.text_projection(text_features) valid_text_embeddings.append(text_embeddings) return model, torch.cat(valid_text_embeddings)