Spaces:
Build error
Build error
PascalNotin
commited on
Commit
•
1335bda
1
Parent(s):
6590011
Implemented first version of design app
Browse files- README.md +1 -1
- app.py +139 -0
- requirements.txt +4 -0
- tranception/__init__.py +1 -0
- tranception/activations.py +114 -0
- tranception/config.py +36 -0
- tranception/model_pytorch.py +930 -0
- tranception/outputs.py +48 -0
- tranception/utils/__init__.py +1 -0
- tranception/utils/dms_utils.py +30 -0
- tranception/utils/msa_utils.py +361 -0
- tranception/utils/scoring_utils.py +203 -0
- tranception/utils/tokenizers/Basic_tokenizer +1 -0
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
title: Tranception Design
|
3 |
emoji: 🐨
|
4 |
-
colorFrom:
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.1.7
|
|
|
1 |
---
|
2 |
title: Tranception Design
|
3 |
emoji: 🐨
|
4 |
+
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.1.7
|
app.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
from transformers import PreTrainedTokenizerFast
|
4 |
+
import tranception
|
5 |
+
import datasets
|
6 |
+
from tranception import config, model_pytorch
|
7 |
+
import pandas as pd
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import seaborn as sns
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
|
13 |
+
unk_token="[UNK]",
|
14 |
+
sep_token="[SEP]",
|
15 |
+
pad_token="[PAD]",
|
16 |
+
cls_token="[CLS]",
|
17 |
+
mask_token="[MASK]"
|
18 |
+
)
|
19 |
+
#######################################################################################################################################
|
20 |
+
############################################### HELPER FUNCTIONS ####################################################################
|
21 |
+
#######################################################################################################################################
|
22 |
+
|
23 |
+
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
|
24 |
+
def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
|
25 |
+
all_single_mutants={}
|
26 |
+
sequence_list=list(sequence)
|
27 |
+
if mutation_range_start is None: mutation_range_start=1
|
28 |
+
if mutation_range_end is None: mutation_range_end=len(sequence)
|
29 |
+
for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]):
|
30 |
+
for mutated_AA in AA_vocab:
|
31 |
+
if current_AA!=mutated_AA:
|
32 |
+
mutated_sequence = sequence_list.copy()
|
33 |
+
mutated_sequence[position] = mutated_AA
|
34 |
+
all_single_mutants[current_AA+str(position+1)+mutated_AA]="".join(mutated_sequence)
|
35 |
+
all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index')
|
36 |
+
all_single_mutants.reset_index(inplace=True)
|
37 |
+
all_single_mutants.columns = ['mutant','mutated_sequence']
|
38 |
+
return all_single_mutants
|
39 |
+
|
40 |
+
def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
|
41 |
+
piv=scores.pivot(index='position',columns='target_AA',values='avg_score').transpose().round(4)
|
42 |
+
fig, ax = plt.subplots(figsize=(len(sequence)*1.2,20))
|
43 |
+
scores_dict = {}
|
44 |
+
valid_mutant_set=set(scores.mutant)
|
45 |
+
if mutation_range_start is None: mutation_range_start=1
|
46 |
+
if mutation_range_end is None: mutation_range_start=len(sequence)
|
47 |
+
for target_AA in list(AA_vocab):
|
48 |
+
for position in range(mutation_range_start,mutation_range_end+1):
|
49 |
+
mutant = sequence[position-1]+str(position)+target_AA
|
50 |
+
if mutant in valid_mutant_set:
|
51 |
+
scores_dict[mutant]= float(scores.loc[scores.mutant==mutant,'avg_score'])
|
52 |
+
else:
|
53 |
+
scores_dict[mutant]=0.0
|
54 |
+
labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(len(AA_vocab),mutation_range_end-mutation_range_start+1)
|
55 |
+
heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
|
56 |
+
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'})
|
57 |
+
heat.figure.axes[-1].yaxis.label.set_size(20)
|
58 |
+
#heat.set_title("Fitness scores for all single amino acid substitutions",fontsize=30)
|
59 |
+
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=30, pad=40)
|
60 |
+
heat.set_xlabel("Sequence position", fontsize = 20)
|
61 |
+
heat.set_ylabel("Amino Acid mutation", fontsize = 20)
|
62 |
+
plt.savefig('fitness_scoring_substitution_matrix.png')
|
63 |
+
return plt
|
64 |
+
|
65 |
+
def suggest_mutations(scores):
|
66 |
+
intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
|
67 |
+
#Best mutants
|
68 |
+
top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant)
|
69 |
+
mutant_recos = "The 5 single mutants with highest predicted fitness are:\n {} \n\n".format(", ".join(top_mutants))
|
70 |
+
#Best positions
|
71 |
+
positive_scores = scores[scores.avg_score > 0]
|
72 |
+
positive_scores_position_avg = positive_scores.groupby(['position']).mean()
|
73 |
+
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
|
74 |
+
print(top_positions)
|
75 |
+
position_recos = "The 5 positions with the highest average fitness increase are:\n {}".format(", ".join(top_positions))
|
76 |
+
return intro_message+mutant_recos+position_recos
|
77 |
+
|
78 |
+
def get_mutated_protein(sequence,mutant):
|
79 |
+
mutated_sequence = list(sequence)
|
80 |
+
mutated_sequence[int(mutant[1:-1])-1]=mutant[-1]
|
81 |
+
return ''.join(mutated_sequence)
|
82 |
+
|
83 |
+
def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,num_workers=0,AA_vocab=AA_vocab):
|
84 |
+
if model_type=="Small":
|
85 |
+
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Small",use_auth_token=True)
|
86 |
+
elif model_type=="Medium":
|
87 |
+
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium",use_auth_token=True)
|
88 |
+
elif model_type=="Large":
|
89 |
+
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large",use_auth_token=True)
|
90 |
+
model.config.tokenizer = tokenizer
|
91 |
+
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
|
92 |
+
scores = model.score_mutants(DMS_data=all_single_mutants,
|
93 |
+
target_seq=sequence,
|
94 |
+
scoring_mirror=scoring_mirror,
|
95 |
+
batch_size_inference=batch_size_inference,
|
96 |
+
num_workers=num_workers,
|
97 |
+
indel_mode=False
|
98 |
+
)
|
99 |
+
scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
|
100 |
+
scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
|
101 |
+
scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
|
102 |
+
score_heatmap = create_scoring_matrix_visual(scores,sequence,AA_vocab,mutation_range_start,mutation_range_end)
|
103 |
+
return score_heatmap,suggest_mutations(scores)
|
104 |
+
|
105 |
+
#######################################################################################################################################
|
106 |
+
############################################### GRADIO INTERFACE ####################################################################
|
107 |
+
#######################################################################################################################################
|
108 |
+
|
109 |
+
title = "Interactive in silico directed evolution with Tranception"
|
110 |
+
description = "Perform in silico directed evolution with Tranception to iteratively improve the fitness of a starting protein sequence one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply. Note: The current version does not currently leverage homologs retrieval at inference time to boost fitness prediction performance."
|
111 |
+
article = "<p style='text-align: center'><a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</a></p>"
|
112 |
+
examples=[
|
113 |
+
['A4_HUMAN: MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN'],
|
114 |
+
['ADRB2_HUMAN: MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'],
|
115 |
+
['AMIE_PSEAE: MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA'],
|
116 |
+
['P53_HUMAN: MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD']
|
117 |
+
]
|
118 |
+
|
119 |
+
model_size_selection = gr.Radio(label="Tranception model size", choices=["Small","Medium","Large"], value="Small")
|
120 |
+
protein_sequence_input = gr.Textbox(lines=1, label="Input protein sequence (see below for examples; default = RL40A_YEAST)",value="MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGGIIEPSLKALASKYNCDKSVCRKCYARLPPRATNCRKRKCGHTNQLRPKKKLK")
|
121 |
+
mutation_range_start = gr.Number(label="Start of mutation range (min value = 1)",value=1,precision=0)
|
122 |
+
mutation_range_end = gr.Number(label="End of mutation range (leave empty for full lenth)",value=10,precision=0)
|
123 |
+
scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)")
|
124 |
+
|
125 |
+
#output ==> find a way to make scroallable
|
126 |
+
output_plot = gr.Plot(label="Fitness scores for all single amino acid substitutions in mutation range")
|
127 |
+
output_recommendations = gr.Textbox(label="Mutation recommendations")
|
128 |
+
|
129 |
+
gr.Interface(
|
130 |
+
fn=score_and_create_matrix_all_singles,
|
131 |
+
inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror],
|
132 |
+
outputs=["plot","text"],
|
133 |
+
title=title,
|
134 |
+
description=description,
|
135 |
+
article=article,
|
136 |
+
examples=examples,
|
137 |
+
enable_queue=True,
|
138 |
+
allow_flagging="never"
|
139 |
+
).launch(debug=True)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
transformers==4.17
|
3 |
+
datasets==1.18.3
|
4 |
+
biopython==1.78
|
tranception/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import config
|
tranception/activations.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from packaging import version
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from transformers.utils import logging
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.get_logger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def _gelu_python(x):
|
14 |
+
"""
|
15 |
+
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
16 |
+
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
17 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
18 |
+
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
19 |
+
"""
|
20 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
21 |
+
|
22 |
+
|
23 |
+
def gelu_new(x):
|
24 |
+
"""
|
25 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
26 |
+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
27 |
+
"""
|
28 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
29 |
+
|
30 |
+
|
31 |
+
if version.parse(torch.__version__) < version.parse("1.4"):
|
32 |
+
gelu = _gelu_python
|
33 |
+
else:
|
34 |
+
gelu = nn.functional.gelu
|
35 |
+
|
36 |
+
|
37 |
+
def gelu_fast(x):
|
38 |
+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
39 |
+
|
40 |
+
|
41 |
+
def quick_gelu(x):
|
42 |
+
return x * torch.sigmoid(1.702 * x)
|
43 |
+
|
44 |
+
|
45 |
+
def _silu_python(x):
|
46 |
+
"""
|
47 |
+
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
|
48 |
+
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
|
49 |
+
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
|
50 |
+
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
|
51 |
+
later.
|
52 |
+
"""
|
53 |
+
return x * torch.sigmoid(x)
|
54 |
+
|
55 |
+
|
56 |
+
if version.parse(torch.__version__) < version.parse("1.7"):
|
57 |
+
silu = _silu_python
|
58 |
+
else:
|
59 |
+
silu = nn.functional.silu
|
60 |
+
|
61 |
+
|
62 |
+
def _mish_python(x):
|
63 |
+
"""
|
64 |
+
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
65 |
+
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
66 |
+
"""
|
67 |
+
return x * torch.tanh(nn.functional.softplus(x))
|
68 |
+
|
69 |
+
|
70 |
+
if version.parse(torch.__version__) < version.parse("1.9"):
|
71 |
+
mish = _mish_python
|
72 |
+
else:
|
73 |
+
mish = nn.functional.mish
|
74 |
+
|
75 |
+
|
76 |
+
def linear_act(x):
|
77 |
+
return x
|
78 |
+
|
79 |
+
def squared_relu(x):
|
80 |
+
"""
|
81 |
+
Squared ReLU variant that is fastest with Pytorch.
|
82 |
+
"""
|
83 |
+
x = nn.functional.relu(x)
|
84 |
+
return x*x
|
85 |
+
|
86 |
+
def squared_relu_xla(x):
|
87 |
+
"""
|
88 |
+
Squared ReLU variant that is fastest with JAX.
|
89 |
+
"""
|
90 |
+
x = nn.functional.relu(x)
|
91 |
+
return x**2
|
92 |
+
|
93 |
+
tranception_ACT2FN = {
|
94 |
+
"relu": nn.functional.relu,
|
95 |
+
"silu": silu,
|
96 |
+
"swish": silu,
|
97 |
+
"gelu": gelu,
|
98 |
+
"tanh": torch.tanh,
|
99 |
+
"gelu_new": gelu_new,
|
100 |
+
"gelu_fast": gelu_fast,
|
101 |
+
"quick_gelu": quick_gelu,
|
102 |
+
"mish": mish,
|
103 |
+
"linear": linear_act,
|
104 |
+
"sigmoid": torch.sigmoid,
|
105 |
+
"squared_relu": squared_relu,
|
106 |
+
"squared_relu_xla": squared_relu_xla,
|
107 |
+
}
|
108 |
+
|
109 |
+
|
110 |
+
def get_activation(activation_string):
|
111 |
+
if activation_string in tranception_ACT2FN:
|
112 |
+
return tranception_ACT2FN[activation_string]
|
113 |
+
else:
|
114 |
+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(tranception_ACT2FN.keys())}")
|
tranception/config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPT2Config
|
2 |
+
|
3 |
+
class TranceptionConfig(GPT2Config):
|
4 |
+
"""
|
5 |
+
Config subclass for Tranception model architecture.
|
6 |
+
"""
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
attention_mode="tranception",
|
10 |
+
position_embedding="grouped_alibi",
|
11 |
+
tokenizer=None,
|
12 |
+
retrieval_aggregation_mode=None,
|
13 |
+
retrieval_inference_weight=0.6,
|
14 |
+
MSA_filename=None,
|
15 |
+
MSA_weight_file_name=None,
|
16 |
+
MSA_start=None,
|
17 |
+
MSA_end=None,
|
18 |
+
full_protein_length=None,
|
19 |
+
clustal_omega_location=None,
|
20 |
+
scoring_window=None,
|
21 |
+
**kwargs
|
22 |
+
):
|
23 |
+
super().__init__(**kwargs)
|
24 |
+
self.model_type="tranception"
|
25 |
+
self.attention_mode=attention_mode
|
26 |
+
self.position_embedding=position_embedding
|
27 |
+
self.tokenizer = tokenizer
|
28 |
+
self.retrieval_aggregation_mode = retrieval_aggregation_mode
|
29 |
+
self.retrieval_inference_weight = retrieval_inference_weight
|
30 |
+
self.MSA_filename = MSA_filename
|
31 |
+
self.MSA_weight_file_name = MSA_weight_file_name
|
32 |
+
self.MSA_start=MSA_start
|
33 |
+
self.MSA_end=MSA_end
|
34 |
+
self.full_protein_length = full_protein_length
|
35 |
+
self.clustal_omega_location = clustal_omega_location
|
36 |
+
self.scoring_window=scoring_window
|
tranception/model_pytorch.py
ADDED
@@ -0,0 +1,930 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import CrossEntropyLoss, NLLLoss
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from transformers import GPT2PreTrainedModel
|
12 |
+
|
13 |
+
from transformers.modeling_utils import (
|
14 |
+
Conv1D,
|
15 |
+
PreTrainedModel,
|
16 |
+
SequenceSummary,
|
17 |
+
find_pruneable_heads_and_indices,
|
18 |
+
prune_conv1d_layer,
|
19 |
+
)
|
20 |
+
from transformers.file_utils import (
|
21 |
+
ModelOutput,
|
22 |
+
add_code_sample_docstrings,
|
23 |
+
add_start_docstrings,
|
24 |
+
add_start_docstrings_to_model_forward,
|
25 |
+
replace_return_docstrings
|
26 |
+
)
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
29 |
+
CausalLMOutputWithCrossAttentions,
|
30 |
+
SequenceClassifierOutputWithPast,
|
31 |
+
TokenClassifierOutput
|
32 |
+
)
|
33 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
34 |
+
|
35 |
+
from tranception.activations import tranception_ACT2FN
|
36 |
+
from tranception.config import TranceptionConfig
|
37 |
+
from tranception.outputs import (
|
38 |
+
TranceptionCausalLMOutputWithCrossAttentions,
|
39 |
+
)
|
40 |
+
from tranception.utils import msa_utils
|
41 |
+
from tranception.utils import scoring_utils
|
42 |
+
|
43 |
+
def nanmean(v, *args, inplace=False, **kwargs):
|
44 |
+
if not inplace:
|
45 |
+
v = v.clone()
|
46 |
+
is_nan = torch.isnan(v)
|
47 |
+
v[is_nan] = 0
|
48 |
+
return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
|
49 |
+
|
50 |
+
def get_slopes(n, mode="standard_alibi", verbose=False):
|
51 |
+
"""
|
52 |
+
Function to compute the m constant for each attention head. Code has been adapted from the official ALiBi codebase at:
|
53 |
+
https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py
|
54 |
+
"""
|
55 |
+
def get_slopes_power_of_2(n):
|
56 |
+
start = (2**(-2**-(math.log2(n)-3)))
|
57 |
+
ratio = start
|
58 |
+
return [start*ratio**i for i in range(n)]
|
59 |
+
if mode=="grouped_alibi":
|
60 |
+
n = n // 4
|
61 |
+
if math.log2(n).is_integer():
|
62 |
+
result = get_slopes_power_of_2(n)
|
63 |
+
else:
|
64 |
+
#Workaround when the number of heads is not a power of 2
|
65 |
+
closest_power_of_2 = 2**math.floor(math.log2(n))
|
66 |
+
result = get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
|
67 |
+
if mode=="grouped_alibi":
|
68 |
+
result = result * 4
|
69 |
+
if verbose:
|
70 |
+
print("ALiBi slopes: {}".format(result))
|
71 |
+
return result
|
72 |
+
|
73 |
+
class SpatialDepthWiseConvolution(nn.Module):
|
74 |
+
def __init__(self, head_dim: int, kernel_size: int = 3):
|
75 |
+
super().__init__()
|
76 |
+
self.kernel_size = kernel_size
|
77 |
+
self.conv = nn.Conv1d(in_channels=head_dim, out_channels=head_dim, kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=head_dim)
|
78 |
+
|
79 |
+
def forward(self, x: torch.Tensor):
|
80 |
+
batch_size, heads, seq_len, head_dim = x.shape
|
81 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
82 |
+
x = x.view(batch_size * heads, head_dim, seq_len)
|
83 |
+
x = self.conv(x)
|
84 |
+
if self.kernel_size>1:
|
85 |
+
x = x[:, :, :-(self.kernel_size - 1)]
|
86 |
+
x = x.view(batch_size, heads, head_dim, seq_len)
|
87 |
+
x = x.permute(0, 1, 3, 2)
|
88 |
+
return x
|
89 |
+
|
90 |
+
class TranceptionBlockAttention(nn.Module):
|
91 |
+
def __init__(self, config, is_cross_attention=False, SDWC_kernel_size=None):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
max_positions = config.max_position_embeddings
|
95 |
+
self.register_buffer(
|
96 |
+
"bias",
|
97 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
98 |
+
1, 1, max_positions, max_positions
|
99 |
+
),
|
100 |
+
)
|
101 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
102 |
+
|
103 |
+
self.embed_dim = config.hidden_size
|
104 |
+
self.num_heads = config.num_attention_heads
|
105 |
+
self.head_dim = self.embed_dim // self.num_heads
|
106 |
+
self.split_size = self.embed_dim
|
107 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
108 |
+
raise ValueError(
|
109 |
+
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
110 |
+
)
|
111 |
+
|
112 |
+
self.scale_attn_weights = config.scale_attn_weights
|
113 |
+
self.is_cross_attention = is_cross_attention
|
114 |
+
|
115 |
+
if self.is_cross_attention:
|
116 |
+
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
117 |
+
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
118 |
+
else:
|
119 |
+
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
120 |
+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
121 |
+
|
122 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
123 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
124 |
+
|
125 |
+
self.pruned_heads = set()
|
126 |
+
|
127 |
+
self.attention_mode=config.attention_mode
|
128 |
+
|
129 |
+
if self.attention_mode=="tranception":
|
130 |
+
assert self.num_heads%4==0, "Invalid number of heads. Tranception requires the number of heads to be a multiple of 4."
|
131 |
+
self.num_heads_per_kernel_size = self.num_heads // 4
|
132 |
+
self.query_depthwiseconv = nn.ModuleDict()
|
133 |
+
self.key_depthwiseconv = nn.ModuleDict()
|
134 |
+
self.value_depthwiseconv = nn.ModuleDict()
|
135 |
+
for kernel_idx, kernel in enumerate([3,5,7]):
|
136 |
+
self.query_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
|
137 |
+
self.key_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
|
138 |
+
self.value_depthwiseconv[str(kernel_idx)] = SpatialDepthWiseConvolution(self.head_dim,kernel)
|
139 |
+
|
140 |
+
def prune_heads(self, heads):
|
141 |
+
if len(heads) == 0:
|
142 |
+
return
|
143 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
144 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
145 |
+
|
146 |
+
# Prune conv1d layers
|
147 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
148 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
149 |
+
|
150 |
+
# Update hyper params
|
151 |
+
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
152 |
+
self.num_heads = self.num_heads - len(heads)
|
153 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
154 |
+
|
155 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None, alibi_bias=None):
|
156 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
157 |
+
|
158 |
+
if self.scale_attn_weights:
|
159 |
+
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
|
160 |
+
|
161 |
+
if not self.is_cross_attention:
|
162 |
+
# if only "normal" attention layer implements causal mask
|
163 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
164 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
165 |
+
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
166 |
+
|
167 |
+
if alibi_bias is not None:
|
168 |
+
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]
|
169 |
+
|
170 |
+
if attention_mask is not None:
|
171 |
+
# Apply the attention mask
|
172 |
+
attn_weights = attn_weights + attention_mask
|
173 |
+
|
174 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
175 |
+
attn_weights = self.attn_dropout(attn_weights)
|
176 |
+
|
177 |
+
# Mask heads if we want to
|
178 |
+
if head_mask is not None:
|
179 |
+
attn_weights = attn_weights * head_mask
|
180 |
+
|
181 |
+
attn_output = torch.matmul(attn_weights, value)
|
182 |
+
|
183 |
+
return attn_output, attn_weights
|
184 |
+
|
185 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
186 |
+
"""
|
187 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
188 |
+
"""
|
189 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
190 |
+
tensor = tensor.view(*new_shape)
|
191 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
192 |
+
|
193 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
194 |
+
"""
|
195 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
196 |
+
"""
|
197 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
198 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
199 |
+
return tensor.view(new_shape)
|
200 |
+
|
201 |
+
def forward(
|
202 |
+
self,
|
203 |
+
hidden_states,
|
204 |
+
layer_past=None,
|
205 |
+
attention_mask=None,
|
206 |
+
head_mask=None,
|
207 |
+
encoder_hidden_states=None,
|
208 |
+
encoder_attention_mask=None,
|
209 |
+
use_cache=False,
|
210 |
+
output_attentions=False,
|
211 |
+
alibi_bias=None,
|
212 |
+
):
|
213 |
+
if encoder_hidden_states is not None:
|
214 |
+
if not hasattr(self, "q_attn"):
|
215 |
+
raise ValueError(
|
216 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
217 |
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
218 |
+
)
|
219 |
+
|
220 |
+
query = self.q_attn(hidden_states)
|
221 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
222 |
+
attention_mask = encoder_attention_mask
|
223 |
+
else:
|
224 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
225 |
+
|
226 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
227 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
228 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
229 |
+
|
230 |
+
if layer_past is not None:
|
231 |
+
past_key, past_value = layer_past
|
232 |
+
key = torch.cat((past_key, key), dim=-2)
|
233 |
+
value = torch.cat((past_value, value), dim=-2)
|
234 |
+
|
235 |
+
if use_cache is True:
|
236 |
+
present = (key, value)
|
237 |
+
else:
|
238 |
+
present = None
|
239 |
+
|
240 |
+
if self.attention_mode=="tranception":
|
241 |
+
# We do not do anything on the first self.num_heads_per_kernel_size heads (kernel =1)
|
242 |
+
query_list=[query[:,:self.num_heads_per_kernel_size,:,:]]
|
243 |
+
key_list=[key[:,:self.num_heads_per_kernel_size,:,:]]
|
244 |
+
value_list=[value[:,:self.num_heads_per_kernel_size,:,:]]
|
245 |
+
for kernel_idx in range(3):
|
246 |
+
query_list.append(self.query_depthwiseconv[str(kernel_idx)](query[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
|
247 |
+
key_list.append(self.key_depthwiseconv[str(kernel_idx)](key[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
|
248 |
+
value_list.append(self.value_depthwiseconv[str(kernel_idx)](value[:,(kernel_idx+1)*self.num_heads_per_kernel_size:(kernel_idx+2)*self.num_heads_per_kernel_size,:,:]))
|
249 |
+
query=torch.cat(query_list, dim=1)
|
250 |
+
key=torch.cat(key_list, dim=1)
|
251 |
+
value=torch.cat(value_list, dim=1)
|
252 |
+
|
253 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, alibi_bias=alibi_bias)
|
254 |
+
|
255 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
256 |
+
attn_output = self.c_proj(attn_output)
|
257 |
+
attn_output = self.resid_dropout(attn_output)
|
258 |
+
|
259 |
+
outputs = (attn_output, present)
|
260 |
+
if output_attentions:
|
261 |
+
outputs += (attn_weights,)
|
262 |
+
|
263 |
+
return outputs # a, present, (attentions)
|
264 |
+
|
265 |
+
class TranceptionBlockMLP(nn.Module):
|
266 |
+
def __init__(self, intermediate_size, config):
|
267 |
+
super().__init__()
|
268 |
+
embed_dim = config.hidden_size
|
269 |
+
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
270 |
+
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
271 |
+
self.act = tranception_ACT2FN[config.activation_function]
|
272 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
273 |
+
|
274 |
+
def forward(self, hidden_states):
|
275 |
+
hidden_states = self.c_fc(hidden_states)
|
276 |
+
hidden_states = self.act(hidden_states)
|
277 |
+
hidden_states = self.c_proj(hidden_states)
|
278 |
+
hidden_states = self.dropout(hidden_states)
|
279 |
+
return hidden_states
|
280 |
+
|
281 |
+
class TranceptionBlock(nn.Module):
|
282 |
+
def __init__(self, config, SDWC_kernel_size=None):
|
283 |
+
super().__init__()
|
284 |
+
hidden_size = config.hidden_size
|
285 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
286 |
+
|
287 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
288 |
+
self.attn = TranceptionBlockAttention(config, SDWC_kernel_size=SDWC_kernel_size)
|
289 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
290 |
+
|
291 |
+
if config.add_cross_attention:
|
292 |
+
self.crossattention = TranceptionBlockAttention(config, is_cross_attention=True, SDWC_kernel_size=SDWC_kernel_size)
|
293 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
294 |
+
|
295 |
+
self.mlp = TranceptionBlockMLP(inner_dim, config)
|
296 |
+
|
297 |
+
def forward(
|
298 |
+
self,
|
299 |
+
hidden_states,
|
300 |
+
layer_past=None,
|
301 |
+
attention_mask=None,
|
302 |
+
head_mask=None,
|
303 |
+
encoder_hidden_states=None,
|
304 |
+
encoder_attention_mask=None,
|
305 |
+
use_cache=False,
|
306 |
+
output_attentions=False,
|
307 |
+
alibi_bias=None,
|
308 |
+
):
|
309 |
+
residual = hidden_states
|
310 |
+
hidden_states = self.ln_1(hidden_states)
|
311 |
+
attn_outputs = self.attn(
|
312 |
+
hidden_states,
|
313 |
+
layer_past=layer_past,
|
314 |
+
attention_mask=attention_mask,
|
315 |
+
head_mask=head_mask,
|
316 |
+
use_cache=use_cache,
|
317 |
+
output_attentions=output_attentions,
|
318 |
+
alibi_bias=alibi_bias,
|
319 |
+
)
|
320 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
321 |
+
outputs = attn_outputs[1:]
|
322 |
+
# residual connection
|
323 |
+
hidden_states = attn_output + residual
|
324 |
+
|
325 |
+
if encoder_hidden_states is not None:
|
326 |
+
# add one self-attention block for cross-attention
|
327 |
+
if not hasattr(self, "crossattention"):
|
328 |
+
raise ValueError(
|
329 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
330 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
331 |
+
)
|
332 |
+
residual = hidden_states
|
333 |
+
hidden_states = self.ln_cross_attn(hidden_states)
|
334 |
+
cross_attn_outputs = self.crossattention(
|
335 |
+
hidden_states,
|
336 |
+
attention_mask=attention_mask,
|
337 |
+
head_mask=head_mask,
|
338 |
+
encoder_hidden_states=encoder_hidden_states,
|
339 |
+
encoder_attention_mask=encoder_attention_mask,
|
340 |
+
output_attentions=output_attentions,
|
341 |
+
)
|
342 |
+
attn_output = cross_attn_outputs[0]
|
343 |
+
# residual connection
|
344 |
+
hidden_states = residual + attn_output
|
345 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
346 |
+
|
347 |
+
residual = hidden_states
|
348 |
+
hidden_states = self.ln_2(hidden_states)
|
349 |
+
|
350 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
351 |
+
|
352 |
+
# residual connection
|
353 |
+
hidden_states = residual + feed_forward_hidden_states
|
354 |
+
|
355 |
+
if use_cache:
|
356 |
+
outputs = (hidden_states,) + outputs
|
357 |
+
else:
|
358 |
+
outputs = (hidden_states,) + outputs[1:]
|
359 |
+
|
360 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
361 |
+
|
362 |
+
class TranceptionModel(GPT2PreTrainedModel):
|
363 |
+
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
364 |
+
def __init__(self, config):
|
365 |
+
super().__init__(config)
|
366 |
+
|
367 |
+
self.embed_dim = config.hidden_size
|
368 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
369 |
+
self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"
|
370 |
+
if self.position_embedding=="learned":
|
371 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
372 |
+
self.alibi = None
|
373 |
+
elif self.position_embedding=="grouped_alibi":
|
374 |
+
maxpos = config.n_positions
|
375 |
+
attn_heads = config.n_head
|
376 |
+
self.slopes = torch.Tensor(get_slopes(attn_heads, mode=self.position_embedding))
|
377 |
+
#The softmax operation is invariant to translation, and bias functions used are always linear.
|
378 |
+
alibi = self.slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(0).unsqueeze(0).expand(attn_heads, -1, -1)
|
379 |
+
alibi = alibi.view(attn_heads, 1, maxpos)
|
380 |
+
self.register_buffer('alibi',alibi)
|
381 |
+
|
382 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
383 |
+
self.h = nn.ModuleList([TranceptionBlock(config) for _ in range(config.num_hidden_layers)])
|
384 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
385 |
+
|
386 |
+
self.init_weights()
|
387 |
+
|
388 |
+
# Model parallel
|
389 |
+
self.model_parallel = False
|
390 |
+
self.device_map = None
|
391 |
+
self.gradient_checkpointing = False
|
392 |
+
|
393 |
+
def parallelize(self, device_map=None, num_cores=None):
|
394 |
+
self.device_map = (
|
395 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
396 |
+
)
|
397 |
+
device_prefix="cuda:"
|
398 |
+
assert_device_map(self.device_map, len(self.h))
|
399 |
+
self.model_parallel = True
|
400 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else device_prefix + str(min(self.device_map.keys()))
|
401 |
+
self.last_device = device_prefix + str(max(self.device_map.keys()))
|
402 |
+
self.wte = self.wte.to(self.first_device)
|
403 |
+
if self.position_embedding=="learned":
|
404 |
+
self.wpe = self.wpe.to(self.first_device)
|
405 |
+
for k, v in self.device_map.items():
|
406 |
+
print("k,v :"+str(k)+","+str(v))
|
407 |
+
for block in v:
|
408 |
+
cuda_device = device_prefix + str(k)
|
409 |
+
self.h[block] = self.h[block].to(cuda_device)
|
410 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
411 |
+
|
412 |
+
def deparallelize(self):
|
413 |
+
self.model_parallel = False
|
414 |
+
self.device_map = None
|
415 |
+
self.first_device = "cpu"
|
416 |
+
self.last_device = "cpu"
|
417 |
+
self.wte = self.wte.to("cpu")
|
418 |
+
if self.position_embedding=="learned":
|
419 |
+
self.wpe = self.wpe.to("cpu")
|
420 |
+
for index in range(len(self.h)):
|
421 |
+
self.h[index] = self.h[index].to("cpu")
|
422 |
+
self.ln_f = self.ln_f.to("cpu")
|
423 |
+
torch.cuda.empty_cache()
|
424 |
+
|
425 |
+
def get_input_embeddings(self):
|
426 |
+
return self.wte
|
427 |
+
|
428 |
+
def set_input_embeddings(self, new_embeddings):
|
429 |
+
self.wte = new_embeddings
|
430 |
+
|
431 |
+
def _prune_heads(self, heads_to_prune):
|
432 |
+
"""
|
433 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
434 |
+
"""
|
435 |
+
for layer, heads in heads_to_prune.items():
|
436 |
+
self.h[layer].attn.prune_heads(heads)
|
437 |
+
|
438 |
+
def forward(
|
439 |
+
self,
|
440 |
+
input_ids=None,
|
441 |
+
past_key_values=None,
|
442 |
+
attention_mask=None,
|
443 |
+
token_type_ids=None,
|
444 |
+
position_ids=None,
|
445 |
+
head_mask=None,
|
446 |
+
inputs_embeds=None,
|
447 |
+
encoder_hidden_states=None,
|
448 |
+
encoder_attention_mask=None,
|
449 |
+
use_cache=None,
|
450 |
+
output_attentions=None,
|
451 |
+
output_hidden_states=None,
|
452 |
+
return_dict=None,
|
453 |
+
):
|
454 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
455 |
+
output_hidden_states = (
|
456 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
457 |
+
)
|
458 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
459 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
460 |
+
|
461 |
+
if input_ids is not None and inputs_embeds is not None:
|
462 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
463 |
+
elif input_ids is not None:
|
464 |
+
input_shape = input_ids.size()
|
465 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
466 |
+
batch_size = input_ids.shape[0]
|
467 |
+
elif inputs_embeds is not None:
|
468 |
+
input_shape = inputs_embeds.size()[:-1]
|
469 |
+
batch_size = inputs_embeds.shape[0]
|
470 |
+
else:
|
471 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
472 |
+
|
473 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
474 |
+
|
475 |
+
if token_type_ids is not None:
|
476 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
477 |
+
if position_ids is not None:
|
478 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
479 |
+
|
480 |
+
if past_key_values is None:
|
481 |
+
past_length = 0
|
482 |
+
past_key_values = tuple([None] * len(self.h))
|
483 |
+
else:
|
484 |
+
past_length = past_key_values[0][0].size(-2)
|
485 |
+
if position_ids is None:
|
486 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
487 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
488 |
+
|
489 |
+
# GPT2Attention mask.
|
490 |
+
if attention_mask is not None:
|
491 |
+
if batch_size <= 0:
|
492 |
+
raise ValueError("batch_size has to be defined and > 0")
|
493 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
494 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
495 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
496 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
497 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
498 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
499 |
+
attention_mask = attention_mask[:, None, None, :]
|
500 |
+
|
501 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
502 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
503 |
+
# positions we want to attend and -10000.0 for masked positions.
|
504 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
505 |
+
# effectively the same as removing these entirely.
|
506 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
507 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
508 |
+
|
509 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
510 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
511 |
+
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
512 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
513 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
514 |
+
if encoder_attention_mask is None:
|
515 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
516 |
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
517 |
+
else:
|
518 |
+
encoder_attention_mask = None
|
519 |
+
|
520 |
+
# Prepare head mask if needed
|
521 |
+
# 1.0 in head_mask indicate we keep the head
|
522 |
+
# attention_probs has shape bsz x n_heads x N x N
|
523 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
524 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
525 |
+
|
526 |
+
if inputs_embeds is None:
|
527 |
+
inputs_embeds = self.wte(input_ids)
|
528 |
+
if self.position_embedding=="learned":
|
529 |
+
position_embeds = self.wpe(position_ids)
|
530 |
+
hidden_states = inputs_embeds + position_embeds
|
531 |
+
else:
|
532 |
+
hidden_states = inputs_embeds
|
533 |
+
|
534 |
+
if token_type_ids is not None:
|
535 |
+
token_type_embeds = self.wte(token_type_ids)
|
536 |
+
hidden_states = hidden_states + token_type_embeds
|
537 |
+
|
538 |
+
hidden_states = self.drop(hidden_states)
|
539 |
+
|
540 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
541 |
+
|
542 |
+
presents = () if use_cache else None
|
543 |
+
all_self_attentions = () if output_attentions else None
|
544 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
545 |
+
all_hidden_states = () if output_hidden_states else None
|
546 |
+
|
547 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
548 |
+
# Model parallel
|
549 |
+
if self.model_parallel:
|
550 |
+
torch.cuda.set_device(hidden_states.device)
|
551 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
552 |
+
if layer_past is not None:
|
553 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
554 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
555 |
+
if attention_mask is not None:
|
556 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
557 |
+
if isinstance(head_mask, torch.Tensor):
|
558 |
+
head_mask = head_mask.to(hidden_states.device)
|
559 |
+
if output_hidden_states:
|
560 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
561 |
+
|
562 |
+
if self.gradient_checkpointing and self.training:
|
563 |
+
if use_cache:
|
564 |
+
logger.warning(
|
565 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
566 |
+
)
|
567 |
+
use_cache = False
|
568 |
+
|
569 |
+
def create_custom_forward(module):
|
570 |
+
def custom_forward(*inputs):
|
571 |
+
# None for past_key_value
|
572 |
+
return module(*inputs, use_cache, output_attentions)
|
573 |
+
|
574 |
+
return custom_forward
|
575 |
+
|
576 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
577 |
+
create_custom_forward(block),
|
578 |
+
hidden_states,
|
579 |
+
None,
|
580 |
+
attention_mask,
|
581 |
+
head_mask[i],
|
582 |
+
encoder_hidden_states,
|
583 |
+
encoder_attention_mask,
|
584 |
+
)
|
585 |
+
else:
|
586 |
+
outputs = block(
|
587 |
+
hidden_states,
|
588 |
+
layer_past=layer_past,
|
589 |
+
attention_mask=attention_mask,
|
590 |
+
head_mask=head_mask[i],
|
591 |
+
encoder_hidden_states=encoder_hidden_states,
|
592 |
+
encoder_attention_mask=encoder_attention_mask,
|
593 |
+
use_cache=use_cache,
|
594 |
+
output_attentions=output_attentions,
|
595 |
+
alibi_bias=self.alibi if hasattr(self, "alibi") else None
|
596 |
+
)
|
597 |
+
|
598 |
+
hidden_states = outputs[0]
|
599 |
+
|
600 |
+
if use_cache is True:
|
601 |
+
presents = presents + (outputs[1],)
|
602 |
+
|
603 |
+
if output_attentions:
|
604 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
605 |
+
if self.config.add_cross_attention:
|
606 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
607 |
+
|
608 |
+
if self.model_parallel:
|
609 |
+
device_prefix="cuda:"
|
610 |
+
for k, v in self.device_map.items():
|
611 |
+
if i == v[-1] and device_prefix + str(k) != self.last_device:
|
612 |
+
hidden_states = hidden_states.to(device_prefix + str(k + 1))
|
613 |
+
|
614 |
+
hidden_states = self.ln_f(hidden_states)
|
615 |
+
|
616 |
+
hidden_states = hidden_states.view(*output_shape)
|
617 |
+
# Add last hidden state
|
618 |
+
if output_hidden_states:
|
619 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
620 |
+
|
621 |
+
if not return_dict:
|
622 |
+
return tuple(
|
623 |
+
v
|
624 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions, moe_loss]
|
625 |
+
if v is not None
|
626 |
+
)
|
627 |
+
|
628 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
629 |
+
last_hidden_state=hidden_states,
|
630 |
+
past_key_values=presents,
|
631 |
+
hidden_states=all_hidden_states,
|
632 |
+
attentions=all_self_attentions,
|
633 |
+
cross_attentions=all_cross_attentions,
|
634 |
+
)
|
635 |
+
|
636 |
+
class TranceptionLMHeadModel(GPT2PreTrainedModel):
|
637 |
+
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
638 |
+
def __init__(self, config):
|
639 |
+
super().__init__(config)
|
640 |
+
self.transformer = TranceptionModel(config)
|
641 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
642 |
+
self.config = config
|
643 |
+
|
644 |
+
self.init_weights()
|
645 |
+
|
646 |
+
self.default_model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
647 |
+
# Model parallel
|
648 |
+
self.model_parallel = False
|
649 |
+
self.device_map = None
|
650 |
+
|
651 |
+
self.retrieval_aggregation_mode = config.retrieval_aggregation_mode if hasattr(config, "retrieval_aggregation_mode") else None
|
652 |
+
if self.retrieval_aggregation_mode is not None:
|
653 |
+
print("Model leverages both autoregressive and retrieval inference")
|
654 |
+
self.MSA_filename = config.MSA_filename if hasattr(config, "MSA_filename") else False
|
655 |
+
self.MSA_folder = '/'.join(self.MSA_filename.split(os.sep)[:-1])
|
656 |
+
self.MSA_name = self.MSA_filename.split(os.sep)[-1]
|
657 |
+
self.retrieval_inference_weight_LR = config.retrieval_inference_weight if hasattr(config, "retrieval_inference_weight") else 0.6
|
658 |
+
self.retrieval_inference_weight_RL = config.retrieval_inference_weight if hasattr(config, "retrieval_inference_weight") else 0.6
|
659 |
+
self.MSA_start=config.MSA_start
|
660 |
+
self.MSA_end=config.MSA_end
|
661 |
+
self.full_protein_length = config.full_protein_length if hasattr(config, "full_protein_length") else -1
|
662 |
+
|
663 |
+
self.MSA_log_prior = torch.log(torch.tensor(
|
664 |
+
msa_utils.get_msa_prior(
|
665 |
+
MSA_data_file=self.MSA_filename,
|
666 |
+
MSA_weight_file_name=config.MSA_weight_file_name,
|
667 |
+
retrieval_aggregation_mode=self.retrieval_aggregation_mode,
|
668 |
+
MSA_start=self.MSA_start,
|
669 |
+
MSA_end=self.MSA_end,
|
670 |
+
len_target_seq=self.full_protein_length,
|
671 |
+
vocab=config.tokenizer.get_vocab(),
|
672 |
+
verbose=False
|
673 |
+
)
|
674 |
+
).float().to(self.default_model_device))
|
675 |
+
else:
|
676 |
+
print("Model only uses autoregressive inference")
|
677 |
+
|
678 |
+
def parallelize(self, device_map=None, num_cores=None, num_pipelines=1):
|
679 |
+
self.num_pipelines=num_pipelines
|
680 |
+
self.device_map = (
|
681 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
682 |
+
if device_map is None
|
683 |
+
else device_map
|
684 |
+
)
|
685 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
686 |
+
self.transformer.parallelize(self.device_map, num_cores=num_cores)
|
687 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
688 |
+
self.model_parallel = True
|
689 |
+
|
690 |
+
def deparallelize(self):
|
691 |
+
self.transformer.deparallelize()
|
692 |
+
self.transformer = self.transformer.to("cpu")
|
693 |
+
self.lm_head = self.lm_head.to("cpu")
|
694 |
+
self.model_parallel = False
|
695 |
+
torch.cuda.empty_cache()
|
696 |
+
|
697 |
+
def get_output_embeddings(self):
|
698 |
+
return self.lm_head
|
699 |
+
|
700 |
+
def set_output_embeddings(self, new_embeddings):
|
701 |
+
self.lm_head = new_embeddings
|
702 |
+
|
703 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
704 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
705 |
+
# only last token for inputs_ids if past is defined in kwargs
|
706 |
+
if past:
|
707 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
708 |
+
if token_type_ids is not None:
|
709 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
710 |
+
|
711 |
+
attention_mask = kwargs.get("attention_mask", None)
|
712 |
+
position_ids = kwargs.get("position_ids", None)
|
713 |
+
|
714 |
+
if attention_mask is not None and position_ids is None:
|
715 |
+
# create position_ids on the fly for batch generation
|
716 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
717 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
718 |
+
if past:
|
719 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
720 |
+
else:
|
721 |
+
position_ids = None
|
722 |
+
|
723 |
+
return {
|
724 |
+
"input_ids": input_ids,
|
725 |
+
"past_key_values": past,
|
726 |
+
"use_cache": kwargs.get("use_cache"),
|
727 |
+
"position_ids": position_ids,
|
728 |
+
"attention_mask": attention_mask,
|
729 |
+
"token_type_ids": token_type_ids,
|
730 |
+
"flip": kwargs.get("flip", None),
|
731 |
+
}
|
732 |
+
|
733 |
+
def forward(
|
734 |
+
self,
|
735 |
+
input_ids=None,
|
736 |
+
past_key_values=None,
|
737 |
+
attention_mask=None,
|
738 |
+
token_type_ids=None,
|
739 |
+
position_ids=None,
|
740 |
+
head_mask=None,
|
741 |
+
inputs_embeds=None,
|
742 |
+
encoder_hidden_states=None,
|
743 |
+
encoder_attention_mask=None,
|
744 |
+
labels=None,
|
745 |
+
use_cache=None,
|
746 |
+
output_attentions=None,
|
747 |
+
output_hidden_states=None,
|
748 |
+
return_dict=None,
|
749 |
+
flip=None,
|
750 |
+
start_slice=None,
|
751 |
+
end_slice=None,
|
752 |
+
mutated_sequence=None,
|
753 |
+
):
|
754 |
+
r"""
|
755 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
756 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
757 |
+
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
758 |
+
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
|
759 |
+
"""
|
760 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
761 |
+
|
762 |
+
transformer_outputs = self.transformer(
|
763 |
+
input_ids,
|
764 |
+
past_key_values=past_key_values,
|
765 |
+
attention_mask=attention_mask,
|
766 |
+
token_type_ids=token_type_ids,
|
767 |
+
position_ids=position_ids,
|
768 |
+
head_mask=head_mask,
|
769 |
+
inputs_embeds=inputs_embeds,
|
770 |
+
encoder_hidden_states=encoder_hidden_states,
|
771 |
+
encoder_attention_mask=encoder_attention_mask,
|
772 |
+
use_cache=use_cache,
|
773 |
+
output_attentions=output_attentions,
|
774 |
+
output_hidden_states=output_hidden_states,
|
775 |
+
return_dict=return_dict
|
776 |
+
)
|
777 |
+
hidden_states = transformer_outputs[0]
|
778 |
+
|
779 |
+
# Set device for model parallelism
|
780 |
+
if self.model_parallel:
|
781 |
+
torch.cuda.set_device(self.transformer.first_device)
|
782 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
783 |
+
self.MSA_log_prior = self.MSA_log_prior.to(self.lm_head.weight.device)
|
784 |
+
|
785 |
+
lm_logits = self.lm_head(hidden_states)
|
786 |
+
|
787 |
+
loss = None
|
788 |
+
if labels is not None:
|
789 |
+
# Shift so that tokens < n predict n
|
790 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
791 |
+
shift_labels = labels[..., 1:].contiguous()
|
792 |
+
|
793 |
+
if self.retrieval_aggregation_mode is not None:
|
794 |
+
batch_size = input_ids.size(0)
|
795 |
+
|
796 |
+
if self.retrieval_aggregation_mode=="aggregate_indel":
|
797 |
+
assert batch_size==1, "Aggregate indel is only supported for batch size of 1"
|
798 |
+
truncated_sequence_text = mutated_sequence[0][start_slice[0]:end_slice[0]]
|
799 |
+
if len(truncated_sequence_text)!=shift_logits.shape[1]-1: # shift_logits only has one extra token compared to truncated_sequence_text (the BOS token)
|
800 |
+
print("Tokenization error -- seq length: {} and shift_logits length - 1 : {}".format(len(mutated_sequence),shift_logits.shape[1]-1))
|
801 |
+
MSA_log_prior, MSA_start, MSA_end = msa_utils.update_retrieved_MSA_log_prior_indel(self, self.MSA_log_prior, self.MSA_start, self.MSA_end, mutated_sequence[0])
|
802 |
+
|
803 |
+
elif self.retrieval_aggregation_mode=="aggregate_substitution":
|
804 |
+
MSA_log_prior=self.MSA_log_prior
|
805 |
+
MSA_start=self.MSA_start
|
806 |
+
MSA_end=self.MSA_end
|
807 |
+
|
808 |
+
shift_log_probas = torch.log_softmax(shift_logits,dim=-1)
|
809 |
+
fused_shift_log_probas = shift_log_probas.clone()
|
810 |
+
if flip is None:
|
811 |
+
flip = torch.zeros(batch_size).to(fused_shift_log_probas.device)
|
812 |
+
flip = flip > 0
|
813 |
+
|
814 |
+
for seq_index in range(batch_size):
|
815 |
+
min_prior_slice = max(start_slice[seq_index], MSA_start)
|
816 |
+
max_prior_slice = min(end_slice[seq_index], MSA_end)
|
817 |
+
|
818 |
+
if max_prior_slice <= min_prior_slice:
|
819 |
+
print("Non overlapping region detected: min_prior_slice {} and max_prior_slice {}".format(min_prior_slice,max_prior_slice))
|
820 |
+
continue
|
821 |
+
|
822 |
+
slice_prior = MSA_log_prior[min_prior_slice:max_prior_slice,:].to(fused_shift_log_probas.device)
|
823 |
+
if flip[seq_index]:
|
824 |
+
slice_prior = torch.flip(slice_prior,dims=(0,))
|
825 |
+
min_logits_slice = max(0,end_slice[seq_index]-MSA_end)
|
826 |
+
max_logits_slice = min_logits_slice + (max_prior_slice-min_prior_slice)
|
827 |
+
fused_shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] = (1-self.retrieval_inference_weight_RL)*shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] + self.retrieval_inference_weight_RL*slice_prior
|
828 |
+
else:
|
829 |
+
min_logits_slice = max(0, MSA_start-start_slice[seq_index])
|
830 |
+
max_logits_slice = min_logits_slice + (max_prior_slice-min_prior_slice)
|
831 |
+
fused_shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] = (1-self.retrieval_inference_weight_LR)*shift_log_probas[seq_index,min_logits_slice:max_logits_slice,:] + self.retrieval_inference_weight_LR*slice_prior
|
832 |
+
|
833 |
+
if self.retrieval_aggregation_mode=="aggregate_indel":
|
834 |
+
try:
|
835 |
+
# If a given residue colume is an added zero-column, then we overwrite prior fusion and only predict based on the autoregressive transformer inference mode.
|
836 |
+
inserted_retrieval_positions = [True if slice_prior[i].sum()==0 else False for i in range(len(slice_prior))]+[True] #Last True is for the end of sentence token
|
837 |
+
fused_shift_log_probas[:,inserted_retrieval_positions,:]=shift_log_probas[:,inserted_retrieval_positions,:]
|
838 |
+
except:
|
839 |
+
print("Error when adding zero column(s) to account for insertion mutations.")
|
840 |
+
|
841 |
+
loss_fct = NLLLoss(reduction='none')
|
842 |
+
loss = loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
|
843 |
+
mask = attention_mask[..., 1:].float()
|
844 |
+
mask[mask==0]=float('nan')
|
845 |
+
loss *= mask
|
846 |
+
loss = nanmean(loss, dim=1).mean()
|
847 |
+
else:
|
848 |
+
loss_fct = CrossEntropyLoss()
|
849 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
850 |
+
fused_shift_log_probas = None
|
851 |
+
|
852 |
+
if not return_dict:
|
853 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
854 |
+
return ((loss,) + output) if loss is not None else output
|
855 |
+
|
856 |
+
return TranceptionCausalLMOutputWithCrossAttentions(
|
857 |
+
loss=loss,
|
858 |
+
logits=lm_logits,
|
859 |
+
past_key_values=transformer_outputs.past_key_values,
|
860 |
+
hidden_states=transformer_outputs.hidden_states,
|
861 |
+
attentions=transformer_outputs.attentions,
|
862 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
863 |
+
fused_shift_log_probas=fused_shift_log_probas
|
864 |
+
)
|
865 |
+
|
866 |
+
|
867 |
+
@staticmethod
|
868 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
869 |
+
"""
|
870 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
871 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
872 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
873 |
+
"""
|
874 |
+
return tuple(
|
875 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
876 |
+
for layer_past in past
|
877 |
+
)
|
878 |
+
|
879 |
+
def score_mutants(self, DMS_data, target_seq=None, scoring_mirror=True, batch_size_inference=10, num_workers=10, indel_mode=False):
|
880 |
+
"""
|
881 |
+
Method to score mutants in an input DMS file.
|
882 |
+
DMS_data: (dataframe) Dataframe containing the list of mutated sequences for scoring.
|
883 |
+
target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay. If not None, returned scores are delta log likelihood wrt that sequence.
|
884 |
+
scoring_mirror: (bool) Whether to score mutated sequences from both directions (Left->Right and Right->Left).
|
885 |
+
batch_size_inference: (int) Batch size for scoring.
|
886 |
+
num_workers: (int) Number of workers to be used in the data loader.
|
887 |
+
indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
|
888 |
+
"""
|
889 |
+
df = DMS_data.copy()
|
890 |
+
if ('mutated_sequence' not in df) and (not indel_mode): df['mutated_sequence'] = df['mutant'].apply(lambda x: scoring_utils.get_mutated_sequence(target_seq, x))
|
891 |
+
assert ('mutated_sequence' in df), "DMS file to score does not have mutated_sequence column"
|
892 |
+
#if 'mutant' not in df: df['mutant'] = df['mutated_sequence'] #if mutant not in DMS file we default to mutated_sequence
|
893 |
+
if 'DMS_score' in df: del df['DMS_score']
|
894 |
+
if 'DMS_score_bin' in df: del df['DMS_score_bin']
|
895 |
+
if target_seq is not None:
|
896 |
+
df_left_to_right_slices = scoring_utils.get_sequence_slices(df, target_seq=target_seq, model_context_len = self.config.n_ctx - 2, indel_mode=indel_mode, scoring_window=self.config.scoring_window)
|
897 |
+
else:
|
898 |
+
df_left_to_right_slices = scoring_utils.get_sequence_slices(df, target_seq=list(df['mutated_sequence'])[0], model_context_len = self.config.n_ctx - 2, indel_mode=indel_mode, scoring_window='sliding')
|
899 |
+
print("Scoring sequences from left to right")
|
900 |
+
scores_L_to_R = scoring_utils.get_tranception_scores_mutated_sequences(model=self, mutated_sequence_df=df_left_to_right_slices, batch_size_inference=batch_size_inference, score_var_name='avg_score_L_to_R', target_seq=target_seq, num_workers=num_workers, indel_mode=indel_mode)
|
901 |
+
if scoring_mirror:
|
902 |
+
print("Scoring sequences from right to left")
|
903 |
+
df_right_to_left_slices = df_left_to_right_slices.copy()
|
904 |
+
df_right_to_left_slices['sliced_mutated_sequence'] = df_right_to_left_slices['sliced_mutated_sequence'].apply(lambda x: x[::-1])
|
905 |
+
scores_R_to_L = scoring_utils.get_tranception_scores_mutated_sequences(model=self, mutated_sequence_df=df_right_to_left_slices, batch_size_inference=batch_size_inference, score_var_name='avg_score_R_to_L', target_seq=target_seq, num_workers=num_workers, reverse=True, indel_mode=indel_mode)
|
906 |
+
all_scores = pd.merge(scores_L_to_R, scores_R_to_L, on='mutated_sequence', how='left', suffixes=('','_R_to_L'))
|
907 |
+
all_scores['avg_score'] = (all_scores['avg_score_L_to_R'] + all_scores['avg_score_R_to_L']) / 2.0
|
908 |
+
else:
|
909 |
+
all_scores = scores_L_to_R
|
910 |
+
all_scores['avg_score'] = all_scores['avg_score_L_to_R']
|
911 |
+
#By design "get_tranception_scores_mutated_sequences" drops the WT from the output. We add it back if that was one of the sequences to score in the DMS (score=0 by definition)
|
912 |
+
if target_seq in DMS_data.mutated_sequence.values:
|
913 |
+
print("LEMON")
|
914 |
+
if scoring_mirror:
|
915 |
+
wt_row = pd.DataFrame([[target_seq,0,0,0]], columns=['mutated_sequence','avg_score_L_to_R','avg_score_R_to_L','avg_score'])
|
916 |
+
else:
|
917 |
+
wt_row = pd.DataFrame([[target_seq,0,0]], columns=['mutated_sequence','avg_score_L_to_R','avg_score'])
|
918 |
+
all_scores = pd.concat([all_scores,wt_row], ignore_index=True)
|
919 |
+
return all_scores
|
920 |
+
|
921 |
+
def encode_batch(self, protein_sequence, sequence_name="sliced_mutated_sequence"):
|
922 |
+
"""
|
923 |
+
Method to process an input AA sequence batch (protein_sequence) and return a tokenized sequence (via the tokenizer associated to the model).
|
924 |
+
"""
|
925 |
+
protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='X', char_replacements='ACDEFGHIKLMNPQRSTVWY')
|
926 |
+
protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='B', char_replacements='DN')
|
927 |
+
protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='J', char_replacements='IL')
|
928 |
+
protein_sequence[sequence_name] = scoring_utils.sequence_replace(sequences=protein_sequence[sequence_name], char_to_replace='Z', char_replacements='EQ')
|
929 |
+
return self.config.tokenizer(list(protein_sequence[sequence_name]), add_special_tokens=True, truncation=True, padding=True, max_length=self.config.n_ctx)
|
930 |
+
|
tranception/outputs.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from transformers.file_utils import ModelOutput
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class TranceptionCausalLMOutputWithCrossAttentions(ModelOutput):
|
10 |
+
"""
|
11 |
+
Class for Tranception causal language model (or autoregressive) outputs.
|
12 |
+
Args:
|
13 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
14 |
+
Language modeling loss (for next-token prediction).
|
15 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
16 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
17 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
18 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
19 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
20 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
21 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
22 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
23 |
+
sequence_length)`.
|
24 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
25 |
+
heads.
|
26 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
27 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
28 |
+
sequence_length)`.
|
29 |
+
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
30 |
+
cross-attention heads.
|
31 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
32 |
+
Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
|
33 |
+
value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
|
34 |
+
setting. Only relevant if `config.is_decoder = True`.
|
35 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
36 |
+
`past_key_values` input) to speed up sequential decoding.
|
37 |
+
fused_shift_log_probas (`torch.FloatTensor` of shape (batch_size, sequence_length, config.vocab_size), *optional*, returned when config.retrieval_aggregation_mode is not None.
|
38 |
+
log_probas for each residue position after aggregating autoregressive logits and retrieval logits.
|
39 |
+
|
40 |
+
"""
|
41 |
+
|
42 |
+
loss: Optional[torch.FloatTensor] = None
|
43 |
+
logits: torch.FloatTensor = None
|
44 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
45 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
46 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
47 |
+
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
48 |
+
fused_shift_log_probas: Optional[torch.FloatTensor] = None
|
tranception/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import scoring_utils, msa_utils
|
tranception/utils/dms_utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from tranception.utils import scoring_utils
|
4 |
+
|
5 |
+
def DMS_file_cleanup(DMS_filename, target_seq, start_idx=1, end_idx=None, DMS_mutant_column='mutant', DMS_phenotype_name='score', DMS_directionality=1, AA_vocab = "ACDEFGHIKLMNPQRSTVWY"):
|
6 |
+
"""
|
7 |
+
Function to process the raw substitution DMS assay data (eg., removing invalid mutants, aggregate silent mutations).
|
8 |
+
"""
|
9 |
+
DMS_data = pd.read_csv(DMS_filename, low_memory=False)
|
10 |
+
end_idx = start_idx + len(target_seq) - 1 if end_idx is None else end_idx
|
11 |
+
DMS_data['mutant'] = DMS_data[DMS_mutant_column]
|
12 |
+
|
13 |
+
DMS_data=DMS_data[DMS_data['mutant'].notnull()].copy()
|
14 |
+
DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([len(y)>=3 for y in x.split(":")]))].copy() #Mutant triplets should have at least 3 or more characters
|
15 |
+
DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([(y[0] in AA_vocab) and (y[1:-1].isnumeric()) and (y[-1] in AA_vocab) for y in x.split(":")]))].copy()
|
16 |
+
DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([int(y[1:-1])-start_idx >=0 and int(y[1:-1]) <= end_idx for y in x.split(":")]))].copy()
|
17 |
+
DMS_data=DMS_data[DMS_data['mutant'].apply(lambda x: all([y[0]==target_seq[int(y[1:-1])-start_idx] for y in x.split(":")]))].copy()
|
18 |
+
|
19 |
+
DMS_data[DMS_phenotype_name]=pd.to_numeric(DMS_data[DMS_phenotype_name],errors='coerce')
|
20 |
+
DMS_data=DMS_data[np.isfinite(DMS_data[DMS_phenotype_name])]
|
21 |
+
DMS_data.dropna(subset = [DMS_phenotype_name], inplace=True)
|
22 |
+
DMS_data['DMS_score'] = DMS_data[DMS_phenotype_name] * DMS_directionality
|
23 |
+
DMS_data=DMS_data[['mutant','DMS_score']]
|
24 |
+
DMS_data=DMS_data.groupby('mutant').mean().reset_index()
|
25 |
+
|
26 |
+
DMS_data['mutated_sequence'] = DMS_data['mutant'].apply(lambda x: scoring_utils.get_mutated_sequence(target_seq, x))
|
27 |
+
DMS_data=DMS_data[['mutant','mutated_sequence','DMS_score']]
|
28 |
+
|
29 |
+
return DMS_data
|
30 |
+
|
tranception/utils/msa_utils.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from collections import defaultdict
|
4 |
+
import random
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from Bio.Align.Applications import ClustalOmegaCommandline
|
8 |
+
|
9 |
+
def filter_msa(msa_data, num_sequences_kept=3):
|
10 |
+
"""
|
11 |
+
Helper function to filter an input MSA msa_data (obtained via process_msa_data) and keep only num_sequences_kept aligned sequences.
|
12 |
+
If the MSA already has fewer sequences than num_sequences_kept, we keep the MSA as is.
|
13 |
+
If filtering, we always keep the first sequence of the MSA (ie. the wild type) by default.
|
14 |
+
Sampling is done without replacement.
|
15 |
+
"""
|
16 |
+
if len(list(msa_data.keys())) <= num_sequences_kept:
|
17 |
+
return msa_data
|
18 |
+
filtered_msa = {}
|
19 |
+
wt_name = next(iter(msa_data))
|
20 |
+
filtered_msa[wt_name] = msa_data[wt_name]
|
21 |
+
del msa_data[wt_name]
|
22 |
+
sequence_names = list(msa_data.keys())
|
23 |
+
sequence_names_sampled = random.sample(sequence_names,k=num_sequences_kept-1)
|
24 |
+
for seq in sequence_names_sampled:
|
25 |
+
filtered_msa[seq] = msa_data[seq]
|
26 |
+
return filtered_msa
|
27 |
+
|
28 |
+
def process_msa_data(MSA_data_file):
|
29 |
+
"""
|
30 |
+
Helper function that takes as input a path to a MSA file (expects a2m format) and returns a dict mapping sequence ID to the corresponding AA sequence.
|
31 |
+
"""
|
32 |
+
msa_data = defaultdict(str)
|
33 |
+
sequence_name = ""
|
34 |
+
with open(MSA_data_file, "r") as msa_file:
|
35 |
+
for i, line in enumerate(msa_file):
|
36 |
+
line = line.rstrip()
|
37 |
+
if line.startswith(">"):
|
38 |
+
sequence_name = line
|
39 |
+
else:
|
40 |
+
msa_data[sequence_name] += line.upper()
|
41 |
+
return msa_data
|
42 |
+
|
43 |
+
def get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab):
|
44 |
+
vocab_size = len(vocab.keys())
|
45 |
+
num_sequences_msa = len(msa_data.keys())
|
46 |
+
one_hots = np.zeros((num_sequences_msa,MSA_end-MSA_start,vocab_size))
|
47 |
+
for i,seq_name in enumerate(msa_data.keys()):
|
48 |
+
sequence = msa_data[seq_name]
|
49 |
+
for j,letter in enumerate(sequence):
|
50 |
+
if letter in vocab:
|
51 |
+
k = vocab[letter]
|
52 |
+
one_hots[i,j,k] = 1.0
|
53 |
+
return one_hots
|
54 |
+
|
55 |
+
def one_hot(sequence_string,vocab):
|
56 |
+
one_hots = np.zeros((len(sequence_string),len(vocab.keys())))
|
57 |
+
for j,letter in enumerate(sequence_string):
|
58 |
+
if letter in vocab:
|
59 |
+
k = vocab[letter]
|
60 |
+
one_hots[j,k] = 1.0
|
61 |
+
return one_hots.flatten()
|
62 |
+
|
63 |
+
def get_msa_prior(MSA_data_file, MSA_weight_file_name, MSA_start, MSA_end, len_target_seq, vocab, retrieval_aggregation_mode="aggregate_substitution", filter_MSA=True, verbose=False):
|
64 |
+
"""
|
65 |
+
Function to enable retrieval inference mode, via computation of (weighted) pseudocounts of AAs at each position of the retrieved MSA.
|
66 |
+
MSA_data_file: (string) path to MSA file (expects a2m format).
|
67 |
+
MSA_weight_file_name: (string) path to sequence weights in MSA.
|
68 |
+
MSA_start: (int) Sequence position that the MSA starts at (1-indexing).
|
69 |
+
MSA_end: (int) Sequence position that the MSA ends at (1-indexing).
|
70 |
+
len_target_seq: (int) Full length of sequence to be scored.
|
71 |
+
vocab: (dict) Vocabulary of the tokenizer.
|
72 |
+
retrieval_aggregation_mode: (string) Mode for retrieval inference (aggregate_substitution Vs aggregate_indel). If None, places a uniform prior over each token.
|
73 |
+
filter_MSA: (bool) Whether to filter out sequences with very low hamming similarity (< 0.2) to the reference sequence in the MSA (first sequence).
|
74 |
+
verbose: (bool) Whether to print to the console processing details along the way.
|
75 |
+
"""
|
76 |
+
msa_data = process_msa_data(MSA_data_file)
|
77 |
+
vocab_size = len(vocab.keys())
|
78 |
+
if verbose: print("Target seq len is {}, MSA length is {}, start position is {}, end position is {} and vocab size is {}".format(len_target_seq,MSA_end-MSA_start,MSA_start,MSA_end,vocab_size))
|
79 |
+
|
80 |
+
if filter_MSA:
|
81 |
+
if verbose: print("Num sequences in MSA pre filtering: {}".format(len(msa_data.keys())))
|
82 |
+
list_sequence_names = list(msa_data.keys())
|
83 |
+
focus_sequence_name = list(msa_data.keys())[0]
|
84 |
+
ref_sequence_hot = one_hot(msa_data[focus_sequence_name],vocab)
|
85 |
+
for sequence_name in list_sequence_names:
|
86 |
+
seq_hot = one_hot(msa_data[sequence_name],vocab)
|
87 |
+
hamming_similarity_seq_ref = np.dot(ref_sequence_hot,seq_hot) / np.dot(ref_sequence_hot,ref_sequence_hot)
|
88 |
+
if hamming_similarity_seq_ref < 0.2:
|
89 |
+
del msa_data[sequence_name]
|
90 |
+
if verbose: print("Num sequences in MSA post filtering: {}".format(len(msa_data.keys())))
|
91 |
+
|
92 |
+
if MSA_weight_file_name is not None:
|
93 |
+
if verbose: print("Using weights in {} for sequences in MSA.".format(MSA_weight_file_name))
|
94 |
+
assert os.path.exists(MSA_weight_file_name), "Weights file not located on disk."
|
95 |
+
MSA_EVE = MSA_processing(
|
96 |
+
MSA_location=MSA_data_file,
|
97 |
+
use_weights=True,
|
98 |
+
weights_location=MSA_weight_file_name
|
99 |
+
)
|
100 |
+
#We scan through all sequences to see if we have a weight for them as per EVE pre-processing. We drop them otherwise.
|
101 |
+
dropped_sequences=0
|
102 |
+
list_sequence_names = list(msa_data.keys())
|
103 |
+
MSA_weight=[]
|
104 |
+
for sequence_name in list_sequence_names:
|
105 |
+
if sequence_name not in MSA_EVE.seq_name_to_sequence:
|
106 |
+
dropped_sequences +=1
|
107 |
+
del msa_data[sequence_name]
|
108 |
+
else:
|
109 |
+
MSA_weight.append(MSA_EVE.seq_name_to_weight[sequence_name])
|
110 |
+
if verbose: print("Dropped {} sequences from MSA due to absent sequence weights".format(dropped_sequences))
|
111 |
+
else:
|
112 |
+
MSA_weight = [1] * len(list(msa_data.keys()))
|
113 |
+
|
114 |
+
if retrieval_aggregation_mode=="aggregate_substitution" or retrieval_aggregation_mode=="aggregate_indel":
|
115 |
+
one_hots = get_one_hot_sequences_dict(msa_data,MSA_start,MSA_end,vocab)
|
116 |
+
MSA_weight = np.expand_dims(np.array(MSA_weight),axis=(1,2))
|
117 |
+
base_rate = 1e-5
|
118 |
+
base_rates = np.ones_like(one_hots) * base_rate
|
119 |
+
weighted_one_hots = (one_hots + base_rates) * MSA_weight
|
120 |
+
MSA_weight_norm_counts = weighted_one_hots.sum(axis=-1).sum(axis=0)
|
121 |
+
MSA_weight_norm_counts = np.tile(MSA_weight_norm_counts.reshape(-1,1), (1,vocab_size))
|
122 |
+
one_hots_avg = weighted_one_hots.sum(axis=0) / MSA_weight_norm_counts
|
123 |
+
msa_prior = np.zeros((len_target_seq,vocab_size))
|
124 |
+
msa_prior[MSA_start:MSA_end,:]=one_hots_avg
|
125 |
+
else:
|
126 |
+
msa_prior = np.ones((len_target_seq,vocab_size)) / vocab_size
|
127 |
+
|
128 |
+
if verbose:
|
129 |
+
for idx, position in enumerate(msa_prior):
|
130 |
+
if len(position)!=25:
|
131 |
+
print("Size error")
|
132 |
+
if not round(position.sum(),2)==1.0:
|
133 |
+
print("Position at index {} does not add up to 1: {}".format(idx, position.sum()))
|
134 |
+
|
135 |
+
return msa_prior
|
136 |
+
|
137 |
+
|
138 |
+
def update_retrieved_MSA_log_prior_indel(model, MSA_log_prior, MSA_start, MSA_end, mutated_sequence):
|
139 |
+
"""
|
140 |
+
Function to process MSA when scoring indels.
|
141 |
+
To identify positions to add / remove in the retrieved MSA, we append and align the sequence to be scored to the original MSA for that protein family with Clustal Omega.
|
142 |
+
If the original MSA is relatively deep (over 100k sequences), we sample (by default) 100k rows at random from that MSA to speed computations.
|
143 |
+
MSA sampling is performed only once (for the first sequence to be scored). Subsequent scoring use the same MSA sample.
|
144 |
+
"""
|
145 |
+
if not os.path.isdir(model.MSA_folder + os.sep + "Sampled"):
|
146 |
+
os.mkdir(model.MSA_folder + os.sep + "Sampled")
|
147 |
+
sampled_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Sampled_" + model.MSA_filename.split(os.sep)[-1]
|
148 |
+
|
149 |
+
if not os.path.exists(sampled_MSA_location):
|
150 |
+
msa_data = process_msa_data(model.MSA_filename)
|
151 |
+
msa_data_sampled = filter_msa(msa_data, num_sequences_kept=100000) #If MSA has less than 100k sequences, the sample is identical to original MSA
|
152 |
+
with open(sampled_MSA_location, 'w') as sampled_write_location:
|
153 |
+
for index, key in enumerate(msa_data_sampled):
|
154 |
+
key_name = ">REFERENCE_SEQUENCE" if index==0 else key
|
155 |
+
msa_data_sampled[key] = msa_data_sampled[key].upper()
|
156 |
+
msa_data_sampled[key] = msa_data_sampled[key].replace(".","-")
|
157 |
+
sampled_write_location.write(key_name+"\n"+"\n".join([msa_data_sampled[key][i:i+80] for i in range(0, len(msa_data_sampled[key]), 80)])+"\n")
|
158 |
+
|
159 |
+
seq_to_align_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Seq_to_align_" + model.MSA_filename.split(os.sep)[-1]
|
160 |
+
sequence_text_split = [mutated_sequence[i:i+80] for i in range(0, len(mutated_sequence), 80)]
|
161 |
+
sequence_text_split_split_join = "\n".join([">SEQ_TO_SCORE"]+sequence_text_split)
|
162 |
+
os.system("echo '"+sequence_text_split_split_join+"' > "+seq_to_align_location)
|
163 |
+
|
164 |
+
expanded_MSA_location = model.MSA_folder + os.sep + "Sampled" + os.sep + "Expanded_" + model.MSA_filename.split(os.sep)[-1]
|
165 |
+
clustalw_cline = ClustalOmegaCommandline(cmd=model.config.clustal_omega_location,
|
166 |
+
profile1=sampled_MSA_location,
|
167 |
+
profile2=seq_to_align_location,
|
168 |
+
outfile=expanded_MSA_location,
|
169 |
+
force=True)
|
170 |
+
stdout, stderr = clustalw_cline()
|
171 |
+
msa_data = process_msa_data(expanded_MSA_location)
|
172 |
+
aligned_seqA, aligned_seqB = msa_data[">SEQ_TO_SCORE"], msa_data[">REFERENCE_SEQUENCE"]
|
173 |
+
try:
|
174 |
+
keep_column=[]
|
175 |
+
for column_index_pairwise_alignment in range(len(aligned_seqA)):
|
176 |
+
if aligned_seqA[column_index_pairwise_alignment]=="-" and aligned_seqB[column_index_pairwise_alignment]=="-":
|
177 |
+
continue
|
178 |
+
elif aligned_seqA[column_index_pairwise_alignment]=="-":
|
179 |
+
keep_column.append(False)
|
180 |
+
elif aligned_seqB[column_index_pairwise_alignment]=="-":
|
181 |
+
MSA_log_prior=torch.cat((MSA_log_prior[:column_index_pairwise_alignment], torch.zeros(MSA_log_prior.shape[1]).view(1,-1).cuda(), MSA_log_prior[column_index_pairwise_alignment:]),dim=0)
|
182 |
+
keep_column.append(True) #keep the zero column we just added
|
183 |
+
else:
|
184 |
+
keep_column.append(True)
|
185 |
+
MSA_log_prior = MSA_log_prior[keep_column]
|
186 |
+
MSA_end = MSA_start + len(MSA_log_prior)
|
187 |
+
except:
|
188 |
+
print("Error when processing the following alignment: {}".format(expanded_MSA_location))
|
189 |
+
return MSA_log_prior, MSA_start, MSA_end
|
190 |
+
|
191 |
+
class MSA_processing:
|
192 |
+
def __init__(self,
|
193 |
+
MSA_location="",
|
194 |
+
theta=0.2,
|
195 |
+
use_weights=True,
|
196 |
+
weights_location="./data/weights",
|
197 |
+
preprocess_MSA=True,
|
198 |
+
threshold_sequence_frac_gaps=0.5,
|
199 |
+
threshold_focus_cols_frac_gaps=0.3,
|
200 |
+
remove_sequences_with_indeterminate_AA_in_focus_cols=True
|
201 |
+
):
|
202 |
+
|
203 |
+
"""
|
204 |
+
This MSA_processing class is directly borrowed from the EVE codebase: https://github.com/OATML-Markslab/EVE
|
205 |
+
|
206 |
+
Parameters:
|
207 |
+
- msa_location: (path) Location of the MSA data. Constraints on input MSA format:
|
208 |
+
- focus_sequence is the first one in the MSA data
|
209 |
+
- first line is structured as follows: ">focus_seq_name/start_pos-end_pos" (e.g., >SPIKE_SARS2/310-550)
|
210 |
+
- corespondding sequence data located on following line(s)
|
211 |
+
- then all other sequences follow with ">name" on first line, corresponding data on subsequent lines
|
212 |
+
- theta: (float) Sequence weighting hyperparameter. Generally: Prokaryotic and eukaryotic families = 0.2; Viruses = 0.01
|
213 |
+
- use_weights: (bool) If False, sets all sequence weights to 1. If True, checks weights_location -- if non empty uses that;
|
214 |
+
otherwise compute weights from scratch and store them at weights_location
|
215 |
+
- weights_location: (path) Location to load from/save to the sequence weights
|
216 |
+
- preprocess_MSA: (bool) performs pre-processing of MSA to remove short fragments and positions that are not well covered.
|
217 |
+
- threshold_sequence_frac_gaps: (float, between 0 and 1) Threshold value to define fragments
|
218 |
+
- sequences with a fraction of gap characters above threshold_sequence_frac_gaps are removed
|
219 |
+
- default is set to 0.5 (i.e., fragments with 50% or more gaps are removed)
|
220 |
+
- threshold_focus_cols_frac_gaps: (float, between 0 and 1) Threshold value to define focus columns
|
221 |
+
- positions with a fraction of gap characters above threshold_focus_cols_pct_gaps will be set to lower case (and not included in the focus_cols)
|
222 |
+
- default is set to 0.3 (i.e., focus positions are the ones with 30% of gaps or less, i.e., 70% or more residue occupancy)
|
223 |
+
- remove_sequences_with_indeterminate_AA_in_focus_cols: (bool) Remove all sequences that have indeterminate AA (e.g., B, J, X, Z) at focus positions of the wild type
|
224 |
+
"""
|
225 |
+
np.random.seed(2021)
|
226 |
+
self.MSA_location = MSA_location
|
227 |
+
self.weights_location = weights_location
|
228 |
+
self.theta = theta
|
229 |
+
self.alphabet = "ACDEFGHIKLMNPQRSTVWY"
|
230 |
+
self.use_weights = use_weights
|
231 |
+
self.preprocess_MSA = preprocess_MSA
|
232 |
+
self.threshold_sequence_frac_gaps = threshold_sequence_frac_gaps
|
233 |
+
self.threshold_focus_cols_frac_gaps = threshold_focus_cols_frac_gaps
|
234 |
+
self.remove_sequences_with_indeterminate_AA_in_focus_cols = remove_sequences_with_indeterminate_AA_in_focus_cols
|
235 |
+
|
236 |
+
self.gen_alignment()
|
237 |
+
|
238 |
+
def gen_alignment(self, verbose=False):
|
239 |
+
""" Read training alignment and store basics in class instance """
|
240 |
+
self.aa_dict = {}
|
241 |
+
for i,aa in enumerate(self.alphabet):
|
242 |
+
self.aa_dict[aa] = i
|
243 |
+
|
244 |
+
self.seq_name_to_sequence = defaultdict(str)
|
245 |
+
name = ""
|
246 |
+
with open(self.MSA_location, "r") as msa_data:
|
247 |
+
for i, line in enumerate(msa_data):
|
248 |
+
line = line.rstrip()
|
249 |
+
if line.startswith(">"):
|
250 |
+
name = line
|
251 |
+
if i==0:
|
252 |
+
self.focus_seq_name = name
|
253 |
+
else:
|
254 |
+
self.seq_name_to_sequence[name] += line
|
255 |
+
|
256 |
+
|
257 |
+
## MSA pre-processing to remove inadequate columns and sequences
|
258 |
+
if self.preprocess_MSA:
|
259 |
+
msa_df = pd.DataFrame.from_dict(self.seq_name_to_sequence, orient='index', columns=['sequence'])
|
260 |
+
# Data clean up
|
261 |
+
msa_df.sequence = msa_df.sequence.apply(lambda x: x.replace(".","-")).apply(lambda x: ''.join([aa.upper() for aa in x]))
|
262 |
+
# Remove columns that would be gaps in the wild type
|
263 |
+
non_gap_wt_cols = [aa!='-' for aa in msa_df.sequence[self.focus_seq_name]]
|
264 |
+
msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa for aa,non_gap_ind in zip(x, non_gap_wt_cols) if non_gap_ind]))
|
265 |
+
assert 0.0 <= self.threshold_sequence_frac_gaps <= 1.0,"Invalid fragment filtering parameter"
|
266 |
+
assert 0.0 <= self.threshold_focus_cols_frac_gaps <= 1.0,"Invalid focus position filtering parameter"
|
267 |
+
msa_array = np.array([list(seq) for seq in msa_df.sequence])
|
268 |
+
gaps_array = np.array(list(map(lambda seq: [aa=='-' for aa in seq], msa_array)))
|
269 |
+
# Identify fragments with too many gaps
|
270 |
+
seq_gaps_frac = gaps_array.mean(axis=1)
|
271 |
+
seq_below_threshold = seq_gaps_frac <= self.threshold_sequence_frac_gaps
|
272 |
+
if verbose: print("Proportion of sequences dropped due to fraction of gaps: "+str(round(float(1 - seq_below_threshold.sum()/seq_below_threshold.shape)*100,2))+"%")
|
273 |
+
# Identify focus columns
|
274 |
+
columns_gaps_frac = gaps_array[seq_below_threshold].mean(axis=0)
|
275 |
+
index_cols_below_threshold = columns_gaps_frac <= self.threshold_focus_cols_frac_gaps
|
276 |
+
if verbose: print("Proportion of non-focus columns removed: "+str(round(float(1 - index_cols_below_threshold.sum()/index_cols_below_threshold.shape)*100,2))+"%")
|
277 |
+
# Lower case non focus cols and filter fragment sequences
|
278 |
+
msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa.upper() if upper_case_ind else aa.lower() for aa, upper_case_ind in zip(x, index_cols_below_threshold)]))
|
279 |
+
msa_df = msa_df[seq_below_threshold]
|
280 |
+
# Overwrite seq_name_to_sequence with clean version
|
281 |
+
self.seq_name_to_sequence = defaultdict(str)
|
282 |
+
for seq_idx in range(len(msa_df['sequence'])):
|
283 |
+
self.seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx]
|
284 |
+
|
285 |
+
self.focus_seq = self.seq_name_to_sequence[self.focus_seq_name]
|
286 |
+
self.focus_cols = [ix for ix, s in enumerate(self.focus_seq) if s == s.upper() and s!='-']
|
287 |
+
self.focus_seq_trimmed = [self.focus_seq[ix] for ix in self.focus_cols]
|
288 |
+
self.seq_len = len(self.focus_cols)
|
289 |
+
self.alphabet_size = len(self.alphabet)
|
290 |
+
|
291 |
+
# Connect local sequence index with uniprot index (index shift inferred from 1st row of MSA)
|
292 |
+
focus_loc = self.focus_seq_name.split("/")[-1]
|
293 |
+
start,stop = focus_loc.split("-")
|
294 |
+
self.focus_start_loc = int(start)
|
295 |
+
self.focus_stop_loc = int(stop)
|
296 |
+
self.uniprot_focus_col_to_wt_aa_dict \
|
297 |
+
= {idx_col+int(start):self.focus_seq[idx_col] for idx_col in self.focus_cols}
|
298 |
+
self.uniprot_focus_col_to_focus_idx \
|
299 |
+
= {idx_col+int(start):idx_col for idx_col in self.focus_cols}
|
300 |
+
|
301 |
+
# Move all letters to CAPS; keeps focus columns only
|
302 |
+
self.raw_seq_name_to_sequence = self.seq_name_to_sequence.copy()
|
303 |
+
for seq_name,sequence in self.seq_name_to_sequence.items():
|
304 |
+
sequence = sequence.replace(".","-")
|
305 |
+
self.seq_name_to_sequence[seq_name] = [sequence[ix].upper() for ix in self.focus_cols]
|
306 |
+
|
307 |
+
# Remove sequences that have indeterminate AA (e.g., B, J, X, Z) in the focus columns
|
308 |
+
if self.remove_sequences_with_indeterminate_AA_in_focus_cols:
|
309 |
+
alphabet_set = set(list(self.alphabet))
|
310 |
+
seq_names_to_remove = []
|
311 |
+
for seq_name,sequence in self.seq_name_to_sequence.items():
|
312 |
+
for letter in sequence:
|
313 |
+
if letter not in alphabet_set and letter != "-":
|
314 |
+
seq_names_to_remove.append(seq_name)
|
315 |
+
continue
|
316 |
+
seq_names_to_remove = list(set(seq_names_to_remove))
|
317 |
+
for seq_name in seq_names_to_remove:
|
318 |
+
del self.seq_name_to_sequence[seq_name]
|
319 |
+
|
320 |
+
# Encode the sequences
|
321 |
+
self.one_hot_encoding = np.zeros((len(self.seq_name_to_sequence.keys()),len(self.focus_cols),len(self.alphabet)))
|
322 |
+
if verbose: print("One-hot encoded sequences shape:" + str(self.one_hot_encoding.shape))
|
323 |
+
for i,seq_name in enumerate(self.seq_name_to_sequence.keys()):
|
324 |
+
sequence = self.seq_name_to_sequence[seq_name]
|
325 |
+
for j,letter in enumerate(sequence):
|
326 |
+
if letter in self.aa_dict:
|
327 |
+
k = self.aa_dict[letter]
|
328 |
+
self.one_hot_encoding[i,j,k] = 1.0
|
329 |
+
|
330 |
+
if self.use_weights:
|
331 |
+
try:
|
332 |
+
self.weights = np.load(file=self.weights_location)
|
333 |
+
if verbose: print("Loaded sequence weights from disk")
|
334 |
+
except:
|
335 |
+
if verbose: print ("Computing sequence weights")
|
336 |
+
list_seq = self.one_hot_encoding
|
337 |
+
list_seq = list_seq.reshape((list_seq.shape[0], list_seq.shape[1] * list_seq.shape[2]))
|
338 |
+
def compute_weight(seq):
|
339 |
+
number_non_empty_positions = np.dot(seq,seq)
|
340 |
+
if number_non_empty_positions>0:
|
341 |
+
denom = np.dot(list_seq,seq) / np.dot(seq,seq)
|
342 |
+
denom = np.sum(denom > 1 - self.theta)
|
343 |
+
return 1/denom
|
344 |
+
else:
|
345 |
+
return 0.0 #return 0 weight if sequence is fully empty
|
346 |
+
self.weights = np.array(list(map(compute_weight,list_seq)))
|
347 |
+
np.save(file=self.weights_location, arr=self.weights)
|
348 |
+
else:
|
349 |
+
# If not using weights, use an isotropic weight matrix
|
350 |
+
if verbose: print("Not weighting sequence data")
|
351 |
+
self.weights = np.ones(self.one_hot_encoding.shape[0])
|
352 |
+
|
353 |
+
self.Neff = np.sum(self.weights)
|
354 |
+
self.num_sequences = self.one_hot_encoding.shape[0]
|
355 |
+
self.seq_name_to_weight={}
|
356 |
+
for i,seq_name in enumerate(self.seq_name_to_sequence.keys()):
|
357 |
+
self.seq_name_to_weight[seq_name]=self.weights[i]
|
358 |
+
|
359 |
+
if verbose:
|
360 |
+
print ("Neff =",str(self.Neff))
|
361 |
+
print ("Data Shape =",self.one_hot_encoding.shape)
|
tranception/utils/scoring_utils.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tqdm
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.nn import CrossEntropyLoss, NLLLoss
|
9 |
+
from torch.utils.data.sampler import Sampler, SequentialSampler
|
10 |
+
|
11 |
+
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerFast
|
12 |
+
from datasets import Dataset
|
13 |
+
|
14 |
+
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
|
15 |
+
|
16 |
+
def get_mutated_sequence(focus_seq, mutant, start_idx=1, AA_vocab=AA_vocab):
|
17 |
+
"""
|
18 |
+
Helper function that mutates an input sequence (focus_seq) via an input mutation triplet (substitutions only).
|
19 |
+
Mutation triplet are typically based on 1-indexing: start_idx is used for switching to 0-indexing.
|
20 |
+
"""
|
21 |
+
mutated_seq = list(focus_seq)
|
22 |
+
for mutation in mutant.split(":"):
|
23 |
+
try:
|
24 |
+
from_AA, position, to_AA = mutation[0], int(mutation[1:-1]), mutation[-1]
|
25 |
+
except:
|
26 |
+
print("Issue with mutant: "+str(mutation))
|
27 |
+
relative_position = position - start_idx
|
28 |
+
assert (from_AA==focus_seq[relative_position]), "Invalid from_AA or mutant position: "+str(mutation)+" from_AA: "+str(from_AA) + " relative pos: "+str(relative_position) + " focus_seq: "+str(focus_seq)
|
29 |
+
assert (to_AA in AA_vocab) , "Mutant to_AA is invalid: "+str(mutation)
|
30 |
+
mutated_seq[relative_position] = to_AA
|
31 |
+
return "".join(mutated_seq)
|
32 |
+
|
33 |
+
def nanmean(v, *args, inplace=False, **kwargs):
|
34 |
+
if not inplace:
|
35 |
+
v = v.clone()
|
36 |
+
is_nan = torch.isnan(v)
|
37 |
+
v[is_nan] = 0
|
38 |
+
return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
|
39 |
+
|
40 |
+
def nansum(v, *args, inplace=False, **kwargs):
|
41 |
+
if not inplace:
|
42 |
+
v = v.clone()
|
43 |
+
is_nan = torch.isnan(v)
|
44 |
+
v[is_nan] = 0
|
45 |
+
return v.sum(*args, **kwargs)
|
46 |
+
|
47 |
+
def get_optimal_window(mutation_position_relative, seq_len_wo_special, model_window):
|
48 |
+
"""
|
49 |
+
Helper function that selects an optimal sequence window that fits the maximum model context size.
|
50 |
+
If the sequence length is less than the maximum context size, the full sequence is returned.
|
51 |
+
"""
|
52 |
+
half_model_window = model_window // 2
|
53 |
+
if seq_len_wo_special <= model_window:
|
54 |
+
return [0,seq_len_wo_special]
|
55 |
+
elif mutation_position_relative < half_model_window:
|
56 |
+
return [0,model_window]
|
57 |
+
elif mutation_position_relative >= seq_len_wo_special - half_model_window:
|
58 |
+
return [seq_len_wo_special - model_window, seq_len_wo_special]
|
59 |
+
else:
|
60 |
+
return [max(0,mutation_position_relative-half_model_window), min(seq_len_wo_special,mutation_position_relative+half_model_window)]
|
61 |
+
|
62 |
+
def sequence_replace_single(sequence, char_to_replace, char_replacements):
|
63 |
+
char_replacements = list(char_replacements)
|
64 |
+
positions = [m.start() for m in re.finditer(char_to_replace, sequence)]
|
65 |
+
replacements = np.random.choice(a=char_replacements, size=len(positions), replace=True)
|
66 |
+
sequence=list(sequence)
|
67 |
+
for idx, position in enumerate(positions):
|
68 |
+
sequence[position]=replacements[idx]
|
69 |
+
return ''.join(sequence)
|
70 |
+
|
71 |
+
def sequence_replace(sequences, char_to_replace, char_replacements):
|
72 |
+
"""
|
73 |
+
Helper function that replaces all Amino Acids passsed in via char_to_replace (as a string of AAs) with Amino Acids sampled from char_replacements (also a string of eligible AAs).
|
74 |
+
"""
|
75 |
+
return [sequence_replace_single(sequence, char_to_replace, char_replacements) for sequence in sequences]
|
76 |
+
|
77 |
+
def get_tranception_scores_mutated_sequences(model, mutated_sequence_df, batch_size_inference, score_var_name, target_seq, num_workers=10, reverse=False, indel_mode=False):
|
78 |
+
"""
|
79 |
+
Helper function that takes as input a set of mutated sequences (in a pandas dataframe) and returns scores for each mutation.
|
80 |
+
If target_seq is not None, returns the delta log likelihood wrt that target sequence -- otherwise returns the log likelihood of the protein sequences.
|
81 |
+
"""
|
82 |
+
scores = {}
|
83 |
+
scores['mutated_sequence']=[]
|
84 |
+
scores['sliced_mutated_sequence']=[]
|
85 |
+
scores['window_start']=[]
|
86 |
+
scores['window_end']=[]
|
87 |
+
scores['score']=[]
|
88 |
+
with torch.no_grad():
|
89 |
+
ds = Dataset.from_pandas(mutated_sequence_df)
|
90 |
+
ds.set_transform(model.encode_batch)
|
91 |
+
data_collator = DataCollatorForLanguageModeling(
|
92 |
+
tokenizer=model.config.tokenizer,
|
93 |
+
mlm=False)
|
94 |
+
sampler = SequentialSampler(ds)
|
95 |
+
ds_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size_inference, sampler=sampler, collate_fn=data_collator, num_workers=num_workers, pin_memory=True, drop_last=False)
|
96 |
+
mutant_index=0
|
97 |
+
for encoded_batch in tqdm.tqdm(ds_loader):
|
98 |
+
full_batch_length = len(encoded_batch['input_ids'])
|
99 |
+
mutated_sequence = np.array(mutated_sequence_df['mutated_sequence'][mutant_index:mutant_index+full_batch_length])
|
100 |
+
scores['mutated_sequence'] += list(mutated_sequence)
|
101 |
+
sliced_mutated_sequence = np.array(mutated_sequence_df['sliced_mutated_sequence'][mutant_index:mutant_index+full_batch_length])
|
102 |
+
scores['sliced_mutated_sequence'] += list(sliced_mutated_sequence)
|
103 |
+
window_start = np.array(mutated_sequence_df['window_start'][mutant_index:mutant_index+full_batch_length])
|
104 |
+
scores['window_start'] += list(window_start)
|
105 |
+
window_end = np.array(mutated_sequence_df['window_end'][mutant_index:mutant_index+full_batch_length])
|
106 |
+
scores['window_end'] += list(window_end)
|
107 |
+
for k, v in encoded_batch.items():
|
108 |
+
if isinstance(v, torch.Tensor):
|
109 |
+
encoded_batch[k] = v.to(model.device)
|
110 |
+
shift_labels = encoded_batch['labels'][..., 1:].contiguous()
|
111 |
+
if (hasattr(model.config,"retrieval_aggregation_mode")) and (model.config.retrieval_aggregation_mode is not None):
|
112 |
+
if reverse:
|
113 |
+
encoded_batch['flip']=torch.tensor([1]*full_batch_length)
|
114 |
+
encoded_batch['start_slice']=window_start
|
115 |
+
encoded_batch['end_slice']=window_end
|
116 |
+
encoded_batch['mutated_sequence'] = mutated_sequence #only mutated_sequence is flipped if the scoring_mirror branch of score_mutants. No need to flip mutated_sequence for MSA re-aligning
|
117 |
+
fused_shift_log_probas=model(**encoded_batch,return_dict=True).fused_shift_log_probas
|
118 |
+
loss_fct = NLLLoss(reduction='none')
|
119 |
+
loss = - loss_fct(input=fused_shift_log_probas.view(-1, fused_shift_log_probas.size(-1)), target=shift_labels.view(-1)).view(fused_shift_log_probas.shape[0],fused_shift_log_probas.shape[1])
|
120 |
+
else:
|
121 |
+
lm_logits=model(**encoded_batch,return_dict=True).logits
|
122 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
123 |
+
loss_fct = CrossEntropyLoss(reduction='none')
|
124 |
+
loss = - loss_fct(input=shift_logits.view(-1, shift_logits.size(-1)), target=shift_labels.view(-1)).view(shift_logits.shape[0],shift_logits.shape[1])
|
125 |
+
mask = encoded_batch['attention_mask'][..., 1:].float()
|
126 |
+
mask[mask==0]=float('nan')
|
127 |
+
loss *= mask
|
128 |
+
loss = nansum(loss, dim=1)
|
129 |
+
scores_batch = list(loss.cpu().numpy())
|
130 |
+
full_batch_length = len(encoded_batch['input_ids'])
|
131 |
+
scores['score'] += scores_batch
|
132 |
+
mutant_index+=full_batch_length
|
133 |
+
scores = pd.DataFrame(scores)
|
134 |
+
if model.config.scoring_window=="sliding":
|
135 |
+
scores = scores[['mutated_sequence','score']].groupby('mutated_sequence').sum().reset_index() #We need to aggregate scores when using sliding mode
|
136 |
+
scores['score'] = scores['score'] / scores['mutated_sequence'].map(lambda x: len(x))
|
137 |
+
if target_seq is not None:
|
138 |
+
scores_mutated_seq = scores[scores.mutated_sequence != target_seq]
|
139 |
+
scores_wt = scores[scores.mutated_sequence == target_seq]
|
140 |
+
merge_delta = 'mutated_sequence' if model.config.scoring_window=="sliding" else 'window_start'
|
141 |
+
if model.config.scoring_window=="optimal":
|
142 |
+
delta_scores = pd.merge(scores_mutated_seq,scores_wt,how='left',on=[merge_delta],suffixes=('','_wt'))
|
143 |
+
delta_scores[score_var_name] = delta_scores['score'] - delta_scores['score_wt']
|
144 |
+
elif model.config.scoring_window=="sliding":
|
145 |
+
delta_scores = scores_mutated_seq.copy()
|
146 |
+
delta_scores[score_var_name] = delta_scores['score'] - list(scores_wt['score'])[0] # In sliding mode there is a single reference window for the WT
|
147 |
+
return delta_scores[['mutated_sequence',score_var_name]]
|
148 |
+
else:
|
149 |
+
scores[score_var_name] = scores['score']
|
150 |
+
return scores[['mutated_sequence',score_var_name]]
|
151 |
+
|
152 |
+
def get_sequence_slices(df, target_seq, model_context_len, start_idx=1, scoring_window="optimal", indel_mode=False):
|
153 |
+
"""
|
154 |
+
Helper function that takes as input a (pandas) dataframe df that contains a list of mutant triplets (substitutions) or full mutated sequences (indels) for scoring.
|
155 |
+
It returns a processed DMS in which sequences have been sliced to satisfy the maximum context window of the model.
|
156 |
+
df: (dataframe) Input dataframe to be processed
|
157 |
+
target_seq: (string) Full reference sequence (wild type) that is mutated in the DMS assay.
|
158 |
+
model_context_len: (int) Maximum context size for the model.
|
159 |
+
start_idx: (int) Integer to move to 0-indexing of positions (mutation triplet are typically based on 1-indexing).
|
160 |
+
scoring_window: (string) Method to slice sequences longer than maximum context size:
|
161 |
+
- optimal selects a single window as large as possible via the get_optimal_window function (this is the default)
|
162 |
+
- sliding splits the full sequence in contiguous (non-overlapping) chunks that are of size equal to the max context (except the last chunk which may be shorter)
|
163 |
+
indel_mode: (bool) Flag to be used when scoring insertions and deletions. Otherwise assumes substitutions.
|
164 |
+
Note: when scoring indels for sequences that would be longer than the model max context length, it is preferable to use the "sliding" scoring_window. Use "optimal" otherwise.
|
165 |
+
"""
|
166 |
+
len_target_seq = len(target_seq)
|
167 |
+
num_mutants = len(df['mutated_sequence'])
|
168 |
+
df=df.reset_index(drop=True)
|
169 |
+
if scoring_window=="optimal":
|
170 |
+
df['mutation_barycenter'] = df['mutant'].apply(lambda x: int(np.array([int(mutation[1:-1]) - start_idx for mutation in x.split(':')]).mean())) if not indel_mode else df['mutated_sequence'].apply(lambda x: len(x)//2)
|
171 |
+
df['scoring_optimal_window'] = df['mutation_barycenter'].apply(lambda x: get_optimal_window(x, len_target_seq, model_context_len)) if not indel_mode else df['mutated_sequence'].apply(lambda x: (0,len(x)))
|
172 |
+
df['sliced_mutated_sequence'] = [df['mutated_sequence'][index][df['scoring_optimal_window'][index][0]:df['scoring_optimal_window'][index][1]] for index in range(num_mutants)]
|
173 |
+
df['window_start'] = df['scoring_optimal_window'].map(lambda x: x[0])
|
174 |
+
df['window_end'] = df['scoring_optimal_window'].map(lambda x: x[1])
|
175 |
+
del df['scoring_optimal_window'], df['mutation_barycenter']
|
176 |
+
if 'mutant' in df: del df['mutant']
|
177 |
+
df_wt=df.copy()
|
178 |
+
df_wt['mutated_sequence'] = [target_seq] * num_mutants
|
179 |
+
if indel_mode: # For indels, we set the wild type reference to be always the same (full length) sequence. We assume here that the length is lower than model context size (otherwise "Sliding" mode should be used)
|
180 |
+
df_wt['window_end'] = df_wt['mutated_sequence'].map(lambda x:len(x))
|
181 |
+
df_wt['sliced_mutated_sequence'] = [target_seq[df_wt['window_start'][index]:df_wt['window_end'][index]] for index in range(num_mutants)]
|
182 |
+
df = pd.concat([df,df_wt], axis=0)
|
183 |
+
df = df.drop_duplicates()
|
184 |
+
elif scoring_window=="sliding":
|
185 |
+
num_windows = 1 + int( len_target_seq / model_context_len)
|
186 |
+
df_list=[]
|
187 |
+
start=0
|
188 |
+
for window_index in range(1, num_windows+1):
|
189 |
+
df_sliced = df.copy()
|
190 |
+
df_sliced['sliced_mutated_sequence'] = df_sliced['mutated_sequence'].map(lambda x: x[start:start+model_context_len])
|
191 |
+
df_sliced['window_start'] = [start] * num_mutants
|
192 |
+
df_sliced['window_end'] = df_sliced['mutated_sequence'].map(lambda x: min(len(x), start+model_context_len))
|
193 |
+
df_sliced_wt = df_sliced.copy()
|
194 |
+
df_sliced_wt['mutated_sequence'] = [target_seq] * num_mutants
|
195 |
+
df_sliced_wt['sliced_mutated_sequence'] = df_sliced_wt['mutated_sequence'].map(lambda x: x[start:start+model_context_len])
|
196 |
+
df_sliced_wt['window_end'] = df_sliced_wt['mutated_sequence'].map(lambda x: min(len(x), start+model_context_len)) #Need to adjust end index if WT and sequence are not same full length
|
197 |
+
df_list.append(df_sliced)
|
198 |
+
df_list.append(df_sliced_wt)
|
199 |
+
start += model_context_len
|
200 |
+
df_final = pd.concat(df_list,axis=0)
|
201 |
+
if 'mutant' in df_final: del df_final['mutant']
|
202 |
+
df = df_final.drop_duplicates()
|
203 |
+
return df.reset_index(drop=True)
|
tranception/utils/tokenizers/Basic_tokenizer
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[CLS]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SEP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":3,"special":true,"content":"[PAD]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":4,"special":true,"content":"[MASK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":{"type":"TemplateProcessing","single":[{"SpecialToken":{"id":"[CLS]","type_id":0}},{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[SEP]","type_id":0}}],"pair":[{"SpecialToken":{"id":"[CLS]","type_id":0}},{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"[SEP]","type_id":0}},{"Sequence":{"id":"B","type_id":1}},{"SpecialToken":{"id":"[SEP]","type_id":1}}],"special_tokens":{"[CLS]":{"id":"[CLS]","ids":[1],"tokens":["[CLS]"]},"[SEP]":{"id":"[SEP]","ids":[2],"tokens":["[SEP]"]}}},"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[UNK]":0,"[CLS]":1,"[SEP]":2,"[PAD]":3,"[MASK]":4,"A":5,"C":6,"D":7,"E":8,"F":9,"G":10,"H":11,"I":12,"K":13,"L":14,"M":15,"N":16,"P":17,"Q":18,"R":19,"S":20,"T":21,"V":22,"W":23,"Y":24},"merges":[]}}
|