PascalNotin
commited on
Commit
•
fa87656
1
Parent(s):
c7fea90
Added Tranception model
Browse files- README.md +133 -3
- __init__.py +1 -0
- activations.py +114 -0
- config.json +46 -0
- config.py +36 -0
- model_pytorch.py +917 -0
- outputs.py +48 -0
- pytorch_model.bin +3 -0
- utils/.DS_Store +0 -0
- utils/__init__.py +1 -0
- utils/dms_utils.py +26 -0
- utils/msa_utils.py +361 -0
- utils/scoring_utils.py +192 -0
- utils/tokenizers/Basic_tokenizer +1 -0
README.md
CHANGED
@@ -1,3 +1,133 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tranception
|
2 |
+
|
3 |
+
This is the official code repository for the paper "Tranception: protein fitness prediction with autoregressive transformers and inference-time retrieval". This project is a joint collaboration between the [Marks lab](https://www.deboramarkslab.com/) and the [OATML group](https://oatml.cs.ox.ac.uk/).
|
4 |
+
|
5 |
+
## Abstract
|
6 |
+
The ability to accurately model the fitness landscape of protein sequences is critical to a wide range of applications, from quantifying the effects of human variants on disease likelihood, to predicting immune-escape mutations in viruses and designing novel biotherapeutic proteins. Deep generative models of protein sequences trained on multiple sequence alignments have been the most successful approaches so far to address these tasks. The performance of these methods is however contingent on the availability of sufficiently deep and diverse alignments for reliable training. Their potential scope is thus limited by the fact many protein families are hard, if not impossible, to align. Large language models trained on massive quantities of non-aligned protein sequences from diverse families address these problems and show potential to eventually bridge the performance gap. We introduce Tranception, a novel transformer architecture leveraging autoregressive predictions and retrieval of homologous sequences at inference to achieve state-of-the-art fitness prediction performance. Given its markedly higher performance on multiple mutants, robustness to shallow alignments and ability to score indels, our approach offers significant gain of scope over existing approaches. To enable more rigorous model testing across a broader range of protein families, we develop ProteinGym -- an extensive set of multiplexed assays of variant effects, substantially increasing both the number and diversity of assays compared to existing benchmarks.
|
7 |
+
|
8 |
+
## Setup
|
9 |
+
You may download the Tranception repository and create a conda environment with the proper dependencies (as listed in `tranception_env.yml`) as follows:
|
10 |
+
```
|
11 |
+
git clone https://github.com/OATML-Markslab/Tranception.git
|
12 |
+
conda env create -f tranception_env.yml
|
13 |
+
```
|
14 |
+
|
15 |
+
## Tranception
|
16 |
+
Tranception is a novel autoregressive transformer architecture that was designed with two core principles in mind: 1) promoting specialization across attention heads 2) explicitly extracting patterns from contiguous subsequences.
|
17 |
+
|
18 |
+
To download the *Tranception Large* model checkpoint (~3.1GB unzipped):
|
19 |
+
```
|
20 |
+
curl -o Tranception_Large.zip https://marks.hms.harvard.edu/ProteinGym/Tranception_Large.zip
|
21 |
+
unzip Tranception_Large.zip
|
22 |
+
rm Tranception_Large.zip
|
23 |
+
```
|
24 |
+
|
25 |
+
Tranception is also made available through the [Huggging Face hub](https://huggingface.co/OATML-Markslab/Tranception).
|
26 |
+
|
27 |
+
When scoring with retrieval, we compute weighted pseudocounts at each position using sequence weights as per the procedure described in [Hopf et al.](https://www.nature.com/articles/nbt.3769).
|
28 |
+
Weights for all proteins in the ProteinGym benchmarks may be downloaded as follows (~68M unzipped):
|
29 |
+
```
|
30 |
+
curl -o MSA_weights.zip https://marks.hms.harvard.edu/ProteinGym/MSA_weights.zip
|
31 |
+
unzip MSA_weights.zip
|
32 |
+
rm MSA_weights.zip
|
33 |
+
```
|
34 |
+
To compute sequence weights for new proteins, you may use the MSA_processing class under `tranception/utils/msa_utils.py`.
|
35 |
+
|
36 |
+
The `examples` folder provides several bash scripts that may be used for scoring and evaluating Tranception on the ProteinGym benchmarks. We also provide a colab notebook illustrating how to load Tranception from the Hugging Face hub and then score the mutated sequences in the ProteinGym benchmarks with it.
|
37 |
+
|
38 |
+
## ProteinGym
|
39 |
+
ProteinGym is an extensive set of Deep Mutational Scanning (DMS) assays curated to enable thorough comparisons of various mutation effect predictors indifferent regimes. ProteinGym is comprised of two benchmarks: 1) a substitution benchmark which consists of the experimental characterisation of ∼1.5M missense variants across 87 DMS assays 2) an indel benchmark that includes ∼300k mutants across 7 DMS assays.
|
40 |
+
|
41 |
+
Each processed file in each benchmark corresponds to a single DMS assay, and contains the following three variables:
|
42 |
+
- mutant (str):
|
43 |
+
- for the substitution benchmark, it describes the set of substitutions to apply on the reference sequence to obtain the mutated sequence (eg., A1P:D2N implies the amino acid 'A' at position 1 should be replaced by 'P', and 'D' at position 2 should be replaced by 'N')
|
44 |
+
- for the indel benchmark, it corresponds to the full mutated sequence
|
45 |
+
- DMS_score (float): corresponds to the experimental measurement in the DMS assay. Across all assays, the higher the DMS_score value, the higher the fitness of the mutated protein
|
46 |
+
- DMS_score_bin (int): indicates whether the DMS_score is above the fitness cutoff (1 is fit, 0 is not fit)
|
47 |
+
|
48 |
+
Additionally, we provide reference files in the [ProteinGym folder](https://github.com/OATML-Markslab/Tranception/tree/main/ProteinGym) that give further details on each assay and contain in particular:
|
49 |
+
- The UniProt_ID of the corresponding protein, along with taxon and MSA depth category
|
50 |
+
- The target sequence (target_seq) used in the assay
|
51 |
+
- Details on how the DMS_score was created from the raw files and how it was binarized
|
52 |
+
|
53 |
+
To download the substitution benchmark (~224M unzipped):
|
54 |
+
```
|
55 |
+
curl -o ProteinGym_substitutions.zip https://marks.hms.harvard.edu/ProteinGym/ProteinGym_substitutions.zip
|
56 |
+
unzip ProteinGym_substitutions.zip
|
57 |
+
rm ProteinGym_substitutions.zip
|
58 |
+
```
|
59 |
+
|
60 |
+
Similarly, to download the indel benchmark (~86M unzipped):
|
61 |
+
```
|
62 |
+
curl -o ProteinGym_indels.zip https://marks.hms.harvard.edu/ProteinGym/ProteinGym_indels.zip
|
63 |
+
unzip ProteinGym_indels.zip
|
64 |
+
rm ProteinGym_indels.zip
|
65 |
+
```
|
66 |
+
|
67 |
+
## Fitness prediction performance
|
68 |
+
|
69 |
+
The [proteingym folder](https://github.com/OATML-Markslab/Tranception/tree/main/ProteinGym) provides detailed performance files for Tranception and baselines on the two ProteinGym benchmarks.
|
70 |
+
|
71 |
+
We recommand to aggregate fitness pediction performance at the Uniprot ID level to avoid biasing results towards proteins for which several DMS assays are available in ProteinGym. The corresponding aggregated files are suffixed with "_Uniprot_level", while the non aggregated performance files are suffixed with "_DMS_level".
|
72 |
+
Furthermore, to enable fair comparison with models trained multiple-sequence alignments (eg., EVE, DeepSequence, EVmutation), we only evaluate on the subset of mutations where position coverage is deemed high enough by these models to make a prediction. The corresponding files are preffixed with "All_models_". For comprehensiveness, we also provide performance files on all possible mutants available in ProteinGym, comparing only with the baselines that are able to score all mutants.
|
73 |
+
Note that for the ProteinGym indel benchmark, baselines that are able to score indels do not have the aforementionned coverage constraints (ie., no distinction between "All_models_" and "All_mutants_") and there is at most one DMS per Uniprot_ID (ie., no difference between "_Uniprot_level" and "_DMS_level"). We thus only provide one set of performance metrics for that benchmark.
|
74 |
+
|
75 |
+
### ProteinGym substitution benchmark - Leaderboard
|
76 |
+
The table below provides the average Spearman's rank correlation between DMS experimental fitness measurements and fitness predictions from Tranception or other baselines on the ProteinGym substitution benchmark. Following the terminology introduced above, we report the performance at the "Uniprot" level for "All models".
|
77 |
+
|
78 |
+
Rank | Model name | Spearman | Reference
|
79 |
+
--- | --- | --- | --- |
|
80 |
+
1 | Ensemble Tranception & EVE | 0.476 | [Notin et al.](https://arxiv.org/abs/2205.13760)
|
81 |
+
2 | Tranception (w/ retrieval) | 0.451 | [Notin et al.](https://arxiv.org/abs/2205.13760)
|
82 |
+
3 | EVE | 0.448 | [Frazer et al.](https://www.nature.com/articles/s41586-021-04043-8)
|
83 |
+
4 | EVmutation | 0.427 | [Hopf et al.](https://www.nature.com/articles/nbt.3769)
|
84 |
+
5 | MSA Transformer | 0.422 | [Rao et al.](https://proceedings.mlr.press/v139/rao21a.html)
|
85 |
+
6 | DeepSequence | 0.415 | [Riesselman et al.](https://www.nature.com/articles/s41592-018-0138-4)
|
86 |
+
7 | Tranception (no retrieval) | 0.406 | [Notin et al.](https://arxiv.org/abs/2205.13760)
|
87 |
+
8 | Wavenet | 0.398 | [Shin et al.](https://www.nature.com/articles/s41467-021-22732-w)
|
88 |
+
9 | Site Independent | 0.397 | [Hopf et al.](https://www.nature.com/articles/nbt.3769)
|
89 |
+
10 | ESM-1v | 0.371 | [Meier et al.](https://proceedings.neurips.cc/paper/2021/hash/f51338d736f95dd42427296047067694-Abstract.html)
|
90 |
+
|
91 |
+
### ProteinGym indel benchmark - Leaderboard
|
92 |
+
The table below provides the average Spearman's rank correlation between DMS experimental fitness measurements and fitness predictions from Tranception or other baselines on the ProteinGym indel benchmark.
|
93 |
+
|
94 |
+
Rank | Model name | Spearman | Reference
|
95 |
+
--- | --- | --- | --- |
|
96 |
+
[Notin et al.](https://arxiv.org/abs/2205.13760)
|
97 |
+
[Notin et al.](https://arxiv.org/abs/2205.13760)
|
98 |
+
[Shin et al.](https://www.nature.com/articles/s41467-021-22732-w)
|
99 |
+
|
100 |
+
## Aggregated model scoring files
|
101 |
+
The scores for all DMS assays in the ProteinGym substitution benchmark for Tranception and other baselines (eg., EVE, Wavenet, ESM-1v, MSA Transformer) may be downloaded as follows;
|
102 |
+
```
|
103 |
+
curl -o scores_all_models_proteingym_substitutions.zip https://marks.hms.harvard.edu/ProteinGym/scores_all_models_proteingym_substitutions.zip
|
104 |
+
unzip scores_all_models_proteingym_substitutions.zip
|
105 |
+
rm scores_all_models_proteingym_substitutions.zip
|
106 |
+
```
|
107 |
+
Similarly for the indel benchmark, all scoring files may be downloaded as follows:
|
108 |
+
```
|
109 |
+
curl -o scores_all_models_proteingym_indels.zip https://marks.hms.harvard.edu/ProteinGym/scores_all_models_proteingym_indels.zip
|
110 |
+
unzip scores_all_models_proteingym_indels.zip
|
111 |
+
rm scores_all_models_proteingym_indels.zip
|
112 |
+
```
|
113 |
+
|
114 |
+
## Multiple Sequence Alignments (MSAs)
|
115 |
+
|
116 |
+
The MSAs used to train alignment-based methods or used at inference in Tranception with retrieval and MSA Transformer may be downloaded as follows (~2.2GB unzipped):
|
117 |
+
```
|
118 |
+
curl -o MSA_ProteinGym.zip https://marks.hms.harvard.edu/ProteinGym/MSA_ProteinGym.zip
|
119 |
+
unzip MSA_ProteinGym.zip
|
120 |
+
rm MSA_ProteinGym.zip
|
121 |
+
```
|
122 |
+
|
123 |
+
## License
|
124 |
+
This project is available under the MIT license.
|
125 |
+
|
126 |
+
## Reference
|
127 |
+
If you use Tranception, ProteinGym or other files provided through this repository (eg., aggregated model scoring files) in your work, please cite the following paper:
|
128 |
+
```
|
129 |
+
Notin, P., Dias, M., Frazer, J., Marchena-Hurtado, J., Gomez, A., Marks, D.S., Gal, Y. (2022). Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval. ICML.
|
130 |
+
```
|
131 |
+
|
132 |
+
## Links
|
133 |
+
Pre-print: https://arxiv.org/abs/2205.13760
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import config
|
activations.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from packaging import version
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from transformers.utils import logging
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.get_logger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def _gelu_python(x):
|
14 |
+
"""
|
15 |
+
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
16 |
+
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
17 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
18 |
+
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
19 |
+
"""
|
20 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
21 |
+
|
22 |
+
|
23 |
+
def gelu_new(x):
|
24 |
+
"""
|
25 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
26 |
+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
27 |
+
"""
|
28 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
29 |
+
|
30 |
+
|
31 |
+
if version.parse(torch.__version__) < version.parse("1.4"):
|
32 |
+
gelu = _gelu_python
|
33 |
+
else:
|
34 |
+
gelu = nn.functional.gelu
|
35 |
+
|
36 |
+
|
37 |
+
def gelu_fast(x):
|
38 |
+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
39 |
+
|
40 |
+
|
41 |
+
def quick_gelu(x):
|
42 |
+
return x * torch.sigmoid(1.702 * x)
|
43 |
+
|
44 |
+
|
45 |
+
def _silu_python(x):
|
46 |
+
"""
|
47 |
+
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
|
48 |
+
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
|
49 |
+
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
|
50 |
+
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
|
51 |
+
later.
|
52 |
+
"""
|
53 |
+
return x * torch.sigmoid(x)
|
54 |
+
|
55 |
+
|
56 |
+
if version.parse(torch.__version__) < version.parse("1.7"):
|
57 |
+
silu = _silu_python
|
58 |
+
else:
|
59 |
+
silu = nn.functional.silu
|
60 |
+
|
61 |
+
|
62 |
+
def _mish_python(x):
|
63 |
+
"""
|
64 |
+
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
65 |
+
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
66 |
+
"""
|
67 |
+
return x * torch.tanh(nn.functional.softplus(x))
|
68 |
+
|
69 |
+
|
70 |
+
if version.parse(torch.__version__) < version.parse("1.9"):
|
71 |
+
mish = _mish_python
|
72 |
+
else:
|
73 |
+
mish = nn.functional.mish
|
74 |
+
|
75 |
+
|
76 |
+
def linear_act(x):
|
77 |
+
return x
|
78 |
+
|
79 |
+
def squared_relu(x):
|
80 |
+
"""
|
81 |
+
Squared ReLU variant that is fastest with Pytorch.
|
82 |
+
"""
|
83 |
+
x = nn.functional.relu(x)
|
84 |
+
return x*x
|
85 |
+
|
86 |
+
def squared_relu_xla(x):
|
87 |
+
"""
|
88 |
+
Squared ReLU variant that is fastest with JAX.
|
89 |
+
"""
|
90 |
+
x = nn.functional.relu(x)
|
91 |
+
return x**2
|
92 |
+
|
93 |
+
tranception_ACT2FN = {
|
94 |
+
"relu": nn.functional.relu,
|
95 |
+
"silu": silu,
|
96 |
+
"swish": silu,
|
97 |
+
"gelu": gelu,
|
98 |
+
"tanh": torch.tanh,
|
99 |
+
"gelu_new": gelu_new,
|
100 |
+
"gelu_fast": gelu_fast,
|
101 |
+
"quick_gelu": quick_gelu,
|
102 |
+
"mish": mish,
|
103 |
+
"linear": linear_act,
|
104 |
+
"sigmoid": torch.sigmoid,
|
105 |
+
"squared_relu": squared_relu,
|
106 |
+
"squared_relu_xla": squared_relu_xla,
|
107 |
+
}
|
108 |
+
|
109 |
+
|
110 |
+
def get_activation(activation_string):
|
111 |
+
if activation_string in tranception_ACT2FN:
|
112 |
+
return tranception_ACT2FN[activation_string]
|
113 |
+
else:
|
114 |
+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(tranception_ACT2FN.keys())}")
|
config.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"MSA_end": null,
|
3 |
+
"MSA_filename": null,
|
4 |
+
"MSA_start": null,
|
5 |
+
"MSA_weight_file_name": null,
|
6 |
+
"_name_or_path": "Tranception_Large",
|
7 |
+
"activation_function": "squared_relu",
|
8 |
+
"architectures": [
|
9 |
+
"TranceptionLMHeadModel"
|
10 |
+
],
|
11 |
+
"attention_mode": "tranception",
|
12 |
+
"attn_pdrop": 0.1,
|
13 |
+
"bos_token_id": 1,
|
14 |
+
"clustal_omega_location": null,
|
15 |
+
"embd_pdrop": 0.1,
|
16 |
+
"eos_token_id": 2,
|
17 |
+
"full_protein_length": null,
|
18 |
+
"initializer_range": 0.02,
|
19 |
+
"layer_norm_epsilon": 1e-05,
|
20 |
+
"local_batch_size": 1,
|
21 |
+
"model_type": "tranception",
|
22 |
+
"n_ctx": 1024,
|
23 |
+
"n_embd": 1280,
|
24 |
+
"n_head": 20,
|
25 |
+
"n_inner": 5120,
|
26 |
+
"n_layer": 36,
|
27 |
+
"n_positions": 1024,
|
28 |
+
"position_embedding": "grouped_alibi",
|
29 |
+
"reorder_and_upcast_attn": false,
|
30 |
+
"resid_pdrop": 0.1,
|
31 |
+
"retrieval_aggregation_mode": null,
|
32 |
+
"retrieval_inference_weight": 0.6,
|
33 |
+
"scale_attn_by_inverse_layer_idx": false,
|
34 |
+
"scale_attn_weights": true,
|
35 |
+
"scoring_window": "optimal",
|
36 |
+
"summary_activation": null,
|
37 |
+
"summary_first_dropout": 0.1,
|
38 |
+
"summary_proj_to_labels": true,
|
39 |
+
"summary_type": "cls_index",
|
40 |
+
"summary_use_proj": true,
|
41 |
+
"tokenizer": null,
|
42 |
+
"torch_dtype": "float32",
|
43 |
+
"transformers_version": "4.17.0",
|
44 |
+
"use_cache": true,
|
45 |
+
"vocab_size": 25
|
46 |
+
}
|
config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPT2Config
|
2 |
+
|
3 |
+
class TranceptionConfig(GPT2Config):
|
4 |
+
"""
|
5 |
+
Config subclass for Tranception model architecture.
|
6 |
+
"""
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
attention_mode="tranception",
|
10 |
+
position_embedding="grouped_alibi",
|
11 |
+
tokenizer=None,
|
12 |
+
retrieval_aggregation_mode=None,
|
13 |
+
retrieval_inference_weight=0.6,
|
14 |
+
MSA_filename=None,
|
15 |
+
MSA_weight_file_name=None,
|
16 |
+
MSA_start=None,
|
17 |
+
MSA_end=None,
|
18 |
+
full_protein_length=None,
|
19 |
+
clustal_omega_location=None,
|
20 |
+
scoring_window=None,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
+
super().__init__(**kwargs)
|
24 |
+
self.model_type="tranception"
|
25 |
+
self.attention_mode=attention_mode
|
26 |
+
self.position_embedding=position_embedding
|
27 |
+
self.tokenizer = tokenizer
|
28 |
+
self.retrieval_aggregation_mode = retrieval_aggregation_mode
|
29 |
+
self.retrieval_inference_weight = retrieval_inference_weight
|
30 |
+
self.MSA_filename = MSA_filename
|
31 |
+
self.MSA_weight_file_name = MSA_weight_file_name
|
32 |
+
self.MSA_start=MSA_start
|
33 |
+
self.MSA_end=MSA_end
|
34 |
+
self.full_protein_length = full_protein_length
|
35 |
+
self.clustal_omega_location = clustal_omega_location
|
36 |
+
self.scoring_window=scoring_window
|
model_pytorch.py
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import CrossEntropyLoss, NLLLoss
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import GPT2PreTrainedModel
|
12 |
+
|
13 |
+
from transformers.modeling_utils import (
|
14 |
+
Conv1D,
|
15 |
+
PreTrainedModel,
|
16 |
+
SequenceSummary,
|
17 |
+
find_pruneable_heads_and_indices,
|
18 |
+
prune_conv1d_layer,
|
19 |
+
)
|
20 |
+
from transformers.file_utils import (
|
21 |
+
ModelOutput,
|
22 |
+
add_code_sample_docstrings,
|
23 |
+
add_start_docstrings,
|
24 |
+
add_start_docstrings_to_model_forward,
|
25 |
+
replace_return_docstrings
|
26 |
+
)
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
29 |
+
CausalLMOutputWithCrossAttentions,
|
30 |
+
SequenceClassifierOutputWithPast,
|
31 |
+
TokenClassifierOutput
|
32 |
+
)
|
33 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
34 |
+
|
35 |
+
from tranception.activations import tranception_ACT2FN
|
36 |
+
from tranception.config import TranceptionConfig
|
37 |
+
from tranception.outputs import (
|
38 |
+
TranceptionCausalLMOutputWithCrossAttentions,
|
39 |
+
)
|
40 |
+
from tranception.utils import msa_utils
|
41 |
+
from tranception.utils import scoring_utils
|
42 |
+
|
43 |
+
def nanmean(v, *args, inplace=False, **kwargs):
|
44 |
+
if not inplace:
|
45 |
+
v = v.clone()
|
46 |
+
is_nan = torch.isnan(v)
|
47 |
+
v[is_nan] = 0
|
48 |
+
return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
|
49 |
+
|
50 |
+
def get_slopes(n, mode="standard_alibi", verbose=False):
|
51 |
+
"""
|
52 |
+
Function to compute the m constant for each attention head. Code has been adapted from the official ALiBi codebase at:
|
53 |
+
https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py
|
54 |
+
"""
|
55 |
+
def get_slopes_power_of_2(n):
|
56 |
+
start = (2**(-2**-(math.log2(n)-3)))
|
57 |
+
ratio = start
|
58 |
+
return [start*ratio**i for i in range(n)]
|
59 |
+
if mode=="grouped_alibi":
|
60 |
+
n = n // 4
|
61 |
+
if math.log2(n).is_integer():
|
62 |
+
result = get_slopes_power_of_2(n)
|
63 |
+
else:
|
64 |
+
#Workaround when the number of heads is not a power of 2
|
65 |
+
closest_power_of_2 = 2**math.floor(math.log2(n))
|
66 |
+
result = get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
|
67 |
+
if mode=="grouped_alibi":
|
68 |
+
result = result * 4
|
69 |
+
if verbose:
|
70 |
+
print("ALiBi slopes: {}".format(result))
|
71 |
+
return result
|
72 |
+
|
73 |
+
class SpatialDepthWiseConvolution(nn.Module):
|
74 |
+
def __init__(self, head_dim: int, kernel_size: int = 3):
|
75 |
+
super().__init__()
|
76 |
+
self.kernel_size = kernel_size
|
77 |
+
self.conv = nn.Conv1d(in_channels=head_dim, out_channels=head_dim, kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=head_dim)
|
78 |
+
|
79 |
+
def forward(self, x: torch.Tensor):
|
80 |
+
batch_size, heads, seq_len, head_dim = x.shape
|
81 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
82 |
+
x = x.view(batch_size * heads, head_dim, seq_len)
|
83 |
+
x = self.conv(x)
|
84 |
+
if self.kernel_size>1:
|
85 |
+
x = x[:, :, :-(self.kernel_size - 1)]
|
86 |
+
x = x.view(batch_size, heads, head_dim, seq_len)
|
87 |
+
x = x.permute(0, 1, 3, 2)
|
88 |
+
return x
|
89 |
+
|
90 |
+
class TranceptionBlockAttention(nn.Module):
|
91 |
+
def __init__(self, config, is_cross_attention=False, SDWC_kernel_size=None):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
max_positions = config.max_position_embeddings
|
95 |
+
self.register_buffer(
|
96 |
+
"bias",
|
97 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
98 |
+
1, 1, max_positions, max_positions
|
99 |
+
),
|
100 |
+
)
|
101 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
102 |
+
|
103 |
+
self.embed_dim = config.hidden_size
|
104 |
+
self.num_heads = config.num_attention_heads
|
105 |
+
self.head_dim = self.embed_dim // self.num_heads
|
106 |
+
self.split_size = self.embed_dim
|
107 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
108 |
+
raise ValueError(
|
109 |
+
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
110 |
+
)
|
111 |
+
|
112 |
+
self.scale_attn_weights = config.scale_attn_weights
|
113 |
+
self.is_cross_attention = is_cross_attention
|
114 |
+
|
115 |
+
if self.is_cross_attention:
|
116 |
+
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
117 |
+
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
118 |
+
else:
|
119 |
+
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
120 |
+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
121 |
+
|
122 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
123 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
124 |
+
|
125 |
+
self.pruned_heads = set()
|
126 |
+
|
127 |
+
self.attention_mode=config.attention_mode
|
128 |
+
|
129 |
+
if self.attention_mode=="tranception":
|
130 |
+
assert self.num_heads%4==0, "Invalid number of heads. Tranception requires the number of heads to be a multiple of 4."
|
131 |
+
self.num_heads_per_kernel_size = self.num_heads // 4
|
132 |
+
self.query_depthwiseconv = nn.ModuleDict()
|
133 |
+
self.key_depthwiseconv = nn.ModuleDict()
|
134 |
+
self.value_depthwiseconv = nn.ModuleDict()
|
135 |
+
for kernel_idx, kernel in enumerate([3,5,7]):
|
136 |
+
self.query_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
|
137 |
+
self.key_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
|
138 |
+
self.value_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
|
139 |
+
|
140 |
+
def prune_heads(self, heads):
|
141 |
+
if len(heads) == 0:
|
142 |
+
return
|
143 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
144 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
145 |
+
|
146 |
+
# Prune conv1d layers
|
147 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
148 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
149 |
+
|
150 |
+
# Update hyper params
|
151 |
+
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
152 |
+
self.num_heads = self.num_heads - len(heads)
|
153 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
154 |
+
|
155 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None, alibi_bias=None):
|
156 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
157 |
+
|
158 |
+
if self.scale_attn_weights:
|
159 |
+
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
|
160 |
+
|
161 |
+
if not self.is_cross_attention:
|
162 |
+
# if only "normal" attention layer implements causal mask
|
163 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
164 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
165 |
+
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
166 |
+
|
167 |
+
if alibi_bias is not None:
|
168 |
+
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]
|
169 |
+
|
170 |
+
if attention_mask is not None:
|
171 |
+
# Apply the attention mask
|
172 |
+
attn_weights = attn_weights + attention_mask
|
173 |
+
|
174 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
175 |
+
attn_weights = self.attn_dropout(attn_weights)
|
176 |
+
|
177 |
+
# Mask heads if we want to
|
178 |
+
if head_mask is not None:
|
179 |
+
attn_weights = attn_weights * head_mask
|
180 |
+
|
181 |
+
attn_output = torch.matmul(attn_weights, value)
|
182 |
+
|
183 |
+
return attn_output, attn_weights
|
184 |
+
|
185 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
186 |
+
"""
|
187 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
188 |
+
"""
|
189 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
190 |
+
tensor = tensor.view(*new_shape)
|
191 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
192 |
+
|
193 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
194 |
+
"""
|
195 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
196 |
+
"""
|
197 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
198 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
199 |
+
return tensor.view(new_shape)
|
200 |
+
|
201 |
+
def forward(
|
202 |
+
self,
|
203 |
+
hidden_states,
|
204 |
+
layer_past=None,
|
205 |
+
attention_mask=None,
|
206 |
+
head_mask=None,
|
207 |
+
encoder_hidden_states=None,
|
208 |
+
encoder_attention_mask=None,
|
209 |
+
use_cache=False,
|
210 |
+
output_attentions=False,
|
211 |
+
alibi_bias=None,
|
212 |
+
):
|
213 |
+
if encoder_hidden_states is not None:
|
214 |
+
if not hasattr(self, "q_attn"):
|
215 |
+
raise ValueError(
|
216 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
217 |
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
218 |
+
)
|
219 |
+
|
220 |
+
query = self.q_attn(hidden_states)
|
221 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
222 |
+
attention_mask = encoder_attention_mask
|
223 |
+
else:
|
224 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
225 |
+
|
226 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
227 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
228 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
229 |
+
|
230 |
+
if layer_past is not None:
|
231 |
+
past_key, past_value = layer_past
|
232 |
+
key = torch.cat((past_key, key), dim=-2)
|
233 |
+
value = torch.cat((past_value, value), dim=-2)
|
234 |
+
|
235 |
+
if use_cache is True:
|
236 |
+
present = (key, value)
|
237 |
+
else:
|
238 |
+
present = None
|
239 |
+
|
240 |
+
if self.attention_mode=="tranception":
|
241 |
+
# We do not do anything on the first self.num_heads_per_kernel_size heads (kernel =1)
|
242 |
+
query_list=[query[:,:self.num_heads_per_kernel_size,:,:]]
|
243 |
+
key_list=[key[:,:self.num_heads_per_kernel_size,:,:]]
|
244 |
+
value_list=[value[:,:self.num_heads_per_kernel_size,:,:]]
|
245 |
+
for kernel_idx in range(3):
|
246 |
+
query_list.append(self.query_depthwiseconv[str(kernel_idx)](query[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
|
247 |
+
key_list.append(self.key_depthwiseconv[str(kernel_idx)](key[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
|
248 |
+
value_list.append(self.value_depthwiseconv[str(kernel_idx)](value[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
|
249 |
+
query=torch.cat(query_list, dim=1)
|
250 |
+
key=torch.cat(key_list, dim=1)
|
251 |
+
value=torch.cat(value_list, dim=1)
|
252 |
+
|
253 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, alibi_bias=alibi_bias)
|
254 |
+
|
255 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
256 |
+
attn_output = self.c_proj(attn_output)
|
257 |
+
attn_output = self.resid_dropout(attn_output)
|
258 |
+
|
259 |
+
outputs = (attn_output, present)
|
260 |
+
if output_attentions:
|
261 |
+
outputs += (attn_weights,)
|
262 |
+
|
263 |
+
return outputs # a, present, (attentions)
|
264 |
+
|
265 |
+
class TranceptionBlockMLP(nn.Module):
|
266 |
+
def __init__(self, intermediate_size, config):
|
267 |
+
super().__init__()
|
268 |
+
embed_dim = config.hidden_size
|
269 |
+
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
270 |
+
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
271 |
+
self.act = tranception_ACT2FN[config.activation_function]
|
272 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
273 |
+
|
274 |
+
def forward(self, hidden_states):
|
275 |
+
hidden_states = self.c_fc(hidden_states)
|
276 |
+
hidden_states = self.act(hidden_states)
|
277 |
+
hidden_states = self.c_proj(hidden_states)
|
278 |
+
hidden_states = self.dropout(hidden_states)
|
279 |
+
return hidden_states
|
280 |
+
|
281 |
+
class TranceptionBlock(nn.Module):
|
282 |
+
def __init__(self, config, SDWC_kernel_size=None):
|
283 |
+
super().__init__()
|
284 |
+
hidden_size = config.hidden_size
|
285 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
286 |
+
|
287 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
288 |
+
self.attn = TranceptionBlockAttention(config, SDWC_kernel_size=SDWC_kernel_size)
|
289 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
290 |
+
|
291 |
+
if config.add_cross_attention:
|
292 |
+
self.crossattention = TranceptionBlockAttention(config, is_cross_attention=True, SDWC_kernel_size=SDWC_kernel_size)
|
293 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
294 |
+
|
295 |
+
self.mlp = TranceptionBlockMLP(inner_dim, config)
|
296 |
+
|
297 |
+
def forward(
|
298 |
+
self,
|
299 |
+
hidden_states,
|
300 |
+
layer_past=None,
|
301 |
+
attention_mask=None,
|
302 |
+
head_mask=None,
|
303 |
+
encoder_hidden_states=None,
|
304 |
+
encoder_attention_mask=None,
|
305 |
+
use_cache=False,
|
306 |
+
output_attentions=False,
|
307 |
+
alibi_bias=None,
|
308 |
+
):
|
309 |
+
residual = hidden_states
|
310 |
+
hidden_states = self.ln_1(hidden_states)
|
311 |
+
attn_outputs = self.attn(
|
312 |
+
hidden_states,
|
313 |
+
layer_past=layer_past,
|
314 |
+
attention_mask=attention_mask,
|
315 |
+
head_mask=head_mask,
|
316 |
+
use_cache=use_cache,
|
317 |
+
output_attentions=output_attentions,
|
318 |
+
alibi_bias=alibi_bias,
|
319 |
+
)
|
320 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
321 |
+
outputs = attn_outputs[1:]
|
322 |
+
# residual connection
|
323 |
+
hidden_states = attn_output + residual
|
324 |
+
|
325 |
+
if encoder_hidden_states is not None:
|
326 |
+
# add one self-attention block for cross-attention
|
327 |
+
if not hasattr(self, "crossattention"):
|
328 |
+
raise ValueError(
|
329 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
330 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
331 |
+
)
|
332 |
+
residual = hidden_states
|
333 |
+
hidden_states = self.ln_cross_attn(hidden_states)
|
334 |
+
cross_attn_outputs = self.crossattention(
|
335 |
+
hidden_states,
|
336 |
+
attention_mask=attention_mask,
|
337 |
+
head_mask=head_mask,
|
338 |
+
encoder_hidden_states=encoder_hidden_states,
|
339 |
+
encoder_attention_mask=encoder_attention_mask,
|
340 |
+
output_attentions=output_attentions,
|
341 |
+
)
|
342 |
+
attn_output = cross_attn_outputs[0]
|
343 |
+
# residual connection
|
344 |
+
hidden_states = residual + attn_output
|
345 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
346 |
+
|
347 |
+
residual = hidden_states
|
348 |
+
hidden_states = self.ln_2(hidden_states)
|
349 |
+
|
350 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
351 |
+
|
352 |
+
# residual connection
|
353 |
+
hidden_states = residual + feed_forward_hidden_states
|
354 |
+
|
355 |
+
if use_cache:
|
356 |
+
outputs = (hidden_states,) + outputs
|
357 |
+
else:
|
358 |
+
outputs = (hidden_states,) + outputs[1:]
|
359 |
+
|
360 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
361 |
+
|
362 |
+
class TranceptionModel(GPT2PreTrainedModel):
|
363 |
+
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
364 |
+
def __init__(self, config):
|
365 |
+
super().__init__(config)
|
366 |
+
|
367 |
+
self.embed_dim = config.hidden_size
|
368 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
369 |
+
self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"
|
370 |
+
if self.position_embedding=="learned":
|
371 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
372 |
+
self.alibi = None
|
373 |
+
elif self.position_embedding=="grouped_alibi":
|
374 |
+
maxpos = config.n_positions
|
375 |
+
attn_heads = config.n_head
|
376 |
+
self.slopes = torch.Tensor(get_slopes(attn_heads, mode=self.position_embedding))
|
377 |
+
#The softmax operation is invariant to translation, and bias functions used are always linear.
|
378 |
+
alibi = self.slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1)
|
379 |
+
alibi = alibi.view(attn_heads, 1, maxpos)
|
380 |
+
self.register_buffer('alibi',alibi)
|
381 |
+
|
382 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
383 |
+
self.h = nn.ModuleList([TranceptionBlock(config) for _ in range(config.num_hidden_layers)])
|
384 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
385 |
+
|
386 |
+
self.init_weights()
|
387 |
+
|
388 |
+
# Model parallel
|
389 |
+
self.model_parallel = False
|
390 |
+
self.device_map = None
|
391 |
+
self.gradient_checkpointing = False
|
392 |
+
|
393 |
+
def parallelize(self, device_map=None, num_cores=None):
|
394 |
+
self.device_map = (
|
395 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
396 |
+
)
|
397 |
+
device_prefix="cuda:"
|
398 |
+
assert_device_map(self.device_map, len(self.h))
|
399 |
+
self.model_parallel = True
|
400 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else device_prefix + str(min(self.device_map.keys()))
|
401 |
+
self.last_device = device_prefix + str(max(self.device_map.keys()))
|
402 |
+
self.wte = self.wte.to(self.first_device)
|
403 |
+
if self.position_embedding=="learned":
|
404 |
+
self.wpe = self.wpe.to(self.first_device)
|
405 |
+
for k, v in self.device_map.items():
|
406 |
+
print("k,v :"+str(k)+","+str(v))
|
407 |
+
for block in v:
|
408 |
+
cuda_device = device_prefix + str(k)
|
409 |
+
self.h[block] = self.h[block].to(cuda_device)
|
410 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
411 |
+
|
412 |
+
def deparallelize(self):
|
413 |
+
self.model_parallel = False
|
414 |
+
self.device_map = None
|
415 |
+
self.first_device = "cpu"
|
416 |
+
self.last_device = "cpu"
|
417 |
+
self.wte = self.wte.to("cpu")
|
418 |
+
if self.position_embedding=="learned":
|
419 |
+
self.wpe = self.wpe.to("cpu")
|
420 |
+
for index in range(len(self.h)):
|
421 |
+
self.h[index] = self.h[index].to("cpu")
|
422 |
+
self.ln_f = self.ln_f.to("cpu")
|
423 |
+
torch.cuda.empty_cache()
|
424 |
+
|
425 |
+
def get_input_embeddings(self):
|
426 |
+
return self.wte
|
427 |
+
|
428 |
+
def set_input_embeddings(self, new_embeddings):
|
429 |
+
self.wte = new_embeddings
|
430 |
+
|
431 |
+
def _prune_heads(self, heads_to_prune):
|
432 |
+
"""
|
433 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
434 |
+
"""
|
435 |
+
for layer, heads in heads_to_prune.items():
|
436 |
+
self.h[layer].attn.prune_heads(heads)
|
437 |
+
|
438 |
+
def forward(
|
439 |
+
self,
|
440 |
+
input_ids=None,
|
441 |
+
past_key_values=None,
|
442 |
+
attention_mask=None,
|
443 |
+
token_type_ids=None,
|
444 |
+
position_ids=None,
|
445 |
+
head_mask=None,
|
446 |
+
inputs_embeds=None,
|
447 |
+
encoder_hidden_states=None,
|
448 |
+
encoder_attention_mask=None,
|
449 |
+
use_cache=None,
|
450 |
+
output_attentions=None,
|
451 |
+
output_hidden_states=None,
|
452 |
+
return_dict=None,
|
453 |
+
):
|
454 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
455 |
+
output_hidden_states = (
|
456 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
457 |
+
)
|
458 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
459 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
460 |
+
|
461 |
+
if input_ids is not None and inputs_embeds is not None:
|
462 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
463 |
+
elif input_ids is not None:
|
464 |
+
input_shape = input_ids.size()
|
465 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
466 |
+
batch_size = input_ids.shape[0]
|
467 |
+
elif inputs_embeds is not None:
|
468 |
+
input_shape = inputs_embeds.size()[:-1]
|
469 |
+
batch_size = inputs_embeds.shape[0]
|
470 |
+
else:
|
471 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
472 |
+
|
473 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
474 |
+
|
475 |
+
if token_type_ids is not None:
|
476 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
477 |
+
if position_ids is not None:
|
478 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
479 |
+
|
480 |
+
if past_key_values is None:
|
481 |
+
past_length = 0
|
482 |
+
past_key_values = tuple([None] * len(self.h))
|
483 |
+
else:
|
484 |
+
past_length = past_key_values[0][0].size(-2)
|
485 |
+
if position_ids is None:
|
486 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
487 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
488 |
+
|
489 |
+
# GPT2Attention mask.
|
490 |
+
if attention_mask is not None:
|
491 |
+
if batch_size <= 0:
|
492 |
+
raise ValueError("batch_size has to be defined and > 0")
|
493 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
494 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
495 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
496 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
497 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
498 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
499 |
+
attention_mask = attention_mask[:, None, None, :]
|
500 |
+
|
501 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
502 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
503 |
+
# positions we want to attend and -10000.0 for masked positions.
|
504 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
505 |
+
# effectively the same as removing these entirely.
|
506 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
507 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
508 |
+
|
509 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
510 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
511 |
+
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
512 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
513 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
514 |
+
if encoder_attention_mask is None:
|
515 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
516 |
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
517 |
+
else:
|
518 |
+
encoder_attention_mask = None
|
519 |
+
|
520 |
+
# Prepare head mask if needed
|
521 |
+
# 1.0 in head_mask indicate we keep the head
|
522 |
+
# attention_probs has shape bsz x n_heads x N x N
|
523 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
524 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
525 |
+
|
526 |
+
if inputs_embeds is None:
|
527 |
+
inputs_embeds = self.wte(input_ids)
|
528 |
+
if self.position_embedding=="learned":
|
529 |
+
position_embeds = self.wpe(position_ids)
|
530 |
+
hidden_states = inputs_embeds + position_embeds
|
531 |
+
else:
|
532 |
+
hidden_states = inputs_embeds
|
533 |
+
|
534 |
+
if token_type_ids is not None:
|
535 |
+
token_type_embeds = self.wte(token_type_ids)
|
536 |
+
hidden_states = hidden_states + token_type_embeds
|
537 |
+
|
538 |
+
hidden_states = self.drop(hidden_states)
|
539 |
+
|
540 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
541 |
+
|
542 |
+
presents = () if use_cache else None
|
543 |
+
all_self_attentions = () if output_attentions else None
|
544 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
545 |
+
all_hidden_states = () if output_hidden_states else None
|
546 |
+
|
547 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
548 |
+
# Model parallel
|
549 |
+
if self.model_parallel:
|
550 |
+
torch.cuda.set_device(hidden_states.device)
|
551 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
552 |
+
if layer_past is not None:
|
553 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
554 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
555 |
+
if attention_mask is not None:
|
556 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
557 |
+
if isinstance(head_mask, torch.Tensor):
|
558 |
+
head_mask = head_mask.to(hidden_states.device)
|
559 |
+
if output_hidden_states:
|
560 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
561 |
+
|
562 |
+
if self.gradient_checkpointing and self.training:
|
563 |
+
if use_cache:
|
564 |
+
logger.warning(
|
565 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
566 |
+
)
|
567 |
+
use_cache = False
|
568 |
+
|
569 |
+
def create_custom_forward(module):
|
570 |
+
def custom_forward(*inputs):
|
571 |
+
# None for past_key_value
|
572 |
+
return module(*inputs, use_cache, output_attentions)
|
573 |
+
|
574 |
+
return custom_forward
|
575 |
+
|
576 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
577 |
+
create_custom_forward(block),
|
578 |
+
hidden_states,
|
579 |
+
None,
|
580 |
+
attention_mask,
|
581 |
+
head_mask[i],
|
582 |
+
encoder_hidden_states,
|
583 |
+
encoder_attention_mask,
|
584 |
+
)
|
585 |
+
else:
|
586 |
+
outputs = block(
|
587 |
+
hidden_states,
|
588 |
+
layer_past=layer_past,
|
589 |
+
attention_mask=attention_mask,
|
590 |
+
head_mask=head_mask[i],
|
591 |
+
encoder_hidden_states=encoder_hidden_states,
|
592 |
+
encoder_attention_mask=encoder_attention_mask,
|
593 |
+
use_cache=use_cache,
|
594 |
+
output_attentions=output_attentions,
|
595 |
+
alibi_bias=self.alibi if hasattr(self, "alibi") else None
|
596 |
+
)
|
597 |
+
|
598 |
+
hidden_states = outputs[0]
|
599 |
+
|
600 |
+
if use_cache is True:
|
601 |
+
presents = presents + (outputs[1],)
|
602 |
+
|
603 |
+
if output_attentions:
|
604 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
605 |
+
if self.config.add_cross_attention:
|
606 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
607 |
+
|
608 |
+
if self.model_parallel:
|
609 |
+
device_prefix="cuda:"
|
610 |
+
for k, v in self.device_map.items():
|
611 |
+
if i == v[-1] and device_prefix + str(k) != self.last_device:
|
612 |
+
hidden_states = hidden_states.to(device_prefix + str(k + 1))
|
613 |
+
|
614 |
+
hidden_states = self.ln_f(hidden_states)
|
615 |
+
|
616 |
+
hidden_states = hidden_states.view(*output_shape)
|
617 |
+
# Add last hidden state
|
618 |
+
if output_hidden_states:
|
619 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
620 |
+
|
621 |
+
if not return_dict:
|
622 |
+
return tuple(
|
623 |
+
v
|
624 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions, moe_loss]
|
625 |
+
if v is not None
|
626 |
+
)
|
627 |
+
|
628 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
629 |
+
last_hidden_state=hidden_states,
|
630 |
+
past_key_values=presents,
|
631 |
+
hidden_states=all_hidden_states,
|
632 |
+
attentions=all_self_attentions,
|
633 |
+
cross_attentions=all_cross_attentions,
|
634 |
+
)
|
635 |
+
|
636 |
+
class TranceptionLMHeadModel(GPT2PreTrainedModel):
|
637 |
+
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
638 |
+
def __init__(self, config):
|
639 |
+
super().__init__(config)
|
640 |
+
self.transformer = TranceptionModel(config)
|
641 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
642 |
+
self.config = config
|
643 |
+
|
644 |
+
self.init_weights()
|
645 |
+
|
646 |
+
self.default_model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
647 |
+
# Model parallel
|
648 |
+
self.model_parallel = False
|
649 |
+
self.device_map = None
|
650 |
+
|
651 |
+
self.retrieval_aggregation_mode = config.retrieval_aggregation_mode if hasattr(config, "retrieval_aggregation_mode") else None
|
652 |
+
if self.retrieval_aggregation_mode is not None:
|
653 |
+
print("Model leverages both autoregressive and retrieval inference")
|
654 |
+
self.MSA_filename = config.MSA_filename if hasattr(config, "MSA_filename") else False
|
655 |
+
self.MSA_folder = '/'.join(self.MSA_filename.split(os.sep)[:-1])
|
656 |
+
self.MSA_name = self.MSA_filename.split(os.sep)[-1]
|
657 |
+
self.retrieval_inference_weight_LR = config.retrieval_inference_weight if hasattr(config, "retrieval_inference_weight") else 0.6
|
658 |
+
self.retrieval_inference_weight_RL = config.retrieval_inference_weight if hasattr(config, "retrieval_inference_weight") else 0.6
|
659 |
+
self.MSA_start=config.MSA_start
|
660 |
+
self.MSA_end=config.MSA_end
|
661 |
+
self.full_protein_length = config.full_protein_length if hasattr(config, "full_protein_length") else -1
|
662 |
+
|
663 |
+
self.MSA_log_prior = torch.log(torch.tensor(
|
664 |
+
msa_utils.get_msa_prior(
|
665 |
+
MSA_data_file=self.MSA_filename,
|
666 |
+
MSA_weight_file_name=config.MSA_weight_file_name,
|
667 |
+
retrieval_aggregation_mode=self.retrieval_aggregation_mode,
|
668 |
+
MSA_start=self.MSA_start,
|
669 |
+
MSA_end=self.MSA_end,
|
670 |
+
len_target_seq=self.full_protein_length,
|
671 |
+
vocab=config.tokenizer.get_vocab(),
|
672 |
+
verbose=False
|
673 |
+
)
|
674 |
+
).float().to(self.default_model_device))
|
675 |
+
else:
|
676 |
+
print("Model only uses autoregressive inference")
|
677 |
+
|
678 |
+
def parallelize(self, device_map=None, num_cores=None, num_pipelines=1):
|
679 |
+
self.num_pipelines=num_pipelines
|
680 |
+
self.device_map = (
|
681 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
682 |
+
if device_map is None
|
683 |
+
else device_map
|
684 |
+
)
|
685 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
686 |
+
self.transformer.parallelize(self.device_map, num_cores=num_cores)
|
687 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
688 |
+
self.model_parallel = True
|
689 |
+
|
690 |
+
def deparallelize(self):
|
691 |
+
self.transformer.deparallelize()
|
692 |
+
self.transformer = self.transformer.to("cpu")
|
693 |
+
self.lm_head = self.lm_head.to("cpu")
|
694 |
+
self.model_parallel = False
|
695 |
+
torch.cuda.empty_cache()
|
696 |
+
|
697 |
+
def get_output_embeddings(self):
|
698 |
+
return self.lm_head
|
699 |
+
|
700 |
+
def set_output_embeddings(self, new_embeddings):
|
701 |
+
self.lm_head = new_embeddings
|
702 |
+
|
703 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
704 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
705 |
+
# only last token for inputs_ids if past is defined in kwargs
|
706 |
+
if past:
|
707 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
708 |
+
if token_type_ids is not None:
|
709 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
710 |
+
|
711 |
+
attention_mask = kwargs.get("attention_mask", None)
|
712 |
+
position_ids = kwargs.get("position_ids", None)
|
713 |
+
|
714 |
+
if attention_mask is not None and position_ids is None:
|
715 |
+
# create position_ids on the fly for batch generation
|
716 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
717 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
718 |
+
if past:
|
719 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
720 |
+
else:
|
721 |
+
position_ids = None
|
722 |
+
|
723 |
+
return {
|
724 |
+
"input_ids": input_ids,
|
725 |
+
"past_key_values": past,
|
726 |
+
"use_cache": kwargs.get("use_cache"),
|
727 |
+
"position_ids": position_ids,
|
728 |
+
"attention_mask": attention_mask,
|
729 |
+
"token_type_ids": token_type_ids,
|
730 |
+
"flip": kwargs.get("flip", None),
|
731 |
+
}
|
732 |
+
|
733 |
+
def forward(
|
734 |
+
self,
|
735 |
+
input_ids=None,
|
736 |
+
past_key_values=None,
|
737 |
+
attention_mask=None,
|
738 |
+
token_type_ids=None,
|
739 |
+
position_ids=None,
|
740 |
+
head_mask=None,
|
741 |
+
inputs_embeds=None,
|
742 |
+
encoder_hidden_states=None,
|
743 |
+
encoder_attention_mask=None,
|
744 |
+
labels=None,
|
745 |
+
use_cache=None,
|
746 |
+
output_attentions=None,
|
747 |
+
output_hidden_states=None,
|
748 |
+
return_dict=None,
|
749 |
+
flip=None,
|
750 |
+
start_slice=None,
|
751 |
+
end_slice=None,
|
752 |
+
full_raw_sequence=None,
|
753 |
+
):
|
754 |
+
r"""
|
755 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
756 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
757 |
+
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
758 |
+
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
|
759 |
+
"""
|
760 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
761 |
+
|
762 |
+
transformer_outputs = self.transformer(
|
763 |
+
input_ids,
|
764 |
+
past_key_values=past_key_values,
|
765 |
+
attention_mask=attention_mask,
|
766 |
+
token_type_ids=token_type_ids,
|
767 |
+
position_ids=position_ids,
|
768 |
+
head_mask=head_mask,
|
769 |
+
inputs_embeds=inputs_embeds,
|
770 |
+
encoder_hidden_states=encoder_hidden_states,
|
771 |
+
encoder_attention_mask=encoder_attention_mask,
|
772 |
+
use_cache=use_cache,
|
773 |
+
output_attentions=output_attentions,
|
774 |
+
output_hidden_states=output_hidden_states,
|
775 |
+
return_dict=return_dict
|
776 |
+
)
|
777 |
+
hidden_states = transformer_outputs[0]
|
778 |
+
|
779 |
+
# Set device for model parallelism
|
780 |
+
if self.model_parallel:
|
781 |
+
torch.cuda.set_device(self.transformer.first_device)
|
782 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
783 |
+
self.MSA_log_prior = self.MSA_log_prior.to(self.lm_head.weight.device)
|
784 |
+
|
785 |
+
lm_logits = self.lm_head(hidden_states)
|
786 |
+
|
787 |
+
loss = None
|
788 |
+
if labels is not None:
|
789 |
+
# Shift so that tokens < n predict n
|
790 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
791 |
+
shift_labels = labels[..., 1:].contiguous()
|
792 |
+
|
793 |
+
if self.retrieval_aggregation_mode is not None:
|
794 |
+
batch_size = input_ids.size(0)
|
795 |
+
|
796 |
+
if self.retrieval_aggregation_mode=="aggregate_indel":
|
797 |
+
assert batch_size==1, "Aggregate indel is only supported for batch size of 1"
|
798 |
+
truncated_sequence_text = full_raw_sequence[0][start_slice[0]:end_slice[0]]
|
799 |
+
if len(truncated_sequence_text)!=shift_logits.shape[1]-1: # shift_logits only has one extra token compared to truncated_sequence_text (the BOS token)
|
800 |
+
print("Tokenization error -- seq length: {} and shift_logits length - 1 : {}".format(len(full_raw_sequence),shift_logits.shape[1]-1))
|
801 |
+
MSA_log_prior, MSA_start, MSA_end = msa_utils.update_retrieved_MSA_log_prior_indel(self, self.MSA_log_prior, self.MSA_start, self.MSA_end, full_raw_sequence[0])
|
802 |
+
|
803 |
+
elif self.retrieval_aggregation_mode=="aggregate_substitution":
|
804 |
+
MSA_log_prior=self.MSA_log_prior
|
805 |
+
MSA_start=self.MSA_start
|
806 |
+
MSA_end=self.MSA_end
|
807 |
+
|
808 |
+
shift_log_probas = torch.log_softmax(shift_logits,dim=-1)
|
809 |
+
fused_shift_log_probas = shift_log_probas.clone()
|
810 |
+
if flip is None:
|
811 |
+
flip = torch.zeros(batch_size).to(fused_shift_log_probas.device)
|
812 |
+
flip = flip > 0
|
813 |
+
|
814 |
+
for seq_index in range(batch_size):
|
815 |
+
min_prior_slice = max(start_slice[seq_index], MSA_start)
|
816 |
+
max_prior_slice = min(end_slice[seq_index], MSA_end)
|
817 |
+
|
818 |
+
if max_prior_slice <= min_prior_slice:
|
819 |
+
print("Non overlapping region detected: min_prior_slice {} and max_prior_slice {}".format(min_prior_slice,max_prior_slice))
|
820 |
+
continue
|
821 |
+
|
822 |
+
slice_prior = MSA_log_prior[min_prior_slice:max_prior_slice,:].to(fused_shift_log_probas.device)
|
823 |
+
if flip[seq_index]:
|
824 |
+
slice_prior = torch.flip(slice_prior,dims=(0,))
|
825 |
+
min_logits_slice = max(0,end_slice[seq_index]-MSA_end)
|
826 |
+
max_logits_slice = min_logits_slice + (max_prior_slice-min_prior_slice)
|
827 |
+
fused_shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] = (1-self.retrieval_inference_weight_RL)*shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] + self.retrieval_inference_weight_RL*slice_prior
|
828 |
+
else:
|
829 |
+
min_logits_slice = max(0, MSA_start-start_slice[seq_index])
|
830 |
+
max_logits_slice = min_logits_slice + (max_prior_slice-min_prior_slice)
|
831 |
+
fused_shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] = (1-self.retrieval_inference_weight_LR)*shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] + self.retrieval_inference_weight_LR*slice_prior
|
832 |
+
|
833 |
+
if self.retrieval_aggregation_mode=="aggregate_indel":
|
834 |
+
try:
|
835 |
+
# If a given residue colume is an added zero-column, then we overwrite prior fusion and only predict based on the autoregressive transformer inference mode.
|
836 |
+
inserted_retrieval_positions = [True if slice_prior[i].sum()==0 else False for i in range(len(slice_prior))]+[True] #Last True is for the end of sentence token
|
837 |
+
fused_shift_log_probas[:,inserted_retrieval_positions,:]=shift_log_probas[:,inserted_retrieval_positions,:]
|
838 |
+
except:
|
839 |
+
print("Error when adding zero column(s) to account for insertion mutations.")
|
840 |
+
|
841 |
+
loss_fct = NLLLoss(reduction='none')
|
842 |
+
loss = loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
|
843 |
+
mask = attention_mask[..., 1:].float()
|
844 |
+
mask[mask==0]=float('nan')
|
845 |
+
loss *= mask
|
846 |
+
loss = nanmean(loss, dim=1).mean()
|
847 |
+
else:
|
848 |
+
loss_fct = CrossEntropyLoss()
|
849 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
850 |
+
fused_shift_log_probas = None
|
851 |
+
|
852 |
+
if not return_dict:
|
853 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
854 |
+
return ((loss,) + output) if loss is not None else output
|
855 |
+
|
856 |
+
return TranceptionCausalLMOutputWithCrossAttentions(
|
857 |
+
loss=loss,
|
858 |
+
logits=lm_logits,
|
859 |
+
past_key_values=transformer_outputs.past_key_values,
|
860 |
+
hidden_states=transformer_outputs.hidden_states,
|
861 |
+
attentions=transformer_outputs.attentions,
|
862 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
863 |
+
fused_shift_log_probas=fused_shift_log_probas
|
864 |
+
)
|
865 |
+
|
866 |
+
|
867 |
+
@staticmethod
|
868 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
869 |
+
"""
|
870 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
871 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
872 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
873 |
+
"""
|
874 |
+
return tuple(
|
875 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
876 |
+
for layer_past in past
|
877 |
+
)
|
878 |
+
|
879 |
+
def score_mutants(self, DMS_data, target_seq, scoring_mirror=True, batch_size_inference=10, num_workers=10, indel_mode=False):
|
880 |
+
"""
|
881 |
+
Method to score mutants in an input DMS file.
|
882 |
+
DMS_data: (dataframe) Dataframe containing the list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring.
|
883 |
+
target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay.
|
884 |
+
scoring_mirror: (bool) Whether to score mutated sequences from both directions (Left->Right and Right->Left).
|
885 |
+
batch_size_inference: (int) Batch size for scoring.
|
886 |
+
num_workers: (int) Number of workers to be used in the data loader.
|
887 |
+
indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
|
888 |
+
"""
|
889 |
+
df = DMS_data.copy()
|
890 |
+
df['mutated_sequence'] = df['mutant'].apply(lambda x: scoring_utils.get_mutated_sequence(target_seq, x)) if not indel_mode else df['mutant']
|
891 |
+
if 'DMS_score' in df: del df['DMS_score']
|
892 |
+
if 'DMS_score_bin' in df: del df['DMS_score_bin']
|
893 |
+
df_left_to_right_slices = scoring_utils.get_sequence_slices(df, target_seq=target_seq, model_context_len = self.config.n_ctx - 2, indel_mode=indel_mode, scoring_window=self.config.scoring_window)
|
894 |
+
print("Scoring sequences from left to right")
|
895 |
+
scores_L_to_R = scoring_utils.get_tranception_scores_mutated_sequences(model=self, mutated_sequence_df=df_left_to_right_slices, batch_size_inference=batch_size_inference, score_var_name='avg_score_L_to_R', len_target_seq=len(target_seq), num_workers=num_workers, indel_mode=indel_mode)
|
896 |
+
if scoring_mirror:
|
897 |
+
print("Scoring sequences from right to left")
|
898 |
+
df_right_to_left_slices = df_left_to_right_slices.copy()
|
899 |
+
df_right_to_left_slices['mutated_sequence'] = df_right_to_left_slices['mutated_sequence'].apply(lambda x: x[::-1])
|
900 |
+
scores_R_to_L = scoring_utils.get_tranception_scores_mutated_sequences(model=self, mutated_sequence_df=df_right_to_left_slices, batch_size_inference=batch_size_inference, score_var_name='avg_score_R_to_L', len_target_seq=len(target_seq), num_workers=num_workers, reverse=True, indel_mode=indel_mode)
|
901 |
+
all_scores = pd.merge(scores_L_to_R, scores_R_to_L, on='mutant', how='left',suffixes=('','_R_to_L'))
|
902 |
+
all_scores['avg_score'] = (all_scores['avg_score_L_to_R'] + all_scores['avg_score_R_to_L']) / 2.0
|
903 |
+
else:
|
904 |
+
all_scores = scores_L_to_R
|
905 |
+
all_scores['avg_score'] = all_scores['avg_score_L_to_R']
|
906 |
+
return all_scores
|
907 |
+
|
908 |
+
def encode_batch(self, protein_sequence, sequence_name="mutated_sequence"):
|
909 |
+
"""
|
910 |
+
Method to process an input AA sequence batch (protein_sequence) and return a tokenized sequence (via the tokenizer associated to the model).
|
911 |
+
"""
|
912 |
+
protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='X', char_replacements='ACDEFGHIKLMNPQRSTVWY')
|
913 |
+
protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='B', char_replacements='DN')
|
914 |
+
protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='J', char_replacements='IL')
|
915 |
+
protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='Z', char_replacements='EQ')
|
916 |
+
return self.config.tokenizer(list(protein_sequence[sequence_name]), add_special_tokens=True, truncation=True, padding=True, max_length=self.config.n_ctx)
|
917 |
+
|
outputs.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from transformers.file_utils import ModelOutput
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class TranceptionCausalLMOutputWithCrossAttentions(ModelOutput):
|
10 |
+
"""
|
11 |
+
Class for Tranception causal language model (or autoregressive) outputs.
|
12 |
+
Args:
|
13 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
14 |
+
Language modeling loss (for next-token prediction).
|
15 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
16 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
17 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
18 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
19 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
20 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
21 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
22 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
23 |
+
sequence_length)`.
|
24 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
25 |
+
heads.
|
26 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
27 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
28 |
+
sequence_length)`.
|
29 |
+
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
30 |
+
cross-attention heads.
|
31 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
32 |
+
Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
|
33 |
+
value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
|
34 |
+
setting. Only relevant if `config.is_decoder = True`.
|
35 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
36 |
+
`past_key_values` input) to speed up sequential decoding.
|
37 |
+
fused_shift_log_probas (`torch.FloatTensor` of shape (batch_size, sequence_length, config.vocab_size), *optional*, returned when config.retrieval_aggregation_mode is not None.
|
38 |
+
log_probas for each residue position after aggregating autoregressive logits and retrieval logits.
|
39 |
+
|
40 |
+
"""
|
41 |
+
|
42 |
+
loss: Optional[torch.FloatTensor] = None
|
43 |
+
logits: torch.FloatTensor = None
|
44 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
45 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
46 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
47 |
+
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
48 |
+
fused_shift_log_probas: Optional[torch.FloatTensor] = None
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e3ee11f9fa7cca8f7859d63769a40ef788f3f32967b5a9ec32bbb2c1f35b0bea
|
3 |
+
size 2872459669
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import scoring_utils, msa_utils
|
utils/dms_utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def DMS_file_cleanup(DMS_filename, target_seq, start_idx=1, end_idx=None, DMS_mutant_column='mutant', DMS_phenotype_name='score', DMS_directionality=1, AA_vocab = "ACDEFGHIKLMNPQRSTVWY"):
|
5 |
+
"""
|
6 |
+
Function to process the raw DMS assay data (eg., removing invalid mutants, aggregate silent mutations)
|
7 |
+
"""
|
8 |
+
DMS_data = pd.read_csv(DMS_filename, low_memory=False)
|
9 |
+
end_idx = start_idx + len(target_seq) - 1 if end_idx is None else end_idx
|
10 |
+
DMS_data['mutant'] = DMS_data[DMS_mutant_column]
|
11 |
+
|
12 |
+
DMS_data=DMS_data[DMS_data['mutant'].notnull()].copy()
|
13 |
+
DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([len(y)>=3 for y in x.split(":")]))].copy() #Mutant triplets should have at least 3 or more characters
|
14 |
+
DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([(y[0] in AA_vocab) and (y[1:-1].isnumeric()) and (y[-1] in AA_vocab) for y in x.split(":")]))].copy()
|
15 |
+
DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([int(y[1:-1])-start_idx >=0 and int(y[1:-1]) <= end_idx for y in x.split(":")]))].copy()
|
16 |
+
DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([y[0]==target_seq[int(y[1:-1])-start_idx] for y in x.split(":")]))].copy()
|
17 |
+
|
18 |
+
DMS_data[DMS_phenotype_name]=pd.to_numeric(DMS_data[DMS_phenotype_name],errors='coerce')
|
19 |
+
DMS_data=DMS_data[np.isfinite(DMS_data[DMS_phenotype_name])]
|
20 |
+
DMS_data.dropna(subset = [DMS_phenotype_name], inplace=True)
|
21 |
+
DMS_data['DMS_score'] = DMS_data[DMS_phenotype_name] * DMS_directionality
|
22 |
+
DMS_data=DMS_data[['mutant','DMS_score']]
|
23 |
+
DMS_data = DMS_data.groupby('mutant').mean().reset_index()
|
24 |
+
|
25 |
+
return DMS_data
|
26 |
+
|
utils/msa_utils.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from collections import defaultdict
|
4 |
+
import random
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from Bio.Align.Applications import ClustalOmegaCommandline
|
8 |
+
|
9 |
+
def filter_msa(msa_data, num_sequences_kept=3):
|
10 |
+
"""
|
11 |
+
Helper function to filter an input MSA msa_data (obtained via process_msa_data) and keep only num_sequences_kept aligned sequences.
|
12 |
+
If the MSA already has fewer sequences than num_sequences_kept, we keep the MSA as is.
|
13 |
+
If filtering, we always keep the first sequence of the MSA (ie. the wild type) by default.
|
14 |
+
Sampling is done without replacement.
|
15 |
+
"""
|
16 |
+
if len(list(msa_data.keys())) <= num_sequences_kept:
|
17 |
+
return msa_data
|
18 |
+
filtered_msa = {}
|
19 |
+
wt_name = next(iter(msa_data))
|
20 |
+
filtered_msa[wt_name] = msa_data[wt_name]
|
21 |
+
del msa_data[wt_name]
|
22 |
+
sequence_names = list(msa_data.keys())
|
23 |
+
sequence_names_sampled = random.sample(sequence_names,k=num_sequences_kept-1)
|
24 |
+
for seq in sequence_names_sampled:
|
25 |
+
filtered_msa[seq] = msa_data[seq]
|
26 |
+
return filtered_msa
|
27 |
+
|
28 |
+
def process_msa_data(MSA_data_file):
|
29 |
+
"""
|
30 |
+
Helper function that takes as input a path to a MSA file (expects a2m format) and returns a dict mapping sequence ID to the corresponding AA sequence.
|
31 |
+
"""
|
32 |
+
msa_data = defaultdict(str)
|
33 |
+
sequence_name = ""
|
34 |
+
with open(MSA_data_file, "r") as msa_file:
|
35 |
+
for i, line in enumerate(msa_file):
|
36 |
+
line = line.rstrip()
|
37 |
+
if line.startswith(">"):
|
38 |
+
sequence_name = line
|
39 |
+
else:
|
40 |
+
msa_data[sequence_name] += line.upper()
|
41 |
+
return msa_data
|
42 |
+
|
43 |
+
def get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab):
|
44 |
+
vocab_size = len(vocab.keys())
|
45 |
+
num_sequences_msa = len(msa_data.keys())
|
46 |
+
one_hots = np.zeros((num_sequences_msa,MSA_end-MSA_start,vocab_size))
|
47 |
+
for i,seq_name in enumerate(msa_data.keys()):
|
48 |
+
sequence = msa_data[seq_name]
|
49 |
+
for j,letter in enumerate(sequence):
|
50 |
+
if letter in vocab:
|
51 |
+
k = vocab[letter]
|
52 |
+
one_hots[i,j,k] = 1.0
|
53 |
+
return one_hots
|
54 |
+
|
55 |
+
def one_hot(sequence_string,vocab):
|
56 |
+
one_hots = np.zeros((len(sequence_string),len(vocab.keys())))
|
57 |
+
for j,letter in enumerate(sequence_string):
|
58 |
+
if letter in vocab:
|
59 |
+
k = vocab[letter]
|
60 |
+
one_hots[j,k] = 1.0
|
61 |
+
return one_hots.flatten()
|
62 |
+
|
63 |
+
def get_msa_prior(MSA_data_file, MSA_weight_file_name, MSA_start, MSA_end, len_target_seq, vocab, retrieval_aggregation_mode="aggregate_substitution", filter_MSA=True, verbose=False):
|
64 |
+
"""
|
65 |
+
Function to enable retrieval inference mode, via computation of (weighted) pseudocounts of AAs at each position of the retrieved MSA.
|
66 |
+
MSA_data_file: (string) path to MSA file (expects a2m format).
|
67 |
+
MSA_weight_file_name: (string) path to sequence weights in MSA.
|
68 |
+
MSA_start: (int) Sequence position that the MSA starts at (1-indexing).
|
69 |
+
MSA_end: (int) Sequence position that the MSA ends at (1-indexing).
|
70 |
+
len_target_seq: (int) Full length of sequence to be scored.
|
71 |
+
vocab: (dict) Vocabulary of the tokenizer.
|
72 |
+
retrieval_aggregation_mode: (string) Mode for retrieval inference (aggregate_substitution Vs aggregate_indel). If None, places a uniform prior over each token.
|
73 |
+
filter_MSA: (bool) Whether to filter out sequences with very low hamming similarity (< 0.2) to the reference sequence in the MSA (first sequence).
|
74 |
+
verbose: (bool) Whether to print to the console processing details along the way.
|
75 |
+
"""
|
76 |
+
msa_data = process_msa_data(MSA_data_file)
|
77 |
+
vocab_size = len(vocab.keys())
|
78 |
+
if verbose: print("Target seq len is {}, MSA length is {}, start position is {}, end position is {} and vocab size is {}".format(len_target_seq,MSA_end-MSA_start,MSA_start,MSA_end,vocab_size))
|
79 |
+
|
80 |
+
if filter_MSA:
|
81 |
+
if verbose: print("Num sequences in MSA pre filtering: {}".format(len(msa_data.keys())))
|
82 |
+
list_sequence_names = list(msa_data.keys())
|
83 |
+
focus_sequence_name = list(msa_data.keys())[0]
|
84 |
+
ref_sequence_hot = one_hot(msa_data[focus_sequence_name],vocab)
|
85 |
+
for sequence_name in list_sequence_names:
|
86 |
+
seq_hot = one_hot(msa_data[sequence_name],vocab)
|
87 |
+
hamming_similarity_seq_ref = np.dot(ref_sequence_hot,seq_hot) / np.dot(ref_sequence_hot,ref_sequence_hot)
|
88 |
+
if hamming_similarity_seq_ref < 0.2:
|
89 |
+
del msa_data[sequence_name]
|
90 |
+
if verbose: print("Num sequences in MSA post filtering: {}".format(len(msa_data.keys())))
|
91 |
+
|
92 |
+
if MSA_weight_file_name is not None:
|
93 |
+
if verbose: print("Using weights in {} for sequences in MSA.".format(MSA_weight_file_name))
|
94 |
+
assert os.path.exists(MSA_weight_file_name), "Weights file not located on disk."
|
95 |
+
MSA_EVE = MSA_processing(
|
96 |
+
MSA_location=MSA_data_file,
|
97 |
+
use_weights=True,
|
98 |
+
weights_location=MSA_weight_file_name
|
99 |
+
)
|
100 |
+
#We scan through all sequences to see if we have a weight for them as per EVE pre-processing. We drop them otherwise.
|
101 |
+
dropped_sequences=0
|
102 |
+
list_sequence_names = list(msa_data.keys())
|
103 |
+
MSA_weight=[]
|
104 |
+
for sequence_name in list_sequence_names:
|
105 |
+
if sequence_name not in MSA_EVE.seq_name_to_sequence:
|
106 |
+
dropped_sequences +=1
|
107 |
+
del msa_data[sequence_name]
|
108 |
+
else:
|
109 |
+
MSA_weight.append(MSA_EVE.seq_name_to_weight[sequence_name])
|
110 |
+
if verbose: print("Dropped {} sequences from MSA due to absent sequence weights".format(dropped_sequences))
|
111 |
+
else:
|
112 |
+
MSA_weight = [1] * len(list(msa_data.keys()))
|
113 |
+
|
114 |
+
if retrieval_aggregation_mode=="aggregate_substitution" or retrieval_aggregation_mode=="aggregate_indel":
|
115 |
+
one_hots = get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab)
|
116 |
+
MSA_weight = np.expand_dims(np.array(MSA_weight),axis=(1,2))
|
117 |
+
base_rate = 1e-5
|
118 |
+
base_rates = np.ones_like(one_hots) * base_rate
|
119 |
+
weighted_one_hots = (one_hots + base_rates) * MSA_weight
|
120 |
+
MSA_weight_norm_counts = weighted_one_hots.sum(axis=-1).sum(axis=0)
|
121 |
+
MSA_weight_norm_counts = np.tile(MSA_weight_norm_counts.reshape(-1,1), (1,vocab_size))
|
122 |
+
one_hots_avg = weighted_one_hots.sum(axis=0) / MSA_weight_norm_counts
|
123 |
+
msa_prior = np.zeros((len_target_seq,vocab_size))
|
124 |
+
msa_prior[MSA_start:MSA_end,:]=one_hots_avg
|
125 |
+
else:
|
126 |
+
msa_prior = np.ones((len_target_seq,vocab_size)) / vocab_size
|
127 |
+
|
128 |
+
if verbose:
|
129 |
+
for idx, position in enumerate(msa_prior):
|
130 |
+
if len(position)!=25:
|
131 |
+
print("Size error")
|
132 |
+
if not round(position.sum(),2)==1.0:
|
133 |
+
print("Position at index {} does not add up to 1: {}".format(idx, position.sum()))
|
134 |
+
|
135 |
+
return msa_prior
|
136 |
+
|
137 |
+
|
138 |
+
def update_retrieved_MSA_log_prior_indel(model, MSA_log_prior, MSA_start, MSA_end, full_raw_sequence):
|
139 |
+
"""
|
140 |
+
Function to process MSA when scoring indels.
|
141 |
+
To identify positions to add / remove in the retrieved MSA, we append and align the sequence to be scored to the original MSA for that protein family with Clustal Omega.
|
142 |
+
If the original MSA is relatively deep (over 100k sequences), we sample (by default) 100k rows at random from that MSA to speed computations.
|
143 |
+
MSA sampling is performed only once (for the first sequence to be scored). Subsequent scoring use the same MSA sample.
|
144 |
+
"""
|
145 |
+
if not os.path.isdir(model.MSA_folder + os.sep + "Sampled"):
|
146 |
+
os.mkdir(model.MSA_folder + os.sep + "Sampled")
|
147 |
+
sampled_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Sampled_" + model.MSA_filename.split(os.sep)[-1]
|
148 |
+
|
149 |
+
if not os.path.exists(sampled_MSA_location):
|
150 |
+
msa_data = process_msa_data(model.MSA_filename)
|
151 |
+
msa_data_sampled = filter_msa(msa_data, num_sequences_kept=100000) #If MSA has less than 100k sequences, the sample is identical to original MSA
|
152 |
+
with open(sampled_MSA_location, 'w') as sampled_write_location:
|
153 |
+
for index, key in enumerate(msa_data_sampled):
|
154 |
+
key_name = ">REFERENCE_SEQUENCE" if index==0 else key
|
155 |
+
msa_data_sampled[key] = msa_data_sampled[key].upper()
|
156 |
+
msa_data_sampled[key] = msa_data_sampled[key].replace(".","-")
|
157 |
+
sampled_write_location.write(key_name+"\n"+"\n".join([msa_data_sampled[key][i:i+80] for i in range(0, len(msa_data_sampled[key]), 80)])+"\n")
|
158 |
+
|
159 |
+
seq_to_align_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Seq_to_align_" + model.MSA_filename.split(os.sep)[-1]
|
160 |
+
sequence_text_split = [full_raw_sequence[i:i+80] for i in range(0, len(full_raw_sequence), 80)]
|
161 |
+
sequence_text_split_split_join = "\n".join([">SEQ_TO_SCORE"]+sequence_text_split)
|
162 |
+
os.system("echo '"+sequence_text_split_split_join+"' > "+seq_to_align_location)
|
163 |
+
|
164 |
+
expanded_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Expanded_" + model.MSA_filename.split(os.sep)[-1]
|
165 |
+
clustalw_cline = ClustalOmegaCommandline(cmd=model.config.clustal_omega_location,
|
166 |
+
profile1=sampled_MSA_location,
|
167 |
+
profile2=seq_to_align_location,
|
168 |
+
outfile=expanded_MSA_location,
|
169 |
+
force=True)
|
170 |
+
stdout, stderr = clustalw_cline()
|
171 |
+
msa_data = process_msa_data(expanded_MSA_location)
|
172 |
+
aligned_seqA, aligned_seqB = msa_data[">SEQ_TO_SCORE"], msa_data[">REFERENCE_SEQUENCE"]
|
173 |
+
try:
|
174 |
+
keep_column=[]
|
175 |
+
for column_index_pairwise_alignment in range(len(aligned_seqA)):
|
176 |
+
if aligned_seqA[column_index_pairwise_alignment]=="-" and aligned_seqB[column_index_pairwise_alignment]=="-":
|
177 |
+
continue
|
178 |
+
elif aligned_seqA[column_index_pairwise_alignment]=="-":
|
179 |
+
keep_column.append(False)
|
180 |
+
elif aligned_seqB[column_index_pairwise_alignment]=="-":
|
181 |
+
MSA_log_prior=torch.cat((MSA_log_prior[:column_index_pairwise_alignment], torch.zeros(MSA_log_prior.shape[1]).view(1,-1).cuda(), MSA_log_prior[column_index_pairwise_alignment:]),dim=0)
|
182 |
+
keep_column.append(True) #keep the zero column we just added
|
183 |
+
else:
|
184 |
+
keep_column.append(True)
|
185 |
+
MSA_log_prior = MSA_log_prior[keep_column]
|
186 |
+
MSA_end = MSA_start + len(MSA_log_prior)
|
187 |
+
except:
|
188 |
+
print("Error when processing the following alignment: {}".format(expanded_MSA_location))
|
189 |
+
return MSA_log_prior, MSA_start, MSA_end
|
190 |
+
|
191 |
+
class MSA_processing:
|
192 |
+
def __init__(self,
|
193 |
+
MSA_location="",
|
194 |
+
theta=0.2,
|
195 |
+
use_weights=True,
|
196 |
+
weights_location="./data/weights",
|
197 |
+
preprocess_MSA=True,
|
198 |
+
threshold_sequence_frac_gaps=0.5,
|
199 |
+
threshold_focus_cols_frac_gaps=0.3,
|
200 |
+
remove_sequences_with_indeterminate_AA_in_focus_cols=True
|
201 |
+
):
|
202 |
+
|
203 |
+
"""
|
204 |
+
This MSA_processing class is directly borrowed from the EVE codebase: https://github.com/OATML-Markslab/EVE
|
205 |
+
|
206 |
+
Parameters:
|
207 |
+
- msa_location: (path) Location of the MSA data. Constraints on input MSA format:
|
208 |
+
- focus_sequence is the first one in the MSA data
|
209 |
+
- first line is structured as follows: ">focus_seq_name/start_pos-end_pos" (e.g., >SPIKE_SARS2/310-550)
|
210 |
+
- corespondding sequence data located on following line(s)
|
211 |
+
- then all other sequences follow with ">name" on first line, corresponding data on subsequent lines
|
212 |
+
- theta: (float) Sequence weighting hyperparameter. Generally: Prokaryotic and eukaryotic families = 0.2; Viruses = 0.01
|
213 |
+
- use_weights: (bool) If False, sets all sequence weights to 1. If True, checks weights_location -- if non empty uses that;
|
214 |
+
otherwise compute weights from scratch and store them at weights_location
|
215 |
+
- weights_location: (path) Location to load from/save to the sequence weights
|
216 |
+
- preprocess_MSA: (bool) performs pre-processing of MSA to remove short fragments and positions that are not well covered.
|
217 |
+
- threshold_sequence_frac_gaps: (float, between 0 and 1) Threshold value to define fragments
|
218 |
+
- sequences with a fraction of gap characters above threshold_sequence_frac_gaps are removed
|
219 |
+
- default is set to 0.5 (i.e., fragments with 50% or more gaps are removed)
|
220 |
+
- threshold_focus_cols_frac_gaps: (float, between 0 and 1) Threshold value to define focus columns
|
221 |
+
- positions with a fraction of gap characters above threshold_focus_cols_pct_gaps will be set to lower case (and not included in the focus_cols)
|
222 |
+
- default is set to 0.3 (i.e., focus positions are the ones with 30% of gaps or less, i.e., 70% or more residue occupancy)
|
223 |
+
- remove_sequences_with_indeterminate_AA_in_focus_cols: (bool) Remove all sequences that have indeterminate AA (e.g., B, J, X, Z) at focus positions of the wild type
|
224 |
+
"""
|
225 |
+
np.random.seed(2021)
|
226 |
+
self.MSA_location = MSA_location
|
227 |
+
self.weights_location = weights_location
|
228 |
+
self.theta = theta
|
229 |
+
self.alphabet = "ACDEFGHIKLMNPQRSTVWY"
|
230 |
+
self.use_weights = use_weights
|
231 |
+
self.preprocess_MSA = preprocess_MSA
|
232 |
+
self.threshold_sequence_frac_gaps = threshold_sequence_frac_gaps
|
233 |
+
self.threshold_focus_cols_frac_gaps = threshold_focus_cols_frac_gaps
|
234 |
+
self.remove_sequences_with_indeterminate_AA_in_focus_cols = remove_sequences_with_indeterminate_AA_in_focus_cols
|
235 |
+
|
236 |
+
self.gen_alignment()
|
237 |
+
|
238 |
+
def gen_alignment(self, verbose=False):
|
239 |
+
""" Read training alignment and store basics in class instance """
|
240 |
+
self.aa_dict = {}
|
241 |
+
for i,aa in enumerate(self.alphabet):
|
242 |
+
self.aa_dict[aa] = i
|
243 |
+
|
244 |
+
self.seq_name_to_sequence = defaultdict(str)
|
245 |
+
name = ""
|
246 |
+
with open(self.MSA_location, "r") as msa_data:
|
247 |
+
for i, line in enumerate(msa_data):
|
248 |
+
line = line.rstrip()
|
249 |
+
if line.startswith(">"):
|
250 |
+
name = line
|
251 |
+
if i==0:
|
252 |
+
self.focus_seq_name = name
|
253 |
+
else:
|
254 |
+
self.seq_name_to_sequence[name] += line
|
255 |
+
|
256 |
+
|
257 |
+
## MSA pre-processing to remove inadequate columns and sequences
|
258 |
+
if self.preprocess_MSA:
|
259 |
+
msa_df = pd.DataFrame.from_dict(self.seq_name_to_sequence, orient='index', columns=['sequence'])
|
260 |
+
# Data clean up
|
261 |
+
msa_df.sequence = msa_df.sequence.apply(lambda x: x.replace(".","-")).apply(lambda x: ''.join([aa.upper() for aa in x]))
|
262 |
+
# Remove columns that would be gaps in the wild type
|
263 |
+
non_gap_wt_cols = [aa!='-' for aa in msa_df.sequence[self.focus_seq_name]]
|
264 |
+
msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa for aa,non_gap_ind in zip(x, non_gap_wt_cols) if non_gap_ind]))
|
265 |
+
assert 0.0 <= self.threshold_sequence_frac_gaps <= 1.0,"Invalid fragment filtering parameter"
|
266 |
+
assert 0.0 <= self.threshold_focus_cols_frac_gaps <= 1.0,"Invalid focus position filtering parameter"
|
267 |
+
msa_array = np.array([list(seq) for seq in msa_df.sequence])
|
268 |
+
gaps_array = np.array(list(map(lambda seq: [aa=='-' for aa in seq], msa_array)))
|
269 |
+
# Identify fragments with too many gaps
|
270 |
+
seq_gaps_frac = gaps_array.mean(axis=1)
|
271 |
+
seq_below_threshold = seq_gaps_frac <= self.threshold_sequence_frac_gaps
|
272 |
+
if verbose: print("Proportion of sequences dropped due to fraction of gaps: "+str(round(float(1 - seq_below_threshold.sum()/seq_below_threshold.shape)*100,2))+"%")
|
273 |
+
# Identify focus columns
|
274 |
+
columns_gaps_frac = gaps_array[seq_below_threshold].mean(axis=0)
|
275 |
+
index_cols_below_threshold = columns_gaps_frac <= self.threshold_focus_cols_frac_gaps
|
276 |
+
if verbose: print("Proportion of non-focus columns removed: "+str(round(float(1 - index_cols_below_threshold.sum()/index_cols_below_threshold.shape)*100,2))+"%")
|
277 |
+
# Lower case non focus cols and filter fragment sequences
|
278 |
+
msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa.upper() if upper_case_ind else aa.lower() for aa, upper_case_ind in zip(x, index_cols_below_threshold)]))
|
279 |
+
msa_df = msa_df[seq_below_threshold]
|
280 |
+
# Overwrite seq_name_to_sequence with clean version
|
281 |
+
self.seq_name_to_sequence = defaultdict(str)
|
282 |
+
for seq_idx in range(len(msa_df['sequence'])):
|
283 |
+
self.seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx]
|
284 |
+
|
285 |
+
self.focus_seq = self.seq_name_to_sequence[self.focus_seq_name]
|
286 |
+
self.focus_cols = [ix for ix, s in enumerate(self.focus_seq) if s == s.upper() and s!='-']
|
287 |
+
self.focus_seq_trimmed = [self.focus_seq[ix] for ix in self.focus_cols]
|
288 |
+
self.seq_len = len(self.focus_cols)
|
289 |
+
self.alphabet_size = len(self.alphabet)
|
290 |
+
|
291 |
+
# Connect local sequence index with uniprot index (index shift inferred from 1st row of MSA)
|
292 |
+
focus_loc = self.focus_seq_name.split("/")[-1]
|
293 |
+
start,stop = focus_loc.split("-")
|
294 |
+
self.focus_start_loc = int(start)
|
295 |
+
self.focus_stop_loc = int(stop)
|
296 |
+
self.uniprot_focus_col_to_wt_aa_dict \
|
297 |
+
= {idx_col+int(start):self.focus_seq[idx_col] for idx_col in self.focus_cols}
|
298 |
+
self.uniprot_focus_col_to_focus_idx \
|
299 |
+
= {idx_col+int(start):idx_col for idx_col in self.focus_cols}
|
300 |
+
|
301 |
+
# Move all letters to CAPS; keeps focus columns only
|
302 |
+
self.raw_seq_name_to_sequence = self.seq_name_to_sequence.copy()
|
303 |
+
for seq_name,sequence in self.seq_name_to_sequence.items():
|
304 |
+
sequence = sequence.replace(".","-")
|
305 |
+
self.seq_name_to_sequence[seq_name] = [sequence[ix].upper() for ix in self.focus_cols]
|
306 |
+
|
307 |
+
# Remove sequences that have indeterminate AA (e.g., B, J, X, Z) in the focus columns
|
308 |
+
if self.remove_sequences_with_indeterminate_AA_in_focus_cols:
|
309 |
+
alphabet_set = set(list(self.alphabet))
|
310 |
+
seq_names_to_remove = []
|
311 |
+
for seq_name,sequence in self.seq_name_to_sequence.items():
|
312 |
+
for letter in sequence:
|
313 |
+
if letter not in alphabet_set and letter != "-":
|
314 |
+
seq_names_to_remove.append(seq_name)
|
315 |
+
continue
|
316 |
+
seq_names_to_remove = list(set(seq_names_to_remove))
|
317 |
+
for seq_name in seq_names_to_remove:
|
318 |
+
del self.seq_name_to_sequence[seq_name]
|
319 |
+
|
320 |
+
# Encode the sequences
|
321 |
+
self.one_hot_encoding = np.zeros((len(self.seq_name_to_sequence.keys()),len(self.focus_cols),len(self.alphabet)))
|
322 |
+
if verbose: print("One-hot encoded sequences shape:" + str(self.one_hot_encoding.shape))
|
323 |
+
for i,seq_name in enumerate(self.seq_name_to_sequence.keys()):
|
324 |
+
sequence = self.seq_name_to_sequence[seq_name]
|
325 |
+
for j,letter in enumerate(sequence):
|
326 |
+
if letter in self.aa_dict:
|
327 |
+
k = self.aa_dict[letter]
|
328 |
+
self.one_hot_encoding[i,j,k] = 1.0
|
329 |
+
|
330 |
+
if self.use_weights:
|
331 |
+
try:
|
332 |
+
self.weights = np.load(file=self.weights_location)
|
333 |
+
if verbose: print("Loaded sequence weights from disk")
|
334 |
+
except:
|
335 |
+
if verbose: print ("Computing sequence weights")
|
336 |
+
list_seq = self.one_hot_encoding
|
337 |
+
list_seq = list_seq.reshape((list_seq.shape[0], list_seq.shape[1] * list_seq.shape[2]))
|
338 |
+
def compute_weight(seq):
|
339 |
+
number_non_empty_positions = np.dot(seq,seq)
|
340 |
+
if number_non_empty_positions>0:
|
341 |
+
denom = np.dot(list_seq,seq) / np.dot(seq,seq)
|
342 |
+
denom = np.sum(denom > 1 - self.theta)
|
343 |
+
return 1/denom
|
344 |
+
else:
|
345 |
+
return 0.0 #return 0 weight if sequence is fully empty
|
346 |
+
self.weights = np.array(list(map(compute_weight,list_seq)))
|
347 |
+
np.save(file=self.weights_location, arr=self.weights)
|
348 |
+
else:
|
349 |
+
# If not using weights, use an isotropic weight matrix
|
350 |
+
if verbose: print("Not weighting sequence data")
|
351 |
+
self.weights = np.ones(self.one_hot_encoding.shape[0])
|
352 |
+
|
353 |
+
self.Neff = np.sum(self.weights)
|
354 |
+
self.num_sequences = self.one_hot_encoding.shape[0]
|
355 |
+
self.seq_name_to_weight={}
|
356 |
+
for i,seq_name in enumerate(self.seq_name_to_sequence.keys()):
|
357 |
+
self.seq_name_to_weight[seq_name]=self.weights[i]
|
358 |
+
|
359 |
+
if verbose:
|
360 |
+
print ("Neff =",str(self.Neff))
|
361 |
+
print ("Data Shape =",self.one_hot_encoding.shape)
|
utils/scoring_utils.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tqdm
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.nn import CrossEntropyLoss, NLLLoss
|
9 |
+
from torch.utils.data.sampler import Sampler, SequentialSampler
|
10 |
+
|
11 |
+
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerFast
|
12 |
+
from datasets import Dataset
|
13 |
+
|
14 |
+
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
|
15 |
+
|
16 |
+
def get_mutated_sequence(focus_seq, mutant, start_idx=1, AA_vocab=AA_vocab):
|
17 |
+
"""
|
18 |
+
Helper function that mutates an input sequence (focus_seq) via an input mutation triplet (substitutions only).
|
19 |
+
Mutation triplet are typically based on 1-indexing: start_idx is used for switching to 0-indexing.
|
20 |
+
"""
|
21 |
+
mutated_seq = list(focus_seq)
|
22 |
+
for mutation in mutant.split(":"):
|
23 |
+
try:
|
24 |
+
from_AA, position, to_AA = mutation[0], int(mutation[1:-1]), mutation[-1]
|
25 |
+
except:
|
26 |
+
print("Issue with mutant: "+str(mutation))
|
27 |
+
relative_position = position - start_idx
|
28 |
+
assert (from_AA==focus_seq[relative_position]), "Invalid from_AA or mutant position: "+str(mutation)+" from_AA: "+str(from_AA) + " relative pos: "+str(relative_position) + " focus_seq: "+str(focus_seq)
|
29 |
+
assert (to_AA in AA_vocab) , "Mutant to_AA is invalid: "+str(mutation)
|
30 |
+
mutated_seq[relative_position] = to_AA
|
31 |
+
return "".join(mutated_seq)
|
32 |
+
|
33 |
+
def nanmean(v, *args, inplace=False, **kwargs):
|
34 |
+
if not inplace:
|
35 |
+
v = v.clone()
|
36 |
+
is_nan = torch.isnan(v)
|
37 |
+
v[is_nan] = 0
|
38 |
+
return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
|
39 |
+
|
40 |
+
def nansum(v, *args, inplace=False, **kwargs):
|
41 |
+
if not inplace:
|
42 |
+
v = v.clone()
|
43 |
+
is_nan = torch.isnan(v)
|
44 |
+
v[is_nan] = 0
|
45 |
+
return v.sum(*args, **kwargs)
|
46 |
+
|
47 |
+
def get_optimal_window(mutation_position_relative, seq_len_wo_special, model_window):
|
48 |
+
"""
|
49 |
+
Helper function that selects an optimal sequence window that fits the maximum model context size.
|
50 |
+
If the sequence length is less than the maximum context size, the full sequence is returned.
|
51 |
+
"""
|
52 |
+
half_model_window = model_window // 2
|
53 |
+
if seq_len_wo_special <= model_window:
|
54 |
+
return [0,seq_len_wo_special]
|
55 |
+
elif mutation_position_relative < half_model_window:
|
56 |
+
return [0,model_window]
|
57 |
+
elif mutation_position_relative >= seq_len_wo_special - half_model_window:
|
58 |
+
return [seq_len_wo_special - model_window, seq_len_wo_special]
|
59 |
+
else:
|
60 |
+
return [max(0,mutation_position_relative-half_model_window), min(seq_len_wo_special,mutation_position_relative+half_model_window)]
|
61 |
+
|
62 |
+
def sequence_replace_single(sequence, char_to_replace, char_replacements):
|
63 |
+
char_replacements = list(char_replacements)
|
64 |
+
positions = [m.start() for m in re.finditer(char_to_replace, sequence)]
|
65 |
+
replacements = np.random.choice(a=char_replacements, size=len(positions), replace=True)
|
66 |
+
sequence=list(sequence)
|
67 |
+
for idx, position in enumerate(positions):
|
68 |
+
sequence[position]=replacements[idx]
|
69 |
+
return ''.join(sequence)
|
70 |
+
|
71 |
+
def sequence_replace(sequences, char_to_replace, char_replacements):
|
72 |
+
"""
|
73 |
+
Helper function that replaces all Amino Acids passsed in via char_to_replace (as a string of AAs) with Amino Acids sampled from char_replacements (also a string of eligible AAs).
|
74 |
+
"""
|
75 |
+
return [sequence_replace_single(sequence, char_to_replace, char_replacements) for sequence in sequences]
|
76 |
+
|
77 |
+
def get_tranception_scores_mutated_sequences(model, mutated_sequence_df, batch_size_inference, score_var_name, len_target_seq, num_workers=10, reverse=False, indel_mode=False):
|
78 |
+
"""
|
79 |
+
Helper function that takes as input a set of mutated sequences (in a pandas dataframe) and returns scores for each mutation (delta log likelihood wrt wild type sequence).
|
80 |
+
"""
|
81 |
+
scores = {}
|
82 |
+
scores['mutant']=[]
|
83 |
+
scores['window_start']=[]
|
84 |
+
scores['window_end']=[]
|
85 |
+
scores['score']=[]
|
86 |
+
with torch.no_grad():
|
87 |
+
ds = Dataset.from_pandas(mutated_sequence_df)
|
88 |
+
ds.set_transform(model.encode_batch)
|
89 |
+
data_collator = DataCollatorForLanguageModeling(
|
90 |
+
tokenizer=model.config.tokenizer,
|
91 |
+
mlm=False)
|
92 |
+
sampler = SequentialSampler(ds)
|
93 |
+
ds_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size_inference, sampler=sampler, collate_fn=data_collator, num_workers=num_workers, pin_memory=True, drop_last=False)
|
94 |
+
mutant_index=0
|
95 |
+
for encoded_batch in tqdm.tqdm(ds_loader):
|
96 |
+
full_batch_length = len(encoded_batch['input_ids'])
|
97 |
+
scores['mutant'] += list(mutated_sequence_df['mutant'][mutant_index:mutant_index+full_batch_length])
|
98 |
+
window_start = np.array(mutated_sequence_df['window_start'][mutant_index:mutant_index+full_batch_length])
|
99 |
+
scores['window_start'] += list(window_start)
|
100 |
+
window_end = np.array(mutated_sequence_df['window_end'][mutant_index:mutant_index+full_batch_length])
|
101 |
+
scores['window_end'] += list(window_end)
|
102 |
+
full_raw_sequence = np.array(mutated_sequence_df['full_raw_sequence'][mutant_index:mutant_index+full_batch_length])
|
103 |
+
for k, v in encoded_batch.items():
|
104 |
+
if isinstance(v, torch.Tensor):
|
105 |
+
encoded_batch[k] = v.to(model.device)
|
106 |
+
shift_labels = encoded_batch['labels'][..., 1:].contiguous()
|
107 |
+
if (hasattr(model.config,"retrieval_aggregation_mode")) and (model.config.retrieval_aggregation_mode is not None):
|
108 |
+
if reverse:
|
109 |
+
encoded_batch['flip']=torch.tensor([1]*full_batch_length)
|
110 |
+
encoded_batch['start_slice']=window_start
|
111 |
+
encoded_batch['end_slice']=window_end
|
112 |
+
encoded_batch['full_raw_sequence'] = full_raw_sequence #only mutated_sequence is flipped if the scoring_mirror branch of score_mutants. No need to flip full_raw_sequence for MSA re-aligning
|
113 |
+
fused_shift_log_probas=model(**encoded_batch,return_dict=True).fused_shift_log_probas
|
114 |
+
loss_fct = NLLLoss(reduction='none')
|
115 |
+
loss = - loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
|
116 |
+
else:
|
117 |
+
lm_logits=model(**encoded_batch,return_dict=True).logits
|
118 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
119 |
+
loss_fct = CrossEntropyLoss(reduction='none')
|
120 |
+
loss = - loss_fct(input=shift_logits.view(-1, shift_logits.size(-1)), target=shift_labels.view(-1)).view(shift_logits.shape[0],shift_logits.shape[1])
|
121 |
+
mask = encoded_batch['attention_mask'][..., 1:].float()
|
122 |
+
mask[mask==0]=float('nan')
|
123 |
+
loss *= mask
|
124 |
+
loss = nanmean(loss, dim=1)
|
125 |
+
scores_batch = list(loss.cpu().numpy())
|
126 |
+
full_batch_length = len(encoded_batch['input_ids'])
|
127 |
+
scores['score'] += scores_batch
|
128 |
+
mutant_index+=full_batch_length
|
129 |
+
scores = pd.DataFrame(scores)
|
130 |
+
scores_mutated_seq = scores[scores.mutant != 'wt']
|
131 |
+
scores_wt = scores[scores.mutant == 'wt']
|
132 |
+
delta_scores = pd.merge(scores_mutated_seq,scores_wt,how='left',on=['window_start'],suffixes=('','_wt'))
|
133 |
+
delta_scores[score_var_name] = delta_scores['score'] - delta_scores['score_wt']
|
134 |
+
delta_scores=delta_scores[['mutant',score_var_name]].groupby('mutant').mean().reset_index()
|
135 |
+
return delta_scores
|
136 |
+
|
137 |
+
def get_sequence_slices(df, target_seq, model_context_len, start_idx=1, scoring_window="optimal", indel_mode=False):
|
138 |
+
"""
|
139 |
+
Helper function that takes as input a (pandas) dataframe df that contains a list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring.
|
140 |
+
It returns a processed DMS in which sequences have been sliced to satisfy the maximum context window of the model.
|
141 |
+
df: (dataframe) Input dataframe to be processed
|
142 |
+
target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay.
|
143 |
+
model_context_len: (int) Maximum context size for the model.
|
144 |
+
start_idx: (int) Integer to move to 0-indexing of positions (mutation triplet are typically based on 1-indexing).
|
145 |
+
scoring_window: (string) Method to slice sequences longer than maximum context size:
|
146 |
+
- optimal selects a single window as large as possible via the get_optimal_window function (this is the default)
|
147 |
+
- sliding splits the full sequence in contiguous (non-overlapping) chunks that are of size equal to the max context (except the last chunk which may be shorter)
|
148 |
+
indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
|
149 |
+
Note: when scoring indels for sequences that would be longer than the model max context length, it is preferable to use the "sliding" scoring_window. Use "optimal" otherwise.
|
150 |
+
"""
|
151 |
+
len_target_seq = len(target_seq)
|
152 |
+
num_mutants = len(df['mutant'])
|
153 |
+
df=df.reset_index(drop=True)
|
154 |
+
if scoring_window=="optimal":
|
155 |
+
df['mutation_barycenter'] = df['mutant'].apply(lambda x: int(np.array([int(mutation[1:-1]) - start_idx for mutation in x.split(':')]).mean())) if not indel_mode else df['mutant'].apply(lambda x: len(x)//2)
|
156 |
+
df['scoring_optimal_window'] = df['mutation_barycenter'].apply(lambda x: get_optimal_window(x, len_target_seq, model_context_len)) if not indel_mode else df['mutant'].apply(lambda x: (0,len(x)))
|
157 |
+
df['full_raw_sequence'] = df['mutated_sequence']
|
158 |
+
df['mutated_sequence'] = [df['mutated_sequence'][index][df['scoring_optimal_window'][index][0]:df['scoring_optimal_window'][index][1]] for index in range(num_mutants)]
|
159 |
+
df['window_start'] = df['scoring_optimal_window'].map(lambda x: x[0])
|
160 |
+
df['window_end'] = df['scoring_optimal_window'].map(lambda x: x[1])
|
161 |
+
del df['scoring_optimal_window']
|
162 |
+
df_wt=df.copy()
|
163 |
+
df_wt['mutant'] = ['wt'] * num_mutants
|
164 |
+
df_wt['full_raw_sequence'] = [target_seq] * num_mutants
|
165 |
+
if indel_mode: # For indels, we set the wild type reference to be always the same (full length) sequence. We assume here that the length is lower than model context size (otherwise use "Sliding")
|
166 |
+
df_wt['mutation_barycenter'] = [len_target_seq // 2] * num_mutants
|
167 |
+
df_wt['window_end'] = df_wt['full_raw_sequence'].map(lambda x:len(x))
|
168 |
+
df_wt['mutated_sequence'] = [target_seq[df_wt['window_start'][index]:df_wt['window_end'][index]] for index in range(num_mutants)]
|
169 |
+
df = pd.concat([df,df_wt], axis=0)
|
170 |
+
df = df.drop_duplicates()
|
171 |
+
elif scoring_window=="sliding":
|
172 |
+
len_target_seq = len(target_seq)
|
173 |
+
num_windows = 1 + int( len_target_seq / model_context_len)
|
174 |
+
df_list=[]
|
175 |
+
start=0
|
176 |
+
for window_index in range(1, num_windows+1):
|
177 |
+
df_sliced = df.copy()
|
178 |
+
df_sliced['full_raw_sequence'] = df_sliced['mutated_sequence']
|
179 |
+
df_sliced['mutated_sequence'] = df_sliced['mutated_sequence'].map(lambda x: x[start:start+model_context_len])
|
180 |
+
df_sliced['window_start'] = [start] * num_mutants
|
181 |
+
df_sliced['window_end'] = df_sliced['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len))
|
182 |
+
df_sliced_wt = df_sliced.copy()
|
183 |
+
df_sliced_wt['mutant'] = ['wt'] * num_mutants
|
184 |
+
df_sliced_wt['full_raw_sequence'] = [target_seq] * num_mutants
|
185 |
+
df_sliced_wt['mutated_sequence'] = df_sliced_wt['full_raw_sequence'].map(lambda x: x[start:start+model_context_len])
|
186 |
+
df_sliced_wt['window_end'] = df_sliced_wt['full_raw_sequence'].map(lambda x: min(len(x), start+model_context_len)) #Need to adjust end index if WT and sequence are not same full length
|
187 |
+
df_list.append(df_sliced)
|
188 |
+
df_list.append(df_sliced_wt)
|
189 |
+
start += model_context_len
|
190 |
+
df_final = pd.concat(df_list,axis=0)
|
191 |
+
df = df_final.drop_duplicates()
|
192 |
+
return df.reset_index(drop=True)
|
utils/tokenizers/Basic_tokenizer
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[CLS]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SEP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":3,"special":true,"content":"[PAD]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":4,"special":true,"content":"[MASK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":{"type":"TemplateProcessing","single":[{"SpecialToken":{"id":"[CLS]","type_id":0}},{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[SEP]","type_id":0}}],"pair":[{"SpecialToken":{"id":"[CLS]","type_id":0}},{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[SEP]","type_id":0}},{"Sequence":{"id":"B","type_id":1}},{"SpecialToken":{"id":"[SEP]","type_id":1}}],"special_tokens":{"[CLS]":{"id":"[CLS]","ids":[1],"tokens":["[CLS]"]},"[SEP]":{"id":"[SEP]","ids":[2],"tokens":["[SEP]"]}}},"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[UNK]":0,"[CLS]":1,"[SEP]":2,"[PAD]":3,"[MASK]":4,"A":5,"C":6,"D":7,"E":8,"F":9,"G":10,"H":11,"I":12,"K":13,"L":14,"M":15,"N":16,"P":17,"Q":18,"R":19,"S":20,"T":21,"V":22,"W":23,"Y":24},"merges":[]}}
|