Spaces:
Sleeping
Sleeping
| from torch import nn | |
| from tqdm.autonotebook import tqdm | |
| from transformers import AutoTokenizer, AutoModel | |
| from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer | |
| import albumentations as A | |
| import cv2 | |
| import timm | |
| import torch | |
| import torch.nn.functional as F | |
| device = torch.device("cpu") | |
| class CFG: | |
| debug = False | |
| image_path = '/content/content/new_images_v5' | |
| captions_path = '/content/content/all_data/new_caption.csv' | |
| batch_size = 12 | |
| num_workers = 2 | |
| head_lr = 1e-3 | |
| image_encoder_lr = 1e-4 | |
| text_encoder_lr = 1e-5 | |
| weight_decay = 1e-3 | |
| patience = 1 | |
| factor = 0.8 | |
| epochs = 2 | |
| saved_model_clinical = '/content/content/new_weights.pt' | |
| trained_model = 'clinical_bert_weights.pt' | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_name = 'resnet50' | |
| image_embedding = 2048 | |
| text_encoder_model = "distilbert-base-uncased" | |
| clinical_encoder_model = "emilyalsentzer/Bio_ClinicalBERT" | |
| text_embedding = 768 | |
| text_tokenizer = "distilbert-base-uncased" | |
| max_length = 200 | |
| pretrained = True # for both image encoder and text encoder | |
| trainable = True # for both image encoder and text encoder | |
| temperature = 1.0 | |
| # image size | |
| size = 224 | |
| # for projection head; used for both image and text encoders | |
| num_projection_layers = 1 | |
| projection_dim = 256 | |
| dropout = 0.1 | |
| def build_loaders(dataframe, tokenizer, mode): | |
| transforms = get_transforms(mode=mode) | |
| dataset = CLIPDataset( | |
| dataframe["image"].values, | |
| dataframe["caption"].values, | |
| tokenizer=tokenizer, | |
| transforms=transforms, | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=CFG.batch_size, | |
| num_workers=CFG.num_workers, | |
| shuffle=True if mode == "train" else False, | |
| ) | |
| return dataloader | |
| class AvgMeter: | |
| def __init__(self, name="Metric"): | |
| self.name = name | |
| self.reset() | |
| def reset(self): | |
| self.avg, self.sum, self.count = [0] * 3 | |
| def update(self, val, count=1): | |
| self.count += count | |
| self.sum += val * count | |
| self.avg = self.sum / self.count | |
| def __repr__(self): | |
| text = f"{self.name}: {self.avg:.4f}" | |
| return text | |
| def get_lr(optimizer): | |
| for param_group in optimizer.param_groups: | |
| return param_group["lr"] | |
| # Custom dataset object. Will tokenize text and apply transforms to images before yielding them. | |
| class CLIPDataset(torch.utils.data.Dataset): | |
| def __init__(self, image_filenames, captions, tokenizer, transforms): | |
| """ | |
| image_filenames and cpations must have the same length; so, if there are | |
| multiple captions for each image, the image_filenames must have repetitive | |
| file names | |
| """ | |
| self.image_filenames = image_filenames | |
| self.captions = list(captions) | |
| self.skippedImgCount = 0 | |
| self.encoded_captions = tokenizer( | |
| list(captions), padding=True, truncation=True, max_length=CFG.max_length | |
| ) | |
| self.transforms = transforms | |
| def __getitem__(self, idx): | |
| item = { | |
| key: torch.tensor(values[idx]) | |
| for key, values in self.encoded_captions.items() | |
| } | |
| image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}") | |
| if image is None: | |
| # Skip the current example and move to the next one | |
| self.skippedImgCount += 1 | |
| return self.__getitem__((idx + 1) % len(self)) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = self.transforms(image=image)['image'] | |
| item['image'] = torch.tensor(image).permute(2, 0, 1).float() | |
| item['caption'] = self.captions[idx] | |
| return item | |
| def __len__(self): | |
| return len(self.captions) | |
| def get_transforms(mode="train"): | |
| if mode == "train": | |
| return A.Compose( | |
| [ | |
| A.Resize(CFG.size, CFG.size, always_apply=True), | |
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |
| ] | |
| ) | |
| else: | |
| return A.Compose( | |
| [ | |
| A.Resize(CFG.size, CFG.size, always_apply=True), | |
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |
| ] | |
| ) | |
| class ImageEncoder(nn.Module): | |
| """ | |
| Encode images to a fixed size vector | |
| """ | |
| def __init__( | |
| self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable | |
| ): | |
| super().__init__() | |
| self.model = timm.create_model( | |
| model_name, pretrained, num_classes=0, global_pool="avg" | |
| ) | |
| for p in self.model.parameters(): | |
| p.requires_grad = trainable | |
| def forward(self, x): | |
| return self.model(x) | |
| class TextEncoder(nn.Module): | |
| def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable): | |
| super().__init__() | |
| if pretrained: | |
| # self.model = DistilBertModel.from_pretrained(model_name) | |
| # Use Bio-ClinicalBERT | |
| self.model = AutoModel.from_pretrained(CFG.clinical_encoder_model) | |
| else: | |
| self.model = DistilBertModel(config=DistilBertConfig()) | |
| for p in self.model.parameters(): | |
| p.requires_grad = trainable | |
| # we are using the CLS token hidden representation as the sentence's embedding | |
| self.target_token_idx = 0 | |
| def forward(self, input_ids, attention_mask): | |
| output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| last_hidden_state = output.last_hidden_state | |
| return last_hidden_state[:, self.target_token_idx, :] | |
| # Get both image and text encodings into a same size matrix | |
| class ProjectionHead(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim, | |
| projection_dim=CFG.projection_dim, | |
| dropout=CFG.dropout | |
| ): | |
| super().__init__() | |
| self.projection = nn.Linear(embedding_dim, projection_dim) | |
| self.gelu = nn.GELU() | |
| self.fc = nn.Linear(projection_dim, projection_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm = nn.LayerNorm(projection_dim) | |
| def forward(self, x): | |
| projected = self.projection(x) | |
| x = self.gelu(projected) | |
| x = self.fc(x) | |
| x = self.dropout(x) | |
| x = x + projected | |
| x = self.layer_norm(x) | |
| return x | |
| class CLIPModel(nn.Module): | |
| def __init__( | |
| self, | |
| temperature=CFG.temperature, | |
| image_embedding=CFG.image_embedding, | |
| text_embedding=CFG.text_embedding, | |
| ): | |
| super().__init__() | |
| self.image_encoder = ImageEncoder() | |
| self.text_encoder = TextEncoder() | |
| self.image_projection = ProjectionHead(embedding_dim=image_embedding) | |
| self.text_projection = ProjectionHead(embedding_dim=text_embedding) | |
| self.temperature = temperature | |
| def forward(self, batch): | |
| # Getting Image and Text Features | |
| image_features = self.image_encoder(batch["image"]) | |
| text_features = self.text_encoder( | |
| input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] | |
| ) | |
| # Getting Image and Text Embeddings (with same dimension) | |
| image_embeddings = self.image_projection(image_features) | |
| text_embeddings = self.text_projection(text_features) | |
| # Calculating the Loss | |
| logits = (text_embeddings @ image_embeddings.T) / self.temperature | |
| images_similarity = image_embeddings @ image_embeddings.T | |
| texts_similarity = text_embeddings @ text_embeddings.T | |
| targets = F.softmax( | |
| (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 | |
| ) | |
| texts_loss = cross_entropy(logits, targets, reduction='none') | |
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |
| loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size) | |
| return loss.mean() | |
| def cross_entropy(preds, targets, reduction='none'): | |
| log_softmax = nn.LogSoftmax(dim=-1) | |
| loss = (-targets * log_softmax(preds)).sum(1) | |
| if reduction == "none": | |
| return loss | |
| elif reduction == "mean": | |
| return loss.mean() | |
| # INFERENCE CODE | |
| def get_image_embeddings(image): | |
| # preprocess the image | |
| if image is None: | |
| print("Image not found!") | |
| return None | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = get_transforms("valid")(image=image)['image'] | |
| image = image.reshape(3, 224, 224) | |
| model = CLIPModel().to(device) | |
| model.load_state_dict(torch.load('weights.pt', map_location=device)) | |
| model.eval() | |
| with torch.no_grad(): | |
| image_tensor = torch.from_numpy(image) | |
| image_features = model.image_encoder(image_tensor.unsqueeze(0).to(device)) | |
| image_embeddings = model.image_projection(image_features) | |
| image_embeddings = F.normalize(image_embeddings, p=2, dim=-1) | |
| return image_embeddings | |
| def predict_caption(image, model, text_embeddings, captions, n=2): | |
| # get the image embeddings | |
| image_embeddings = get_image_embeddings(image) | |
| if image_embeddings is None: | |
| return None | |
| # normalize the embeddings | |
| image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) | |
| text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) | |
| # calculate the dot product of image and text embeddings | |
| dot_similarity = image_embeddings_n @ text_embeddings_n.T | |
| # get the top n matches | |
| values, indices = torch.topk(dot_similarity.squeeze(0), n) | |
| indices = indices.cpu().numpy().tolist() | |
| matches = [captions[idx] for idx in indices] | |
| return matches | |
| def get_text_embeddings(valid_df): | |
| tokenizer = AutoTokenizer.from_pretrained(CFG.clinical_encoder_model) | |
| valid_loader = build_loaders(valid_df, tokenizer, mode="valid") | |
| model = CLIPModel().to(device) | |
| model.load_state_dict(torch.load("weights.pt", map_location=device)) | |
| model.eval() | |
| valid_text_embeddings = [] | |
| with torch.no_grad(): | |
| for batch in tqdm(valid_loader): | |
| text_features = model.text_encoder( | |
| input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device) | |
| ) | |
| text_embeddings = model.text_projection(text_features) | |
| valid_text_embeddings.append(text_embeddings) | |
| return model, torch.cat(valid_text_embeddings) |