Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: ps-license@tuebingen.mpg.de | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| def plot_mask2D(mask, | |
| title="", | |
| point_coords=None, | |
| figsize=10, | |
| point_marker_size=5): | |
| ''' | |
| Simple plotting tool to show intermediate mask predictions and points | |
| where PointRend is applied. | |
| Args: | |
| mask (Tensor): mask prediction of shape HxW | |
| title (str): title for the plot | |
| point_coords ((Tensor, Tensor)): x and y point coordinates | |
| figsize (int): size of the figure to plot | |
| point_marker_size (int): marker size for points | |
| ''' | |
| H, W = mask.shape | |
| plt.figure(figsize=(figsize, figsize)) | |
| if title: | |
| title += ", " | |
| plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30) | |
| plt.ylabel(H, fontsize=30) | |
| plt.xlabel(W, fontsize=30) | |
| plt.xticks([], []) | |
| plt.yticks([], []) | |
| plt.imshow(mask.detach(), | |
| interpolation="nearest", | |
| cmap=plt.get_cmap('gray')) | |
| if point_coords is not None: | |
| plt.scatter(x=point_coords[0], | |
| y=point_coords[1], | |
| color="red", | |
| s=point_marker_size, | |
| clip_on=True) | |
| plt.xlim(-0.5, W - 0.5) | |
| plt.ylim(H - 0.5, -0.5) | |
| plt.show() | |
| def plot_mask3D(mask=None, | |
| title="", | |
| point_coords=None, | |
| figsize=1500, | |
| point_marker_size=8, | |
| interactive=True): | |
| ''' | |
| Simple plotting tool to show intermediate mask predictions and points | |
| where PointRend is applied. | |
| Args: | |
| mask (Tensor): mask prediction of shape DxHxW | |
| title (str): title for the plot | |
| point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates | |
| figsize (int): size of the figure to plot | |
| point_marker_size (int): marker size for points | |
| ''' | |
| import trimesh | |
| import vtkplotter | |
| from skimage import measure | |
| vp = vtkplotter.Plotter(title=title, size=(figsize, figsize)) | |
| vis_list = [] | |
| if mask is not None: | |
| mask = mask.detach().to("cpu").numpy() | |
| mask = mask.transpose(2, 1, 0) | |
| # marching cube to find surface | |
| verts, faces, normals, values = measure.marching_cubes_lewiner( | |
| mask, 0.5, gradient_direction='ascent') | |
| # create a mesh | |
| mesh = trimesh.Trimesh(verts, faces) | |
| mesh.visual.face_colors = [200, 200, 250, 100] | |
| vis_list.append(mesh) | |
| if point_coords is not None: | |
| point_coords = torch.stack(point_coords, 1).to("cpu").numpy() | |
| # import numpy as np | |
| # select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112) | |
| # select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272) | |
| # select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112) | |
| # select = np.logical_and(np.logical_and(select_x, select_y), select_z) | |
| # point_coords = point_coords[select, :] | |
| pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red') | |
| vis_list.append(pc) | |
| vp.show(*vis_list, | |
| bg="white", | |
| axes=1, | |
| interactive=interactive, | |
| azimuth=30, | |
| elevation=30) | |
| def create_grid3D(min, max, steps): | |
| if type(min) is int: | |
| min = (min, min, min) # (x, y, z) | |
| if type(max) is int: | |
| max = (max, max, max) # (x, y) | |
| if type(steps) is int: | |
| steps = (steps, steps, steps) # (x, y, z) | |
| arrangeX = torch.linspace(min[0], max[0], steps[0]).long() | |
| arrangeY = torch.linspace(min[1], max[1], steps[1]).long() | |
| arrangeZ = torch.linspace(min[2], max[2], steps[2]).long() | |
| gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX]) | |
| coords = torch.stack([gridW, girdH, | |
| gridD]) # [2, steps[0], steps[1], steps[2]] | |
| coords = coords.view(3, -1).t() # [N, 3] | |
| return coords | |
| def create_grid2D(min, max, steps): | |
| if type(min) is int: | |
| min = (min, min) # (x, y) | |
| if type(max) is int: | |
| max = (max, max) # (x, y) | |
| if type(steps) is int: | |
| steps = (steps, steps) # (x, y) | |
| arrangeX = torch.linspace(min[0], max[0], steps[0]).long() | |
| arrangeY = torch.linspace(min[1], max[1], steps[1]).long() | |
| girdH, gridW = torch.meshgrid([arrangeY, arrangeX]) | |
| coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]] | |
| coords = coords.view(2, -1).t() # [N, 2] | |
| return coords | |
| class SmoothConv2D(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3): | |
| super().__init__() | |
| assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" | |
| self.padding = (kernel_size - 1) // 2 | |
| weight = torch.ones( | |
| (in_channels, out_channels, kernel_size, kernel_size), | |
| dtype=torch.float32) / (kernel_size**2) | |
| self.register_buffer('weight', weight) | |
| def forward(self, input): | |
| return F.conv2d(input, self.weight, padding=self.padding) | |
| class SmoothConv3D(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3): | |
| super().__init__() | |
| assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}" | |
| self.padding = (kernel_size - 1) // 2 | |
| weight = torch.ones( | |
| (in_channels, out_channels, kernel_size, kernel_size, kernel_size), | |
| dtype=torch.float32) / (kernel_size**3) | |
| self.register_buffer('weight', weight) | |
| def forward(self, input): | |
| return F.conv3d(input, self.weight, padding=self.padding) | |
| def build_smooth_conv3D(in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| padding=1): | |
| smooth_conv = torch.nn.Conv3d(in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| padding=padding) | |
| smooth_conv.weight.data = torch.ones( | |
| (in_channels, out_channels, kernel_size, kernel_size, kernel_size), | |
| dtype=torch.float32) / (kernel_size**3) | |
| smooth_conv.bias.data = torch.zeros(out_channels) | |
| return smooth_conv | |
| def build_smooth_conv2D(in_channels=1, | |
| out_channels=1, | |
| kernel_size=3, | |
| padding=1): | |
| smooth_conv = torch.nn.Conv2d(in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| padding=padding) | |
| smooth_conv.weight.data = torch.ones( | |
| (in_channels, out_channels, kernel_size, kernel_size), | |
| dtype=torch.float32) / (kernel_size**2) | |
| smooth_conv.bias.data = torch.zeros(out_channels) | |
| return smooth_conv | |
| def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points, | |
| **kwargs): | |
| """ | |
| Find `num_points` most uncertain points from `uncertainty_map` grid. | |
| Args: | |
| uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty | |
| values for a set of points on a regular H x W x D grid. | |
| num_points (int): The number of points P to select. | |
| Returns: | |
| point_indices (Tensor): A tensor of shape (N, P) that contains indices from | |
| [0, H x W x D) of the most uncertain points. | |
| point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized | |
| coordinates of the most uncertain points from the H x W x D grid. | |
| """ | |
| R, _, D, H, W = uncertainty_map.shape | |
| # h_step = 1.0 / float(H) | |
| # w_step = 1.0 / float(W) | |
| # d_step = 1.0 / float(D) | |
| num_points = min(D * H * W, num_points) | |
| point_scores, point_indices = torch.topk(uncertainty_map.view( | |
| R, D * H * W), | |
| k=num_points, | |
| dim=1) | |
| point_coords = torch.zeros(R, | |
| num_points, | |
| 3, | |
| dtype=torch.float, | |
| device=uncertainty_map.device) | |
| # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step | |
| # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step | |
| # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step | |
| point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x | |
| point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y | |
| point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z | |
| print(f"resolution {D} x {H} x {W}", point_scores.min(), | |
| point_scores.max()) | |
| return point_indices, point_coords | |
| def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points, | |
| clip_min): | |
| """ | |
| Find `num_points` most uncertain points from `uncertainty_map` grid. | |
| Args: | |
| uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty | |
| values for a set of points on a regular H x W x D grid. | |
| num_points (int): The number of points P to select. | |
| Returns: | |
| point_indices (Tensor): A tensor of shape (N, P) that contains indices from | |
| [0, H x W x D) of the most uncertain points. | |
| point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized | |
| coordinates of the most uncertain points from the H x W x D grid. | |
| """ | |
| R, _, D, H, W = uncertainty_map.shape | |
| # h_step = 1.0 / float(H) | |
| # w_step = 1.0 / float(W) | |
| # d_step = 1.0 / float(D) | |
| assert R == 1, "batchsize > 1 is not implemented!" | |
| uncertainty_map = uncertainty_map.view(D * H * W) | |
| indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) | |
| num_points = min(num_points, indices.size(0)) | |
| point_scores, point_indices = torch.topk(uncertainty_map[indices], | |
| k=num_points, | |
| dim=0) | |
| point_indices = indices[point_indices].unsqueeze(0) | |
| point_coords = torch.zeros(R, | |
| num_points, | |
| 3, | |
| dtype=torch.float, | |
| device=uncertainty_map.device) | |
| # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step | |
| # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step | |
| # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step | |
| point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x | |
| point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y | |
| point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z | |
| # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max()) | |
| return point_indices, point_coords | |
| def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points, | |
| **kwargs): | |
| """ | |
| Find `num_points` most uncertain points from `uncertainty_map` grid. | |
| Args: | |
| uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty | |
| values for a set of points on a regular H x W grid. | |
| num_points (int): The number of points P to select. | |
| Returns: | |
| point_indices (Tensor): A tensor of shape (N, P) that contains indices from | |
| [0, H x W) of the most uncertain points. | |
| point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized | |
| coordinates of the most uncertain points from the H x W grid. | |
| """ | |
| R, _, H, W = uncertainty_map.shape | |
| # h_step = 1.0 / float(H) | |
| # w_step = 1.0 / float(W) | |
| num_points = min(H * W, num_points) | |
| point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W), | |
| k=num_points, | |
| dim=1) | |
| point_coords = torch.zeros(R, | |
| num_points, | |
| 2, | |
| dtype=torch.long, | |
| device=uncertainty_map.device) | |
| # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step | |
| # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step | |
| point_coords[:, :, 0] = (point_indices % W).to(torch.long) | |
| point_coords[:, :, 1] = (point_indices // W).to(torch.long) | |
| # print (point_scores.min(), point_scores.max()) | |
| return point_indices, point_coords | |
| def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points, | |
| clip_min): | |
| """ | |
| Find `num_points` most uncertain points from `uncertainty_map` grid. | |
| Args: | |
| uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty | |
| values for a set of points on a regular H x W grid. | |
| num_points (int): The number of points P to select. | |
| Returns: | |
| point_indices (Tensor): A tensor of shape (N, P) that contains indices from | |
| [0, H x W) of the most uncertain points. | |
| point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized | |
| coordinates of the most uncertain points from the H x W grid. | |
| """ | |
| R, _, H, W = uncertainty_map.shape | |
| # h_step = 1.0 / float(H) | |
| # w_step = 1.0 / float(W) | |
| assert R == 1, "batchsize > 1 is not implemented!" | |
| uncertainty_map = uncertainty_map.view(H * W) | |
| indices = (uncertainty_map >= clip_min).nonzero().squeeze(1) | |
| num_points = min(num_points, indices.size(0)) | |
| point_scores, point_indices = torch.topk(uncertainty_map[indices], | |
| k=num_points, | |
| dim=0) | |
| point_indices = indices[point_indices].unsqueeze(0) | |
| point_coords = torch.zeros(R, | |
| num_points, | |
| 2, | |
| dtype=torch.long, | |
| device=uncertainty_map.device) | |
| # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step | |
| # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step | |
| point_coords[:, :, 0] = (point_indices % W).to(torch.long) | |
| point_coords[:, :, 1] = (point_indices // W).to(torch.long) | |
| # print (point_scores.min(), point_scores.max()) | |
| return point_indices, point_coords | |
| def calculate_uncertainty(logits, classes=None, balance_value=0.5): | |
| """ | |
| We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the | |
| foreground class in `classes`. | |
| Args: | |
| logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or | |
| class-agnostic, where R is the total number of predicted masks in all images and C is | |
| the number of foreground classes. The values are logits. | |
| classes (list): A list of length R that contains either predicted of ground truth class | |
| for eash predicted mask. | |
| Returns: | |
| scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with | |
| the most uncertain locations having the highest uncertainty score. | |
| """ | |
| if logits.shape[1] == 1: | |
| gt_class_logits = logits | |
| else: | |
| gt_class_logits = logits[ | |
| torch.arange(logits.shape[0], device=logits.device), | |
| classes].unsqueeze(1) | |
| return -torch.abs(gt_class_logits - balance_value) | |