""" File: config.py Author: Elena Ryumina and Dmitry Ryumin Description: Plotting statistical information. License: MIT License """ import matplotlib.pyplot as plt import numpy as np import cv2 import torch # Importing necessary components for the Gradio app from app.config import DICT_PRED def show_cam_on_image( img: np.ndarray, mask: np.ndarray, use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET, image_weight: float = 0.5, ) -> np.ndarray: """This function overlays the cam mask on the image as an heatmap. By default the heatmap is in BGR format. :param img: The base image in RGB or BGR format. :param mask: The cam mask. :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. :param colormap: The OpenCV colormap to be used. :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. :returns: The default image with the cam overlay. Implemented by https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py """ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) if use_rgb: heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 if np.max(img) > 1: raise Exception("The input image should np.float32 in the range [0, 1]") if image_weight < 0 or image_weight > 1: raise Exception( f"image_weight should be in the range [0, 1].\ Got: {image_weight}" ) cam = (1 - image_weight) * heatmap + image_weight * img cam = cam / np.max(cam) return np.uint8(255 * cam) def get_heatmaps( gradients, activations, name_layer, face_image, use_rgb=True, image_weight=0.6 ): gradient = gradients[name_layer] activation = activations[name_layer] pooled_gradients = torch.mean(gradient[0], dim=[0, 2, 3]) for i in range(activation.size()[1]): activation[:, i, :, :] *= pooled_gradients[i] heatmap = torch.mean(activation, dim=1).squeeze().cpu() heatmap = np.maximum(heatmap, 0) heatmap /= torch.max(heatmap) heatmap = torch.unsqueeze(heatmap, -1) heatmap = cv2.resize(heatmap.detach().numpy(), (224, 224)) cur_face_hm = cv2.resize(face_image, (224, 224)) cur_face_hm = np.float32(cur_face_hm) / 255 heatmap = show_cam_on_image( cur_face_hm, heatmap, use_rgb=use_rgb, image_weight=image_weight ) return heatmap def plot_compound_expression_prediction( dict_preds: dict[str, list[float]], save_path: str = None, frame_indices: list[int] = None, colors: list[str] = ["green", "orange", "red", "purple", "blue"], figsize: tuple = (12, 6), title: str = "Confusion Matrix", ) -> plt.Figure: fig, ax = plt.subplots(figsize=figsize) for idx, (k, v) in enumerate(dict_preds.items()): if idx == 2: offset = (idx+1 - len(dict_preds) // 2) * 0.1 elif idx == 3: offset = (idx-1 - len(dict_preds) // 2) * 0.1 else: offset = (idx - len(dict_preds) // 2) * 0.1 shifted_v = [val + offset + 1 for val in v] ax.plot(range(1, len(shifted_v) + 1), shifted_v, color=colors[idx], linestyle='dotted', label=k) ax.legend() ax.grid(True) ax.set_xlabel("Number of frames") ax.set_ylabel("Basic emotion / compound expression") ax.set_title(title) ax.set_xticks([i+1 for i in frame_indices]) ax.set_yticks( range(0, 21) ) ax.set_yticklabels([''] + list(DICT_PRED.values()) + ['']) fig.tight_layout() if save_path: fig.savefig( save_path, format=save_path.rsplit(".", 1)[1], bbox_inches="tight", pad_inches=0, ) return fig def display_frame_info(img, text, margin=1.0, box_scale=1.0): img_copy = img.copy() img_h, img_w, _ = img_copy.shape line_width = int(min(img_h, img_w) * 0.001) thickness = max(int(line_width / 3), 1) font_face = cv2.FONT_HERSHEY_SIMPLEX font_color = (0, 0, 0) font_scale = thickness / 1.5 t_w, t_h = cv2.getTextSize(text, font_face, font_scale, None)[0] margin_n = int(t_h * margin) sub_img = img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale), img_w - t_w - margin_n - int(2 * t_h * box_scale): img_w - margin_n] white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255 img_copy[0 + margin_n: 0 + margin_n + t_h + int(2 * t_h * box_scale), img_w - t_w - margin_n - int(2 * t_h * box_scale):img_w - margin_n] = cv2.addWeighted(sub_img, 0.5, white_rect, .5, 1.0) cv2.putText(img=img_copy, text=text, org=(img_w - t_w - margin_n - int(2 * t_h * box_scale) // 2, 0 + margin_n + t_h + int(2 * t_h * box_scale) // 2), fontFace=font_face, fontScale=font_scale, color=font_color, thickness=thickness, lineType=cv2.LINE_AA, bottomLeftOrigin=False) return img_copy def plot_audio(time_axis, waveform, frame_indices, fps, figsize=(10, 4)) -> plt.Figure: frame_times = np.array(frame_indices) / fps fig, ax = plt.subplots(figsize=figsize) ax.plot(time_axis, waveform[0]) ax.set_xlabel('Time (frames)') ax.set_ylabel('Amplitude') ax.grid(True) ax.set_xticks(frame_times) ax.set_xticklabels([f'{int(frame_time*fps)+1}' for frame_time in frame_times]) fig.tight_layout() return fig def plot_images(image_paths): fig, axes = plt.subplots(1, len(image_paths), figsize=(12, 2)) for ax, img_path in zip(axes, image_paths): ax.imshow(img_path) ax.axis('off') fig.tight_layout() return fig