Update README.md
Browse files
README.md
CHANGED
@@ -104,8 +104,10 @@ model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
|
|
104 |
# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
|
105 |
model.full() if device=='cpu' else model.half()
|
106 |
|
107 |
-
# prepare your protein sequences/structures as a list.
|
108 |
-
|
|
|
|
|
109 |
min_len = min([ len(s) for s in folding_example])
|
110 |
max_len = max([ len(s) for s in folding_example])
|
111 |
|
@@ -116,9 +118,12 @@ sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequen
|
|
116 |
sequence_examples = [ "<AA2fold>" + " " + s for s in sequence_examples]
|
117 |
|
118 |
# tokenize sequences and pad up to the longest sequence in the batch
|
119 |
-
ids = tokenizer.batch_encode_plus(sequences_example,
|
|
|
|
|
|
|
120 |
|
121 |
-
# Generation configuration
|
122 |
gen_kwargs_aa2fold = {
|
123 |
"do_sample": True,
|
124 |
"num_beams": 3,
|
@@ -128,11 +133,11 @@ gen_kwargs_aa2fold = {
|
|
128 |
"repetition_penalty" : 1.2,
|
129 |
}
|
130 |
|
131 |
-
# translate from AA to 3Di
|
132 |
with torch.no_grad():
|
133 |
-
|
134 |
-
|
135 |
-
attention_mask=
|
136 |
max_length=max_len, # max length of generated text
|
137 |
min_length=min_len, # minimum length of the generated text
|
138 |
early_stopping=True, # stop early if end-of-text token is generated
|
@@ -140,40 +145,22 @@ with torch.no_grad():
|
|
140 |
**gen_kwargs_aa2fold
|
141 |
)
|
142 |
# Decode and remove white-spaces between tokens
|
143 |
-
|
144 |
-
|
145 |
-
```
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
|
150 |
-
import torch
|
151 |
-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
152 |
|
153 |
-
#
|
154 |
-
|
155 |
-
|
156 |
-
# Load the model
|
157 |
-
model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
|
158 |
-
|
159 |
-
# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
|
160 |
-
model.full() if device=='cpu' else model.half()
|
161 |
-
|
162 |
-
# prepare your protein sequences/structures as a list. Amino acid sequences are expected to be upper-case ("PRTEINO" below)
|
163 |
-
folding_example = ["prtein", "strctr"]
|
164 |
-
min_len = min([ len(s) for s in folding_example])
|
165 |
-
max_len = max([ len(s) for s in folding_example])
|
166 |
-
|
167 |
-
# replace all rare/ambiguous amino acids by X (3Di sequences does not have those) and introduce white-space between all sequences (AAs and 3Di)
|
168 |
-
sequence_examples = [" ".join(list(sequence)) for sequence in sequence_examples]
|
169 |
-
|
170 |
-
# add pre-fixes accordingly. For the translation from 3Di to AAs, you need to prepend "<fold2AA>"
|
171 |
-
sequence_examples = [ "<fold2AA>" + " " + s for s in sequence_examples]
|
172 |
|
173 |
# tokenize sequences and pad up to the longest sequence in the batch
|
174 |
-
|
|
|
|
|
|
|
175 |
|
176 |
-
#
|
177 |
gen_kwargs_fold2AA = {
|
178 |
"do_sample": True,
|
179 |
"top_p" : 0.90,
|
@@ -182,11 +169,11 @@ gen_kwargs_fold2AA = {
|
|
182 |
"repetition_penalty" : 1.2,
|
183 |
}
|
184 |
|
185 |
-
# translate from 3Di to AA
|
186 |
with torch.no_grad():
|
187 |
-
|
188 |
-
|
189 |
-
attention_mask=
|
190 |
max_length=max_len, # max length of generated text
|
191 |
min_length=min_len, # minimum length of the generated text
|
192 |
early_stopping=True, # stop early if end-of-text token is generated
|
@@ -194,8 +181,9 @@ with torch.no_grad():
|
|
194 |
**gen_kwargs_fold2AA
|
195 |
)
|
196 |
# Decode and remove white-spaces between tokens
|
197 |
-
|
198 |
-
|
|
|
199 |
```
|
200 |
|
201 |
|
|
|
104 |
# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
|
105 |
model.full() if device=='cpu' else model.half()
|
106 |
|
107 |
+
# prepare your protein sequences/structures as a list.
|
108 |
+
# Amino acid sequences are expected to be upper-case ("PRTEINO" below)
|
109 |
+
# while 3Di-sequences need to be lower-case.
|
110 |
+
sequence_examples = ["PRTEINO", "SEQWENCE"]
|
111 |
min_len = min([ len(s) for s in folding_example])
|
112 |
max_len = max([ len(s) for s in folding_example])
|
113 |
|
|
|
118 |
sequence_examples = [ "<AA2fold>" + " " + s for s in sequence_examples]
|
119 |
|
120 |
# tokenize sequences and pad up to the longest sequence in the batch
|
121 |
+
ids = tokenizer.batch_encode_plus(sequences_example,
|
122 |
+
add_special_tokens=True,
|
123 |
+
padding="longest",
|
124 |
+
return_tensors='pt').to(device))
|
125 |
|
126 |
+
# Generation configuration for "folding" (AA-->3Di)
|
127 |
gen_kwargs_aa2fold = {
|
128 |
"do_sample": True,
|
129 |
"num_beams": 3,
|
|
|
133 |
"repetition_penalty" : 1.2,
|
134 |
}
|
135 |
|
136 |
+
# translate from AA to 3Di (AA-->3Di)
|
137 |
with torch.no_grad():
|
138 |
+
translations = model.generate(
|
139 |
+
ids.input_ids,
|
140 |
+
attention_mask=ids.attention_mask,
|
141 |
max_length=max_len, # max length of generated text
|
142 |
min_length=min_len, # minimum length of the generated text
|
143 |
early_stopping=True, # stop early if end-of-text token is generated
|
|
|
145 |
**gen_kwargs_aa2fold
|
146 |
)
|
147 |
# Decode and remove white-spaces between tokens
|
148 |
+
decoded_translations = tokenizer.batch_decode( translations, skip_special_tokens=True )
|
149 |
+
structure_sequences = [ "".join(ts.split(" ")) for ts in decoded_translations ] # predicted 3Di strings
|
|
|
150 |
|
151 |
+
# Now we can use the same model and invert the translation logic
|
152 |
+
# to generate an amino acid sequence from the predicted 3Di-sequence (3Di-->AA)
|
|
|
|
|
|
|
153 |
|
154 |
+
# add pre-fixes accordingly. For the translation from 3Di to AA (3Di-->AA), you need to prepend "<fold2AA>"
|
155 |
+
sequence_examples_backtranslation = [ "<fold2AA>" + " " + s for s in decoded_translations]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
# tokenize sequences and pad up to the longest sequence in the batch
|
158 |
+
ids_backtranslation = tokenizer.batch_encode_plus(sequence_examples_backtranslation,
|
159 |
+
add_special_tokens=True,
|
160 |
+
padding="longest",
|
161 |
+
return_tensors='pt').to(device))
|
162 |
|
163 |
+
# Example generation configuration for "inverse folding" (3Di-->AA)
|
164 |
gen_kwargs_fold2AA = {
|
165 |
"do_sample": True,
|
166 |
"top_p" : 0.90,
|
|
|
169 |
"repetition_penalty" : 1.2,
|
170 |
}
|
171 |
|
172 |
+
# translate from 3Di to AA (3Di-->AA)
|
173 |
with torch.no_grad():
|
174 |
+
backtranslations = model.generate(
|
175 |
+
ids_backtranslation.input_ids,
|
176 |
+
attention_mask=ids_backtranslation.attention_mask,
|
177 |
max_length=max_len, # max length of generated text
|
178 |
min_length=min_len, # minimum length of the generated text
|
179 |
early_stopping=True, # stop early if end-of-text token is generated
|
|
|
181 |
**gen_kwargs_fold2AA
|
182 |
)
|
183 |
# Decode and remove white-spaces between tokens
|
184 |
+
decoded_backtranslations = tokenizer.batch_decode( backtranslations, skip_special_tokens=True )
|
185 |
+
aminoAcid_sequences = [ "".join(ts.split(" ")) for ts in decoded_backtranslations ] # predicted amino acid strings
|
186 |
+
|
187 |
```
|
188 |
|
189 |
|