|
import torch |
|
import torch.nn.functional as F |
|
import math |
|
from einops import rearrange |
|
|
|
|
|
def lengths_to_mask(lengths, max_len): |
|
|
|
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) |
|
return mask |
|
|
|
|
|
def get_pad_mask_idx(seq, pad_idx): |
|
return (seq != pad_idx).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_subsequent_mask(seq): |
|
sz_b, seq_len = seq.shape |
|
subsequent_mask = (1 - torch.triu( |
|
torch.ones((1, seq_len, seq_len)), diagonal=1)).bool() |
|
return subsequent_mask.to(seq.device) |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
def eval_decorator(fn): |
|
def inner(model, *args, **kwargs): |
|
was_training = model.training |
|
model.eval() |
|
out = fn(model, *args, **kwargs) |
|
model.train(was_training) |
|
return out |
|
return inner |
|
|
|
def l2norm(t): |
|
return F.normalize(t, dim = -1) |
|
|
|
|
|
|
|
|
|
def get_mask_subset_prob(mask, prob): |
|
subset_mask = torch.bernoulli(mask, p=prob) & mask |
|
return subset_mask |
|
|
|
|
|
|
|
def get_mask_special_tokens(ids, special_ids): |
|
mask = torch.zeros_like(ids).bool() |
|
for special_id in special_ids: |
|
mask |= (ids==special_id) |
|
return mask |
|
|
|
|
|
def _get_activation_fn(activation): |
|
if activation == "relu": |
|
return F.relu |
|
elif activation == "gelu": |
|
return F.gelu |
|
|
|
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) |
|
|
|
|
|
|
|
def uniform(shape, device=None): |
|
return torch.zeros(shape, device=device).float().uniform_(0, 1) |
|
|
|
def prob_mask_like(shape, prob, device=None): |
|
if prob == 1: |
|
return torch.ones(shape, device=device, dtype=torch.bool) |
|
elif prob == 0: |
|
return torch.zeros(shape, device=device, dtype=torch.bool) |
|
else: |
|
return uniform(shape, device=device) < prob |
|
|
|
|
|
|
|
def log(t, eps = 1e-20): |
|
return torch.log(t.clamp(min = eps)) |
|
|
|
def gumbel_noise(t): |
|
noise = torch.zeros_like(t).uniform_(0, 1) |
|
return -log(-log(noise)) |
|
|
|
def gumbel_sample(t, temperature = 1., dim = 1): |
|
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def top_k(logits, thres = 0.9, dim = 1): |
|
k = math.ceil((1 - thres) * logits.shape[dim]) |
|
val, ind = logits.topk(k, dim = dim) |
|
probs = torch.full_like(logits, float('-inf')) |
|
probs.scatter_(dim, ind, val) |
|
|
|
|
|
|
|
|
|
return probs |
|
|
|
|
|
|
|
|
|
def cosine_schedule(t): |
|
return torch.cos(t * math.pi * 0.5) |
|
|
|
def scale_cosine_schedule(t, scale): |
|
return torch.clip(scale*torch.cos(t * math.pi * 0.5) + 1 - scale, min=0., max=1.) |
|
|
|
|
|
def q_schedule(bs, low, high, device): |
|
noise = uniform((bs,), device=device) |
|
schedule = 1 - cosine_schedule(noise) |
|
return torch.round(schedule * (high - low - 1)).long() + low |
|
|
|
def cal_performance(pred, labels, ignore_index=None, smoothing=0., tk=1): |
|
loss = cal_loss(pred, labels, ignore_index, smoothing=smoothing) |
|
|
|
|
|
|
|
|
|
pred_id_k = torch.topk(pred, k=tk, dim=1).indices |
|
pred_id = pred_id_k[:, 0] |
|
mask = labels.ne(ignore_index) |
|
n_correct = (pred_id_k == labels.unsqueeze(1)).any(dim=1).masked_select(mask) |
|
acc = torch.mean(n_correct.float()).item() |
|
|
|
return loss, pred_id, acc |
|
|
|
|
|
def cal_loss(pred, labels, ignore_index=None, smoothing=0.): |
|
'''Calculate cross entropy loss, apply label smoothing if needed.''' |
|
|
|
|
|
if smoothing: |
|
space = 2 |
|
n_class = pred.size(1) |
|
mask = labels.ne(ignore_index) |
|
one_hot = rearrange(F.one_hot(labels, n_class + space), 'a ... b -> a b ...')[:, :n_class] |
|
|
|
sm_one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1) |
|
neg_log_prb = -F.log_softmax(pred, dim=1) |
|
loss = (sm_one_hot * neg_log_prb).sum(dim=1) |
|
|
|
loss = torch.mean(loss.masked_select(mask)) |
|
else: |
|
loss = F.cross_entropy(pred, labels, ignore_index=ignore_index) |
|
|
|
return loss |