File size: 3,247 Bytes
712d350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import torch

from extra_utils import res_to_list, res_to_seq


class AbEncoding:

    def __init__(self, device = 'cpu', ncpu = 1):
        
        self.device = device
        self.ncpu = ncpu
        
    def _initiate_abencoding(self, model, tokenizer):
        self.AbLang = model
        self.tokenizer = tokenizer
        
    def _encode_sequences(self, seqs):
        tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
        with torch.no_grad():
            return self.AbLang.AbRep(tokens).last_hidden_states
        
    def _predict_logits(self, seqs):
        tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
        with torch.no_grad():
            return self.AbLang(tokens)
        
    def _predict_logits_with_step_masking(self, seqs):
        
        tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
        
        logits = []
        for single_seq_tokens in tokens:
            
            tkn_len = len(single_seq_tokens)
            masked_tokens = single_seq_tokens.repeat(tkn_len, 1)
            for num in range(tkn_len):
                masked_tokens[num, num] = self.tokenizer.mask_token
            
            with torch.no_grad():
                logits_tmp = self.AbLang(masked_tokens)
                       
            logits_tmp = torch.stack([logits_tmp[num, num] for num in range(tkn_len)])

            logits.append(logits_tmp)
    
        return torch.stack(logits, dim=0)        

    def seqcoding(self, seqs, **kwargs):
        """
        Sequence specific representations
        """
        
        encodings = self._encode_sequences(seqs).cpu().numpy()
        
        lens = np.vectorize(len)(seqs)
        lens = np.tile(lens.reshape(-1,1,1), (encodings.shape[2], 1))

        return np.apply_along_axis(res_to_seq, 2, np.c_[np.swapaxes(encodings,1,2), lens])
        
    def rescoding(self, seqs, align=False, **kwargs):
        """
        Residue specific representations.
        """
        encodings = self._encode_sequences(seqs).cpu().numpy()
           
        if align: return encodings
            
        else: return [res_to_list(state, seq) for state, seq in zip(encodings, seqs)]
        
    def likelihood(self, seqs, align=False, stepwise_masking=False, **kwargs):
        """
        Likelihood of mutations
        """
        if stepwise_masking:
            logits = self._predict_logits_with_step_masking(seqs).cpu().numpy()
        else:
            logits = self._predict_logits(seqs).cpu().numpy()
        
        if align: return logits
            
        else: return [res_to_list(state, seq) for state, seq in zip(logits, seqs)]
        
    def probability(self, seqs, align=False, stepwise_masking=False, **kwargs):
        """
        Probability of mutations
        """
        if stepwise_masking:
            logits = self._predict_logits_with_step_masking(seqs)
        else:
            logits = self._predict_logits(seqs)
        probs = logits.softmax(-1).cpu().numpy()
        
        if align: return probs
            
        else: return [res_to_list(state, seq) for state, seq in zip(probs, seqs)]