import torch def reconstruction(assignment, labels, hard_assignment=None): """ reconstruction Args: assignment: torch.Tensor A Tensor of shape (B, n_spixels, n_pixels) labels: torch.Tensor A Tensor of shape (B, C, n_pixels) hard_assignment: torch.Tensor A Tensor of shape (B, n_pixels) """ labels = labels.permute(0, 2, 1).contiguous() # matrix product between (n_spixels, n_pixels) and (n_pixels, channels) spixel_mean = torch.bmm(assignment, labels) / (assignment.sum(2, keepdim=True) + 1e-16) if hard_assignment is None: # (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels) permuted_assignment = assignment.permute(0, 2, 1).contiguous() # matrix product between (n_pixels, n_spixels) and (n_spixels, channels) reconstructed_labels = torch.bmm(permuted_assignment, spixel_mean) else: # index sampling reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0) return reconstructed_labels.permute(0, 2, 1).contiguous() def reconstruct_loss_with_cross_etnropy(assignment, labels, hard_assignment=None): """ reconstruction loss with cross entropy Args: assignment: torch.Tensor A Tensor of shape (B, n_spixels, n_pixels) labels: torch.Tensor A Tensor of shape (B, C, n_pixels) hard_assignment: torch.Tensor A Tensor of shape (B, n_pixels) """ reconstracted_labels = reconstruction(assignment, labels, hard_assignment) reconstracted_labels = reconstracted_labels / (1e-16 + reconstracted_labels.sum(1, keepdim=True)) mask = labels > 0 return -(reconstracted_labels[mask] + 1e-16).log().mean() def reconstruct_loss_with_mse(assignment, labels, hard_assignment=None): """ reconstruction loss with mse Args: assignment: torch.Tensor A Tensor of shape (B, n_spixels, n_pixels) labels: torch.Tensor A Tensor of shape (B, C, n_pixels) hard_assignment: torch.Tensor A Tensor of shape (B, n_pixels) """ reconstracted_labels = reconstruction(assignment, labels, hard_assignment) return torch.nn.functional.mse_loss(reconstracted_labels, labels)