Spaces:
Paused
Paused
| import torch | |
| import torch.nn.functional as F | |
| from collections import defaultdict | |
| def make_positions(tensor, padding_idx): | |
| """Replace non-padding symbols with their position numbers. | |
| Position numbers begin at padding_idx+1. Padding symbols are ignored. | |
| """ | |
| # The series of casts and type-conversions here are carefully | |
| # balanced to both work with ONNX export and XLA. In particular XLA | |
| # prefers ints, cumsum defaults to output longs, and ONNX doesn't know | |
| # how to handle the dtype kwarg in cumsum. | |
| mask = tensor.ne(padding_idx).int() | |
| return ( | |
| torch.cumsum(mask, dim=1).type_as(mask) * mask | |
| ).long() + padding_idx | |
| def fill_with_neg_inf2(t): | |
| """FP16-compatible function that fills a tensor with -inf.""" | |
| return t.float().fill_(-1e8).type_as(t) | |
| def softmax(x, dim): | |
| return F.softmax(x, dim=dim, dtype=torch.float32) | |