PascalNotin commited on
Commit
fa87656
1 Parent(s): c7fea90

Added Tranception model

Browse files
README.md CHANGED
@@ -1,3 +1,133 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tranception
2
+
3
+ This is the official code repository for the paper "Tranception: protein fitness prediction with autoregressive transformers and inference-time retrieval". This project is a joint collaboration between the [Marks lab](https://www.deboramarkslab.com/) and the [OATML group](https://oatml.cs.ox.ac.uk/).
4
+
5
+ ## Abstract
6
+ The ability to accurately model the fitness landscape of protein sequences is critical to a wide range of applications, from quantifying the effects of human variants on disease likelihood, to predicting immune-escape mutations in viruses and designing novel biotherapeutic proteins. Deep generative models of protein sequences trained on multiple sequence alignments have been the most successful approaches so far to address these tasks. The performance of these methods is however contingent on the availability of sufficiently deep and diverse alignments for reliable training. Their potential scope is thus limited by the fact many protein families are hard, if not impossible, to align. Large language models trained on massive quantities of non-aligned protein sequences from diverse families address these problems and show potential to eventually bridge the performance gap. We introduce Tranception, a novel transformer architecture leveraging autoregressive predictions and retrieval of homologous sequences at inference to achieve state-of-the-art fitness prediction performance. Given its markedly higher performance on multiple mutants, robustness to shallow alignments and ability to score indels, our approach offers significant gain of scope over existing approaches. To enable more rigorous model testing across a broader range of protein families, we develop ProteinGym -- an extensive set of multiplexed assays of variant effects, substantially increasing both the number and diversity of assays compared to existing benchmarks.
7
+
8
+ ## Setup
9
+ You may download the Tranception repository and create a conda environment with the proper dependencies (as listed in `tranception_env.yml`) as follows:
10
+ ```
11
+ git clone https://github.com/OATML-Markslab/Tranception.git
12
+ conda env create -f tranception_env.yml
13
+ ```
14
+
15
+ ## Tranception
16
+ Tranception is a novel autoregressive transformer architecture that was designed with two core principles in mind: 1) promoting specialization across attention heads 2) explicitly extracting patterns from contiguous subsequences.
17
+
18
+ To download the *Tranception Large* model checkpoint (~3.1GB unzipped):
19
+ ```
20
+ curl -o Tranception_Large.zip https://marks.hms.harvard.edu/ProteinGym/Tranception_Large.zip
21
+ unzip Tranception_Large.zip
22
+ rm Tranception_Large.zip
23
+ ```
24
+
25
+ Tranception is also made available through the [Huggging Face hub](https://huggingface.co/OATML-Markslab/Tranception).
26
+
27
+ When scoring with retrieval, we compute weighted pseudocounts at each position using sequence weights as per the procedure described in [Hopf et al.](https://www.nature.com/articles/nbt.3769).
28
+ Weights for all proteins in the ProteinGym benchmarks may be downloaded as follows (~68M unzipped):
29
+ ```
30
+ curl -o MSA_weights.zip https://marks.hms.harvard.edu/ProteinGym/MSA_weights.zip
31
+ unzip MSA_weights.zip
32
+ rm MSA_weights.zip
33
+ ```
34
+ To compute sequence weights for new proteins, you may use the MSA_processing class under `tranception/utils/msa_utils.py`.
35
+
36
+ The `examples` folder provides several bash scripts that may be used for scoring and evaluating Tranception on the ProteinGym benchmarks. We also provide a colab notebook illustrating how to load Tranception from the Hugging Face hub and then score the mutated sequences in the ProteinGym benchmarks with it.
37
+
38
+ ## ProteinGym
39
+ ProteinGym is an extensive set of Deep Mutational Scanning (DMS) assays curated to enable thorough comparisons of various mutation effect predictors indifferent regimes. ProteinGym is comprised of two benchmarks: 1) a substitution benchmark which consists of the experimental characterisation of ∼1.5M missense variants across 87 DMS assays 2) an indel benchmark that includes ∼300k mutants across 7 DMS assays.
40
+
41
+ Each processed file in each benchmark corresponds to a single DMS assay, and contains the following three variables:
42
+ - mutant (str):
43
+ - for the substitution benchmark, it describes the set of substitutions to apply on the reference sequence to obtain the mutated sequence (eg., A1P:D2N implies the amino acid 'A' at position 1 should be replaced by 'P', and 'D' at position 2 should be replaced by 'N')
44
+ - for the indel benchmark, it corresponds to the full mutated sequence
45
+ - DMS_score (float): corresponds to the experimental measurement in the DMS assay. Across all assays, the higher the DMS_score value, the higher the fitness of the mutated protein
46
+ - DMS_score_bin (int): indicates whether the DMS_score is above the fitness cutoff (1 is fit, 0 is not fit)
47
+
48
+ Additionally, we provide reference files in the [ProteinGym folder](https://github.com/OATML-Markslab/Tranception/tree/main/ProteinGym) that give further details on each assay and contain in particular:
49
+ - The UniProt_ID of the corresponding protein, along with taxon and MSA depth category
50
+ - The target sequence (target_seq) used in the assay
51
+ - Details on how the DMS_score was created from the raw files and how it was binarized
52
+
53
+ To download the substitution benchmark (~224M unzipped):
54
+ ```
55
+ curl -o ProteinGym_substitutions.zip https://marks.hms.harvard.edu/ProteinGym/ProteinGym_substitutions.zip
56
+ unzip ProteinGym_substitutions.zip
57
+ rm ProteinGym_substitutions.zip
58
+ ```
59
+
60
+ Similarly, to download the indel benchmark (~86M unzipped):
61
+ ```
62
+ curl -o ProteinGym_indels.zip https://marks.hms.harvard.edu/ProteinGym/ProteinGym_indels.zip
63
+ unzip ProteinGym_indels.zip
64
+ rm ProteinGym_indels.zip
65
+ ```
66
+
67
+ ## Fitness prediction performance
68
+
69
+ The [proteingym folder](https://github.com/OATML-Markslab/Tranception/tree/main/ProteinGym) provides detailed performance files for Tranception and baselines on the two ProteinGym benchmarks.
70
+
71
+ We recommand to aggregate fitness pediction performance at the Uniprot ID level to avoid biasing results towards proteins for which several DMS assays are available in ProteinGym. The corresponding aggregated files are suffixed with "_Uniprot_level", while the non aggregated performance files are suffixed with "_DMS_level".
72
+ Furthermore, to enable fair comparison with models trained multiple-sequence alignments (eg., EVE, DeepSequence, EVmutation), we only evaluate on the subset of mutations where position coverage is deemed high enough by these models to make a prediction. The corresponding files are preffixed with "All_models_". For comprehensiveness, we also provide performance files on all possible mutants available in ProteinGym, comparing only with the baselines that are able to score all mutants.
73
+ Note that for the ProteinGym indel benchmark, baselines that are able to score indels do not have the aforementionned coverage constraints (ie., no distinction between "All_models_" and "All_mutants_") and there is at most one DMS per Uniprot_ID (ie., no difference between "_Uniprot_level" and "_DMS_level"). We thus only provide one set of performance metrics for that benchmark.
74
+
75
+ ### ProteinGym substitution benchmark - Leaderboard
76
+ The table below provides the average Spearman's rank correlation between DMS experimental fitness measurements and fitness predictions from Tranception or other baselines on the ProteinGym substitution benchmark. Following the terminology introduced above, we report the performance at the "Uniprot" level for "All models".
77
+
78
+ Rank | Model name | Spearman | Reference
79
+ --- | --- | --- | --- |
80
+ 1 | Ensemble Tranception & EVE | 0.476 | [Notin et al.](https://arxiv.org/abs/2205.13760)
81
+ 2 | Tranception (w/ retrieval) | 0.451 | [Notin et al.](https://arxiv.org/abs/2205.13760)
82
+ 3 | EVE | 0.448 | [Frazer et al.](https://www.nature.com/articles/s41586-021-04043-8)
83
+ 4 | EVmutation | 0.427 | [Hopf et al.](https://www.nature.com/articles/nbt.3769)
84
+ 5 | MSA Transformer | 0.422 | [Rao et al.](https://proceedings.mlr.press/v139/rao21a.html)
85
+ 6 | DeepSequence | 0.415 | [Riesselman et al.](https://www.nature.com/articles/s41592-018-0138-4)
86
+ 7 | Tranception (no retrieval) | 0.406 | [Notin et al.](https://arxiv.org/abs/2205.13760)
87
+ 8 | Wavenet | 0.398 | [Shin et al.](https://www.nature.com/articles/s41467-021-22732-w)
88
+ 9 | Site Independent | 0.397 | [Hopf et al.](https://www.nature.com/articles/nbt.3769)
89
+ 10 | ESM-1v | 0.371 | [Meier et al.](https://proceedings.neurips.cc/paper/2021/hash/f51338d736f95dd42427296047067694-Abstract.html)
90
+
91
+ ### ProteinGym indel benchmark - Leaderboard
92
+ The table below provides the average Spearman's rank correlation between DMS experimental fitness measurements and fitness predictions from Tranception or other baselines on the ProteinGym indel benchmark.
93
+
94
+ Rank | Model name | Spearman | Reference
95
+ --- | --- | --- | --- |
96
+ [Notin et al.](https://arxiv.org/abs/2205.13760)
97
+ [Notin et al.](https://arxiv.org/abs/2205.13760)
98
+ [Shin et al.](https://www.nature.com/articles/s41467-021-22732-w)
99
+
100
+ ## Aggregated model scoring files
101
+ The scores for all DMS assays in the ProteinGym substitution benchmark for Tranception and other baselines (eg., EVE, Wavenet, ESM-1v, MSA Transformer) may be downloaded as follows;
102
+ ```
103
+ curl -o scores_all_models_proteingym_substitutions.zip https://marks.hms.harvard.edu/ProteinGym/scores_all_models_proteingym_substitutions.zip
104
+ unzip scores_all_models_proteingym_substitutions.zip
105
+ rm scores_all_models_proteingym_substitutions.zip
106
+ ```
107
+ Similarly for the indel benchmark, all scoring files may be downloaded as follows:
108
+ ```
109
+ curl -o scores_all_models_proteingym_indels.zip https://marks.hms.harvard.edu/ProteinGym/scores_all_models_proteingym_indels.zip
110
+ unzip scores_all_models_proteingym_indels.zip
111
+ rm scores_all_models_proteingym_indels.zip
112
+ ```
113
+
114
+ ## Multiple Sequence Alignments (MSAs)
115
+
116
+ The MSAs used to train alignment-based methods or used at inference in Tranception with retrieval and MSA Transformer may be downloaded as follows (~2.2GB unzipped):
117
+ ```
118
+ curl -o MSA_ProteinGym.zip https://marks.hms.harvard.edu/ProteinGym/MSA_ProteinGym.zip
119
+ unzip MSA_ProteinGym.zip
120
+ rm MSA_ProteinGym.zip
121
+ ```
122
+
123
+ ## License
124
+ This project is available under the MIT license.
125
+
126
+ ## Reference
127
+ If you use Tranception, ProteinGym or other files provided through this repository (eg., aggregated model scoring files) in your work, please cite the following paper:
128
+ ```
129
+ Notin, P., Dias, M., Frazer, J., Marchena-Hurtado, J., Gomez, A., Marks, D.S., Gal, Y. (2022). Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval. ICML.
130
+ ```
131
+
132
+ ## Links
133
+ Pre-print: https://arxiv.org/abs/2205.13760
__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from . import config
activations.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from packaging import version
5
+ from torch import nn
6
+
7
+ from transformers.utils import logging
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ def _gelu_python(x):
14
+ """
15
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
16
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
17
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
18
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
19
+ """
20
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
21
+
22
+
23
+ def gelu_new(x):
24
+ """
25
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
26
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
27
+ """
28
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
29
+
30
+
31
+ if version.parse(torch.__version__) < version.parse("1.4"):
32
+ gelu = _gelu_python
33
+ else:
34
+ gelu = nn.functional.gelu
35
+
36
+
37
+ def gelu_fast(x):
38
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
39
+
40
+
41
+ def quick_gelu(x):
42
+ return x * torch.sigmoid(1.702 * x)
43
+
44
+
45
+ def _silu_python(x):
46
+ """
47
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
48
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
49
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
50
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
51
+ later.
52
+ """
53
+ return x * torch.sigmoid(x)
54
+
55
+
56
+ if version.parse(torch.__version__) < version.parse("1.7"):
57
+ silu = _silu_python
58
+ else:
59
+ silu = nn.functional.silu
60
+
61
+
62
+ def _mish_python(x):
63
+ """
64
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
65
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
66
+ """
67
+ return x * torch.tanh(nn.functional.softplus(x))
68
+
69
+
70
+ if version.parse(torch.__version__) < version.parse("1.9"):
71
+ mish = _mish_python
72
+ else:
73
+ mish = nn.functional.mish
74
+
75
+
76
+ def linear_act(x):
77
+ return x
78
+
79
+ def squared_relu(x):
80
+ """
81
+ Squared ReLU variant that is fastest with Pytorch.
82
+ """
83
+ x = nn.functional.relu(x)
84
+ return x*x
85
+
86
+ def squared_relu_xla(x):
87
+ """
88
+ Squared ReLU variant that is fastest with JAX.
89
+ """
90
+ x = nn.functional.relu(x)
91
+ return x**2
92
+
93
+ tranception_ACT2FN = {
94
+ "relu": nn.functional.relu,
95
+ "silu": silu,
96
+ "swish": silu,
97
+ "gelu": gelu,
98
+ "tanh": torch.tanh,
99
+ "gelu_new": gelu_new,
100
+ "gelu_fast": gelu_fast,
101
+ "quick_gelu": quick_gelu,
102
+ "mish": mish,
103
+ "linear": linear_act,
104
+ "sigmoid": torch.sigmoid,
105
+ "squared_relu": squared_relu,
106
+ "squared_relu_xla": squared_relu_xla,
107
+ }
108
+
109
+
110
+ def get_activation(activation_string):
111
+ if activation_string in tranception_ACT2FN:
112
+ return tranception_ACT2FN[activation_string]
113
+ else:
114
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(tranception_ACT2FN.keys())}")
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "MSA_end": null,
3
+ "MSA_filename": null,
4
+ "MSA_start": null,
5
+ "MSA_weight_file_name": null,
6
+ "_name_or_path": "Tranception_Large",
7
+ "activation_function": "squared_relu",
8
+ "architectures": [
9
+ "TranceptionLMHeadModel"
10
+ ],
11
+ "attention_mode": "tranception",
12
+ "attn_pdrop": 0.1,
13
+ "bos_token_id": 1,
14
+ "clustal_omega_location": null,
15
+ "embd_pdrop": 0.1,
16
+ "eos_token_id": 2,
17
+ "full_protein_length": null,
18
+ "initializer_range": 0.02,
19
+ "layer_norm_epsilon": 1e-05,
20
+ "local_batch_size": 1,
21
+ "model_type": "tranception",
22
+ "n_ctx": 1024,
23
+ "n_embd": 1280,
24
+ "n_head": 20,
25
+ "n_inner": 5120,
26
+ "n_layer": 36,
27
+ "n_positions": 1024,
28
+ "position_embedding": "grouped_alibi",
29
+ "reorder_and_upcast_attn": false,
30
+ "resid_pdrop": 0.1,
31
+ "retrieval_aggregation_mode": null,
32
+ "retrieval_inference_weight": 0.6,
33
+ "scale_attn_by_inverse_layer_idx": false,
34
+ "scale_attn_weights": true,
35
+ "scoring_window": "optimal",
36
+ "summary_activation": null,
37
+ "summary_first_dropout": 0.1,
38
+ "summary_proj_to_labels": true,
39
+ "summary_type": "cls_index",
40
+ "summary_use_proj": true,
41
+ "tokenizer": null,
42
+ "torch_dtype": "float32",
43
+ "transformers_version": "4.17.0",
44
+ "use_cache": true,
45
+ "vocab_size": 25
46
+ }
config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+ class TranceptionConfig(GPT2Config):
4
+ """
5
+ Config subclass for Tranception model architecture.
6
+ """
7
+ def __init__(
8
+ self,
9
+ attention_mode="tranception",
10
+ position_embedding="grouped_alibi",
11
+ tokenizer=None,
12
+ retrieval_aggregation_mode=None,
13
+ retrieval_inference_weight=0.6,
14
+ MSA_filename=None,
15
+ MSA_weight_file_name=None,
16
+ MSA_start=None,
17
+ MSA_end=None,
18
+ full_protein_length=None,
19
+ clustal_omega_location=None,
20
+ scoring_window=None,
21
+ **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.model_type="tranception"
25
+ self.attention_mode=attention_mode
26
+ self.position_embedding=position_embedding
27
+ self.tokenizer = tokenizer
28
+ self.retrieval_aggregation_mode = retrieval_aggregation_mode
29
+ self.retrieval_inference_weight = retrieval_inference_weight
30
+ self.MSA_filename = MSA_filename
31
+ self.MSA_weight_file_name = MSA_weight_file_name
32
+ self.MSA_start=MSA_start
33
+ self.MSA_end=MSA_end
34
+ self.full_protein_length = full_protein_length
35
+ self.clustal_omega_location = clustal_omega_location
36
+ self.scoring_window=scoring_window
model_pytorch.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+ import math
4
+ import os
5
+ import pandas as pd
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, NLLLoss
10
+ import torch.nn.functional as F
11
+ from transformers import GPT2PreTrainedModel
12
+
13
+ from transformers.modeling_utils import (
14
+ Conv1D,
15
+ PreTrainedModel,
16
+ SequenceSummary,
17
+ find_pruneable_heads_and_indices,
18
+ prune_conv1d_layer,
19
+ )
20
+ from transformers.file_utils import (
21
+ ModelOutput,
22
+ add_code_sample_docstrings,
23
+ add_start_docstrings,
24
+ add_start_docstrings_to_model_forward,
25
+ replace_return_docstrings
26
+ )
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ CausalLMOutputWithCrossAttentions,
30
+ SequenceClassifierOutputWithPast,
31
+ TokenClassifierOutput
32
+ )
33
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
34
+
35
+ from tranception.activations import tranception_ACT2FN
36
+ from tranception.config import TranceptionConfig
37
+ from tranception.outputs import (
38
+ TranceptionCausalLMOutputWithCrossAttentions,
39
+ )
40
+ from tranception.utils import msa_utils
41
+ from tranception.utils import scoring_utils
42
+
43
+ def nanmean(v, *args, inplace=False, **kwargs):
44
+ if not inplace:
45
+ v = v.clone()
46
+ is_nan = torch.isnan(v)
47
+ v[is_nan] = 0
48
+ return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
49
+
50
+ def get_slopes(n, mode="standard_alibi", verbose=False):
51
+ """
52
+ Function to compute the m constant for each attention head. Code has been adapted from the official ALiBi codebase at:
53
+ https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py
54
+ """
55
+ def get_slopes_power_of_2(n):
56
+ start = (2**(-2**-(math.log2(n)-3)))
57
+ ratio = start
58
+ return [start*ratio**i for i in range(n)]
59
+ if mode=="grouped_alibi":
60
+ n = n // 4
61
+ if math.log2(n).is_integer():
62
+ result = get_slopes_power_of_2(n)
63
+ else:
64
+ #Workaround when the number of heads is not a power of 2
65
+ closest_power_of_2 = 2**math.floor(math.log2(n))
66
+ result = get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
67
+ if mode=="grouped_alibi":
68
+ result = result * 4
69
+ if verbose:
70
+ print("ALiBi slopes: {}".format(result))
71
+ return result
72
+
73
+ class SpatialDepthWiseConvolution(nn.Module):
74
+ def __init__(self, head_dim: int, kernel_size: int = 3):
75
+ super().__init__()
76
+ self.kernel_size = kernel_size
77
+ self.conv = nn.Conv1d(in_channels=head_dim, out_channels=head_dim, kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=head_dim)
78
+
79
+ def forward(self, x: torch.Tensor):
80
+ batch_size, heads, seq_len, head_dim = x.shape
81
+ x = x.permute(0, 1, 3, 2).contiguous()
82
+ x = x.view(batch_size * heads, head_dim, seq_len)
83
+ x = self.conv(x)
84
+ if self.kernel_size>1:
85
+ x = x[:, :, :-(self.kernel_size - 1)]
86
+ x = x.view(batch_size, heads, head_dim, seq_len)
87
+ x = x.permute(0, 1, 3, 2)
88
+ return x
89
+
90
+ class TranceptionBlockAttention(nn.Module):
91
+ def __init__(self, config, is_cross_attention=False, SDWC_kernel_size=None):
92
+ super().__init__()
93
+
94
+ max_positions = config.max_position_embeddings
95
+ self.register_buffer(
96
+ "bias",
97
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
98
+ 1, 1, max_positions, max_positions
99
+ ),
100
+ )
101
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
102
+
103
+ self.embed_dim = config.hidden_size
104
+ self.num_heads = config.num_attention_heads
105
+ self.head_dim = self.embed_dim // self.num_heads
106
+ self.split_size = self.embed_dim
107
+ if self.head_dim * self.num_heads != self.embed_dim:
108
+ raise ValueError(
109
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
110
+ )
111
+
112
+ self.scale_attn_weights = config.scale_attn_weights
113
+ self.is_cross_attention = is_cross_attention
114
+
115
+ if self.is_cross_attention:
116
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
117
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
118
+ else:
119
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
120
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
121
+
122
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
123
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
124
+
125
+ self.pruned_heads = set()
126
+
127
+ self.attention_mode=config.attention_mode
128
+
129
+ if self.attention_mode=="tranception":
130
+ assert self.num_heads%4==0, "Invalid number of heads. Tranception requires the number of heads to be a multiple of 4."
131
+ self.num_heads_per_kernel_size = self.num_heads // 4
132
+ self.query_depthwiseconv = nn.ModuleDict()
133
+ self.key_depthwiseconv = nn.ModuleDict()
134
+ self.value_depthwiseconv = nn.ModuleDict()
135
+ for kernel_idx, kernel in enumerate([3,5,7]):
136
+ self.query_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
137
+ self.key_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
138
+ self.value_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
139
+
140
+ def prune_heads(self, heads):
141
+ if len(heads) == 0:
142
+ return
143
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
144
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
145
+
146
+ # Prune conv1d layers
147
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
148
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
149
+
150
+ # Update hyper params
151
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
152
+ self.num_heads = self.num_heads - len(heads)
153
+ self.pruned_heads = self.pruned_heads.union(heads)
154
+
155
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None, alibi_bias=None):
156
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
157
+
158
+ if self.scale_attn_weights:
159
+ attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
160
+
161
+ if not self.is_cross_attention:
162
+ # if only "normal" attention layer implements causal mask
163
+ query_length, key_length = query.size(-2), key.size(-2)
164
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
165
+ attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
166
+
167
+ if alibi_bias is not None:
168
+ attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]
169
+
170
+ if attention_mask is not None:
171
+ # Apply the attention mask
172
+ attn_weights = attn_weights + attention_mask
173
+
174
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
175
+ attn_weights = self.attn_dropout(attn_weights)
176
+
177
+ # Mask heads if we want to
178
+ if head_mask is not None:
179
+ attn_weights = attn_weights * head_mask
180
+
181
+ attn_output = torch.matmul(attn_weights, value)
182
+
183
+ return attn_output, attn_weights
184
+
185
+ def _split_heads(self, tensor, num_heads, attn_head_size):
186
+ """
187
+ Splits hidden_size dim into attn_head_size and num_heads
188
+ """
189
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
190
+ tensor = tensor.view(*new_shape)
191
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
192
+
193
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
194
+ """
195
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
196
+ """
197
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
198
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
199
+ return tensor.view(new_shape)
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states,
204
+ layer_past=None,
205
+ attention_mask=None,
206
+ head_mask=None,
207
+ encoder_hidden_states=None,
208
+ encoder_attention_mask=None,
209
+ use_cache=False,
210
+ output_attentions=False,
211
+ alibi_bias=None,
212
+ ):
213
+ if encoder_hidden_states is not None:
214
+ if not hasattr(self, "q_attn"):
215
+ raise ValueError(
216
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
217
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
218
+ )
219
+
220
+ query = self.q_attn(hidden_states)
221
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
222
+ attention_mask = encoder_attention_mask
223
+ else:
224
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
225
+
226
+ query = self._split_heads(query, self.num_heads, self.head_dim)
227
+ key = self._split_heads(key, self.num_heads, self.head_dim)
228
+ value = self._split_heads(value, self.num_heads, self.head_dim)
229
+
230
+ if layer_past is not None:
231
+ past_key, past_value = layer_past
232
+ key = torch.cat((past_key, key), dim=-2)
233
+ value = torch.cat((past_value, value), dim=-2)
234
+
235
+ if use_cache is True:
236
+ present = (key, value)
237
+ else:
238
+ present = None
239
+
240
+ if self.attention_mode=="tranception":
241
+ # We do not do anything on the first self.num_heads_per_kernel_size heads (kernel =1)
242
+ query_list=[query[:,:self.num_heads_per_kernel_size,:,:]]
243
+ key_list=[key[:,:self.num_heads_per_kernel_size,:,:]]
244
+ value_list=[value[:,:self.num_heads_per_kernel_size,:,:]]
245
+ for kernel_idx in range(3):
246
+ query_list.append(self.query_depthwiseconv[str(kernel_idx)](query[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
247
+ key_list.append(self.key_depthwiseconv[str(kernel_idx)](key[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
248
+ value_list.append(self.value_depthwiseconv[str(kernel_idx)](value[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
249
+ query=torch.cat(query_list, dim=1)
250
+ key=torch.cat(key_list, dim=1)
251
+ value=torch.cat(value_list, dim=1)
252
+
253
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, alibi_bias=alibi_bias)
254
+
255
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
256
+ attn_output = self.c_proj(attn_output)
257
+ attn_output = self.resid_dropout(attn_output)
258
+
259
+ outputs = (attn_output, present)
260
+ if output_attentions:
261
+ outputs += (attn_weights,)
262
+
263
+ return outputs # a, present, (attentions)
264
+
265
+ class TranceptionBlockMLP(nn.Module):
266
+ def __init__(self, intermediate_size, config):
267
+ super().__init__()
268
+ embed_dim = config.hidden_size
269
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
270
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
271
+ self.act = tranception_ACT2FN[config.activation_function]
272
+ self.dropout = nn.Dropout(config.resid_pdrop)
273
+
274
+ def forward(self, hidden_states):
275
+ hidden_states = self.c_fc(hidden_states)
276
+ hidden_states = self.act(hidden_states)
277
+ hidden_states = self.c_proj(hidden_states)
278
+ hidden_states = self.dropout(hidden_states)
279
+ return hidden_states
280
+
281
+ class TranceptionBlock(nn.Module):
282
+ def __init__(self, config, SDWC_kernel_size=None):
283
+ super().__init__()
284
+ hidden_size = config.hidden_size
285
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
286
+
287
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
288
+ self.attn = TranceptionBlockAttention(config, SDWC_kernel_size=SDWC_kernel_size)
289
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
290
+
291
+ if config.add_cross_attention:
292
+ self.crossattention = TranceptionBlockAttention(config, is_cross_attention=True, SDWC_kernel_size=SDWC_kernel_size)
293
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
294
+
295
+ self.mlp = TranceptionBlockMLP(inner_dim, config)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states,
300
+ layer_past=None,
301
+ attention_mask=None,
302
+ head_mask=None,
303
+ encoder_hidden_states=None,
304
+ encoder_attention_mask=None,
305
+ use_cache=False,
306
+ output_attentions=False,
307
+ alibi_bias=None,
308
+ ):
309
+ residual = hidden_states
310
+ hidden_states = self.ln_1(hidden_states)
311
+ attn_outputs = self.attn(
312
+ hidden_states,
313
+ layer_past=layer_past,
314
+ attention_mask=attention_mask,
315
+ head_mask=head_mask,
316
+ use_cache=use_cache,
317
+ output_attentions=output_attentions,
318
+ alibi_bias=alibi_bias,
319
+ )
320
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
321
+ outputs = attn_outputs[1:]
322
+ # residual connection
323
+ hidden_states = attn_output + residual
324
+
325
+ if encoder_hidden_states is not None:
326
+ # add one self-attention block for cross-attention
327
+ if not hasattr(self, "crossattention"):
328
+ raise ValueError(
329
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
330
+ "cross-attention layers by setting `config.add_cross_attention=True`"
331
+ )
332
+ residual = hidden_states
333
+ hidden_states = self.ln_cross_attn(hidden_states)
334
+ cross_attn_outputs = self.crossattention(
335
+ hidden_states,
336
+ attention_mask=attention_mask,
337
+ head_mask=head_mask,
338
+ encoder_hidden_states=encoder_hidden_states,
339
+ encoder_attention_mask=encoder_attention_mask,
340
+ output_attentions=output_attentions,
341
+ )
342
+ attn_output = cross_attn_outputs[0]
343
+ # residual connection
344
+ hidden_states = residual + attn_output
345
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
346
+
347
+ residual = hidden_states
348
+ hidden_states = self.ln_2(hidden_states)
349
+
350
+ feed_forward_hidden_states = self.mlp(hidden_states)
351
+
352
+ # residual connection
353
+ hidden_states = residual + feed_forward_hidden_states
354
+
355
+ if use_cache:
356
+ outputs = (hidden_states,) + outputs
357
+ else:
358
+ outputs = (hidden_states,) + outputs[1:]
359
+
360
+ return outputs # hidden_states, present, (attentions, cross_attentions)
361
+
362
+ class TranceptionModel(GPT2PreTrainedModel):
363
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
364
+ def __init__(self, config):
365
+ super().__init__(config)
366
+
367
+ self.embed_dim = config.hidden_size
368
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
369
+ self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"
370
+ if self.position_embedding=="learned":
371
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
372
+ self.alibi = None
373
+ elif self.position_embedding=="grouped_alibi":
374
+ maxpos = config.n_positions
375
+ attn_heads = config.n_head
376
+ self.slopes = torch.Tensor(get_slopes(attn_heads, mode=self.position_embedding))
377
+ #The softmax operation is invariant to translation, and bias functions used are always linear.
378
+ alibi = self.slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1)
379
+ alibi = alibi.view(attn_heads, 1, maxpos)
380
+ self.register_buffer('alibi',alibi)
381
+
382
+ self.drop = nn.Dropout(config.embd_pdrop)
383
+ self.h = nn.ModuleList([TranceptionBlock(config) for _ in range(config.num_hidden_layers)])
384
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
385
+
386
+ self.init_weights()
387
+
388
+ # Model parallel
389
+ self.model_parallel = False
390
+ self.device_map = None
391
+ self.gradient_checkpointing = False
392
+
393
+ def parallelize(self, device_map=None, num_cores=None):
394
+ self.device_map = (
395
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
396
+ )
397
+ device_prefix="cuda:"
398
+ assert_device_map(self.device_map, len(self.h))
399
+ self.model_parallel = True
400
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else device_prefix + str(min(self.device_map.keys()))
401
+ self.last_device = device_prefix + str(max(self.device_map.keys()))
402
+ self.wte = self.wte.to(self.first_device)
403
+ if self.position_embedding=="learned":
404
+ self.wpe = self.wpe.to(self.first_device)
405
+ for k, v in self.device_map.items():
406
+ print("k,v :"+str(k)+","+str(v))
407
+ for block in v:
408
+ cuda_device = device_prefix + str(k)
409
+ self.h[block] = self.h[block].to(cuda_device)
410
+ self.ln_f = self.ln_f.to(self.last_device)
411
+
412
+ def deparallelize(self):
413
+ self.model_parallel = False
414
+ self.device_map = None
415
+ self.first_device = "cpu"
416
+ self.last_device = "cpu"
417
+ self.wte = self.wte.to("cpu")
418
+ if self.position_embedding=="learned":
419
+ self.wpe = self.wpe.to("cpu")
420
+ for index in range(len(self.h)):
421
+ self.h[index] = self.h[index].to("cpu")
422
+ self.ln_f = self.ln_f.to("cpu")
423
+ torch.cuda.empty_cache()
424
+
425
+ def get_input_embeddings(self):
426
+ return self.wte
427
+
428
+ def set_input_embeddings(self, new_embeddings):
429
+ self.wte = new_embeddings
430
+
431
+ def _prune_heads(self, heads_to_prune):
432
+ """
433
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
434
+ """
435
+ for layer, heads in heads_to_prune.items():
436
+ self.h[layer].attn.prune_heads(heads)
437
+
438
+ def forward(
439
+ self,
440
+ input_ids=None,
441
+ past_key_values=None,
442
+ attention_mask=None,
443
+ token_type_ids=None,
444
+ position_ids=None,
445
+ head_mask=None,
446
+ inputs_embeds=None,
447
+ encoder_hidden_states=None,
448
+ encoder_attention_mask=None,
449
+ use_cache=None,
450
+ output_attentions=None,
451
+ output_hidden_states=None,
452
+ return_dict=None,
453
+ ):
454
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
455
+ output_hidden_states = (
456
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
457
+ )
458
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
459
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
460
+
461
+ if input_ids is not None and inputs_embeds is not None:
462
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
463
+ elif input_ids is not None:
464
+ input_shape = input_ids.size()
465
+ input_ids = input_ids.view(-1, input_shape[-1])
466
+ batch_size = input_ids.shape[0]
467
+ elif inputs_embeds is not None:
468
+ input_shape = inputs_embeds.size()[:-1]
469
+ batch_size = inputs_embeds.shape[0]
470
+ else:
471
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
472
+
473
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
474
+
475
+ if token_type_ids is not None:
476
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
477
+ if position_ids is not None:
478
+ position_ids = position_ids.view(-1, input_shape[-1])
479
+
480
+ if past_key_values is None:
481
+ past_length = 0
482
+ past_key_values = tuple([None] * len(self.h))
483
+ else:
484
+ past_length = past_key_values[0][0].size(-2)
485
+ if position_ids is None:
486
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
487
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
488
+
489
+ # GPT2Attention mask.
490
+ if attention_mask is not None:
491
+ if batch_size <= 0:
492
+ raise ValueError("batch_size has to be defined and > 0")
493
+ attention_mask = attention_mask.view(batch_size, -1)
494
+ # We create a 3D attention mask from a 2D tensor mask.
495
+ # Sizes are [batch_size, 1, 1, to_seq_length]
496
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
497
+ # this attention mask is more simple than the triangular masking of causal attention
498
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
499
+ attention_mask = attention_mask[:, None, None, :]
500
+
501
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
502
+ # masked positions, this operation will create a tensor which is 0.0 for
503
+ # positions we want to attend and -10000.0 for masked positions.
504
+ # Since we are adding it to the raw scores before the softmax, this is
505
+ # effectively the same as removing these entirely.
506
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
507
+ attention_mask = (1.0 - attention_mask) * -10000.0
508
+
509
+ # If a 2D ou 3D attention mask is provided for the cross-attention
510
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
511
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
512
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
513
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
514
+ if encoder_attention_mask is None:
515
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
516
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
517
+ else:
518
+ encoder_attention_mask = None
519
+
520
+ # Prepare head mask if needed
521
+ # 1.0 in head_mask indicate we keep the head
522
+ # attention_probs has shape bsz x n_heads x N x N
523
+ # head_mask has shape n_layer x batch x n_heads x N x N
524
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
525
+
526
+ if inputs_embeds is None:
527
+ inputs_embeds = self.wte(input_ids)
528
+ if self.position_embedding=="learned":
529
+ position_embeds = self.wpe(position_ids)
530
+ hidden_states = inputs_embeds + position_embeds
531
+ else:
532
+ hidden_states = inputs_embeds
533
+
534
+ if token_type_ids is not None:
535
+ token_type_embeds = self.wte(token_type_ids)
536
+ hidden_states = hidden_states + token_type_embeds
537
+
538
+ hidden_states = self.drop(hidden_states)
539
+
540
+ output_shape = input_shape + (hidden_states.size(-1),)
541
+
542
+ presents = () if use_cache else None
543
+ all_self_attentions = () if output_attentions else None
544
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
545
+ all_hidden_states = () if output_hidden_states else None
546
+
547
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
548
+ # Model parallel
549
+ if self.model_parallel:
550
+ torch.cuda.set_device(hidden_states.device)
551
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
552
+ if layer_past is not None:
553
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
554
+ # Ensure that attention_mask is always on the same device as hidden_states
555
+ if attention_mask is not None:
556
+ attention_mask = attention_mask.to(hidden_states.device)
557
+ if isinstance(head_mask, torch.Tensor):
558
+ head_mask = head_mask.to(hidden_states.device)
559
+ if output_hidden_states:
560
+ all_hidden_states = all_hidden_states + (hidden_states,)
561
+
562
+ if self.gradient_checkpointing and self.training:
563
+ if use_cache:
564
+ logger.warning(
565
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
566
+ )
567
+ use_cache = False
568
+
569
+ def create_custom_forward(module):
570
+ def custom_forward(*inputs):
571
+ # None for past_key_value
572
+ return module(*inputs, use_cache, output_attentions)
573
+
574
+ return custom_forward
575
+
576
+ outputs = torch.utils.checkpoint.checkpoint(
577
+ create_custom_forward(block),
578
+ hidden_states,
579
+ None,
580
+ attention_mask,
581
+ head_mask[i],
582
+ encoder_hidden_states,
583
+ encoder_attention_mask,
584
+ )
585
+ else:
586
+ outputs = block(
587
+ hidden_states,
588
+ layer_past=layer_past,
589
+ attention_mask=attention_mask,
590
+ head_mask=head_mask[i],
591
+ encoder_hidden_states=encoder_hidden_states,
592
+ encoder_attention_mask=encoder_attention_mask,
593
+ use_cache=use_cache,
594
+ output_attentions=output_attentions,
595
+ alibi_bias=self.alibi if hasattr(self, "alibi") else None
596
+ )
597
+
598
+ hidden_states = outputs[0]
599
+
600
+ if use_cache is True:
601
+ presents = presents + (outputs[1],)
602
+
603
+ if output_attentions:
604
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
605
+ if self.config.add_cross_attention:
606
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
607
+
608
+ if self.model_parallel:
609
+ device_prefix="cuda:"
610
+ for k, v in self.device_map.items():
611
+ if i == v[-1] and device_prefix + str(k) != self.last_device:
612
+ hidden_states = hidden_states.to(device_prefix + str(k + 1))
613
+
614
+ hidden_states = self.ln_f(hidden_states)
615
+
616
+ hidden_states = hidden_states.view(*output_shape)
617
+ # Add last hidden state
618
+ if output_hidden_states:
619
+ all_hidden_states = all_hidden_states + (hidden_states,)
620
+
621
+ if not return_dict:
622
+ return tuple(
623
+ v
624
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions, moe_loss]
625
+ if v is not None
626
+ )
627
+
628
+ return BaseModelOutputWithPastAndCrossAttentions(
629
+ last_hidden_state=hidden_states,
630
+ past_key_values=presents,
631
+ hidden_states=all_hidden_states,
632
+ attentions=all_self_attentions,
633
+ cross_attentions=all_cross_attentions,
634
+ )
635
+
636
+ class TranceptionLMHeadModel(GPT2PreTrainedModel):
637
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
638
+ def __init__(self, config):
639
+ super().__init__(config)
640
+ self.transformer = TranceptionModel(config)
641
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
642
+ self.config = config
643
+
644
+ self.init_weights()
645
+
646
+ self.default_model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
647
+ # Model parallel
648
+ self.model_parallel = False
649
+ self.device_map = None
650
+
651
+ self.retrieval_aggregation_mode = config.retrieval_aggregation_mode if hasattr(config, "retrieval_aggregation_mode") else None
652
+ if self.retrieval_aggregation_mode is not None:
653
+ print("Model leverages both autoregressive and retrieval inference")
654
+ self.MSA_filename = config.MSA_filename if hasattr(config, "MSA_filename") else False
655
+ self.MSA_folder = '/'.join(self.MSA_filename.split(os.sep)[:-1])
656
+ self.MSA_name = self.MSA_filename.split(os.sep)[-1]
657
+ self.retrieval_inference_weight_LR = config.retrieval_inference_weight if hasattr(config, "retrieval_inference_weight") else 0.6
658
+ self.retrieval_inference_weight_RL = config.retrieval_inference_weight if hasattr(config, "retrieval_inference_weight") else 0.6
659
+ self.MSA_start=config.MSA_start
660
+ self.MSA_end=config.MSA_end
661
+ self.full_protein_length = config.full_protein_length if hasattr(config, "full_protein_length") else -1
662
+
663
+ self.MSA_log_prior = torch.log(torch.tensor(
664
+ msa_utils.get_msa_prior(
665
+ MSA_data_file=self.MSA_filename,
666
+ MSA_weight_file_name=config.MSA_weight_file_name,
667
+ retrieval_aggregation_mode=self.retrieval_aggregation_mode,
668
+ MSA_start=self.MSA_start,
669
+ MSA_end=self.MSA_end,
670
+ len_target_seq=self.full_protein_length,
671
+ vocab=config.tokenizer.get_vocab(),
672
+ verbose=False
673
+ )
674
+ ).float().to(self.default_model_device))
675
+ else:
676
+ print("Model only uses autoregressive inference")
677
+
678
+ def parallelize(self, device_map=None, num_cores=None, num_pipelines=1):
679
+ self.num_pipelines=num_pipelines
680
+ self.device_map = (
681
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
682
+ if device_map is None
683
+ else device_map
684
+ )
685
+ assert_device_map(self.device_map, len(self.transformer.h))
686
+ self.transformer.parallelize(self.device_map, num_cores=num_cores)
687
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
688
+ self.model_parallel = True
689
+
690
+ def deparallelize(self):
691
+ self.transformer.deparallelize()
692
+ self.transformer = self.transformer.to("cpu")
693
+ self.lm_head = self.lm_head.to("cpu")
694
+ self.model_parallel = False
695
+ torch.cuda.empty_cache()
696
+
697
+ def get_output_embeddings(self):
698
+ return self.lm_head
699
+
700
+ def set_output_embeddings(self, new_embeddings):
701
+ self.lm_head = new_embeddings
702
+
703
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
704
+ token_type_ids = kwargs.get("token_type_ids", None)
705
+ # only last token for inputs_ids if past is defined in kwargs
706
+ if past:
707
+ input_ids = input_ids[:, -1].unsqueeze(-1)
708
+ if token_type_ids is not None:
709
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
710
+
711
+ attention_mask = kwargs.get("attention_mask", None)
712
+ position_ids = kwargs.get("position_ids", None)
713
+
714
+ if attention_mask is not None and position_ids is None:
715
+ # create position_ids on the fly for batch generation
716
+ position_ids = attention_mask.long().cumsum(-1) - 1
717
+ position_ids.masked_fill_(attention_mask == 0, 1)
718
+ if past:
719
+ position_ids = position_ids[:, -1].unsqueeze(-1)
720
+ else:
721
+ position_ids = None
722
+
723
+ return {
724
+ "input_ids": input_ids,
725
+ "past_key_values": past,
726
+ "use_cache": kwargs.get("use_cache"),
727
+ "position_ids": position_ids,
728
+ "attention_mask": attention_mask,
729
+ "token_type_ids": token_type_ids,
730
+ "flip": kwargs.get("flip", None),
731
+ }
732
+
733
+ def forward(
734
+ self,
735
+ input_ids=None,
736
+ past_key_values=None,
737
+ attention_mask=None,
738
+ token_type_ids=None,
739
+ position_ids=None,
740
+ head_mask=None,
741
+ inputs_embeds=None,
742
+ encoder_hidden_states=None,
743
+ encoder_attention_mask=None,
744
+ labels=None,
745
+ use_cache=None,
746
+ output_attentions=None,
747
+ output_hidden_states=None,
748
+ return_dict=None,
749
+ flip=None,
750
+ start_slice=None,
751
+ end_slice=None,
752
+ full_raw_sequence=None,
753
+ ):
754
+ r"""
755
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
756
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
757
+ ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
758
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
759
+ """
760
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
761
+
762
+ transformer_outputs = self.transformer(
763
+ input_ids,
764
+ past_key_values=past_key_values,
765
+ attention_mask=attention_mask,
766
+ token_type_ids=token_type_ids,
767
+ position_ids=position_ids,
768
+ head_mask=head_mask,
769
+ inputs_embeds=inputs_embeds,
770
+ encoder_hidden_states=encoder_hidden_states,
771
+ encoder_attention_mask=encoder_attention_mask,
772
+ use_cache=use_cache,
773
+ output_attentions=output_attentions,
774
+ output_hidden_states=output_hidden_states,
775
+ return_dict=return_dict
776
+ )
777
+ hidden_states = transformer_outputs[0]
778
+
779
+ # Set device for model parallelism
780
+ if self.model_parallel:
781
+ torch.cuda.set_device(self.transformer.first_device)
782
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
783
+ self.MSA_log_prior = self.MSA_log_prior.to(self.lm_head.weight.device)
784
+
785
+ lm_logits = self.lm_head(hidden_states)
786
+
787
+ loss = None
788
+ if labels is not None:
789
+ # Shift so that tokens < n predict n
790
+ shift_logits = lm_logits[..., :-1, :].contiguous()
791
+ shift_labels = labels[..., 1:].contiguous()
792
+
793
+ if self.retrieval_aggregation_mode is not None:
794
+ batch_size = input_ids.size(0)
795
+
796
+ if self.retrieval_aggregation_mode=="aggregate_indel":
797
+ assert batch_size==1, "Aggregate indel is only supported for batch size of 1"
798
+ truncated_sequence_text = full_raw_sequence[0][start_slice[0]:end_slice[0]]
799
+ if len(truncated_sequence_text)!=shift_logits.shape[1]-1: # shift_logits only has one extra token compared to truncated_sequence_text (the BOS token)
800
+ print("Tokenization error -- seq length: {} and shift_logits length - 1 : {}".format(len(full_raw_sequence),shift_logits.shape[1]-1))
801
+ MSA_log_prior, MSA_start, MSA_end = msa_utils.update_retrieved_MSA_log_prior_indel(self, self.MSA_log_prior, self.MSA_start, self.MSA_end, full_raw_sequence[0])
802
+
803
+ elif self.retrieval_aggregation_mode=="aggregate_substitution":
804
+ MSA_log_prior=self.MSA_log_prior
805
+ MSA_start=self.MSA_start
806
+ MSA_end=self.MSA_end
807
+
808
+ shift_log_probas = torch.log_softmax(shift_logits,dim=-1)
809
+ fused_shift_log_probas = shift_log_probas.clone()
810
+ if flip is None:
811
+ flip = torch.zeros(batch_size).to(fused_shift_log_probas.device)
812
+ flip = flip > 0
813
+
814
+ for seq_index in range(batch_size):
815
+ min_prior_slice = max(start_slice[seq_index], MSA_start)
816
+ max_prior_slice = min(end_slice[seq_index], MSA_end)
817
+
818
+ if max_prior_slice <= min_prior_slice:
819
+ print("Non overlapping region detected: min_prior_slice {} and max_prior_slice {}".format(min_prior_slice,max_prior_slice))
820
+ continue
821
+
822
+ slice_prior = MSA_log_prior[min_prior_slice:max_prior_slice,:].to(fused_shift_log_probas.device)
823
+ if flip[seq_index]:
824
+ slice_prior = torch.flip(slice_prior,dims=(0,))
825
+ min_logits_slice = max(0,end_slice[seq_index]-MSA_end)
826
+ max_logits_slice = min_logits_slice + (max_prior_slice-min_prior_slice)
827
+ fused_shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] = (1-self.retrieval_inference_weight_RL)*shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] + self.retrieval_inference_weight_RL*slice_prior
828
+ else:
829
+ min_logits_slice = max(0, MSA_start-start_slice[seq_index])
830
+ max_logits_slice = min_logits_slice + (max_prior_slice-min_prior_slice)
831
+ fused_shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] = (1-self.retrieval_inference_weight_LR)*shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] + self.retrieval_inference_weight_LR*slice_prior
832
+
833
+ if self.retrieval_aggregation_mode=="aggregate_indel":
834
+ try:
835
+ # If a given residue colume is an added zero-column, then we overwrite prior fusion and only predict based on the autoregressive transformer inference mode.
836
+ inserted_retrieval_positions = [True if slice_prior[i].sum()==0 else False for i in range(len(slice_prior))]+[True] #Last True is for the end of sentence token
837
+ fused_shift_log_probas[:,inserted_retrieval_positions,:]=shift_log_probas[:,inserted_retrieval_positions,:]
838
+ except:
839
+ print("Error when adding zero column(s) to account for insertion mutations.")
840
+
841
+ loss_fct = NLLLoss(reduction='none')
842
+ loss = loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
843
+ mask = attention_mask[..., 1:].float()
844
+ mask[mask==0]=float('nan')
845
+ loss *= mask
846
+ loss = nanmean(loss, dim=1).mean()
847
+ else:
848
+ loss_fct = CrossEntropyLoss()
849
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
850
+ fused_shift_log_probas = None
851
+
852
+ if not return_dict:
853
+ output = (lm_logits,) + transformer_outputs[1:]
854
+ return ((loss,) + output) if loss is not None else output
855
+
856
+ return TranceptionCausalLMOutputWithCrossAttentions(
857
+ loss=loss,
858
+ logits=lm_logits,
859
+ past_key_values=transformer_outputs.past_key_values,
860
+ hidden_states=transformer_outputs.hidden_states,
861
+ attentions=transformer_outputs.attentions,
862
+ cross_attentions=transformer_outputs.cross_attentions,
863
+ fused_shift_log_probas=fused_shift_log_probas
864
+ )
865
+
866
+
867
+ @staticmethod
868
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
869
+ """
870
+ This function is used to re-order the :obj:`past_key_values` cache if
871
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
872
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
873
+ """
874
+ return tuple(
875
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
876
+ for layer_past in past
877
+ )
878
+
879
+ def score_mutants(self, DMS_data, target_seq, scoring_mirror=True, batch_size_inference=10, num_workers=10, indel_mode=False):
880
+ """
881
+ Method to score mutants in an input DMS file.
882
+ DMS_data: (dataframe) Dataframe containing the list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring.
883
+ target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay.
884
+ scoring_mirror: (bool) Whether to score mutated sequences from both directions (Left->Right and Right->Left).
885
+ batch_size_inference: (int) Batch size for scoring.
886
+ num_workers: (int) Number of workers to be used in the data loader.
887
+ indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
888
+ """
889
+ df = DMS_data.copy()
890
+ df['mutated_sequence'] = df['mutant'].apply(lambda x: scoring_utils.get_mutated_sequence(target_seq, x)) if not indel_mode else df['mutant']
891
+ if 'DMS_score' in df: del df['DMS_score']
892
+ if 'DMS_score_bin' in df: del df['DMS_score_bin']
893
+ df_left_to_right_slices = scoring_utils.get_sequence_slices(df, target_seq=target_seq, model_context_len = self.config.n_ctx - 2, indel_mode=indel_mode, scoring_window=self.config.scoring_window)
894
+ print("Scoring sequences from left to right")
895
+ scores_L_to_R = scoring_utils.get_tranception_scores_mutated_sequences(model=self, mutated_sequence_df=df_left_to_right_slices, batch_size_inference=batch_size_inference, score_var_name='avg_score_L_to_R', len_target_seq=len(target_seq), num_workers=num_workers, indel_mode=indel_mode)
896
+ if scoring_mirror:
897
+ print("Scoring sequences from right to left")
898
+ df_right_to_left_slices = df_left_to_right_slices.copy()
899
+ df_right_to_left_slices['mutated_sequence'] = df_right_to_left_slices['mutated_sequence'].apply(lambda x: x[::-1])
900
+ scores_R_to_L = scoring_utils.get_tranception_scores_mutated_sequences(model=self, mutated_sequence_df=df_right_to_left_slices, batch_size_inference=batch_size_inference, score_var_name='avg_score_R_to_L', len_target_seq=len(target_seq), num_workers=num_workers, reverse=True, indel_mode=indel_mode)
901
+ all_scores = pd.merge(scores_L_to_R, scores_R_to_L, on='mutant', how='left',suffixes=('','_R_to_L'))
902
+ all_scores['avg_score'] = (all_scores['avg_score_L_to_R'] + all_scores['avg_score_R_to_L']) / 2.0
903
+ else:
904
+ all_scores = scores_L_to_R
905
+ all_scores['avg_score'] = all_scores['avg_score_L_to_R']
906
+ return all_scores
907
+
908
+ def encode_batch(self, protein_sequence, sequence_name="mutated_sequence"):
909
+ """
910
+ Method to process an input AA sequence batch (protein_sequence) and return a tokenized sequence (via the tokenizer associated to the model).
911
+ """
912
+ protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='X', char_replacements='ACDEFGHIKLMNPQRSTVWY')
913
+ protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='B', char_replacements='DN')
914
+ protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='J', char_replacements='IL')
915
+ protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='Z', char_replacements='EQ')
916
+ return self.config.tokenizer(list(protein_sequence[sequence_name]), add_special_tokens=True, truncation=True, padding=True, max_length=self.config.n_ctx)
917
+
outputs.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+
6
+ from transformers.file_utils import ModelOutput
7
+
8
+ @dataclass
9
+ class TranceptionCausalLMOutputWithCrossAttentions(ModelOutput):
10
+ """
11
+ Class for Tranception causal language model (or autoregressive) outputs.
12
+ Args:
13
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
14
+ Language modeling loss (for next-token prediction).
15
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
16
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
17
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
18
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
19
+ shape `(batch_size, sequence_length, hidden_size)`.
20
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
21
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
22
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
23
+ sequence_length)`.
24
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
25
+ heads.
26
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
27
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
28
+ sequence_length)`.
29
+ Cross attentions weights after the attention softmax, used to compute the weighted average in the
30
+ cross-attention heads.
31
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
32
+ Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
33
+ value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
34
+ setting. Only relevant if `config.is_decoder = True`.
35
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
36
+ `past_key_values` input) to speed up sequential decoding.
37
+ fused_shift_log_probas (`torch.FloatTensor` of shape (batch_size, sequence_length, config.vocab_size), *optional*, returned when config.retrieval_aggregation_mode is not None.
38
+ log_probas for each residue position after aggregating autoregressive logits and retrieval logits.
39
+
40
+ """
41
+
42
+ loss: Optional[torch.FloatTensor] = None
43
+ logits: torch.FloatTensor = None
44
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
45
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
46
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
47
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
48
+ fused_shift_log_probas: Optional[torch.FloatTensor] = None
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3ee11f9fa7cca8f7859d63769a40ef788f3f32967b5a9ec32bbb2c1f35b0bea
3
+ size 2872459669
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from . import scoring_utils, msa_utils
utils/dms_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ def DMS_file_cleanup(DMS_filename, target_seq, start_idx=1, end_idx=None, DMS_mutant_column='mutant', DMS_phenotype_name='score', DMS_directionality=1, AA_vocab = "ACDEFGHIKLMNPQRSTVWY"):
5
+ """
6
+ Function to process the raw DMS assay data (eg., removing invalid mutants, aggregate silent mutations)
7
+ """
8
+ DMS_data = pd.read_csv(DMS_filename, low_memory=False)
9
+ end_idx = start_idx + len(target_seq) - 1 if end_idx is None else end_idx
10
+ DMS_data['mutant'] = DMS_data[DMS_mutant_column]
11
+
12
+ DMS_data=DMS_data[DMS_data['mutant'].notnull()].copy()
13
+ DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([len(y)>=3 for y in x.split(":")]))].copy() #Mutant triplets should have at least 3 or more characters
14
+ DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([(y[0] in AA_vocab) and (y[1:-1].isnumeric()) and (y[-1] in AA_vocab) for y in x.split(":")]))].copy()
15
+ DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([int(y[1:-1])-start_idx >=0 and int(y[1:-1]) <= end_idx for y in x.split(":")]))].copy()
16
+ DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([y[0]==target_seq[int(y[1:-1])-start_idx] for y in x.split(":")]))].copy()
17
+
18
+ DMS_data[DMS_phenotype_name]=pd.to_numeric(DMS_data[DMS_phenotype_name],errors='coerce')
19
+ DMS_data=DMS_data[np.isfinite(DMS_data[DMS_phenotype_name])]
20
+ DMS_data.dropna(subset = [DMS_phenotype_name], inplace=True)
21
+ DMS_data['DMS_score'] = DMS_data[DMS_phenotype_name] * DMS_directionality
22
+ DMS_data=DMS_data[['mutant','DMS_score']]
23
+ DMS_data = DMS_data.groupby('mutant').mean().reset_index()
24
+
25
+ return DMS_data
26
+
utils/msa_utils.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from collections import defaultdict
4
+ import random
5
+ import os
6
+ import torch
7
+ from Bio.Align.Applications import ClustalOmegaCommandline
8
+
9
+ def filter_msa(msa_data, num_sequences_kept=3):
10
+ """
11
+ Helper function to filter an input MSA msa_data (obtained via process_msa_data) and keep only num_sequences_kept aligned sequences.
12
+ If the MSA already has fewer sequences than num_sequences_kept, we keep the MSA as is.
13
+ If filtering, we always keep the first sequence of the MSA (ie. the wild type) by default.
14
+ Sampling is done without replacement.
15
+ """
16
+ if len(list(msa_data.keys())) <= num_sequences_kept:
17
+ return msa_data
18
+ filtered_msa = {}
19
+ wt_name = next(iter(msa_data))
20
+ filtered_msa[wt_name] = msa_data[wt_name]
21
+ del msa_data[wt_name]
22
+ sequence_names = list(msa_data.keys())
23
+ sequence_names_sampled = random.sample(sequence_names,k=num_sequences_kept-1)
24
+ for seq in sequence_names_sampled:
25
+ filtered_msa[seq] = msa_data[seq]
26
+ return filtered_msa
27
+
28
+ def process_msa_data(MSA_data_file):
29
+ """
30
+ Helper function that takes as input a path to a MSA file (expects a2m format) and returns a dict mapping sequence ID to the corresponding AA sequence.
31
+ """
32
+ msa_data = defaultdict(str)
33
+ sequence_name = ""
34
+ with open(MSA_data_file, "r") as msa_file:
35
+ for i, line in enumerate(msa_file):
36
+ line = line.rstrip()
37
+ if line.startswith(">"):
38
+ sequence_name = line
39
+ else:
40
+ msa_data[sequence_name] += line.upper()
41
+ return msa_data
42
+
43
+ def get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab):
44
+ vocab_size = len(vocab.keys())
45
+ num_sequences_msa = len(msa_data.keys())
46
+ one_hots = np.zeros((num_sequences_msa,MSA_end-MSA_start,vocab_size))
47
+ for i,seq_name in enumerate(msa_data.keys()):
48
+ sequence = msa_data[seq_name]
49
+ for j,letter in enumerate(sequence):
50
+ if letter in vocab:
51
+ k = vocab[letter]
52
+ one_hots[i,j,k] = 1.0
53
+ return one_hots
54
+
55
+ def one_hot(sequence_string,vocab):
56
+ one_hots = np.zeros((len(sequence_string),len(vocab.keys())))
57
+ for j,letter in enumerate(sequence_string):
58
+ if letter in vocab:
59
+ k = vocab[letter]
60
+ one_hots[j,k] = 1.0
61
+ return one_hots.flatten()
62
+
63
+ def get_msa_prior(MSA_data_file, MSA_weight_file_name, MSA_start, MSA_end, len_target_seq, vocab, retrieval_aggregation_mode="aggregate_substitution", filter_MSA=True, verbose=False):
64
+ """
65
+ Function to enable retrieval inference mode, via computation of (weighted) pseudocounts of AAs at each position of the retrieved MSA.
66
+ MSA_data_file: (string) path to MSA file (expects a2m format).
67
+ MSA_weight_file_name: (string) path to sequence weights in MSA.
68
+ MSA_start: (int) Sequence position that the MSA starts at (1-indexing).
69
+ MSA_end: (int) Sequence position that the MSA ends at (1-indexing).
70
+ len_target_seq: (int) Full length of sequence to be scored.
71
+ vocab: (dict) Vocabulary of the tokenizer.
72
+ retrieval_aggregation_mode: (string) Mode for retrieval inference (aggregate_substitution Vs aggregate_indel). If None, places a uniform prior over each token.
73
+ filter_MSA: (bool) Whether to filter out sequences with very low hamming similarity (< 0.2) to the reference sequence in the MSA (first sequence).
74
+ verbose: (bool) Whether to print to the console processing details along the way.
75
+ """
76
+ msa_data = process_msa_data(MSA_data_file)
77
+ vocab_size = len(vocab.keys())
78
+ if verbose: print("Target seq len is {}, MSA length is {}, start position is {}, end position is {} and vocab size is {}".format(len_target_seq,MSA_end-MSA_start,MSA_start,MSA_end,vocab_size))
79
+
80
+ if filter_MSA:
81
+ if verbose: print("Num sequences in MSA pre filtering: {}".format(len(msa_data.keys())))
82
+ list_sequence_names = list(msa_data.keys())
83
+ focus_sequence_name = list(msa_data.keys())[0]
84
+ ref_sequence_hot = one_hot(msa_data[focus_sequence_name],vocab)
85
+ for sequence_name in list_sequence_names:
86
+ seq_hot = one_hot(msa_data[sequence_name],vocab)
87
+ hamming_similarity_seq_ref = np.dot(ref_sequence_hot,seq_hot) / np.dot(ref_sequence_hot,ref_sequence_hot)
88
+ if hamming_similarity_seq_ref < 0.2:
89
+ del msa_data[sequence_name]
90
+ if verbose: print("Num sequences in MSA post filtering: {}".format(len(msa_data.keys())))
91
+
92
+ if MSA_weight_file_name is not None:
93
+ if verbose: print("Using weights in {} for sequences in MSA.".format(MSA_weight_file_name))
94
+ assert os.path.exists(MSA_weight_file_name), "Weights file not located on disk."
95
+ MSA_EVE = MSA_processing(
96
+ MSA_location=MSA_data_file,
97
+ use_weights=True,
98
+ weights_location=MSA_weight_file_name
99
+ )
100
+ #We scan through all sequences to see if we have a weight for them as per EVE pre-processing. We drop them otherwise.
101
+ dropped_sequences=0
102
+ list_sequence_names = list(msa_data.keys())
103
+ MSA_weight=[]
104
+ for sequence_name in list_sequence_names:
105
+ if sequence_name not in MSA_EVE.seq_name_to_sequence:
106
+ dropped_sequences +=1
107
+ del msa_data[sequence_name]
108
+ else:
109
+ MSA_weight.append(MSA_EVE.seq_name_to_weight[sequence_name])
110
+ if verbose: print("Dropped {} sequences from MSA due to absent sequence weights".format(dropped_sequences))
111
+ else:
112
+ MSA_weight = [1] * len(list(msa_data.keys()))
113
+
114
+ if retrieval_aggregation_mode=="aggregate_substitution" or retrieval_aggregation_mode=="aggregate_indel":
115
+ one_hots = get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab)
116
+ MSA_weight = np.expand_dims(np.array(MSA_weight),axis=(1,2))
117
+ base_rate = 1e-5
118
+ base_rates = np.ones_like(one_hots) * base_rate
119
+ weighted_one_hots = (one_hots + base_rates) * MSA_weight
120
+ MSA_weight_norm_counts = weighted_one_hots.sum(axis=-1).sum(axis=0)
121
+ MSA_weight_norm_counts = np.tile(MSA_weight_norm_counts.reshape(-1,1), (1,vocab_size))
122
+ one_hots_avg = weighted_one_hots.sum(axis=0) / MSA_weight_norm_counts
123
+ msa_prior = np.zeros((len_target_seq,vocab_size))
124
+ msa_prior[MSA_start:MSA_end,:]=one_hots_avg
125
+ else:
126
+ msa_prior = np.ones((len_target_seq,vocab_size)) / vocab_size
127
+
128
+ if verbose:
129
+ for idx, position in enumerate(msa_prior):
130
+ if len(position)!=25:
131
+ print("Size error")
132
+ if not round(position.sum(),2)==1.0:
133
+ print("Position at index {} does not add up to 1: {}".format(idx, position.sum()))
134
+
135
+ return msa_prior
136
+
137
+
138
+ def update_retrieved_MSA_log_prior_indel(model, MSA_log_prior, MSA_start, MSA_end, full_raw_sequence):
139
+ """
140
+ Function to process MSA when scoring indels.
141
+ To identify positions to add / remove in the retrieved MSA, we append and align the sequence to be scored to the original MSA for that protein family with Clustal Omega.
142
+ If the original MSA is relatively deep (over 100k sequences), we sample (by default) 100k rows at random from that MSA to speed computations.
143
+ MSA sampling is performed only once (for the first sequence to be scored). Subsequent scoring use the same MSA sample.
144
+ """
145
+ if not os.path.isdir(model.MSA_folder + os.sep + "Sampled"):
146
+ os.mkdir(model.MSA_folder + os.sep + "Sampled")
147
+ sampled_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Sampled_" + model.MSA_filename.split(os.sep)[-1]
148
+
149
+ if not os.path.exists(sampled_MSA_location):
150
+ msa_data = process_msa_data(model.MSA_filename)
151
+ msa_data_sampled = filter_msa(msa_data, num_sequences_kept=100000) #If MSA has less than 100k sequences, the sample is identical to original MSA
152
+ with open(sampled_MSA_location, 'w') as sampled_write_location:
153
+ for index, key in enumerate(msa_data_sampled):
154
+ key_name = ">REFERENCE_SEQUENCE" if index==0 else key
155
+ msa_data_sampled[key] = msa_data_sampled[key].upper()
156
+ msa_data_sampled[key] = msa_data_sampled[key].replace(".","-")
157
+ sampled_write_location.write(key_name+"\n"+"\n".join([msa_data_sampled[key][i:i+80] for i in range(0, len(msa_data_sampled[key]), 80)])+"\n")
158
+
159
+ seq_to_align_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Seq_to_align_" + model.MSA_filename.split(os.sep)[-1]
160
+ sequence_text_split = [full_raw_sequence[i:i+80] for i in range(0, len(full_raw_sequence), 80)]
161
+ sequence_text_split_split_join = "\n".join([">SEQ_TO_SCORE"]+sequence_text_split)
162
+ os.system("echo '"+sequence_text_split_split_join+"' > "+seq_to_align_location)
163
+
164
+ expanded_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Expanded_" + model.MSA_filename.split(os.sep)[-1]
165
+ clustalw_cline = ClustalOmegaCommandline(cmd=model.config.clustal_omega_location,
166
+ profile1=sampled_MSA_location,
167
+ profile2=seq_to_align_location,
168
+ outfile=expanded_MSA_location,
169
+ force=True)
170
+ stdout, stderr = clustalw_cline()
171
+ msa_data = process_msa_data(expanded_MSA_location)
172
+ aligned_seqA, aligned_seqB = msa_data[">SEQ_TO_SCORE"], msa_data[">REFERENCE_SEQUENCE"]
173
+ try:
174
+ keep_column=[]
175
+ for column_index_pairwise_alignment in range(len(aligned_seqA)):
176
+ if aligned_seqA[column_index_pairwise_alignment]=="-" and aligned_seqB[column_index_pairwise_alignment]=="-":
177
+ continue
178
+ elif aligned_seqA[column_index_pairwise_alignment]=="-":
179
+ keep_column.append(False)
180
+ elif aligned_seqB[column_index_pairwise_alignment]=="-":
181
+ MSA_log_prior=torch.cat((MSA_log_prior[:column_index_pairwise_alignment], torch.zeros(MSA_log_prior.shape[1]).view(1,-1).cuda(), MSA_log_prior[column_index_pairwise_alignment:]),dim=0)
182
+ keep_column.append(True) #keep the zero column we just added
183
+ else:
184
+ keep_column.append(True)
185
+ MSA_log_prior = MSA_log_prior[keep_column]
186
+ MSA_end = MSA_start + len(MSA_log_prior)
187
+ except:
188
+ print("Error when processing the following alignment: {}".format(expanded_MSA_location))
189
+ return MSA_log_prior, MSA_start, MSA_end
190
+
191
+ class MSA_processing:
192
+ def __init__(self,
193
+ MSA_location="",
194
+ theta=0.2,
195
+ use_weights=True,
196
+ weights_location="./data/weights",
197
+ preprocess_MSA=True,
198
+ threshold_sequence_frac_gaps=0.5,
199
+ threshold_focus_cols_frac_gaps=0.3,
200
+ remove_sequences_with_indeterminate_AA_in_focus_cols=True
201
+ ):
202
+
203
+ """
204
+ This MSA_processing class is directly borrowed from the EVE codebase: https://github.com/OATML-Markslab/EVE
205
+
206
+ Parameters:
207
+ - msa_location: (path) Location of the MSA data. Constraints on input MSA format:
208
+ - focus_sequence is the first one in the MSA data
209
+ - first line is structured as follows: ">focus_seq_name/start_pos-end_pos" (e.g., >SPIKE_SARS2/310-550)
210
+ - corespondding sequence data located on following line(s)
211
+ - then all other sequences follow with ">name" on first line, corresponding data on subsequent lines
212
+ - theta: (float) Sequence weighting hyperparameter. Generally: Prokaryotic and eukaryotic families = 0.2; Viruses = 0.01
213
+ - use_weights: (bool) If False, sets all sequence weights to 1. If True, checks weights_location -- if non empty uses that;
214
+ otherwise compute weights from scratch and store them at weights_location
215
+ - weights_location: (path) Location to load from/save to the sequence weights
216
+ - preprocess_MSA: (bool) performs pre-processing of MSA to remove short fragments and positions that are not well covered.
217
+ - threshold_sequence_frac_gaps: (float, between 0 and 1) Threshold value to define fragments
218
+ - sequences with a fraction of gap characters above threshold_sequence_frac_gaps are removed
219
+ - default is set to 0.5 (i.e., fragments with 50% or more gaps are removed)
220
+ - threshold_focus_cols_frac_gaps: (float, between 0 and 1) Threshold value to define focus columns
221
+ - positions with a fraction of gap characters above threshold_focus_cols_pct_gaps will be set to lower case (and not included in the focus_cols)
222
+ - default is set to 0.3 (i.e., focus positions are the ones with 30% of gaps or less, i.e., 70% or more residue occupancy)
223
+ - remove_sequences_with_indeterminate_AA_in_focus_cols: (bool) Remove all sequences that have indeterminate AA (e.g., B, J, X, Z) at focus positions of the wild type
224
+ """
225
+ np.random.seed(2021)
226
+ self.MSA_location = MSA_location
227
+ self.weights_location = weights_location
228
+ self.theta = theta
229
+ self.alphabet = "ACDEFGHIKLMNPQRSTVWY"
230
+ self.use_weights = use_weights
231
+ self.preprocess_MSA = preprocess_MSA
232
+ self.threshold_sequence_frac_gaps = threshold_sequence_frac_gaps
233
+ self.threshold_focus_cols_frac_gaps = threshold_focus_cols_frac_gaps
234
+ self.remove_sequences_with_indeterminate_AA_in_focus_cols = remove_sequences_with_indeterminate_AA_in_focus_cols
235
+
236
+ self.gen_alignment()
237
+
238
+ def gen_alignment(self, verbose=False):
239
+ """ Read training alignment and store basics in class instance """
240
+ self.aa_dict = {}
241
+ for i,aa in enumerate(self.alphabet):
242
+ self.aa_dict[aa] = i
243
+
244
+ self.seq_name_to_sequence = defaultdict(str)
245
+ name = ""
246
+ with open(self.MSA_location, "r") as msa_data:
247
+ for i, line in enumerate(msa_data):
248
+ line = line.rstrip()
249
+ if line.startswith(">"):
250
+ name = line
251
+ if i==0:
252
+ self.focus_seq_name = name
253
+ else:
254
+ self.seq_name_to_sequence[name] += line
255
+
256
+
257
+ ## MSA pre-processing to remove inadequate columns and sequences
258
+ if self.preprocess_MSA:
259
+ msa_df = pd.DataFrame.from_dict(self.seq_name_to_sequence, orient='index', columns=['sequence'])
260
+ # Data clean up
261
+ msa_df.sequence = msa_df.sequence.apply(lambda x: x.replace(".","-")).apply(lambda x: ''.join([aa.upper() for aa in x]))
262
+ # Remove columns that would be gaps in the wild type
263
+ non_gap_wt_cols = [aa!='-' for aa in msa_df.sequence[self.focus_seq_name]]
264
+ msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa for aa,non_gap_ind in zip(x, non_gap_wt_cols) if non_gap_ind]))
265
+ assert 0.0 <= self.threshold_sequence_frac_gaps <= 1.0,"Invalid fragment filtering parameter"
266
+ assert 0.0 <= self.threshold_focus_cols_frac_gaps <= 1.0,"Invalid focus position filtering parameter"
267
+ msa_array = np.array([list(seq) for seq in msa_df.sequence])
268
+ gaps_array = np.array(list(map(lambda seq: [aa=='-' for aa in seq], msa_array)))
269
+ # Identify fragments with too many gaps
270
+ seq_gaps_frac = gaps_array.mean(axis=1)
271
+ seq_below_threshold = seq_gaps_frac <= self.threshold_sequence_frac_gaps
272
+ if verbose: print("Proportion of sequences dropped due to fraction of gaps: "+str(round(float(1 - seq_below_threshold.sum()/seq_below_threshold.shape)*100,2))+"%")
273
+ # Identify focus columns
274
+ columns_gaps_frac = gaps_array[seq_below_threshold].mean(axis=0)
275
+ index_cols_below_threshold = columns_gaps_frac <= self.threshold_focus_cols_frac_gaps
276
+ if verbose: print("Proportion of non-focus columns removed: "+str(round(float(1 - index_cols_below_threshold.sum()/index_cols_below_threshold.shape)*100,2))+"%")
277
+ # Lower case non focus cols and filter fragment sequences
278
+ msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa.upper() if upper_case_ind else aa.lower() for aa, upper_case_ind in zip(x, index_cols_below_threshold)]))
279
+ msa_df = msa_df[seq_below_threshold]
280
+ # Overwrite seq_name_to_sequence with clean version
281
+ self.seq_name_to_sequence = defaultdict(str)
282
+ for seq_idx in range(len(msa_df['sequence'])):
283
+ self.seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx]
284
+
285
+ self.focus_seq = self.seq_name_to_sequence[self.focus_seq_name]
286
+ self.focus_cols = [ix for ix, s in enumerate(self.focus_seq) if s == s.upper() and s!='-']
287
+ self.focus_seq_trimmed = [self.focus_seq[ix] for ix in self.focus_cols]
288
+ self.seq_len = len(self.focus_cols)
289
+ self.alphabet_size = len(self.alphabet)
290
+
291
+ # Connect local sequence index with uniprot index (index shift inferred from 1st row of MSA)
292
+ focus_loc = self.focus_seq_name.split("/")[-1]
293
+ start,stop = focus_loc.split("-")
294
+ self.focus_start_loc = int(start)
295
+ self.focus_stop_loc = int(stop)
296
+ self.uniprot_focus_col_to_wt_aa_dict \
297
+ = {idx_col+int(start):self.focus_seq[idx_col] for idx_col in self.focus_cols}
298
+ self.uniprot_focus_col_to_focus_idx \
299
+ = {idx_col+int(start):idx_col for idx_col in self.focus_cols}
300
+
301
+ # Move all letters to CAPS; keeps focus columns only
302
+ self.raw_seq_name_to_sequence = self.seq_name_to_sequence.copy()
303
+ for seq_name,sequence in self.seq_name_to_sequence.items():
304
+ sequence = sequence.replace(".","-")
305
+ self.seq_name_to_sequence[seq_name] = [sequence[ix].upper() for ix in self.focus_cols]
306
+
307
+ # Remove sequences that have indeterminate AA (e.g., B, J, X, Z) in the focus columns
308
+ if self.remove_sequences_with_indeterminate_AA_in_focus_cols:
309
+ alphabet_set = set(list(self.alphabet))
310
+ seq_names_to_remove = []
311
+ for seq_name,sequence in self.seq_name_to_sequence.items():
312
+ for letter in sequence:
313
+ if letter not in alphabet_set and letter != "-":
314
+ seq_names_to_remove.append(seq_name)
315
+ continue
316
+ seq_names_to_remove = list(set(seq_names_to_remove))
317
+ for seq_name in seq_names_to_remove:
318
+ del self.seq_name_to_sequence[seq_name]
319
+
320
+ # Encode the sequences
321
+ self.one_hot_encoding = np.zeros((len(self.seq_name_to_sequence.keys()),len(self.focus_cols),len(self.alphabet)))
322
+ if verbose: print("One-hot encoded sequences shape:" + str(self.one_hot_encoding.shape))
323
+ for i,seq_name in enumerate(self.seq_name_to_sequence.keys()):
324
+ sequence = self.seq_name_to_sequence[seq_name]
325
+ for j,letter in enumerate(sequence):
326
+ if letter in self.aa_dict:
327
+ k = self.aa_dict[letter]
328
+ self.one_hot_encoding[i,j,k] = 1.0
329
+
330
+ if self.use_weights:
331
+ try:
332
+ self.weights = np.load(file=self.weights_location)
333
+ if verbose: print("Loaded sequence weights from disk")
334
+ except:
335
+ if verbose: print ("Computing sequence weights")
336
+ list_seq = self.one_hot_encoding
337
+ list_seq = list_seq.reshape((list_seq.shape[0], list_seq.shape[1] * list_seq.shape[2]))
338
+ def compute_weight(seq):
339
+ number_non_empty_positions = np.dot(seq,seq)
340
+ if number_non_empty_positions>0:
341
+ denom = np.dot(list_seq,seq) / np.dot(seq,seq)
342
+ denom = np.sum(denom > 1 - self.theta)
343
+ return 1/denom
344
+ else:
345
+ return 0.0 #return 0 weight if sequence is fully empty
346
+ self.weights = np.array(list(map(compute_weight,list_seq)))
347
+ np.save(file=self.weights_location, arr=self.weights)
348
+ else:
349
+ # If not using weights, use an isotropic weight matrix
350
+ if verbose: print("Not weighting sequence data")
351
+ self.weights = np.ones(self.one_hot_encoding.shape[0])
352
+
353
+ self.Neff = np.sum(self.weights)
354
+ self.num_sequences = self.one_hot_encoding.shape[0]
355
+ self.seq_name_to_weight={}
356
+ for i,seq_name in enumerate(self.seq_name_to_sequence.keys()):
357
+ self.seq_name_to_weight[seq_name]=self.weights[i]
358
+
359
+ if verbose:
360
+ print ("Neff =",str(self.Neff))
361
+ print ("Data Shape =",self.one_hot_encoding.shape)
utils/scoring_utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tqdm
3
+ import re
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ import torch
8
+ from torch.nn import CrossEntropyLoss, NLLLoss
9
+ from torch.utils.data.sampler import Sampler, SequentialSampler
10
+
11
+ from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerFast
12
+ from datasets import Dataset
13
+
14
+ AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
15
+
16
+ def get_mutated_sequence(focus_seq, mutant, start_idx=1, AA_vocab=AA_vocab):
17
+ """
18
+ Helper function that mutates an input sequence (focus_seq) via an input mutation triplet (substitutions only).
19
+ Mutation triplet are typically based on 1-indexing: start_idx is used for switching to 0-indexing.
20
+ """
21
+ mutated_seq = list(focus_seq)
22
+ for mutation in mutant.split(":"):
23
+ try:
24
+ from_AA, position, to_AA = mutation[0], int(mutation[1:-1]), mutation[-1]
25
+ except:
26
+ print("Issue with mutant: "+str(mutation))
27
+ relative_position = position - start_idx
28
+ assert (from_AA==focus_seq[relative_position]), "Invalid from_AA or mutant position: "+str(mutation)+" from_AA: "+str(from_AA) + " relative pos: "+str(relative_position) + " focus_seq: "+str(focus_seq)
29
+ assert (to_AA in AA_vocab) , "Mutant to_AA is invalid: "+str(mutation)
30
+ mutated_seq[relative_position] = to_AA
31
+ return "".join(mutated_seq)
32
+
33
+ def nanmean(v, *args, inplace=False, **kwargs):
34
+ if not inplace:
35
+ v = v.clone()
36
+ is_nan = torch.isnan(v)
37
+ v[is_nan] = 0
38
+ return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
39
+
40
+ def nansum(v, *args, inplace=False, **kwargs):
41
+ if not inplace:
42
+ v = v.clone()
43
+ is_nan = torch.isnan(v)
44
+ v[is_nan] = 0
45
+ return v.sum(*args, **kwargs)
46
+
47
+ def get_optimal_window(mutation_position_relative, seq_len_wo_special, model_window):
48
+ """
49
+ Helper function that selects an optimal sequence window that fits the maximum model context size.
50
+ If the sequence length is less than the maximum context size, the full sequence is returned.
51
+ """
52
+ half_model_window = model_window // 2
53
+ if seq_len_wo_special <= model_window:
54
+ return [0,seq_len_wo_special]
55
+ elif mutation_position_relative < half_model_window:
56
+ return [0,model_window]
57
+ elif mutation_position_relative >= seq_len_wo_special - half_model_window:
58
+ return [seq_len_wo_special - model_window, seq_len_wo_special]
59
+ else:
60
+ return [max(0,mutation_position_relative-half_model_window), min(seq_len_wo_special,mutation_position_relative+half_model_window)]
61
+
62
+ def sequence_replace_single(sequence, char_to_replace, char_replacements):
63
+ char_replacements = list(char_replacements)
64
+ positions = [m.start() for m in re.finditer(char_to_replace, sequence)]
65
+ replacements = np.random.choice(a=char_replacements, size=len(positions), replace=True)
66
+ sequence=list(sequence)
67
+ for idx, position in enumerate(positions):
68
+ sequence[position]=replacements[idx]
69
+ return ''.join(sequence)
70
+
71
+ def sequence_replace(sequences, char_to_replace, char_replacements):
72
+ """
73
+ Helper function that replaces all Amino Acids passsed in via char_to_replace (as a string of AAs) with Amino Acids sampled from char_replacements (also a string of eligible AAs).
74
+ """
75
+ return [sequence_replace_single(sequence, char_to_replace, char_replacements) for sequence in sequences]
76
+
77
+ def get_tranception_scores_mutated_sequences(model, mutated_sequence_df, batch_size_inference, score_var_name, len_target_seq, num_workers=10, reverse=False, indel_mode=False):
78
+ """
79
+ Helper function that takes as input a set of mutated sequences (in a pandas dataframe) and returns scores for each mutation (delta log likelihood wrt wild type sequence).
80
+ """
81
+ scores = {}
82
+ scores['mutant']=[]
83
+ scores['window_start']=[]
84
+ scores['window_end']=[]
85
+ scores['score']=[]
86
+ with torch.no_grad():
87
+ ds = Dataset.from_pandas(mutated_sequence_df)
88
+ ds.set_transform(model.encode_batch)
89
+ data_collator = DataCollatorForLanguageModeling(
90
+ tokenizer=model.config.tokenizer,
91
+ mlm=False)
92
+ sampler = SequentialSampler(ds)
93
+ ds_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size_inference, sampler=sampler, collate_fn=data_collator, num_workers=num_workers, pin_memory=True, drop_last=False)
94
+ mutant_index=0
95
+ for encoded_batch in tqdm.tqdm(ds_loader):
96
+ full_batch_length = len(encoded_batch['input_ids'])
97
+ scores['mutant'] += list(mutated_sequence_df['mutant'][mutant_index:mutant_index+full_batch_length])
98
+ window_start = np.array(mutated_sequence_df['window_start'][mutant_index:mutant_index+full_batch_length])
99
+ scores['window_start'] += list(window_start)
100
+ window_end = np.array(mutated_sequence_df['window_end'][mutant_index:mutant_index+full_batch_length])
101
+ scores['window_end'] += list(window_end)
102
+ full_raw_sequence = np.array(mutated_sequence_df['full_raw_sequence'][mutant_index:mutant_index+full_batch_length])
103
+ for k, v in encoded_batch.items():
104
+ if isinstance(v, torch.Tensor):
105
+ encoded_batch[k] = v.to(model.device)
106
+ shift_labels = encoded_batch['labels'][..., 1:].contiguous()
107
+ if (hasattr(model.config,"retrieval_aggregation_mode")) and (model.config.retrieval_aggregation_mode is not None):
108
+ if reverse:
109
+ encoded_batch['flip']=torch.tensor([1]*full_batch_length)
110
+ encoded_batch['start_slice']=window_start
111
+ encoded_batch['end_slice']=window_end
112
+ encoded_batch['full_raw_sequence'] = full_raw_sequence #only mutated_sequence is flipped if the scoring_mirror branch of score_mutants. No need to flip full_raw_sequence for MSA re-aligning
113
+ fused_shift_log_probas=model(**encoded_batch,return_dict=True).fused_shift_log_probas
114
+ loss_fct = NLLLoss(reduction='none')
115
+ loss = - loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
116
+ else:
117
+ lm_logits=model(**encoded_batch,return_dict=True).logits
118
+ shift_logits = lm_logits[..., :-1, :].contiguous()
119
+ loss_fct = CrossEntropyLoss(reduction='none')
120
+ loss = - loss_fct(input=shift_logits.view(-1, shift_logits.size(-1)), target=shift_labels.view(-1)).view(shift_logits.shape[0],shift_logits.shape[1])
121
+ mask = encoded_batch['attention_mask'][..., 1:].float()
122
+ mask[mask==0]=float('nan')
123
+ loss *= mask
124
+ loss = nanmean(loss, dim=1)
125
+ scores_batch = list(loss.cpu().numpy())
126
+ full_batch_length = len(encoded_batch['input_ids'])
127
+ scores['score'] += scores_batch
128
+ mutant_index+=full_batch_length
129
+ scores = pd.DataFrame(scores)
130
+ scores_mutated_seq = scores[scores.mutant != 'wt']
131
+ scores_wt = scores[scores.mutant == 'wt']
132
+ delta_scores = pd.merge(scores_mutated_seq,scores_wt,how='left',on=['window_start'],suffixes=('','_wt'))
133
+ delta_scores[score_var_name] = delta_scores['score'] - delta_scores['score_wt']
134
+ delta_scores=delta_scores[['mutant',score_var_name]].groupby('mutant').mean().reset_index()
135
+ return delta_scores
136
+
137
+ def get_sequence_slices(df, target_seq, model_context_len, start_idx=1, scoring_window="optimal", indel_mode=False):
138
+ """
139
+ Helper function that takes as input a (pandas) dataframe df that contains a list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring.
140
+ It returns a processed DMS in which sequences have been sliced to satisfy the maximum context window of the model.
141
+ df: (dataframe) Input dataframe to be processed
142
+ target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay.
143
+ model_context_len: (int) Maximum context size for the model.
144
+ start_idx: (int) Integer to move to 0-indexing of positions (mutation triplet are typically based on 1-indexing).
145
+ scoring_window: (string) Method to slice sequences longer than maximum context size:
146
+ - optimal selects a single window as large as possible via the get_optimal_window function (this is the default)
147
+ - sliding splits the full sequence in contiguous (non-overlapping) chunks that are of size equal to the max context (except the last chunk which may be shorter)
148
+ indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
149
+ Note: when scoring indels for sequences that would be longer than the model max context length, it is preferable to use the "sliding" scoring_window. Use "optimal" otherwise.
150
+ """
151
+ len_target_seq = len(target_seq)
152
+ num_mutants = len(df['mutant'])
153
+ df=df.reset_index(drop=True)
154
+ if scoring_window=="optimal":
155
+ df['mutation_barycenter'] = df['mutant'].apply(lambda x: int(np.array([int(mutation[1:-1]) - start_idx for mutation in x.split(':')]).mean())) if not indel_mode else df['mutant'].apply(lambda x: len(x)//2)
156
+ df['scoring_optimal_window'] = df['mutation_barycenter'].apply(lambda x: get_optimal_window(x, len_target_seq, model_context_len)) if not indel_mode else df['mutant'].apply(lambda x: (0,len(x)))
157
+ df['full_raw_sequence'] = df['mutated_sequence']
158
+ df['mutated_sequence'] = [df['mutated_sequence'][index][df['scoring_optimal_window'][index][0]:df['scoring_optimal_window'][index][1]] for index in range(num_mutants)]
159
+ df['window_start'] = df['scoring_optimal_window'].map(lambda x: x[0])
160
+ df['window_end'] = df['scoring_optimal_window'].map(lambda x: x[1])
161
+ del df['scoring_optimal_window']
162
+ df_wt=df.copy()
163
+ df_wt['mutant'] = ['wt'] * num_mutants
164
+ df_wt['full_raw_sequence'] = [target_seq] * num_mutants
165
+ if indel_mode: # For indels, we set the wild type reference to be always the same (full length) sequence. We assume here that the length is lower than model context size (otherwise use "Sliding")
166
+ df_wt['mutation_barycenter'] = [len_target_seq // 2] * num_mutants
167
+ df_wt['window_end'] = df_wt['full_raw_sequence'].map(lambda x:len(x))
168
+ df_wt['mutated_sequence'] = [target_seq[df_wt['window_start'][index]:df_wt['window_end'][index]] for index in range(num_mutants)]
169
+ df = pd.concat([df,df_wt], axis=0)
170
+ df = df.drop_duplicates()
171
+ elif scoring_window=="sliding":
172
+ len_target_seq = len(target_seq)
173
+ num_windows = 1 + int( len_target_seq / model_context_len)
174
+ df_list=[]
175
+ start=0
176
+ for window_index in range(1, num_windows+1):
177
+ df_sliced = df.copy()
178
+ df_sliced['full_raw_sequence'] = df_sliced['mutated_sequence']
179
+ df_sliced['mutated_sequence'] = df_sliced['mutated_sequence'].map(lambda x: x[start:start+model_context_len])
180
+ df_sliced['window_start'] = [start] * num_mutants
181
+ df_sliced['window_end'] = df_sliced['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len))
182
+ df_sliced_wt = df_sliced.copy()
183
+ df_sliced_wt['mutant'] = ['wt'] * num_mutants
184
+ df_sliced_wt['full_raw_sequence'] = [target_seq] * num_mutants
185
+ df_sliced_wt['mutated_sequence'] = df_sliced_wt['full_raw_sequence'].map(lambda x: x[start:start+model_context_len])
186
+ df_sliced_wt['window_end'] = df_sliced_wt['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len)) #Need to adjust end index if WT and sequence are not same full length
187
+ df_list.append(df_sliced)
188
+ df_list.append(df_sliced_wt)
189
+ start += model_context_len
190
+ df_final = pd.concat(df_list,axis=0)
191
+ df = df_final.drop_duplicates()
192
+ return df.reset_index(drop=True)
utils/tokenizers/Basic_tokenizer ADDED
@@ -0,0 +1 @@
 
1
+ {"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[CLS]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SEP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":3,"special":true,"content":"[PAD]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":4,"special":true,"content":"[MASK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":{"type":"TemplateProcessing","single":[{"SpecialToken":{"id":"[CLS]","type_id":0}},{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[SEP]","type_id":0}}],"pair":[{"SpecialToken":{"id":"[CLS]","type_id":0}},{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[SEP]","type_id":0}},{"Sequence":{"id":"B","type_id":1}},{"SpecialToken":{"id":"[SEP]","type_id":1}}],"special_tokens":{"[CLS]":{"id":"[CLS]","ids":[1],"tokens":["[CLS]"]},"[SEP]":{"id":"[SEP]","ids":[2],"tokens":["[SEP]"]}}},"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[UNK]":0,"[CLS]":1,"[SEP]":2,"[PAD]":3,"[MASK]":4,"A":5,"C":6,"D":7,"E":8,"F":9,"G":10,"H":11,"I":12,"K":13,"L":14,"M":15,"N":16,"P":17,"Q":18,"R":19,"S":20,"T":21,"V":22,"W":23,"Y":24},"merges":[]}}