Spaces:
Runtime error
Runtime error
import torch | |
def reconstruction(assignment, labels, hard_assignment=None): | |
""" | |
reconstruction | |
Args: | |
assignment: torch.Tensor | |
A Tensor of shape (B, n_spixels, n_pixels) | |
labels: torch.Tensor | |
A Tensor of shape (B, C, n_pixels) | |
hard_assignment: torch.Tensor | |
A Tensor of shape (B, n_pixels) | |
""" | |
labels = labels.permute(0, 2, 1).contiguous() | |
# matrix product between (n_spixels, n_pixels) and (n_pixels, channels) | |
spixel_mean = torch.bmm(assignment, labels) / (assignment.sum(2, keepdim=True) + 1e-16) | |
if hard_assignment is None: | |
# (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels) | |
permuted_assignment = assignment.permute(0, 2, 1).contiguous() | |
# matrix product between (n_pixels, n_spixels) and (n_spixels, channels) | |
reconstructed_labels = torch.bmm(permuted_assignment, spixel_mean) | |
else: | |
# index sampling | |
reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0) | |
return reconstructed_labels.permute(0, 2, 1).contiguous() | |
def reconstruct_loss_with_cross_etnropy(assignment, labels, hard_assignment=None): | |
""" | |
reconstruction loss with cross entropy | |
Args: | |
assignment: torch.Tensor | |
A Tensor of shape (B, n_spixels, n_pixels) | |
labels: torch.Tensor | |
A Tensor of shape (B, C, n_pixels) | |
hard_assignment: torch.Tensor | |
A Tensor of shape (B, n_pixels) | |
""" | |
reconstracted_labels = reconstruction(assignment, labels, hard_assignment) | |
reconstracted_labels = reconstracted_labels / (1e-16 + reconstracted_labels.sum(1, keepdim=True)) | |
mask = labels > 0 | |
return -(reconstracted_labels[mask] + 1e-16).log().mean() | |
def reconstruct_loss_with_mse(assignment, labels, hard_assignment=None): | |
""" | |
reconstruction loss with mse | |
Args: | |
assignment: torch.Tensor | |
A Tensor of shape (B, n_spixels, n_pixels) | |
labels: torch.Tensor | |
A Tensor of shape (B, C, n_pixels) | |
hard_assignment: torch.Tensor | |
A Tensor of shape (B, n_pixels) | |
""" | |
reconstracted_labels = reconstruction(assignment, labels, hard_assignment) | |
return torch.nn.functional.mse_loss(reconstracted_labels, labels) |