import torch import numpy as np from typing import List, Callable import numpy as np import cv2 class PerturbationConfidenceMetric: def __init__(self, perturbation): self.perturbation = perturbation def __call__(self, input_tensor: torch.Tensor, cams: np.ndarray, targets: List[Callable], model: torch.nn.Module, return_visualization=False, return_diff=True): if return_diff: with torch.no_grad(): outputs = model(input_tensor) scores = [target(output).cpu().numpy() for target, output in zip(targets, outputs)] scores = np.float32(scores) batch_size = input_tensor.size(0) perturbated_tensors = [] for i in range(batch_size): cam = cams[i] tensor = self.perturbation(input_tensor[i, ...].cpu(), torch.from_numpy(cam)) tensor = tensor.to(input_tensor.device) perturbated_tensors.append(tensor.unsqueeze(0)) perturbated_tensors = torch.cat(perturbated_tensors) with torch.no_grad(): outputs_after_imputation = model(perturbated_tensors) scores_after_imputation = [ target(output).cpu().numpy() for target, output in zip( targets, outputs_after_imputation)] scores_after_imputation = np.float32(scores_after_imputation) if return_diff: result = scores_after_imputation - scores else: result = scores_after_imputation if return_visualization: return result, perturbated_tensors else: return result class RemoveMostRelevantFirst: def __init__(self, percentile, imputer): self.percentile = percentile self.imputer = imputer def __call__(self, input_tensor, mask): imputer = self.imputer if self.percentile != 'auto': threshold = np.percentile(mask.cpu().numpy(), self.percentile) binary_mask = np.float32(mask < threshold) else: _, binary_mask = cv2.threshold( np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) binary_mask = torch.from_numpy(binary_mask) binary_mask = binary_mask.to(mask.device) return imputer(input_tensor, binary_mask) class RemoveLeastRelevantFirst(RemoveMostRelevantFirst): def __init__(self, percentile, imputer): super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer) def __call__(self, input_tensor, mask): return super(RemoveLeastRelevantFirst, self).__call__( input_tensor, 1 - mask) class AveragerAcrossThresholds: def __init__( self, imputer, percentiles=[ 10, 20, 30, 40, 50, 60, 70, 80, 90]): self.imputer = imputer self.percentiles = percentiles def __call__(self, input_tensor: torch.Tensor, cams: np.ndarray, targets: List[Callable], model: torch.nn.Module): scores = [] for percentile in self.percentiles: imputer = self.imputer(percentile) scores.append(imputer(input_tensor, cams, targets, model)) return np.mean(np.float32(scores), axis=0)