File size: 4,814 Bytes
49bceed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from abc import ABC, abstractmethod
import numpy as np
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from utils import configs
from utils.functional import (
check_data_type_variable,
get_device,
image_augmentations,
normalize_image_to_zero_one,
reshape_transform,
)
class BaseModelGradCAM(ABC):
def __init__(
self,
name_model: str,
freeze_model: bool,
pretrained_model: bool,
support_set_method: str,
):
self.name_model = name_model
self.freeze_model = freeze_model
self.pretrained_model = pretrained_model
self.support_set_method = support_set_method
self.model = None
self.device = get_device()
self.check_arguments()
def check_arguments(self):
check_data_type_variable(self.name_model, str)
check_data_type_variable(self.freeze_model, bool)
check_data_type_variable(self.pretrained_model, bool)
check_data_type_variable(self.support_set_method, str)
old_name_model = self.name_model
if self.name_model == configs.CLIP_NAME_MODEL:
old_name_model = self.name_model
self.name_model = "clip"
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
raise ValueError(f"Model {self.name_model} not supported")
if self.support_set_method not in configs.SUPPORT_SET_METHODS:
raise ValueError(
f"Support set method {self.support_set_method} not supported"
)
self.name_model = old_name_model
@abstractmethod
def init_model(self):
pass
def set_grad_cam(self):
if self.name_model == "resnet50":
self.target_layers = (self.model.model.layer4[-1],)
elif self.name_model == "vgg16":
self.target_layers = (self.model.model.features[-1],)
elif self.name_model == "inception_v4":
self.target_layers = (self.model.model.features[-1],)
elif self.name_model == "efficientnet_b4":
self.target_layers = (self.model.model.blocks[-1],)
elif self.name_model == "mobilenetv3_large_100":
self.target_layers = (self.model.model.blocks[-1],)
elif self.name_model == "densenet121":
self.target_layers = (self.model.model.features[-1],)
elif self.name_model == "vit_base_patch16_224_dino":
self.target_layers = (self.model.model.blocks[-1].norm1,)
elif self.name_model == "clip":
self.target_layers = (
self.model.vision_model.encoder.layers[-1].layer_norm1,
)
else:
self.target_layers = (self.model.model.features[-1],)
if self.name_model in ("vit_base_patch16_224_dino", "clip"):
self.gradcam = GradCAM(
model=self.model,
target_layers=self.target_layers,
reshape_transform=reshape_transform,
use_cuda=True if self.device.type == "cuda" else False,
)
else:
self.gradcam = GradCAM(
model=self.model,
target_layers=self.target_layers,
use_cuda=True if self.device.type == "cuda" else False,
)
def get_grad_cam(self, image: np.ndarray) -> np.ndarray:
image = np.array(
Image.fromarray(image).resize((configs.SIZE_IMAGES, configs.SIZE_IMAGES))
)
image_input = image_augmentations()(image=image)["image"]
image_input = image_input.unsqueeze(axis=0).to(self.device)
gradcam = self.gradcam(image_input)
gradcam = gradcam[0, :]
gradcam = show_cam_on_image(
normalize_image_to_zero_one(image), gradcam, use_rgb=True
)
return gradcam
def get_grad_cam_with_output_target(
self, image: np.ndarray, index_class: int
) -> np.ndarray:
image = np.array(
Image.fromarray(image).resize((configs.SIZE_IMAGES, configs.SIZE_IMAGES))
)
image_input = image_augmentations()(image=image)["image"]
image_input = image_input.unsqueeze(axis=0).to(self.device)
targets = (ClassifierOutputTarget(index_class),)
gradcam = self.gradcam(image_input, targets=targets)
gradcam = gradcam[0, :]
gradcam = show_cam_on_image(
normalize_image_to_zero_one(image), gradcam, use_rgb=True
)
return gradcam
|