|
import numpy as np |
|
import torch |
|
|
|
from extra_utils import res_to_list, res_to_seq |
|
|
|
|
|
class AbScores: |
|
|
|
def __init__(self, device = 'cpu', ncpu = 1): |
|
|
|
self.device = device |
|
self.ncpu = ncpu |
|
|
|
def _initiate_abencoding(self, model, tokenizer): |
|
self.AbLang = model |
|
self.tokenizer = tokenizer |
|
|
|
def _encode_sequences(self, seqs): |
|
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) |
|
with torch.no_grad(): |
|
return self.AbLang.AbRep(tokens).last_hidden_states.numpy() |
|
|
|
def _predict_logits(self, seqs): |
|
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device) |
|
with torch.no_grad(): |
|
return self.AbLang(tokens), tokens |
|
|
|
def pseudo_log_likelihood(self, seqs, **kwargs): |
|
""" |
|
Pseudo log likelihood of sequences. |
|
""" |
|
|
|
plls = [] |
|
for seq in seqs: |
|
|
|
labels = self.tokenizer( |
|
seq, pad=True, w_extra_tkns=False, device=self.used_device |
|
) |
|
|
|
idxs = ( |
|
~torch.isin(labels, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device)) |
|
).nonzero() |
|
|
|
masked_tokens = labels.repeat(len(idxs), 1) |
|
for num, idx in enumerate(idxs): |
|
masked_tokens[num, idx[1]] = self.tokenizer.mask_token |
|
|
|
with torch.no_grad(): |
|
logits = self.AbLang(masked_tokens) |
|
|
|
logits[:, :, self.tokenizer.all_special_tokens] = -float("inf") |
|
logits = torch.stack([logits[num, idx[1]] for num, idx in enumerate(idxs)]) |
|
|
|
labels = labels[:,idxs[:,1:]].squeeze(2)[0] |
|
|
|
nll = torch.nn.functional.cross_entropy( |
|
logits, |
|
labels, |
|
reduction="mean", |
|
) |
|
|
|
pll = -nll |
|
|
|
plls.append(pll) |
|
|
|
plls = torch.stack(plls, dim=0).cpu().numpy() |
|
|
|
return plls |
|
|
|
def confidence(self, seqs, **kwargs): |
|
""" |
|
Log likelihood of sequences without masking. |
|
""" |
|
|
|
labels = self.tokenizer( |
|
seqs, pad=True, w_extra_tkns=False, device=self.used_device |
|
) |
|
with torch.no_grad(): |
|
logits = self.AbLang(labels) |
|
logits[:, :, self.tokenizer.all_special_tokens] = -float("inf") |
|
|
|
plls = [] |
|
for label, logit in zip(labels, logits): |
|
|
|
idxs = ( |
|
~torch.isin(label, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device)) |
|
).nonzero().squeeze(1) |
|
|
|
nll = torch.nn.functional.cross_entropy( |
|
logit[idxs], |
|
label[idxs], |
|
reduction="mean", |
|
) |
|
|
|
pll = -nll |
|
plls.append(pll) |
|
|
|
return torch.stack(plls, dim=0).cpu().numpy() |