# This code is licensed under a non-commercial license. | |
import torch | |
from torch.autograd import Variable | |
def to_var(x, requires_grad=False, volatile=False): | |
if torch.cuda.is_available(): | |
x = x.cuda() | |
return Variable(x, requires_grad=requires_grad, volatile=volatile) | |
def top_k_logits(logits, k, probs=False): | |
""" | |
Masks everything but the k top entries as -infinity (1e10). | |
Used to mask logits such that e^-infinity -> 0 won't contribute to the | |
sum of the denominator. | |
""" | |
if k == 0: | |
return logits | |
else: | |
values = torch.topk(logits, k)[0] | |
batch_mins = values[:, -1].view(-1, 1).expand_as(logits) | |
if probs: | |
return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) | |
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits) | |