PascalNotin commited on
Commit
1335bda
1 Parent(s): 6590011

Implemented first version of design app

Browse files
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: Tranception Design
3
  emoji: 🐨
4
- colorFrom: green
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.1.7
 
1
  ---
2
  title: Tranception Design
3
  emoji: 🐨
4
+ colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.1.7
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import PreTrainedTokenizerFast
4
+ import tranception
5
+ import datasets
6
+ from tranception import config, model_pytorch
7
+ import pandas as pd
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import gradio as gr
11
+
12
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
13
+ unk_token="[UNK]",
14
+ sep_token="[SEP]",
15
+ pad_token="[PAD]",
16
+ cls_token="[CLS]",
17
+ mask_token="[MASK]"
18
+ )
19
+ #######################################################################################################################################
20
+ ############################################### HELPER FUNCTIONS ####################################################################
21
+ #######################################################################################################################################
22
+
23
+ AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
24
+ def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
25
+ all_single_mutants={}
26
+ sequence_list=list(sequence)
27
+ if mutation_range_start is None: mutation_range_start=1
28
+ if mutation_range_end is None: mutation_range_end=len(sequence)
29
+ for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]):
30
+ for mutated_AA in AA_vocab:
31
+ if current_AA!=mutated_AA:
32
+ mutated_sequence = sequence_list.copy()
33
+ mutated_sequence[position] = mutated_AA
34
+ all_single_mutants[current_AA+str(position+1)+mutated_AA]="".join(mutated_sequence)
35
+ all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index')
36
+ all_single_mutants.reset_index(inplace=True)
37
+ all_single_mutants.columns = ['mutant','mutated_sequence']
38
+ return all_single_mutants
39
+
40
+ def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
41
+ piv=scores.pivot(index='position',columns='target_AA',values='avg_score').transpose().round(4)
42
+ fig, ax = plt.subplots(figsize=(len(sequence)*1.2,20))
43
+ scores_dict = {}
44
+ valid_mutant_set=set(scores.mutant)
45
+ if mutation_range_start is None: mutation_range_start=1
46
+ if mutation_range_end is None: mutation_range_start=len(sequence)
47
+ for target_AA in list(AA_vocab):
48
+ for position in range(mutation_range_start,mutation_range_end+1):
49
+ mutant = sequence[position-1]+str(position)+target_AA
50
+ if mutant in valid_mutant_set:
51
+ scores_dict[mutant]= float(scores.loc[scores.mutant==mutant,'avg_score'])
52
+ else:
53
+ scores_dict[mutant]=0.0
54
+ labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(len(AA_vocab),mutation_range_end-mutation_range_start+1)
55
+ heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
56
+ cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'})
57
+ heat.figure.axes[-1].yaxis.label.set_size(20)
58
+ #heat.set_title("Fitness scores for all single amino acid substitutions",fontsize=30)
59
+ heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=30, pad=40)
60
+ heat.set_xlabel("Sequence position", fontsize = 20)
61
+ heat.set_ylabel("Amino Acid mutation", fontsize = 20)
62
+ plt.savefig('fitness_scoring_substitution_matrix.png')
63
+ return plt
64
+
65
+ def suggest_mutations(scores):
66
+ intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
67
+ #Best mutants
68
+ top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant)
69
+ mutant_recos = "The 5 single mutants with highest predicted fitness are:\n {} \n\n".format(", ".join(top_mutants))
70
+ #Best positions
71
+ positive_scores = scores[scores.avg_score > 0]
72
+ positive_scores_position_avg = positive_scores.groupby(['position']).mean()
73
+ top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
74
+ print(top_positions)
75
+ position_recos = "The 5 positions with the highest average fitness increase are:\n {}".format(", ".join(top_positions))
76
+ return intro_message+mutant_recos+position_recos
77
+
78
+ def get_mutated_protein(sequence,mutant):
79
+ mutated_sequence = list(sequence)
80
+ mutated_sequence[int(mutant[1:-1])-1]=mutant[-1]
81
+ return ''.join(mutated_sequence)
82
+
83
+ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,num_workers=0,AA_vocab=AA_vocab):
84
+ if model_type=="Small":
85
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Small",use_auth_token=True)
86
+ elif model_type=="Medium":
87
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium",use_auth_token=True)
88
+ elif model_type=="Large":
89
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large",use_auth_token=True)
90
+ model.config.tokenizer = tokenizer
91
+ all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
92
+ scores = model.score_mutants(DMS_data=all_single_mutants,
93
+ target_seq=sequence,
94
+ scoring_mirror=scoring_mirror,
95
+ batch_size_inference=batch_size_inference,
96
+ num_workers=num_workers,
97
+ indel_mode=False
98
+ )
99
+ scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
100
+ scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
101
+ scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
102
+ score_heatmap = create_scoring_matrix_visual(scores,sequence,AA_vocab,mutation_range_start,mutation_range_end)
103
+ return score_heatmap,suggest_mutations(scores)
104
+
105
+ #######################################################################################################################################
106
+ ############################################### GRADIO INTERFACE ####################################################################
107
+ #######################################################################################################################################
108
+
109
+ title = "Interactive in silico directed evolution with Tranception"
110
+ description = "Perform in silico directed evolution with Tranception to iteratively improve the fitness of a starting protein sequence one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply. Note: The current version does not currently leverage homologs retrieval at inference time to boost fitness prediction performance."
111
+ article = "<p style='text-align: center'><a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</a></p>"
112
+ examples=[
113
+ ['A4_HUMAN: MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN'],
114
+ ['ADRB2_HUMAN: MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'],
115
+ ['AMIE_PSEAE: MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA'],
116
+ ['P53_HUMAN: MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD']
117
+ ]
118
+
119
+ model_size_selection = gr.Radio(label="Tranception model size", choices=["Small","Medium","Large"], value="Small")
120
+ protein_sequence_input = gr.Textbox(lines=1, label="Input protein sequence (see below for examples; default = RL40A_YEAST)",value="MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGGIIEPSLKALASKYNCDKSVCRKCYARLPPRATNCRKRKCGHTNQLRPKKKLK")
121
+ mutation_range_start = gr.Number(label="Start of mutation range (min value = 1)",value=1,precision=0)
122
+ mutation_range_end = gr.Number(label="End of mutation range (leave empty for full lenth)",value=10,precision=0)
123
+ scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)")
124
+
125
+ #output ==> find a way to make scroallable
126
+ output_plot = gr.Plot(label="Fitness scores for all single amino acid substitutions in mutation range")
127
+ output_recommendations = gr.Textbox(label="Mutation recommendations")
128
+
129
+ gr.Interface(
130
+ fn=score_and_create_matrix_all_singles,
131
+ inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror],
132
+ outputs=["plot","text"],
133
+ title=title,
134
+ description=description,
135
+ article=article,
136
+ examples=examples,
137
+ enable_queue=True,
138
+ allow_flagging="never"
139
+ ).launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ transformers==4.17
3
+ datasets==1.18.3
4
+ biopython==1.78
tranception/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import config
tranception/activations.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from packaging import version
5
+ from torch import nn
6
+
7
+ from transformers.utils import logging
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ def _gelu_python(x):
14
+ """
15
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
16
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
17
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
18
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
19
+ """
20
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
21
+
22
+
23
+ def gelu_new(x):
24
+ """
25
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
26
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
27
+ """
28
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
29
+
30
+
31
+ if version.parse(torch.__version__) < version.parse("1.4"):
32
+ gelu = _gelu_python
33
+ else:
34
+ gelu = nn.functional.gelu
35
+
36
+
37
+ def gelu_fast(x):
38
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
39
+
40
+
41
+ def quick_gelu(x):
42
+ return x * torch.sigmoid(1.702 * x)
43
+
44
+
45
+ def _silu_python(x):
46
+ """
47
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
48
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
49
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
50
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
51
+ later.
52
+ """
53
+ return x * torch.sigmoid(x)
54
+
55
+
56
+ if version.parse(torch.__version__) < version.parse("1.7"):
57
+ silu = _silu_python
58
+ else:
59
+ silu = nn.functional.silu
60
+
61
+
62
+ def _mish_python(x):
63
+ """
64
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
65
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
66
+ """
67
+ return x * torch.tanh(nn.functional.softplus(x))
68
+
69
+
70
+ if version.parse(torch.__version__) < version.parse("1.9"):
71
+ mish = _mish_python
72
+ else:
73
+ mish = nn.functional.mish
74
+
75
+
76
+ def linear_act(x):
77
+ return x
78
+
79
+ def squared_relu(x):
80
+ """
81
+ Squared ReLU variant that is fastest with Pytorch.
82
+ """
83
+ x = nn.functional.relu(x)
84
+ return x*x
85
+
86
+ def squared_relu_xla(x):
87
+ """
88
+ Squared ReLU variant that is fastest with JAX.
89
+ """
90
+ x = nn.functional.relu(x)
91
+ return x**2
92
+
93
+ tranception_ACT2FN = {
94
+ "relu": nn.functional.relu,
95
+ "silu": silu,
96
+ "swish": silu,
97
+ "gelu": gelu,
98
+ "tanh": torch.tanh,
99
+ "gelu_new": gelu_new,
100
+ "gelu_fast": gelu_fast,
101
+ "quick_gelu": quick_gelu,
102
+ "mish": mish,
103
+ "linear": linear_act,
104
+ "sigmoid": torch.sigmoid,
105
+ "squared_relu": squared_relu,
106
+ "squared_relu_xla": squared_relu_xla,
107
+ }
108
+
109
+
110
+ def get_activation(activation_string):
111
+ if activation_string in tranception_ACT2FN:
112
+ return tranception_ACT2FN[activation_string]
113
+ else:
114
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(tranception_ACT2FN.keys())}")
tranception/config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+ class TranceptionConfig(GPT2Config):
4
+ """
5
+ Config subclass for Tranception model architecture.
6
+ """
7
+ def __init__(
8
+ self,
9
+ attention_mode="tranception",
10
+ position_embedding="grouped_alibi",
11
+ tokenizer=None,
12
+ retrieval_aggregation_mode=None,
13
+ retrieval_inference_weight=0.6,
14
+ MSA_filename=None,
15
+ MSA_weight_file_name=None,
16
+ MSA_start=None,
17
+ MSA_end=None,
18
+ full_protein_length=None,
19
+ clustal_omega_location=None,
20
+ scoring_window=None,
21
+ **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.model_type="tranception"
25
+ self.attention_mode=attention_mode
26
+ self.position_embedding=position_embedding
27
+ self.tokenizer = tokenizer
28
+ self.retrieval_aggregation_mode = retrieval_aggregation_mode
29
+ self.retrieval_inference_weight = retrieval_inference_weight
30
+ self.MSA_filename = MSA_filename
31
+ self.MSA_weight_file_name = MSA_weight_file_name
32
+ self.MSA_start=MSA_start
33
+ self.MSA_end=MSA_end
34
+ self.full_protein_length = full_protein_length
35
+ self.clustal_omega_location = clustal_omega_location
36
+ self.scoring_window=scoring_window
tranception/model_pytorch.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+ import math
4
+ import os
5
+ import pandas as pd
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, NLLLoss
10
+ import torch.nn.functional as F
11
+ from transformers import GPT2PreTrainedModel
12
+
13
+ from transformers.modeling_utils import (
14
+ Conv1D,
15
+ PreTrainedModel,
16
+ SequenceSummary,
17
+ find_pruneable_heads_and_indices,
18
+ prune_conv1d_layer,
19
+ )
20
+ from transformers.file_utils import (
21
+ ModelOutput,
22
+ add_code_sample_docstrings,
23
+ add_start_docstrings,
24
+ add_start_docstrings_to_model_forward,
25
+ replace_return_docstrings
26
+ )
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ CausalLMOutputWithCrossAttentions,
30
+ SequenceClassifierOutputWithPast,
31
+ TokenClassifierOutput
32
+ )
33
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
34
+
35
+ from tranception.activations import tranception_ACT2FN
36
+ from tranception.config import TranceptionConfig
37
+ from tranception.outputs import (
38
+ TranceptionCausalLMOutputWithCrossAttentions,
39
+ )
40
+ from tranception.utils import msa_utils
41
+ from tranception.utils import scoring_utils
42
+
43
+ def nanmean(v, *args, inplace=False, **kwargs):
44
+ if not inplace:
45
+ v = v.clone()
46
+ is_nan = torch.isnan(v)
47
+ v[is_nan] = 0
48
+ return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
49
+
50
+ def get_slopes(n, mode="standard_alibi", verbose=False):
51
+ """
52
+ Function to compute the m constant for each attention head. Code has been adapted from the official ALiBi codebase at:
53
+ https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py
54
+ """
55
+ def get_slopes_power_of_2(n):
56
+ start = (2**(-2**-(math.log2(n)-3)))
57
+ ratio = start
58
+ return [start*ratio**i for i in range(n)]
59
+ if mode=="grouped_alibi":
60
+ n = n // 4
61
+ if math.log2(n).is_integer():
62
+ result = get_slopes_power_of_2(n)
63
+ else:
64
+ #Workaround when the number of heads is not a power of 2
65
+ closest_power_of_2 = 2**math.floor(math.log2(n))
66
+ result = get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
67
+ if mode=="grouped_alibi":
68
+ result = result * 4
69
+ if verbose:
70
+ print("ALiBi slopes: {}".format(result))
71
+ return result
72
+
73
+ class SpatialDepthWiseConvolution(nn.Module):
74
+ def __init__(self, head_dim: int, kernel_size: int = 3):
75
+ super().__init__()
76
+ self.kernel_size = kernel_size
77
+ self.conv = nn.Conv1d(in_channels=head_dim, out_channels=head_dim, kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=head_dim)
78
+
79
+ def forward(self, x: torch.Tensor):
80
+ batch_size, heads, seq_len, head_dim = x.shape
81
+ x = x.permute(0, 1, 3, 2).contiguous()
82
+ x = x.view(batch_size * heads, head_dim, seq_len)
83
+ x = self.conv(x)
84
+ if self.kernel_size>1:
85
+ x = x[:, :, :-(self.kernel_size - 1)]
86
+ x = x.view(batch_size, heads, head_dim, seq_len)
87
+ x = x.permute(0, 1, 3, 2)
88
+ return x
89
+
90
+ class TranceptionBlockAttention(nn.Module):
91
+ def __init__(self, config, is_cross_attention=False, SDWC_kernel_size=None):
92
+ super().__init__()
93
+
94
+ max_positions = config.max_position_embeddings
95
+ self.register_buffer(
96
+ "bias",
97
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
98
+ 1, 1, max_positions, max_positions
99
+ ),
100
+ )
101
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
102
+
103
+ self.embed_dim = config.hidden_size
104
+ self.num_heads = config.num_attention_heads
105
+ self.head_dim = self.embed_dim // self.num_heads
106
+ self.split_size = self.embed_dim
107
+ if self.head_dim * self.num_heads != self.embed_dim:
108
+ raise ValueError(
109
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
110
+ )
111
+
112
+ self.scale_attn_weights = config.scale_attn_weights
113
+ self.is_cross_attention = is_cross_attention
114
+
115
+ if self.is_cross_attention:
116
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
117
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
118
+ else:
119
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
120
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
121
+
122
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
123
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
124
+
125
+ self.pruned_heads = set()
126
+
127
+ self.attention_mode=config.attention_mode
128
+
129
+ if self.attention_mode=="tranception":
130
+ assert self.num_heads%4==0, "Invalid number of heads. Tranception requires the number of heads to be a multiple of 4."
131
+ self.num_heads_per_kernel_size = self.num_heads // 4
132
+ self.query_depthwiseconv = nn.ModuleDict()
133
+ self.key_depthwiseconv = nn.ModuleDict()
134
+ self.value_depthwiseconv = nn.ModuleDict()
135
+ for kernel_idx, kernel in enumerate([3,5,7]):
136
+ self.query_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
137
+ self.key_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
138
+ self.value_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
139
+
140
+ def prune_heads(self, heads):
141
+ if len(heads) == 0:
142
+ return
143
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
144
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
145
+
146
+ # Prune conv1d layers
147
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
148
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
149
+
150
+ # Update hyper params
151
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
152
+ self.num_heads = self.num_heads - len(heads)
153
+ self.pruned_heads = self.pruned_heads.union(heads)
154
+
155
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None, alibi_bias=None):
156
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
157
+
158
+ if self.scale_attn_weights:
159
+ attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
160
+
161
+ if not self.is_cross_attention:
162
+ # if only "normal" attention layer implements causal mask
163
+ query_length, key_length = query.size(-2), key.size(-2)
164
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
165
+ attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
166
+
167
+ if alibi_bias is not None:
168
+ attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]
169
+
170
+ if attention_mask is not None:
171
+ # Apply the attention mask
172
+ attn_weights = attn_weights + attention_mask
173
+
174
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
175
+ attn_weights = self.attn_dropout(attn_weights)
176
+
177
+ # Mask heads if we want to
178
+ if head_mask is not None:
179
+ attn_weights = attn_weights * head_mask
180
+
181
+ attn_output = torch.matmul(attn_weights, value)
182
+
183
+ return attn_output, attn_weights
184
+
185
+ def _split_heads(self, tensor, num_heads, attn_head_size):
186
+ """
187
+ Splits hidden_size dim into attn_head_size and num_heads
188
+ """
189
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
190
+ tensor = tensor.view(*new_shape)
191
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
192
+
193
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
194
+ """
195
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
196
+ """
197
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
198
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
199
+ return tensor.view(new_shape)
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states,
204
+ layer_past=None,
205
+ attention_mask=None,
206
+ head_mask=None,
207
+ encoder_hidden_states=None,
208
+ encoder_attention_mask=None,
209
+ use_cache=False,
210
+ output_attentions=False,
211
+ alibi_bias=None,
212
+ ):
213
+ if encoder_hidden_states is not None:
214
+ if not hasattr(self, "q_attn"):
215
+ raise ValueError(
216
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
217
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
218
+ )
219
+
220
+ query = self.q_attn(hidden_states)
221
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
222
+ attention_mask = encoder_attention_mask
223
+ else:
224
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
225
+
226
+ query = self._split_heads(query, self.num_heads, self.head_dim)
227
+ key = self._split_heads(key, self.num_heads, self.head_dim)
228
+ value = self._split_heads(value, self.num_heads, self.head_dim)
229
+
230
+ if layer_past is not None:
231
+ past_key, past_value = layer_past
232
+ key = torch.cat((past_key, key), dim=-2)
233
+ value = torch.cat((past_value, value), dim=-2)
234
+
235
+ if use_cache is True:
236
+ present = (key, value)
237
+ else:
238
+ present = None
239
+
240
+ if self.attention_mode=="tranception":
241
+ # We do not do anything on the first self.num_heads_per_kernel_size heads (kernel =1)
242
+ query_list=[query[:,:self.num_heads_per_kernel_size,:,:]]
243
+ key_list=[key[:,:self.num_heads_per_kernel_size,:,:]]
244
+ value_list=[value[:,:self.num_heads_per_kernel_size,:,:]]
245
+ for kernel_idx in range(3):
246
+ query_list.append(self.query_depthwiseconv[str(kernel_idx)](query[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
247
+ key_list.append(self.key_depthwiseconv[str(kernel_idx)](key[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
248
+ value_list.append(self.value_depthwiseconv[str(kernel_idx)](value[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
249
+ query=torch.cat(query_list, dim=1)
250
+ key=torch.cat(key_list, dim=1)
251
+ value=torch.cat(value_list, dim=1)
252
+
253
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, alibi_bias=alibi_bias)
254
+
255
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
256
+ attn_output = self.c_proj(attn_output)
257
+ attn_output = self.resid_dropout(attn_output)
258
+
259
+ outputs = (attn_output, present)
260
+ if output_attentions:
261
+ outputs += (attn_weights,)
262
+
263
+ return outputs # a, present, (attentions)
264
+
265
+ class TranceptionBlockMLP(nn.Module):
266
+ def __init__(self, intermediate_size, config):
267
+ super().__init__()
268
+ embed_dim = config.hidden_size
269
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
270
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
271
+ self.act = tranception_ACT2FN[config.activation_function]
272
+ self.dropout = nn.Dropout(config.resid_pdrop)
273
+
274
+ def forward(self, hidden_states):
275
+ hidden_states = self.c_fc(hidden_states)
276
+ hidden_states = self.act(hidden_states)
277
+ hidden_states = self.c_proj(hidden_states)
278
+ hidden_states = self.dropout(hidden_states)
279
+ return hidden_states
280
+
281
+ class TranceptionBlock(nn.Module):
282
+ def __init__(self, config, SDWC_kernel_size=None):
283
+ super().__init__()
284
+ hidden_size = config.hidden_size
285
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
286
+
287
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
288
+ self.attn = TranceptionBlockAttention(config, SDWC_kernel_size=SDWC_kernel_size)
289
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
290
+
291
+ if config.add_cross_attention:
292
+ self.crossattention = TranceptionBlockAttention(config, is_cross_attention=True, SDWC_kernel_size=SDWC_kernel_size)
293
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
294
+
295
+ self.mlp = TranceptionBlockMLP(inner_dim, config)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states,
300
+ layer_past=None,
301
+ attention_mask=None,
302
+ head_mask=None,
303
+ encoder_hidden_states=None,
304
+ encoder_attention_mask=None,
305
+ use_cache=False,
306
+ output_attentions=False,
307
+ alibi_bias=None,
308
+ ):
309
+ residual = hidden_states
310
+ hidden_states = self.ln_1(hidden_states)
311
+ attn_outputs = self.attn(
312
+ hidden_states,
313
+ layer_past=layer_past,
314
+ attention_mask=attention_mask,
315
+ head_mask=head_mask,
316
+ use_cache=use_cache,
317
+ output_attentions=output_attentions,
318
+ alibi_bias=alibi_bias,
319
+ )
320
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
321
+ outputs = attn_outputs[1:]
322
+ # residual connection
323
+ hidden_states = attn_output + residual
324
+
325
+ if encoder_hidden_states is not None:
326
+ # add one self-attention block for cross-attention
327
+ if not hasattr(self, "crossattention"):
328
+ raise ValueError(
329
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
330
+ "cross-attention layers by setting `config.add_cross_attention=True`"
331
+ )
332
+ residual = hidden_states
333
+ hidden_states = self.ln_cross_attn(hidden_states)
334
+ cross_attn_outputs = self.crossattention(
335
+ hidden_states,
336
+ attention_mask=attention_mask,
337
+ head_mask=head_mask,
338
+ encoder_hidden_states=encoder_hidden_states,
339
+ encoder_attention_mask=encoder_attention_mask,
340
+ output_attentions=output_attentions,
341
+ )
342
+ attn_output = cross_attn_outputs[0]
343
+ # residual connection
344
+ hidden_states = residual + attn_output
345
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
346
+
347
+ residual = hidden_states
348
+ hidden_states = self.ln_2(hidden_states)
349
+
350
+ feed_forward_hidden_states = self.mlp(hidden_states)
351
+
352
+ # residual connection
353
+ hidden_states = residual + feed_forward_hidden_states
354
+
355
+ if use_cache:
356
+ outputs = (hidden_states,) + outputs
357
+ else:
358
+ outputs = (hidden_states,) + outputs[1:]
359
+
360
+ return outputs # hidden_states, present, (attentions, cross_attentions)
361
+
362
+ class TranceptionModel(GPT2PreTrainedModel):
363
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
364
+ def __init__(self, config):
365
+ super().__init__(config)
366
+
367
+ self.embed_dim = config.hidden_size
368
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
369
+ self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"
370
+ if self.position_embedding=="learned":
371
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
372
+ self.alibi = None
373
+ elif self.position_embedding=="grouped_alibi":
374
+ maxpos = config.n_positions
375
+ attn_heads = config.n_head
376
+ self.slopes = torch.Tensor(get_slopes(attn_heads, mode=self.position_embedding))
377
+ #The softmax operation is invariant to translation, and bias functions used are always linear.
378
+ alibi = self.slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1)
379
+ alibi = alibi.view(attn_heads, 1, maxpos)
380
+ self.register_buffer('alibi',alibi)
381
+
382
+ self.drop = nn.Dropout(config.embd_pdrop)
383
+ self.h = nn.ModuleList([TranceptionBlock(config) for _ in range(config.num_hidden_layers)])
384
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
385
+
386
+ self.init_weights()
387
+
388
+ # Model parallel
389
+ self.model_parallel = False
390
+ self.device_map = None
391
+ self.gradient_checkpointing = False
392
+
393
+ def parallelize(self, device_map=None, num_cores=None):
394
+ self.device_map = (
395
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
396
+ )
397
+ device_prefix="cuda:"
398
+ assert_device_map(self.device_map, len(self.h))
399
+ self.model_parallel = True
400
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else device_prefix + str(min(self.device_map.keys()))
401
+ self.last_device = device_prefix + str(max(self.device_map.keys()))
402
+ self.wte = self.wte.to(self.first_device)
403
+ if self.position_embedding=="learned":
404
+ self.wpe = self.wpe.to(self.first_device)
405
+ for k, v in self.device_map.items():
406
+ print("k,v :"+str(k)+","+str(v))
407
+ for block in v:
408
+ cuda_device = device_prefix + str(k)
409
+ self.h[block] = self.h[block].to(cuda_device)
410
+ self.ln_f = self.ln_f.to(self.last_device)
411
+
412
+ def deparallelize(self):
413
+ self.model_parallel = False
414
+ self.device_map = None
415
+ self.first_device = "cpu"
416
+ self.last_device = "cpu"
417
+ self.wte = self.wte.to("cpu")
418
+ if self.position_embedding=="learned":
419
+ self.wpe = self.wpe.to("cpu")
420
+ for index in range(len(self.h)):
421
+ self.h[index] = self.h[index].to("cpu")
422
+ self.ln_f = self.ln_f.to("cpu")
423
+ torch.cuda.empty_cache()
424
+
425
+ def get_input_embeddings(self):
426
+ return self.wte
427
+
428
+ def set_input_embeddings(self, new_embeddings):
429
+ self.wte = new_embeddings
430
+
431
+ def _prune_heads(self, heads_to_prune):
432
+ """
433
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
434
+ """
435
+ for layer, heads in heads_to_prune.items():
436
+ self.h[layer].attn.prune_heads(heads)
437
+
438
+ def forward(
439
+ self,
440
+ input_ids=None,
441
+ past_key_values=None,
442
+ attention_mask=None,
443
+ token_type_ids=None,
444
+ position_ids=None,
445
+ head_mask=None,
446
+ inputs_embeds=None,
447
+ encoder_hidden_states=None,
448
+ encoder_attention_mask=None,
449
+ use_cache=None,
450
+ output_attentions=None,
451
+ output_hidden_states=None,
452
+ return_dict=None,
453
+ ):
454
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
455
+ output_hidden_states = (
456
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
457
+ )
458
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
459
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
460
+
461
+ if input_ids is not None and inputs_embeds is not None:
462
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
463
+ elif input_ids is not None:
464
+ input_shape = input_ids.size()
465
+ input_ids = input_ids.view(-1, input_shape[-1])
466
+ batch_size = input_ids.shape[0]
467
+ elif inputs_embeds is not None:
468
+ input_shape = inputs_embeds.size()[:-1]
469
+ batch_size = inputs_embeds.shape[0]
470
+ else:
471
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
472
+
473
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
474
+
475
+ if token_type_ids is not None:
476
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
477
+ if position_ids is not None:
478
+ position_ids = position_ids.view(-1, input_shape[-1])
479
+
480
+ if past_key_values is None:
481
+ past_length = 0
482
+ past_key_values = tuple([None] * len(self.h))
483
+ else:
484
+ past_length = past_key_values[0][0].size(-2)
485
+ if position_ids is None:
486
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
487
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
488
+
489
+ # GPT2Attention mask.
490
+ if attention_mask is not None:
491
+ if batch_size <= 0:
492
+ raise ValueError("batch_size has to be defined and > 0")
493
+ attention_mask = attention_mask.view(batch_size, -1)
494
+ # We create a 3D attention mask from a 2D tensor mask.
495
+ # Sizes are [batch_size, 1, 1, to_seq_length]
496
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
497
+ # this attention mask is more simple than the triangular masking of causal attention
498
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
499
+ attention_mask = attention_mask[:, None, None, :]
500
+
501
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
502
+ # masked positions, this operation will create a tensor which is 0.0 for
503
+ # positions we want to attend and -10000.0 for masked positions.
504
+ # Since we are adding it to the raw scores before the softmax, this is
505
+ # effectively the same as removing these entirely.
506
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
507
+ attention_mask = (1.0 - attention_mask) * -10000.0
508
+
509
+ # If a 2D ou 3D attention mask is provided for the cross-attention
510
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
511
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
512
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
513
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
514
+ if encoder_attention_mask is None:
515
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
516
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
517
+ else:
518
+ encoder_attention_mask = None
519
+
520
+ # Prepare head mask if needed
521
+ # 1.0 in head_mask indicate we keep the head
522
+ # attention_probs has shape bsz x n_heads x N x N
523
+ # head_mask has shape n_layer x batch x n_heads x N x N
524
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
525
+
526
+ if inputs_embeds is None:
527
+ inputs_embeds = self.wte(input_ids)
528
+ if self.position_embedding=="learned":
529
+ position_embeds = self.wpe(position_ids)
530
+ hidden_states = inputs_embeds + position_embeds
531
+ else:
532
+ hidden_states = inputs_embeds
533
+
534
+ if token_type_ids is not None:
535
+ token_type_embeds = self.wte(token_type_ids)
536
+ hidden_states = hidden_states + token_type_embeds
537
+
538
+ hidden_states = self.drop(hidden_states)
539
+
540
+ output_shape = input_shape + (hidden_states.size(-1),)
541
+
542
+ presents = () if use_cache else None
543
+ all_self_attentions = () if output_attentions else None
544
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
545
+ all_hidden_states = () if output_hidden_states else None
546
+
547
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
548
+ # Model parallel
549
+ if self.model_parallel:
550
+ torch.cuda.set_device(hidden_states.device)
551
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
552
+ if layer_past is not None:
553
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
554
+ # Ensure that attention_mask is always on the same device as hidden_states
555
+ if attention_mask is not None:
556
+ attention_mask = attention_mask.to(hidden_states.device)
557
+ if isinstance(head_mask, torch.Tensor):
558
+ head_mask = head_mask.to(hidden_states.device)
559
+ if output_hidden_states:
560
+ all_hidden_states = all_hidden_states + (hidden_states,)
561
+
562
+ if self.gradient_checkpointing and self.training:
563
+ if use_cache:
564
+ logger.warning(
565
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
566
+ )
567
+ use_cache = False
568
+
569
+ def create_custom_forward(module):
570
+ def custom_forward(*inputs):
571
+ # None for past_key_value
572
+ return module(*inputs, use_cache, output_attentions)
573
+
574
+ return custom_forward
575
+
576
+ outputs = torch.utils.checkpoint.checkpoint(
577
+ create_custom_forward(block),
578
+ hidden_states,
579
+ None,
580
+ attention_mask,
581
+ head_mask[i],
582
+ encoder_hidden_states,
583
+ encoder_attention_mask,
584
+ )
585
+ else:
586
+ outputs = block(
587
+ hidden_states,
588
+ layer_past=layer_past,
589
+ attention_mask=attention_mask,
590
+ head_mask=head_mask[i],
591
+ encoder_hidden_states=encoder_hidden_states,
592
+ encoder_attention_mask=encoder_attention_mask,
593
+ use_cache=use_cache,
594
+ output_attentions=output_attentions,
595
+ alibi_bias=self.alibi if hasattr(self, "alibi") else None
596
+ )
597
+
598
+ hidden_states = outputs[0]
599
+
600
+ if use_cache is True:
601
+ presents = presents + (outputs[1],)
602
+
603
+ if output_attentions:
604
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
605
+ if self.config.add_cross_attention:
606
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
607
+
608
+ if self.model_parallel:
609
+ device_prefix="cuda:"
610
+ for k, v in self.device_map.items():
611
+ if i == v[-1] and device_prefix + str(k) != self.last_device:
612
+ hidden_states = hidden_states.to(device_prefix + str(k + 1))
613
+
614
+ hidden_states = self.ln_f(hidden_states)
615
+
616
+ hidden_states = hidden_states.view(*output_shape)
617
+ # Add last hidden state
618
+ if output_hidden_states:
619
+ all_hidden_states = all_hidden_states + (hidden_states,)
620
+
621
+ if not return_dict:
622
+ return tuple(
623
+ v
624
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions, moe_loss]
625
+ if v is not None
626
+ )
627
+
628
+ return BaseModelOutputWithPastAndCrossAttentions(
629
+ last_hidden_state=hidden_states,
630
+ past_key_values=presents,
631
+ hidden_states=all_hidden_states,
632
+ attentions=all_self_attentions,
633
+ cross_attentions=all_cross_attentions,
634
+ )
635
+
636
+ class TranceptionLMHeadModel(GPT2PreTrainedModel):
637
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
638
+ def __init__(self, config):
639
+ super().__init__(config)
640
+ self.transformer = TranceptionModel(config)
641
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
642
+ self.config = config
643
+
644
+ self.init_weights()
645
+
646
+ self.default_model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
647
+ # Model parallel
648
+ self.model_parallel = False
649
+ self.device_map = None
650
+
651
+ self.retrieval_aggregation_mode = config.retrieval_aggregation_mode if hasattr(config, "retrieval_aggregation_mode") else None
652
+ if self.retrieval_aggregation_mode is not None:
653
+ print("Model leverages both autoregressive and retrieval inference")
654
+ self.MSA_filename = config.MSA_filename if hasattr(config, "MSA_filename") else False
655
+ self.MSA_folder = '/'.join(self.MSA_filename.split(os.sep)[:-1])
656
+ self.MSA_name = self.MSA_filename.split(os.sep)[-1]
657
+ self.retrieval_inference_weight_LR = config.retrieval_inference_weight if hasattr(config, "retrieval_inference_weight") else 0.6
658
+ self.retrieval_inference_weight_RL = config.retrieval_inference_weight if hasattr(config, "retrieval_inference_weight") else 0.6
659
+ self.MSA_start=config.MSA_start
660
+ self.MSA_end=config.MSA_end
661
+ self.full_protein_length = config.full_protein_length if hasattr(config, "full_protein_length") else -1
662
+
663
+ self.MSA_log_prior = torch.log(torch.tensor(
664
+ msa_utils.get_msa_prior(
665
+ MSA_data_file=self.MSA_filename,
666
+ MSA_weight_file_name=config.MSA_weight_file_name,
667
+ retrieval_aggregation_mode=self.retrieval_aggregation_mode,
668
+ MSA_start=self.MSA_start,
669
+ MSA_end=self.MSA_end,
670
+ len_target_seq=self.full_protein_length,
671
+ vocab=config.tokenizer.get_vocab(),
672
+ verbose=False
673
+ )
674
+ ).float().to(self.default_model_device))
675
+ else:
676
+ print("Model only uses autoregressive inference")
677
+
678
+ def parallelize(self, device_map=None, num_cores=None, num_pipelines=1):
679
+ self.num_pipelines=num_pipelines
680
+ self.device_map = (
681
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
682
+ if device_map is None
683
+ else device_map
684
+ )
685
+ assert_device_map(self.device_map, len(self.transformer.h))
686
+ self.transformer.parallelize(self.device_map, num_cores=num_cores)
687
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
688
+ self.model_parallel = True
689
+
690
+ def deparallelize(self):
691
+ self.transformer.deparallelize()
692
+ self.transformer = self.transformer.to("cpu")
693
+ self.lm_head = self.lm_head.to("cpu")
694
+ self.model_parallel = False
695
+ torch.cuda.empty_cache()
696
+
697
+ def get_output_embeddings(self):
698
+ return self.lm_head
699
+
700
+ def set_output_embeddings(self, new_embeddings):
701
+ self.lm_head = new_embeddings
702
+
703
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
704
+ token_type_ids = kwargs.get("token_type_ids", None)
705
+ # only last token for inputs_ids if past is defined in kwargs
706
+ if past:
707
+ input_ids = input_ids[:, -1].unsqueeze(-1)
708
+ if token_type_ids is not None:
709
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
710
+
711
+ attention_mask = kwargs.get("attention_mask", None)
712
+ position_ids = kwargs.get("position_ids", None)
713
+
714
+ if attention_mask is not None and position_ids is None:
715
+ # create position_ids on the fly for batch generation
716
+ position_ids = attention_mask.long().cumsum(-1) - 1
717
+ position_ids.masked_fill_(attention_mask == 0, 1)
718
+ if past:
719
+ position_ids = position_ids[:, -1].unsqueeze(-1)
720
+ else:
721
+ position_ids = None
722
+
723
+ return {
724
+ "input_ids": input_ids,
725
+ "past_key_values": past,
726
+ "use_cache": kwargs.get("use_cache"),
727
+ "position_ids": position_ids,
728
+ "attention_mask": attention_mask,
729
+ "token_type_ids": token_type_ids,
730
+ "flip": kwargs.get("flip", None),
731
+ }
732
+
733
+ def forward(
734
+ self,
735
+ input_ids=None,
736
+ past_key_values=None,
737
+ attention_mask=None,
738
+ token_type_ids=None,
739
+ position_ids=None,
740
+ head_mask=None,
741
+ inputs_embeds=None,
742
+ encoder_hidden_states=None,
743
+ encoder_attention_mask=None,
744
+ labels=None,
745
+ use_cache=None,
746
+ output_attentions=None,
747
+ output_hidden_states=None,
748
+ return_dict=None,
749
+ flip=None,
750
+ start_slice=None,
751
+ end_slice=None,
752
+ mutated_sequence=None,
753
+ ):
754
+ r"""
755
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
756
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
757
+ ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
758
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
759
+ """
760
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
761
+
762
+ transformer_outputs = self.transformer(
763
+ input_ids,
764
+ past_key_values=past_key_values,
765
+ attention_mask=attention_mask,
766
+ token_type_ids=token_type_ids,
767
+ position_ids=position_ids,
768
+ head_mask=head_mask,
769
+ inputs_embeds=inputs_embeds,
770
+ encoder_hidden_states=encoder_hidden_states,
771
+ encoder_attention_mask=encoder_attention_mask,
772
+ use_cache=use_cache,
773
+ output_attentions=output_attentions,
774
+ output_hidden_states=output_hidden_states,
775
+ return_dict=return_dict
776
+ )
777
+ hidden_states = transformer_outputs[0]
778
+
779
+ # Set device for model parallelism
780
+ if self.model_parallel:
781
+ torch.cuda.set_device(self.transformer.first_device)
782
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
783
+ self.MSA_log_prior = self.MSA_log_prior.to(self.lm_head.weight.device)
784
+
785
+ lm_logits = self.lm_head(hidden_states)
786
+
787
+ loss = None
788
+ if labels is not None:
789
+ # Shift so that tokens < n predict n
790
+ shift_logits = lm_logits[..., :-1, :].contiguous()
791
+ shift_labels = labels[..., 1:].contiguous()
792
+
793
+ if self.retrieval_aggregation_mode is not None:
794
+ batch_size = input_ids.size(0)
795
+
796
+ if self.retrieval_aggregation_mode=="aggregate_indel":
797
+ assert batch_size==1, "Aggregate indel is only supported for batch size of 1"
798
+ truncated_sequence_text = mutated_sequence[0][start_slice[0]:end_slice[0]]
799
+ if len(truncated_sequence_text)!=shift_logits.shape[1]-1: # shift_logits only has one extra token compared to truncated_sequence_text (the BOS token)
800
+ print("Tokenization error -- seq length: {} and shift_logits length - 1 : {}".format(len(mutated_sequence),shift_logits.shape[1]-1))
801
+ MSA_log_prior, MSA_start, MSA_end = msa_utils.update_retrieved_MSA_log_prior_indel(self, self.MSA_log_prior, self.MSA_start, self.MSA_end, mutated_sequence[0])
802
+
803
+ elif self.retrieval_aggregation_mode=="aggregate_substitution":
804
+ MSA_log_prior=self.MSA_log_prior
805
+ MSA_start=self.MSA_start
806
+ MSA_end=self.MSA_end
807
+
808
+ shift_log_probas = torch.log_softmax(shift_logits,dim=-1)
809
+ fused_shift_log_probas = shift_log_probas.clone()
810
+ if flip is None:
811
+ flip = torch.zeros(batch_size).to(fused_shift_log_probas.device)
812
+ flip = flip > 0
813
+
814
+ for seq_index in range(batch_size):
815
+ min_prior_slice = max(start_slice[seq_index], MSA_start)
816
+ max_prior_slice = min(end_slice[seq_index], MSA_end)
817
+
818
+ if max_prior_slice <= min_prior_slice:
819
+ print("Non overlapping region detected: min_prior_slice {} and max_prior_slice {}".format(min_prior_slice,max_prior_slice))
820
+ continue
821
+
822
+ slice_prior = MSA_log_prior[min_prior_slice:max_prior_slice,:].to(fused_shift_log_probas.device)
823
+ if flip[seq_index]:
824
+ slice_prior = torch.flip(slice_prior,dims=(0,))
825
+ min_logits_slice = max(0,end_slice[seq_index]-MSA_end)
826
+ max_logits_slice = min_logits_slice + (max_prior_slice-min_prior_slice)
827
+ fused_shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] = (1-self.retrieval_inference_weight_RL)*shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] + self.retrieval_inference_weight_RL*slice_prior
828
+ else:
829
+ min_logits_slice = max(0, MSA_start-start_slice[seq_index])
830
+ max_logits_slice = min_logits_slice + (max_prior_slice-min_prior_slice)
831
+ fused_shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] = (1-self.retrieval_inference_weight_LR)*shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] + self.retrieval_inference_weight_LR*slice_prior
832
+
833
+ if self.retrieval_aggregation_mode=="aggregate_indel":
834
+ try:
835
+ # If a given residue colume is an added zero-column, then we overwrite prior fusion and only predict based on the autoregressive transformer inference mode.
836
+ inserted_retrieval_positions = [True if slice_prior[i].sum()==0 else False for i in range(len(slice_prior))]+[True] #Last True is for the end of sentence token
837
+ fused_shift_log_probas[:,inserted_retrieval_positions,:]=shift_log_probas[:,inserted_retrieval_positions,:]
838
+ except:
839
+ print("Error when adding zero column(s) to account for insertion mutations.")
840
+
841
+ loss_fct = NLLLoss(reduction='none')
842
+ loss = loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
843
+ mask = attention_mask[..., 1:].float()
844
+ mask[mask==0]=float('nan')
845
+ loss *= mask
846
+ loss = nanmean(loss, dim=1).mean()
847
+ else:
848
+ loss_fct = CrossEntropyLoss()
849
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
850
+ fused_shift_log_probas = None
851
+
852
+ if not return_dict:
853
+ output = (lm_logits,) + transformer_outputs[1:]
854
+ return ((loss,) + output) if loss is not None else output
855
+
856
+ return TranceptionCausalLMOutputWithCrossAttentions(
857
+ loss=loss,
858
+ logits=lm_logits,
859
+ past_key_values=transformer_outputs.past_key_values,
860
+ hidden_states=transformer_outputs.hidden_states,
861
+ attentions=transformer_outputs.attentions,
862
+ cross_attentions=transformer_outputs.cross_attentions,
863
+ fused_shift_log_probas=fused_shift_log_probas
864
+ )
865
+
866
+
867
+ @staticmethod
868
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
869
+ """
870
+ This function is used to re-order the :obj:`past_key_values` cache if
871
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
872
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
873
+ """
874
+ return tuple(
875
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
876
+ for layer_past in past
877
+ )
878
+
879
+ def score_mutants(self, DMS_data, target_seq=None, scoring_mirror=True, batch_size_inference=10, num_workers=10, indel_mode=False):
880
+ """
881
+ Method to score mutants in an input DMS file.
882
+ DMS_data: (dataframe) Dataframe containing the list of mutated sequences for scoring.
883
+ target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay. If not None, returned scores are delta log likelihood wrt that sequence.
884
+ scoring_mirror: (bool) Whether to score mutated sequences from both directions (Left->Right and Right->Left).
885
+ batch_size_inference: (int) Batch size for scoring.
886
+ num_workers: (int) Number of workers to be used in the data loader.
887
+ indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
888
+ """
889
+ df = DMS_data.copy()
890
+ if ('mutated_sequence' not in df) and (not indel_mode): df['mutated_sequence'] = df['mutant'].apply(lambda x: scoring_utils.get_mutated_sequence(target_seq, x))
891
+ assert ('mutated_sequence' in df), "DMS file to score does not have mutated_sequence column"
892
+ #if 'mutant' not in df: df['mutant'] = df['mutated_sequence'] #if mutant not in DMS file we default to mutated_sequence
893
+ if 'DMS_score' in df: del df['DMS_score']
894
+ if 'DMS_score_bin' in df: del df['DMS_score_bin']
895
+ if target_seq is not None:
896
+ df_left_to_right_slices = scoring_utils.get_sequence_slices(df, target_seq=target_seq, model_context_len = self.config.n_ctx - 2, indel_mode=indel_mode, scoring_window=self.config.scoring_window)
897
+ else:
898
+ df_left_to_right_slices = scoring_utils.get_sequence_slices(df, target_seq=list(df['mutated_sequence'])[0], model_context_len = self.config.n_ctx - 2, indel_mode=indel_mode, scoring_window='sliding')
899
+ print("Scoring sequences from left to right")
900
+ scores_L_to_R = scoring_utils.get_tranception_scores_mutated_sequences(model=self, mutated_sequence_df=df_left_to_right_slices, batch_size_inference=batch_size_inference, score_var_name='avg_score_L_to_R', target_seq=target_seq, num_workers=num_workers, indel_mode=indel_mode)
901
+ if scoring_mirror:
902
+ print("Scoring sequences from right to left")
903
+ df_right_to_left_slices = df_left_to_right_slices.copy()
904
+ df_right_to_left_slices['sliced_mutated_sequence'] = df_right_to_left_slices['sliced_mutated_sequence'].apply(lambda x: x[::-1])
905
+ scores_R_to_L = scoring_utils.get_tranception_scores_mutated_sequences(model=self, mutated_sequence_df=df_right_to_left_slices, batch_size_inference=batch_size_inference, score_var_name='avg_score_R_to_L', target_seq=target_seq, num_workers=num_workers, reverse=True, indel_mode=indel_mode)
906
+ all_scores = pd.merge(scores_L_to_R, scores_R_to_L, on='mutated_sequence', how='left', suffixes=('','_R_to_L'))
907
+ all_scores['avg_score'] = (all_scores['avg_score_L_to_R'] + all_scores['avg_score_R_to_L']) / 2.0
908
+ else:
909
+ all_scores = scores_L_to_R
910
+ all_scores['avg_score'] = all_scores['avg_score_L_to_R']
911
+ #By design "get_tranception_scores_mutated_sequences" drops the WT from the output. We add it back if that was one of the sequences to score in the DMS (score=0 by definition)
912
+ if target_seq in DMS_data.mutated_sequence.values:
913
+ print("LEMON")
914
+ if scoring_mirror:
915
+ wt_row = pd.DataFrame([[target_seq,0,0,0]], columns=['mutated_sequence','avg_score_L_to_R','avg_score_R_to_L','avg_score'])
916
+ else:
917
+ wt_row = pd.DataFrame([[target_seq,0,0]], columns=['mutated_sequence','avg_score_L_to_R','avg_score'])
918
+ all_scores = pd.concat([all_scores,wt_row], ignore_index=True)
919
+ return all_scores
920
+
921
+ def encode_batch(self, protein_sequence, sequence_name="sliced_mutated_sequence"):
922
+ """
923
+ Method to process an input AA sequence batch (protein_sequence) and return a tokenized sequence (via the tokenizer associated to the model).
924
+ """
925
+ protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='X', char_replacements='ACDEFGHIKLMNPQRSTVWY')
926
+ protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='B', char_replacements='DN')
927
+ protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='J', char_replacements='IL')
928
+ protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='Z', char_replacements='EQ')
929
+ return self.config.tokenizer(list(protein_sequence[sequence_name]), add_special_tokens=True, truncation=True, padding=True, max_length=self.config.n_ctx)
930
+
tranception/outputs.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+
6
+ from transformers.file_utils import ModelOutput
7
+
8
+ @dataclass
9
+ class TranceptionCausalLMOutputWithCrossAttentions(ModelOutput):
10
+ """
11
+ Class for Tranception causal language model (or autoregressive) outputs.
12
+ Args:
13
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
14
+ Language modeling loss (for next-token prediction).
15
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
16
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
17
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
18
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
19
+ shape `(batch_size, sequence_length, hidden_size)`.
20
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
21
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
22
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
23
+ sequence_length)`.
24
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
25
+ heads.
26
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
27
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
28
+ sequence_length)`.
29
+ Cross attentions weights after the attention softmax, used to compute the weighted average in the
30
+ cross-attention heads.
31
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
32
+ Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
33
+ value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
34
+ setting. Only relevant if `config.is_decoder = True`.
35
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
36
+ `past_key_values` input) to speed up sequential decoding.
37
+ fused_shift_log_probas (`torch.FloatTensor` of shape (batch_size, sequence_length, config.vocab_size), *optional*, returned when config.retrieval_aggregation_mode is not None.
38
+ log_probas for each residue position after aggregating autoregressive logits and retrieval logits.
39
+
40
+ """
41
+
42
+ loss: Optional[torch.FloatTensor] = None
43
+ logits: torch.FloatTensor = None
44
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
45
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
46
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
47
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
48
+ fused_shift_log_probas: Optional[torch.FloatTensor] = None
tranception/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import scoring_utils, msa_utils
tranception/utils/dms_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from tranception.utils import scoring_utils
4
+
5
+ def DMS_file_cleanup(DMS_filename, target_seq, start_idx=1, end_idx=None, DMS_mutant_column='mutant', DMS_phenotype_name='score', DMS_directionality=1, AA_vocab = "ACDEFGHIKLMNPQRSTVWY"):
6
+ """
7
+ Function to process the raw substitution DMS assay data (eg., removing invalid mutants, aggregate silent mutations).
8
+ """
9
+ DMS_data = pd.read_csv(DMS_filename, low_memory=False)
10
+ end_idx = start_idx + len(target_seq) - 1 if end_idx is None else end_idx
11
+ DMS_data['mutant'] = DMS_data[DMS_mutant_column]
12
+
13
+ DMS_data=DMS_data[DMS_data['mutant'].notnull()].copy()
14
+ DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([len(y)>=3 for y in x.split(":")]))].copy() #Mutant triplets should have at least 3 or more characters
15
+ DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([(y[0] in AA_vocab) and (y[1:-1].isnumeric()) and (y[-1] in AA_vocab) for y in x.split(":")]))].copy()
16
+ DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([int(y[1:-1])-start_idx >=0 and int(y[1:-1]) <= end_idx for y in x.split(":")]))].copy()
17
+ DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([y[0]==target_seq[int(y[1:-1])-start_idx] for y in x.split(":")]))].copy()
18
+
19
+ DMS_data[DMS_phenotype_name]=pd.to_numeric(DMS_data[DMS_phenotype_name],errors='coerce')
20
+ DMS_data=DMS_data[np.isfinite(DMS_data[DMS_phenotype_name])]
21
+ DMS_data.dropna(subset = [DMS_phenotype_name], inplace=True)
22
+ DMS_data['DMS_score'] = DMS_data[DMS_phenotype_name] * DMS_directionality
23
+ DMS_data=DMS_data[['mutant','DMS_score']]
24
+ DMS_data=DMS_data.groupby('mutant').mean().reset_index()
25
+
26
+ DMS_data['mutated_sequence'] = DMS_data['mutant'].apply(lambda x: scoring_utils.get_mutated_sequence(target_seq, x))
27
+ DMS_data=DMS_data[['mutant','mutated_sequence','DMS_score']]
28
+
29
+ return DMS_data
30
+
tranception/utils/msa_utils.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from collections import defaultdict
4
+ import random
5
+ import os
6
+ import torch
7
+ from Bio.Align.Applications import ClustalOmegaCommandline
8
+
9
+ def filter_msa(msa_data, num_sequences_kept=3):
10
+ """
11
+ Helper function to filter an input MSA msa_data (obtained via process_msa_data) and keep only num_sequences_kept aligned sequences.
12
+ If the MSA already has fewer sequences than num_sequences_kept, we keep the MSA as is.
13
+ If filtering, we always keep the first sequence of the MSA (ie. the wild type) by default.
14
+ Sampling is done without replacement.
15
+ """
16
+ if len(list(msa_data.keys())) <= num_sequences_kept:
17
+ return msa_data
18
+ filtered_msa = {}
19
+ wt_name = next(iter(msa_data))
20
+ filtered_msa[wt_name] = msa_data[wt_name]
21
+ del msa_data[wt_name]
22
+ sequence_names = list(msa_data.keys())
23
+ sequence_names_sampled = random.sample(sequence_names,k=num_sequences_kept-1)
24
+ for seq in sequence_names_sampled:
25
+ filtered_msa[seq] = msa_data[seq]
26
+ return filtered_msa
27
+
28
+ def process_msa_data(MSA_data_file):
29
+ """
30
+ Helper function that takes as input a path to a MSA file (expects a2m format) and returns a dict mapping sequence ID to the corresponding AA sequence.
31
+ """
32
+ msa_data = defaultdict(str)
33
+ sequence_name = ""
34
+ with open(MSA_data_file, "r") as msa_file:
35
+ for i, line in enumerate(msa_file):
36
+ line = line.rstrip()
37
+ if line.startswith(">"):
38
+ sequence_name = line
39
+ else:
40
+ msa_data[sequence_name] += line.upper()
41
+ return msa_data
42
+
43
+ def get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab):
44
+ vocab_size = len(vocab.keys())
45
+ num_sequences_msa = len(msa_data.keys())
46
+ one_hots = np.zeros((num_sequences_msa,MSA_end-MSA_start,vocab_size))
47
+ for i,seq_name in enumerate(msa_data.keys()):
48
+ sequence = msa_data[seq_name]
49
+ for j,letter in enumerate(sequence):
50
+ if letter in vocab:
51
+ k = vocab[letter]
52
+ one_hots[i,j,k] = 1.0
53
+ return one_hots
54
+
55
+ def one_hot(sequence_string,vocab):
56
+ one_hots = np.zeros((len(sequence_string),len(vocab.keys())))
57
+ for j,letter in enumerate(sequence_string):
58
+ if letter in vocab:
59
+ k = vocab[letter]
60
+ one_hots[j,k] = 1.0
61
+ return one_hots.flatten()
62
+
63
+ def get_msa_prior(MSA_data_file, MSA_weight_file_name, MSA_start, MSA_end, len_target_seq, vocab, retrieval_aggregation_mode="aggregate_substitution", filter_MSA=True, verbose=False):
64
+ """
65
+ Function to enable retrieval inference mode, via computation of (weighted) pseudocounts of AAs at each position of the retrieved MSA.
66
+ MSA_data_file: (string) path to MSA file (expects a2m format).
67
+ MSA_weight_file_name: (string) path to sequence weights in MSA.
68
+ MSA_start: (int) Sequence position that the MSA starts at (1-indexing).
69
+ MSA_end: (int) Sequence position that the MSA ends at (1-indexing).
70
+ len_target_seq: (int) Full length of sequence to be scored.
71
+ vocab: (dict) Vocabulary of the tokenizer.
72
+ retrieval_aggregation_mode: (string) Mode for retrieval inference (aggregate_substitution Vs aggregate_indel). If None, places a uniform prior over each token.
73
+ filter_MSA: (bool) Whether to filter out sequences with very low hamming similarity (< 0.2) to the reference sequence in the MSA (first sequence).
74
+ verbose: (bool) Whether to print to the console processing details along the way.
75
+ """
76
+ msa_data = process_msa_data(MSA_data_file)
77
+ vocab_size = len(vocab.keys())
78
+ if verbose: print("Target seq len is {}, MSA length is {}, start position is {}, end position is {} and vocab size is {}".format(len_target_seq,MSA_end-MSA_start,MSA_start,MSA_end,vocab_size))
79
+
80
+ if filter_MSA:
81
+ if verbose: print("Num sequences in MSA pre filtering: {}".format(len(msa_data.keys())))
82
+ list_sequence_names = list(msa_data.keys())
83
+ focus_sequence_name = list(msa_data.keys())[0]
84
+ ref_sequence_hot = one_hot(msa_data[focus_sequence_name],vocab)
85
+ for sequence_name in list_sequence_names:
86
+ seq_hot = one_hot(msa_data[sequence_name],vocab)
87
+ hamming_similarity_seq_ref = np.dot(ref_sequence_hot,seq_hot) / np.dot(ref_sequence_hot,ref_sequence_hot)
88
+ if hamming_similarity_seq_ref < 0.2:
89
+ del msa_data[sequence_name]
90
+ if verbose: print("Num sequences in MSA post filtering: {}".format(len(msa_data.keys())))
91
+
92
+ if MSA_weight_file_name is not None:
93
+ if verbose: print("Using weights in {} for sequences in MSA.".format(MSA_weight_file_name))
94
+ assert os.path.exists(MSA_weight_file_name), "Weights file not located on disk."
95
+ MSA_EVE = MSA_processing(
96
+ MSA_location=MSA_data_file,
97
+ use_weights=True,
98
+ weights_location=MSA_weight_file_name
99
+ )
100
+ #We scan through all sequences to see if we have a weight for them as per EVE pre-processing. We drop them otherwise.
101
+ dropped_sequences=0
102
+ list_sequence_names = list(msa_data.keys())
103
+ MSA_weight=[]
104
+ for sequence_name in list_sequence_names:
105
+ if sequence_name not in MSA_EVE.seq_name_to_sequence:
106
+ dropped_sequences +=1
107
+ del msa_data[sequence_name]
108
+ else:
109
+ MSA_weight.append(MSA_EVE.seq_name_to_weight[sequence_name])
110
+ if verbose: print("Dropped {} sequences from MSA due to absent sequence weights".format(dropped_sequences))
111
+ else:
112
+ MSA_weight = [1] * len(list(msa_data.keys()))
113
+
114
+ if retrieval_aggregation_mode=="aggregate_substitution" or retrieval_aggregation_mode=="aggregate_indel":
115
+ one_hots = get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab)
116
+ MSA_weight = np.expand_dims(np.array(MSA_weight),axis=(1,2))
117
+ base_rate = 1e-5
118
+ base_rates = np.ones_like(one_hots) * base_rate
119
+ weighted_one_hots = (one_hots + base_rates) * MSA_weight
120
+ MSA_weight_norm_counts = weighted_one_hots.sum(axis=-1).sum(axis=0)
121
+ MSA_weight_norm_counts = np.tile(MSA_weight_norm_counts.reshape(-1,1), (1,vocab_size))
122
+ one_hots_avg = weighted_one_hots.sum(axis=0) / MSA_weight_norm_counts
123
+ msa_prior = np.zeros((len_target_seq,vocab_size))
124
+ msa_prior[MSA_start:MSA_end,:]=one_hots_avg
125
+ else:
126
+ msa_prior = np.ones((len_target_seq,vocab_size)) / vocab_size
127
+
128
+ if verbose:
129
+ for idx, position in enumerate(msa_prior):
130
+ if len(position)!=25:
131
+ print("Size error")
132
+ if not round(position.sum(),2)==1.0:
133
+ print("Position at index {} does not add up to 1: {}".format(idx, position.sum()))
134
+
135
+ return msa_prior
136
+
137
+
138
+ def update_retrieved_MSA_log_prior_indel(model, MSA_log_prior, MSA_start, MSA_end, mutated_sequence):
139
+ """
140
+ Function to process MSA when scoring indels.
141
+ To identify positions to add / remove in the retrieved MSA, we append and align the sequence to be scored to the original MSA for that protein family with Clustal Omega.
142
+ If the original MSA is relatively deep (over 100k sequences), we sample (by default) 100k rows at random from that MSA to speed computations.
143
+ MSA sampling is performed only once (for the first sequence to be scored). Subsequent scoring use the same MSA sample.
144
+ """
145
+ if not os.path.isdir(model.MSA_folder + os.sep + "Sampled"):
146
+ os.mkdir(model.MSA_folder + os.sep + "Sampled")
147
+ sampled_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Sampled_" + model.MSA_filename.split(os.sep)[-1]
148
+
149
+ if not os.path.exists(sampled_MSA_location):
150
+ msa_data = process_msa_data(model.MSA_filename)
151
+ msa_data_sampled = filter_msa(msa_data, num_sequences_kept=100000) #If MSA has less than 100k sequences, the sample is identical to original MSA
152
+ with open(sampled_MSA_location, 'w') as sampled_write_location:
153
+ for index, key in enumerate(msa_data_sampled):
154
+ key_name = ">REFERENCE_SEQUENCE" if index==0 else key
155
+ msa_data_sampled[key] = msa_data_sampled[key].upper()
156
+ msa_data_sampled[key] = msa_data_sampled[key].replace(".","-")
157
+ sampled_write_location.write(key_name+"\n"+"\n".join([msa_data_sampled[key][i:i+80] for i in range(0, len(msa_data_sampled[key]), 80)])+"\n")
158
+
159
+ seq_to_align_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Seq_to_align_" + model.MSA_filename.split(os.sep)[-1]
160
+ sequence_text_split = [mutated_sequence[i:i+80] for i in range(0, len(mutated_sequence), 80)]
161
+ sequence_text_split_split_join = "\n".join([">SEQ_TO_SCORE"]+sequence_text_split)
162
+ os.system("echo '"+sequence_text_split_split_join+"' > "+seq_to_align_location)
163
+
164
+ expanded_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Expanded_" + model.MSA_filename.split(os.sep)[-1]
165
+ clustalw_cline = ClustalOmegaCommandline(cmd=model.config.clustal_omega_location,
166
+ profile1=sampled_MSA_location,
167
+ profile2=seq_to_align_location,
168
+ outfile=expanded_MSA_location,
169
+ force=True)
170
+ stdout, stderr = clustalw_cline()
171
+ msa_data = process_msa_data(expanded_MSA_location)
172
+ aligned_seqA, aligned_seqB = msa_data[">SEQ_TO_SCORE"], msa_data[">REFERENCE_SEQUENCE"]
173
+ try:
174
+ keep_column=[]
175
+ for column_index_pairwise_alignment in range(len(aligned_seqA)):
176
+ if aligned_seqA[column_index_pairwise_alignment]=="-" and aligned_seqB[column_index_pairwise_alignment]=="-":
177
+ continue
178
+ elif aligned_seqA[column_index_pairwise_alignment]=="-":
179
+ keep_column.append(False)
180
+ elif aligned_seqB[column_index_pairwise_alignment]=="-":
181
+ MSA_log_prior=torch.cat((MSA_log_prior[:column_index_pairwise_alignment], torch.zeros(MSA_log_prior.shape[1]).view(1,-1).cuda(), MSA_log_prior[column_index_pairwise_alignment:]),dim=0)
182
+ keep_column.append(True) #keep the zero column we just added
183
+ else:
184
+ keep_column.append(True)
185
+ MSA_log_prior = MSA_log_prior[keep_column]
186
+ MSA_end = MSA_start + len(MSA_log_prior)
187
+ except:
188
+ print("Error when processing the following alignment: {}".format(expanded_MSA_location))
189
+ return MSA_log_prior, MSA_start, MSA_end
190
+
191
+ class MSA_processing:
192
+ def __init__(self,
193
+ MSA_location="",
194
+ theta=0.2,
195
+ use_weights=True,
196
+ weights_location="./data/weights",
197
+ preprocess_MSA=True,
198
+ threshold_sequence_frac_gaps=0.5,
199
+ threshold_focus_cols_frac_gaps=0.3,
200
+ remove_sequences_with_indeterminate_AA_in_focus_cols=True
201
+ ):
202
+
203
+ """
204
+ This MSA_processing class is directly borrowed from the EVE codebase: https://github.com/OATML-Markslab/EVE
205
+
206
+ Parameters:
207
+ - msa_location: (path) Location of the MSA data. Constraints on input MSA format:
208
+ - focus_sequence is the first one in the MSA data
209
+ - first line is structured as follows: ">focus_seq_name/start_pos-end_pos" (e.g., >SPIKE_SARS2/310-550)
210
+ - corespondding sequence data located on following line(s)
211
+ - then all other sequences follow with ">name" on first line, corresponding data on subsequent lines
212
+ - theta: (float) Sequence weighting hyperparameter. Generally: Prokaryotic and eukaryotic families = 0.2; Viruses = 0.01
213
+ - use_weights: (bool) If False, sets all sequence weights to 1. If True, checks weights_location -- if non empty uses that;
214
+ otherwise compute weights from scratch and store them at weights_location
215
+ - weights_location: (path) Location to load from/save to the sequence weights
216
+ - preprocess_MSA: (bool) performs pre-processing of MSA to remove short fragments and positions that are not well covered.
217
+ - threshold_sequence_frac_gaps: (float, between 0 and 1) Threshold value to define fragments
218
+ - sequences with a fraction of gap characters above threshold_sequence_frac_gaps are removed
219
+ - default is set to 0.5 (i.e., fragments with 50% or more gaps are removed)
220
+ - threshold_focus_cols_frac_gaps: (float, between 0 and 1) Threshold value to define focus columns
221
+ - positions with a fraction of gap characters above threshold_focus_cols_pct_gaps will be set to lower case (and not included in the focus_cols)
222
+ - default is set to 0.3 (i.e., focus positions are the ones with 30% of gaps or less, i.e., 70% or more residue occupancy)
223
+ - remove_sequences_with_indeterminate_AA_in_focus_cols: (bool) Remove all sequences that have indeterminate AA (e.g., B, J, X, Z) at focus positions of the wild type
224
+ """
225
+ np.random.seed(2021)
226
+ self.MSA_location = MSA_location
227
+ self.weights_location = weights_location
228
+ self.theta = theta
229
+ self.alphabet = "ACDEFGHIKLMNPQRSTVWY"
230
+ self.use_weights = use_weights
231
+ self.preprocess_MSA = preprocess_MSA
232
+ self.threshold_sequence_frac_gaps = threshold_sequence_frac_gaps
233
+ self.threshold_focus_cols_frac_gaps = threshold_focus_cols_frac_gaps
234
+ self.remove_sequences_with_indeterminate_AA_in_focus_cols = remove_sequences_with_indeterminate_AA_in_focus_cols
235
+
236
+ self.gen_alignment()
237
+
238
+ def gen_alignment(self, verbose=False):
239
+ """ Read training alignment and store basics in class instance """
240
+ self.aa_dict = {}
241
+ for i,aa in enumerate(self.alphabet):
242
+ self.aa_dict[aa] = i
243
+
244
+ self.seq_name_to_sequence = defaultdict(str)
245
+ name = ""
246
+ with open(self.MSA_location, "r") as msa_data:
247
+ for i, line in enumerate(msa_data):
248
+ line = line.rstrip()
249
+ if line.startswith(">"):
250
+ name = line
251
+ if i==0:
252
+ self.focus_seq_name = name
253
+ else:
254
+ self.seq_name_to_sequence[name] += line
255
+
256
+
257
+ ## MSA pre-processing to remove inadequate columns and sequences
258
+ if self.preprocess_MSA:
259
+ msa_df = pd.DataFrame.from_dict(self.seq_name_to_sequence, orient='index', columns=['sequence'])
260
+ # Data clean up
261
+ msa_df.sequence = msa_df.sequence.apply(lambda x: x.replace(".","-")).apply(lambda x: ''.join([aa.upper() for aa in x]))
262
+ # Remove columns that would be gaps in the wild type
263
+ non_gap_wt_cols = [aa!='-' for aa in msa_df.sequence[self.focus_seq_name]]
264
+ msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa for aa,non_gap_ind in zip(x, non_gap_wt_cols) if non_gap_ind]))
265
+ assert 0.0 <= self.threshold_sequence_frac_gaps <= 1.0,"Invalid fragment filtering parameter"
266
+ assert 0.0 <= self.threshold_focus_cols_frac_gaps <= 1.0,"Invalid focus position filtering parameter"
267
+ msa_array = np.array([list(seq) for seq in msa_df.sequence])
268
+ gaps_array = np.array(list(map(lambda seq: [aa=='-' for aa in seq], msa_array)))
269
+ # Identify fragments with too many gaps
270
+ seq_gaps_frac = gaps_array.mean(axis=1)
271
+ seq_below_threshold = seq_gaps_frac <= self.threshold_sequence_frac_gaps
272
+ if verbose: print("Proportion of sequences dropped due to fraction of gaps: "+str(round(float(1 - seq_below_threshold.sum()/seq_below_threshold.shape)*100,2))+"%")
273
+ # Identify focus columns
274
+ columns_gaps_frac = gaps_array[seq_below_threshold].mean(axis=0)
275
+ index_cols_below_threshold = columns_gaps_frac <= self.threshold_focus_cols_frac_gaps
276
+ if verbose: print("Proportion of non-focus columns removed: "+str(round(float(1 - index_cols_below_threshold.sum()/index_cols_below_threshold.shape)*100,2))+"%")
277
+ # Lower case non focus cols and filter fragment sequences
278
+ msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa.upper() if upper_case_ind else aa.lower() for aa, upper_case_ind in zip(x, index_cols_below_threshold)]))
279
+ msa_df = msa_df[seq_below_threshold]
280
+ # Overwrite seq_name_to_sequence with clean version
281
+ self.seq_name_to_sequence = defaultdict(str)
282
+ for seq_idx in range(len(msa_df['sequence'])):
283
+ self.seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx]
284
+
285
+ self.focus_seq = self.seq_name_to_sequence[self.focus_seq_name]
286
+ self.focus_cols = [ix for ix, s in enumerate(self.focus_seq) if s == s.upper() and s!='-']
287
+ self.focus_seq_trimmed = [self.focus_seq[ix] for ix in self.focus_cols]
288
+ self.seq_len = len(self.focus_cols)
289
+ self.alphabet_size = len(self.alphabet)
290
+
291
+ # Connect local sequence index with uniprot index (index shift inferred from 1st row of MSA)
292
+ focus_loc = self.focus_seq_name.split("/")[-1]
293
+ start,stop = focus_loc.split("-")
294
+ self.focus_start_loc = int(start)
295
+ self.focus_stop_loc = int(stop)
296
+ self.uniprot_focus_col_to_wt_aa_dict \
297
+ = {idx_col+int(start):self.focus_seq[idx_col] for idx_col in self.focus_cols}
298
+ self.uniprot_focus_col_to_focus_idx \
299
+ = {idx_col+int(start):idx_col for idx_col in self.focus_cols}
300
+
301
+ # Move all letters to CAPS; keeps focus columns only
302
+ self.raw_seq_name_to_sequence = self.seq_name_to_sequence.copy()
303
+ for seq_name,sequence in self.seq_name_to_sequence.items():
304
+ sequence = sequence.replace(".","-")
305
+ self.seq_name_to_sequence[seq_name] = [sequence[ix].upper() for ix in self.focus_cols]
306
+
307
+ # Remove sequences that have indeterminate AA (e.g., B, J, X, Z) in the focus columns
308
+ if self.remove_sequences_with_indeterminate_AA_in_focus_cols:
309
+ alphabet_set = set(list(self.alphabet))
310
+ seq_names_to_remove = []
311
+ for seq_name,sequence in self.seq_name_to_sequence.items():
312
+ for letter in sequence:
313
+ if letter not in alphabet_set and letter != "-":
314
+ seq_names_to_remove.append(seq_name)
315
+ continue
316
+ seq_names_to_remove = list(set(seq_names_to_remove))
317
+ for seq_name in seq_names_to_remove:
318
+ del self.seq_name_to_sequence[seq_name]
319
+
320
+ # Encode the sequences
321
+ self.one_hot_encoding = np.zeros((len(self.seq_name_to_sequence.keys()),len(self.focus_cols),len(self.alphabet)))
322
+ if verbose: print("One-hot encoded sequences shape:" + str(self.one_hot_encoding.shape))
323
+ for i,seq_name in enumerate(self.seq_name_to_sequence.keys()):
324
+ sequence = self.seq_name_to_sequence[seq_name]
325
+ for j,letter in enumerate(sequence):
326
+ if letter in self.aa_dict:
327
+ k = self.aa_dict[letter]
328
+ self.one_hot_encoding[i,j,k] = 1.0
329
+
330
+ if self.use_weights:
331
+ try:
332
+ self.weights = np.load(file=self.weights_location)
333
+ if verbose: print("Loaded sequence weights from disk")
334
+ except:
335
+ if verbose: print ("Computing sequence weights")
336
+ list_seq = self.one_hot_encoding
337
+ list_seq = list_seq.reshape((list_seq.shape[0], list_seq.shape[1] * list_seq.shape[2]))
338
+ def compute_weight(seq):
339
+ number_non_empty_positions = np.dot(seq,seq)
340
+ if number_non_empty_positions>0:
341
+ denom = np.dot(list_seq,seq) / np.dot(seq,seq)
342
+ denom = np.sum(denom > 1 - self.theta)
343
+ return 1/denom
344
+ else:
345
+ return 0.0 #return 0 weight if sequence is fully empty
346
+ self.weights = np.array(list(map(compute_weight,list_seq)))
347
+ np.save(file=self.weights_location, arr=self.weights)
348
+ else:
349
+ # If not using weights, use an isotropic weight matrix
350
+ if verbose: print("Not weighting sequence data")
351
+ self.weights = np.ones(self.one_hot_encoding.shape[0])
352
+
353
+ self.Neff = np.sum(self.weights)
354
+ self.num_sequences = self.one_hot_encoding.shape[0]
355
+ self.seq_name_to_weight={}
356
+ for i,seq_name in enumerate(self.seq_name_to_sequence.keys()):
357
+ self.seq_name_to_weight[seq_name]=self.weights[i]
358
+
359
+ if verbose:
360
+ print ("Neff =",str(self.Neff))
361
+ print ("Data Shape =",self.one_hot_encoding.shape)
tranception/utils/scoring_utils.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tqdm
3
+ import re
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ import torch
8
+ from torch.nn import CrossEntropyLoss, NLLLoss
9
+ from torch.utils.data.sampler import Sampler, SequentialSampler
10
+
11
+ from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerFast
12
+ from datasets import Dataset
13
+
14
+ AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
15
+
16
+ def get_mutated_sequence(focus_seq, mutant, start_idx=1, AA_vocab=AA_vocab):
17
+ """
18
+ Helper function that mutates an input sequence (focus_seq) via an input mutation triplet (substitutions only).
19
+ Mutation triplet are typically based on 1-indexing: start_idx is used for switching to 0-indexing.
20
+ """
21
+ mutated_seq = list(focus_seq)
22
+ for mutation in mutant.split(":"):
23
+ try:
24
+ from_AA, position, to_AA = mutation[0], int(mutation[1:-1]), mutation[-1]
25
+ except:
26
+ print("Issue with mutant: "+str(mutation))
27
+ relative_position = position - start_idx
28
+ assert (from_AA==focus_seq[relative_position]), "Invalid from_AA or mutant position: "+str(mutation)+" from_AA: "+str(from_AA) + " relative pos: "+str(relative_position) + " focus_seq: "+str(focus_seq)
29
+ assert (to_AA in AA_vocab) , "Mutant to_AA is invalid: "+str(mutation)
30
+ mutated_seq[relative_position] = to_AA
31
+ return "".join(mutated_seq)
32
+
33
+ def nanmean(v, *args, inplace=False, **kwargs):
34
+ if not inplace:
35
+ v = v.clone()
36
+ is_nan = torch.isnan(v)
37
+ v[is_nan] = 0
38
+ return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
39
+
40
+ def nansum(v, *args, inplace=False, **kwargs):
41
+ if not inplace:
42
+ v = v.clone()
43
+ is_nan = torch.isnan(v)
44
+ v[is_nan] = 0
45
+ return v.sum(*args, **kwargs)
46
+
47
+ def get_optimal_window(mutation_position_relative, seq_len_wo_special, model_window):
48
+ """
49
+ Helper function that selects an optimal sequence window that fits the maximum model context size.
50
+ If the sequence length is less than the maximum context size, the full sequence is returned.
51
+ """
52
+ half_model_window = model_window // 2
53
+ if seq_len_wo_special <= model_window:
54
+ return [0,seq_len_wo_special]
55
+ elif mutation_position_relative < half_model_window:
56
+ return [0,model_window]
57
+ elif mutation_position_relative >= seq_len_wo_special - half_model_window:
58
+ return [seq_len_wo_special - model_window, seq_len_wo_special]
59
+ else:
60
+ return [max(0,mutation_position_relative-half_model_window), min(seq_len_wo_special,mutation_position_relative+half_model_window)]
61
+
62
+ def sequence_replace_single(sequence, char_to_replace, char_replacements):
63
+ char_replacements = list(char_replacements)
64
+ positions = [m.start() for m in re.finditer(char_to_replace, sequence)]
65
+ replacements = np.random.choice(a=char_replacements, size=len(positions), replace=True)
66
+ sequence=list(sequence)
67
+ for idx, position in enumerate(positions):
68
+ sequence[position]=replacements[idx]
69
+ return ''.join(sequence)
70
+
71
+ def sequence_replace(sequences, char_to_replace, char_replacements):
72
+ """
73
+ Helper function that replaces all Amino Acids passsed in via char_to_replace (as a string of AAs) with Amino Acids sampled from char_replacements (also a string of eligible AAs).
74
+ """
75
+ return [sequence_replace_single(sequence, char_to_replace, char_replacements) for sequence in sequences]
76
+
77
+ def get_tranception_scores_mutated_sequences(model, mutated_sequence_df, batch_size_inference, score_var_name, target_seq, num_workers=10, reverse=False, indel_mode=False):
78
+ """
79
+ Helper function that takes as input a set of mutated sequences (in a pandas dataframe) and returns scores for each mutation.
80
+ If target_seq is not None, returns the delta log likelihood wrt that target sequence -- otherwise returns the log likelihood of the protein sequences.
81
+ """
82
+ scores = {}
83
+ scores['mutated_sequence']=[]
84
+ scores['sliced_mutated_sequence']=[]
85
+ scores['window_start']=[]
86
+ scores['window_end']=[]
87
+ scores['score']=[]
88
+ with torch.no_grad():
89
+ ds = Dataset.from_pandas(mutated_sequence_df)
90
+ ds.set_transform(model.encode_batch)
91
+ data_collator = DataCollatorForLanguageModeling(
92
+ tokenizer=model.config.tokenizer,
93
+ mlm=False)
94
+ sampler = SequentialSampler(ds)
95
+ ds_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size_inference, sampler=sampler, collate_fn=data_collator, num_workers=num_workers, pin_memory=True, drop_last=False)
96
+ mutant_index=0
97
+ for encoded_batch in tqdm.tqdm(ds_loader):
98
+ full_batch_length = len(encoded_batch['input_ids'])
99
+ mutated_sequence = np.array(mutated_sequence_df['mutated_sequence'][mutant_index:mutant_index+full_batch_length])
100
+ scores['mutated_sequence'] += list(mutated_sequence)
101
+ sliced_mutated_sequence = np.array(mutated_sequence_df['sliced_mutated_sequence'][mutant_index:mutant_index+full_batch_length])
102
+ scores['sliced_mutated_sequence'] += list(sliced_mutated_sequence)
103
+ window_start = np.array(mutated_sequence_df['window_start'][mutant_index:mutant_index+full_batch_length])
104
+ scores['window_start'] += list(window_start)
105
+ window_end = np.array(mutated_sequence_df['window_end'][mutant_index:mutant_index+full_batch_length])
106
+ scores['window_end'] += list(window_end)
107
+ for k, v in encoded_batch.items():
108
+ if isinstance(v, torch.Tensor):
109
+ encoded_batch[k] = v.to(model.device)
110
+ shift_labels = encoded_batch['labels'][..., 1:].contiguous()
111
+ if (hasattr(model.config,"retrieval_aggregation_mode")) and (model.config.retrieval_aggregation_mode is not None):
112
+ if reverse:
113
+ encoded_batch['flip']=torch.tensor([1]*full_batch_length)
114
+ encoded_batch['start_slice']=window_start
115
+ encoded_batch['end_slice']=window_end
116
+ encoded_batch['mutated_sequence'] = mutated_sequence #only mutated_sequence is flipped if the scoring_mirror branch of score_mutants. No need to flip mutated_sequence for MSA re-aligning
117
+ fused_shift_log_probas=model(**encoded_batch,return_dict=True).fused_shift_log_probas
118
+ loss_fct = NLLLoss(reduction='none')
119
+ loss = - loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
120
+ else:
121
+ lm_logits=model(**encoded_batch,return_dict=True).logits
122
+ shift_logits = lm_logits[..., :-1, :].contiguous()
123
+ loss_fct = CrossEntropyLoss(reduction='none')
124
+ loss = - loss_fct(input=shift_logits.view(-1, shift_logits.size(-1)), target=shift_labels.view(-1)).view(shift_logits.shape[0],shift_logits.shape[1])
125
+ mask = encoded_batch['attention_mask'][..., 1:].float()
126
+ mask[mask==0]=float('nan')
127
+ loss *= mask
128
+ loss = nansum(loss, dim=1)
129
+ scores_batch = list(loss.cpu().numpy())
130
+ full_batch_length = len(encoded_batch['input_ids'])
131
+ scores['score'] += scores_batch
132
+ mutant_index+=full_batch_length
133
+ scores = pd.DataFrame(scores)
134
+ if model.config.scoring_window=="sliding":
135
+ scores = scores[['mutated_sequence','score']].groupby('mutated_sequence').sum().reset_index() #We need to aggregate scores when using sliding mode
136
+ scores['score'] = scores['score'] / scores['mutated_sequence'].map(lambda x: len(x))
137
+ if target_seq is not None:
138
+ scores_mutated_seq = scores[scores.mutated_sequence != target_seq]
139
+ scores_wt = scores[scores.mutated_sequence == target_seq]
140
+ merge_delta = 'mutated_sequence' if model.config.scoring_window=="sliding" else 'window_start'
141
+ if model.config.scoring_window=="optimal":
142
+ delta_scores = pd.merge(scores_mutated_seq,scores_wt,how='left',on=[merge_delta],suffixes=('','_wt'))
143
+ delta_scores[score_var_name] = delta_scores['score'] - delta_scores['score_wt']
144
+ elif model.config.scoring_window=="sliding":
145
+ delta_scores = scores_mutated_seq.copy()
146
+ delta_scores[score_var_name] = delta_scores['score'] - list(scores_wt['score'])[0] # In sliding mode there is a single reference window for the WT
147
+ return delta_scores[['mutated_sequence',score_var_name]]
148
+ else:
149
+ scores[score_var_name] = scores['score']
150
+ return scores[['mutated_sequence',score_var_name]]
151
+
152
+ def get_sequence_slices(df, target_seq, model_context_len, start_idx=1, scoring_window="optimal", indel_mode=False):
153
+ """
154
+ Helper function that takes as input a (pandas) dataframe df that contains a list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring.
155
+ It returns a processed DMS in which sequences have been sliced to satisfy the maximum context window of the model.
156
+ df: (dataframe) Input dataframe to be processed
157
+ target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay.
158
+ model_context_len: (int) Maximum context size for the model.
159
+ start_idx: (int) Integer to move to 0-indexing of positions (mutation triplet are typically based on 1-indexing).
160
+ scoring_window: (string) Method to slice sequences longer than maximum context size:
161
+ - optimal selects a single window as large as possible via the get_optimal_window function (this is the default)
162
+ - sliding splits the full sequence in contiguous (non-overlapping) chunks that are of size equal to the max context (except the last chunk which may be shorter)
163
+ indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
164
+ Note: when scoring indels for sequences that would be longer than the model max context length, it is preferable to use the "sliding" scoring_window. Use "optimal" otherwise.
165
+ """
166
+ len_target_seq = len(target_seq)
167
+ num_mutants = len(df['mutated_sequence'])
168
+ df=df.reset_index(drop=True)
169
+ if scoring_window=="optimal":
170
+ df['mutation_barycenter'] = df['mutant'].apply(lambda x: int(np.array([int(mutation[1:-1]) - start_idx for mutation in x.split(':')]).mean())) if not indel_mode else df['mutated_sequence'].apply(lambda x: len(x)//2)
171
+ df['scoring_optimal_window'] = df['mutation_barycenter'].apply(lambda x: get_optimal_window(x, len_target_seq, model_context_len)) if not indel_mode else df['mutated_sequence'].apply(lambda x: (0,len(x)))
172
+ df['sliced_mutated_sequence'] = [df['mutated_sequence'][index][df['scoring_optimal_window'][index][0]:df['scoring_optimal_window'][index][1]] for index in range(num_mutants)]
173
+ df['window_start'] = df['scoring_optimal_window'].map(lambda x: x[0])
174
+ df['window_end'] = df['scoring_optimal_window'].map(lambda x: x[1])
175
+ del df['scoring_optimal_window'], df['mutation_barycenter']
176
+ if 'mutant' in df: del df['mutant']
177
+ df_wt=df.copy()
178
+ df_wt['mutated_sequence'] = [target_seq] * num_mutants
179
+ if indel_mode: # For indels, we set the wild type reference to be always the same (full length) sequence. We assume here that the length is lower than model context size (otherwise "Sliding" mode should be used)
180
+ df_wt['window_end'] = df_wt['mutated_sequence'].map(lambda x:len(x))
181
+ df_wt['sliced_mutated_sequence'] = [target_seq[df_wt['window_start'][index]:df_wt['window_end'][index]] for index in range(num_mutants)]
182
+ df = pd.concat([df,df_wt], axis=0)
183
+ df = df.drop_duplicates()
184
+ elif scoring_window=="sliding":
185
+ num_windows = 1 + int( len_target_seq / model_context_len)
186
+ df_list=[]
187
+ start=0
188
+ for window_index in range(1, num_windows+1):
189
+ df_sliced = df.copy()
190
+ df_sliced['sliced_mutated_sequence'] = df_sliced['mutated_sequence'].map(lambda x: x[start:start+model_context_len])
191
+ df_sliced['window_start'] = [start] * num_mutants
192
+ df_sliced['window_end'] = df_sliced['mutated_sequence'].map(lambda x: min(len(x), start+model_context_len))
193
+ df_sliced_wt = df_sliced.copy()
194
+ df_sliced_wt['mutated_sequence'] = [target_seq] * num_mutants
195
+ df_sliced_wt['sliced_mutated_sequence'] = df_sliced_wt['mutated_sequence'].map(lambda x: x[start:start+model_context_len])
196
+ df_sliced_wt['window_end'] = df_sliced_wt['mutated_sequence'].map(lambda x: min(len(x), start+model_context_len)) #Need to adjust end index if WT and sequence are not same full length
197
+ df_list.append(df_sliced)
198
+ df_list.append(df_sliced_wt)
199
+ start += model_context_len
200
+ df_final = pd.concat(df_list,axis=0)
201
+ if 'mutant' in df_final: del df_final['mutant']
202
+ df = df_final.drop_duplicates()
203
+ return df.reset_index(drop=True)
tranception/utils/tokenizers/Basic_tokenizer ADDED
@@ -0,0 +1 @@
 
 
1
+ {"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[CLS]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SEP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":3,"special":true,"content":"[PAD]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":4,"special":true,"content":"[MASK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":{"type":"TemplateProcessing","single":[{"SpecialToken":{"id":"[CLS]","type_id":0}},{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[SEP]","type_id":0}}],"pair":[{"SpecialToken":{"id":"[CLS]","type_id":0}},{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[SEP]","type_id":0}},{"Sequence":{"id":"B","type_id":1}},{"SpecialToken":{"id":"[SEP]","type_id":1}}],"special_tokens":{"[CLS]":{"id":"[CLS]","ids":[1],"tokens":["[CLS]"]},"[SEP]":{"id":"[SEP]","ids":[2],"tokens":["[SEP]"]}}},"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[UNK]":0,"[CLS]":1,"[SEP]":2,"[PAD]":3,"[MASK]":4,"A":5,"C":6,"D":7,"E":8,"F":9,"G":10,"H":11,"I":12,"K":13,"L":14,"M":15,"N":16,"P":17,"Q":18,"R":19,"S":20,"T":21,"V":22,"W":23,"Y":24},"merges":[]}}