File size: 438 Bytes
ecf08bc |
1 2 3 4 5 6 7 8 9 10 11 12 |
from torch import nn, Tensor
class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
"""
this is just a compatibility layer because my target tensor is float and has an extra dimension
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
if len(target.shape) == len(input.shape):
assert target.shape[1] == 1
target = target[:, 0]
return super().forward(input, target.long()) |