AVCER / app /plot.py
ElenaRyumina's picture
Summary
47aeb66
raw
history blame
No virus
5.82 kB
"""
File: config.py
Author: Elena Ryumina and Dmitry Ryumin
Description: Plotting statistical information.
License: MIT License
"""
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
# Importing necessary components for the Gradio app
from app.config import DICT_PRED
def show_cam_on_image(
img: np.ndarray,
mask: np.ndarray,
use_rgb: bool = False,
colormap: int = cv2.COLORMAP_JET,
image_weight: float = 0.5,
) -> np.ndarray:
"""This function overlays the cam mask on the image as an heatmap.
By default the heatmap is in BGR format.
:param img: The base image in RGB or BGR format.
:param mask: The cam mask.
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
:param colormap: The OpenCV colormap to be used.
:param image_weight: The final result is image_weight * img + (1-image_weight) * mask.
:returns: The default image with the cam overlay.
Implemented by https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
"""
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
if use_rgb:
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
heatmap = np.float32(heatmap) / 255
if np.max(img) > 1:
raise Exception("The input image should np.float32 in the range [0, 1]")
if image_weight < 0 or image_weight > 1:
raise Exception(
f"image_weight should be in the range [0, 1].\
Got: {image_weight}"
)
cam = (1 - image_weight) * heatmap + image_weight * img
cam = cam / np.max(cam)
return np.uint8(255 * cam)
def get_heatmaps(
gradients, activations, name_layer, face_image, use_rgb=True, image_weight=0.6
):
gradient = gradients[name_layer]
activation = activations[name_layer]
pooled_gradients = torch.mean(gradient[0], dim=[0, 2, 3])
for i in range(activation.size()[1]):
activation[:, i, :, :] *= pooled_gradients[i]
heatmap = torch.mean(activation, dim=1).squeeze().cpu()
heatmap = np.maximum(heatmap, 0)
heatmap /= torch.max(heatmap)
heatmap = torch.unsqueeze(heatmap, -1)
heatmap = cv2.resize(heatmap.detach().numpy(), (224, 224))
cur_face_hm = cv2.resize(face_image, (224, 224))
cur_face_hm = np.float32(cur_face_hm) / 255
heatmap = show_cam_on_image(
cur_face_hm, heatmap, use_rgb=use_rgb, image_weight=image_weight
)
return heatmap
def plot_compound_expression_prediction(
dict_preds: dict[str, list[float]],
save_path: str = None,
frame_indices: list[int] = None,
colors: list[str] = ["green", "orange", "red", "purple", "blue"],
figsize: tuple = (12, 6),
title: str = "Confusion Matrix",
) -> plt.Figure:
fig, ax = plt.subplots(figsize=figsize)
for idx, (k, v) in enumerate(dict_preds.items()):
if idx == 2:
offset = (idx+1 - len(dict_preds) // 2) * 0.1
elif idx == 3:
offset = (idx-1 - len(dict_preds) // 2) * 0.1
else:
offset = (idx - len(dict_preds) // 2) * 0.1
shifted_v = [val + offset + 1 for val in v]
ax.plot(range(1, len(shifted_v) + 1), shifted_v, color=colors[idx], linestyle='dotted', label=k)
ax.legend()
ax.grid(True)
ax.set_xlabel("Number of frames")
ax.set_ylabel("Basic emotion / compound expression")
ax.set_title(title)
ax.set_xticks([i+1 for i in frame_indices])
ax.set_yticks(
range(0, 21)
)
ax.set_yticklabels([''] + list(DICT_PRED.values()) + [''])
fig.tight_layout()
if save_path:
fig.savefig(
save_path,
format=save_path.rsplit(".", 1)[1],
bbox_inches="tight",
pad_inches=0,
)
return fig
def display_frame_info(img, text, margin=1.0, box_scale=1.0):
img_copy = img.copy()
img_h, img_w, _ = img_copy.shape
line_width = int(min(img_h, img_w) * 0.001)
thickness = max(int(line_width / 3), 1)
font_face = cv2.FONT_HERSHEY_SIMPLEX
font_color = (0, 0, 0)
font_scale = thickness / 1.5
t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0]
margin_n = int(t_h * margin)
sub_img = img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n]
white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale),
img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5, 1.0)
cv2.putText(img=img_copy,
text=text,
org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2,
0 + margin_n + t_h + int(2 * t_h * box_scale) // 2),
fontFace=font_face,
fontScale=font_scale,
color=font_color,
thickness=thickness,
lineType=cv2.LINE_AA,
bottomLeftOrigin=False)
return img_copy
def plot_audio(time_axis, waveform, frame_indices, fps, figsize=(10, 4)) -> plt.Figure:
frame_times = np.array(frame_indices) / fps
fig, ax = plt.subplots(figsize=figsize)
ax.plot(time_axis, waveform[0])
ax.set_xlabel('Time (frames)')
ax.set_ylabel('Amplitude')
ax.grid(True)
ax.set_xticks(frame_times)
ax.set_xticklabels([f'{int(frame_time*fps)+1}' for frame_time in frame_times])
fig.tight_layout()
return fig
def plot_images(image_paths):
fig, axes = plt.subplots(1, len(image_paths), figsize=(12, 2))
for ax, img_path in zip(axes, image_paths):
ax.imshow(img_path)
ax.axis('off')
fig.tight_layout()
return fig