Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,843 Bytes
812b01c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import torch
import torch.nn as nn
class TaikoEnergyLoss(nn.Module):
def __init__(self, reduction="mean"):
super().__init__()
# Use 'none' reduction to get element-wise losses, then manually apply masking and reduction
self.mse_loss = nn.MSELoss(reduction="none")
self.reduction = reduction
def forward(self, outputs, batch):
"""
Calculates the MSE loss for energy-based predictions.
Args:
outputs (dict): Model output, containing 'presence' tensor.
outputs['presence'] shape: (B, T, 3) for don, ka, drumroll energies.
batch (dict): Batch data from collate_fn, containing true labels and lengths.
batch['don_labels'], batch['ka_labels'], batch['drumroll_labels'] shape: (B, T)
batch['lengths'] shape: (B,) - valid sequence lengths for time dimension T.
Returns:
torch.Tensor: The calculated loss.
"""
pred_energies = outputs["presence"] # (B, T, 3)
true_don = batch["don_labels"] # (B, T)
true_ka = batch["ka_labels"] # (B, T)
true_drumroll = batch["drumroll_labels"] # (B, T)
# Stack true labels to match the structure of pred_energies (B, T, 3)
true_energies = torch.stack([true_don, true_ka, true_drumroll], dim=2)
B, T, _ = pred_energies.shape
# Create a mask based on batch['lengths'] to ignore padded parts of sequences
# batch['lengths'] gives the actual length of each sequence in the batch
# mask shape: (B, T)
mask_2d = torch.arange(T, device=pred_energies.device).expand(B, T) < batch[
"lengths"
].unsqueeze(1)
# Expand mask to (B, T, 1) to broadcast across the 3 energy channels
mask_3d = mask_2d.unsqueeze(2)
# Calculate element-wise MSE loss
loss_elementwise = self.mse_loss(pred_energies, true_energies) # (B, T, 3)
# Apply the mask to the loss
masked_loss = loss_elementwise * mask_3d
if self.reduction == "mean":
# Sum the loss over all valid (unmasked) elements and divide by the number of valid elements
total_loss = masked_loss.sum()
num_valid_elements = mask_3d.sum() # Total number of unmasked float values
if num_valid_elements > 0:
return total_loss / num_valid_elements
else:
# Avoid division by zero if there are no valid elements (e.g., empty batch or all lengths are 0)
return torch.tensor(
0.0, device=pred_energies.device, requires_grad=True
)
elif self.reduction == "sum":
return masked_loss.sum()
else: # 'none' or any other case
return masked_loss
|