sunshineatnoon
Add application file
1b2a9b1
raw
history blame
No virus
2.31 kB
import torch
def reconstruction(assignment, labels, hard_assignment=None):
"""
reconstruction
Args:
assignment: torch.Tensor
A Tensor of shape (B, n_spixels, n_pixels)
labels: torch.Tensor
A Tensor of shape (B, C, n_pixels)
hard_assignment: torch.Tensor
A Tensor of shape (B, n_pixels)
"""
labels = labels.permute(0, 2, 1).contiguous()
# matrix product between (n_spixels, n_pixels) and (n_pixels, channels)
spixel_mean = torch.bmm(assignment, labels) / (assignment.sum(2, keepdim=True) + 1e-16)
if hard_assignment is None:
# (B, n_spixels, n_pixels) -> (B, n_pixels, n_spixels)
permuted_assignment = assignment.permute(0, 2, 1).contiguous()
# matrix product between (n_pixels, n_spixels) and (n_spixels, channels)
reconstructed_labels = torch.bmm(permuted_assignment, spixel_mean)
else:
# index sampling
reconstructed_labels = torch.stack([sm[ha, :] for sm, ha in zip(spixel_mean, hard_assignment)], 0)
return reconstructed_labels.permute(0, 2, 1).contiguous()
def reconstruct_loss_with_cross_etnropy(assignment, labels, hard_assignment=None):
"""
reconstruction loss with cross entropy
Args:
assignment: torch.Tensor
A Tensor of shape (B, n_spixels, n_pixels)
labels: torch.Tensor
A Tensor of shape (B, C, n_pixels)
hard_assignment: torch.Tensor
A Tensor of shape (B, n_pixels)
"""
reconstracted_labels = reconstruction(assignment, labels, hard_assignment)
reconstracted_labels = reconstracted_labels / (1e-16 + reconstracted_labels.sum(1, keepdim=True))
mask = labels > 0
return -(reconstracted_labels[mask] + 1e-16).log().mean()
def reconstruct_loss_with_mse(assignment, labels, hard_assignment=None):
"""
reconstruction loss with mse
Args:
assignment: torch.Tensor
A Tensor of shape (B, n_spixels, n_pixels)
labels: torch.Tensor
A Tensor of shape (B, C, n_pixels)
hard_assignment: torch.Tensor
A Tensor of shape (B, n_pixels)
"""
reconstracted_labels = reconstruction(assignment, labels, hard_assignment)
return torch.nn.functional.mse_loss(reconstracted_labels, labels)