# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt def plot_mask2D(mask, title="", point_coords=None, figsize=10, point_marker_size=5): ''' Simple plotting tool to show intermediate mask predictions and points where PointRend is applied. Args: mask (Tensor): mask prediction of shape HxW title (str): title for the plot point_coords ((Tensor, Tensor)): x and y point coordinates figsize (int): size of the figure to plot point_marker_size (int): marker size for points ''' H, W = mask.shape plt.figure(figsize=(figsize, figsize)) if title: title += ", " plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30) plt.ylabel(H, fontsize=30) plt.xlabel(W, fontsize=30) plt.xticks([], []) plt.yticks([], []) plt.imshow(mask.detach(), interpolation="nearest", cmap=plt.get_cmap('gray')) if point_coords is not None: plt.scatter(x=point_coords[0], y=point_coords[1], color="red", s=point_marker_size, clip_on=True) plt.xlim(-0.5, W - 0.5) plt.ylim(H - 0.5, -0.5) plt.show() def plot_mask3D(mask=None, title="", point_coords=None, figsize=1500, point_marker_size=8, interactive=True): ''' Simple plotting tool to show intermediate mask predictions and points where PointRend is applied. Args: mask (Tensor): mask prediction of shape DxHxW title (str): title for the plot point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates figsize (int): size of the figure to plot point_marker_size (int): marker size for points ''' import trimesh import vtkplotter from skimage import measure vp = vtkplotter.Plotter(title=title, size=(figsize, figsize)) vis_list = [] if mask is not None: mask = mask.detach().to("cpu").numpy() mask = mask.transpose(2, 1, 0) # marching cube to find surface verts, faces, normals, values = measure.marching_cubes_lewiner( mask, 0.5, gradient_direction='ascent') # create a mesh mesh = trimesh.Trimesh(verts, faces) mesh.visual.face_colors = [200, 200, 250, 100] vis_list.append(mesh) if point_coords is not None: point_coords = torch.stack(point_coords, 1).to("cpu").numpy() # import numpy as np # select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112) # select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272) # select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112) # select = np.logical_and(np.logical_and(select_x, select_y), select_z) # point_coords = point_coords[select, :] pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red') vis_list.append(pc) vp.show(*vis_list, bg="white", axes=1, interactive=interactive, azimuth=30, elevation=30) def create_grid3D(min, max, steps): if type(min) is int: min = (min, min, min) # (x, y, z) if type(max) is int: max = (max, max, max) # (x, y) if type(steps) is int: steps = (steps, steps, steps) # (x, y, z) arrangeX = torch.linspace(min[0], max[0], steps[0]).long() arrangeY = torch.linspace(min[1], max[1], steps[1]).long() arrangeZ = torch.linspace(min[2], max[2], steps[2]).long() gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX]) coords = torch.stack([gridW, girdH, gridD]) # [2, steps[0], steps[1], steps[2]] coords = coords.view(3, -1).t() # [N, 3] return coords def create_grid2D(min, max, steps): if type(min) is int: min = (min, min) # (x, y) if type(max) is int: max = (max, max) # (x, y) if type(steps) is int: steps = (steps, steps) # (x, y) arrangeX = torch.linspace(min[0], max[0], steps[0]).long() arrangeY = torch.linspace(min[1], max[1], steps[1]).long() girdH, gridW = torch.meshgrid([arrangeY, arrangeX]) coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]] coords = coords.view(2, -1).t() # [N, 2] return coords class SmoothConv2D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3): super().__init__() assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" self.padding = (kernel_size - 1) // 2 weight = torch.ones( (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32) / (kernel_size**2) self.register_buffer('weight', weight) def forward(self, input): return F.conv2d(input, self.weight, padding=self.padding) class SmoothConv3D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3): super().__init__() assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" self.padding = (kernel_size - 1) // 2 weight = torch.ones( (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32) / (kernel_size**3) self.register_buffer('weight', weight) def forward(self, input): return F.conv3d(input, self.weight, padding=self.padding) def build_smooth_conv3D(in_channels=1, out_channels=1, kernel_size=3, padding=1): smooth_conv = torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding) smooth_conv.weight.data = torch.ones( (in_channels, out_channels, kernel_size, kernel_size, kernel_size), dtype=torch.float32) / (kernel_size**3) smooth_conv.bias.data = torch.zeros(out_channels) return smooth_conv def build_smooth_conv2D(in_channels=1, out_channels=1, kernel_size=3, padding=1): smooth_conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding) smooth_conv.weight.data = torch.ones( (in_channels, out_channels, kernel_size, kernel_size), dtype=torch.float32) / (kernel_size**2) smooth_conv.bias.data = torch.zeros(out_channels) return smooth_conv def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, **kwargs): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty values for a set of points on a regular H x W x D grid. num_points (int): The number of points P to select. Returns: point_indices (Tensor): A tensor of shape (N, P) that contains indices from [0, H x W x D) of the most uncertain points. point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the H x W x D grid. """ R, _, D, H, W = uncertainty_map.shape # h_step = 1.0 / float(H) # w_step = 1.0 / float(W) # d_step = 1.0 / float(D) num_points = min(D * H * W, num_points) point_scores, point_indices = torch.topk(uncertainty_map.view( R, D * H * W), k=num_points, dim=1) point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device) # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z print(f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max()) return point_indices, point_coords def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, clip_min): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty values for a set of points on a regular H x W x D grid. num_points (int): The number of points P to select. Returns: point_indices (Tensor): A tensor of shape (N, P) that contains indices from [0, H x W x D) of the most uncertain points. point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the H x W x D grid. """ R, _, D, H, W = uncertainty_map.shape # h_step = 1.0 / float(H) # w_step = 1.0 / float(W) # d_step = 1.0 / float(D) assert R == 1, "batchsize > 1 is not implemented!" uncertainty_map = uncertainty_map.view(D * H * W) indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) num_points = min(num_points, indices.size(0)) point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0) point_indices = indices[point_indices].unsqueeze(0) point_coords = torch.zeros(R, num_points, 3, dtype=torch.float, device=uncertainty_map.device) # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max()) return point_indices, point_coords def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, **kwargs): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty values for a set of points on a regular H x W grid. num_points (int): The number of points P to select. Returns: point_indices (Tensor): A tensor of shape (N, P) that contains indices from [0, H x W) of the most uncertain points. point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the H x W grid. """ R, _, H, W = uncertainty_map.shape # h_step = 1.0 / float(H) # w_step = 1.0 / float(W) num_points = min(H * W, num_points) point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1) point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device) # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step point_coords[:, :, 0] = (point_indices % W).to(torch.long) point_coords[:, :, 1] = (point_indices // W).to(torch.long) # print (point_scores.min(), point_scores.max()) return point_indices, point_coords def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, clip_min): """ Find `num_points` most uncertain points from `uncertainty_map` grid. Args: uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty values for a set of points on a regular H x W grid. num_points (int): The number of points P to select. Returns: point_indices (Tensor): A tensor of shape (N, P) that contains indices from [0, H x W) of the most uncertain points. point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the H x W grid. """ R, _, H, W = uncertainty_map.shape # h_step = 1.0 / float(H) # w_step = 1.0 / float(W) assert R == 1, "batchsize > 1 is not implemented!" uncertainty_map = uncertainty_map.view(H * W) indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) num_points = min(num_points, indices.size(0)) point_scores, point_indices = torch.topk(uncertainty_map[indices], k=num_points, dim=0) point_indices = indices[point_indices].unsqueeze(0) point_coords = torch.zeros(R, num_points, 2, dtype=torch.long, device=uncertainty_map.device) # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step point_coords[:, :, 0] = (point_indices % W).to(torch.long) point_coords[:, :, 1] = (point_indices // W).to(torch.long) # print (point_scores.min(), point_scores.max()) return point_indices, point_coords def calculate_uncertainty(logits, classes=None, balance_value=0.5): """ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is the number of foreground classes. The values are logits. classes (list): A list of length R that contains either predicted of ground truth class for eash predicted mask. Returns: scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ if logits.shape[1] == 1: gt_class_logits = logits else: gt_class_logits = logits[ torch.arange(logits.shape[0], device=logits.device), classes].unsqueeze(1) return -torch.abs(gt_class_logits - balance_value)