"""Utility functions for videos, plotting and computing performance metrics.""" import os import typing import cv2 # pytype: disable=attribute-error import matplotlib import numpy as np import torch import tqdm from . import video from . import segmentation def loadvideo(filename: str) -> np.ndarray: """Loads a video from a file. Args: filename (str): filename of video Returns: A np.ndarray with dimensions (channels=3, frames, height, width). The values will be uint8's ranging from 0 to 255. Raises: FileNotFoundError: Could not find `filename` ValueError: An error occurred while reading the video """ if not os.path.exists(filename): raise FileNotFoundError(filename) capture = cv2.VideoCapture(filename) frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8) for count in range(frame_count): ret, frame = capture.read() if not ret: raise ValueError("Failed to load frame #{} of {}.".format(count, filename)) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) v[count, :, :] = frame v = v.transpose((3, 0, 1, 2)) return v def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1): """Saves a video to a file. Args: filename (str): filename of video array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width) fps (float or int): frames per second Returns: None """ c, _, height, width = array.shape if c != 3: raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape)))) fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') out = cv2.VideoWriter(filename, fourcc, fps, (width, height)) for frame in array.transpose((1, 2, 3, 0)): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame) def get_mean_and_std(dataset: torch.utils.data.Dataset, samples: int = 128, batch_size: int = 8, num_workers: int = 4): """Computes mean and std from samples from a Pytorch dataset. Args: dataset (torch.utils.data.Dataset): A Pytorch dataset. ``dataset[i][0]'' is expected to be the i-th video in the dataset, which should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and standard deviation are computed over all elements. Defaults to 128. batch_size (int, optional): how many samples per batch to load Defaults to 8. num_workers (int, optional): how many subprocesses to use for data loading. If 0, the data will be loaded in the main process. Defaults to 4. Returns: A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,). """ if samples is not None and len(dataset) > samples: indices = np.random.choice(len(dataset), samples, replace=False) dataset = torch.utils.data.Subset(dataset, indices) dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) n = 0 # number of elements taken (should be equal to samples by end of for loop) s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,)) s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,)) for (x, *_) in tqdm.tqdm(dataloader): x = x.transpose(0, 1).contiguous().view(3, -1) n += x.shape[1] s1 += torch.sum(x, dim=1).numpy() s2 += torch.sum(x ** 2, dim=1).numpy() mean = s1 / n # type: np.ndarray std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray mean = mean.astype(np.float32) std = std.astype(np.float32) return mean, std def bootstrap(a, b, func, samples=10000): """Computes a bootstrapped confidence intervals for ``func(a, b)''. Args: a (array_like): first argument to `func`. b (array_like): second argument to `func`. func (callable): Function to compute confidence intervals for. ``dataset[i][0]'' is expected to be the i-th video in the dataset, which should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) samples (int, optional): Number of samples to compute. Defaults to 10000. Returns: A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile). """ a = np.array(a) b = np.array(b) bootstraps = [] for _ in range(samples): ind = np.random.choice(len(a), len(a)) bootstraps.append(func(a[ind], b[ind])) bootstraps = sorted(bootstraps) return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))] def latexify(): """Sets matplotlib params to appear more like LaTeX. Based on https://nipunbatra.github.io/blog/2014/latexify.html """ params = {'backend': 'pdf', 'axes.titlesize': 8, 'axes.labelsize': 8, 'font.size': 8, 'legend.fontsize': 8, 'xtick.labelsize': 8, 'ytick.labelsize': 8, 'font.family': 'DejaVu Serif', 'font.serif': 'Computer Modern', } matplotlib.rcParams.update(params) def dice_similarity_coefficient(inter, union): """Computes the dice similarity coefficient. Args: inter (iterable): iterable of the intersections union (iterable): iterable of the unions """ return 2 * sum(inter) / (sum(union) + sum(inter)) __all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"]