|
import os
|
|
import sys
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from pytorch_grad_cam import GradCAM
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
|
|
|
from utils import configs
|
|
from utils.functional import (
|
|
check_data_type_variable,
|
|
get_device,
|
|
image_augmentations,
|
|
normalize_image_to_zero_one,
|
|
reshape_transform,
|
|
)
|
|
|
|
|
|
class BaseModelGradCAM(ABC):
|
|
def __init__(
|
|
self,
|
|
name_model: str,
|
|
freeze_model: bool,
|
|
pretrained_model: bool,
|
|
support_set_method: str,
|
|
):
|
|
self.name_model = name_model
|
|
self.freeze_model = freeze_model
|
|
self.pretrained_model = pretrained_model
|
|
self.support_set_method = support_set_method
|
|
self.model = None
|
|
self.device = get_device()
|
|
|
|
self.check_arguments()
|
|
|
|
def check_arguments(self):
|
|
check_data_type_variable(self.name_model, str)
|
|
check_data_type_variable(self.freeze_model, bool)
|
|
check_data_type_variable(self.pretrained_model, bool)
|
|
check_data_type_variable(self.support_set_method, str)
|
|
|
|
old_name_model = self.name_model
|
|
if self.name_model == configs.CLIP_NAME_MODEL:
|
|
old_name_model = self.name_model
|
|
self.name_model = "clip"
|
|
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
|
|
raise ValueError(f"Model {self.name_model} not supported")
|
|
if self.support_set_method not in configs.SUPPORT_SET_METHODS:
|
|
raise ValueError(
|
|
f"Support set method {self.support_set_method} not supported"
|
|
)
|
|
self.name_model = old_name_model
|
|
|
|
@abstractmethod
|
|
def init_model(self):
|
|
pass
|
|
|
|
def set_grad_cam(self):
|
|
if self.name_model == "resnet50":
|
|
self.target_layers = (self.model.model.layer4[-1],)
|
|
elif self.name_model == "vgg16":
|
|
self.target_layers = (self.model.model.features[-1],)
|
|
elif self.name_model == "inception_v4":
|
|
self.target_layers = (self.model.model.features[-1],)
|
|
elif self.name_model == "efficientnet_b4":
|
|
self.target_layers = (self.model.model.blocks[-1],)
|
|
elif self.name_model == "mobilenetv3_large_100":
|
|
self.target_layers = (self.model.model.blocks[-1],)
|
|
elif self.name_model == "densenet121":
|
|
self.target_layers = (self.model.model.features[-1],)
|
|
elif self.name_model == "vit_base_patch16_224_dino":
|
|
self.target_layers = (self.model.model.blocks[-1].norm1,)
|
|
elif self.name_model == "clip":
|
|
self.target_layers = (
|
|
self.model.vision_model.encoder.layers[-1].layer_norm1,
|
|
)
|
|
else:
|
|
self.target_layers = (self.model.model.features[-1],)
|
|
|
|
if self.name_model in ("vit_base_patch16_224_dino", "clip"):
|
|
self.gradcam = GradCAM(
|
|
model=self.model,
|
|
target_layers=self.target_layers,
|
|
reshape_transform=reshape_transform,
|
|
use_cuda=True if self.device.type == "cuda" else False,
|
|
)
|
|
else:
|
|
self.gradcam = GradCAM(
|
|
model=self.model,
|
|
target_layers=self.target_layers,
|
|
use_cuda=True if self.device.type == "cuda" else False,
|
|
)
|
|
|
|
def get_grad_cam(self, image: np.ndarray) -> np.ndarray:
|
|
image = np.array(
|
|
Image.fromarray(image).resize((configs.SIZE_IMAGES, configs.SIZE_IMAGES))
|
|
)
|
|
image_input = image_augmentations()(image=image)["image"]
|
|
image_input = image_input.unsqueeze(axis=0).to(self.device)
|
|
gradcam = self.gradcam(image_input)
|
|
gradcam = gradcam[0, :]
|
|
gradcam = show_cam_on_image(
|
|
normalize_image_to_zero_one(image), gradcam, use_rgb=True
|
|
)
|
|
return gradcam
|
|
|
|
def get_grad_cam_with_output_target(
|
|
self, image: np.ndarray, index_class: int
|
|
) -> np.ndarray:
|
|
image = np.array(
|
|
Image.fromarray(image).resize((configs.SIZE_IMAGES, configs.SIZE_IMAGES))
|
|
)
|
|
image_input = image_augmentations()(image=image)["image"]
|
|
image_input = image_input.unsqueeze(axis=0).to(self.device)
|
|
targets = (ClassifierOutputTarget(index_class),)
|
|
gradcam = self.gradcam(image_input, targets=targets)
|
|
gradcam = gradcam[0, :]
|
|
gradcam = show_cam_on_image(
|
|
normalize_image_to_zero_one(image), gradcam, use_rgb=True
|
|
)
|
|
return gradcam
|
|
|