| import torch | |
| def distribute_into_histogram(loss: torch.Tensor, | |
| t: torch.Tensor, | |
| num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]: | |
| loss = loss.detach().flatten() | |
| t = t.detach().flatten() | |
| t = (t * num_bins).long() | |
| hist = torch.zeros(num_bins, device=loss.device) | |
| count = torch.zeros(num_bins, device=loss.device) | |
| hist.scatter_add_(0, t, loss) | |
| count.scatter_add_(0, t, torch.ones_like(loss)) | |
| return hist, count | |