ho11laqe's picture
init
ecf08bc
raw
history blame
No virus
438 Bytes
from torch import nn, Tensor
class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
"""
this is just a compatibility layer because my target tensor is float and has an extra dimension
"""
def forward(self, input: Tensor, target: Tensor) -> Tensor:
if len(target.shape) == len(input.shape):
assert target.shape[1] == 1
target = target[:, 0]
return super().forward(input, target.long())