| import torch | |
| from torch import nn | |
| class ScaledSoftmaxCE(nn.Module): | |
| def forward(self, x, label): | |
| logits = x[..., :-10] | |
| temp_scales = x[..., -10:] | |
| logprobs = logits.softmax(-1) | |
| import torch | |
| from torch import nn | |
| class ScaledSoftmaxCE(nn.Module): | |
| def forward(self, x, label): | |
| logits = x[..., :-10] | |
| temp_scales = x[..., -10:] | |
| logprobs = logits.softmax(-1) | |