import torch import torch.nn.functional as F import math from einops import rearrange # return mask where padding is FALSE def lengths_to_mask(lengths, max_len): # max_len = max(lengths) mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) return mask #(b, len) # return mask where padding is ALL FALSE def get_pad_mask_idx(seq, pad_idx): return (seq != pad_idx).unsqueeze(1) # Given seq: (b, s) # Return mat: (1, s, s) # Example Output: # [[[ True, False, False], # [ True, True, False], # [ True, True, True]]] # For causal attention def get_subsequent_mask(seq): sz_b, seq_len = seq.shape subsequent_mask = (1 - torch.triu( torch.ones((1, seq_len, seq_len)), diagonal=1)).bool() return subsequent_mask.to(seq.device) def exists(val): return val is not None def default(val, d): return val if exists(val) else d def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner def l2norm(t): return F.normalize(t, dim = -1) # tensor helpers # Get a random subset of TRUE mask, with prob def get_mask_subset_prob(mask, prob): subset_mask = torch.bernoulli(mask, p=prob) & mask return subset_mask # Get mask of special_tokens in ids def get_mask_special_tokens(ids, special_ids): mask = torch.zeros_like(ids).bool() for special_id in special_ids: mask |= (ids==special_id) return mask # network builder helpers def _get_activation_fn(activation): if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) # classifier free guidance functions def uniform(shape, device=None): return torch.zeros(shape, device=device).float().uniform_(0, 1) def prob_mask_like(shape, prob, device=None): if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return uniform(shape, device=device) < prob # sampling helpers def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) def gumbel_sample(t, temperature = 1., dim = 1): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) # Example input: # [[ 0.3596, 0.0862, 0.9771, -1.0000, -1.0000, -1.0000], # [ 0.4141, 0.1781, 0.6628, 0.5721, -1.0000, -1.0000], # [ 0.9428, 0.3586, 0.1659, 0.8172, 0.9273, -1.0000]] # Example output: # [[ -inf, -inf, 0.9771, -inf, -inf, -inf], # [ -inf, -inf, 0.6628, -inf, -inf, -inf], # [0.9428, -inf, -inf, -inf, -inf, -inf]] def top_k(logits, thres = 0.9, dim = 1): k = math.ceil((1 - thres) * logits.shape[dim]) val, ind = logits.topk(k, dim = dim) probs = torch.full_like(logits, float('-inf')) probs.scatter_(dim, ind, val) # func verified # print(probs) # print(logits) # raise return probs # noise schedules # More on large value, less on small def cosine_schedule(t): return torch.cos(t * math.pi * 0.5) def scale_cosine_schedule(t, scale): return torch.clip(scale*torch.cos(t * math.pi * 0.5) + 1 - scale, min=0., max=1.) # More on small value, less on large def q_schedule(bs, low, high, device): noise = uniform((bs,), device=device) schedule = 1 - cosine_schedule(noise) return torch.round(schedule * (high - low - 1)).long() + low def cal_performance(pred, labels, ignore_index=None, smoothing=0., tk=1): loss = cal_loss(pred, labels, ignore_index, smoothing=smoothing) # pred_id = torch.argmax(pred, dim=1) # mask = labels.ne(ignore_index) # n_correct = pred_id.eq(labels).masked_select(mask) # acc = torch.mean(n_correct.float()).item() pred_id_k = torch.topk(pred, k=tk, dim=1).indices pred_id = pred_id_k[:, 0] mask = labels.ne(ignore_index) n_correct = (pred_id_k == labels.unsqueeze(1)).any(dim=1).masked_select(mask) acc = torch.mean(n_correct.float()).item() return loss, pred_id, acc def cal_loss(pred, labels, ignore_index=None, smoothing=0.): '''Calculate cross entropy loss, apply label smoothing if needed.''' # print(pred.shape, labels.shape) #torch.Size([64, 1028, 55]) torch.Size([64, 55]) # print(pred.shape, labels.shape) #torch.Size([64, 1027, 55]) torch.Size([64, 55]) if smoothing: space = 2 n_class = pred.size(1) mask = labels.ne(ignore_index) one_hot = rearrange(F.one_hot(labels, n_class + space), 'a ... b -> a b ...')[:, :n_class] # one_hot = torch.zeros_like(pred).scatter(1, labels.unsqueeze(1), 1) sm_one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1) neg_log_prb = -F.log_softmax(pred, dim=1) loss = (sm_one_hot * neg_log_prb).sum(dim=1) # loss = F.cross_entropy(pred, sm_one_hot, reduction='none') loss = torch.mean(loss.masked_select(mask)) else: loss = F.cross_entropy(pred, labels, ignore_index=ignore_index) return loss