unknown
Add application file
f526a64
raw
history blame
No virus
5.41 kB
import numpy as np
from PIL import Image
import torch
from typing import Callable, List, Tuple, Optional
from sklearn.decomposition import NMF
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
from pytorch_grad_cam.utils.image import scale_cam_image, create_labels_legend, show_factorization_on_image
def dff(activations: np.ndarray, n_components: int = 5):
""" Compute Deep Feature Factorization on a 2d Activations tensor.
:param activations: A numpy array of shape batch x channels x height x width
:param n_components: The number of components for the non negative matrix factorization
:returns: A tuple of the concepts (a numpy array with shape channels x components),
and the explanation heatmaps (a numpy arary with shape batch x height x width)
"""
batch_size, channels, h, w = activations.shape
reshaped_activations = activations.transpose((1, 0, 2, 3))
reshaped_activations[np.isnan(reshaped_activations)] = 0
reshaped_activations = reshaped_activations.reshape(
reshaped_activations.shape[0], -1)
offset = reshaped_activations.min(axis=-1)
reshaped_activations = reshaped_activations - offset[:, None]
model = NMF(n_components=n_components, init='random', random_state=0)
W = model.fit_transform(reshaped_activations)
H = model.components_
concepts = W + offset[:, None]
explanations = H.reshape(n_components, batch_size, h, w)
explanations = explanations.transpose((1, 0, 2, 3))
return concepts, explanations
class DeepFeatureFactorization:
""" Deep Feature Factorization: https://arxiv.org/abs/1806.10206
This gets a model andcomputes the 2D activations for a target layer,
and computes Non Negative Matrix Factorization on the activations.
Optionally it runs a computation on the concept embeddings,
like running a classifier on them.
The explanation heatmaps are scalled to the range [0, 1]
and to the input tensor width and height.
"""
def __init__(self,
model: torch.nn.Module,
target_layer: torch.nn.Module,
reshape_transform: Callable = None,
computation_on_concepts=None
):
self.model = model
self.computation_on_concepts = computation_on_concepts
self.activations_and_grads = ActivationsAndGradients(
self.model, [target_layer], reshape_transform)
def __call__(self,
input_tensor: torch.Tensor,
n_components: int = 16):
batch_size, channels, h, w = input_tensor.size()
_ = self.activations_and_grads(input_tensor)
with torch.no_grad():
activations = self.activations_and_grads.activations[0].cpu(
).numpy()
concepts, explanations = dff(activations, n_components=n_components)
processed_explanations = []
for batch in explanations:
processed_explanations.append(scale_cam_image(batch, (w, h)))
if self.computation_on_concepts:
with torch.no_grad():
concept_tensors = torch.from_numpy(
np.float32(concepts).transpose((1, 0)))
concept_outputs = self.computation_on_concepts(
concept_tensors).cpu().numpy()
return concepts, processed_explanations, concept_outputs
else:
return concepts, processed_explanations
def __del__(self):
self.activations_and_grads.release()
def __exit__(self, exc_type, exc_value, exc_tb):
self.activations_and_grads.release()
if isinstance(exc_value, IndexError):
# Handle IndexError here...
print(
f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}")
return True
def run_dff_on_image(model: torch.nn.Module,
target_layer: torch.nn.Module,
classifier: torch.nn.Module,
img_pil: Image,
img_tensor: torch.Tensor,
reshape_transform=Optional[Callable],
n_components: int = 5,
top_k: int = 2) -> np.ndarray:
""" Helper function to create a Deep Feature Factorization visualization for a single image.
TBD: Run this on a batch with several images.
"""
rgb_img_float = np.array(img_pil) / 255
dff = DeepFeatureFactorization(model=model,
reshape_transform=reshape_transform,
target_layer=target_layer,
computation_on_concepts=classifier)
concepts, batch_explanations, concept_outputs = dff(
img_tensor[None, :], n_components)
concept_outputs = torch.softmax(
torch.from_numpy(concept_outputs),
axis=-1).numpy()
concept_label_strings = create_labels_legend(concept_outputs,
labels=model.config.id2label,
top_k=top_k)
visualization = show_factorization_on_image(
rgb_img_float,
batch_explanations[0],
image_weight=0.3,
concept_labels=concept_label_strings)
result = np.hstack((np.array(img_pil), visualization))
return result