mshukor
init
3eb682b
raw
history blame
No virus
12.6 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import confusion_matrix
import timesformer.utils.logging as logging
from timesformer.datasets.utils import pack_pathway_output, tensor_normalize
logger = logging.get_logger(__name__)
def get_confusion_matrix(preds, labels, num_classes, normalize="true"):
"""
Calculate confusion matrix on the provided preds and labels.
Args:
preds (tensor or lists of tensors): predictions. Each tensor is in
in the shape of (n_batch, num_classes). Tensor(s) must be on CPU.
labels (tensor or lists of tensors): corresponding labels. Each tensor is
in the shape of either (n_batch,) or (n_batch, num_classes).
num_classes (int): number of classes. Tensor(s) must be on CPU.
normalize (Optional[str]) : {‘true’, ‘pred’, ‘all’}, default="true"
Normalizes confusion matrix over the true (rows), predicted (columns)
conditions or all the population. If None, confusion matrix
will not be normalized.
Returns:
cmtx (ndarray): confusion matrix of size (num_classes x num_classes)
"""
if isinstance(preds, list):
preds = torch.cat(preds, dim=0)
if isinstance(labels, list):
labels = torch.cat(labels, dim=0)
# If labels are one-hot encoded, get their indices.
if labels.ndim == preds.ndim:
labels = torch.argmax(labels, dim=-1)
# Get the predicted class indices for examples.
preds = torch.flatten(torch.argmax(preds, dim=-1))
labels = torch.flatten(labels)
cmtx = confusion_matrix(
labels, preds, labels=list(range(num_classes)), normalize=normalize
)
return cmtx
def plot_confusion_matrix(cmtx, num_classes, class_names=None, figsize=None):
"""
A function to create a colored and labeled confusion matrix matplotlib figure
given true labels and preds.
Args:
cmtx (ndarray): confusion matrix.
num_classes (int): total number of classes.
class_names (Optional[list of strs]): a list of class names.
figsize (Optional[float, float]): the figure size of the confusion matrix.
If None, default to [6.4, 4.8].
Returns:
img (figure): matplotlib figure.
"""
if class_names is None or type(class_names) != list:
class_names = [str(i) for i in range(num_classes)]
figure = plt.figure(figsize=figsize)
plt.imshow(cmtx, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
# Use white text if squares are dark; otherwise black.
threshold = cmtx.max() / 2.0
for i, j in itertools.product(range(cmtx.shape[0]), range(cmtx.shape[1])):
color = "white" if cmtx[i, j] > threshold else "black"
plt.text(
j,
i,
format(cmtx[i, j], ".2f") if cmtx[i, j] != 0 else ".",
horizontalalignment="center",
color=color,
)
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
return figure
def plot_topk_histogram(tag, array, k=10, class_names=None, figsize=None):
"""
Plot histogram of top-k value from the given array.
Args:
tag (str): histogram title.
array (tensor): a tensor to draw top k value from.
k (int): number of top values to draw from array.
Defaut to 10.
class_names (list of strings, optional):
a list of names for values in array.
figsize (Optional[float, float]): the figure size of the confusion matrix.
If None, default to [6.4, 4.8].
Returns:
fig (matplotlib figure): a matplotlib figure of the histogram.
"""
val, ind = torch.topk(array, k)
fig = plt.Figure(figsize=figsize, facecolor="w", edgecolor="k")
ax = fig.add_subplot(1, 1, 1)
if class_names is None:
class_names = [str(i) for i in ind]
else:
class_names = [class_names[i] for i in ind]
tick_marks = np.arange(k)
width = 0.75
ax.bar(
tick_marks,
val,
width,
color="orange",
tick_label=class_names,
edgecolor="w",
linewidth=1,
)
ax.set_xlabel("Candidates")
ax.set_xticks(tick_marks)
ax.set_xticklabels(class_names, rotation=-45, ha="center")
ax.xaxis.set_label_position("bottom")
ax.xaxis.tick_bottom()
y_tick = np.linspace(0, 1, num=10)
ax.set_ylabel("Frequency")
ax.set_yticks(y_tick)
y_labels = [format(i, ".1f") for i in y_tick]
ax.set_yticklabels(y_labels, ha="center")
for i, v in enumerate(val.numpy()):
ax.text(
i - 0.1,
v + 0.03,
format(v, ".2f"),
color="orange",
fontweight="bold",
)
ax.set_title(tag)
fig.set_tight_layout(True)
return fig
class GetWeightAndActivation:
"""
A class used to get weights and activations from specified layers from a Pytorch model.
"""
def __init__(self, model, layers):
"""
Args:
model (nn.Module): the model containing layers to obtain weights and activations from.
layers (list of strings): a list of layer names to obtain weights and activations from.
Names are hierarchical, separated by /. For example, If a layer follow a path
"s1" ---> "pathway0_stem" ---> "conv", the layer path is "s1/pathway0_stem/conv".
"""
self.model = model
self.hooks = {}
self.layers_names = layers
# eval mode
self.model.eval()
self._register_hooks()
def _get_layer(self, layer_name):
"""
Return a layer (nn.Module Object) given a hierarchical layer name, separated by /.
Args:
layer_name (str): the name of the layer.
"""
layer_ls = layer_name.split("/")
prev_module = self.model
for layer in layer_ls:
prev_module = prev_module._modules[layer]
return prev_module
def _register_single_hook(self, layer_name):
"""
Register hook to a layer, given layer_name, to obtain activations.
Args:
layer_name (str): name of the layer.
"""
def hook_fn(module, input, output):
self.hooks[layer_name] = output.clone().detach()
layer = get_layer(self.model, layer_name)
layer.register_forward_hook(hook_fn)
def _register_hooks(self):
"""
Register hooks to layers in `self.layers_names`.
"""
for layer_name in self.layers_names:
self._register_single_hook(layer_name)
def get_activations(self, input, bboxes=None):
"""
Obtain all activations from layers that we register hooks for.
Args:
input (tensors, list of tensors): the model input.
bboxes (Optional): Bouding boxes data that might be required
by the model.
Returns:
activation_dict (Python dictionary): a dictionary of the pair
{layer_name: list of activations}, where activations are outputs returned
by the layer.
"""
input_clone = [inp.clone() for inp in input]
if bboxes is not None:
preds = self.model(input_clone, bboxes)
else:
preds = self.model(input_clone)
activation_dict = {}
for layer_name, hook in self.hooks.items():
# list of activations for each instance.
activation_dict[layer_name] = hook
return activation_dict, preds
def get_weights(self):
"""
Returns weights from registered layers.
Returns:
weights (Python dictionary): a dictionary of the pair
{layer_name: weight}, where weight is the weight tensor.
"""
weights = {}
for layer in self.layers_names:
cur_layer = get_layer(self.model, layer)
if hasattr(cur_layer, "weight"):
weights[layer] = cur_layer.weight.clone().detach()
else:
logger.error(
"Layer {} does not have weight attribute.".format(layer)
)
return weights
def get_indexing(string):
"""
Parse numpy-like fancy indexing from a string.
Args:
string (str): string represent the indices to take
a subset of from array. Indices for each dimension
are separated by `,`; indices for different dimensions
are separated by `;`.
e.g.: For a numpy array `arr` of shape (3,3,3), the string "1,2;1,2"
means taking the sub-array `arr[[1,2], [1,2]]
Returns:
final_indexing (tuple): the parsed indexing.
"""
index_ls = string.strip().split(";")
final_indexing = []
for index in index_ls:
index_single_dim = index.split(",")
index_single_dim = [int(i) for i in index_single_dim]
final_indexing.append(index_single_dim)
return tuple(final_indexing)
def process_layer_index_data(layer_ls, layer_name_prefix=""):
"""
Extract layer names and numpy-like fancy indexing from a string.
Args:
layer_ls (list of strs): list of strings containing data about layer names
and their indexing. For each string, layer name and indexing is separated by whitespaces.
e.g.: [layer1 1,2;2, layer2, layer3 150;3,4]
layer_name_prefix (Optional[str]): prefix to be added to each layer name.
Returns:
layer_name (list of strings): a list of layer names.
indexing_dict (Python dict): a dictionary of the pair
{one_layer_name: indexing_for_that_layer}
"""
layer_name, indexing_dict = [], {}
for layer in layer_ls:
ls = layer.split()
name = layer_name_prefix + ls[0]
layer_name.append(name)
if len(ls) == 2:
indexing_dict[name] = get_indexing(ls[1])
else:
indexing_dict[name] = ()
return layer_name, indexing_dict
def process_cv2_inputs(frames, cfg):
"""
Normalize and prepare inputs as a list of tensors. Each tensor
correspond to a unique pathway.
Args:
frames (list of array): list of input images (correspond to one clip) in range [0, 255].
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
"""
inputs = torch.from_numpy(np.array(frames)).float() / 255
inputs = tensor_normalize(inputs, cfg.DATA.MEAN, cfg.DATA.STD)
# T H W C -> C T H W.
inputs = inputs.permute(3, 0, 1, 2)
# Sample frames for num_frames specified.
index = torch.linspace(0, inputs.shape[1] - 1, cfg.DATA.NUM_FRAMES).long()
inputs = torch.index_select(inputs, 1, index)
inputs = pack_pathway_output(cfg, inputs)
inputs = [inp.unsqueeze(0) for inp in inputs]
return inputs
def get_layer(model, layer_name):
"""
Return the targeted layer (nn.Module Object) given a hierarchical layer name,
separated by /.
Args:
model (model): model to get layers from.
layer_name (str): name of the layer.
Returns:
prev_module (nn.Module): the layer from the model with `layer_name` name.
"""
layer_ls = layer_name.split("/")
prev_module = model
for layer in layer_ls:
prev_module = prev_module._modules[layer]
return prev_module
class TaskInfo:
def __init__(self):
self.frames = None
self.id = -1
self.bboxes = None
self.action_preds = None
self.num_buffer_frames = 0
self.img_height = -1
self.img_width = -1
self.crop_size = -1
self.clip_vis_size = -1
def add_frames(self, idx, frames):
"""
Add the clip and corresponding id.
Args:
idx (int): the current index of the clip.
frames (list[ndarray]): list of images in "BGR" format.
"""
self.frames = frames
self.id = idx
def add_bboxes(self, bboxes):
"""
Add correspondding bounding boxes.
"""
self.bboxes = bboxes
def add_action_preds(self, preds):
"""
Add the corresponding action predictions.
"""
self.action_preds = preds