from torch import nn, Tensor | |
class RobustCrossEntropyLoss(nn.CrossEntropyLoss): | |
""" | |
this is just a compatibility layer because my target tensor is float and has an extra dimension | |
""" | |
def forward(self, input: Tensor, target: Tensor) -> Tensor: | |
if len(target.shape) == len(input.shape): | |
assert target.shape[1] == 1 | |
target = target[:, 0] | |
return super().forward(input, target.long()) |