mheinz commited on
Commit
648ea67
1 Parent(s): a48e328

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +31 -43
README.md CHANGED
@@ -104,8 +104,10 @@ model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
104
  # only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
105
  model.full() if device=='cpu' else model.half()
106
 
107
- # prepare your protein sequences/structures as a list. Amino acid sequences are expected to be upper-case ("PRTEINO" below)
108
- folding_example = ["PRTEINO", "SEQWENCE"]
 
 
109
  min_len = min([ len(s) for s in folding_example])
110
  max_len = max([ len(s) for s in folding_example])
111
 
@@ -116,9 +118,12 @@ sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequen
116
  sequence_examples = [ "<AA2fold>" + " " + s for s in sequence_examples]
117
 
118
  # tokenize sequences and pad up to the longest sequence in the batch
119
- ids = tokenizer.batch_encode_plus(sequences_example, add_special_tokens=True, padding="longest",return_tensors='pt').to(device)
 
 
 
120
 
121
- # Generation configuration
122
  gen_kwargs_aa2fold = {
123
  "do_sample": True,
124
  "num_beams": 3,
@@ -128,11 +133,11 @@ gen_kwargs_aa2fold = {
128
  "repetition_penalty" : 1.2,
129
  }
130
 
131
- # translate from AA to 3Di
132
  with torch.no_grad():
133
- target = model.generate(
134
- start_encoding.input_ids,
135
- attention_mask=start_encoding.attention_mask,
136
  max_length=max_len, # max length of generated text
137
  min_length=min_len, # minimum length of the generated text
138
  early_stopping=True, # stop early if end-of-text token is generated
@@ -140,40 +145,22 @@ with torch.no_grad():
140
  **gen_kwargs_aa2fold
141
  )
142
  # Decode and remove white-spaces between tokens
143
- t_strings = tokenizer.batch_decode( target, skip_special_tokens=True )
144
- t_strings = [ "".join(ts.split(" ")) for ts in t_strings ]
145
- ```
146
 
147
- Translation ("inverse folding", i.e., 3Di to AA):
148
- ```python
149
- from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
150
- import torch
151
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
152
 
153
- # Load the tokenizer
154
- tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case=False).to(device)
155
-
156
- # Load the model
157
- model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
158
-
159
- # only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
160
- model.full() if device=='cpu' else model.half()
161
-
162
- # prepare your protein sequences/structures as a list. Amino acid sequences are expected to be upper-case ("PRTEINO" below)
163
- folding_example = ["prtein", "strctr"]
164
- min_len = min([ len(s) for s in folding_example])
165
- max_len = max([ len(s) for s in folding_example])
166
-
167
- # replace all rare/ambiguous amino acids by X (3Di sequences does not have those) and introduce white-space between all sequences (AAs and 3Di)
168
- sequence_examples = [" ".join(list(sequence)) for sequence in sequence_examples]
169
-
170
- # add pre-fixes accordingly. For the translation from 3Di to AAs, you need to prepend "<fold2AA>"
171
- sequence_examples = [ "<fold2AA>" + " " + s for s in sequence_examples]
172
 
173
  # tokenize sequences and pad up to the longest sequence in the batch
174
- ids = tokenizer.batch_encode_plus(sequences_example, add_special_tokens=True, padding="longest",return_tensors='pt').to(device)
 
 
 
175
 
176
- # Generation configuration
177
  gen_kwargs_fold2AA = {
178
  "do_sample": True,
179
  "top_p" : 0.90,
@@ -182,11 +169,11 @@ gen_kwargs_fold2AA = {
182
  "repetition_penalty" : 1.2,
183
  }
184
 
185
- # translate from 3Di to AA
186
  with torch.no_grad():
187
- target = model.generate(
188
- start_encoding.input_ids,
189
- attention_mask=start_encoding.attention_mask,
190
  max_length=max_len, # max length of generated text
191
  min_length=min_len, # minimum length of the generated text
192
  early_stopping=True, # stop early if end-of-text token is generated
@@ -194,8 +181,9 @@ with torch.no_grad():
194
  **gen_kwargs_fold2AA
195
  )
196
  # Decode and remove white-spaces between tokens
197
- t_strings = tokenizer.batch_decode( target, skip_special_tokens=True )
198
- t_strings = [ "".join(ts.split(" ")) for ts in t_strings ]
 
199
  ```
200
 
201
 
 
104
  # only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
105
  model.full() if device=='cpu' else model.half()
106
 
107
+ # prepare your protein sequences/structures as a list.
108
+ # Amino acid sequences are expected to be upper-case ("PRTEINO" below)
109
+ # while 3Di-sequences need to be lower-case.
110
+ sequence_examples = ["PRTEINO", "SEQWENCE"]
111
  min_len = min([ len(s) for s in folding_example])
112
  max_len = max([ len(s) for s in folding_example])
113
 
 
118
  sequence_examples = [ "<AA2fold>" + " " + s for s in sequence_examples]
119
 
120
  # tokenize sequences and pad up to the longest sequence in the batch
121
+ ids = tokenizer.batch_encode_plus(sequences_example,
122
+ add_special_tokens=True,
123
+ padding="longest",
124
+ return_tensors='pt').to(device))
125
 
126
+ # Generation configuration for "folding" (AA-->3Di)
127
  gen_kwargs_aa2fold = {
128
  "do_sample": True,
129
  "num_beams": 3,
 
133
  "repetition_penalty" : 1.2,
134
  }
135
 
136
+ # translate from AA to 3Di (AA-->3Di)
137
  with torch.no_grad():
138
+ translations = model.generate(
139
+ ids.input_ids,
140
+ attention_mask=ids.attention_mask,
141
  max_length=max_len, # max length of generated text
142
  min_length=min_len, # minimum length of the generated text
143
  early_stopping=True, # stop early if end-of-text token is generated
 
145
  **gen_kwargs_aa2fold
146
  )
147
  # Decode and remove white-spaces between tokens
148
+ decoded_translations = tokenizer.batch_decode( translations, skip_special_tokens=True )
149
+ structure_sequences = [ "".join(ts.split(" ")) for ts in decoded_translations ] # predicted 3Di strings
 
150
 
151
+ # Now we can use the same model and invert the translation logic
152
+ # to generate an amino acid sequence from the predicted 3Di-sequence (3Di-->AA)
 
 
 
153
 
154
+ # add pre-fixes accordingly. For the translation from 3Di to AA (3Di-->AA), you need to prepend "<fold2AA>"
155
+ sequence_examples_backtranslation = [ "<fold2AA>" + " " + s for s in decoded_translations]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # tokenize sequences and pad up to the longest sequence in the batch
158
+ ids_backtranslation = tokenizer.batch_encode_plus(sequence_examples_backtranslation,
159
+ add_special_tokens=True,
160
+ padding="longest",
161
+ return_tensors='pt').to(device))
162
 
163
+ # Example generation configuration for "inverse folding" (3Di-->AA)
164
  gen_kwargs_fold2AA = {
165
  "do_sample": True,
166
  "top_p" : 0.90,
 
169
  "repetition_penalty" : 1.2,
170
  }
171
 
172
+ # translate from 3Di to AA (3Di-->AA)
173
  with torch.no_grad():
174
+ backtranslations = model.generate(
175
+ ids_backtranslation.input_ids,
176
+ attention_mask=ids_backtranslation.attention_mask,
177
  max_length=max_len, # max length of generated text
178
  min_length=min_len, # minimum length of the generated text
179
  early_stopping=True, # stop early if end-of-text token is generated
 
181
  **gen_kwargs_fold2AA
182
  )
183
  # Decode and remove white-spaces between tokens
184
+ decoded_backtranslations = tokenizer.batch_decode( backtranslations, skip_special_tokens=True )
185
+ aminoAcid_sequences = [ "".join(ts.split(" ")) for ts in decoded_backtranslations ] # predicted amino acid strings
186
+
187
  ```
188
 
189