Spaces:
Sleeping
Sleeping
import numpy as np | |
from PIL import Image | |
import torch | |
from typing import Callable, List, Tuple, Optional | |
from sklearn.decomposition import NMF | |
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients | |
from pytorch_grad_cam.utils.image import scale_cam_image, create_labels_legend, show_factorization_on_image | |
def dff(activations: np.ndarray, n_components: int = 5): | |
""" Compute Deep Feature Factorization on a 2d Activations tensor. | |
:param activations: A numpy array of shape batch x channels x height x width | |
:param n_components: The number of components for the non negative matrix factorization | |
:returns: A tuple of the concepts (a numpy array with shape channels x components), | |
and the explanation heatmaps (a numpy arary with shape batch x height x width) | |
""" | |
batch_size, channels, h, w = activations.shape | |
reshaped_activations = activations.transpose((1, 0, 2, 3)) | |
reshaped_activations[np.isnan(reshaped_activations)] = 0 | |
reshaped_activations = reshaped_activations.reshape( | |
reshaped_activations.shape[0], -1) | |
offset = reshaped_activations.min(axis=-1) | |
reshaped_activations = reshaped_activations - offset[:, None] | |
model = NMF(n_components=n_components, init='random', random_state=0) | |
W = model.fit_transform(reshaped_activations) | |
H = model.components_ | |
concepts = W + offset[:, None] | |
explanations = H.reshape(n_components, batch_size, h, w) | |
explanations = explanations.transpose((1, 0, 2, 3)) | |
return concepts, explanations | |
class DeepFeatureFactorization: | |
""" Deep Feature Factorization: https://arxiv.org/abs/1806.10206 | |
This gets a model andcomputes the 2D activations for a target layer, | |
and computes Non Negative Matrix Factorization on the activations. | |
Optionally it runs a computation on the concept embeddings, | |
like running a classifier on them. | |
The explanation heatmaps are scalled to the range [0, 1] | |
and to the input tensor width and height. | |
""" | |
def __init__(self, | |
model: torch.nn.Module, | |
target_layer: torch.nn.Module, | |
reshape_transform: Callable = None, | |
computation_on_concepts=None | |
): | |
self.model = model | |
self.computation_on_concepts = computation_on_concepts | |
self.activations_and_grads = ActivationsAndGradients( | |
self.model, [target_layer], reshape_transform) | |
def __call__(self, | |
input_tensor: torch.Tensor, | |
n_components: int = 16): | |
batch_size, channels, h, w = input_tensor.size() | |
_ = self.activations_and_grads(input_tensor) | |
with torch.no_grad(): | |
activations = self.activations_and_grads.activations[0].cpu( | |
).numpy() | |
concepts, explanations = dff(activations, n_components=n_components) | |
processed_explanations = [] | |
for batch in explanations: | |
processed_explanations.append(scale_cam_image(batch, (w, h))) | |
if self.computation_on_concepts: | |
with torch.no_grad(): | |
concept_tensors = torch.from_numpy( | |
np.float32(concepts).transpose((1, 0))) | |
concept_outputs = self.computation_on_concepts( | |
concept_tensors).cpu().numpy() | |
return concepts, processed_explanations, concept_outputs | |
else: | |
return concepts, processed_explanations | |
def __del__(self): | |
self.activations_and_grads.release() | |
def __exit__(self, exc_type, exc_value, exc_tb): | |
self.activations_and_grads.release() | |
if isinstance(exc_value, IndexError): | |
# Handle IndexError here... | |
print( | |
f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}") | |
return True | |
def run_dff_on_image(model: torch.nn.Module, | |
target_layer: torch.nn.Module, | |
classifier: torch.nn.Module, | |
img_pil: Image, | |
img_tensor: torch.Tensor, | |
reshape_transform=Optional[Callable], | |
n_components: int = 5, | |
top_k: int = 2) -> np.ndarray: | |
""" Helper function to create a Deep Feature Factorization visualization for a single image. | |
TBD: Run this on a batch with several images. | |
""" | |
rgb_img_float = np.array(img_pil) / 255 | |
dff = DeepFeatureFactorization(model=model, | |
reshape_transform=reshape_transform, | |
target_layer=target_layer, | |
computation_on_concepts=classifier) | |
concepts, batch_explanations, concept_outputs = dff( | |
img_tensor[None, :], n_components) | |
concept_outputs = torch.softmax( | |
torch.from_numpy(concept_outputs), | |
axis=-1).numpy() | |
concept_label_strings = create_labels_legend(concept_outputs, | |
labels=model.config.id2label, | |
top_k=top_k) | |
visualization = show_factorization_on_image( | |
rgb_img_float, | |
batch_explanations[0], | |
image_weight=0.3, | |
concept_labels=concept_label_strings) | |
result = np.hstack((np.array(img_pil), visualization)) | |
return result | |