rice_leaf_diseases / pytorch_grad_cam /metrics /perturbation_confidence.py
unknown
Add application file
f526a64
raw
history blame
No virus
3.51 kB
import torch
import numpy as np
from typing import List, Callable
import numpy as np
import cv2
class PerturbationConfidenceMetric:
def __init__(self, perturbation):
self.perturbation = perturbation
def __call__(self, input_tensor: torch.Tensor,
cams: np.ndarray,
targets: List[Callable],
model: torch.nn.Module,
return_visualization=False,
return_diff=True):
if return_diff:
with torch.no_grad():
outputs = model(input_tensor)
scores = [target(output).cpu().numpy()
for target, output in zip(targets, outputs)]
scores = np.float32(scores)
batch_size = input_tensor.size(0)
perturbated_tensors = []
for i in range(batch_size):
cam = cams[i]
tensor = self.perturbation(input_tensor[i, ...].cpu(),
torch.from_numpy(cam))
tensor = tensor.to(input_tensor.device)
perturbated_tensors.append(tensor.unsqueeze(0))
perturbated_tensors = torch.cat(perturbated_tensors)
with torch.no_grad():
outputs_after_imputation = model(perturbated_tensors)
scores_after_imputation = [
target(output).cpu().numpy() for target, output in zip(
targets, outputs_after_imputation)]
scores_after_imputation = np.float32(scores_after_imputation)
if return_diff:
result = scores_after_imputation - scores
else:
result = scores_after_imputation
if return_visualization:
return result, perturbated_tensors
else:
return result
class RemoveMostRelevantFirst:
def __init__(self, percentile, imputer):
self.percentile = percentile
self.imputer = imputer
def __call__(self, input_tensor, mask):
imputer = self.imputer
if self.percentile != 'auto':
threshold = np.percentile(mask.cpu().numpy(), self.percentile)
binary_mask = np.float32(mask < threshold)
else:
_, binary_mask = cv2.threshold(
np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
binary_mask = torch.from_numpy(binary_mask)
binary_mask = binary_mask.to(mask.device)
return imputer(input_tensor, binary_mask)
class RemoveLeastRelevantFirst(RemoveMostRelevantFirst):
def __init__(self, percentile, imputer):
super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer)
def __call__(self, input_tensor, mask):
return super(RemoveLeastRelevantFirst, self).__call__(
input_tensor, 1 - mask)
class AveragerAcrossThresholds:
def __init__(
self,
imputer,
percentiles=[
10,
20,
30,
40,
50,
60,
70,
80,
90]):
self.imputer = imputer
self.percentiles = percentiles
def __call__(self,
input_tensor: torch.Tensor,
cams: np.ndarray,
targets: List[Callable],
model: torch.nn.Module):
scores = []
for percentile in self.percentiles:
imputer = self.imputer(percentile)
scores.append(imputer(input_tensor, cams, targets, model))
return np.mean(np.float32(scores), axis=0)