import numpy as np import torch import torch.nn as nn import pickle from typing import Tuple from transformers import GPT2LMHeadModel from modules.decoder import DeCap from medclip import MedCLIPModel, MedCLIPVisionModelViT import math import pdb class MedCapModel(nn.Module): def __init__(self, args, tokenizer): super(MedCapModel, self).__init__() self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.args = args self.tokenizer = tokenizer self.model = DeCap(args, tokenizer) self.align_model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT) self.align_model.from_pretrained() self.prompt = torch.load(args.prompt) if args.dataset == 'iu_xray': self.forward = self.forward_iu_xray else: self.forward = self.forward_mimic_cxr def noise_injection(self, x, variance=0.001, modality_offset=None, dont_norm=False): if variance == 0.0: return x std = math.sqrt(variance) if not dont_norm: x = torch.nn.functional.normalize(x, dim=1) else: x = x + (torch.randn(x.shape) * std) # todo by some conventions multivraiance noise should be devided by sqrt of dim if modality_offset is not None: x = x + modality_offset return torch.nn.functional.normalize(x, dim=1) def align_encode_images_iu_xray(self, images): # Split the images image1, image2 = images.unbind(dim=1) # Encode each image feature1 = self.align_model.encode_image(image1) feature2 = self.align_model.encode_image(image2) if self.args.prompt_load == 'yes': sim_1 = feature1 @ self.prompt.T.float() sim_1 = (sim_1 * 100).softmax(dim=-1) prefix_embedding_1 = sim_1 @ self.prompt.float() prefix_embedding_1 /= prefix_embedding_1.norm(dim=-1, keepdim=True) sim_2 = feature2 @ self.prompt.T.float() sim_2 = (sim_2 * 100).softmax(dim=-1) prefix_embedding_2 = sim_2 @ self.prompt.float() prefix_embedding_2 /= prefix_embedding_2.norm(dim=-1, keepdim=True) averaged_prompt_features = torch.mean(torch.stack([prefix_embedding_1, prefix_embedding_2]), dim=0) return averaged_prompt_features else: # Concatenate the features averaged_features = torch.mean(torch.stack([feature1, feature2]), dim=0) return averaged_features def align_encode_images_mimic_cxr(self, images): feature = self.align_model.encode_image(images) if self.args.prompt_load == 'yes': sim = feature @ self.prompt.T.float() sim = (sim * 100).softmax(dim=-1) prefix_embedding = sim @ self.prompt.float() prefix_embedding /= prefix_embedding.norm(dim=-1, keepdim=True) return prefix_embedding else: return feature def forward_iu_xray(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}): self.align_model.to(self.device) self.align_model.eval() align_ids = align_ids.long() align_image_feature = None if self.args.train_mode == 'fine-tuning': align_image_feature = self.align_encode_images_iu_xray(images) if mode == 'train': align_text_feature = self.align_model.encode_text(align_ids, align_masks) if self.args.noise_inject == 'yes': align_text_feature = self.noise_injection(align_text_feature) if self.args.train_mode == 'fine-tuning': if self.args.F_version == 'v1': combined_feature = torch.cat([align_text_feature, align_image_feature], dim=-1) align_text_feature = self.fc_reduce_dim(combined_feature) if self.args.F_version == 'v2': align_text_feature = align_image_feature outputs = self.model(align_text_feature, reports_ids, mode='forward') logits = outputs.logits logits = logits[:, :-1] return logits elif mode == 'sample': align_image_feature = self.align_encode_images_iu_xray(images) outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts) return outputs else: raise ValueError def forward_mimic_cxr(self, reports_ids, align_ids, align_masks, images, mode='train', update_opts={}): self.align_model.to(self.device) self.align_model.eval() align_ids = align_ids.long() if mode == 'train': if self.args.noise_inject == 'yes': align_text_feature = self.align_model.encode_text(align_ids, align_masks) align_text_feature = self.noise_injection(align_text_feature) else: align_text_feature = self.align_model.encode_text(align_ids, align_masks) outputs = self.model(align_text_feature, reports_ids, mode='forward') logits = outputs.logits logits = logits[:, :-1] return logits elif mode == 'sample': align_image_feature = self.align_encode_images_mimic_cxr(images) outputs = self.model(align_image_feature, reports_ids, mode='sample', update_opts=update_opts) return outputs else: raise ValueError