Spaces:
Runtime error
Runtime error
File size: 2,308 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
def reconstruction(assignment, labels, hard_assignment=None):
"""
reconstruction
Args:
assignment: torch.Tensor
A Tensor of shape (B, n_spixels, n_pixels)
labels: torch.Tensor
A Tensor of shape (B, C, n_pixels)
hard_assignment: torch.Tensor
A Tensor of shape (B, n_pixels)
"""
labels = labels.permute(0, 2, 1).contiguous()
# matrix product between (n_spixels, n_pixels) and (n_pixels, channels)
spixel_mean = torch.bmm(assignment, labels) / (assignment.sum(2, keepdim=True) + 1e-16)
if hard_assignment is None:
# (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels)
permuted_assignment = assignment.permute(0, 2, 1).contiguous()
# matrix product between (n_pixels, n_spixels) and (n_spixels, channels)
reconstructed_labels = torch.bmm(permuted_assignment, spixel_mean)
else:
# index sampling
reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0)
return reconstructed_labels.permute(0, 2, 1).contiguous()
def reconstruct_loss_with_cross_etnropy(assignment, labels, hard_assignment=None):
"""
reconstruction loss with cross entropy
Args:
assignment: torch.Tensor
A Tensor of shape (B, n_spixels, n_pixels)
labels: torch.Tensor
A Tensor of shape (B, C, n_pixels)
hard_assignment: torch.Tensor
A Tensor of shape (B, n_pixels)
"""
reconstracted_labels = reconstruction(assignment, labels, hard_assignment)
reconstracted_labels = reconstracted_labels / (1e-16 + reconstracted_labels.sum(1, keepdim=True))
mask = labels > 0
return -(reconstracted_labels[mask] + 1e-16).log().mean()
def reconstruct_loss_with_mse(assignment, labels, hard_assignment=None):
"""
reconstruction loss with mse
Args:
assignment: torch.Tensor
A Tensor of shape (B, n_spixels, n_pixels)
labels: torch.Tensor
A Tensor of shape (B, C, n_pixels)
hard_assignment: torch.Tensor
A Tensor of shape (B, n_pixels)
"""
reconstracted_labels = reconstruction(assignment, labels, hard_assignment)
return torch.nn.functional.mse_loss(reconstracted_labels, labels) |