import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from abc import ABC, abstractmethod import numpy as np from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from utils import configs from utils.functional import ( check_data_type_variable, get_device, image_augmentations, normalize_image_to_zero_one, reshape_transform, ) class BaseModelGradCAM(ABC): def __init__( self, name_model: str, freeze_model: bool, pretrained_model: bool, support_set_method: str, ): self.name_model = name_model self.freeze_model = freeze_model self.pretrained_model = pretrained_model self.support_set_method = support_set_method self.model = None self.device = get_device() self.check_arguments() def check_arguments(self): check_data_type_variable(self.name_model, str) check_data_type_variable(self.freeze_model, bool) check_data_type_variable(self.pretrained_model, bool) check_data_type_variable(self.support_set_method, str) old_name_model = self.name_model if self.name_model == configs.CLIP_NAME_MODEL: old_name_model = self.name_model self.name_model = "clip" if self.name_model not in tuple(configs.NAME_MODELS.keys()): raise ValueError(f"Model {self.name_model} not supported") if self.support_set_method not in configs.SUPPORT_SET_METHODS: raise ValueError( f"Support set method {self.support_set_method} not supported" ) self.name_model = old_name_model @abstractmethod def init_model(self): pass def set_grad_cam(self): if self.name_model == "resnet50": self.target_layers = (self.model.model.layer4[-1],) elif self.name_model == "vgg16": self.target_layers = (self.model.model.features[-1],) elif self.name_model == "inception_v4": self.target_layers = (self.model.model.features[-1],) elif self.name_model == "efficientnet_b4": self.target_layers = (self.model.model.blocks[-1],) elif self.name_model == "mobilenetv3_large_100": self.target_layers = (self.model.model.blocks[-1],) elif self.name_model == "densenet121": self.target_layers = (self.model.model.features[-1],) elif self.name_model == "vit_base_patch16_224_dino": self.target_layers = (self.model.model.blocks[-1].norm1,) elif self.name_model == "clip": self.target_layers = ( self.model.vision_model.encoder.layers[-1].layer_norm1, ) else: self.target_layers = (self.model.model.features[-1],) if self.name_model in ("vit_base_patch16_224_dino", "clip"): self.gradcam = GradCAM( model=self.model, target_layers=self.target_layers, reshape_transform=reshape_transform, use_cuda=True if self.device.type == "cuda" else False, ) else: self.gradcam = GradCAM( model=self.model, target_layers=self.target_layers, use_cuda=True if self.device.type == "cuda" else False, ) def get_grad_cam(self, image: np.ndarray) -> np.ndarray: image = np.array( Image.fromarray(image).resize((configs.SIZE_IMAGES, configs.SIZE_IMAGES)) ) image_input = image_augmentations()(image=image)["image"] image_input = image_input.unsqueeze(axis=0).to(self.device) gradcam = self.gradcam(image_input) gradcam = gradcam[0, :] gradcam = show_cam_on_image( normalize_image_to_zero_one(image), gradcam, use_rgb=True ) return gradcam def get_grad_cam_with_output_target( self, image: np.ndarray, index_class: int ) -> np.ndarray: image = np.array( Image.fromarray(image).resize((configs.SIZE_IMAGES, configs.SIZE_IMAGES)) ) image_input = image_augmentations()(image=image)["image"] image_input = image_input.unsqueeze(axis=0).to(self.device) targets = (ClassifierOutputTarget(index_class),) gradcam = self.gradcam(image_input, targets=targets) gradcam = gradcam[0, :] gradcam = show_cam_on_image( normalize_image_to_zero_one(image), gradcam, use_rgb=True ) return gradcam