KeXing commited on
Commit
212111c
1 Parent(s): 5f1e767

Upload 26 files

Browse files
tape/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import datasets # noqa: F401
2
+ from . import metrics # noqa: F401
3
+ from .tokenizers import TAPETokenizer # noqa: F401
4
+ from .models.modeling_utils import ProteinModel
5
+ from .models.modeling_utils import ProteinConfig
6
+
7
+ import sys
8
+ from pathlib import Path
9
+ import importlib
10
+ import pkgutil
11
+
12
+ __version__ = '0.4'
13
+
14
+
15
+ # Import all the models and configs
16
+ for _, name, _ in pkgutil.iter_modules([str(Path(__file__).parent / 'models')]):
17
+ imported_module = importlib.import_module('.models.' + name, package=__name__)
18
+ for name, cls in imported_module.__dict__.items():
19
+ if isinstance(cls, type) and \
20
+ (issubclass(cls, ProteinModel) or issubclass(cls, ProteinConfig)):
21
+ setattr(sys.modules[__name__], name, cls)
tape/datasets.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection
2
+ from copy import copy
3
+ from pathlib import Path
4
+ import pickle as pkl
5
+ import logging
6
+ import random
7
+
8
+ import lmdb
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import Dataset
13
+ from scipy.spatial.distance import pdist, squareform
14
+
15
+ from .tokenizers import TAPETokenizer
16
+ from .registry import registry
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def dataset_factory(data_file: Union[str, Path], *args, **kwargs) -> Dataset:
22
+ data_file = Path(data_file)
23
+ if not data_file.exists():
24
+ raise FileNotFoundError(data_file)
25
+ if data_file.suffix == '.lmdb':
26
+ return LMDBDataset(data_file, *args, **kwargs)
27
+ elif data_file.suffix in {'.fasta', '.fna', '.ffn', '.faa', '.frn'}:
28
+ return FastaDataset(data_file, *args, **kwargs)
29
+ elif data_file.suffix == '.json':
30
+ return JSONDataset(data_file, *args, **kwargs)
31
+ elif data_file.is_dir():
32
+ return NPZDataset(data_file, *args, **kwargs)
33
+ else:
34
+ raise ValueError(f"Unrecognized datafile type {data_file.suffix}")
35
+
36
+
37
+ def pad_sequences(sequences: Sequence, constant_value=0, dtype=None) -> np.ndarray:
38
+ batch_size = len(sequences)
39
+ shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()
40
+
41
+ if dtype is None:
42
+ dtype = sequences[0].dtype
43
+
44
+ if isinstance(sequences[0], np.ndarray):
45
+ array = np.full(shape, constant_value, dtype=dtype)
46
+ elif isinstance(sequences[0], torch.Tensor):
47
+ array = torch.full(shape, constant_value, dtype=dtype)
48
+
49
+ for arr, seq in zip(array, sequences):
50
+ arrslice = tuple(slice(dim) for dim in seq.shape)
51
+ arr[arrslice] = seq
52
+
53
+ return array
54
+
55
+
56
+ class FastaDataset(Dataset):
57
+ """Creates a dataset from a fasta file.
58
+ Args:
59
+ data_file (Union[str, Path]): Path to fasta file.
60
+ in_memory (bool, optional): Whether to load the full dataset into memory.
61
+ Default: False.
62
+ """
63
+
64
+ def __init__(self,
65
+ data_file: Union[str, Path],
66
+ in_memory: bool = False):
67
+
68
+ from Bio import SeqIO
69
+ data_file = Path(data_file)
70
+ if not data_file.exists():
71
+ raise FileNotFoundError(data_file)
72
+
73
+ # if in_memory:
74
+ cache = list(SeqIO.parse(str(data_file), 'fasta'))
75
+ num_examples = len(cache)
76
+ self._cache = cache
77
+ # else:
78
+ # records = SeqIO.index(str(data_file), 'fasta')
79
+ # num_examples = len(records)
80
+ #
81
+ # if num_examples < 10000:
82
+ # logger.info("Reading full fasta file into memory because number of examples "
83
+ # "is very low. This loads data approximately 20x faster.")
84
+ # in_memory = True
85
+ # cache = list(records.values())
86
+ # self._cache = cache
87
+ # else:
88
+ # self._records = records
89
+ # self._keys = list(records.keys())
90
+
91
+ self._in_memory = in_memory
92
+ self._num_examples = num_examples
93
+
94
+ def __len__(self) -> int:
95
+ return self._num_examples
96
+
97
+ def __getitem__(self, index: int):
98
+ if not 0 <= index < self._num_examples:
99
+ raise IndexError(index)
100
+
101
+ # if self._in_memory and self._cache[index] is not None:
102
+ record = self._cache[index]
103
+ # else:
104
+ # key = self._keys[index]
105
+ # record = self._records[key]
106
+ # if self._in_memory:
107
+ # self._cache[index] = record
108
+
109
+ item = {'id': record.id,
110
+ 'primary': str(record.seq),
111
+ 'protein_length': len(record.seq)}
112
+ return item
113
+
114
+
115
+ class LMDBDataset(Dataset):
116
+ """Creates a dataset from an lmdb file.
117
+ Args:
118
+ data_file (Union[str, Path]): Path to lmdb file.
119
+ in_memory (bool, optional): Whether to load the full dataset into memory.
120
+ Default: False.
121
+ """
122
+
123
+ def __init__(self,
124
+ data_file: Union[str, Path],
125
+ in_memory: bool = False):
126
+
127
+ data_file = Path(data_file)
128
+ if not data_file.exists():
129
+ raise FileNotFoundError(data_file)
130
+
131
+ env = lmdb.open(str(data_file), max_readers=1, readonly=True,
132
+ lock=False, readahead=False, meminit=False)
133
+
134
+ with env.begin(write=False) as txn:
135
+ num_examples = pkl.loads(txn.get(b'num_examples'))
136
+
137
+ if in_memory:
138
+ cache = [None] * num_examples
139
+ self._cache = cache
140
+
141
+ self._env = env
142
+ self._in_memory = in_memory
143
+ self._num_examples = num_examples
144
+
145
+ def __len__(self) -> int:
146
+ return self._num_examples
147
+
148
+ def __getitem__(self, index: int):
149
+ if not 0 <= index < self._num_examples:
150
+ raise IndexError(index)
151
+
152
+ if self._in_memory and self._cache[index] is not None:
153
+ item = self._cache[index]
154
+ else:
155
+ with self._env.begin(write=False) as txn:
156
+ item = pkl.loads(txn.get(str(index).encode()))
157
+ if 'id' not in item:
158
+ item['id'] = str(index)
159
+ if self._in_memory:
160
+ self._cache[index] = item
161
+ return item
162
+
163
+
164
+ class JSONDataset(Dataset):
165
+ """Creates a dataset from a json file. Assumes that data is
166
+ a JSON serialized list of record, where each record is
167
+ a dictionary.
168
+ Args:
169
+ data_file (Union[str, Path]): Path to json file.
170
+ in_memory (bool): Dummy variable to match API of other datasets
171
+ """
172
+
173
+ def __init__(self, data_file: Union[str, Path], in_memory: bool = True):
174
+ import json
175
+ data_file = Path(data_file)
176
+ if not data_file.exists():
177
+ raise FileNotFoundError(data_file)
178
+ records = json.loads(data_file.read_text())
179
+
180
+ if not isinstance(records, list):
181
+ raise TypeError(f"TAPE JSONDataset requires a json serialized list, "
182
+ f"received {type(records)}")
183
+ self._records = records
184
+ self._num_examples = len(records)
185
+
186
+ def __len__(self) -> int:
187
+ return self._num_examples
188
+
189
+ def __getitem__(self, index: int):
190
+ if not 0 <= index < self._num_examples:
191
+ raise IndexError(index)
192
+
193
+ item = self._records[index]
194
+ if not isinstance(item, dict):
195
+ raise TypeError(f"Expected dataset to contain a list of dictionary "
196
+ f"records, received record of type {type(item)}")
197
+ if 'id' not in item:
198
+ item['id'] = str(index)
199
+ return item
200
+
201
+
202
+ class NPZDataset(Dataset):
203
+ """Creates a dataset from a directory of npz files.
204
+ Args:
205
+ data_file (Union[str, Path]): Path to directory of npz files
206
+ in_memory (bool): Dummy variable to match API of other datasets
207
+ """
208
+
209
+ def __init__(self,
210
+ data_file: Union[str, Path],
211
+ in_memory: bool = True,
212
+ split_files: Optional[Collection[str]] = None):
213
+ data_file = Path(data_file)
214
+ if not data_file.exists():
215
+ raise FileNotFoundError(data_file)
216
+ if not data_file.is_dir():
217
+ raise NotADirectoryError(data_file)
218
+ file_glob = data_file.glob('*.npz')
219
+ if split_files is None:
220
+ file_list = list(file_glob)
221
+ else:
222
+ split_files = set(split_files)
223
+ if len(split_files) == 0:
224
+ raise ValueError("Passed an empty split file set")
225
+
226
+ file_list = [f for f in file_glob if f.name in split_files]
227
+ if len(file_list) != len(split_files):
228
+ num_missing = len(split_files) - len(file_list)
229
+ raise FileNotFoundError(
230
+ f"{num_missing} specified split files not found in directory")
231
+
232
+ if len(file_list) == 0:
233
+ raise FileNotFoundError(f"No .npz files found in {data_file}")
234
+
235
+ self._file_list = file_list
236
+
237
+ def __len__(self) -> int:
238
+ return len(self._file_list)
239
+
240
+ def __getitem__(self, index: int):
241
+ if not 0 <= index < len(self):
242
+ raise IndexError(index)
243
+
244
+ item = dict(np.load(self._file_list[index]))
245
+ if not isinstance(item, dict):
246
+ raise TypeError(f"Expected dataset to contain a list of dictionary "
247
+ f"records, received record of type {type(item)}")
248
+ if 'id' not in item:
249
+ item['id'] = self._file_list[index].stem
250
+ return item
251
+
252
+
253
+ @registry.register_task('embed')
254
+ class EmbedDataset(Dataset):
255
+
256
+ def __init__(self,
257
+ data_file: Union[str, Path],
258
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
259
+ in_memory: bool = False,
260
+ convert_tokens_to_ids: bool = True):
261
+ super().__init__()
262
+
263
+ if isinstance(tokenizer, str):
264
+ tokenizer = TAPETokenizer(vocab=tokenizer)
265
+ self.tokenizer = tokenizer
266
+ self.data = dataset_factory(data_file)
267
+
268
+ def __len__(self) -> int:
269
+ return len(self.data)
270
+
271
+ def __getitem__(self, index: int):
272
+ item = self.data[index]
273
+ token_ids = self.tokenizer.encode(item['primary'])
274
+ input_mask = np.ones_like(token_ids)
275
+ return item['id'], token_ids, input_mask
276
+
277
+ def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
278
+ ids, tokens, input_mask = zip(*batch)
279
+ ids = list(ids)
280
+ tokens = torch.from_numpy(pad_sequences(tokens))
281
+ input_mask = torch.from_numpy(pad_sequences(input_mask))
282
+ return {'ids': ids, 'input_ids': tokens, 'input_mask': input_mask} # type: ignore
283
+
284
+
285
+ @registry.register_task('masked_language_modeling')
286
+ class MaskedLanguageModelingDataset(Dataset):
287
+ """Creates the Masked Language Modeling Pfam Dataset
288
+ Args:
289
+ data_path (Union[str, Path]): Path to tape data root.
290
+ split (str): One of ['train', 'valid', 'holdout'], specifies which data file to load.
291
+ in_memory (bool, optional): Whether to load the full dataset into memory.
292
+ Default: False.
293
+ """
294
+
295
+ def __init__(self,
296
+ data_path: Union[str, Path],
297
+ split: str,
298
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
299
+ in_memory: bool = False):
300
+ super().__init__()
301
+ if split not in ('train', 'valid', 'holdout'):
302
+ raise ValueError(
303
+ f"Unrecognized split: {split}. "
304
+ f"Must be one of ['train', 'valid', 'holdout']")
305
+ if isinstance(tokenizer, str):
306
+ tokenizer = TAPETokenizer(vocab=tokenizer)
307
+ self.tokenizer = tokenizer
308
+
309
+ data_path = Path(data_path)
310
+ data_file = f'pfam/pfam_{split}.lmdb'
311
+ self.data = dataset_factory(data_path / data_file, in_memory)
312
+
313
+ def __len__(self) -> int:
314
+ return len(self.data)
315
+
316
+ def __getitem__(self, index):
317
+ item = self.data[index]
318
+ tokens = self.tokenizer.tokenize(item['primary'])
319
+ tokens = self.tokenizer.add_special_tokens(tokens)
320
+ masked_tokens, labels = self._apply_bert_mask(tokens)
321
+ masked_token_ids = np.array(
322
+ self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
323
+ input_mask = np.ones_like(masked_token_ids)
324
+
325
+ masked_token_ids = np.array(
326
+ self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
327
+
328
+ return masked_token_ids, input_mask, labels, item['clan'], item['family']
329
+
330
+ def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]:
331
+ input_ids, input_mask, lm_label_ids, clan, family = tuple(zip(*batch))
332
+
333
+ input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
334
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
335
+ # ignore_index is -1
336
+ lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1))
337
+ clan = torch.LongTensor(clan) # type: ignore
338
+ family = torch.LongTensor(family) # type: ignore
339
+
340
+ return {'input_ids': input_ids,
341
+ 'input_mask': input_mask,
342
+ 'targets': lm_label_ids}
343
+
344
+ def _apply_bert_mask(self, tokens: List[str]) -> Tuple[List[str], List[int]]:
345
+ masked_tokens = copy(tokens)
346
+ labels = np.zeros([len(tokens)], np.int64) - 1
347
+
348
+ for i, token in enumerate(tokens):
349
+ # Tokens begin and end with start_token and stop_token, ignore these
350
+ if token in (self.tokenizer.start_token, self.tokenizer.stop_token):
351
+ pass
352
+
353
+ prob = random.random()
354
+ if prob < 0.15:
355
+ prob /= 0.15
356
+ labels[i] = self.tokenizer.convert_token_to_id(token)
357
+
358
+ if prob < 0.8:
359
+ # 80% random change to mask token
360
+ token = self.tokenizer.mask_token
361
+ elif prob < 0.9:
362
+ # 10% chance to change to random token
363
+ token = self.tokenizer.convert_id_to_token(
364
+ random.randint(0, self.tokenizer.vocab_size - 1))
365
+ else:
366
+ # 10% chance to keep current token
367
+ pass
368
+
369
+ masked_tokens[i] = token
370
+
371
+ return masked_tokens, labels
372
+
373
+
374
+ @registry.register_task('beta_lactamase')
375
+ class BetaModelingDataset(MaskedLanguageModelingDataset):
376
+
377
+ def __init__(self,
378
+ data_path: Union[str, Path],
379
+ split: str,
380
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
381
+ in_memory: bool = False):
382
+ super().__init__(data_path, split, tokenizer, in_memory)
383
+ data_path = Path(data_path)
384
+ data_file = f'unilanguage/{split}_combined.fasta'
385
+ self.data = dataset_factory(data_path / data_file, in_memory)
386
+
387
+ def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]:
388
+ input_ids, input_mask, lm_label_ids = tuple(zip(*batch))
389
+
390
+ input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
391
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
392
+ # ignore_index is -1
393
+ lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1))
394
+
395
+ return {'input_ids': input_ids,
396
+ 'input_mask': input_mask,
397
+ 'targets': lm_label_ids}
398
+
399
+ def __getitem__(self, index):
400
+ item = self.data[index]
401
+ tokens = self.tokenizer.tokenize(item['primary'])
402
+ tokens = self.tokenizer.add_special_tokens(tokens)
403
+ masked_tokens, labels = self._apply_bert_mask(tokens)
404
+ masked_token_ids = np.array(
405
+ self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
406
+ input_mask = np.ones_like(masked_token_ids)
407
+
408
+ masked_token_ids = np.array(
409
+ self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
410
+
411
+ return masked_token_ids, input_mask, labels
412
+
413
+
414
+ @registry.register_task('unilanguage')
415
+ class UniModelingDataset(MaskedLanguageModelingDataset):
416
+
417
+ def __init__(self,
418
+ data_path: Union[str, Path],
419
+ split: str,
420
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
421
+ in_memory: bool = False):
422
+ super().__init__(data_path, split, tokenizer, in_memory)
423
+ data_path = Path(data_path)
424
+ data_file = f'unilanguage/PF00144_full_length_sequences_labeled.fasta'
425
+ self.data = dataset_factory(data_path / data_file, in_memory)
426
+
427
+ def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]:
428
+ input_ids, input_mask, lm_label_ids = tuple(zip(*batch))
429
+
430
+ input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
431
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
432
+ # ignore_index is -1
433
+ lm_label_ids = torch.from_numpy(pad_sequences(lm_label_ids, -1))
434
+
435
+ return {'input_ids': input_ids,
436
+ 'input_mask': input_mask,
437
+ 'targets': lm_label_ids}
438
+
439
+ def __getitem__(self, index):
440
+ item = self.data[index]
441
+ tokens = self.tokenizer.tokenize(item['primary'])
442
+ tokens = self.tokenizer.add_special_tokens(tokens)
443
+ masked_tokens, labels = self._apply_bert_mask(tokens)
444
+ masked_token_ids = np.array(
445
+ self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
446
+ input_mask = np.ones_like(masked_token_ids)
447
+
448
+ masked_token_ids = np.array(
449
+ self.tokenizer.convert_tokens_to_ids(masked_tokens), np.int64)
450
+
451
+ return masked_token_ids, input_mask, labels
452
+
453
+
454
+ @registry.register_task('language_modeling')
455
+ class LanguageModelingDataset(Dataset):
456
+ """Creates the Language Modeling Pfam Dataset
457
+ Args:
458
+ data_path (Union[str, Path]): Path to tape data root.
459
+ split (str): One of ['train', 'valid', 'holdout'], specifies which data file to load.
460
+ in_memory (bool, optional): Whether to load the full dataset into memory.
461
+ Default: False.
462
+ """
463
+
464
+ def __init__(self,
465
+ data_path: Union[str, Path],
466
+ split: str,
467
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
468
+ in_memory: bool = False):
469
+ super().__init__()
470
+ if split not in ('train', 'valid', 'holdout'):
471
+ raise ValueError(
472
+ f"Unrecognized split: {split}. "
473
+ f"Must be one of ['train', 'valid', 'holdout']")
474
+ if isinstance(tokenizer, str):
475
+ tokenizer = TAPETokenizer(vocab=tokenizer)
476
+ self.tokenizer = tokenizer
477
+
478
+ data_path = Path(data_path)
479
+ data_file = f'pfam/pfam_{split}.lmdb'
480
+ self.data = dataset_factory(data_path / data_file, in_memory)
481
+
482
+ def __len__(self) -> int:
483
+ return len(self.data)
484
+
485
+ def __getitem__(self, index):
486
+ item = self.data[index]
487
+ token_ids = self.tokenizer.encode(item['primary'])
488
+ input_mask = np.ones_like(token_ids)
489
+
490
+ return token_ids, input_mask, item['clan'], item['family']
491
+
492
+ def collate_fn(self, batch: List[Any]) -> Dict[str, torch.Tensor]:
493
+ input_ids, input_mask, clan, family = tuple(zip(*batch))
494
+
495
+ torch_inputs = torch.from_numpy(pad_sequences(input_ids, 0))
496
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
497
+ # ignore_index is -1
498
+ torch_labels = torch.from_numpy(pad_sequences(input_ids, -1))
499
+ clan = torch.LongTensor(clan) # type: ignore
500
+ family = torch.LongTensor(family) # type: ignore
501
+
502
+ return {'input_ids': torch_inputs,
503
+ 'input_mask': input_mask,
504
+ 'targets': torch_labels}
505
+
506
+
507
+ @registry.register_task('fluorescence')
508
+ class FluorescenceDataset(Dataset):
509
+
510
+ def __init__(self,
511
+ data_path: Union[str, Path],
512
+ split: str,
513
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
514
+ in_memory: bool = False):
515
+
516
+ if split not in ('train', 'valid', 'test'):
517
+ raise ValueError(f"Unrecognized split: {split}. "
518
+ f"Must be one of ['train', 'valid', 'test']")
519
+ if isinstance(tokenizer, str):
520
+ tokenizer = TAPETokenizer(vocab=tokenizer)
521
+ self.tokenizer = tokenizer
522
+
523
+ data_path = Path(data_path)
524
+ data_file = f'fluorescence/fluorescence_{split}.lmdb'
525
+ self.data = dataset_factory(data_path / data_file, in_memory)
526
+
527
+ def __len__(self) -> int:
528
+ return len(self.data)
529
+
530
+ def __getitem__(self, index: int):
531
+ item = self.data[index]
532
+ token_ids = self.tokenizer.encode(item['primary'])
533
+ input_mask = np.ones_like(token_ids)
534
+ return token_ids, input_mask, float(item['log_fluorescence'][0])
535
+
536
+ def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
537
+ input_ids, input_mask, fluorescence_true_value = tuple(zip(*batch))
538
+ input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
539
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
540
+ fluorescence_true_value = torch.FloatTensor(fluorescence_true_value) # type: ignore
541
+ fluorescence_true_value = fluorescence_true_value.unsqueeze(1)
542
+
543
+ return {'input_ids': input_ids,
544
+ 'input_mask': input_mask,
545
+ 'targets': fluorescence_true_value}
546
+
547
+
548
+ @registry.register_task('stability')
549
+ class StabilityDataset(Dataset):
550
+
551
+ def __init__(self,
552
+ data_path: Union[str, Path],
553
+ split: str,
554
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
555
+ in_memory: bool = False):
556
+
557
+ if split not in ('train', 'valid', 'test'):
558
+ raise ValueError(f"Unrecognized split: {split}. "
559
+ f"Must be one of ['train', 'valid', 'test']")
560
+ if isinstance(tokenizer, str):
561
+ tokenizer = TAPETokenizer(vocab=tokenizer)
562
+ self.tokenizer = tokenizer
563
+
564
+ data_path = Path(data_path)
565
+ data_file = f'stability/stability_{split}.lmdb'
566
+
567
+ self.data = dataset_factory(data_path / data_file, in_memory)
568
+
569
+ def __len__(self) -> int:
570
+ return len(self.data)
571
+
572
+ def __getitem__(self, index: int):
573
+ item = self.data[index]
574
+ token_ids = self.tokenizer.encode(item['primary'])
575
+ input_mask = np.ones_like(token_ids)
576
+ return token_ids, input_mask, float(item['stability_score'][0])
577
+
578
+ def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
579
+ input_ids, input_mask, stability_true_value = tuple(zip(*batch))
580
+ input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
581
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
582
+ stability_true_value = torch.FloatTensor(stability_true_value) # type: ignore
583
+ stability_true_value = stability_true_value.unsqueeze(1)
584
+
585
+ return {'input_ids': input_ids,
586
+ 'input_mask': input_mask,
587
+ 'targets': stability_true_value}
588
+
589
+
590
+ @registry.register_task('remote_homology', num_labels=1195)
591
+ class RemoteHomologyDataset(Dataset):
592
+
593
+ def __init__(self,
594
+ data_path: Union[str, Path],
595
+ split: str,
596
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
597
+ in_memory: bool = False):
598
+
599
+ if split not in ('train', 'valid', 'test_fold_holdout',
600
+ 'test_family_holdout', 'test_superfamily_holdout'):
601
+ raise ValueError(f"Unrecognized split: {split}. Must be one of "
602
+ f"['train', 'valid', 'test_fold_holdout', "
603
+ f"'test_family_holdout', 'test_superfamily_holdout']")
604
+ if isinstance(tokenizer, str):
605
+ tokenizer = TAPETokenizer(vocab=tokenizer)
606
+ self.tokenizer = tokenizer
607
+
608
+ data_path = Path(data_path)
609
+ data_file = f'remote_homology/remote_homology_{split}.lmdb'
610
+ self.data = dataset_factory(data_path / data_file, in_memory)
611
+
612
+ def __len__(self) -> int:
613
+ return len(self.data)
614
+
615
+ def __getitem__(self, index: int):
616
+ item = self.data[index]
617
+ token_ids = self.tokenizer.encode(item['primary'])
618
+ input_mask = np.ones_like(token_ids)
619
+ return token_ids, input_mask, item['fold_label']
620
+
621
+ def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
622
+ input_ids, input_mask, fold_label = tuple(zip(*batch))
623
+ input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
624
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
625
+ fold_label = torch.LongTensor(fold_label) # type: ignore
626
+
627
+ return {'input_ids': input_ids,
628
+ 'input_mask': input_mask,
629
+ 'targets': fold_label}
630
+
631
+
632
+ @registry.register_task('contact_prediction')
633
+ class ProteinnetDataset(Dataset):
634
+
635
+ def __init__(self,
636
+ data_path: Union[str, Path],
637
+ split: str,
638
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
639
+ in_memory: bool = False):
640
+
641
+ if split not in ('train', 'train_unfiltered', 'valid', 'test'):
642
+ raise ValueError(f"Unrecognized split: {split}. Must be one of "
643
+ f"['train', 'train_unfiltered', 'valid', 'test']")
644
+
645
+ if isinstance(tokenizer, str):
646
+ tokenizer = TAPETokenizer(vocab=tokenizer)
647
+ self.tokenizer = tokenizer
648
+
649
+ data_path = Path(data_path)
650
+ data_file = f'proteinnet/proteinnet_{split}.lmdb'
651
+ self.data = dataset_factory(data_path / data_file, in_memory)
652
+
653
+ def __len__(self) -> int:
654
+ return len(self.data)
655
+
656
+ def __getitem__(self, index: int):
657
+ item = self.data[index]
658
+ protein_length = len(item['primary'])
659
+ token_ids = self.tokenizer.encode(item['primary'])
660
+ input_mask = np.ones_like(token_ids)
661
+
662
+ valid_mask = item['valid_mask']
663
+ contact_map = np.less(squareform(pdist(item['tertiary'])), 8.0).astype(np.int64)
664
+
665
+ yind, xind = np.indices(contact_map.shape)
666
+ invalid_mask = ~(valid_mask[:, None] & valid_mask[None, :])
667
+ invalid_mask |= np.abs(yind - xind) < 6
668
+ contact_map[invalid_mask] = -1
669
+
670
+ return token_ids, input_mask, contact_map, protein_length
671
+
672
+ def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
673
+ input_ids, input_mask, contact_labels, protein_length = tuple(zip(*batch))
674
+ input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
675
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
676
+ contact_labels = torch.from_numpy(pad_sequences(contact_labels, -1))
677
+ protein_length = torch.LongTensor(protein_length) # type: ignore
678
+
679
+ return {'input_ids': input_ids,
680
+ 'input_mask': input_mask,
681
+ 'targets': contact_labels,
682
+ 'protein_length': protein_length}
683
+
684
+
685
+ @registry.register_task('secondary_structure', num_labels=3)
686
+ class SecondaryStructureDataset(Dataset):
687
+
688
+ def __init__(self,
689
+ data_path: Union[str, Path],
690
+ split: str,
691
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
692
+ in_memory: bool = False):
693
+
694
+ if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513'):
695
+ raise ValueError(f"Unrecognized split: {split}. Must be one of "
696
+ f"['train', 'valid', 'casp12', "
697
+ f"'ts115', 'cb513']")
698
+ if isinstance(tokenizer, str):
699
+ tokenizer = TAPETokenizer(vocab=tokenizer)
700
+ self.tokenizer = tokenizer
701
+
702
+ data_path = Path(data_path)
703
+ data_file = f'secondary_structure/secondary_structure_{split}.lmdb'
704
+ self.data = dataset_factory(data_path / data_file, in_memory)
705
+
706
+ def __len__(self) -> int:
707
+ return len(self.data)
708
+
709
+ def __getitem__(self, index: int):
710
+ item = self.data[index]
711
+ token_ids = self.tokenizer.encode(item['primary'])
712
+ input_mask = np.ones_like(token_ids)
713
+
714
+ # pad with -1s because of cls/sep tokens
715
+ labels = np.asarray(item['ss3'], np.int64)
716
+ labels = np.pad(labels, (1, 1), 'constant', constant_values=-1)
717
+
718
+ return token_ids, input_mask, labels
719
+
720
+ def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
721
+ input_ids, input_mask, ss_label = tuple(zip(*batch))
722
+ input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
723
+ input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
724
+ ss_label = torch.from_numpy(pad_sequences(ss_label, -1))
725
+
726
+ output = {'input_ids': input_ids,
727
+ 'input_mask': input_mask,
728
+ 'targets': ss_label}
729
+
730
+ return output
731
+
732
+
733
+ @registry.register_task('trrosetta')
734
+ class TRRosettaDataset(Dataset):
735
+
736
+ def __init__(self,
737
+ data_path: Union[str, Path],
738
+ split: str,
739
+ tokenizer: Union[str, TAPETokenizer] = 'iupac',
740
+ in_memory: bool = False,
741
+ max_seqlen: int = 300):
742
+ if split not in ('train', 'valid'):
743
+ raise ValueError(
744
+ f"Unrecognized split: {split}. "
745
+ f"Must be one of ['train', 'valid']")
746
+ if isinstance(tokenizer, str):
747
+ tokenizer = TAPETokenizer(vocab=tokenizer)
748
+ self.tokenizer = tokenizer
749
+
750
+ data_path = Path(data_path)
751
+ data_path = data_path / 'trrosetta'
752
+ split_files = (data_path / f'{split}_files.txt').read_text().split()
753
+ self.data = NPZDataset(data_path / 'npz', in_memory, split_files=split_files)
754
+
755
+ self._dist_bins = np.arange(2, 20.1, 0.5)
756
+ self._dihedral_bins = (15 + np.arange(-180, 180, 15)) / 180 * np.pi
757
+ self._planar_bins = (15 + np.arange(0, 180, 15)) / 180 * np.pi
758
+ self._split = split
759
+ self.max_seqlen = max_seqlen
760
+ self.msa_cutoff = 0.8
761
+ self.penalty_coeff = 4.5
762
+
763
+ def __len__(self) -> int:
764
+ return len(self.data)
765
+
766
+ def __getitem__(self, index):
767
+ item = self.data[index]
768
+
769
+ msa = item['msa']
770
+ dist = item['dist6d']
771
+ omega = item['omega6d']
772
+ theta = item['theta6d']
773
+ phi = item['phi6d']
774
+
775
+ if self._split == 'train':
776
+ msa = self._subsample_msa(msa)
777
+ elif self._split == 'valid':
778
+ msa = msa[:20000] # runs out of memory if msa is way too big
779
+ msa, dist, omega, theta, phi = self._slice_long_sequences(
780
+ msa, dist, omega, theta, phi)
781
+
782
+ mask = dist == 0
783
+
784
+ dist_bins = np.digitize(dist, self._dist_bins)
785
+ omega_bins = np.digitize(omega, self._dihedral_bins) + 1
786
+ theta_bins = np.digitize(theta, self._dihedral_bins) + 1
787
+ phi_bins = np.digitize(phi, self._planar_bins) + 1
788
+
789
+ dist_bins[mask] = 0
790
+ omega_bins[mask] = 0
791
+ theta_bins[mask] = 0
792
+ phi_bins[mask] = 0
793
+
794
+ dist_bins[np.diag_indices_from(dist_bins)] = -1
795
+
796
+ # input_mask = np.ones_like(msa[0])
797
+
798
+ return msa, dist_bins, omega_bins, theta_bins, phi_bins
799
+
800
+ def _slice_long_sequences(self, msa, dist, omega, theta, phi):
801
+ seqlen = msa.shape[1]
802
+ if self.max_seqlen > 0 and seqlen > self.max_seqlen:
803
+ start = np.random.randint(seqlen - self.max_seqlen + 1)
804
+ end = start + self.max_seqlen
805
+
806
+ msa = msa[:, start:end]
807
+ dist = dist[start:end, start:end]
808
+ omega = omega[start:end, start:end]
809
+ theta = theta[start:end, start:end]
810
+ phi = phi[start:end, start:end]
811
+
812
+ return msa, dist, omega, theta, phi
813
+
814
+ def _subsample_msa(self, msa):
815
+ num_alignments, seqlen = msa.shape
816
+
817
+ if num_alignments < 10:
818
+ return msa
819
+
820
+ num_sample = int(10 ** np.random.uniform(np.log10(num_alignments)) - 10)
821
+
822
+ if num_sample <= 0:
823
+ return msa[0][None, :]
824
+ elif num_sample > 20000:
825
+ num_sample = 20000
826
+
827
+ indices = np.random.choice(
828
+ msa.shape[0] - 1, size=num_sample, replace=False) + 1
829
+ indices = np.pad(indices, [1, 0], 'constant') # add the sequence back in
830
+ return msa[indices]
831
+
832
+ def collate_fn(self, batch):
833
+ msa, dist_bins, omega_bins, theta_bins, phi_bins = tuple(zip(*batch))
834
+ # features = pad_sequences([self.featurize(msa_) for msa_ in msa], 0)
835
+ msa1hot = pad_sequences(
836
+ [F.one_hot(torch.LongTensor(msa_), 21) for msa_ in msa], 0, torch.float)
837
+ # input_mask = torch.FloatTensor(pad_sequences(input_mask, 0))
838
+ dist_bins = torch.LongTensor(pad_sequences(dist_bins, -1))
839
+ omega_bins = torch.LongTensor(pad_sequences(omega_bins, 0))
840
+ theta_bins = torch.LongTensor(pad_sequences(theta_bins, 0))
841
+ phi_bins = torch.LongTensor(pad_sequences(phi_bins, 0))
842
+
843
+ return {'msa1hot': msa1hot,
844
+ # 'input_mask': input_mask,
845
+ 'dist': dist_bins,
846
+ 'omega': omega_bins,
847
+ 'theta': theta_bins,
848
+ 'phi': phi_bins}
849
+
850
+ def featurize(self, msa):
851
+ msa = torch.LongTensor(msa)
852
+ msa1hot = F.one_hot(msa, 21).float()
853
+
854
+ seqlen = msa1hot.size(1)
855
+
856
+ weights = self.reweight(msa1hot)
857
+ features_1d = self.extract_features_1d(msa1hot, weights)
858
+ features_2d = self.extract_features_2d(msa1hot, weights)
859
+
860
+ features = torch.cat((
861
+ features_1d.unsqueeze(1).repeat(1, seqlen, 1),
862
+ features_1d.unsqueeze(0).repeat(seqlen, 1, 1),
863
+ features_2d), -1)
864
+
865
+ features = features.permute(2, 0, 1)
866
+
867
+ return features
868
+
869
+ def reweight(self, msa1hot):
870
+ # Reweight
871
+ seqlen = msa1hot.size(1)
872
+ id_min = seqlen * self.msa_cutoff
873
+ id_mtx = torch.tensordot(msa1hot, msa1hot, [[1, 2], [1, 2]])
874
+ id_mask = id_mtx > id_min
875
+ weights = 1.0 / id_mask.float().sum(-1)
876
+ return weights
877
+
878
+ def extract_features_1d(self, msa1hot, weights):
879
+ # 1D Features
880
+ seqlen = msa1hot.size(1)
881
+ f1d_seq = msa1hot[0, :, :20]
882
+
883
+ # msa2pssm
884
+ beff = weights.sum()
885
+ f_i = (weights[:, None, None] * msa1hot).sum(0) / beff + 1e-9
886
+ h_i = (-f_i * f_i.log()).sum(1, keepdims=True)
887
+ f1d_pssm = torch.cat((f_i, h_i), dim=1)
888
+
889
+ f1d = torch.cat((f1d_seq, f1d_pssm), dim=1)
890
+ f1d = f1d.view(seqlen, 42)
891
+ return f1d
892
+
893
+ def extract_features_2d(self, msa1hot, weights):
894
+ # 2D Features
895
+ num_alignments = msa1hot.size(0)
896
+ seqlen = msa1hot.size(1)
897
+ num_symbols = 21
898
+ if num_alignments == 1:
899
+ # No alignments, predict from sequence alone
900
+ f2d_dca = torch.zeros(seqlen, seqlen, 442, dtype=torch.float)
901
+ else:
902
+ # fast_dca
903
+
904
+ # covariance
905
+ x = msa1hot.view(num_alignments, seqlen * num_symbols)
906
+ num_points = weights.sum() - weights.mean().sqrt()
907
+ mean = (x * weights[:, None]).sum(0, keepdims=True) / num_points
908
+ x = (x - mean) * weights[:, None].sqrt()
909
+ cov = torch.matmul(x.transpose(-1, -2), x) / num_points
910
+
911
+ # inverse covariance
912
+ reg = torch.eye(seqlen * num_symbols) * self.penalty_coeff / weights.sum().sqrt()
913
+ cov_reg = cov + reg
914
+ inv_cov = torch.inverse(cov_reg)
915
+
916
+ x1 = inv_cov.view(seqlen, num_symbols, seqlen, num_symbols)
917
+ x2 = x1.permute(0, 2, 1, 3)
918
+ features = x2.reshape(seqlen, seqlen, num_symbols * num_symbols)
919
+
920
+ x3 = (x1[:, :-1, :, :-1] ** 2).sum((1, 3)).sqrt() * (1 - torch.eye(seqlen))
921
+ apc = x3.sum(0, keepdims=True) * x3.sum(1, keepdims=True) / x3.sum()
922
+ contacts = (x3 - apc) * (1 - torch.eye(seqlen))
923
+
924
+ f2d_dca = torch.cat([features, contacts[:, :, None]], axis=2)
925
+
926
+ return f2d_dca
tape/errors.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ class EarlyStopping(Exception):
2
+ """Raised when stopping training b/c no improvement in validation loss"""
3
+ pass
tape/main.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import os
3
+ import logging
4
+ import argparse
5
+ import warnings
6
+ import inspect
7
+
8
+
9
+ try:
10
+ import apex # noqa: F401
11
+ APEX_FOUND = True
12
+ except ImportError:
13
+ APEX_FOUND = False
14
+
15
+ from .registry import registry
16
+ from . import training
17
+ from . import utils
18
+
19
+ CallbackList = typing.Sequence[typing.Callable]
20
+ OutputDict = typing.Dict[str, typing.List[typing.Any]]
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+ warnings.filterwarnings( # Ignore pytorch warning about loss gathering
25
+ 'ignore', message='Was asked to gather along dimension 0', module='torch.nn.parallel')
26
+
27
+
28
+ def create_base_parser() -> argparse.ArgumentParser:
29
+ parser = argparse.ArgumentParser(description='Parent parser for tape functions',
30
+ add_help=False)
31
+ parser.add_argument('model_type', help='Base model class to run')
32
+ parser.add_argument('--model_config_file', default=None, type=utils.check_is_file,
33
+ help='Config file for model')
34
+ parser.add_argument('--vocab_file', default=None,
35
+ help='Pretrained tokenizer vocab file')
36
+ parser.add_argument('--output_dir', default='./results', type=str)
37
+ parser.add_argument('--no_cuda', action='store_true', help='CPU-only flag')
38
+ parser.add_argument('--seed', default=42, type=int, help='Random seed to use')
39
+ parser.add_argument('--local_rank', type=int, default=-1,
40
+ help='Local rank of process in distributed training. '
41
+ 'Set by launch script.')
42
+ parser.add_argument('--tokenizer', choices=['iupac', 'unirep'],
43
+ default='iupac', help='Tokenizes to use on the amino acid sequences')
44
+ parser.add_argument('--num_workers', default=8, type=int,
45
+ help='Number of workers to use for multi-threaded data loading')
46
+ parser.add_argument('--log_level', default=logging.INFO,
47
+ choices=['DEBUG', 'INFO', 'WARN', 'WARNING', 'ERROR',
48
+ logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR],
49
+ help="log level for the experiment")
50
+ parser.add_argument('--debug', action='store_true', help='Run in debug mode')
51
+
52
+ return parser
53
+
54
+
55
+ def create_train_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
56
+ parser = argparse.ArgumentParser(description='Run Training on the TAPE datasets',
57
+ parents=[base_parser])
58
+ parser.add_argument('task', choices=list(registry.task_name_mapping.keys()),
59
+ help='TAPE Task to train/eval on')
60
+ parser.add_argument('--learning_rate', default=1e-4, type=float,
61
+ help='Learning rate')
62
+ parser.add_argument('--batch_size', default=1024, type=int,
63
+ help='Batch size')
64
+ parser.add_argument('--data_dir', default='./data', type=utils.check_is_dir,
65
+ help='Directory from which to load task data')
66
+ parser.add_argument('--num_train_epochs', default=10, type=int,
67
+ help='Number of training epochs')
68
+ parser.add_argument('--num_steps_per_epoch', default=-1, type=int,
69
+ help='Number of steps per epoch')
70
+ parser.add_argument('--num_log_iter', default=20, type=int,
71
+ help='Number of training steps per log iteration')
72
+ parser.add_argument('--fp16', action='store_true', help='Whether to use fp16 weights')
73
+ parser.add_argument('--warmup_steps', default=10000, type=int,
74
+ help='Number of learning rate warmup steps')
75
+ parser.add_argument('--gradient_accumulation_steps', default=1, type=int,
76
+ help='Number of forward passes to make for each backwards pass')
77
+ parser.add_argument('--loss_scale', default=0, type=int,
78
+ help='Loss scaling. Only used during fp16 training.')
79
+ parser.add_argument('--max_grad_norm', default=1.0, type=float,
80
+ help='Maximum gradient norm')
81
+ parser.add_argument('--exp_name', default=None, type=str,
82
+ help='Name to give to this experiment')
83
+ parser.add_argument('--from_pretrained', default=None, type=str,
84
+ help='Directory containing config and pretrained model weights')
85
+ parser.add_argument('--log_dir', default='./logs', type=str)
86
+ parser.add_argument('--eval_freq', type=int, default=1,
87
+ help="Frequency of eval pass. A value <= 0 means the eval pass is "
88
+ "not run")
89
+ parser.add_argument('--save_freq', default='improvement', type=utils.int_or_str,
90
+ help="How often to save the model during training. Either an integer "
91
+ "frequency or the string 'improvement'")
92
+ parser.add_argument('--patience', default=-1, type=int,
93
+ help="How many epochs without improvement to wait before ending "
94
+ "training")
95
+ parser.add_argument('--resume_from_checkpoint', action='store_true',
96
+ help="whether to resume training from the checkpoint")
97
+ parser.add_argument('--val_check_frac', default=1.0, type=float,
98
+ help="Fraction of validation to check")
99
+ return parser
100
+
101
+
102
+ def create_eval_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
103
+ parser = argparse.ArgumentParser(description='Run Eval on the TAPE Datasets',
104
+ parents=[base_parser])
105
+ parser.add_argument('task', choices=list(registry.task_name_mapping.keys()),
106
+ help='TAPE Task to train/eval on')
107
+ parser.add_argument('from_pretrained', type=str,
108
+ help='Directory containing config and pretrained model weights')
109
+ parser.add_argument('--batch_size', default=1024, type=int,
110
+ help='Batch size')
111
+ parser.add_argument('--data_dir', default='./data', type=utils.check_is_dir,
112
+ help='Directory from which to load task data')
113
+ parser.add_argument('--metrics', default=[],
114
+ help=f'Metrics to run on the result. '
115
+ f'Choices: {list(registry.metric_name_mapping.keys())}',
116
+ nargs='*')
117
+ parser.add_argument('--split', default='test', type=str,
118
+ help='Which split to run on')
119
+ return parser
120
+
121
+
122
+ def create_embed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
123
+ parser = argparse.ArgumentParser(
124
+ description='Embed a set of proteins with a pretrained model',
125
+ parents=[base_parser])
126
+ parser.add_argument('data_file', type=str,
127
+ help='File containing set of proteins to embed')
128
+ parser.add_argument('out_file', type=str,
129
+ help='Name of output file')
130
+ parser.add_argument('from_pretrained', type=str,
131
+ help='Directory containing config and pretrained model weights')
132
+ parser.add_argument('--batch_size', default=1024, type=int,
133
+ help='Batch size')
134
+ parser.add_argument('--full_sequence_embed', action='store_true',
135
+ help='If true, saves an embedding at every amino acid position '
136
+ 'in the sequence. Note that this can take a large amount '
137
+ 'of disk space.')
138
+ parser.set_defaults(task='embed')
139
+ return parser
140
+
141
+
142
+ def create_distributed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
143
+ parser = argparse.ArgumentParser(add_help=False, parents=[base_parser])
144
+ # typing.Optional arguments for the launch helper
145
+ parser.add_argument("--nnodes", type=int, default=1,
146
+ help="The number of nodes to use for distributed "
147
+ "training")
148
+ parser.add_argument("--node_rank", type=int, default=0,
149
+ help="The rank of the node for multi-node distributed "
150
+ "training")
151
+ parser.add_argument("--nproc_per_node", type=int, default=1,
152
+ help="The number of processes to launch on each node, "
153
+ "for GPU training, this is recommended to be set "
154
+ "to the number of GPUs in your system so that "
155
+ "each process can be bound to a single GPU.")
156
+ parser.add_argument("--master_addr", default="127.0.0.1", type=str,
157
+ help="Master node (rank 0)'s address, should be either "
158
+ "the IP address or the hostname of node 0, for "
159
+ "single node multi-proc training, the "
160
+ "--master_addr can simply be 127.0.0.1")
161
+ parser.add_argument("--master_port", default=29500, type=int,
162
+ help="Master node (rank 0)'s free port that needs to "
163
+ "be used for communciation during distributed "
164
+ "training")
165
+ return parser
166
+
167
+
168
+ def create_model_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
169
+ parser = argparse.ArgumentParser(add_help=False, parents=[base_parser])
170
+ parser.add_argument('--model_args', nargs=argparse.REMAINDER, default=None)
171
+ return parser
172
+
173
+ def run_train(args: typing.Optional[argparse.Namespace] = None, env=None) -> None:
174
+ if env is not None:
175
+ os.environ = env
176
+
177
+ if args is None:
178
+ base_parser = create_base_parser()
179
+ train_parser = create_train_parser(base_parser)
180
+ model_parser = create_model_parser(train_parser)
181
+ args = model_parser.parse_args()
182
+
183
+ if args.gradient_accumulation_steps < 1:
184
+ raise ValueError(
185
+ f"Invalid gradient_accumulation_steps parameter: "
186
+ f"{args.gradient_accumulation_steps}, should be >= 1")
187
+
188
+ if (args.fp16 or args.local_rank != -1) and not APEX_FOUND:
189
+ raise ImportError(
190
+ "Please install apex from https://www.github.com/nvidia/apex "
191
+ "to use distributed and fp16 training.")
192
+
193
+ arg_dict = vars(args)
194
+ arg_names = inspect.getfullargspec(training.run_train).args
195
+
196
+ missing = set(arg_names) - set(arg_dict.keys())
197
+ if missing:
198
+ raise RuntimeError(f"Missing arguments: {missing}")
199
+ train_args = {name: arg_dict[name] for name in arg_names}
200
+
201
+ training.run_train(**train_args)
202
+
203
+
204
+ def run_eval(args: typing.Optional[argparse.Namespace] = None) -> typing.Dict[str, float]:
205
+ if args is None:
206
+ base_parser = create_base_parser()
207
+ parser = create_eval_parser(base_parser)
208
+ parser = create_model_parser(parser)
209
+ args = parser.parse_args()
210
+
211
+ if args.from_pretrained is None:
212
+ raise ValueError("Must specify pretrained model")
213
+ if args.local_rank != -1:
214
+ raise ValueError("TAPE does not support distributed validation pass")
215
+
216
+ arg_dict = vars(args)
217
+ arg_names = inspect.getfullargspec(training.run_eval).args
218
+
219
+ missing = set(arg_names) - set(arg_dict.keys())
220
+ if missing:
221
+ raise RuntimeError(f"Missing arguments: {missing}")
222
+ eval_args = {name: arg_dict[name] for name in arg_names}
223
+
224
+ return training.run_eval(**eval_args)
225
+
226
+
227
+ def run_embed(args: typing.Optional[argparse.Namespace] = None) -> None:
228
+ if args is None:
229
+ base_parser = create_base_parser()
230
+ parser = create_embed_parser(base_parser)
231
+ parser = create_model_parser(parser)
232
+ args = parser.parse_args()
233
+ if args.from_pretrained is None:
234
+ raise ValueError("Must specify pretrained model")
235
+ if args.local_rank != -1:
236
+ raise ValueError("TAPE does not support distributed validation pass")
237
+
238
+ arg_dict = vars(args)
239
+ arg_names = inspect.getfullargspec(training.run_embed).args
240
+
241
+ missing = set(arg_names) - set(arg_dict.keys())
242
+ if missing:
243
+ raise RuntimeError(f"Missing arguments: {missing}")
244
+ embed_args = {name: arg_dict[name] for name in arg_names}
245
+
246
+ training.run_embed(**embed_args)
247
+
248
+
249
+ def run_train_distributed(args: typing.Optional[argparse.Namespace] = None) -> None:
250
+ """Runs distributed training via multiprocessing.
251
+ """
252
+ if args is None:
253
+ base_parser = create_base_parser()
254
+ distributed_parser = create_distributed_parser(base_parser)
255
+ distributed_train_parser = create_train_parser(distributed_parser)
256
+ parser = create_model_parser(distributed_train_parser)
257
+ args = parser.parse_args()
258
+
259
+ # Define the experiment name here, instead of dealing with barriers and communication
260
+ # when getting the experiment name
261
+ exp_name = utils.get_expname(args.exp_name, args.task, args.model_type)
262
+ args.exp_name = exp_name
263
+ utils.launch_process_group(
264
+ run_train, args, args.nproc_per_node, args.nnodes,
265
+ args.node_rank, args.master_addr, args.master_port)
266
+
267
+
268
+ if __name__ == '__main__':
269
+ run_train_distributed()
tape/metrics.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence, Union
2
+ import numpy as np
3
+ import scipy.stats
4
+
5
+ from .registry import registry
6
+
7
+
8
+ @registry.register_metric('mse')
9
+ def mean_squared_error(target: Sequence[float],
10
+ prediction: Sequence[float]) -> float:
11
+ target_array = np.asarray(target)
12
+ prediction_array = np.asarray(prediction)
13
+ return np.mean(np.square(target_array - prediction_array))
14
+
15
+
16
+ @registry.register_metric('mae')
17
+ def mean_absolute_error(target: Sequence[float],
18
+ prediction: Sequence[float]) -> float:
19
+ target_array = np.asarray(target)
20
+ prediction_array = np.asarray(prediction)
21
+ return np.mean(np.abs(target_array - prediction_array))
22
+
23
+
24
+ @registry.register_metric('spearmanr')
25
+ def spearmanr(target: Sequence[float],
26
+ prediction: Sequence[float]) -> float:
27
+ target_array = np.asarray(target)
28
+ prediction_array = np.asarray(prediction)
29
+ return scipy.stats.spearmanr(target_array, prediction_array).correlation
30
+
31
+
32
+ @registry.register_metric('accuracy')
33
+ def accuracy(target: Union[Sequence[int], Sequence[Sequence[int]]],
34
+ prediction: Union[Sequence[float], Sequence[Sequence[float]]]) -> float:
35
+ if isinstance(target[0], int):
36
+ # non-sequence case
37
+ return np.mean(np.asarray(target) == np.asarray(prediction).argmax(-1))
38
+ else:
39
+ correct = 0
40
+ total = 0
41
+ for label, score in zip(target, prediction):
42
+ label_array = np.asarray(label)
43
+ pred_array = np.asarray(score).argmax(-1)
44
+ mask = label_array != -1
45
+ is_correct = label_array[mask] == pred_array[mask]
46
+ correct += is_correct.sum()
47
+ total += is_correct.size
48
+ return correct / total
tape/models/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .modeling_utils import ProteinConfig # noqa: F401
2
+ # from .modeling_utils import ProteinModel # noqa: F401
3
+
4
+ # from .modeling_bert import ProteinBertModel # noqa: F401
5
+ # from .modeling_bert import ProteinBertForMaskedLM # noqa: F401
6
+ # from .modeling_bert import ProteinBertForValuePrediction # noqa: F401
7
+ # from .modeling_bert import ProteinBertForSequenceClassification # noqa: F401
8
+ # from .modeling_bert import ProteinBertForSequenceToSequenceClassification # noqa: F401
9
+ # # TODO: ProteinBertForContactPrediction
10
+ # from .modeling_resnet import ProteinResNetModel # noqa: F401
11
+ # from .modeling_resnet import ProteinResNetForMaskedLM # noqa: F401
12
+ # from .modeling_resnet import ProteinResNetForValuePrediction # noqa: F401
13
+ # from .modeling_resnet import ProteinResNetForSequenceClassification # noqa: F401
14
+ # from .modeling_resnet import ProteinResNetForSequenceToSequenceClassification # noqa: F401
15
+ # # TODO: ProteinResNetForContactPrediction
16
+ # # TODO: ProteinLSTM*
17
+ # from .modeling_unirep import UniRepModel # noqa: F401
18
+ # from .modeling_unirep import UniRepForLM # noqa: F401
19
+ # from .modeling_unirep import UniRepForValuePrediction # noqa: F401
20
+ # from .modeling_unirep import UniRepForSequenceClassification # noqa: F401
21
+ # from .modeling_unirep import UniRepForSequenceToSequenceClassification # noqa: F401
22
+ # # TODO: UniRepForContactPrediction
23
+ # # TODO: Bepler*
24
+ # from .modeling_onehot import OneHotModel # noqa: F401
25
+ # from .modeling_onehot import OneHotForValuePrediction # noqa: F401
26
+ # from .modeling_onehot import OneHotForSequenceClassification # noqa: F401
27
+ # from .modeling_onehot import OneHotForSequenceToSequenceClassification # noqa: F401
28
+ # TODO: OneHotForContactPrediction
tape/models/file_utils.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the huggingface transformers library at
4
+ https://github.com/huggingface/transformers, which in turn is adapted from the AllenNLP
5
+ library at https://github.com/allenai/allennlp
6
+ Copyright by the AllenNLP authors.
7
+ Note - this file goes to effort to support Python 2, but the rest of this repository does not.
8
+ """
9
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
10
+
11
+ import typing
12
+ import sys
13
+ import json
14
+ import logging
15
+ import os
16
+ import tempfile
17
+ import fnmatch
18
+ from io import open
19
+
20
+ import boto3
21
+ import requests
22
+ from botocore.exceptions import ClientError
23
+ from tqdm import tqdm
24
+
25
+ from contextlib import contextmanager
26
+ from functools import partial, wraps
27
+ from hashlib import sha256
28
+
29
+ from filelock import FileLock
30
+ # from tqdm.auto import tqdm
31
+
32
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ try:
36
+ from torch.hub import _get_torch_home
37
+ torch_cache_home = _get_torch_home()
38
+ except ImportError:
39
+ torch_cache_home = os.path.expanduser(
40
+ os.getenv('TORCH_HOME', os.path.join(
41
+ os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
42
+ default_cache_path = os.path.join(torch_cache_home, 'protein_models')
43
+
44
+ try:
45
+ from urllib.parse import urlparse
46
+ except ImportError:
47
+ from urlparse import urlparse # type: ignore
48
+
49
+ try:
50
+ from pathlib import Path
51
+ PYTORCH_PRETRAINED_BERT_CACHE: typing.Union[str, Path] = Path(
52
+ os.getenv('PROTEIN_MODELS_CACHE', os.getenv(
53
+ 'PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
54
+ except (AttributeError, ImportError):
55
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PROTEIN_MODELS_CACHE',
56
+ os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
57
+ default_cache_path))
58
+
59
+ PROTEIN_MODELS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
60
+
61
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
62
+
63
+
64
+ def get_cache():
65
+ return PROTEIN_MODELS_CACHE
66
+
67
+
68
+ def get_etag(url):
69
+ # Get eTag to add to filename, if it exists.
70
+ if url.startswith("s3://"):
71
+ etag = s3_etag(url)
72
+ else:
73
+ try:
74
+ response = requests.head(url, allow_redirects=True)
75
+ if response.status_code != 200:
76
+ etag = None
77
+ else:
78
+ etag = response.headers.get("ETag")
79
+ except EnvironmentError:
80
+ etag = None
81
+
82
+ if sys.version_info[0] == 2 and etag is not None:
83
+ etag = etag.decode('utf-8')
84
+
85
+ return etag
86
+
87
+
88
+ def url_to_filename(url, etag=None):
89
+ """
90
+ Convert `url` into a hashed filename in a repeatable way.
91
+ If `etag` is specified, append its hash to the url's, delimited
92
+ by a period.
93
+ """
94
+ url_bytes = url.encode('utf-8')
95
+ url_hash = sha256(url_bytes)
96
+ filename = url_hash.hexdigest()
97
+
98
+ if etag:
99
+ etag_bytes = etag.encode('utf-8')
100
+ etag_hash = sha256(etag_bytes)
101
+ filename += '.' + etag_hash.hexdigest()
102
+
103
+ return filename
104
+
105
+
106
+ def filename_to_url(filename, cache_dir=None):
107
+ """
108
+ Return the url and etag (which may be ``None``) stored for `filename`.
109
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
110
+ """
111
+ if cache_dir is None:
112
+ cache_dir = PROTEIN_MODELS_CACHE
113
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
114
+ cache_dir = str(cache_dir)
115
+
116
+ cache_path = os.path.join(cache_dir, filename)
117
+ if not os.path.exists(cache_path):
118
+ raise EnvironmentError("file {} not found".format(cache_path))
119
+
120
+ meta_path = cache_path + '.json'
121
+ if not os.path.exists(meta_path):
122
+ raise EnvironmentError("file {} not found".format(meta_path))
123
+
124
+ with open(meta_path, encoding="utf-8") as meta_file:
125
+ metadata = json.load(meta_file)
126
+ url = metadata['url']
127
+ etag = metadata['etag']
128
+
129
+ return url, etag
130
+
131
+
132
+ def cached_path(url_or_filename, force_download=False, cache_dir=None):
133
+ """
134
+ Given something that might be a URL (or might be a local path),
135
+ determine which. If it's a URL, download the file and cache it, and
136
+ return the path to the cached file. If it's already a local path,
137
+ make sure the file exists and then return the path.
138
+
139
+ Args:
140
+ cache_dir: specify a cache directory to save the file to
141
+ (overwrite the default cache dir).
142
+ force_download: if True, re-dowload the file even if it's
143
+ already cached in the cache dir.
144
+ """
145
+ if cache_dir is None:
146
+ cache_dir = PROTEIN_MODELS_CACHE
147
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
148
+ url_or_filename = str(url_or_filename)
149
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
150
+ cache_dir = str(cache_dir)
151
+
152
+ parsed = urlparse(url_or_filename)
153
+
154
+ if parsed.scheme in ('http', 'https', 's3'):
155
+ # URL, so get it from the cache (downloading if necessary)
156
+ output_path = get_from_cache(url_or_filename, cache_dir, force_download)
157
+ elif os.path.exists(url_or_filename):
158
+ # File, and it exists.
159
+ output_path = url_or_filename
160
+ elif parsed.scheme == '':
161
+ # File, but it doesn't exist.
162
+ raise EnvironmentError("file {} not found".format(url_or_filename))
163
+ else:
164
+ # Something unknown
165
+ raise ValueError("unable to parse {} as a URL or as a local path".format(
166
+ url_or_filename))
167
+
168
+ return output_path
169
+
170
+
171
+ def split_s3_path(url):
172
+ """Split a full s3 path into the bucket name and path."""
173
+ parsed = urlparse(url)
174
+ if not parsed.netloc or not parsed.path:
175
+ raise ValueError("bad s3 path {}".format(url))
176
+ bucket_name = parsed.netloc
177
+ s3_path = parsed.path
178
+ # Remove '/' at beginning of path.
179
+ if s3_path.startswith("/"):
180
+ s3_path = s3_path[1:]
181
+ return bucket_name, s3_path
182
+
183
+
184
+ def s3_request(func):
185
+ """
186
+ Wrapper function for s3 requests in order to create more helpful error
187
+ messages.
188
+ """
189
+
190
+ @wraps(func)
191
+ def wrapper(url, *args, **kwargs):
192
+ try:
193
+ return func(url, *args, **kwargs)
194
+ except ClientError as exc:
195
+ if int(exc.response["Error"]["Code"]) == 404:
196
+ raise EnvironmentError("file {} not found".format(url))
197
+ else:
198
+ raise
199
+
200
+ return wrapper
201
+
202
+
203
+ @s3_request
204
+ def s3_etag(url):
205
+ """Check ETag on S3 object."""
206
+ s3_resource = boto3.resource("s3")
207
+ bucket_name, s3_path = split_s3_path(url)
208
+ s3_object = s3_resource.Object(bucket_name, s3_path)
209
+ return s3_object.e_tag
210
+
211
+
212
+ @s3_request
213
+ def s3_get(url, temp_file):
214
+ """Pull a file directly from S3."""
215
+ s3_resource = boto3.resource("s3")
216
+ bucket_name, s3_path = split_s3_path(url)
217
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
218
+
219
+
220
+ def http_get(url, temp_file):
221
+ req = requests.get(url, stream=True)
222
+ content_length = req.headers.get('Content-Length')
223
+ total = int(content_length) if content_length is not None else None
224
+ progress = tqdm(unit="B", total=total)
225
+ for chunk in req.iter_content(chunk_size=1024):
226
+ if chunk: # filter out keep-alive new chunks
227
+ progress.update(len(chunk))
228
+ temp_file.write(chunk)
229
+ progress.close()
230
+
231
+
232
+ def get_from_cache(url, cache_dir=None, force_download=False, resume_download=False):
233
+ """
234
+ Given a URL, look for the corresponding dataset in the local cache.
235
+ If it's not there, download it. Then return the path to the cached file.
236
+ """
237
+ if cache_dir is None:
238
+ cache_dir = PROTEIN_MODELS_CACHE
239
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
240
+ cache_dir = str(cache_dir)
241
+ if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
242
+ cache_dir = str(cache_dir)
243
+
244
+ if not os.path.exists(cache_dir):
245
+ os.makedirs(cache_dir)
246
+
247
+ # Get eTag to add to filename, if it exists.
248
+ if url.startswith("s3://"):
249
+ etag = s3_etag(url)
250
+ else:
251
+ try:
252
+ response = requests.head(url, allow_redirects=True)
253
+ if response.status_code != 200:
254
+ etag = None
255
+ else:
256
+ etag = response.headers.get("ETag")
257
+ except EnvironmentError:
258
+ etag = None
259
+
260
+ if sys.version_info[0] == 2 and etag is not None:
261
+ etag = etag.decode('utf-8')
262
+ filename = url_to_filename(url, etag)
263
+
264
+ # get cache path to put the file
265
+ cache_path = os.path.join(cache_dir, filename)
266
+
267
+ if os.path.exists(cache_path) and etag is None:
268
+ return cache_path
269
+
270
+ # If we don't have a connection (etag is None) and can't identify the file
271
+ # try to get the last downloaded one
272
+ if not os.path.exists(cache_path) and etag is None:
273
+ matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
274
+ matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
275
+ if matching_files:
276
+ cache_path = os.path.join(cache_dir, matching_files[-1])
277
+
278
+ # From now on, etag is not None
279
+ if os.path.exists(cache_path) and not force_download:
280
+ return cache_path
281
+
282
+ # Prevent parallel downloads of the same file with a lock.
283
+ lock_path = cache_path + ".lock"
284
+ with FileLock(lock_path):
285
+
286
+ # If the download just completed while the lock was activated.
287
+ if os.path.exists(cache_path) and not force_download:
288
+ # Even if returning early like here, the lock will be released.
289
+ return cache_path
290
+
291
+ if resume_download:
292
+ incomplete_path = cache_path + ".incomplete"
293
+
294
+ @contextmanager
295
+ def _resumable_file_manager():
296
+ with open(incomplete_path, "a+b") as f:
297
+ yield f
298
+
299
+ temp_file_manager = _resumable_file_manager
300
+ else:
301
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir,
302
+ delete=False)
303
+ # Download to temporary file, then copy to cache dir once finished.
304
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
305
+ with temp_file_manager() as temp_file:
306
+ logger.info("%s not in cache or force_download=True, download to %s",
307
+ url, temp_file.name)
308
+
309
+ http_get(url, temp_file)
310
+
311
+ logger.info("storing %s in cache at %s", url, cache_path)
312
+ os.replace(temp_file.name, cache_path)
313
+
314
+ logger.info("creating metadata file for %s", cache_path)
315
+ meta = {"url": url, "etag": etag}
316
+ meta_path = cache_path + ".json"
317
+ with open(meta_path, "w") as meta_file:
318
+ json.dump(meta, meta_file)
319
+ '''
320
+ if not os.path.exists(cache_path):
321
+ # Download to temporary file, then copy to cache dir once finished.
322
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
323
+ with tempfile.NamedTemporaryFile() as temp_file:
324
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
325
+
326
+ # GET file object
327
+ if url.startswith("s3://"):
328
+ s3_get(url, temp_file)
329
+ else:
330
+ http_get(url, temp_file)
331
+
332
+ # we are copying the file before closing it, so flush to avoid truncation
333
+ temp_file.flush()
334
+ # shutil.copyfileobj() starts at the current position, so go to the start
335
+ temp_file.seek(0)
336
+
337
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
338
+ with open(cache_path, 'wb') as cache_file:
339
+ shutil.copyfileobj(temp_file, cache_file)
340
+
341
+ logger.info("creating metadata file for %s", cache_path)
342
+ meta = {'url': url, 'etag': etag}
343
+ meta_path = cache_path + '.json'
344
+ with open(meta_path, 'w') as meta_file:
345
+ output_string = json.dumps(meta)
346
+ if sys.version_info[0] == 2 and isinstance(output_string, str):
347
+ # The beauty of python 2
348
+ output_string = unicode(output_string, 'utf-8') # noqa: F821
349
+ meta_file.write(output_string)
350
+
351
+ logger.info("removing temp file %s", temp_file.name)
352
+ '''
353
+ return cache_path
tape/models/modeling_autoencoder.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import logging
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .modeling_utils import ProteinConfig
8
+ from .modeling_utils import ProteinModel
9
+ from .modeling_utils import get_activation_fn
10
+ from .modeling_utils import MLMHead
11
+ from .modeling_utils import LayerNorm
12
+ from .modeling_utils import ValuePredictionHead
13
+ from .modeling_utils import SequenceClassificationHead
14
+ from .modeling_utils import SequenceToSequenceClassificationHead
15
+ from .modeling_utils import PairwiseContactPredictionHead
16
+ from ..registry import registry
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {}
21
+ RESNET_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {}
22
+
23
+
24
+ class ProteinAEConfig(ProteinConfig):
25
+ pretrained_config_archive_map = RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP
26
+
27
+ def __init__(self,
28
+ vocab_size: int = 30,
29
+ hidden_size: int = 512,
30
+ num_hidden_layers: int = 30,
31
+ hidden_act: str = "gelu",
32
+ hidden_dropout_prob: float = 0.1,
33
+ initializer_range: float = 0.02,
34
+ layer_norm_eps: float = 1e-12,
35
+ temporal_pooling: str = 'attention',
36
+ freeze_embedding: bool = False,
37
+ max_size: int = 3000,
38
+ latent_size: int = 1024,
39
+ **kwargs):
40
+ super().__init__(**kwargs)
41
+ self.vocab_size = vocab_size
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.hidden_size = hidden_size
44
+ self.hidden_act = hidden_act
45
+ self.hidden_dropout_prob = hidden_dropout_prob
46
+ self.initializer_range = initializer_range
47
+ self.layer_norm_eps = layer_norm_eps
48
+ self.temporal_pooling = temporal_pooling
49
+ self.freeze_embedding = freeze_embedding
50
+ self.max_size = max_size
51
+ self.latent_size = latent_size
52
+
53
+
54
+ class MaskedConv1d(nn.Conv1d):
55
+
56
+ def forward(self, x, input_mask=None):
57
+ if input_mask is not None:
58
+ x = x * input_mask
59
+ return super().forward(x)
60
+
61
+
62
+ class ProteinResNetLayerNorm(nn.Module):
63
+
64
+ def __init__(self, config):
65
+ super().__init__()
66
+ self.norm = LayerNorm(config.hidden_size)
67
+
68
+ def forward(self, x):
69
+ return self.norm(x.transpose(1, 2)).transpose(1, 2)
70
+
71
+
72
+ class ProteinResNetBlock(nn.Module):
73
+
74
+ def __init__(self, config):
75
+ super().__init__()
76
+ self.conv1 = MaskedConv1d(
77
+ config.hidden_size, config.hidden_size, 3, padding=1, bias=False)
78
+ # self.bn1 = nn.BatchNorm1d(config.hidden_size)
79
+ self.bn1 = ProteinResNetLayerNorm(config)
80
+ self.conv2 = MaskedConv1d(
81
+ config.hidden_size, config.hidden_size, 3, padding=1, bias=False)
82
+ # self.bn2 = nn.BatchNorm1d(config.hidden_size)
83
+ self.bn2 = ProteinResNetLayerNorm(config)
84
+ self.activation_fn = get_activation_fn(config.hidden_act)
85
+
86
+ def forward(self, x, input_mask=None):
87
+ identity = x
88
+
89
+ out = self.conv1(x, input_mask)
90
+ out = self.bn1(out)
91
+ out = self.activation_fn(out)
92
+
93
+ out = self.conv2(out, input_mask)
94
+ out = self.bn2(out)
95
+
96
+ out += identity
97
+ out = self.activation_fn(out)
98
+
99
+ return out
100
+
101
+
102
+ class ProteinResNetEmbeddings(nn.Module):
103
+ """Construct the embeddings from word, position and token_type embeddings.
104
+ """
105
+ def __init__(self, config):
106
+ super().__init__()
107
+ embed_dim = config.hidden_size
108
+ self.word_embeddings = nn.Embedding(config.vocab_size, embed_dim, padding_idx=0)
109
+ inverse_frequency = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0) / embed_dim))
110
+ self.register_buffer('inverse_frequency', inverse_frequency)
111
+
112
+ self.layer_norm = LayerNorm(embed_dim, eps=config.layer_norm_eps)
113
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
114
+
115
+ def forward(self, input_ids):
116
+ words_embeddings = self.word_embeddings(input_ids)
117
+
118
+ seq_length = input_ids.size(1)
119
+ position_ids = torch.arange(
120
+ seq_length - 1, -1, -1.0,
121
+ dtype=words_embeddings.dtype,
122
+ device=words_embeddings.device)
123
+ sinusoidal_input = torch.ger(position_ids, self.inverse_frequency)
124
+ position_embeddings = torch.cat([sinusoidal_input.sin(), sinusoidal_input.cos()], -1)
125
+ position_embeddings = position_embeddings.unsqueeze(0)
126
+
127
+ embeddings = words_embeddings + position_embeddings
128
+ embeddings = self.layer_norm(embeddings)
129
+ embeddings = self.dropout(embeddings)
130
+ return embeddings
131
+
132
+
133
+ class ResNetEncoder(nn.Module):
134
+
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ self.config = config
138
+ self.output_hidden_states = config.output_hidden_states
139
+ self.encoder = nn.ModuleList(
140
+ [ProteinResNetBlock(config) for _ in range(config.num_hidden_layers)])
141
+
142
+ self.decoder = nn.ModuleList(
143
+ [ProteinResNetBlock(config) for _ in range(config.num_hidden_layers)])
144
+
145
+ self.bottleneck1 = nn.Linear(93*config.hidden_size, config.latent_size)
146
+ self.bottleneck2 = nn.Linear(config.latent_size, 94*config.hidden_size)
147
+
148
+ def forward(self, hidden_states, input_mask=None):
149
+ for i, layer_module in enumerate(self.encoder):
150
+ hidden_states = layer_module(hidden_states)
151
+ if i != 0 and i % 5 == 0:
152
+ hidden_states = nn.functional.avg_pool1d(hidden_states, 2, stride=2)
153
+
154
+ bs = hidden_states.shape[0]
155
+ latents = self.bottleneck1(hidden_states.reshape(bs, -1))
156
+ hidden_states = self.bottleneck2(latents).reshape(bs, -1, 94)
157
+
158
+
159
+ for i, layer_module in enumerate(self.decoder):
160
+ if i != 0 and i % 5 == 0:
161
+ hidden_states = nn.functional.interpolate(hidden_states, scale_factor=2)
162
+ hidden_states = layer_module(hidden_states)
163
+
164
+ hidden_states = hidden_states[:,:,:self.config.max_size]
165
+ outputs = (hidden_states, latents)
166
+
167
+ return outputs
168
+
169
+
170
+ class ProteinAEAbstractModel(ProteinModel):
171
+ """ An abstract class to handle weights initialization and
172
+ a simple interface for dowloading and loading pretrained models.
173
+ """
174
+ config_class = ProteinAEConfig
175
+ base_model_prefix = "ae"
176
+
177
+ def __init__(self, config):
178
+ super().__init__(config)
179
+
180
+ def _init_weights(self, module):
181
+ """ Initialize the weights """
182
+ if isinstance(module, nn.Embedding):
183
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
184
+ elif isinstance(module, nn.Linear):
185
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
186
+ if module.bias is not None:
187
+ module.bias.data.zero_()
188
+ elif isinstance(module, nn.Conv1d):
189
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
190
+ if module.bias is not None:
191
+ module.bias.data.zero_()
192
+
193
+
194
+ @registry.register_task_model('embed', 'autoencoder')
195
+ class ProteinResNetModel(ProteinAEAbstractModel):
196
+
197
+ def __init__(self, config):
198
+ super().__init__(config)
199
+
200
+ self.embeddings = ProteinResNetEmbeddings(config)
201
+ self.encoder = ResNetEncoder(config)
202
+
203
+ self.init_weights()
204
+
205
+ def forward(self,
206
+ input_ids,
207
+ input_mask=None):
208
+ pre_pad_shape = input_ids.shape[1]
209
+ if pre_pad_shape >= self.config.max_size:
210
+ input_ids = input_ids[:,:self.config.max_size]
211
+ if not input_mask is None:
212
+ input_mask = input_mask[:,:self.config.max_size]
213
+ else:
214
+ input_ids = F.pad(input_ids, (0, self.config.max_size - pre_pad_shape))
215
+ if not input_mask is None:
216
+ input_mask = F.pad(input_mask, (0, self.config.max_size - pre_pad_shape))
217
+ assert input_ids.shape[1] == self.config.max_size
218
+
219
+ if input_mask is not None and torch.any(input_mask != 1):
220
+ extended_input_mask = input_mask.unsqueeze(2)
221
+ # fp16 compatibility
222
+ extended_input_mask = extended_input_mask.to(
223
+ dtype=next(self.parameters()).dtype)
224
+ else:
225
+ extended_input_mask = None
226
+
227
+ embedding_output = self.embeddings(input_ids)
228
+ embedding_output = embedding_output.transpose(1, 2)
229
+ if extended_input_mask is not None:
230
+ extended_input_mask = extended_input_mask.transpose(1, 2)
231
+ sequence_output, pooled_output = self.encoder(embedding_output, extended_input_mask)
232
+ sequence_output = sequence_output.transpose(1, 2).contiguous()
233
+ return sequence_output, pooled_output
234
+
235
+ @registry.register_task_model('beta_lactamase', 'autoencoder')
236
+ @registry.register_task_model('language_modeling', 'autoencoder')
237
+ class ProteinResNetForMaskedLM(ProteinAEAbstractModel):
238
+
239
+ def __init__(self, config):
240
+ super().__init__(config)
241
+
242
+ self.resnet = ProteinResNetModel(config)
243
+ self.mlm = MLMHead(
244
+ config.hidden_size, config.vocab_size, config.hidden_act, config.layer_norm_eps,
245
+ ignore_index=-1)
246
+
247
+ self.init_weights()
248
+ self.tie_weights()
249
+
250
+ def tie_weights(self):
251
+ """ Make sure we are sharing the input and output embeddings.
252
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
253
+ """
254
+ self._tie_or_clone_weights(self.mlm.decoder,
255
+ self.resnet.embeddings.word_embeddings)
256
+
257
+ def forward(self,
258
+ input_ids,
259
+ input_mask=None,
260
+ targets=None):
261
+ pre_pad_shape = input_ids.shape[1]
262
+ if targets is not None:
263
+ targets = targets[:,:self.config.max_size]
264
+
265
+ outputs = self.resnet(input_ids, input_mask=input_mask)
266
+ outputs = self.mlm(outputs[0][:,:pre_pad_shape,:], targets) + (outputs[1],)
267
+ # (loss), prediction_scores, (hidden_states), (attentions)
268
+ return outputs
269
+
270
+
271
+ @registry.register_task_model('fluorescence', 'autoencoder')
272
+ @registry.register_task_model('stability', 'autoencoder')
273
+ class ProteinResNetForValuePrediction(ProteinAEAbstractModel):
274
+
275
+ def __init__(self, config):
276
+ super().__init__(config)
277
+
278
+ self.resnet = ProteinResNetModel(config)
279
+ self.predict = ValuePredictionHead(config.hidden_size)
280
+ self.freeze_embedding = config.freeze_embedding
281
+ self.init_weights()
282
+
283
+ def forward(self, input_ids, input_mask=None, targets=None):
284
+ if self.freeze_embedding:
285
+ self.resnet.train(False)
286
+
287
+ outputs = self.resnet(input_ids, input_mask=input_mask)
288
+
289
+ sequence_output, pooled_output = outputs[:2]
290
+ outputs = self.predict(pooled_output, targets) + outputs[2:]
291
+ # (loss), prediction_scores, (hidden_states), (attentions)
292
+ return outputs
293
+
294
+
295
+ @registry.register_task_model('remote_homology', 'autoencoder')
296
+ class ProteinResNetForSequenceClassification(ProteinAEAbstractModel):
297
+
298
+ def __init__(self, config):
299
+ super().__init__(config)
300
+
301
+ self.resnet = ProteinResNetModel(config)
302
+ self.classify = SequenceClassificationHead(config.hidden_size, config.num_labels)
303
+ self.freeze_embedding = config.freeze_embedding
304
+
305
+ self.init_weights()
306
+
307
+ def forward(self, input_ids, input_mask=None, targets=None):
308
+ if self.freeze_embedding:
309
+ self.resnet.train(False)
310
+
311
+ outputs = self.resnet(input_ids, input_mask=input_mask)
312
+
313
+ sequence_output, pooled_output = outputs[:2]
314
+ outputs = self.classify(pooled_output, targets) + outputs[2:]
315
+ # (loss), prediction_scores, (hidden_states), (attentions)
316
+ return outputs
tape/models/modeling_bert.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ # Modified by Roshan Rao
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """PyTorch BERT model. """
18
+
19
+ from __future__ import absolute_import, division, print_function, unicode_literals
20
+
21
+ import logging
22
+ import math
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.utils.checkpoint import checkpoint
27
+
28
+ from .modeling_utils import ProteinConfig
29
+ from .modeling_utils import ProteinModel
30
+ from .modeling_utils import prune_linear_layer
31
+ from .modeling_utils import get_activation_fn
32
+ from .modeling_utils import LayerNorm
33
+ from .modeling_utils import MLMHead
34
+ from .modeling_utils import ValuePredictionHead
35
+ from .modeling_utils import SequenceClassificationHead
36
+ from .modeling_utils import SequenceToSequenceClassificationHead
37
+ from .modeling_utils import PairwiseContactPredictionHead
38
+ from ..registry import registry
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
43
+ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
44
+ 'bert-base': URL_PREFIX + "bert-base-pytorch_model.bin",
45
+ }
46
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
47
+ 'bert-base': URL_PREFIX + "bert-base-config.json"
48
+ }
49
+
50
+
51
+ class ProteinBertConfig(ProteinConfig):
52
+ r"""
53
+ :class:`~pytorch_transformers.ProteinBertConfig` is the configuration class to store the
54
+ configuration of a `ProteinBertModel`.
55
+
56
+
57
+ Arguments:
58
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in
59
+ `ProteinBertModel`.
60
+ hidden_size: Size of the encoder layers and the pooler layer.
61
+ num_hidden_layers: Number of hidden layers in the ProteinBert encoder.
62
+ num_attention_heads: Number of attention heads for each attention layer in
63
+ the ProteinBert encoder.
64
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
65
+ layer in the ProteinBert encoder.
66
+ hidden_act: The non-linear activation function (function or string) in the
67
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
68
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
69
+ layers in the embeddings, encoder, and pooler.
70
+ attention_probs_dropout_prob: The dropout ratio for the attention
71
+ probabilities.
72
+ max_position_embeddings: The maximum sequence length that this model might
73
+ ever be used with. Typically set this to something large just in case
74
+ (e.g., 512 or 1024 or 2048).
75
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
76
+ `ProteinBertModel`.
77
+ initializer_range: The sttdev of the truncated_normal_initializer for
78
+ initializing all weight matrices.
79
+ layer_norm_eps: The epsilon used by LayerNorm.
80
+ """
81
+ pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
82
+
83
+ def __init__(self,
84
+ vocab_size: int = 30,
85
+ hidden_size: int = 768,
86
+ num_hidden_layers: int = 12,
87
+ num_attention_heads: int = 12,
88
+ intermediate_size: int = 3072,
89
+ hidden_act: str = "gelu",
90
+ hidden_dropout_prob: float = 0.1,
91
+ attention_probs_dropout_prob: float = 0.1,
92
+ max_position_embeddings: int = 8096,
93
+ type_vocab_size: int = 2,
94
+ initializer_range: float = 0.02,
95
+ layer_norm_eps: float = 1e-12,
96
+ temporal_pooling: str = 'attention',
97
+ freeze_embedding: bool = False,
98
+ **kwargs):
99
+ super().__init__(**kwargs)
100
+ self.vocab_size = vocab_size
101
+ self.hidden_size = hidden_size
102
+ self.num_hidden_layers = num_hidden_layers
103
+ self.num_attention_heads = num_attention_heads
104
+ self.hidden_act = hidden_act
105
+ self.intermediate_size = intermediate_size
106
+ self.hidden_dropout_prob = hidden_dropout_prob
107
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
108
+ self.max_position_embeddings = max_position_embeddings
109
+ self.type_vocab_size = type_vocab_size
110
+ self.initializer_range = initializer_range
111
+ self.layer_norm_eps = layer_norm_eps
112
+ self.temporal_pooling = temporal_pooling
113
+ self.freeze_embedding = freeze_embedding
114
+
115
+
116
+ class ProteinBertEmbeddings(nn.Module):
117
+ """Construct the embeddings from word, position and token_type embeddings.
118
+ """
119
+ def __init__(self, config):
120
+ super().__init__()
121
+ self.word_embeddings = nn.Embedding(
122
+ config.vocab_size, config.hidden_size, padding_idx=0)
123
+ self.position_embeddings = nn.Embedding(
124
+ config.max_position_embeddings, config.hidden_size)
125
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
126
+
127
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be
128
+ # able to load any TensorFlow checkpoint file
129
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
130
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
131
+
132
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
133
+ seq_length = input_ids.size(1)
134
+ if position_ids is None:
135
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
136
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
137
+ if token_type_ids is None:
138
+ token_type_ids = torch.zeros_like(input_ids)
139
+
140
+ words_embeddings = self.word_embeddings(input_ids)
141
+ position_embeddings = self.position_embeddings(position_ids)
142
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
143
+
144
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
145
+ embeddings = self.LayerNorm(embeddings)
146
+ embeddings = self.dropout(embeddings)
147
+ return embeddings
148
+
149
+
150
+ class ProteinBertSelfAttention(nn.Module):
151
+ def __init__(self, config):
152
+ super().__init__()
153
+ if config.hidden_size % config.num_attention_heads != 0:
154
+ raise ValueError(
155
+ "The hidden size (%d) is not a multiple of the number of attention "
156
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
157
+ self.output_attentions = config.output_attentions
158
+
159
+ self.num_attention_heads = config.num_attention_heads
160
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
161
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
162
+
163
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
164
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
165
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
166
+
167
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
168
+
169
+ def transpose_for_scores(self, x):
170
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
171
+ x = x.view(*new_x_shape)
172
+ return x.permute(0, 2, 1, 3)
173
+
174
+ def forward(self, hidden_states, attention_mask):
175
+ mixed_query_layer = self.query(hidden_states)
176
+ mixed_key_layer = self.key(hidden_states)
177
+ mixed_value_layer = self.value(hidden_states)
178
+
179
+ query_layer = self.transpose_for_scores(mixed_query_layer)
180
+ key_layer = self.transpose_for_scores(mixed_key_layer)
181
+ value_layer = self.transpose_for_scores(mixed_value_layer)
182
+
183
+ # Take the dot product between "query" and "key" to get the raw attention scores.
184
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
185
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
186
+ # Apply the attention mask is (precomputed for all layers in
187
+ # ProteinBertModel forward() function)
188
+ attention_scores = attention_scores + attention_mask
189
+
190
+ # Normalize the attention scores to probabilities.
191
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
192
+
193
+ # This is actually dropping out entire tokens to attend to, which might
194
+ # seem a bit unusual, but is taken from the original ProteinBert paper.
195
+ attention_probs = self.dropout(attention_probs)
196
+
197
+ context_layer = torch.matmul(attention_probs, value_layer)
198
+
199
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
200
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
201
+ context_layer = context_layer.view(*new_context_layer_shape)
202
+
203
+ outputs = (context_layer, attention_probs) \
204
+ if self.output_attentions else (context_layer,)
205
+ return outputs
206
+
207
+
208
+ class ProteinBertSelfOutput(nn.Module):
209
+ def __init__(self, config):
210
+ super().__init__()
211
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
212
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
213
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
214
+
215
+ def forward(self, hidden_states, input_tensor):
216
+ hidden_states = self.dense(hidden_states)
217
+ hidden_states = self.dropout(hidden_states)
218
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
219
+ return hidden_states
220
+
221
+
222
+ class ProteinBertAttention(nn.Module):
223
+ def __init__(self, config):
224
+ super().__init__()
225
+ self.self = ProteinBertSelfAttention(config)
226
+ self.output = ProteinBertSelfOutput(config)
227
+
228
+ def prune_heads(self, heads):
229
+ if len(heads) == 0:
230
+ return
231
+ mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
232
+ for head in heads:
233
+ mask[head] = 0
234
+ mask = mask.view(-1).contiguous().eq(1)
235
+ index = torch.arange(len(mask))[mask].long()
236
+ # Prune linear layers
237
+ self.self.query = prune_linear_layer(self.self.query, index)
238
+ self.self.key = prune_linear_layer(self.self.key, index)
239
+ self.self.value = prune_linear_layer(self.self.value, index)
240
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
241
+ # Update hyper params
242
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
243
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
244
+
245
+ def forward(self, input_tensor, attention_mask):
246
+ self_outputs = self.self(input_tensor, attention_mask)
247
+ attention_output = self.output(self_outputs[0], input_tensor)
248
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
249
+ return outputs
250
+
251
+
252
+ class ProteinBertIntermediate(nn.Module):
253
+ def __init__(self, config):
254
+ super().__init__()
255
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
256
+ if isinstance(config.hidden_act, str):
257
+ self.intermediate_act_fn = get_activation_fn(config.hidden_act)
258
+ else:
259
+ self.intermediate_act_fn = config.hidden_act
260
+
261
+ def forward(self, hidden_states):
262
+ hidden_states = self.dense(hidden_states)
263
+ hidden_states = self.intermediate_act_fn(hidden_states)
264
+ return hidden_states
265
+
266
+
267
+ class ProteinBertOutput(nn.Module):
268
+ def __init__(self, config):
269
+ super().__init__()
270
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
271
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
272
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
273
+
274
+ def forward(self, hidden_states, input_tensor):
275
+ hidden_states = self.dense(hidden_states)
276
+ hidden_states = self.dropout(hidden_states)
277
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
278
+ return hidden_states
279
+
280
+
281
+ class ProteinBertLayer(nn.Module):
282
+ def __init__(self, config):
283
+ super().__init__()
284
+ self.attention = ProteinBertAttention(config)
285
+ self.intermediate = ProteinBertIntermediate(config)
286
+ self.output = ProteinBertOutput(config)
287
+
288
+ def forward(self, hidden_states, attention_mask):
289
+ attention_outputs = self.attention(hidden_states, attention_mask)
290
+ attention_output = attention_outputs[0]
291
+ intermediate_output = self.intermediate(attention_output)
292
+ layer_output = self.output(intermediate_output, attention_output)
293
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
294
+ return outputs
295
+
296
+
297
+ class ProteinBertEncoder(nn.Module):
298
+ def __init__(self, config):
299
+ super().__init__()
300
+ self.output_attentions = config.output_attentions
301
+ self.output_hidden_states = config.output_hidden_states
302
+ self.layer = nn.ModuleList(
303
+ [ProteinBertLayer(config) for _ in range(config.num_hidden_layers)])
304
+
305
+ def run_function(self, start, chunk_size):
306
+ def custom_forward(hidden_states, attention_mask):
307
+ all_hidden_states = ()
308
+ all_attentions = ()
309
+ chunk_slice = slice(start, start + chunk_size)
310
+ for layer in self.layer[chunk_slice]:
311
+ if self.output_hidden_states:
312
+ all_hidden_states = all_hidden_states + (hidden_states,)
313
+ layer_outputs = layer(hidden_states, attention_mask)
314
+ hidden_states = layer_outputs[0]
315
+
316
+ if self.output_attentions:
317
+ all_attentions = all_attentions + (layer_outputs[1],)
318
+
319
+ if self.output_hidden_states:
320
+ all_hidden_states = all_hidden_states + (hidden_states,)
321
+ outputs = (hidden_states,)
322
+ if self.output_hidden_states:
323
+ outputs = outputs + (all_hidden_states,)
324
+ if self.output_attentions:
325
+ outputs = outputs + (all_attentions,)
326
+ return outputs
327
+
328
+ return custom_forward
329
+
330
+ def forward(self, hidden_states, attention_mask, chunks=None):
331
+ all_hidden_states = ()
332
+ all_attentions = ()
333
+
334
+ if chunks is not None:
335
+ assert isinstance(chunks, int)
336
+ chunk_size = (len(self.layer) + chunks - 1) // chunks
337
+ for start in range(0, len(self.layer), chunk_size):
338
+ outputs = checkpoint(self.run_function(start, chunk_size),
339
+ hidden_states, attention_mask)
340
+ if self.output_hidden_states:
341
+ all_hidden_states = all_hidden_states + outputs[1]
342
+ if self.output_attentions:
343
+ all_attentions = all_attentions + outputs[-1]
344
+ hidden_states = outputs[0]
345
+ else:
346
+ for i, layer_module in enumerate(self.layer):
347
+ if self.output_hidden_states:
348
+ all_hidden_states = all_hidden_states + (hidden_states,)
349
+
350
+ layer_outputs = layer_module(hidden_states, attention_mask)
351
+ hidden_states = layer_outputs[0]
352
+
353
+ if self.output_attentions:
354
+ all_attentions = all_attentions + (layer_outputs[1],)
355
+
356
+ # Add last layer
357
+ if self.output_hidden_states:
358
+ all_hidden_states = all_hidden_states + (hidden_states,)
359
+
360
+ outputs = (hidden_states,)
361
+ if self.output_hidden_states:
362
+ outputs = outputs + (all_hidden_states,)
363
+ if self.output_attentions:
364
+ outputs = outputs + (all_attentions,)
365
+ return outputs # outputs, (hidden states), (attentions)
366
+
367
+
368
+ class ProteinBertPooler(nn.Module):
369
+ def __init__(self, config):
370
+ super().__init__()
371
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
372
+ self.activation = nn.Tanh()
373
+ self.temporal_pooling = config.temporal_pooling
374
+ self._la_w1 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
375
+ self._la_w2 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
376
+ self._la_mlp = nn.Linear(config.hidden_size, config.hidden_size)
377
+
378
+ def forward(self, hidden_states):
379
+ # We "pool" the model by simply taking the hidden state corresponding
380
+ # to the first token.
381
+ if self.temporal_pooling == 'mean':
382
+ return hidden_states.mean(dim=1)
383
+ if self.temporal_pooling == 'max':
384
+ return hidden_states.max(dim=1)
385
+ if self.temporal_pooling == 'concat':
386
+ _temp = hidden_states.reshape(hidden_states.shape[0], -1)
387
+ return torch.nn.functional.pad(_temp, (0, 2048 - _temp.shape[1]))
388
+ if self.temporal_pooling == 'topmax':
389
+ val, _ = torch.topk(hidden_states, k=5, dim=1)
390
+ return val.mean(dim=1)
391
+ if self.temporal_pooling == 'light_attention':
392
+ _temp = hidden_states.permute(0,2,1)
393
+ a = self._la_w1(_temp).softmax(dim=-1)
394
+ v = self._la_w2(_temp)
395
+ v_max = v.max(dim=-1).values
396
+ v_sum = (a * v).sum(dim=-1)
397
+ return self._la_mlp(torch.cat([v_max, v_sum], dim=1))
398
+
399
+ first_token_tensor = hidden_states[:, 0]
400
+ pooled_output = self.dense(first_token_tensor)
401
+ pooled_output = self.activation(pooled_output)
402
+ return pooled_output
403
+
404
+
405
+ class ProteinBertAbstractModel(ProteinModel):
406
+ """ An abstract class to handle weights initialization and
407
+ a simple interface for dowloading and loading pretrained models.
408
+ """
409
+ config_class = ProteinBertConfig
410
+ pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
411
+ base_model_prefix = "bert"
412
+
413
+ def _init_weights(self, module):
414
+ """ Initialize the weights """
415
+ if isinstance(module, (nn.Linear, nn.Embedding)):
416
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
417
+ elif isinstance(module, LayerNorm):
418
+ module.bias.data.zero_()
419
+ module.weight.data.fill_(1.0)
420
+ if isinstance(module, nn.Linear) and module.bias is not None:
421
+ module.bias.data.zero_()
422
+
423
+
424
+ @registry.register_task_model('embed', 'transformer')
425
+ class ProteinBertModel(ProteinBertAbstractModel):
426
+
427
+ def __init__(self, config):
428
+ super().__init__(config)
429
+
430
+ self.embeddings = ProteinBertEmbeddings(config)
431
+ self.encoder = ProteinBertEncoder(config)
432
+ self.pooler = ProteinBertPooler(config)
433
+
434
+ self.init_weights()
435
+
436
+ def _resize_token_embeddings(self, new_num_tokens):
437
+ old_embeddings = self.embeddings.word_embeddings
438
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
439
+ self.embeddings.word_embeddings = new_embeddings
440
+ return self.embeddings.word_embeddings
441
+
442
+ def _prune_heads(self, heads_to_prune):
443
+ """ Prunes heads of the model.
444
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
445
+ See base class ProteinModel
446
+ """
447
+ for layer, heads in heads_to_prune.items():
448
+ self.encoder.layer[layer].attention.prune_heads(heads)
449
+
450
+ def forward(self,
451
+ input_ids,
452
+ input_mask=None):
453
+ if input_mask is None:
454
+ input_mask = torch.ones_like(input_ids)
455
+
456
+ # We create a 3D attention mask from a 2D tensor mask.
457
+ # Sizes are [batch_size, 1, 1, to_seq_length]
458
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
459
+ # this attention mask is more simple than the triangular masking of causal attention
460
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
461
+ extended_attention_mask = input_mask.unsqueeze(1).unsqueeze(2)
462
+
463
+ # Since input_mask is 1.0 for positions we want to attend and 0.0 for
464
+ # masked positions, this operation will create a tensor which is 0.0 for
465
+ # positions we want to attend and -10000.0 for masked positions.
466
+ # Since we are adding it to the raw scores before the softmax, this is
467
+ # effectively the same as removing these entirely.
468
+ extended_attention_mask = extended_attention_mask.to(
469
+ dtype=torch.float32) # fp16 compatibility
470
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
471
+
472
+ embedding_output = self.embeddings(input_ids)
473
+ encoder_outputs = self.encoder(embedding_output,
474
+ extended_attention_mask,
475
+ chunks=None)
476
+ sequence_output = encoder_outputs[0]
477
+ pooled_output = self.pooler(sequence_output)
478
+
479
+ # add hidden_states and attentions if they are here
480
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
481
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
482
+
483
+
484
+ @registry.register_task_model('masked_language_modeling', 'transformer')
485
+ class ProteinBertForMaskedLM(ProteinBertAbstractModel):
486
+
487
+ def __init__(self, config):
488
+ super().__init__(config)
489
+
490
+ self.bert = ProteinBertModel(config)
491
+ self.mlm = MLMHead(
492
+ config.hidden_size, config.vocab_size, config.hidden_act, config.layer_norm_eps,
493
+ ignore_index=-1)
494
+
495
+ self.init_weights()
496
+ self.tie_weights()
497
+
498
+ def tie_weights(self):
499
+ """ Make sure we are sharing the input and output embeddings.
500
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
501
+ """
502
+ self._tie_or_clone_weights(self.mlm.decoder,
503
+ self.bert.embeddings.word_embeddings)
504
+
505
+ def forward(self,
506
+ input_ids,
507
+ input_mask=None,
508
+ targets=None):
509
+
510
+ outputs = self.bert(input_ids, input_mask=input_mask)
511
+
512
+ sequence_output, pooled_output = outputs[:2]
513
+ # add hidden states and attention if they are here
514
+ outputs = self.mlm(sequence_output, targets) + outputs[:2]
515
+ # (loss), prediction_scores, (hidden_states), (attentions)
516
+ return outputs
517
+
518
+
519
+ @registry.register_task_model('fluorescence', 'transformer')
520
+ @registry.register_task_model('stability', 'transformer')
521
+ class ProteinBertForValuePrediction(ProteinBertAbstractModel):
522
+
523
+ def __init__(self, config):
524
+ super().__init__(config)
525
+
526
+ self.bert = ProteinBertModel(config)
527
+ self.predict = ValuePredictionHead(config.hidden_size)
528
+ self.freeze_embedding = config.freeze_embedding
529
+ self.init_weights()
530
+
531
+ def forward(self, input_ids, input_mask=None, targets=None):
532
+ if self.freeze_embedding:
533
+ self.bert.train(False)
534
+ outputs = self.bert(input_ids, input_mask=input_mask)
535
+
536
+ sequence_output, pooled_output = outputs[:2]
537
+ outputs = self.predict(pooled_output, targets) + outputs[2:]
538
+ # (loss), prediction_scores, (hidden_states), (attentions)
539
+ return outputs
540
+
541
+
542
+ @registry.register_task_model('remote_homology', 'transformer')
543
+ class ProteinBertForSequenceClassification(ProteinBertAbstractModel):
544
+
545
+ def __init__(self, config):
546
+ super().__init__(config)
547
+
548
+ self.bert = ProteinBertModel(config)
549
+ self.classify = SequenceClassificationHead(
550
+ config.hidden_size, config.num_labels)
551
+ self.freeze_embedding = config.freeze_embedding
552
+ self.init_weights()
553
+
554
+ def forward(self, input_ids, input_mask=None, targets=None):
555
+ if self.freeze_embedding:
556
+ self.bert.train(False)
557
+ outputs = self.bert(input_ids, input_mask=input_mask)
558
+
559
+ sequence_output, pooled_output = outputs[:2]
560
+
561
+ outputs = self.classify(pooled_output, targets) + outputs[2:]
562
+ # (loss), prediction_scores, (hidden_states), (attentions)
563
+ return outputs
564
+
565
+
566
+ @registry.register_task_model('secondary_structure', 'transformer')
567
+ class ProteinBertForSequenceToSequenceClassification(ProteinBertAbstractModel):
568
+
569
+ def __init__(self, config):
570
+ super().__init__(config)
571
+
572
+ self.bert = ProteinBertModel(config)
573
+ self.classify = SequenceToSequenceClassificationHead(
574
+ config.hidden_size, config.num_labels, ignore_index=-1)
575
+
576
+ self.init_weights()
577
+
578
+ def forward(self, input_ids, input_mask=None, targets=None):
579
+
580
+ outputs = self.bert(input_ids, input_mask=input_mask)
581
+
582
+ sequence_output, pooled_output = outputs[:2]
583
+ outputs = self.classify(sequence_output, targets) + outputs[2:]
584
+ # (loss), prediction_scores, (hidden_states), (attentions)
585
+ return outputs
586
+
587
+
588
+ @registry.register_task_model('contact_prediction', 'transformer')
589
+ class ProteinBertForContactPrediction(ProteinBertAbstractModel):
590
+
591
+ def __init__(self, config):
592
+ super().__init__(config)
593
+
594
+ self.bert = ProteinBertModel(config)
595
+ self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
596
+
597
+ self.init_weights()
598
+
599
+ def forward(self, input_ids, protein_length, input_mask=None, targets=None):
600
+
601
+ outputs = self.bert(input_ids, input_mask=input_mask)
602
+
603
+ sequence_output, pooled_output = outputs[:2]
604
+ outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
605
+ # (loss), prediction_scores, (hidden_states), (attentions)
606
+ return outputs
tape/models/modeling_bottleneck.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from tape import ProteinModel, ProteinConfig
5
+ from tape.models.modeling_utils import SequenceToSequenceClassificationHead
6
+ from tape.registry import registry
7
+ from .modeling_utils import LayerNorm, MLMHead
8
+ from .modeling_bert import ProteinBertModel, ProteinBertConfig
9
+ from .modeling_lstm import ProteinLSTMModel, ProteinLSTMConfig
10
+ from .modeling_resnet import ProteinResNetModel, ProteinResNetConfig
11
+
12
+
13
+ class BottleneckConfig(ProteinConfig):
14
+ def __init__(self,
15
+ hidden_size: int = 1024,
16
+ max_size: int = 300,
17
+ backend_name: str = 'resnet',
18
+ **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.hidden_size = hidden_size
21
+ self.max_size = max_size
22
+ self.backend_name = backend_name
23
+
24
+
25
+ class BottleneckAbstractModel(ProteinModel):
26
+ """ All your models will inherit from this one - it's used to define the
27
+ config_class of the model set and also to define the base_model_prefix.
28
+ This is used to allow easy loading/saving into different models.
29
+ """
30
+ config_class = BottleneckConfig
31
+ base_model_prefix = 'bottleneck'
32
+
33
+ def __init__(self, config):
34
+ super().__init__(config)
35
+
36
+ def _init_weights(self, module):
37
+ """ Initialize the weights """
38
+ if isinstance(module, nn.Embedding):
39
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
40
+ elif isinstance(module, nn.Linear):
41
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
42
+ if module.bias is not None:
43
+ module.bias.data.zero_()
44
+ elif isinstance(module, LayerNorm):
45
+ module.bias.data.zero_()
46
+ module.weight.data.fill_(1.0)
47
+ elif isinstance(module, nn.Conv1d):
48
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
49
+ if module.bias is not None:
50
+ module.bias.data.zero_()
51
+ # elif isinstance(module, ProteinResNetBlock):
52
+ # nn.init.constant_(module.bn2.weight, 0)
53
+
54
+ @registry.register_task_model('embed', 'bottleneck')
55
+ class ProteinBottleneckModel(BottleneckAbstractModel):
56
+
57
+ def __init__(self, config):
58
+ super().__init__(config)
59
+ if config.backend_name == 'resnet':
60
+ config = ProteinResNetConfig()
61
+ self.backbone1 = ProteinResNetModel(config)
62
+ elif config.backend_name == 'transformer':
63
+ config = ProteinBertConfig()
64
+ self.backbone1 = ProteinBertModel(config)
65
+ elif config.backend_name == 'lstm':
66
+ config = ProteinLSTMConfig(hidden_size=256)
67
+ self.backbone1 = ProteinLSTMModel(config)
68
+ config.hidden_size = config.hidden_size * 2
69
+ else:
70
+ raise ValueError('Somethings wrong')
71
+ self.linear1 = nn.Linear(self.config.max_size*config.hidden_size, self.config.hidden_size)
72
+ self.linear2 = nn.Linear(self.config.hidden_size, self.config.max_size*config.hidden_size)
73
+
74
+ def forward(self, input_ids, input_mask=None):
75
+ pre_pad_shape = input_ids.shape[1]
76
+ if pre_pad_shape >= self.config.max_size:
77
+ input_ids = input_ids[:,:self.config.max_size]
78
+ if not input_mask is None:
79
+ input_mask = input_mask[:,:self.config.max_size]
80
+ else:
81
+ input_ids = F.pad(input_ids, (0, self.config.max_size - pre_pad_shape))
82
+ if not input_mask is None:
83
+ input_mask = F.pad(input_mask, (0, self.config.max_size - pre_pad_shape))
84
+ assert input_ids.shape[1] == self.config.max_size
85
+
86
+ output = self.backbone1(input_ids, input_mask)
87
+ sequence_output = output[0]
88
+ pre_shape = sequence_output.shape
89
+ embeddings = self.linear1(sequence_output.reshape(sequence_output.shape[0], -1))
90
+ sequence_output = self.linear2(embeddings).reshape(*pre_shape)
91
+ sequence_output = sequence_output[:,:pre_pad_shape]
92
+ outputs = (sequence_output, embeddings) + output[2:]
93
+ return outputs
94
+
95
+ @registry.register_task_model('beta_lactamase', 'bottleneck')
96
+ @registry.register_task_model('masked_language_modeling', 'bottleneck')
97
+ @registry.register_task_model('language_modeling', 'bottleneck')
98
+ class ProteinBottleneckForPretraining(BottleneckAbstractModel):
99
+
100
+ def __init__(self, config):
101
+ super().__init__(config)
102
+ self.backbone1 = ProteinBottleneckModel(config)
103
+
104
+ if config.backend_name == 'resnet':
105
+ config = ProteinResNetConfig()
106
+ self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act,
107
+ config.layer_norm_eps, ignore_index=-1)
108
+ elif config.backend_name == 'transformer':
109
+ config = ProteinBertConfig()
110
+ self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act,
111
+ config.layer_norm_eps, ignore_index=-1)
112
+ elif config.backend_name == 'lstm':
113
+ config = ProteinLSTMConfig(hidden_size=256)
114
+ self.backbone2 = nn.Linear(config.hidden_size, config.vocab_size)
115
+ config.hidden_size = config.hidden_size * 2
116
+ else:
117
+ raise ValueError('Somethings wrong')
118
+
119
+ def forward(self,
120
+ input_ids,
121
+ input_mask=None,
122
+ targets=None):
123
+ if input_ids.shape[1]>self.config.max_size:
124
+ targets = targets[:,:self.config.max_size]
125
+
126
+ outputs = self.backbone1(input_ids, input_mask)
127
+ sequence_output = outputs[0]
128
+ if self.config.backend_name == 'resnet' or self.config.backend_name == 'transformer':
129
+ outputs = self.backbone2(sequence_output, targets) + outputs[2:]
130
+ elif self.config.backend_name == 'lstm':
131
+ sequence_output, pooled_output = outputs[:2]
132
+
133
+ forward_prediction, reverse_prediction = sequence_output.chunk(2, -1)
134
+ forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0])
135
+ reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1])
136
+ prediction_scores = \
137
+ self.backbone2(forward_prediction) + self.backbone2(reverse_prediction)
138
+ prediction_scores = prediction_scores.contiguous()
139
+
140
+ # add hidden states and if they are here
141
+ outputs = (prediction_scores,) + outputs[2:]
142
+
143
+ if targets is not None:
144
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
145
+ lm_loss = loss_fct(
146
+ prediction_scores.view(-1, 30), targets.view(-1))
147
+ outputs = (lm_loss,) + outputs
148
+
149
+ # (loss), prediction_scores, (hidden_states), (attentions)
150
+ return outputs
tape/models/modeling_lstm.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import typing
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .modeling_utils import ProteinConfig
8
+ from .modeling_utils import ProteinModel
9
+ from .modeling_utils import ValuePredictionHead
10
+ from .modeling_utils import SequenceClassificationHead
11
+ from .modeling_utils import SequenceToSequenceClassificationHead
12
+ from .modeling_utils import PairwiseContactPredictionHead
13
+ from ..registry import registry
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
19
+ LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {}
20
+ LSTM_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {}
21
+
22
+
23
+ class ProteinLSTMConfig(ProteinConfig):
24
+ pretrained_config_archive_map = LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP
25
+
26
+ def __init__(self,
27
+ vocab_size: int = 30,
28
+ input_size: int = 128,
29
+ hidden_size: int = 1024,
30
+ num_hidden_layers: int = 3,
31
+ hidden_dropout_prob: float = 0.1,
32
+ initializer_range: float = 0.02,
33
+ temporal_pooling: str = 'attention',
34
+ freeze_embedding: bool = False,
35
+ **kwargs):
36
+ super().__init__(**kwargs)
37
+ self.vocab_size = vocab_size
38
+ self.input_size = input_size
39
+ self.hidden_size = hidden_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.hidden_dropout_prob = hidden_dropout_prob
42
+ self.initializer_range = initializer_range
43
+ self.temporal_pooling = temporal_pooling
44
+ self.freeze_embedding = freeze_embedding
45
+
46
+
47
+ class ProteinLSTMLayer(nn.Module):
48
+
49
+ def __init__(self, input_size: int, hidden_size: int, dropout: float = 0.):
50
+ super().__init__()
51
+ self.dropout = nn.Dropout(dropout)
52
+ self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
53
+
54
+ def forward(self, inputs):
55
+ inputs = self.dropout(inputs)
56
+ self.lstm.flatten_parameters()
57
+ return self.lstm(inputs)
58
+
59
+
60
+ class ProteinLSTMPooler(nn.Module):
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.scalar_reweighting = nn.Linear(2 * config.num_hidden_layers, 1)
64
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
65
+ self.activation = nn.Tanh()
66
+ self.temporal_pooling = config.temporal_pooling
67
+ self._la_w1 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
68
+ self._la_w2 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
69
+ self._la_mlp = nn.Linear(config.hidden_size, config.hidden_size)
70
+
71
+ def forward(self, hidden_states):
72
+ # We "pool" the model by simply taking the hidden state corresponding
73
+ # to the first token.
74
+ if self.temporal_pooling == 'mean':
75
+ return hidden_states.mean(dim=1)
76
+ if self.temporal_pooling == 'max':
77
+ return hidden_states.max(dim=1)
78
+ if self.temporal_pooling == 'concat':
79
+ _temp = hidden_states.reshape(hidden_states.shape[0], -1)
80
+ return torch.nn.functional.pad(_temp, (0, 2048 - _temp.shape[1]))
81
+ if self.temporal_pooling == 'topmax':
82
+ val, _ = torch.topk(hidden_states, k=5, dim=1)
83
+ return val.mean(dim=1)
84
+ if self.temporal_pooling == 'light_attention':
85
+ _temp = hidden_states.permute(0,2,1)
86
+ a = self._la_w1(_temp).softmax(dim=-1)
87
+ v = self._la_w2(_temp)
88
+ v_max = v.max(dim=-1).values
89
+ v_sum = (a * v).sum(dim=-1)
90
+ return self._la_mlp(torch.cat([v_max, v_sum], dim=1))
91
+
92
+ pooled_output = self.scalar_reweighting(hidden_states).squeeze(2)
93
+ pooled_output = self.dense(pooled_output)
94
+ pooled_output = self.activation(pooled_output)
95
+ return pooled_output
96
+
97
+
98
+ class ProteinLSTMEncoder(nn.Module):
99
+
100
+ def __init__(self, config: ProteinLSTMConfig):
101
+ super().__init__()
102
+ forward_lstm = [ProteinLSTMLayer(config.input_size, config.hidden_size)]
103
+ reverse_lstm = [ProteinLSTMLayer(config.input_size, config.hidden_size)]
104
+ for _ in range(config.num_hidden_layers - 1):
105
+ forward_lstm.append(ProteinLSTMLayer(
106
+ config.hidden_size, config.hidden_size, config.hidden_dropout_prob))
107
+ reverse_lstm.append(ProteinLSTMLayer(
108
+ config.hidden_size, config.hidden_size, config.hidden_dropout_prob))
109
+ self.forward_lstm = nn.ModuleList(forward_lstm)
110
+ self.reverse_lstm = nn.ModuleList(reverse_lstm)
111
+ self.output_hidden_states = config.output_hidden_states
112
+
113
+ def forward(self, inputs, input_mask=None):
114
+ all_forward_pooled = ()
115
+ all_reverse_pooled = ()
116
+ all_hidden_states = (inputs,)
117
+ forward_output = inputs
118
+ for layer in self.forward_lstm:
119
+ forward_output, forward_pooled = layer(forward_output)
120
+ all_forward_pooled = all_forward_pooled + (forward_pooled[0],)
121
+ all_hidden_states = all_hidden_states + (forward_output,)
122
+
123
+ reversed_sequence = self.reverse_sequence(inputs, input_mask)
124
+ reverse_output = reversed_sequence
125
+ for layer in self.reverse_lstm:
126
+ reverse_output, reverse_pooled = layer(reverse_output)
127
+ all_reverse_pooled = all_reverse_pooled + (reverse_pooled[0],)
128
+ all_hidden_states = all_hidden_states + (reverse_output,)
129
+ reverse_output = self.reverse_sequence(reverse_output, input_mask)
130
+
131
+ output = torch.cat((forward_output, reverse_output), dim=2)
132
+
133
+ pooled = all_forward_pooled + all_reverse_pooled
134
+ pooled = torch.stack(pooled, 3).squeeze(0)
135
+ outputs = (output, pooled)
136
+ if self.output_hidden_states:
137
+ outputs = outputs + (all_hidden_states,)
138
+
139
+ return outputs # sequence_embedding, pooled_embedding, (hidden_states)
140
+
141
+ def reverse_sequence(self, sequence, input_mask):
142
+ if input_mask is None:
143
+ idx = torch.arange(sequence.size(1) - 1, -1, -1)
144
+ reversed_sequence = sequence.index_select(1, idx, device=sequence.device)
145
+ else:
146
+ sequence_lengths = input_mask.sum(1)
147
+ reversed_sequence = []
148
+ for seq, seqlen in zip(sequence, sequence_lengths):
149
+ idx = torch.arange(seqlen - 1, -1, -1, device=seq.device)
150
+ seq = seq.index_select(0, idx)
151
+ seq = F.pad(seq, [0, 0, 0, sequence.size(1) - seqlen])
152
+ reversed_sequence.append(seq)
153
+ reversed_sequence = torch.stack(reversed_sequence, 0)
154
+ return reversed_sequence
155
+
156
+
157
+ class ProteinLSTMAbstractModel(ProteinModel):
158
+
159
+ config_class = ProteinLSTMConfig
160
+ pretrained_model_archive_map = LSTM_PRETRAINED_MODEL_ARCHIVE_MAP
161
+ base_model_prefix = "lstm"
162
+
163
+ def _init_weights(self, module):
164
+ """ Initialize the weights """
165
+ if isinstance(module, (nn.Linear, nn.Embedding)):
166
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
167
+ if isinstance(module, nn.Linear) and module.bias is not None:
168
+ module.bias.data.zero_()
169
+
170
+
171
+ @registry.register_task_model('embed', 'lstm')
172
+ class ProteinLSTMModel(ProteinLSTMAbstractModel):
173
+
174
+ def __init__(self, config: ProteinLSTMConfig):
175
+ super().__init__(config)
176
+ self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size)
177
+ self.encoder = ProteinLSTMEncoder(config)
178
+ self.pooler = ProteinLSTMPooler(config)
179
+ self.output_hidden_states = config.output_hidden_states
180
+ self.init_weights()
181
+
182
+ def forward(self, input_ids, input_mask=None):
183
+ if input_mask is None:
184
+ input_mask = torch.ones_like(input_ids)
185
+
186
+ # fp16 compatibility
187
+ embedding_output = self.embed_matrix(input_ids)
188
+ outputs = self.encoder(embedding_output, input_mask=input_mask)
189
+ sequence_output = outputs[0]
190
+ pooled_outputs = self.pooler(outputs[1])
191
+
192
+ outputs = (sequence_output, pooled_outputs) + outputs[2:]
193
+ return outputs # sequence_output, pooled_output, (hidden_states)
194
+
195
+
196
+ @registry.register_task_model('language_modeling', 'lstm')
197
+ class ProteinLSTMForLM(ProteinLSTMAbstractModel):
198
+
199
+ def __init__(self, config):
200
+ super().__init__(config)
201
+
202
+ self.lstm = ProteinLSTMModel(config)
203
+ self.feedforward = nn.Linear(config.hidden_size, config.vocab_size)
204
+
205
+ self.init_weights()
206
+
207
+ def forward(self,
208
+ input_ids,
209
+ input_mask=None,
210
+ targets=None):
211
+
212
+ outputs = self.lstm(input_ids, input_mask=input_mask)
213
+
214
+ sequence_output, pooled_output = outputs[:2]
215
+
216
+ forward_prediction, reverse_prediction = sequence_output.chunk(2, -1)
217
+ forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0])
218
+ reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1])
219
+ prediction_scores = \
220
+ self.feedforward(forward_prediction) + self.feedforward(reverse_prediction)
221
+ prediction_scores = prediction_scores.contiguous()
222
+
223
+ # add hidden states and if they are here
224
+ outputs = (prediction_scores,) + outputs[:2]
225
+
226
+ if targets is not None:
227
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
228
+ lm_loss = loss_fct(
229
+ prediction_scores.view(-1, self.config.vocab_size), targets.view(-1))
230
+ outputs = (lm_loss,) + outputs
231
+
232
+ # (loss), prediction_scores, seq_relationship_score, (hidden_states)
233
+ return outputs
234
+
235
+
236
+ @registry.register_task_model('fluorescence', 'lstm')
237
+ @registry.register_task_model('stability', 'lstm')
238
+ class ProteinLSTMForValuePrediction(ProteinLSTMAbstractModel):
239
+
240
+ def __init__(self, config):
241
+ super().__init__(config)
242
+
243
+ self.lstm = ProteinLSTMModel(config)
244
+ self.predict = ValuePredictionHead(config.hidden_size)
245
+ self.freeze_embedding = config.freeze_embedding
246
+ self.init_weights()
247
+
248
+ def forward(self, input_ids, input_mask=None, targets=None):
249
+ if self.freeze_embedding:
250
+ self.lstm.train(False)
251
+
252
+ outputs = self.lstm(input_ids, input_mask=input_mask)
253
+
254
+ sequence_output, pooled_output = outputs[:2]
255
+ outputs = self.predict(pooled_output, targets) + outputs[2:]
256
+ # (loss), prediction_scores, (hidden_states)
257
+ return outputs
258
+
259
+
260
+ @registry.register_task_model('remote_homology', 'lstm')
261
+ class ProteinLSTMForSequenceClassification(ProteinLSTMAbstractModel):
262
+
263
+ def __init__(self, config):
264
+ super().__init__(config)
265
+
266
+ self.lstm = ProteinLSTMModel(config)
267
+ self.classify = SequenceClassificationHead(
268
+ config.hidden_size, config.num_labels)
269
+ self.freeze_embedding = config.freeze_embedding
270
+ self.init_weights()
271
+
272
+ def forward(self, input_ids, input_mask=None, targets=None):
273
+ if self.freeze_embedding:
274
+ self.lstm.train(False)
275
+
276
+ outputs = self.lstm(input_ids, input_mask=input_mask)
277
+
278
+ sequence_output, pooled_output = outputs[:2]
279
+ outputs = self.classify(pooled_output, targets) + outputs[2:]
280
+ # (loss), prediction_scores, (hidden_states)
281
+ return outputs
282
+
283
+
284
+ @registry.register_task_model('secondary_structure', 'lstm')
285
+ class ProteinLSTMForSequenceToSequenceClassification(ProteinLSTMAbstractModel):
286
+
287
+ def __init__(self, config):
288
+ super().__init__(config)
289
+
290
+ self.lstm = ProteinLSTMModel(config)
291
+ self.classify = SequenceToSequenceClassificationHead(
292
+ config.hidden_size * 2, config.num_labels, ignore_index=-1)
293
+
294
+ self.init_weights()
295
+
296
+ def forward(self, input_ids, input_mask=None, targets=None):
297
+
298
+ outputs = self.lstm(input_ids, input_mask=input_mask)
299
+
300
+ sequence_output, pooled_output = outputs[:2]
301
+ amino_acid_class_scores = self.classify(sequence_output.contiguous())
302
+
303
+ # add hidden states and if they are here
304
+ outputs = (amino_acid_class_scores,) + outputs[2:]
305
+
306
+ if targets is not None:
307
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
308
+ classification_loss = loss_fct(
309
+ amino_acid_class_scores.view(-1, self.config.num_labels),
310
+ targets.view(-1))
311
+ outputs = (classification_loss,) + outputs
312
+
313
+ # (loss), prediction_scores, seq_relationship_score, (hidden_states)
314
+ return outputs
315
+
316
+
317
+ @registry.register_task_model('contact_prediction', 'lstm')
318
+ class ProteinLSTMForContactPrediction(ProteinLSTMAbstractModel):
319
+
320
+ def __init__(self, config):
321
+ super().__init__(config)
322
+
323
+ self.lstm = ProteinLSTMModel(config)
324
+ self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
325
+
326
+ self.init_weights()
327
+
328
+ def forward(self, input_ids, protein_length, input_mask=None, targets=None):
329
+
330
+ outputs = self.lstm(input_ids, input_mask=input_mask)
331
+
332
+ sequence_output, pooled_output = outputs[:2]
333
+ outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
334
+ # (loss), prediction_scores, (hidden_states), (attentions)
335
+ return outputs
tape/models/modeling_onehot.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import typing
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .modeling_utils import ProteinConfig
8
+ from .modeling_utils import ProteinModel
9
+ from .modeling_utils import ValuePredictionHead
10
+ from .modeling_utils import SequenceClassificationHead
11
+ from .modeling_utils import SequenceToSequenceClassificationHead
12
+ from .modeling_utils import PairwiseContactPredictionHead
13
+ from ..registry import registry
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ProteinOneHotConfig(ProteinConfig):
19
+ pretrained_config_archive_map: typing.Dict[str, str] = {}
20
+
21
+ def __init__(self,
22
+ vocab_size: int,
23
+ initializer_range: float = 0.02,
24
+ use_evolutionary: bool = False,
25
+ **kwargs):
26
+ super().__init__(**kwargs)
27
+ self.vocab_size = vocab_size
28
+ self.use_evolutionary = use_evolutionary
29
+ self.initializer_range = initializer_range
30
+
31
+
32
+ class ProteinOneHotAbstractModel(ProteinModel):
33
+
34
+ config_class = ProteinOneHotConfig
35
+ pretrained_model_archive_map: typing.Dict[str, str] = {}
36
+ base_model_prefix = "onehot"
37
+
38
+ def _init_weights(self, module):
39
+ """ Initialize the weights """
40
+ if isinstance(module, (nn.Linear, nn.Embedding)):
41
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
42
+ if isinstance(module, nn.Linear) and module.bias is not None:
43
+ module.bias.data.zero_()
44
+
45
+
46
+ class ProteinOneHotModel(ProteinOneHotAbstractModel):
47
+
48
+ def __init__(self, config: ProteinOneHotConfig):
49
+ super().__init__(config)
50
+ self.vocab_size = config.vocab_size
51
+
52
+ # Note: this exists *solely* for fp16 support
53
+ # There doesn't seem to be an easier way to check whether to use fp16 or fp32 training
54
+ buffer = torch.tensor([0.])
55
+ self.register_buffer('_buffer', buffer)
56
+
57
+ def forward(self, input_ids, input_mask=None):
58
+ if input_mask is None:
59
+ input_mask = torch.ones_like(input_ids)
60
+
61
+ sequence_output = F.one_hot(input_ids, num_classes=self.vocab_size)
62
+ # fp16 compatibility
63
+ sequence_output = sequence_output.type_as(self._buffer)
64
+ input_mask = input_mask.unsqueeze(2).type_as(sequence_output)
65
+ # just a bag-of-words for amino acids
66
+ pooled_outputs = (sequence_output * input_mask).sum(1) / input_mask.sum(1)
67
+
68
+ outputs = (sequence_output, pooled_outputs)
69
+ return outputs
70
+
71
+
72
+ @registry.register_task_model('fluorescence', 'onehot')
73
+ @registry.register_task_model('stability', 'onehot')
74
+ class ProteinOneHotForValuePrediction(ProteinOneHotAbstractModel):
75
+
76
+ def __init__(self, config):
77
+ super().__init__(config)
78
+
79
+ self.onehot = ProteinOneHotModel(config)
80
+ self.predict = ValuePredictionHead(config.vocab_size)
81
+
82
+ self.init_weights()
83
+
84
+ def forward(self, input_ids, input_mask=None, targets=None):
85
+
86
+ outputs = self.onehot(input_ids, input_mask=input_mask)
87
+
88
+ sequence_output, pooled_output = outputs[:2]
89
+ outputs = self.predict(pooled_output, targets) + outputs[2:]
90
+ # (loss), prediction_scores, (hidden_states)
91
+ return outputs
92
+
93
+
94
+ @registry.register_task_model('remote_homology', 'onehot')
95
+ class ProteinOneHotForSequenceClassification(ProteinOneHotAbstractModel):
96
+
97
+ def __init__(self, config):
98
+ super().__init__(config)
99
+
100
+ self.onehot = ProteinOneHotModel(config)
101
+ self.classify = SequenceClassificationHead(config.vocab_size, config.num_labels)
102
+
103
+ self.init_weights()
104
+
105
+ def forward(self, input_ids, input_mask=None, targets=None):
106
+
107
+ outputs = self.onehot(input_ids, input_mask=input_mask)
108
+
109
+ sequence_output, pooled_output = outputs[:2]
110
+ outputs = self.classify(pooled_output, targets) + outputs[2:]
111
+ # (loss), prediction_scores, (hidden_states)
112
+ return outputs
113
+
114
+
115
+ @registry.register_task_model('secondary_structure', 'onehot')
116
+ class ProteinOneHotForSequenceToSequenceClassification(ProteinOneHotAbstractModel):
117
+
118
+ def __init__(self, config):
119
+ super().__init__(config)
120
+
121
+ self.onehot = ProteinOneHotModel(config)
122
+ self.classify = SequenceToSequenceClassificationHead(
123
+ config.vocab_size, config.num_labels, ignore_index=-1)
124
+
125
+ self.init_weights()
126
+
127
+ def forward(self, input_ids, input_mask=None, targets=None):
128
+
129
+ outputs = self.onehot(input_ids, input_mask=input_mask)
130
+
131
+ sequence_output, pooled_output = outputs[:2]
132
+ outputs = self.classify(sequence_output, targets) + outputs[2:]
133
+ # (loss), prediction_scores, (hidden_states)
134
+ return outputs
135
+
136
+
137
+ @registry.register_task_model('contact_prediction', 'onehot')
138
+ class ProteinOneHotForContactPrediction(ProteinOneHotAbstractModel):
139
+
140
+ def __init__(self, config):
141
+ super().__init__(config)
142
+
143
+ self.onehot = ProteinOneHotModel(config)
144
+ self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
145
+
146
+ self.init_weights()
147
+
148
+ def forward(self, input_ids, protein_length, input_mask=None, targets=None):
149
+
150
+ outputs = self.onehot(input_ids, input_mask=input_mask)
151
+
152
+ sequence_output, pooled_output = outputs[:2]
153
+ outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
154
+ # (loss), prediction_scores, (hidden_states), (attentions)
155
+ return outputs
tape/models/modeling_resnet.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import logging
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .modeling_utils import ProteinConfig
7
+ from .modeling_utils import ProteinModel
8
+ from .modeling_utils import get_activation_fn
9
+ from .modeling_utils import MLMHead
10
+ from .modeling_utils import LayerNorm
11
+ from .modeling_utils import ValuePredictionHead
12
+ from .modeling_utils import SequenceClassificationHead
13
+ from .modeling_utils import SequenceToSequenceClassificationHead
14
+ from .modeling_utils import PairwiseContactPredictionHead
15
+ from ..registry import registry
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {}
20
+ RESNET_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {}
21
+
22
+
23
+ class ProteinResNetConfig(ProteinConfig):
24
+ pretrained_config_archive_map = RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP
25
+
26
+ def __init__(self,
27
+ vocab_size: int = 30,
28
+ hidden_size: int = 512,
29
+ num_hidden_layers: int = 30,
30
+ hidden_act: str = "gelu",
31
+ hidden_dropout_prob: float = 0.1,
32
+ initializer_range: float = 0.02,
33
+ layer_norm_eps: float = 1e-12,
34
+ temporal_pooling: str = 'attention',
35
+ freeze_embedding: bool = False,
36
+ **kwargs):
37
+ super().__init__(**kwargs)
38
+ self.vocab_size = vocab_size
39
+ self.num_hidden_layers = num_hidden_layers
40
+ self.hidden_size = hidden_size
41
+ self.hidden_act = hidden_act
42
+ self.hidden_dropout_prob = hidden_dropout_prob
43
+ self.initializer_range = initializer_range
44
+ self.layer_norm_eps = layer_norm_eps
45
+ self.temporal_pooling = temporal_pooling
46
+ self.freeze_embedding = freeze_embedding
47
+
48
+
49
+ class MaskedConv1d(nn.Conv1d):
50
+
51
+ def forward(self, x, input_mask=None):
52
+ if input_mask is not None:
53
+ x = x * input_mask
54
+ return super().forward(x)
55
+
56
+
57
+ class ProteinResNetLayerNorm(nn.Module):
58
+
59
+ def __init__(self, config):
60
+ super().__init__()
61
+ self.norm = LayerNorm(config.hidden_size)
62
+
63
+ def forward(self, x):
64
+ return self.norm(x.transpose(1, 2)).transpose(1, 2)
65
+
66
+
67
+ class ProteinResNetBlock(nn.Module):
68
+
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.conv1 = MaskedConv1d(
72
+ config.hidden_size, config.hidden_size, 3, padding=1, bias=False)
73
+ # self.bn1 = nn.BatchNorm1d(config.hidden_size)
74
+ self.bn1 = ProteinResNetLayerNorm(config)
75
+ self.conv2 = MaskedConv1d(
76
+ config.hidden_size, config.hidden_size, 3, padding=1, bias=False)
77
+ # self.bn2 = nn.BatchNorm1d(config.hidden_size)
78
+ self.bn2 = ProteinResNetLayerNorm(config)
79
+ self.activation_fn = get_activation_fn(config.hidden_act)
80
+
81
+ def forward(self, x, input_mask=None):
82
+ identity = x
83
+
84
+ out = self.conv1(x, input_mask)
85
+ out = self.bn1(out)
86
+ out = self.activation_fn(out)
87
+
88
+ out = self.conv2(out, input_mask)
89
+ out = self.bn2(out)
90
+
91
+ out += identity
92
+ out = self.activation_fn(out)
93
+
94
+ return out
95
+
96
+
97
+ class ProteinResNetEmbeddings(nn.Module):
98
+ """Construct the embeddings from word, position and token_type embeddings.
99
+ """
100
+ def __init__(self, config):
101
+ super().__init__()
102
+ embed_dim = config.hidden_size
103
+ self.word_embeddings = nn.Embedding(config.vocab_size, embed_dim, padding_idx=0)
104
+ inverse_frequency = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0) / embed_dim))
105
+ self.register_buffer('inverse_frequency', inverse_frequency)
106
+
107
+ self.layer_norm = LayerNorm(embed_dim, eps=config.layer_norm_eps)
108
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
109
+
110
+ def forward(self, input_ids):
111
+ words_embeddings = self.word_embeddings(input_ids)
112
+
113
+ seq_length = input_ids.size(1)
114
+ position_ids = torch.arange(
115
+ seq_length - 1, -1, -1.0,
116
+ dtype=words_embeddings.dtype,
117
+ device=words_embeddings.device)
118
+ sinusoidal_input = torch.ger(position_ids, self.inverse_frequency)
119
+ position_embeddings = torch.cat([sinusoidal_input.sin(), sinusoidal_input.cos()], -1)
120
+ position_embeddings = position_embeddings.unsqueeze(0)
121
+
122
+ embeddings = words_embeddings + position_embeddings
123
+ embeddings = self.layer_norm(embeddings)
124
+ embeddings = self.dropout(embeddings)
125
+ return embeddings
126
+
127
+
128
+ class ProteinResNetPooler(nn.Module):
129
+ def __init__(self, config):
130
+ super().__init__()
131
+ self.attention_weights = nn.Linear(config.hidden_size, 1)
132
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
133
+ self.activation = nn.Tanh()
134
+ self.temporal_pooling = config.temporal_pooling
135
+ self._la_w1 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
136
+ self._la_w2 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2)
137
+ self._la_mlp = nn.Linear(config.hidden_size, config.hidden_size)
138
+
139
+ def forward(self, hidden_states, mask=None):
140
+ # We "pool" the model by simply taking the hidden state corresponding
141
+ # to the first token.
142
+ if self.temporal_pooling == 'mean':
143
+ return hidden_states.mean(dim=1)
144
+ if self.temporal_pooling == 'max':
145
+ return hidden_states.max(dim=1)
146
+ if self.temporal_pooling == 'concat':
147
+ _temp = hidden_states.reshape(hidden_states.shape[0], -1)
148
+ return torch.nn.functional.pad(_temp, (0, 2048 - _temp.shape[1]))
149
+ if self.temporal_pooling == 'meanmax':
150
+ _mean = hidden_states.mean(dim=1)
151
+ _max = hidden_states.max(dim=1)
152
+ return torch.cat([_mean, _max])
153
+ if self.temporal_pooling == 'topmax':
154
+ val, _ = torch.topk(hidden_states, k=5, dim=1)
155
+ return val.mean(dim=1)
156
+ if self.temporal_pooling == 'light_attention':
157
+ _temp = hidden_states.permute(0,2,1)
158
+ a = self._la_w1(_temp).softmax(dim=-1)
159
+ v = self._la_w2(_temp)
160
+ v_max = v.max(dim=-1).values
161
+ v_sum = (a * v).sum(dim=-1)
162
+ return self._la_mlp(torch.cat([v_max, v_sum], dim=1))
163
+
164
+ attention_scores = self.attention_weights(hidden_states)
165
+ if mask is not None:
166
+ attention_scores += -10000. * (1 - mask)
167
+ attention_weights = torch.softmax(attention_scores, -1)
168
+ weighted_mean_embedding = torch.matmul(
169
+ hidden_states.transpose(1, 2), attention_weights).squeeze(2)
170
+ pooled_output = self.dense(weighted_mean_embedding)
171
+ pooled_output = self.activation(pooled_output)
172
+ return pooled_output
173
+
174
+
175
+ class ResNetEncoder(nn.Module):
176
+
177
+ def __init__(self, config):
178
+ super().__init__()
179
+ self.output_hidden_states = config.output_hidden_states
180
+ self.layer = nn.ModuleList(
181
+ [ProteinResNetBlock(config) for _ in range(config.num_hidden_layers)])
182
+
183
+ def forward(self, hidden_states, input_mask=None):
184
+ all_hidden_states = ()
185
+ for layer_module in self.layer:
186
+ if self.output_hidden_states:
187
+ all_hidden_states = all_hidden_states + (hidden_states,)
188
+ hidden_states = layer_module(hidden_states, input_mask)
189
+
190
+ if self.output_hidden_states:
191
+ all_hidden_states = all_hidden_states + (hidden_states,)
192
+
193
+ outputs = (hidden_states,)
194
+ if self.output_hidden_states:
195
+ outputs = outputs + (all_hidden_states,)
196
+
197
+ return outputs
198
+
199
+
200
+ class ProteinResNetAbstractModel(ProteinModel):
201
+ """ An abstract class to handle weights initialization and
202
+ a simple interface for dowloading and loading pretrained models.
203
+ """
204
+ config_class = ProteinResNetConfig
205
+ pretrained_model_archive_map = RESNET_PRETRAINED_MODEL_ARCHIVE_MAP
206
+ base_model_prefix = "resnet"
207
+
208
+ def __init__(self, config):
209
+ super().__init__(config)
210
+
211
+ def _init_weights(self, module):
212
+ """ Initialize the weights """
213
+ if isinstance(module, nn.Embedding):
214
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
215
+ elif isinstance(module, nn.Linear):
216
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
217
+ if module.bias is not None:
218
+ module.bias.data.zero_()
219
+ elif isinstance(module, nn.Conv1d):
220
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
221
+ if module.bias is not None:
222
+ module.bias.data.zero_()
223
+ # elif isinstance(module, ProteinResNetBlock):
224
+ # nn.init.constant_(module.bn2.weight, 0)
225
+
226
+
227
+ @registry.register_task_model('embed', 'resnet')
228
+ class ProteinResNetModel(ProteinResNetAbstractModel):
229
+
230
+ def __init__(self, config):
231
+ super().__init__(config)
232
+
233
+ self.embeddings = ProteinResNetEmbeddings(config)
234
+ self.encoder = ResNetEncoder(config)
235
+ self.pooler = ProteinResNetPooler(config)
236
+
237
+ self.init_weights()
238
+
239
+ def forward(self,
240
+ input_ids,
241
+ input_mask=None):
242
+ if input_mask is not None and torch.any(input_mask != 1):
243
+ extended_input_mask = input_mask.unsqueeze(2)
244
+ # fp16 compatibility
245
+ extended_input_mask = extended_input_mask.to(
246
+ dtype=next(self.parameters()).dtype)
247
+ else:
248
+ extended_input_mask = None
249
+
250
+ embedding_output = self.embeddings(input_ids)
251
+ embedding_output = embedding_output.transpose(1, 2)
252
+ if extended_input_mask is not None:
253
+ extended_input_mask = extended_input_mask.transpose(1, 2)
254
+ encoder_outputs = self.encoder(embedding_output, extended_input_mask)
255
+ sequence_output = encoder_outputs[0]
256
+ sequence_output = sequence_output.transpose(1, 2).contiguous()
257
+ # sequence_output = encoder_outputs[0]
258
+ if extended_input_mask is not None:
259
+ extended_input_mask = extended_input_mask.transpose(1, 2)
260
+ pooled_output = self.pooler(sequence_output, extended_input_mask)
261
+
262
+ # add hidden_states and attentions if they are here
263
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
264
+ return outputs # sequence_output, pooled_output, (hidden_states)
265
+
266
+
267
+ @registry.register_task_model('masked_language_modeling', 'resnet')
268
+ class ProteinResNetForMaskedLM(ProteinResNetAbstractModel):
269
+
270
+ def __init__(self, config):
271
+ super().__init__(config)
272
+
273
+ self.resnet = ProteinResNetModel(config)
274
+ self.mlm = MLMHead(
275
+ config.hidden_size, config.vocab_size, config.hidden_act, config.layer_norm_eps,
276
+ ignore_index=-1)
277
+
278
+ self.init_weights()
279
+ self.tie_weights()
280
+
281
+ def tie_weights(self):
282
+ """ Make sure we are sharing the input and output embeddings.
283
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
284
+ """
285
+ self._tie_or_clone_weights(self.mlm.decoder,
286
+ self.resnet.embeddings.word_embeddings)
287
+
288
+ def forward(self,
289
+ input_ids,
290
+ input_mask=None,
291
+ targets=None):
292
+
293
+ outputs = self.resnet(input_ids, input_mask=input_mask)
294
+
295
+ sequence_output, pooled_output = outputs[:2]
296
+ outputs = self.mlm(sequence_output, targets) + outputs[:2]
297
+ # (loss), prediction_scores, (hidden_states), (attentions)
298
+ return outputs
299
+
300
+
301
+ @registry.register_task_model('fluorescence', 'resnet')
302
+ @registry.register_task_model('stability', 'resnet')
303
+ class ProteinResNetForValuePrediction(ProteinResNetAbstractModel):
304
+
305
+ def __init__(self, config):
306
+ super().__init__(config)
307
+
308
+ self.resnet = ProteinResNetModel(config)
309
+ self.predict = ValuePredictionHead(config.hidden_size)
310
+ self.freeze_embedding = config.freeze_embedding
311
+ self.init_weights()
312
+
313
+ def forward(self, input_ids, input_mask=None, targets=None):
314
+ if self.freeze_embedding:
315
+ self.resnet.train(False)
316
+
317
+ outputs = self.resnet(input_ids, input_mask=input_mask)
318
+
319
+ sequence_output, pooled_output = outputs[:2]
320
+ outputs = self.predict(pooled_output, targets) + outputs[2:]
321
+ # (loss), prediction_scores, (hidden_states), (attentions)
322
+ return outputs
323
+
324
+
325
+ @registry.register_task_model('remote_homology', 'resnet')
326
+ class ProteinResNetForSequenceClassification(ProteinResNetAbstractModel):
327
+
328
+ def __init__(self, config):
329
+ super().__init__(config)
330
+
331
+ self.resnet = ProteinResNetModel(config)
332
+ self.classify = SequenceClassificationHead(config.hidden_size, config.num_labels)
333
+ self.freeze_embedding = config.freeze_embedding
334
+
335
+ self.init_weights()
336
+
337
+ def forward(self, input_ids, input_mask=None, targets=None):
338
+ if self.freeze_embedding:
339
+ self.resnet.train(False)
340
+
341
+ outputs = self.resnet(input_ids, input_mask=input_mask)
342
+
343
+ sequence_output, pooled_output = outputs[:2]
344
+ outputs = self.classify(pooled_output, targets) + outputs[2:]
345
+ # (loss), prediction_scores, (hidden_states), (attentions)
346
+ return outputs
347
+
348
+
349
+ @registry.register_task_model('secondary_structure', 'resnet')
350
+ class ProteinResNetForSequenceToSequenceClassification(ProteinResNetAbstractModel):
351
+
352
+ def __init__(self, config):
353
+ super().__init__(config)
354
+
355
+ self.resnet = ProteinResNetModel(config)
356
+ self.classify = SequenceToSequenceClassificationHead(
357
+ config.hidden_size, config.num_labels, ignore_index=-1)
358
+
359
+ self.init_weights()
360
+
361
+ def forward(self, input_ids, input_mask=None, targets=None):
362
+
363
+ outputs = self.resnet(input_ids, input_mask=input_mask)
364
+
365
+ sequence_output, pooled_output = outputs[:2]
366
+ outputs = self.classify(sequence_output, targets) + outputs[2:]
367
+ # (loss), prediction_scores, (hidden_states), (attentions)
368
+ return outputs
369
+
370
+
371
+ @registry.register_task_model('contact_prediction', 'resnet')
372
+ class ProteinResNetForContactPrediction(ProteinResNetAbstractModel):
373
+
374
+ def __init__(self, config):
375
+ super().__init__(config)
376
+
377
+ self.resnet = ProteinResNetModel(config)
378
+ self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
379
+
380
+ self.init_weights()
381
+
382
+ def forward(self, input_ids, protein_length, input_mask=None, targets=None):
383
+
384
+ outputs = self.resnet(input_ids, input_mask=input_mask)
385
+
386
+ sequence_output, pooled_output = outputs[:2]
387
+ outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
388
+ # (loss), prediction_scores, (hidden_states), (attentions)
389
+ return outputs
tape/models/modeling_trrosetta.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ..registry import registry
5
+ from .modeling_utils import ProteinConfig
6
+ from .modeling_utils import ProteinModel
7
+
8
+ URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
9
+ TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
10
+ 'xaa': URL_PREFIX + "trRosetta-xaa-pytorch_model.bin",
11
+ 'xab': URL_PREFIX + "trRosetta-xab-pytorch_model.bin",
12
+ 'xac': URL_PREFIX + "trRosetta-xac-pytorch_model.bin",
13
+ 'xad': URL_PREFIX + "trRosetta-xad-pytorch_model.bin",
14
+ 'xae': URL_PREFIX + "trRosetta-xae-pytorch_model.bin",
15
+ }
16
+ TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
17
+ 'xaa': URL_PREFIX + "trRosetta-xaa-config.json",
18
+ 'xab': URL_PREFIX + "trRosetta-xab-config.json",
19
+ 'xac': URL_PREFIX + "trRosetta-xac-config.json",
20
+ 'xad': URL_PREFIX + "trRosetta-xad-config.json",
21
+ 'xae': URL_PREFIX + "trRosetta-xae-config.json",
22
+ }
23
+
24
+
25
+ class TRRosettaConfig(ProteinConfig):
26
+
27
+ pretrained_config_archive_map = TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP
28
+
29
+ def __init__(self,
30
+ num_features: int = 64,
31
+ kernel_size: int = 3,
32
+ num_layers: int = 61,
33
+ dropout: float = 0.15,
34
+ msa_cutoff: float = 0.8,
35
+ penalty_coeff: float = 4.5,
36
+ initializer_range: float = 0.02,
37
+ **kwargs):
38
+ super().__init__(**kwargs)
39
+ self.num_features = num_features
40
+ self.kernel_size = kernel_size
41
+ self.num_layers = num_layers
42
+ self.dropout = dropout
43
+ self.msa_cutoff = msa_cutoff
44
+ self.penalty_coeff = penalty_coeff
45
+ self.initializer_range = initializer_range
46
+
47
+
48
+ class MSAFeatureExtractor(nn.Module):
49
+
50
+ def __init__(self, config: TRRosettaConfig):
51
+ super().__init__()
52
+ self.msa_cutoff = config.msa_cutoff
53
+ self.penalty_coeff = config.penalty_coeff
54
+
55
+ def forward(self, msa1hot):
56
+ # Convert to float, then potentially back to half
57
+ # These transforms aren't well suited to half-precision
58
+ initial_type = msa1hot.dtype
59
+
60
+ msa1hot = msa1hot.float()
61
+ seqlen = msa1hot.size(2)
62
+
63
+ weights = self.reweight(msa1hot)
64
+ features_1d = self.extract_features_1d(msa1hot, weights)
65
+ features_2d = self.extract_features_2d(msa1hot, weights)
66
+
67
+ left = features_1d.unsqueeze(2).repeat(1, 1, seqlen, 1)
68
+ right = features_1d.unsqueeze(1).repeat(1, seqlen, 1, 1)
69
+ features = torch.cat((left, right, features_2d), -1)
70
+ features = features.type(initial_type)
71
+ features = features.permute(0, 3, 1, 2)
72
+ features = features.contiguous()
73
+ return features
74
+
75
+ def reweight(self, msa1hot, eps=1e-9):
76
+ # Reweight
77
+ seqlen = msa1hot.size(2)
78
+ id_min = seqlen * self.msa_cutoff
79
+ id_mtx = torch.stack([torch.tensordot(el, el, [[1, 2], [1, 2]]) for el in msa1hot], 0)
80
+ id_mask = id_mtx > id_min
81
+ weights = 1.0 / (id_mask.type_as(msa1hot).sum(-1) + eps)
82
+ return weights
83
+
84
+ def extract_features_1d(self, msa1hot, weights):
85
+ # 1D Features
86
+ f1d_seq = msa1hot[:, 0, :, :20]
87
+ batch_size = msa1hot.size(0)
88
+ seqlen = msa1hot.size(2)
89
+
90
+ # msa2pssm
91
+ beff = weights.sum()
92
+ f_i = (weights[:, :, None, None] * msa1hot).sum(1) / beff + 1e-9
93
+ h_i = (-f_i * f_i.log()).sum(2, keepdims=True)
94
+ f1d_pssm = torch.cat((f_i, h_i), dim=2)
95
+ f1d = torch.cat((f1d_seq, f1d_pssm), dim=2)
96
+ f1d = f1d.view(batch_size, seqlen, 42)
97
+ return f1d
98
+
99
+ def extract_features_2d(self, msa1hot, weights):
100
+ # 2D Features
101
+ batch_size = msa1hot.size(0)
102
+ num_alignments = msa1hot.size(1)
103
+ seqlen = msa1hot.size(2)
104
+ num_symbols = 21
105
+
106
+ if num_alignments == 1:
107
+ # No alignments, predict from sequence alone
108
+ f2d_dca = torch.zeros(
109
+ batch_size, seqlen, seqlen, 442,
110
+ dtype=torch.float,
111
+ device=msa1hot.device)
112
+ return f2d_dca
113
+
114
+ # compute fast_dca
115
+ # covariance
116
+ x = msa1hot.view(batch_size, num_alignments, seqlen * num_symbols)
117
+ num_points = weights.sum(1) - weights.mean(1).sqrt()
118
+ mean = (x * weights.unsqueeze(2)).sum(1, keepdims=True) / num_points[:, None, None]
119
+ x = (x - mean) * weights[:, :, None].sqrt()
120
+ cov = torch.matmul(x.transpose(-1, -2), x) / num_points[:, None, None]
121
+
122
+ # inverse covariance
123
+ reg = torch.eye(seqlen * num_symbols,
124
+ device=weights.device,
125
+ dtype=weights.dtype)[None]
126
+ reg = reg * self.penalty_coeff / weights.sum(1, keepdims=True).sqrt().unsqueeze(2)
127
+ cov_reg = cov + reg
128
+ inv_cov = torch.stack([torch.inverse(cr) for cr in cov_reg.unbind(0)], 0)
129
+
130
+ x1 = inv_cov.view(batch_size, seqlen, num_symbols, seqlen, num_symbols)
131
+ x2 = x1.permute(0, 1, 3, 2, 4)
132
+ features = x2.reshape(batch_size, seqlen, seqlen, num_symbols * num_symbols)
133
+
134
+ x3 = (x1[:, :, :-1, :, :-1] ** 2).sum((2, 4)).sqrt() * (
135
+ 1 - torch.eye(seqlen, device=weights.device, dtype=weights.dtype)[None])
136
+ apc = x3.sum(1, keepdims=True) * x3.sum(2, keepdims=True) / x3.sum(
137
+ (1, 2), keepdims=True)
138
+ contacts = (x3 - apc) * (1 - torch.eye(
139
+ seqlen, device=x3.device, dtype=x3.dtype).unsqueeze(0))
140
+
141
+ f2d_dca = torch.cat([features, contacts[:, :, :, None]], axis=3)
142
+ return f2d_dca
143
+
144
+ @property
145
+ def feature_size(self) -> int:
146
+ return 526
147
+
148
+
149
+ class DilatedResidualBlock(nn.Module):
150
+
151
+ def __init__(self, num_features: int, kernel_size: int, dilation: int, dropout: float):
152
+ super().__init__()
153
+ padding = self._get_padding(kernel_size, dilation)
154
+ self.conv1 = nn.Conv2d(
155
+ num_features, num_features, kernel_size, padding=padding, dilation=dilation)
156
+ self.norm1 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6)
157
+ self.actv1 = nn.ELU(inplace=True)
158
+ self.dropout = nn.Dropout(dropout)
159
+ self.conv2 = nn.Conv2d(
160
+ num_features, num_features, kernel_size, padding=padding, dilation=dilation)
161
+ self.norm2 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6)
162
+ self.actv2 = nn.ELU(inplace=True)
163
+ self.apply(self._init_weights)
164
+ nn.init.constant_(self.norm2.weight, 0)
165
+
166
+ def _get_padding(self, kernel_size: int, dilation: int) -> int:
167
+ return (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
168
+
169
+ def _init_weights(self, module):
170
+ """ Initialize the weights """
171
+ if isinstance(module, nn.Conv2d):
172
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
173
+ if module.bias is not None:
174
+ module.bias.data.zero_()
175
+
176
+ # elif isinstance(module, DilatedResidualBlock):
177
+ # nn.init.constant_(module.norm2.weight, 0)
178
+
179
+ def forward(self, features):
180
+ shortcut = features
181
+ features = self.conv1(features)
182
+ features = self.norm1(features)
183
+ features = self.actv1(features)
184
+ features = self.dropout(features)
185
+ features = self.conv2(features)
186
+ features = self.norm2(features)
187
+ features = self.actv2(features + shortcut)
188
+ return features
189
+
190
+
191
+ class TRRosettaAbstractModel(ProteinModel):
192
+
193
+ config_class = TRRosettaConfig
194
+ base_model_prefix = 'trrosetta'
195
+ pretrained_model_archive_map = TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP
196
+
197
+ def __init__(self, config: TRRosettaConfig):
198
+ super().__init__(config)
199
+
200
+ def _init_weights(self, module):
201
+ """ Initialize the weights """
202
+ if isinstance(module, nn.Linear):
203
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
204
+ if module.bias is not None:
205
+ module.bias.data.zero_()
206
+ elif isinstance(module, nn.Conv2d):
207
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
208
+ if module.bias is not None:
209
+ module.bias.data.zero_()
210
+ elif isinstance(module, DilatedResidualBlock):
211
+ nn.init.constant_(module.norm2.weight, 0)
212
+
213
+
214
+ class TRRosettaPredictor(TRRosettaAbstractModel):
215
+
216
+ def __init__(self, config: TRRosettaConfig):
217
+ super().__init__(config)
218
+ layers = [
219
+ nn.Conv2d(526, config.num_features, 1),
220
+ nn.InstanceNorm2d(config.num_features, affine=True, eps=1e-6),
221
+ nn.ELU(),
222
+ nn.Dropout(config.dropout)]
223
+
224
+ dilation = 1
225
+ for _ in range(config.num_layers):
226
+ block = DilatedResidualBlock(
227
+ config.num_features, config.kernel_size, dilation, config.dropout)
228
+ layers.append(block)
229
+
230
+ dilation *= 2
231
+ if dilation > 16:
232
+ dilation = 1
233
+
234
+ self.resnet = nn.Sequential(*layers)
235
+ self.predict_theta = nn.Conv2d(config.num_features, 25, 1)
236
+ self.predict_phi = nn.Conv2d(config.num_features, 13, 1)
237
+ self.predict_dist = nn.Conv2d(config.num_features, 37, 1)
238
+ self.predict_bb = nn.Conv2d(config.num_features, 3, 1)
239
+ self.predict_omega = nn.Conv2d(config.num_features, 25, 1)
240
+
241
+ self.init_weights()
242
+
243
+ def init_weights(self):
244
+ self.apply(self._init_weights)
245
+ nn.init.constant_(self.predict_theta.weight, 0)
246
+ nn.init.constant_(self.predict_phi.weight, 0)
247
+ nn.init.constant_(self.predict_dist.weight, 0)
248
+ nn.init.constant_(self.predict_bb.weight, 0)
249
+ nn.init.constant_(self.predict_omega.weight, 0)
250
+
251
+ def forward(self,
252
+ features,
253
+ theta=None,
254
+ phi=None,
255
+ dist=None,
256
+ omega=None):
257
+ batch_size = features.size(0)
258
+ seqlen = features.size(2)
259
+ embedding = self.resnet(features)
260
+
261
+ # anglegrams for theta
262
+ logits_theta = self.predict_theta(embedding)
263
+
264
+ # anglegrams for phi
265
+ logits_phi = self.predict_phi(embedding)
266
+
267
+ # symmetrize
268
+ sym_embedding = 0.5 * (embedding + embedding.transpose(-1, -2))
269
+
270
+ # distograms
271
+ logits_dist = self.predict_dist(sym_embedding)
272
+
273
+ # beta-strand pairings (not used)
274
+ # logits_bb = self.predict_bb(sym_embedding)
275
+
276
+ # anglegrams for omega
277
+ logits_omega = self.predict_omega(sym_embedding)
278
+
279
+ logits_dist = logits_dist.permute(0, 2, 3, 1).contiguous()
280
+ logits_theta = logits_theta.permute(0, 2, 3, 1).contiguous()
281
+ logits_omega = logits_omega.permute(0, 2, 3, 1).contiguous()
282
+ logits_phi = logits_phi.permute(0, 2, 3, 1).contiguous()
283
+
284
+ probs = {}
285
+ probs['p_dist'] = nn.Softmax(-1)(logits_dist)
286
+ probs['p_theta'] = nn.Softmax(-1)(logits_theta)
287
+ probs['p_omega'] = nn.Softmax(-1)(logits_omega)
288
+ probs['p_phi'] = nn.Softmax(-1)(logits_phi)
289
+ outputs = (probs,)
290
+
291
+ metrics = {}
292
+ total_loss = 0
293
+
294
+ if dist is not None:
295
+ logits_dist = logits_dist.reshape(batch_size * seqlen * seqlen, 37)
296
+ loss_dist = nn.CrossEntropyLoss(ignore_index=-1)(logits_dist, dist.view(-1))
297
+ metrics['dist'] = loss_dist
298
+ total_loss += loss_dist
299
+ if theta is not None:
300
+ logits_theta = logits_theta.reshape(batch_size * seqlen * seqlen, 25)
301
+ loss_theta = nn.CrossEntropyLoss(ignore_index=0)(logits_theta, theta.view(-1))
302
+ metrics['theta'] = loss_theta
303
+ total_loss += loss_theta
304
+ if omega is not None:
305
+ logits_omega = logits_omega.reshape(batch_size * seqlen * seqlen, 25)
306
+ loss_omega = nn.CrossEntropyLoss(ignore_index=0)(logits_omega, omega.view(-1))
307
+ metrics['omega'] = loss_omega
308
+ total_loss += loss_omega
309
+ if phi is not None:
310
+ logits_phi = logits_phi.reshape(batch_size * seqlen * seqlen, 13)
311
+ loss_phi = nn.CrossEntropyLoss(ignore_index=0)(logits_phi, phi.view(-1))
312
+ metrics['phi'] = loss_phi
313
+ total_loss += loss_phi
314
+
315
+ if len(metrics) > 0:
316
+ outputs = ((total_loss, metrics),) + outputs
317
+
318
+ return outputs
319
+
320
+
321
+ @registry.register_task_model('trrosetta', 'trrosetta')
322
+ class TRRosetta(TRRosettaAbstractModel):
323
+
324
+ def __init__(self, config: TRRosettaConfig):
325
+ super().__init__(config)
326
+ self.extract_features = MSAFeatureExtractor(config)
327
+ self.trrosetta = TRRosettaPredictor(config)
328
+
329
+ def forward(self,
330
+ msa1hot,
331
+ theta=None,
332
+ phi=None,
333
+ dist=None,
334
+ omega=None):
335
+ features = self.extract_features(msa1hot)
336
+ return self.trrosetta(features, theta, phi, dist, omega)
tape/models/modeling_unirep.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import typing
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils import weight_norm
6
+
7
+ from .modeling_utils import ProteinConfig
8
+ from .modeling_utils import ProteinModel
9
+ from .modeling_utils import ValuePredictionHead
10
+ from .modeling_utils import SequenceClassificationHead
11
+ from .modeling_utils import SequenceToSequenceClassificationHead
12
+ from .modeling_utils import PairwiseContactPredictionHead
13
+ from ..registry import registry
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
19
+ UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {
20
+ 'babbler-1900': URL_PREFIX + 'unirep-base-config.json'}
21
+ UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {
22
+ 'babbler-1900': URL_PREFIX + 'unirep-base-pytorch_model.bin'}
23
+
24
+
25
+ class UniRepConfig(ProteinConfig):
26
+ pretrained_config_archive_map = UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP
27
+
28
+ def __init__(self,
29
+ vocab_size: int = 26,
30
+ input_size: int = 10,
31
+ hidden_size: int = 1900,
32
+ hidden_dropout_prob: float = 0.1,
33
+ layer_norm_eps: float = 1e-12,
34
+ initializer_range: float = 0.02,
35
+ **kwargs):
36
+ super().__init__(**kwargs)
37
+ self.vocab_size = vocab_size
38
+ self.input_size = input_size
39
+ self.hidden_size = hidden_size
40
+ self.hidden_dropout_prob = hidden_dropout_prob
41
+ self.layer_norm_eps = layer_norm_eps
42
+ self.initializer_range = initializer_range
43
+
44
+
45
+ class mLSTMCell(nn.Module):
46
+ def __init__(self, config):
47
+ super().__init__()
48
+ project_size = config.hidden_size * 4
49
+ self.wmx = weight_norm(
50
+ nn.Linear(config.input_size, config.hidden_size, bias=False))
51
+ self.wmh = weight_norm(
52
+ nn.Linear(config.hidden_size, config.hidden_size, bias=False))
53
+ self.wx = weight_norm(
54
+ nn.Linear(config.input_size, project_size, bias=False))
55
+ self.wh = weight_norm(
56
+ nn.Linear(config.hidden_size, project_size, bias=True))
57
+
58
+ def forward(self, inputs, state):
59
+ h_prev, c_prev = state
60
+ m = self.wmx(inputs) * self.wmh(h_prev)
61
+ z = self.wx(inputs) + self.wh(m)
62
+ i, f, o, u = torch.chunk(z, 4, 1)
63
+ i = torch.sigmoid(i)
64
+ f = torch.sigmoid(f)
65
+ o = torch.sigmoid(o)
66
+ u = torch.tanh(u)
67
+ c = f * c_prev + i * u
68
+ h = o * torch.tanh(c)
69
+
70
+ return h, c
71
+
72
+
73
+ class mLSTM(nn.Module):
74
+
75
+ def __init__(self, config):
76
+ super().__init__()
77
+ self.mlstm_cell = mLSTMCell(config)
78
+ self.hidden_size = config.hidden_size
79
+
80
+ def forward(self, inputs, state=None, mask=None):
81
+ batch_size = inputs.size(0)
82
+ seqlen = inputs.size(1)
83
+
84
+ if mask is None:
85
+ mask = torch.ones(batch_size, seqlen, 1, dtype=inputs.dtype, device=inputs.device)
86
+ elif mask.dim() == 2:
87
+ mask = mask.unsqueeze(2)
88
+
89
+ if state is None:
90
+ zeros = torch.zeros(batch_size, self.hidden_size,
91
+ dtype=inputs.dtype, device=inputs.device)
92
+ state = (zeros, zeros)
93
+
94
+ steps = []
95
+ for seq in range(seqlen):
96
+ prev = state
97
+ seq_input = inputs[:, seq, :]
98
+ hx, cx = self.mlstm_cell(seq_input, state)
99
+ seqmask = mask[:, seq]
100
+ hx = seqmask * hx + (1 - seqmask) * prev[0]
101
+ cx = seqmask * cx + (1 - seqmask) * prev[1]
102
+ state = (hx, cx)
103
+ steps.append(hx)
104
+
105
+ return torch.stack(steps, 1), (hx, cx)
106
+
107
+
108
+ class UniRepAbstractModel(ProteinModel):
109
+
110
+ config_class = UniRepConfig
111
+ pretrained_model_archive_map = UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP
112
+ base_model_prefix = "unirep"
113
+
114
+ def _init_weights(self, module):
115
+ """ Initialize the weights """
116
+ if isinstance(module, (nn.Linear, nn.Embedding)):
117
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
118
+ if isinstance(module, nn.Linear) and module.bias is not None:
119
+ module.bias.data.zero_()
120
+
121
+
122
+ @registry.register_task_model('embed', 'unirep')
123
+ class UniRepModel(UniRepAbstractModel):
124
+
125
+ def __init__(self, config: UniRepConfig):
126
+ super().__init__(config)
127
+ self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size)
128
+ self.encoder = mLSTM(config)
129
+ self.output_hidden_states = config.output_hidden_states
130
+ self.init_weights()
131
+
132
+ def forward(self, input_ids, input_mask=None):
133
+ if input_mask is None:
134
+ input_mask = torch.ones_like(input_ids)
135
+
136
+ # fp16 compatibility
137
+ input_mask = input_mask.to(dtype=next(self.parameters()).dtype)
138
+ embedding_output = self.embed_matrix(input_ids)
139
+
140
+ encoder_outputs = self.encoder(embedding_output, mask=input_mask)
141
+ sequence_output = encoder_outputs[0]
142
+ hidden_states = encoder_outputs[1]
143
+ pooled_outputs = torch.cat(hidden_states, 1)
144
+
145
+ outputs = (sequence_output, pooled_outputs)
146
+ return outputs
147
+
148
+
149
+ @registry.register_task_model('language_modeling', 'unirep')
150
+ class UniRepForLM(UniRepAbstractModel):
151
+ # TODO: Fix this for UniRep - UniRep changes the size of the targets
152
+
153
+ def __init__(self, config):
154
+ super().__init__(config)
155
+
156
+ self.unirep = UniRepModel(config)
157
+ self.feedforward = nn.Linear(config.hidden_size, config.vocab_size - 1)
158
+
159
+ self.init_weights()
160
+
161
+ def forward(self,
162
+ input_ids,
163
+ input_mask=None,
164
+ targets=None):
165
+
166
+ outputs = self.unirep(input_ids, input_mask=input_mask)
167
+
168
+ sequence_output, pooled_output = outputs[:2]
169
+ prediction_scores = self.feedforward(sequence_output)
170
+
171
+ # add hidden states and if they are here
172
+ outputs = (prediction_scores,) + outputs[2:]
173
+
174
+ if targets is not None:
175
+ targets = targets[:, 1:]
176
+ prediction_scores = prediction_scores[:, :-1]
177
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
178
+ lm_loss = loss_fct(
179
+ prediction_scores.view(-1, self.config.vocab_size), targets.view(-1))
180
+ outputs = (lm_loss,) + outputs
181
+
182
+ # (loss), prediction_scores, (hidden_states)
183
+ return outputs
184
+
185
+
186
+ @registry.register_task_model('fluorescence', 'unirep')
187
+ @registry.register_task_model('stability', 'unirep')
188
+ class UniRepForValuePrediction(UniRepAbstractModel):
189
+
190
+ def __init__(self, config):
191
+ super().__init__(config)
192
+
193
+ self.unirep = UniRepModel(config)
194
+ self.predict = ValuePredictionHead(config.hidden_size * 2)
195
+
196
+ self.init_weights()
197
+
198
+ def forward(self, input_ids, input_mask=None, targets=None):
199
+
200
+ outputs = self.unirep(input_ids, input_mask=input_mask)
201
+
202
+ sequence_output, pooled_output = outputs[:2]
203
+ outputs = self.predict(pooled_output, targets) + outputs[2:]
204
+ # (loss), prediction_scores, (hidden_states)
205
+ return outputs
206
+
207
+
208
+ @registry.register_task_model('remote_homology', 'unirep')
209
+ class UniRepForSequenceClassification(UniRepAbstractModel):
210
+
211
+ def __init__(self, config):
212
+ super().__init__(config)
213
+
214
+ self.unirep = UniRepModel(config)
215
+ self.classify = SequenceClassificationHead(
216
+ config.hidden_size * 2, config.num_labels)
217
+
218
+ self.init_weights()
219
+
220
+ def forward(self, input_ids, input_mask=None, targets=None):
221
+
222
+ outputs = self.unirep(input_ids, input_mask=input_mask)
223
+
224
+ sequence_output, pooled_output = outputs[:2]
225
+ outputs = self.classify(pooled_output, targets) + outputs[2:]
226
+ # (loss), prediction_scores, (hidden_states)
227
+ return outputs
228
+
229
+
230
+ @registry.register_task_model('secondary_structure', 'unirep')
231
+ class UniRepForSequenceToSequenceClassification(UniRepAbstractModel):
232
+
233
+ def __init__(self, config):
234
+ super().__init__(config)
235
+
236
+ self.unirep = UniRepModel(config)
237
+ self.classify = SequenceToSequenceClassificationHead(
238
+ config.hidden_size, config.num_labels, ignore_index=-1)
239
+
240
+ self.init_weights()
241
+
242
+ def forward(self, input_ids, input_mask=None, targets=None):
243
+
244
+ outputs = self.unirep(input_ids, input_mask=input_mask)
245
+
246
+ sequence_output, pooled_output = outputs[:2]
247
+ outputs = self.classify(sequence_output, targets) + outputs[2:]
248
+ # (loss), prediction_scores, (hidden_states)
249
+ return outputs
250
+
251
+
252
+ @registry.register_task_model('contact_prediction', 'unirep')
253
+ class UniRepForContactPrediction(UniRepAbstractModel):
254
+
255
+ def __init__(self, config):
256
+ super().__init__(config)
257
+
258
+ self.unirep = UniRepModel(config)
259
+ self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
260
+
261
+ self.init_weights()
262
+
263
+ def forward(self, input_ids, protein_length, input_mask=None, targets=None):
264
+
265
+ outputs = self.unirep(input_ids, input_mask=input_mask)
266
+
267
+ sequence_output, pooled_output = outputs[:2]
268
+ outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
269
+ # (loss), prediction_scores, (hidden_states), (attentions)
270
+ return outputs
tape/models/modeling_utils.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ # Modified by Roshan Rao
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """PyTorch Protein models."""
18
+ from __future__ import (absolute_import, division, print_function,
19
+ unicode_literals)
20
+ import typing
21
+ import copy
22
+ import json
23
+ import logging
24
+ import os
25
+ from io import open
26
+ import math
27
+ from torch.nn.utils.weight_norm import weight_norm
28
+
29
+ import torch
30
+ from torch import nn
31
+ import torch.nn.functional as F
32
+
33
+ from .file_utils import cached_path
34
+
35
+ CONFIG_NAME = "config.json"
36
+ WEIGHTS_NAME = "pytorch_model.bin"
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class ProteinConfig(object):
42
+ """ Base class for all configuration classes.
43
+ Handles a few parameters common to all models' configurations as well as methods
44
+ for loading/downloading/saving configurations.
45
+
46
+ Class attributes (overridden by derived classes):
47
+ - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names`
48
+ (string) as keys and `url` (string) of associated pretrained model
49
+ configurations as values.
50
+
51
+ Parameters:
52
+ ``finetuning_task``: string, default `None`. Name of the task used to fine-tune
53
+ the model.
54
+ ``num_labels``: integer, default `2`. Number of classes to use when the model is
55
+ a classification model (sequences/tokens)
56
+ ``output_attentions``: boolean, default `False`. Should the model returns
57
+ attentions weights.
58
+ ``output_hidden_states``: string, default `False`. Should the model returns all
59
+ hidden-states.
60
+ ``torchscript``: string, default `False`. Is the model used with Torchscript.
61
+ """
62
+ pretrained_config_archive_map: typing.Dict[str, str] = {}
63
+
64
+ def __init__(self, **kwargs):
65
+ self.finetuning_task = kwargs.pop('finetuning_task', None)
66
+ self.num_labels = kwargs.pop('num_labels', 2)
67
+ self.output_attentions = kwargs.pop('output_attentions', False)
68
+ self.output_hidden_states = kwargs.pop('output_hidden_states', False)
69
+ self.torchscript = kwargs.pop('torchscript', False)
70
+
71
+ def save_pretrained(self, save_directory):
72
+ """ Save a configuration object to the directory `save_directory`, so that it
73
+ can be re-loaded using the :func:`~ProteinConfig.from_pretrained`
74
+ class method.
75
+ """
76
+ assert os.path.isdir(save_directory), "Saving path should be a directory where the " \
77
+ "model and configuration can be saved"
78
+
79
+ # If we save using the predefined names, we can load using `from_pretrained`
80
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
81
+
82
+ self.to_json_file(output_config_file)
83
+
84
+ @classmethod
85
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
86
+ r""" Instantiate a :class:`~ProteinConfig`
87
+ (or a derived class) from a pre-trained model configuration.
88
+
89
+ Parameters:
90
+ pretrained_model_name_or_path: either:
91
+
92
+ - a string with the `shortcut name` of a pre-trained model configuration to
93
+ load from cache or download, e.g.: ``bert-base-uncased``.
94
+ - a path to a `directory` containing a configuration file saved using the
95
+ :func:`~ProteinConfig.save_pretrained` method,
96
+ e.g.: ``./my_model_directory/``.
97
+ - a path or url to a saved configuration JSON `file`,
98
+ e.g.: ``./my_model_directory/configuration.json``.
99
+
100
+ cache_dir: (`optional`) string:
101
+ Path to a directory in which a downloaded pre-trained model
102
+ configuration should be cached if the standard cache should not be used.
103
+
104
+ kwargs: (`optional`) dict:
105
+ key/value pairs with which to update the configuration object after loading.
106
+
107
+ - The values in kwargs of any keys which are configuration attributes will
108
+ be used to override the loaded values.
109
+ - Behavior concerning key/value pairs whose keys are *not* configuration
110
+ attributes is controlled by the `return_unused_kwargs` keyword parameter.
111
+
112
+ return_unused_kwargs: (`optional`) bool:
113
+
114
+ - If False, then this function returns just the final configuration object.
115
+ - If True, then this functions returns a tuple `(config, unused_kwargs)`
116
+ where `unused_kwargs` is a dictionary consisting of the key/value pairs
117
+ whose keys are not configuration attributes: ie the part of kwargs which
118
+ has not been used to update `config` and is otherwise ignored.
119
+
120
+ Examples::
121
+
122
+ # We can't instantiate directly the base class `ProteinConfig` so let's
123
+ show the examples on a derived class: ProteinBertConfig
124
+ # Download configuration from S3 and cache.
125
+ config = ProteinBertConfig.from_pretrained('bert-base-uncased')
126
+ # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
127
+ config = ProteinBertConfig.from_pretrained('./test/saved_model/')
128
+ config = ProteinBertConfig.from_pretrained(
129
+ './test/saved_model/my_configuration.json')
130
+ config = ProteinBertConfig.from_pretrained(
131
+ 'bert-base-uncased', output_attention=True, foo=False)
132
+ assert config.output_attention == True
133
+ config, unused_kwargs = BertConfig.from_pretrained(
134
+ 'bert-base-uncased', output_attention=True,
135
+ foo=False, return_unused_kwargs=True)
136
+ assert config.output_attention == True
137
+ assert unused_kwargs == {'foo': False}
138
+
139
+ """
140
+ cache_dir = kwargs.pop('cache_dir', None)
141
+ return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
142
+
143
+ if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
144
+ config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
145
+ elif os.path.isdir(pretrained_model_name_or_path):
146
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
147
+ else:
148
+ config_file = pretrained_model_name_or_path
149
+ # redirect to the cache, if necessary
150
+ try:
151
+ resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
152
+ except EnvironmentError:
153
+ if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
154
+ logger.error("Couldn't reach server at '{}' to download pretrained model "
155
+ "configuration file.".format(config_file))
156
+ else:
157
+ logger.error(
158
+ "Model name '{}' was not found in model name list ({}). "
159
+ "We assumed '{}' was a path or url but couldn't find any file "
160
+ "associated to this path or url.".format(
161
+ pretrained_model_name_or_path,
162
+ ', '.join(cls.pretrained_config_archive_map.keys()),
163
+ config_file))
164
+ return None
165
+ if resolved_config_file == config_file:
166
+ logger.info("loading configuration file {}".format(config_file))
167
+ else:
168
+ logger.info("loading configuration file {} from cache at {}".format(
169
+ config_file, resolved_config_file))
170
+
171
+ # Load config
172
+ config = cls.from_json_file(resolved_config_file)
173
+
174
+ # Update config with kwargs if needed
175
+ to_remove = []
176
+ for key, value in kwargs.items():
177
+ if hasattr(config, key):
178
+ setattr(config, key, value)
179
+ to_remove.append(key)
180
+ for key in to_remove:
181
+ kwargs.pop(key, None)
182
+
183
+ logger.info("Model config %s", config)
184
+ if return_unused_kwargs:
185
+ return config, kwargs
186
+ else:
187
+ return config
188
+
189
+ @classmethod
190
+ def from_dict(cls, json_object):
191
+ """Constructs a `Config` from a Python dictionary of parameters."""
192
+ config = cls(vocab_size_or_config_json_file=-1)
193
+ for key, value in json_object.items():
194
+ config.__dict__[key] = value
195
+ return config
196
+
197
+ @classmethod
198
+ def from_json_file(cls, json_file):
199
+ """Constructs a `BertConfig` from a json file of parameters."""
200
+ with open(json_file, "r", encoding='utf-8') as reader:
201
+ text = reader.read()
202
+ return cls.from_dict(json.loads(text))
203
+
204
+ def __eq__(self, other):
205
+ return self.__dict__ == other.__dict__
206
+
207
+ def __repr__(self):
208
+ return str(self.to_json_string())
209
+
210
+ def to_dict(self):
211
+ """Serializes this instance to a Python dictionary."""
212
+ output = copy.deepcopy(self.__dict__)
213
+ return output
214
+
215
+ def to_json_string(self):
216
+ """Serializes this instance to a JSON string."""
217
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
218
+
219
+ def to_json_file(self, json_file_path):
220
+ """ Save this instance to a json file."""
221
+ with open(json_file_path, "w", encoding='utf-8') as writer:
222
+ writer.write(self.to_json_string())
223
+
224
+
225
+ class ProteinModel(nn.Module):
226
+ r""" Base class for all models.
227
+
228
+ :class:`~ProteinModel` takes care of storing the configuration of
229
+ the models and handles methods for loading/downloading/saving models as well as a
230
+ few methods commons to all models to (i) resize the input embeddings and (ii) prune
231
+ heads in the self-attention heads.
232
+
233
+ Class attributes (overridden by derived classes):
234
+ - ``config_class``: a class derived from :class:`~ProteinConfig`
235
+ to use as configuration class for this model architecture.
236
+ - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names`
237
+ (string) as keys and `url` (string) of associated pretrained weights as values.
238
+
239
+ - ``base_model_prefix``: a string indicating the attribute associated to the
240
+ base model in derived classes of the same architecture adding modules on top
241
+ of the base model.
242
+ """
243
+ config_class: typing.Type[ProteinConfig] = ProteinConfig
244
+ pretrained_model_archive_map: typing.Dict[str, str] = {}
245
+ base_model_prefix = ""
246
+
247
+ def __init__(self, config, *inputs, **kwargs):
248
+ super().__init__()
249
+ if not isinstance(config, ProteinConfig):
250
+ raise ValueError(
251
+ "Parameter config in `{}(config)` should be an instance of class "
252
+ "`ProteinConfig`. To create a model from a pretrained model use "
253
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
254
+ self.__class__.__name__, self.__class__.__name__
255
+ ))
256
+ # Save config in model
257
+ self.config = config
258
+
259
+ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
260
+ """ Build a resized Embedding Module from a provided token Embedding Module.
261
+ Increasing the size will add newly initialized vectors at the end
262
+ Reducing the size will remove vectors from the end
263
+
264
+ Args:
265
+ new_num_tokens: (`optional`) int
266
+ New number of tokens in the embedding matrix.
267
+ Increasing the size will add newly initialized vectors at the end
268
+ Reducing the size will remove vectors from the end
269
+ If not provided or None: return the provided token Embedding Module.
270
+ Return: ``torch.nn.Embeddings``
271
+ Pointer to the resized Embedding Module or the old Embedding Module if
272
+ new_num_tokens is None
273
+ """
274
+ if new_num_tokens is None:
275
+ return old_embeddings
276
+
277
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
278
+ if old_num_tokens == new_num_tokens:
279
+ return old_embeddings
280
+
281
+ # Build new embeddings
282
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
283
+ new_embeddings.to(old_embeddings.weight.device)
284
+
285
+ # initialize all new embeddings (in particular added tokens)
286
+ self.init_weights(new_embeddings)
287
+
288
+ # Copy word embeddings from the previous weights
289
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
290
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = \
291
+ old_embeddings.weight.data[:num_tokens_to_copy, :]
292
+
293
+ return new_embeddings
294
+
295
+ def _tie_or_clone_weights(self, first_module, second_module):
296
+ """ Tie or clone module weights depending of weither we are using TorchScript or not
297
+ """
298
+ if self.config.torchscript:
299
+ first_module.weight = nn.Parameter(second_module.weight.clone())
300
+ else:
301
+ first_module.weight = second_module.weight
302
+
303
+ def resize_token_embeddings(self, new_num_tokens=None):
304
+ """ Resize input token embeddings matrix of the model if
305
+ new_num_tokens != config.vocab_size. Take care of tying weights embeddings
306
+ afterwards if the model class has a `tie_weights()` method.
307
+
308
+ Arguments:
309
+
310
+ new_num_tokens: (`optional`) int:
311
+ New number of tokens in the embedding matrix. Increasing the size will add
312
+ newly initialized vectors at the end. Reducing the size will remove vectors
313
+ from the end. If not provided or None: does nothing and just returns a
314
+ pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
315
+
316
+ Return: ``torch.nn.Embeddings``
317
+ Pointer to the input tokens Embeddings Module of the model
318
+ """
319
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
320
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens)
321
+ if new_num_tokens is None:
322
+ return model_embeds
323
+
324
+ # Update base model and current model config
325
+ self.config.vocab_size = new_num_tokens
326
+ base_model.vocab_size = new_num_tokens
327
+
328
+ # Tie weights again if needed
329
+ if hasattr(self, 'tie_weights'):
330
+ self.tie_weights()
331
+
332
+ return model_embeds
333
+
334
+ def init_weights(self):
335
+ """ Initialize and prunes weights if needed. """
336
+ # Initialize weights
337
+ self.apply(self._init_weights)
338
+
339
+ # Prune heads if needed
340
+ if getattr(self.config, 'pruned_heads', False):
341
+ self.prune_heads(self.config.pruned_heads)
342
+
343
+ def prune_heads(self, heads_to_prune):
344
+ """ Prunes heads of the base model.
345
+
346
+ Arguments:
347
+
348
+ heads_to_prune: dict with keys being selected layer indices (`int`) and
349
+ associated values being the list of heads to prune in said layer
350
+ (list of `int`).
351
+ """
352
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
353
+ base_model._prune_heads(heads_to_prune)
354
+
355
+ def save_pretrained(self, save_directory):
356
+ """ Save a model and its configuration file to a directory, so that it
357
+ can be re-loaded using the `:func:`~ProteinModel.from_pretrained`
358
+ ` class method.
359
+ """
360
+ assert os.path.isdir(save_directory), "Saving path should be a directory where "\
361
+ "the model and configuration can be saved"
362
+
363
+ # Only save the model it-self if we are using distributed training
364
+ model_to_save = self.module if hasattr(self, 'module') else self
365
+
366
+ # Save configuration file
367
+ model_to_save.config.save_pretrained(save_directory)
368
+
369
+ # If we save using the predefined names, we can load using `from_pretrained`
370
+ output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
371
+
372
+ torch.save(model_to_save.state_dict(), output_model_file)
373
+
374
+ @classmethod
375
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
376
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
377
+
378
+ The model is set in evaluation mode by default using ``model.eval()``
379
+ (Dropout modules are deactivated)
380
+ To train the model, you should first set it back in training mode with ``model.train()``
381
+
382
+ The warning ``Weights from XXX not initialized from pretrained model`` means that
383
+ the weights of XXX do not come pre-trained with the rest of the model.
384
+ It is up to you to train those weights with a downstream fine-tuning task.
385
+
386
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used
387
+ by YYY, therefore those weights are discarded.
388
+
389
+ Parameters:
390
+ pretrained_model_name_or_path: either:
391
+
392
+ - a string with the `shortcut name` of a pre-trained model to load from cache
393
+ or download, e.g.: ``bert-base-uncased``.
394
+ - a path to a `directory` containing model weights saved using
395
+ :func:`~ProteinModel.save_pretrained`,
396
+ e.g.: ``./my_model_directory/``.
397
+
398
+ model_args: (`optional`) Sequence of positional arguments:
399
+ All remaning positional arguments will be passed to the underlying model's
400
+ ``__init__`` method
401
+
402
+ config: (`optional`) instance of a class derived from
403
+ :class:`~ProteinConfig`: Configuration for the model to
404
+ use instead of an automatically loaded configuation. Configuration can be
405
+ automatically loaded when:
406
+
407
+ - the model is a model provided by the library (loaded with the
408
+ ``shortcut-name`` string of a pretrained model), or
409
+ - the model was saved using
410
+ :func:`~ProteinModel.save_pretrained` and is reloaded
411
+ by suppling the save directory.
412
+ - the model is loaded by suppling a local directory as
413
+ ``pretrained_model_name_or_path`` and a configuration JSON file named
414
+ `config.json` is found in the directory.
415
+
416
+ state_dict: (`optional`) dict:
417
+ an optional state dictionnary for the model to use instead of a state
418
+ dictionary loaded from saved weights file. This option can be used if you
419
+ want to create a model from a pretrained configuration but load your own
420
+ weights. In this case though, you should check if using
421
+ :func:`~ProteinModel.save_pretrained` and
422
+ :func:`~ProteinModel.from_pretrained` is not a
423
+ simpler option.
424
+
425
+ cache_dir: (`optional`) string:
426
+ Path to a directory in which a downloaded pre-trained model
427
+ configuration should be cached if the standard cache should not be used.
428
+
429
+ force_download: (`optional`) boolean, default False:
430
+ Force to (re-)download the model weights and configuration files and override
431
+ the cached versions if they exists.
432
+
433
+ resume_download: (`optional`) boolean, default False:
434
+ Do not delete incompletely recieved file. Attempt to resume the download if
435
+ such a file exists.
436
+
437
+ output_loading_info: (`optional`) boolean:
438
+ Set to ``True`` to also return a dictionnary containing missing keys,
439
+ unexpected keys and error messages.
440
+
441
+ kwargs: (`optional`) Remaining dictionary of keyword arguments:
442
+ Can be used to update the configuration object (after it being loaded) and
443
+ initiate the model. (e.g. ``output_attention=True``). Behave differently
444
+ depending on whether a `config` is provided or automatically loaded:
445
+
446
+ - If a configuration is provided with ``config``, ``**kwarg
447
+ directly passed to the underlying model's ``__init__`` method (we assume
448
+ all relevant updates to the configuration have already been done)
449
+ - If a configuration is not provided, ``kwargs`` will be first passed to the
450
+ configuration class initialization function
451
+ (:func:`~ProteinConfig.from_pretrained`). Each key of
452
+ ``kwargs`` that corresponds to a configuration attribute will be used to
453
+ override said attribute with the supplied ``kwargs`` value. Remaining keys
454
+ that do not correspond to any configuration attribute will be passed to the
455
+ underlying model's ``__init__`` function.
456
+
457
+ Examples::
458
+
459
+ # Download model and configuration from S3 and cache.
460
+ model = ProteinBertModel.from_pretrained('bert-base-uncased')
461
+ # E.g. model was saved using `save_pretrained('./test/saved_model/')`
462
+ model = ProteinBertModel.from_pretrained('./test/saved_model/')
463
+ # Update configuration during loading
464
+ model = ProteinBertModel.from_pretrained('bert-base-uncased', output_attention=True)
465
+ assert model.config.output_attention == True
466
+
467
+ """
468
+ config = kwargs.pop('config', None)
469
+ state_dict = kwargs.pop('state_dict', None)
470
+ cache_dir = kwargs.pop('cache_dir', None)
471
+ output_loading_info = kwargs.pop('output_loading_info', False)
472
+
473
+ force_download = kwargs.pop("force_download", False)
474
+ kwargs.pop("resume_download", False)
475
+
476
+ # Load config
477
+ if config is None:
478
+ config, model_kwargs = cls.config_class.from_pretrained(
479
+ pretrained_model_name_or_path, *model_args,
480
+ cache_dir=cache_dir, return_unused_kwargs=True,
481
+ # force_download=force_download,
482
+ # resume_download=resume_download,
483
+ **kwargs
484
+ )
485
+ else:
486
+ model_kwargs = kwargs
487
+
488
+ # Load model
489
+ if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
490
+ archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
491
+ elif os.path.isdir(pretrained_model_name_or_path):
492
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
493
+ else:
494
+ archive_file = pretrained_model_name_or_path
495
+ # redirect to the cache, if necessary
496
+ try:
497
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir,
498
+ force_download=force_download)
499
+ except EnvironmentError:
500
+ if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
501
+ logger.error(
502
+ "Couldn't reach server at '{}' to download pretrained weights.".format(
503
+ archive_file))
504
+ else:
505
+ logger.error(
506
+ "Model name '{}' was not found in model name list ({}). "
507
+ "We assumed '{}' was a path or url but couldn't find any file "
508
+ "associated to this path or url.".format(
509
+ pretrained_model_name_or_path,
510
+ ', '.join(cls.pretrained_model_archive_map.keys()),
511
+ archive_file))
512
+ return None
513
+ if resolved_archive_file == archive_file:
514
+ logger.info("loading weights file {}".format(archive_file))
515
+ else:
516
+ logger.info("loading weights file {} from cache at {}".format(
517
+ archive_file, resolved_archive_file))
518
+
519
+ # Instantiate model.
520
+ model = cls(config, *model_args, **model_kwargs)
521
+
522
+ if state_dict is None:
523
+ state_dict = torch.load(resolved_archive_file, map_location='cpu')
524
+
525
+ # Convert old format to new format if needed from a PyTorch state_dict
526
+ old_keys = []
527
+ new_keys = []
528
+ for key in state_dict.keys():
529
+ new_key = None
530
+ if 'gamma' in key:
531
+ new_key = key.replace('gamma', 'weight')
532
+ if 'beta' in key:
533
+ new_key = key.replace('beta', 'bias')
534
+ if new_key:
535
+ old_keys.append(key)
536
+ new_keys.append(new_key)
537
+ for old_key, new_key in zip(old_keys, new_keys):
538
+ state_dict[new_key] = state_dict.pop(old_key)
539
+
540
+ # Load from a PyTorch state_dict
541
+ missing_keys = []
542
+ unexpected_keys = []
543
+ error_msgs = []
544
+ # copy state_dict so _load_from_state_dict can modify it
545
+ metadata = getattr(state_dict, '_metadata', None)
546
+ state_dict = state_dict.copy()
547
+ if metadata is not None:
548
+ state_dict._metadata = metadata
549
+
550
+ def load(module, prefix=''):
551
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
552
+ module._load_from_state_dict(
553
+ state_dict, prefix, local_metadata, True, missing_keys,
554
+ unexpected_keys, error_msgs)
555
+ for name, child in module._modules.items():
556
+ if child is not None:
557
+ load(child, prefix + name + '.')
558
+
559
+ # Make sure we are able to load base models as well as derived models (with heads)
560
+ start_prefix = ''
561
+ model_to_load = model
562
+ if cls.base_model_prefix not in (None, ''):
563
+ if not hasattr(model, cls.base_model_prefix) and \
564
+ any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
565
+ start_prefix = cls.base_model_prefix + '.'
566
+ if hasattr(model, cls.base_model_prefix) and \
567
+ not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
568
+ model_to_load = getattr(model, cls.base_model_prefix)
569
+
570
+ load(model_to_load, prefix=start_prefix)
571
+ if len(missing_keys) > 0:
572
+ logger.info("Weights of {} not initialized from pretrained model: {}".format(
573
+ model.__class__.__name__, missing_keys))
574
+ if len(unexpected_keys) > 0:
575
+ logger.info("Weights from pretrained model not used in {}: {}".format(
576
+ model.__class__.__name__, unexpected_keys))
577
+ if len(error_msgs) > 0:
578
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
579
+ model.__class__.__name__, "\n\t".join(error_msgs)))
580
+
581
+ if hasattr(model, 'tie_weights'):
582
+ model.tie_weights() # make sure word embedding weights are still tied
583
+
584
+ # Set model in evaluation mode to desactivate DropOut modules by default
585
+ model.eval()
586
+
587
+ if output_loading_info:
588
+ loading_info = {
589
+ "missing_keys": missing_keys,
590
+ "unexpected_keys": unexpected_keys,
591
+ "error_msgs": error_msgs}
592
+ return model, loading_info
593
+
594
+ return model
595
+
596
+
597
+ def prune_linear_layer(layer, index, dim=0):
598
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
599
+ Return the pruned layer as a new layer with requires_grad=True.
600
+ Used to remove heads.
601
+ """
602
+ index = index.to(layer.weight.device)
603
+ W = layer.weight.index_select(dim, index).clone().detach()
604
+ if layer.bias is not None:
605
+ if dim == 1:
606
+ b = layer.bias.clone().detach()
607
+ else:
608
+ b = layer.bias[index].clone().detach()
609
+ new_size = list(layer.weight.size())
610
+ new_size[dim] = len(index)
611
+ new_layer = nn.Linear(
612
+ new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
613
+ new_layer.weight.requires_grad = False
614
+ new_layer.weight.copy_(W.contiguous())
615
+ new_layer.weight.requires_grad = True
616
+ if layer.bias is not None:
617
+ new_layer.bias.requires_grad = False
618
+ new_layer.bias.copy_(b.contiguous())
619
+ new_layer.bias.requires_grad = True
620
+ return new_layer
621
+
622
+
623
+ def accuracy(logits, labels, ignore_index: int = -100):
624
+ with torch.no_grad():
625
+ valid_mask = (labels != ignore_index)
626
+ predictions = logits.float().argmax(-1)
627
+ correct = (predictions == labels) * valid_mask
628
+ return correct.sum().float() / valid_mask.sum().float()
629
+
630
+
631
+ def gelu(x):
632
+ """Implementation of the gelu activation function.
633
+ For information: OpenAI GPT's gelu is slightly different
634
+ (and gives slightly different results):
635
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
636
+ Also see https://arxiv.org/abs/1606.08415
637
+ """
638
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
639
+
640
+
641
+ def swish(x):
642
+ return x * torch.sigmoid(x)
643
+
644
+
645
+ def get_activation_fn(name: str) -> typing.Callable:
646
+ if name == 'gelu':
647
+ return gelu
648
+ elif name == 'relu':
649
+ return torch.nn.functional.relu
650
+ elif name == 'swish':
651
+ return swish
652
+ else:
653
+ raise ValueError(f"Unrecognized activation fn: {name}")
654
+
655
+
656
+ try:
657
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm # type: ignore
658
+ except (ImportError, AttributeError):
659
+ logger.info("Better speed can be achieved with apex installed from "
660
+ "https://www.github.com/nvidia/apex .")
661
+
662
+ class LayerNorm(nn.Module): # type: ignore
663
+ def __init__(self, hidden_size, eps=1e-12):
664
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
665
+ """
666
+ super().__init__()
667
+ self.weight = nn.Parameter(torch.ones(hidden_size))
668
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
669
+ self.variance_epsilon = eps
670
+
671
+ def forward(self, x):
672
+ u = x.mean(-1, keepdim=True)
673
+ s = (x - u).pow(2).mean(-1, keepdim=True)
674
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
675
+ return self.weight * x + self.bias
676
+
677
+
678
+ class SimpleMLP(nn.Module):
679
+
680
+ def __init__(self,
681
+ in_dim: int,
682
+ hid_dim: int,
683
+ out_dim: int,
684
+ dropout: float = 0.):
685
+ super().__init__()
686
+ self.main = nn.Sequential(
687
+ weight_norm(nn.Linear(in_dim, hid_dim), dim=None),
688
+ nn.ReLU(),
689
+ nn.Dropout(dropout, inplace=True),
690
+ weight_norm(nn.Linear(hid_dim, out_dim), dim=None))
691
+
692
+ def forward(self, x):
693
+ return self.main(x)
694
+
695
+
696
+ class SimpleConv(nn.Module):
697
+
698
+ def __init__(self,
699
+ in_dim: int,
700
+ hid_dim: int,
701
+ out_dim: int,
702
+ dropout: float = 0.):
703
+ super().__init__()
704
+ self.main = nn.Sequential(
705
+ nn.BatchNorm1d(in_dim), # Added this
706
+ weight_norm(nn.Conv1d(in_dim, hid_dim, 5, padding=2), dim=None),
707
+ nn.ReLU(),
708
+ nn.Dropout(dropout, inplace=True),
709
+ weight_norm(nn.Conv1d(hid_dim, out_dim, 3, padding=1), dim=None))
710
+
711
+ def forward(self, x):
712
+ x = x.transpose(1, 2)
713
+ x = self.main(x)
714
+ x = x.transpose(1, 2).contiguous()
715
+ return x
716
+
717
+
718
+ class Accuracy(nn.Module):
719
+
720
+ def __init__(self, ignore_index: int = -100):
721
+ super().__init__()
722
+ self.ignore_index = ignore_index
723
+
724
+ def forward(self, inputs, target):
725
+ return accuracy(inputs, target, self.ignore_index)
726
+
727
+
728
+ class PredictionHeadTransform(nn.Module):
729
+
730
+ def __init__(self,
731
+ hidden_size: int,
732
+ hidden_act: typing.Union[str, typing.Callable] = 'gelu',
733
+ layer_norm_eps: float = 1e-12):
734
+ super().__init__()
735
+ self.dense = nn.Linear(hidden_size, hidden_size)
736
+ if isinstance(hidden_act, str):
737
+ self.transform_act_fn = get_activation_fn(hidden_act)
738
+ else:
739
+ self.transform_act_fn = hidden_act
740
+ self.LayerNorm = LayerNorm(hidden_size, eps=layer_norm_eps)
741
+
742
+ def forward(self, hidden_states):
743
+ hidden_states = self.dense(hidden_states)
744
+ hidden_states = self.transform_act_fn(hidden_states)
745
+ hidden_states = self.LayerNorm(hidden_states)
746
+ return hidden_states
747
+
748
+
749
+ class MLMHead(nn.Module):
750
+
751
+ def __init__(self,
752
+ hidden_size: int,
753
+ vocab_size: int,
754
+ hidden_act: typing.Union[str, typing.Callable] = 'gelu',
755
+ layer_norm_eps: float = 1e-12,
756
+ ignore_index: int = -100):
757
+ super().__init__()
758
+ self.transform = PredictionHeadTransform(hidden_size, hidden_act, layer_norm_eps)
759
+
760
+ # The output weights are the same as the input embeddings, but there is
761
+ # an output-only bias for each token.
762
+ self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
763
+ self.bias = nn.Parameter(data=torch.zeros(vocab_size)) # type: ignore
764
+ self.vocab_size = vocab_size
765
+ self._ignore_index = ignore_index
766
+
767
+ def forward(self, hidden_states, targets=None):
768
+ hidden_states = self.transform(hidden_states)
769
+ hidden_states = self.decoder(hidden_states) + self.bias
770
+ outputs = (hidden_states,)
771
+ if targets is not None:
772
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index)
773
+ masked_lm_loss = loss_fct(
774
+ hidden_states.reshape(-1, self.vocab_size), targets.reshape(-1))
775
+ metrics = {'perplexity': torch.exp(masked_lm_loss)}
776
+ loss_and_metrics = (masked_lm_loss, metrics)
777
+ outputs = (loss_and_metrics,) + outputs
778
+ return outputs # (loss), prediction_scores
779
+
780
+
781
+ class ValuePredictionHead(nn.Module):
782
+ def __init__(self, hidden_size: int, dropout: float = 0.):
783
+ super().__init__()
784
+ self.value_prediction = SimpleMLP(hidden_size, 512, 1, dropout)
785
+
786
+ def forward(self, pooled_output, targets=None):
787
+ value_pred = self.value_prediction(pooled_output)
788
+ outputs = (value_pred,)
789
+
790
+ if targets is not None:
791
+ loss_fct = nn.MSELoss()
792
+ value_pred_loss = loss_fct(value_pred, targets)
793
+ outputs = (value_pred_loss,) + outputs
794
+ return outputs # (loss), value_prediction
795
+
796
+
797
+ class SequenceClassificationHead(nn.Module):
798
+ def __init__(self, hidden_size: int, num_labels: int):
799
+ super().__init__()
800
+ self.classify = SimpleMLP(hidden_size, 512, num_labels)
801
+
802
+ def forward(self, pooled_output, targets=None):
803
+ logits = self.classify(pooled_output)
804
+ outputs = (logits,)
805
+
806
+ if targets is not None:
807
+ loss_fct = nn.CrossEntropyLoss()
808
+ classification_loss = loss_fct(logits, targets)
809
+ metrics = {'accuracy': accuracy(logits, targets)}
810
+ loss_and_metrics = (classification_loss, metrics)
811
+ outputs = (loss_and_metrics,) + outputs
812
+
813
+ return outputs # (loss), logits
814
+
815
+
816
+ class SequenceToSequenceClassificationHead(nn.Module):
817
+
818
+ def __init__(self,
819
+ hidden_size: int,
820
+ num_labels: int,
821
+ ignore_index: int = -100):
822
+ super().__init__()
823
+ self.classify = SimpleConv(
824
+ hidden_size, 512, num_labels)
825
+ self.num_labels = num_labels
826
+ self._ignore_index = ignore_index
827
+
828
+ def forward(self, sequence_output, targets=None):
829
+ sequence_logits = self.classify(sequence_output)
830
+ outputs = (sequence_logits,)
831
+ if targets is not None:
832
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index)
833
+ classification_loss = loss_fct(
834
+ sequence_logits.view(-1, self.num_labels), targets.view(-1))
835
+ acc_fct = Accuracy(ignore_index=self._ignore_index)
836
+ metrics = {'accuracy':
837
+ acc_fct(sequence_logits.view(-1, self.num_labels), targets.view(-1))}
838
+ loss_and_metrics = (classification_loss, metrics)
839
+ outputs = (loss_and_metrics,) + outputs
840
+ return outputs # (loss), sequence_logits
841
+
842
+
843
+ class PairwiseContactPredictionHead(nn.Module):
844
+
845
+ def __init__(self, hidden_size: int, ignore_index=-100):
846
+ super().__init__()
847
+ self.predict = nn.Sequential(
848
+ nn.Dropout(), nn.Linear(2 * hidden_size, 2))
849
+ self._ignore_index = ignore_index
850
+
851
+ def forward(self, inputs, sequence_lengths, targets=None):
852
+ prod = inputs[:, :, None, :] * inputs[:, None, :, :]
853
+ diff = inputs[:, :, None, :] - inputs[:, None, :, :]
854
+ pairwise_features = torch.cat((prod, diff), -1)
855
+ prediction = self.predict(pairwise_features)
856
+ prediction = (prediction + prediction.transpose(1, 2)) / 2
857
+ prediction = prediction[:, 1:-1, 1:-1].contiguous() # remove start/stop tokens
858
+ outputs = (prediction,)
859
+
860
+ if targets is not None:
861
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self._ignore_index)
862
+ contact_loss = loss_fct(
863
+ prediction.view(-1, 2), targets.view(-1))
864
+ metrics = {'precision_at_l5':
865
+ self.compute_precision_at_l5(sequence_lengths, prediction, targets)}
866
+ loss_and_metrics = (contact_loss, metrics)
867
+ outputs = (loss_and_metrics,) + outputs
868
+
869
+ return outputs
870
+
871
+ def compute_precision_at_l5(self, sequence_lengths, prediction, labels):
872
+ with torch.no_grad():
873
+ valid_mask = labels != self._ignore_index
874
+ seqpos = torch.arange(valid_mask.size(1), device=sequence_lengths.device)
875
+ x_ind, y_ind = torch.meshgrid(seqpos, seqpos)
876
+ valid_mask &= ((y_ind - x_ind) >= 6).unsqueeze(0)
877
+ probs = F.softmax(prediction, 3)[:, :, :, 1]
878
+ valid_mask = valid_mask.type_as(probs)
879
+ correct = 0
880
+ total = 0
881
+ for length, prob, label, mask in zip(sequence_lengths, probs, labels, valid_mask):
882
+ masked_prob = (prob * mask).view(-1)
883
+ most_likely = masked_prob.topk(length // 5, sorted=False)
884
+ selected = label.view(-1).gather(0, most_likely.indices)
885
+ correct += selected.sum().float()
886
+ total += selected.numel()
887
+ return correct / total
tape/optimization.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Modifications by Roshan Rao
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch optimization for BERT model."""
17
+
18
+ import logging
19
+ import math
20
+
21
+ import torch
22
+ from torch.optim import Optimizer # type: ignore
23
+ from torch.optim.lr_scheduler import LambdaLR
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ConstantLRSchedule(LambdaLR):
29
+ """ Constant learning rate schedule.
30
+ """
31
+ def __init__(self, optimizer, last_epoch=-1):
32
+ super(ConstantLRSchedule, self).__init__(
33
+ optimizer, lambda _: 1.0, last_epoch=last_epoch)
34
+
35
+
36
+ class WarmupConstantSchedule(LambdaLR):
37
+ """ Linear warmup and then constant.
38
+ Linearly increases learning rate schedule from 0 to 1 over `warmup_steps`
39
+ training steps. Keeps learning rate schedule equal to 1. after warmup_steps.
40
+ """
41
+ def __init__(self, optimizer, warmup_steps, last_epoch=-1):
42
+ self.warmup_steps = warmup_steps
43
+ super(WarmupConstantSchedule, self).__init__(
44
+ optimizer, self.lr_lambda, last_epoch=last_epoch)
45
+
46
+ def lr_lambda(self, step):
47
+ if step < self.warmup_steps:
48
+ return float(step) / float(max(1.0, self.warmup_steps))
49
+ return 1.
50
+
51
+
52
+ class WarmupLinearSchedule(LambdaLR):
53
+ """ Linear warmup and then linear decay.
54
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
55
+ Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps`
56
+ steps.
57
+ """
58
+ def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
59
+ self.warmup_steps = warmup_steps
60
+ self.t_total = t_total
61
+ super(WarmupLinearSchedule, self).__init__(
62
+ optimizer, self.lr_lambda, last_epoch=last_epoch)
63
+
64
+ def lr_lambda(self, step):
65
+ if step < self.warmup_steps:
66
+ return float(step) / float(max(1, self.warmup_steps))
67
+ return max(0.0, float(self.t_total - step) / float(
68
+ max(1.0, self.t_total - self.warmup_steps)))
69
+
70
+
71
+ class WarmupCosineSchedule(LambdaLR):
72
+ """ Linear warmup and then cosine decay.
73
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
74
+ Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps
75
+ following a cosine curve. If `cycles` (default=0.5) is different from default, learning
76
+ rate follows cosine function after warmup.
77
+ """
78
+ def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
79
+ self.warmup_steps = warmup_steps
80
+ self.t_total = t_total
81
+ self.cycles = cycles
82
+ super(WarmupCosineSchedule, self).__init__(
83
+ optimizer, self.lr_lambda, last_epoch=last_epoch)
84
+
85
+ def lr_lambda(self, step):
86
+ if step < self.warmup_steps:
87
+ return float(step) / float(max(1.0, self.warmup_steps))
88
+ # progress after warmup
89
+ progress = float(step - self.warmup_steps) / float(
90
+ max(1, self.t_total - self.warmup_steps))
91
+ return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
92
+
93
+
94
+ class WarmupCosineWithHardRestartsSchedule(LambdaLR):
95
+ """ Linear warmup and then cosine cycles with hard restarts.
96
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
97
+ If `cycles` (default=1.) is different from default, learning rate follows `cycles` times
98
+ a cosine decaying learning rate (with hard restarts).
99
+ """
100
+ def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1):
101
+ self.warmup_steps = warmup_steps
102
+ self.t_total = t_total
103
+ self.cycles = cycles
104
+ super(WarmupCosineWithHardRestartsSchedule, self).__init__(
105
+ optimizer, self.lr_lambda, last_epoch=last_epoch)
106
+
107
+ def lr_lambda(self, step):
108
+ if step < self.warmup_steps:
109
+ return float(step) / float(max(1, self.warmup_steps))
110
+ # progress after warmup
111
+ progress = float(step - self.warmup_steps) / float(
112
+ max(1, self.t_total - self.warmup_steps))
113
+ if progress >= 1.0:
114
+ return 0.0
115
+ return max(0.0, 0.5 * (1. + math.cos(
116
+ math.pi * ((float(self.cycles) * progress) % 1.0))))
117
+
118
+
119
+ class AdamW(Optimizer):
120
+ """ Implements Adam algorithm with weight decay fix.
121
+
122
+ Parameters:
123
+ lr (float): learning rate. Default 1e-3.
124
+ betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
125
+ eps (float): Adams epsilon. Default: 1e-6
126
+ weight_decay (float): Weight decay. Default: 0.0
127
+ correct_bias (bool): can be set to False to avoid correcting bias in Adam
128
+ (e.g. like in Bert TF repository). Default True.
129
+ """
130
+ def __init__(self,
131
+ params,
132
+ lr=1e-3,
133
+ betas=(0.9, 0.999),
134
+ eps=1e-6,
135
+ weight_decay=0.0,
136
+ correct_bias=True):
137
+ if lr < 0.0:
138
+ raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
139
+ if not 0.0 <= betas[0] < 1.0:
140
+ raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
141
+ if not 0.0 <= betas[1] < 1.0:
142
+ raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
143
+ if not 0.0 <= eps:
144
+ raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
145
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
146
+ correct_bias=correct_bias)
147
+ super(AdamW, self).__init__(params, defaults)
148
+
149
+ def step(self, closure=None):
150
+ """Performs a single optimization step.
151
+
152
+ Arguments:
153
+ closure (callable, optional): A closure that reevaluates the model
154
+ and returns the loss.
155
+ """
156
+ loss = None
157
+ if closure is not None:
158
+ loss = closure()
159
+
160
+ for group in self.param_groups:
161
+ for p in group['params']:
162
+ if p.grad is None:
163
+ continue
164
+ grad = p.grad.data
165
+ if grad.is_sparse:
166
+ raise RuntimeError('Adam does not support sparse gradients, '
167
+ 'please consider SparseAdam instead')
168
+
169
+ state = self.state[p]
170
+
171
+ # State initialization
172
+ if len(state) == 0:
173
+ state['step'] = 0
174
+ # Exponential moving average of gradient values
175
+ state['exp_avg'] = torch.zeros_like(p.data)
176
+ # Exponential moving average of squared gradient values
177
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
178
+
179
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
180
+ beta1, beta2 = group['betas']
181
+
182
+ state['step'] += 1
183
+
184
+ # Decay the first and second moment running average coefficient
185
+ # In-place operations to update the averages at the same time
186
+ exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
187
+ exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
188
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
189
+
190
+ step_size = group['lr']
191
+ if group['correct_bias']: # No bias correction for Bert
192
+ bias_correction1 = 1.0 - beta1 ** state['step']
193
+ bias_correction2 = 1.0 - beta2 ** state['step']
194
+ step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
195
+
196
+ p.data.addcdiv_(-step_size, exp_avg, denom)
197
+
198
+ # Just adding the square of the weights to the loss function is *not*
199
+ # the correct way of using L2 regularization/weight decay with Adam,
200
+ # since that will interact with the m and v parameters in strange ways.
201
+ #
202
+ # Instead we want to decay the weights in a manner that doesn't interact
203
+ # with the m/v parameters. This is equivalent to adding the square
204
+ # of the weights to the loss with plain (non-momentum) SGD.
205
+ # Add weight decay at the end (fixed version)
206
+ if group['weight_decay'] > 0.0:
207
+ p.data.add_(-group['lr'] * group['weight_decay'], p.data)
208
+
209
+ return loss
tape/registry.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Type, Callable, Optional, Union
2
+ from torch.utils.data import Dataset
3
+ from .models.modeling_utils import ProteinModel
4
+ from pathlib import Path
5
+
6
+ PathType = Union[str, Path]
7
+
8
+
9
+ def convert_model_args(model_args):
10
+ d = {}
11
+ for e in model_args:
12
+ k, v = e.split("=")
13
+ try:
14
+ v = int(v)
15
+ except:
16
+ try:
17
+ v = float(v)
18
+ except:
19
+ v = str(v)
20
+ d[k] = v
21
+ return d
22
+
23
+
24
+ class TAPETaskSpec:
25
+ """
26
+ Attributes
27
+ ----------
28
+ name (str):
29
+ The name of the TAPE task
30
+ dataset (Type[Dataset]):
31
+ The dataset used in the TAPE task
32
+ num_labels (int):
33
+ number of labels used if this is a classification task
34
+ models (Dict[str, ProteinModel]):
35
+ The set of models that can be used for this task. Default: {}.
36
+ """
37
+
38
+ def __init__(self,
39
+ name: str,
40
+ dataset: Type[Dataset],
41
+ num_labels: int = -1,
42
+ models: Optional[Dict[str, Type[ProteinModel]]] = None):
43
+ self.name = name
44
+ self.dataset = dataset
45
+ self.num_labels = num_labels
46
+ self.models = models if models is not None else {}
47
+
48
+ def register_model(self, model_name: str, model_cls: Optional[Type[ProteinModel]] = None):
49
+ if model_cls is not None:
50
+ if model_name in self.models:
51
+ raise KeyError(
52
+ f"A model with name '{model_name}' is already registered for this task")
53
+ self.models[model_name] = model_cls
54
+ return model_cls
55
+ else:
56
+ return lambda model_cls: self.register_model(model_name, model_cls)
57
+
58
+ def get_model(self, model_name: str) -> Type[ProteinModel]:
59
+ return self.models[model_name]
60
+
61
+
62
+ class Registry:
63
+ r"""Class for registry object which acts as the
64
+ central repository for TAPE."""
65
+
66
+ task_name_mapping: Dict[str, TAPETaskSpec] = {}
67
+ metric_name_mapping: Dict[str, Callable] = {}
68
+
69
+ @classmethod
70
+ def register_task(cls,
71
+ task_name: str,
72
+ num_labels: int = -1,
73
+ dataset: Optional[Type[Dataset]] = None,
74
+ models: Optional[Dict[str, Type[ProteinModel]]] = None):
75
+ """ Register a a new TAPE task. This creates a new TAPETaskSpec.
76
+
77
+ Args:
78
+
79
+ task_name (str): The name of the TAPE task.
80
+ num_labels (int): Number of labels used if this is a classification task. If this
81
+ is not a classification task, simply leave the default as -1.
82
+ dataset (Type[Dataset]): The dataset used in the TAPE task.
83
+ models (Optional[Dict[str, ProteinModel]]): The set of models that can be used for
84
+ this task. If you do not pass this argument, you can register models to the task
85
+ later by using `registry.register_task_model`. Default: {}.
86
+
87
+ Examples:
88
+
89
+ There are two ways of registering a new task. First, one can define the task by simply
90
+ declaring all the components, and then calling the register method, like so:
91
+
92
+ class SecondaryStructureDataset(Dataset):
93
+ ...
94
+
95
+ class ProteinBertForSequenceToSequenceClassification():
96
+ ...
97
+
98
+ registry.register_task(
99
+ 'secondary_structure', 3, SecondaryStructureDataset,
100
+ {'transformer': ProteinBertForSequenceToSequenceClassification})
101
+
102
+ This will register a new task, 'secondary_structure', with a single model. More models
103
+ can be added with `registry.register_task_model`. Alternatively, this can be used as a
104
+ decorator:
105
+
106
+ @registry.regsiter_task('secondary_structure', 3)
107
+ class SecondaryStructureDataset(Dataset):
108
+ ...
109
+
110
+ @registry.register_task_model('secondary_structure', 'transformer')
111
+ class ProteinBertForSequenceToSequenceClassification():
112
+ ...
113
+
114
+ These two pieces of code are exactly equivalent, in terms of the resulting registry
115
+ state.
116
+
117
+ """
118
+ if dataset is not None:
119
+ if models is None:
120
+ models = {}
121
+ task_spec = TAPETaskSpec(task_name, dataset, num_labels, models)
122
+ return cls.register_task_spec(task_name, task_spec).dataset
123
+ else:
124
+ return lambda dataset: cls.register_task(task_name, num_labels, dataset, models)
125
+
126
+ @classmethod
127
+ def register_task_spec(cls, task_name: str, task_spec: Optional[TAPETaskSpec] = None):
128
+ """ Registers a task_spec directly. If you find it easier to actually create a
129
+ TAPETaskSpec manually, and then register it, feel free to use this method,
130
+ but otherwise it is likely easier to use `registry.register_task`.
131
+ """
132
+ if task_spec is not None:
133
+ if task_name in cls.task_name_mapping:
134
+ raise KeyError(f"A task with name '{task_name}' is already registered")
135
+ cls.task_name_mapping[task_name] = task_spec
136
+ return task_spec
137
+ else:
138
+ return lambda task_spec: cls.register_task_spec(task_name, task_spec)
139
+
140
+ @classmethod
141
+ def register_task_model(cls,
142
+ task_name: str,
143
+ model_name: str,
144
+ model_cls: Optional[Type[ProteinModel]] = None):
145
+ r"""Register a specific model to a task with the provided model name.
146
+ The task must already be in the registry - you cannot register a
147
+ model to an unregistered task.
148
+
149
+ Args:
150
+ task_name (str): Name of task to which to register the model.
151
+ model_name (str): Name of model to use when registering task, this
152
+ is the name that you will use to refer to the model on the
153
+ command line.
154
+ model_cls (Type[ProteinModel]): The model to register.
155
+
156
+ Examples:
157
+
158
+ As with `registry.register_task`, this can both be used as a regular
159
+ python function, and as a decorator. For example this:
160
+
161
+ class ProteinBertForSequenceToSequenceClassification():
162
+ ...
163
+ registry.register_task_model(
164
+ 'secondary_structure', 'transformer',
165
+ ProteinBertForSequenceToSequenceClassification)
166
+
167
+ and as a decorator:
168
+
169
+ @registry.register_task_model('secondary_structure', 'transformer')
170
+ class ProteinBertForSequenceToSequenceClassification():
171
+ ...
172
+
173
+ are both equivalent.
174
+ """
175
+ if task_name not in cls.task_name_mapping:
176
+ raise KeyError(
177
+ f"Tried to register a task model for an unregistered task: {task_name}. "
178
+ f"Make sure to register the task {task_name} first.")
179
+ return cls.task_name_mapping[task_name].register_model(model_name, model_cls)
180
+
181
+ @classmethod
182
+ def register_metric(cls, name: str) -> Callable[[Callable], Callable]:
183
+ r"""Register a metric to registry with key 'name'
184
+
185
+ Args:
186
+ name: Key with which the metric will be registered.
187
+
188
+ Usage::
189
+ from tape.registry import registry
190
+
191
+ @registry.register_metric('mse')
192
+ def mean_squred_error(inputs, outputs):
193
+ ...
194
+ """
195
+
196
+ def wrap(fn: Callable) -> Callable:
197
+ assert callable(fn), "All metrics must be callable"
198
+ cls.metric_name_mapping[name] = fn
199
+ return fn
200
+
201
+ return wrap
202
+
203
+ @classmethod
204
+ def get_task_spec(cls, name: str) -> TAPETaskSpec:
205
+ return cls.task_name_mapping[name]
206
+
207
+ @classmethod
208
+ def get_metric(cls, name: str) -> Callable:
209
+ return cls.metric_name_mapping[name]
210
+
211
+ @classmethod
212
+ def get_task_model(cls,
213
+ model_name: str,
214
+ task_name: str,
215
+ config_file: Optional[PathType] = None,
216
+ load_dir: Optional[PathType] = None,
217
+ model_args = None) -> ProteinModel:
218
+ """ Create a TAPE task model, either from scratch or from a pretrained model.
219
+ This is mostly a helper function that evaluates the if statements in a
220
+ sensible order if you pass all three of the arguments.
221
+ Args:
222
+ model_name (str): Which type of model to create (e.g. transformer, unirep, ...)
223
+ task_name (str): The TAPE task for which to create a model
224
+ config_file (str, optional): A json config file that specifies hyperparameters
225
+ load_dir (str, optional): A save directory for a pretrained model
226
+ Returns:
227
+ model (ProteinModel): A TAPE task model
228
+ """
229
+ task_spec = registry.get_task_spec(task_name)
230
+ model_cls = task_spec.get_model(model_name)
231
+
232
+ if load_dir is not None:
233
+ model = model_cls.from_pretrained(load_dir, num_labels=task_spec.num_labels)
234
+ else:
235
+ config_class = model_cls.config_class
236
+ if config_file is not None:
237
+ config = config_class.from_json_file(config_file)
238
+ else:
239
+ config = config_class()
240
+
241
+ if model_args:
242
+ model_args = convert_model_args(model_args)
243
+ for k,v in model_args.items():
244
+ if k in config.__dict__ and type(config.__dict__[k])==type(v):
245
+ setattr(config, k, v)
246
+ else:
247
+ raise ValueError(f"model arg {k} not in config or of the same type as default")
248
+
249
+ config.num_labels = task_spec.num_labels
250
+ model = model_cls(config)
251
+ return model
252
+
253
+
254
+ registry = Registry()
tape/tokenizers.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import logging
3
+ from collections import OrderedDict
4
+ import numpy as np
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ IUPAC_CODES = OrderedDict([
9
+ ('Ala', 'A'),
10
+ ('Asx', 'B'),
11
+ ('Cys', 'C'),
12
+ ('Asp', 'D'),
13
+ ('Glu', 'E'),
14
+ ('Phe', 'F'),
15
+ ('Gly', 'G'),
16
+ ('His', 'H'),
17
+ ('Ile', 'I'),
18
+ ('Lys', 'K'),
19
+ ('Leu', 'L'),
20
+ ('Met', 'M'),
21
+ ('Asn', 'N'),
22
+ ('Pro', 'P'),
23
+ ('Gln', 'Q'),
24
+ ('Arg', 'R'),
25
+ ('Ser', 'S'),
26
+ ('Thr', 'T'),
27
+ ('Sec', 'U'),
28
+ ('Val', 'V'),
29
+ ('Trp', 'W'),
30
+ ('Xaa', 'X'),
31
+ ('Tyr', 'Y'),
32
+ ('Glx', 'Z')])
33
+
34
+ IUPAC_VOCAB = OrderedDict([
35
+ ("<pad>", 0),
36
+ ("<mask>", 1),
37
+ ("<cls>", 2),
38
+ ("<sep>", 3),
39
+ ("<unk>", 4),
40
+ ("A", 5),
41
+ ("B", 6),
42
+ ("C", 7),
43
+ ("D", 8),
44
+ ("E", 9),
45
+ ("F", 10),
46
+ ("G", 11),
47
+ ("H", 12),
48
+ ("I", 13),
49
+ ("K", 14),
50
+ ("L", 15),
51
+ ("M", 16),
52
+ ("N", 17),
53
+ ("O", 18),
54
+ ("P", 19),
55
+ ("Q", 20),
56
+ ("R", 21),
57
+ ("S", 22),
58
+ ("T", 23),
59
+ ("U", 24),
60
+ ("V", 25),
61
+ ("W", 26),
62
+ ("X", 27),
63
+ ("Y", 28),
64
+ ("Z", 29)])
65
+
66
+ UNIREP_VOCAB = OrderedDict([
67
+ ("<pad>", 0),
68
+ ("M", 1),
69
+ ("R", 2),
70
+ ("H", 3),
71
+ ("K", 4),
72
+ ("D", 5),
73
+ ("E", 6),
74
+ ("S", 7),
75
+ ("T", 8),
76
+ ("N", 9),
77
+ ("Q", 10),
78
+ ("C", 11),
79
+ ("U", 12),
80
+ ("G", 13),
81
+ ("P", 14),
82
+ ("A", 15),
83
+ ("V", 16),
84
+ ("I", 17),
85
+ ("F", 18),
86
+ ("Y", 19),
87
+ ("W", 20),
88
+ ("L", 21),
89
+ ("O", 22),
90
+ ("X", 23),
91
+ ("Z", 23),
92
+ ("B", 23),
93
+ ("J", 23),
94
+ ("<cls>", 24),
95
+ ("<sep>", 25)])
96
+
97
+
98
+ class TAPETokenizer():
99
+ r"""TAPE Tokenizer. Can use different vocabs depending on the model.
100
+ """
101
+
102
+ def __init__(self, vocab: str = 'iupac'):
103
+ if vocab == 'iupac':
104
+ self.vocab = IUPAC_VOCAB
105
+ elif vocab == 'unirep':
106
+ self.vocab = UNIREP_VOCAB
107
+ self.tokens = list(self.vocab.keys())
108
+ self._vocab_type = vocab
109
+ assert self.start_token in self.vocab and self.stop_token in self.vocab
110
+
111
+ @property
112
+ def vocab_size(self) -> int:
113
+ return len(self.vocab)
114
+
115
+ @property
116
+ def start_token(self) -> str:
117
+ return "<cls>"
118
+
119
+ @property
120
+ def stop_token(self) -> str:
121
+ return "<sep>"
122
+
123
+ @property
124
+ def mask_token(self) -> str:
125
+ if "<mask>" in self.vocab:
126
+ return "<mask>"
127
+ else:
128
+ raise RuntimeError(f"{self._vocab_type} vocab does not support masking")
129
+
130
+ def tokenize(self, text: str) -> List[str]:
131
+ return [x for x in text]
132
+
133
+ def convert_token_to_id(self, token: str) -> int:
134
+ """ Converts a token (str/unicode) in an id using the vocab. """
135
+ try:
136
+ return self.vocab[token]
137
+ except KeyError:
138
+ raise KeyError(f"Unrecognized token: '{token}'")
139
+
140
+ def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
141
+ return [self.convert_token_to_id(token) for token in tokens]
142
+
143
+ def convert_id_to_token(self, index: int) -> str:
144
+ """Converts an index (integer) in a token (string/unicode) using the vocab."""
145
+ try:
146
+ return self.tokens[index]
147
+ except IndexError:
148
+ raise IndexError(f"Unrecognized index: '{index}'")
149
+
150
+ def convert_ids_to_tokens(self, indices: List[int]) -> List[str]:
151
+ return [self.convert_id_to_token(id_) for id_ in indices]
152
+
153
+ def convert_tokens_to_string(self, tokens: str) -> str:
154
+ """ Converts a sequence of tokens (string) in a single string. """
155
+ return ''.join(tokens)
156
+
157
+ def add_special_tokens(self, token_ids: List[str]) -> List[str]:
158
+ """
159
+ Adds special tokens to the a sequence for sequence classification tasks.
160
+ A BERT sequence has the following format: [CLS] X [SEP]
161
+ """
162
+ cls_token = [self.start_token]
163
+ sep_token = [self.stop_token]
164
+ return cls_token + token_ids + sep_token
165
+
166
+ def encode(self, text: str) -> np.ndarray:
167
+ tokens = self.tokenize(text)
168
+ tokens = self.add_special_tokens(tokens)
169
+ token_ids = self.convert_tokens_to_ids(tokens)
170
+ return np.array(token_ids, np.int64)
171
+
172
+ @classmethod
173
+ def from_pretrained(cls, **kwargs):
174
+ return cls()
tape/training.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import os
3
+ import logging
4
+ from timeit import default_timer as timer
5
+ import json
6
+ from pathlib import Path
7
+ import inspect
8
+ import pickle as pkl
9
+
10
+ from tqdm import tqdm
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.optim as optim
14
+ from torch.utils.data import DataLoader
15
+ from .optimization import WarmupLinearSchedule
16
+
17
+ from . import utils
18
+ from . import errors
19
+ from . import visualization
20
+ from .registry import registry
21
+ from .models.modeling_utils import ProteinModel
22
+
23
+ try:
24
+ from apex import amp
25
+ import amp_C
26
+ import apex_C
27
+ from apex.amp import _amp_state
28
+ from apex.parallel.distributed import flat_dist_call
29
+ from apex.parallel.distributed import DistributedDataParallel as DDP
30
+ APEX_FOUND = True
31
+ except ImportError:
32
+ APEX_FOUND = False
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ MetricsDict = typing.Dict[str, float]
37
+ LossAndMetrics = typing.Tuple[float, MetricsDict]
38
+ OutputDict = typing.Dict[str, typing.Any]
39
+
40
+
41
+ class ForwardRunner:
42
+
43
+ def __init__(self,
44
+ model: ProteinModel,
45
+ device: torch.device = torch.device('cuda:0'),
46
+ n_gpu: int = 1,
47
+ fp16: bool = False,
48
+ local_rank: int = -1):
49
+
50
+ self.model = model
51
+ self.device = device
52
+ self.n_gpu = n_gpu
53
+ self.fp16 = fp16
54
+ self.local_rank = local_rank
55
+
56
+ forward_arg_keys = inspect.getfullargspec(model.forward).args
57
+ forward_arg_keys = forward_arg_keys[1:] # remove self argument
58
+ self._forward_arg_keys = forward_arg_keys
59
+ assert 'input_ids' in self._forward_arg_keys
60
+
61
+ def initialize_distributed_model(self):
62
+ if self.local_rank != -1:
63
+ if not self.fp16:
64
+ self.model = DDP(self.model)
65
+ else:
66
+ flat_dist_call([param.data for param in self.model.parameters()],
67
+ torch.distributed.broadcast, (0,))
68
+ elif self.n_gpu > 1:
69
+ self.model = nn.DataParallel(self.model)
70
+
71
+ def forward(self,
72
+ batch: typing.Dict[str, torch.Tensor],
73
+ return_outputs: bool = False,
74
+ no_loss: bool = False):
75
+ # Filter out batch items that aren't used in this model
76
+ # Requires that dataset keys match the forward args of the model
77
+ # Useful if some elements of the data are only used by certain models
78
+ # e.g. PSSMs / MSAs and other evolutionary data
79
+ batch = {name: tensor for name, tensor in batch.items()
80
+ if name in self._forward_arg_keys}
81
+ if self.device.type == 'cuda':
82
+ batch = {name: tensor.cuda(device=self.device, non_blocking=True)
83
+ for name, tensor in batch.items()}
84
+
85
+ outputs = self.model(**batch)
86
+
87
+ if no_loss:
88
+ return outputs
89
+
90
+ if isinstance(outputs[0], tuple):
91
+ # model also returned metrics
92
+ loss, metrics = outputs[0]
93
+ else:
94
+ # no metrics
95
+ loss = outputs[0]
96
+ metrics = {}
97
+
98
+ if self.n_gpu > 1: # pytorch DataDistributed doesn't mean scalars
99
+ loss = loss.mean()
100
+ metrics = {name: metric.mean() for name, metric in metrics.items()}
101
+
102
+ if return_outputs:
103
+ return loss, metrics, outputs
104
+ else:
105
+ return loss, metrics
106
+
107
+ def train(self):
108
+ self.model.train()
109
+ return self
110
+
111
+ def eval(self):
112
+ self.model.eval()
113
+ return self
114
+
115
+
116
+ class BackwardRunner(ForwardRunner):
117
+
118
+ def __init__(self,
119
+ model: ProteinModel,
120
+ optimizer: optim.Optimizer, # type: ignore
121
+ gradient_accumulation_steps: int = 1,
122
+ device: torch.device = torch.device('cuda:0'),
123
+ n_gpu: int = 1,
124
+ fp16: bool = False,
125
+ local_rank: int = -1,
126
+ max_grad_norm: float = 1.0,
127
+ warmup_steps: int = 0,
128
+ num_train_optimization_steps: int = 1000000):
129
+
130
+ super().__init__(model, device, n_gpu, fp16, local_rank)
131
+ self.optimizer = optimizer
132
+ self.max_grad_norm = max_grad_norm
133
+ self._global_step = 0
134
+ self._local_rank = local_rank
135
+ self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
136
+ self.gradient_accumulation_steps = gradient_accumulation_steps
137
+ self._delay_accumulation = fp16 and local_rank != -1
138
+
139
+ self.scheduler = WarmupLinearSchedule(
140
+ self.optimizer, warmup_steps, num_train_optimization_steps)
141
+
142
+ def initialize_fp16(self):
143
+ if self.fp16:
144
+ self.model, self.optimizer = amp.initialize(
145
+ self.model, self.optimizer, opt_level="O2", loss_scale="dynamic",
146
+ master_weights=True)
147
+ _amp_state.loss_scalers[0]._loss_scale = 2 ** 20
148
+
149
+ def resume_from_checkpoint(self, checkpoint_dir: str) -> int:
150
+ checkpoint = torch.load(
151
+ os.path.join(checkpoint_dir, 'checkpoint.bin'), map_location=self.device)
152
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
153
+ if self.fp16:
154
+ self.optimizer._lazy_init_maybe_master_weights()
155
+ self.optimizer._amp_stash.lazy_init_called = True
156
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
157
+ for param, saved in zip(
158
+ amp.master_params(self.optimizer), checkpoint['master params']):
159
+ param.data.copy_(saved.data)
160
+ amp.load_state_dict(checkpoint['amp'])
161
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
162
+ start_epoch = checkpoint['epoch'] + 1
163
+ return start_epoch
164
+
165
+ def save_state(self, save_directory: typing.Union[str, Path], epoch_id: int):
166
+ save_directory = Path(save_directory)
167
+ if not save_directory.exists():
168
+ save_directory.mkdir()
169
+ else:
170
+ assert save_directory.is_dir(), "Save path should be a directory"
171
+ model_to_save = getattr(self.model, 'module', self.model)
172
+ model_to_save.save_pretrained(save_directory)
173
+ optimizer_state: typing.Dict[str, typing.Any] = {
174
+ 'optimizer': self.optimizer.state_dict(),
175
+ 'scheduler': self.scheduler.state_dict(),
176
+ 'epoch': epoch_id}
177
+ if APEX_FOUND:
178
+ optimizer_state['master params'] = list(amp.master_params(self.optimizer))
179
+ try:
180
+ optimizer_state['amp'] = amp.state_dict()
181
+ except AttributeError:
182
+ pass
183
+ torch.save(optimizer_state, save_directory / 'checkpoint.bin')
184
+
185
+ def backward(self, loss) -> None:
186
+ if not self._delay_accumulation:
187
+ loss = loss / self.gradient_accumulation_steps
188
+ if self.fp16:
189
+ with amp.scale_loss(loss, self.optimizer,
190
+ delay_overflow_check=self._delay_accumulation) as scaled_loss:
191
+ scaled_loss.backward()
192
+ else:
193
+ loss.backward()
194
+
195
+ def step(self) -> None:
196
+ nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
197
+ if self._local_rank == -1:
198
+ self._step()
199
+ elif not self.fp16:
200
+ # TODO: Can you do this allreduce after accumulation also?
201
+ self._step()
202
+ else:
203
+ self._step_distributed_fp16()
204
+
205
+ def _step(self) -> None:
206
+ self.optimizer.step()
207
+ if self.scheduler is not None:
208
+ self.scheduler.step() # type: ignore
209
+ self._global_step += 1
210
+
211
+ def _step_distributed_fp16(self) -> None:
212
+ # manually allreduce gradients after all accumulation steps
213
+ # check for Inf/NaN
214
+ # 1. allocate an uninitialized buffer for flattened gradient
215
+ scaler = _amp_state.loss_scalers[0]
216
+ master_grads = [p.grad for p in amp.master_params(self.optimizer) if p.grad is not None]
217
+ flat_grad_size = sum(p.numel() for p in master_grads)
218
+ # allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else \
219
+ # torch.float32
220
+ allreduce_dtype = torch.float16
221
+ flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype)
222
+ # 2. combine unflattening and predivision of unscaled 'raw' gradient
223
+ allreduced_views = apex_C.unflatten(flat_raw, master_grads)
224
+ self._overflow_buf.zero_()
225
+ amp_C.multi_tensor_scale(
226
+ 65536,
227
+ self._overflow_buf,
228
+ [master_grads, allreduced_views],
229
+ scaler.loss_scale() / (
230
+ torch.distributed.get_world_size() * self.gradient_accumulation_steps))
231
+ # 3. sum gradient across ranks. Because of the predivision, this averages the gradient
232
+ torch.distributed.all_reduce(flat_raw)
233
+ # 4. combine unscaling and unflattening of allreduced gradient
234
+ self._overflow_buf.zero_()
235
+ amp_C.multi_tensor_scale(
236
+ 65536,
237
+ self._overflow_buf,
238
+ [allreduced_views, master_grads],
239
+ 1. / scaler.loss_scale())
240
+ # 5. update loss scale
241
+ scaler = _amp_state.loss_scalers[0]
242
+ old_overflow_buf = scaler._overflow_buf
243
+ scaler._overflow_buf = self._overflow_buf
244
+ had_overflow = scaler.update_scale()
245
+ scaler._overfloat_buf = old_overflow_buf
246
+ # 6. call optimizer step function
247
+ if had_overflow == 0:
248
+ self._step()
249
+ else:
250
+ # Overflow detected, print message and clear gradients
251
+ logger.info(f"Gradient overflow. Skipping step, reducing loss scale to "
252
+ f"{scaler.loss_scale()}")
253
+ if _amp_state.opt_properties.master_weights:
254
+ for param in self.optimizer._amp_stash.all_fp32_from_fp16_params:
255
+ param.grad = None
256
+ for param in self.model.parameters():
257
+ param.grad = None
258
+
259
+ @property
260
+ def global_step(self) -> int:
261
+ return self._global_step
262
+
263
+
264
+ def run_train_epoch(epoch_id: int,
265
+ train_loader: DataLoader,
266
+ runner: BackwardRunner,
267
+ viz: typing.Optional[visualization.TAPEVisualizer] = None,
268
+ num_log_iter: int = 20,
269
+ gradient_accumulation_steps: int = 1,
270
+ num_steps_per_epoch: int = -1) -> LossAndMetrics:
271
+ if viz is None:
272
+ viz = visualization.DummyVisualizer()
273
+ smoothing = 1 - 1 / num_log_iter
274
+ accumulator = utils.MetricsAccumulator(smoothing)
275
+
276
+ torch.set_grad_enabled(True)
277
+ runner.train()
278
+
279
+ def make_log_str(step: int, time: float) -> str:
280
+ ep_percent = epoch_id + step / len(train_loader)
281
+ if runner.scheduler is not None:
282
+ curr_lr = runner.scheduler.get_lr()[0] # type: ignore
283
+ else:
284
+ curr_lr = runner.optimizer.param_groups[0]['lr']
285
+
286
+ print_str = []
287
+ print_str.append(f"[Ep: {ep_percent:.2f}]")
288
+ print_str.append(f"[Iter: {runner.global_step}]")
289
+ print_str.append(f"[Time: {time:5.2f}s]")
290
+ print_str.append(f"[Loss: {accumulator.loss():.5g}]")
291
+
292
+ for name, value in accumulator.metrics().items():
293
+ print_str.append(f"[{name.capitalize()}: {value:.5g}]")
294
+
295
+ print_str.append(f"[LR: {curr_lr:.5g}]")
296
+ return ''.join(print_str)
297
+
298
+ start_t = timer()
299
+ for step, batch in enumerate(train_loader):
300
+ loss, metrics = runner.forward(batch) # type: ignore
301
+ runner.backward(loss)
302
+ accumulator.update(loss, metrics, step=False)
303
+ if (step + 1) % gradient_accumulation_steps == 0:
304
+ runner.step()
305
+ viz.log_metrics(accumulator.step(), "train", runner.global_step)
306
+ if runner.global_step % num_log_iter == 0:
307
+ end_t = timer()
308
+ logger.info(make_log_str(step, end_t - start_t))
309
+ start_t = end_t
310
+ if num_steps_per_epoch != -1 and (step + 1) > num_steps_per_epoch:
311
+ break
312
+
313
+ final_print_str = f"Train: [Loss: {accumulator.final_loss():.5g}]"
314
+ for name, value in accumulator.final_metrics().items():
315
+ final_print_str += f"[{name.capitalize()}: {value:.5g}]"
316
+ logger.info(final_print_str)
317
+ return accumulator.final_loss(), accumulator.final_metrics()
318
+
319
+
320
+ def run_valid_epoch(epoch_id: int,
321
+ valid_loader: DataLoader,
322
+ runner: ForwardRunner,
323
+ viz: typing.Optional[visualization.TAPEVisualizer] = None,
324
+ is_master: bool = True,
325
+ val_check_frac: float = 1.0) -> typing.Tuple[float, typing.Dict[str, float]]:
326
+ num_batches = len(valid_loader)
327
+ num_batches_to_run = int(num_batches * val_check_frac)
328
+ accumulator = utils.MetricsAccumulator()
329
+
330
+ torch.set_grad_enabled(False)
331
+ runner.eval()
332
+
333
+ for idx, batch in enumerate(tqdm(valid_loader, desc='Running Eval', total=num_batches_to_run,
334
+ disable=not is_master, leave=False)):
335
+ loss, metrics = runner.forward(batch) # type: ignore
336
+ accumulator.update(loss, metrics)
337
+ if idx>num_batches_to_run:
338
+ break
339
+
340
+ # Reduce loss across all processes if multiprocessing
341
+ eval_loss = utils.reduce_scalar(accumulator.final_loss())
342
+ metrics = {name: utils.reduce_scalar(value)
343
+ for name, value in accumulator.final_metrics().items()}
344
+
345
+ print_str = f"Evaluation: [Loss: {eval_loss:.5g}]"
346
+ for name, value in metrics.items():
347
+ print_str += f"[{name.capitalize()}: {value:.5g}]"
348
+
349
+ metrics['loss'] = eval_loss
350
+ if viz is not None:
351
+ viz.log_metrics(metrics, "val", getattr(runner, 'global_step', epoch_id))
352
+
353
+ logger.info(print_str)
354
+
355
+ return eval_loss, metrics
356
+
357
+
358
+ def _get_outputs_to_save(batch, outputs):
359
+ targets = batch['targets'].cpu().numpy()
360
+ outputs = outputs.cpu().numpy()
361
+ protein_length = batch['protein_length'].sum(1).cpu().numpy()
362
+
363
+ reshaped_output = []
364
+ for target, output, plength in zip(targets, outputs, protein_length):
365
+ output_slices = tuple(slice(1, plength - 1) if dim == protein_length.max() else
366
+ slice(0, dim) for dim in output.shape)
367
+ output = output[output_slices]
368
+ target = target[output_slices]
369
+
370
+ reshaped_output.append((target, output))
371
+ reshaped_output
372
+
373
+
374
+ def run_eval_epoch(eval_loader: DataLoader,
375
+ runner: ForwardRunner,
376
+ is_master: bool = True) -> typing.List[typing.Dict[str, typing.Any]]:
377
+ torch.set_grad_enabled(False)
378
+ runner.eval()
379
+
380
+ save_outputs = []
381
+
382
+ for batch in tqdm(eval_loader, desc='Evaluation', total=len(eval_loader),
383
+ disable=not is_master):
384
+ loss, metrics, outputs = runner.forward(batch, return_outputs=True) # type: ignore
385
+ predictions = outputs[1].cpu().numpy()
386
+ targets = batch['targets'].cpu().numpy()
387
+ for pred, target in zip(predictions, targets):
388
+ save_outputs.append({'prediction': pred, 'target': target})
389
+
390
+ return save_outputs
391
+
392
+
393
+ def run_train(model_type: str,
394
+ task: str,
395
+ learning_rate: float = 1e-4,
396
+ batch_size: int = 1024,
397
+ num_train_epochs: int = 10,
398
+ num_log_iter: int = 20,
399
+ fp16: bool = False,
400
+ warmup_steps: int = 10000,
401
+ gradient_accumulation_steps: int = 1,
402
+ loss_scale: int = 0,
403
+ max_grad_norm: float = 1.0,
404
+ exp_name: typing.Optional[str] = None,
405
+ from_pretrained: typing.Optional[str] = None,
406
+ log_dir: str = './logs',
407
+ eval_freq: int = 1,
408
+ save_freq: typing.Union[int, str] = 1,
409
+ model_config_file: typing.Optional[str] = None,
410
+ data_dir: str = './data',
411
+ output_dir: str = './results',
412
+ no_cuda: bool = False,
413
+ seed: int = 42,
414
+ local_rank: int = -1,
415
+ tokenizer: str = 'iupac',
416
+ num_workers: int = 8,
417
+ debug: bool = False,
418
+ log_level: typing.Union[str, int] = logging.INFO,
419
+ patience: int = -1,
420
+ resume_from_checkpoint: bool = False,
421
+ model_args = None,
422
+ num_steps_per_epoch: int = -1,
423
+ val_check_frac: float = 1.0) -> None:
424
+
425
+ # SETUP AND LOGGING CODE #
426
+ input_args = locals()
427
+ device, n_gpu, is_master = utils.setup_distributed(
428
+ local_rank, no_cuda)
429
+
430
+ exp_dir = utils.get_expname(exp_name, task, model_type)
431
+ save_path = Path(output_dir) / exp_dir
432
+
433
+ if is_master:
434
+ # save all the hidden parameters.
435
+ save_path.mkdir(parents=True, exist_ok=True)
436
+ with (save_path / 'args.json').open('w') as f:
437
+ json.dump(input_args, f)
438
+
439
+ utils.barrier_if_distributed()
440
+ utils.setup_logging(local_rank, save_path, log_level)
441
+ utils.set_random_seeds(seed, n_gpu)
442
+
443
+ train_dataset = utils.setup_dataset(task, data_dir, 'train', tokenizer)
444
+ valid_dataset = utils.setup_dataset(task, data_dir, 'valid', tokenizer)
445
+ train_loader = utils.setup_loader(
446
+ train_dataset, batch_size, local_rank, n_gpu,
447
+ gradient_accumulation_steps, num_workers)
448
+ valid_loader = utils.setup_loader(
449
+ valid_dataset, batch_size, local_rank, n_gpu,
450
+ gradient_accumulation_steps, num_workers)
451
+
452
+ num_train_optimization_steps = utils.get_num_train_optimization_steps(
453
+ train_dataset, batch_size, num_train_epochs)
454
+
455
+ model = registry.get_task_model(model_type, task, model_config_file, from_pretrained, model_args)
456
+ model = model.to(device)
457
+ optimizer = utils.setup_optimizer(model, learning_rate)
458
+ viz = visualization.get(log_dir, exp_dir, local_rank, debug=debug)
459
+ viz.log_config(input_args)
460
+ viz.log_config(model.config.to_dict())
461
+ viz.watch(model)
462
+
463
+ logger.info(
464
+ f"device: {device} "
465
+ f"n_gpu: {n_gpu}, "
466
+ f"distributed_training: {local_rank != -1}, "
467
+ f"16-bits training: {fp16}")
468
+
469
+ runner = BackwardRunner(
470
+ model, optimizer, gradient_accumulation_steps, device, n_gpu,
471
+ fp16, local_rank, max_grad_norm, warmup_steps, num_train_optimization_steps)
472
+
473
+ runner.initialize_fp16()
474
+ if resume_from_checkpoint:
475
+ assert from_pretrained is not None
476
+ start_epoch = runner.resume_from_checkpoint(from_pretrained)
477
+ else:
478
+ start_epoch = 0
479
+ runner.initialize_distributed_model()
480
+
481
+ num_train_optimization_steps = utils.get_num_train_optimization_steps(
482
+ train_dataset, batch_size, num_train_epochs)
483
+ is_master = local_rank in (-1, 0)
484
+
485
+ if isinstance(save_freq, str) and save_freq != 'improvement':
486
+ raise ValueError(
487
+ f"Only recongized string value for save_freq is 'improvement'"
488
+ f", received: {save_freq}")
489
+
490
+ if save_freq == 'improvement' and eval_freq <= 0:
491
+ raise ValueError("Cannot set save_freq to 'improvement' and eval_freq < 0")
492
+
493
+ num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
494
+ logger.info("***** Running training *****")
495
+ logger.info(" Num examples = %d", len(train_dataset))
496
+ logger.info(" Batch size = %d", batch_size)
497
+ logger.info(" Num epochs = %d", num_train_epochs)
498
+ logger.info(" Num train steps = %d", num_train_optimization_steps)
499
+ logger.info(" Num parameters = %d", num_trainable_parameters)
500
+
501
+ best_val_loss = float('inf')
502
+ num_evals_no_improvement = 0
503
+
504
+ def do_save(epoch_id: int, num_evals_no_improvement: int) -> bool:
505
+ if not is_master:
506
+ return False
507
+ if isinstance(save_freq, int):
508
+ return ((epoch_id + 1) % save_freq == 0) or ((epoch_id + 1) == num_train_epochs)
509
+ else:
510
+ return num_evals_no_improvement == 0
511
+
512
+ utils.barrier_if_distributed()
513
+
514
+ # ACTUAL TRAIN/EVAL LOOP #
515
+ with utils.wrap_cuda_oom_error(local_rank, batch_size, n_gpu, gradient_accumulation_steps):
516
+ for epoch_id in range(start_epoch, num_train_epochs):
517
+ run_train_epoch(epoch_id, train_loader, runner,
518
+ viz, num_log_iter, gradient_accumulation_steps, num_steps_per_epoch)
519
+ if eval_freq > 0 and (epoch_id + 1) % eval_freq == 0:
520
+ val_loss, _ = run_valid_epoch(epoch_id, valid_loader, runner, viz, is_master, val_check_frac)
521
+ if val_loss < best_val_loss:
522
+ best_val_loss = val_loss
523
+ num_evals_no_improvement = 0
524
+ else:
525
+ num_evals_no_improvement += 1
526
+
527
+ # Save trained model
528
+ if do_save(epoch_id, num_evals_no_improvement):
529
+ logger.info("** ** * Saving trained model ** ** * ")
530
+ # Only save the model itself
531
+ runner.save_state(save_path, epoch_id)
532
+ logger.info(f"Saving model checkpoint to {save_path}")
533
+
534
+ utils.barrier_if_distributed()
535
+ if patience > 0 and num_evals_no_improvement >= patience:
536
+ logger.info(f"Finished training at epoch {epoch_id} because no "
537
+ f"improvement for {num_evals_no_improvement} epochs.")
538
+ logger.log(35, f"Best Val Loss: {best_val_loss}")
539
+ if local_rank != -1:
540
+ # If you're distributed, raise this error. It sends a signal to
541
+ # the master process which lets it kill other processes and terminate
542
+ # without actually reporting an error. See utils/distributed_utils.py
543
+ # for the signal handling code.
544
+ raise errors.EarlyStopping
545
+ else:
546
+ break
547
+ logger.info(f"Finished training after {num_train_epochs} epochs.")
548
+ if best_val_loss != float('inf'):
549
+ logger.log(35, f"Best Val Loss: {best_val_loss}")
550
+
551
+
552
+ def run_eval(model_type: str,
553
+ task: str,
554
+ from_pretrained: str,
555
+ split: str = 'test',
556
+ batch_size: int = 1024,
557
+ model_config_file: typing.Optional[str] = None,
558
+ data_dir: str = './data',
559
+ no_cuda: bool = False,
560
+ seed: int = 42,
561
+ tokenizer: str = 'iupac',
562
+ num_workers: int = 8,
563
+ debug: bool = False,
564
+ metrics: typing.Tuple[str, ...] = (),
565
+ log_level: typing.Union[str, int] = logging.INFO) -> typing.Dict[str, float]:
566
+
567
+ local_rank = -1 # TAPE does not support torch.distributed.launch for evaluation
568
+ device, n_gpu, is_master = utils.setup_distributed(local_rank, no_cuda)
569
+ utils.setup_logging(local_rank, save_path=None, log_level=log_level)
570
+ utils.set_random_seeds(seed, n_gpu)
571
+
572
+ pretrained_dir = Path(from_pretrained)
573
+
574
+ logger.info(
575
+ f"device: {device} "
576
+ f"n_gpu: {n_gpu}")
577
+
578
+ model = registry.get_task_model(model_type, task, model_config_file, from_pretrained)
579
+ model = model.to(device)
580
+
581
+ runner = ForwardRunner(model, device, n_gpu)
582
+ runner.initialize_distributed_model()
583
+ valid_dataset = utils.setup_dataset(task, data_dir, split, tokenizer)
584
+ valid_loader = utils.setup_loader(
585
+ valid_dataset, batch_size, local_rank, n_gpu,
586
+ 1, num_workers)
587
+
588
+ metric_functions = [registry.get_metric(name) for name in metrics]
589
+ save_outputs = run_eval_epoch(valid_loader, runner, is_master)
590
+ target = [el['target'] for el in save_outputs]
591
+ prediction = [el['prediction'] for el in save_outputs]
592
+
593
+ metrics_to_save = {name: metric(target, prediction)
594
+ for name, metric in zip(metrics, metric_functions)}
595
+ logger.info(''.join(f'{name}: {val}' for name, val in metrics_to_save.items()))
596
+
597
+ with (pretrained_dir / 'results.pkl').open('wb') as f:
598
+ pkl.dump((metrics_to_save, save_outputs), f)
599
+
600
+ return metrics_to_save
601
+
602
+
603
+ def run_embed(model_type: str,
604
+ data_file: str,
605
+ out_file: str,
606
+ from_pretrained: str,
607
+ batch_size: int = 1024,
608
+ model_config_file: typing.Optional[str] = None,
609
+ full_sequence_embed: bool = False,
610
+ no_cuda: bool = False,
611
+ seed: int = 42,
612
+ tokenizer: str = 'iupac',
613
+ num_workers: int = 8,
614
+ log_level: typing.Union[str, int] = logging.INFO) -> None:
615
+
616
+ local_rank = -1 # TAPE does not support torch.distributed.launch for embedding
617
+ device, n_gpu, is_master = utils.setup_distributed(local_rank, no_cuda)
618
+ utils.setup_logging(local_rank, save_path=None, log_level=log_level)
619
+ utils.set_random_seeds(seed, n_gpu)
620
+
621
+ logger.info(
622
+ f"device: {device} "
623
+ f"n_gpu: {n_gpu}")
624
+
625
+ task_spec = registry.get_task_spec('embed')
626
+ model = registry.get_task_model(
627
+ model_type, task_spec.name, model_config_file, from_pretrained)
628
+ model = model.to(device)
629
+ runner = ForwardRunner(model, device, n_gpu)
630
+ runner.initialize_distributed_model()
631
+ runner.eval()
632
+ torch.set_grad_enabled(False)
633
+
634
+ dataset = task_spec.dataset(data_file, tokenizer=tokenizer) # type: ignore
635
+ valid_loader = utils.setup_loader(dataset, batch_size, local_rank, n_gpu, 1, num_workers)
636
+
637
+ with utils.IncrementalNPZ(out_file) as npzfile:
638
+ with utils.wrap_cuda_oom_error(local_rank, batch_size, n_gpu):
639
+ for batch in tqdm(valid_loader, total=len(valid_loader)):
640
+ outputs = runner.forward(batch, no_loss=True)
641
+ ids = batch['ids']
642
+ sequence_embed = outputs[0]
643
+ pooled_embed = outputs[1]
644
+ sequence_lengths = batch['input_mask'].sum(1)
645
+ sequence_embed = sequence_embed.cpu().numpy()
646
+ pooled_embed = pooled_embed.cpu().numpy()
647
+ sequence_lengths = sequence_lengths.cpu().numpy()
648
+
649
+ for seqembed, poolembed, length, protein_id in zip(
650
+ sequence_embed, pooled_embed, sequence_lengths, ids):
651
+ seqembed = seqembed[:length]
652
+ arrays = {'pooled': poolembed}
653
+ if not full_sequence_embed:
654
+ # avgpool across the sequence
655
+ arrays['avg'] = seqembed.mean(0)
656
+ else:
657
+ arrays['seq'] = seqembed
658
+ to_save = {protein_id: arrays}
659
+ npzfile.savez(**to_save)
tape/utils/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import int_or_str # noqa: F401
2
+ from .utils import check_is_file # noqa: F401
3
+ from .utils import check_is_dir # noqa: F401
4
+ from .utils import path_to_datetime # noqa: F401
5
+ from .utils import get_expname # noqa: F401
6
+ from .utils import get_effective_num_gpus # noqa: F401
7
+ from .utils import get_effective_batch_size # noqa: F401
8
+ from .utils import get_num_train_optimization_steps # noqa: F401
9
+ from .utils import set_random_seeds # noqa: F401
10
+ from .utils import MetricsAccumulator # noqa: F401
11
+ from .utils import wrap_cuda_oom_error # noqa: F401
12
+ from .utils import write_lmdb # noqa: F401
13
+ from .utils import IncrementalNPZ # noqa: F401
14
+
15
+ from .setup_utils import setup_logging # noqa: F401
16
+ from .setup_utils import setup_optimizer # noqa: F401
17
+ from .setup_utils import setup_dataset # noqa: F401
18
+ from .setup_utils import setup_loader # noqa: F401
19
+ from .setup_utils import setup_distributed # noqa: F401
20
+
21
+ from .distributed_utils import barrier_if_distributed # noqa: F401
22
+ from .distributed_utils import reduce_scalar # noqa: F401
23
+ from .distributed_utils import launch_process_group # noqa: F401
tape/utils/_sampler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of a bucketed data sampler from PyTorch-NLP.
2
+ Modified by Roshan Rao.
3
+
4
+ See https://github.com/PetrochukM/PyTorch-NLP/
5
+ """
6
+ import typing
7
+ import math
8
+ import operator
9
+ from torch.utils.data.sampler import Sampler
10
+ from torch.utils.data.sampler import BatchSampler
11
+ from torch.utils.data.sampler import SubsetRandomSampler
12
+
13
+
14
+ class SortedSampler(Sampler):
15
+ """ Samples elements sequentially, always in the same order.
16
+ Args:
17
+ data (iterable): Iterable data.
18
+ sort_key (callable): Specifies a function of one argument that is used to extract a
19
+ numerical comparison key from each list element.
20
+ Example:
21
+ >>> list(SortedSampler(range(10), sort_key=lambda i: -i))
22
+ [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
23
+ """
24
+
25
+ def __init__(self,
26
+ dataset,
27
+ sort_key: typing.Callable[[int], typing.Any],
28
+ indices: typing.Optional[typing.Iterable[int]] = None):
29
+ super().__init__(dataset)
30
+ self.dataset = dataset
31
+ self.sort_key = sort_key
32
+ if indices is None:
33
+ sort_keys = map(sort_key, dataset)
34
+ else:
35
+ sort_keys = ((i, sort_key(dataset[i])) for i in indices)
36
+ self.sorted_indices = [i for i, _ in sorted(sort_keys, key=operator.itemgetter(1))]
37
+
38
+ def __iter__(self):
39
+ return iter(self.sorted_indices)
40
+
41
+ def __len__(self):
42
+ return len(self.dataset)
43
+
44
+
45
+ class BucketBatchSampler(BatchSampler):
46
+ """ `BucketBatchSampler` toggles between `sampler` batches and sorted batches.
47
+ Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between
48
+ random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted
49
+ and vice versa. Provides ~10-25 percent speedup.
50
+
51
+ Background:
52
+ ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular
53
+ libraries like ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together
54
+ examples with a similar size length to reduce the padding required for each batch
55
+ while maintaining some noise through bucketing.
56
+
57
+ Args:
58
+ sampler (torch.data.utils.sampler.Sampler):
59
+ batch_size (int): Size of mini-batch.
60
+ drop_last (bool): If `True` the sampler will drop the last batch if its size
61
+ would be less than `batch_size`.
62
+ sort_key (callable, optional): Callable to specify a comparison key for sorting.
63
+ bucket_size_multiplier (int, optional): Buckets are of size
64
+ `batch_size * bucket_size_multiplier`.
65
+ Example:
66
+ >>> from torch.utils.data.sampler import SequentialSampler
67
+ >>> sampler = SequentialSampler(list(range(10)))
68
+ >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False))
69
+ [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]]
70
+ >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True))
71
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
72
+ """
73
+
74
+ def __init__(self,
75
+ sampler,
76
+ batch_size,
77
+ drop_last,
78
+ sort_key,
79
+ dataset,
80
+ bucket_size_multiplier=100):
81
+ super().__init__(sampler, batch_size, drop_last)
82
+ self.sort_key = sort_key
83
+ self.dataset = dataset
84
+ self.bucket_sampler = BatchSampler(
85
+ sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False)
86
+
87
+ def __iter__(self):
88
+ for bucket in self.bucket_sampler:
89
+ sorted_sampler = SortedSampler(self.dataset, self.sort_key, indices=bucket)
90
+ for batch in SubsetRandomSampler(
91
+ list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
92
+ yield batch
93
+
94
+ def __len__(self):
95
+ if self.drop_last:
96
+ return len(self.sampler) // self.batch_size
97
+ else:
98
+ return math.ceil(len(self.sampler) / self.batch_size)
tape/utils/distributed_utils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import argparse
3
+ import os
4
+ import multiprocessing as mp
5
+ import sys
6
+ import signal
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch.multiprocessing import _prctl_pr_set_pdeathsig # type: ignore
11
+
12
+ from ..errors import EarlyStopping
13
+
14
+
15
+ def reduce_scalar(scalar: float) -> float:
16
+ if dist.is_available() and dist.is_initialized():
17
+ float_tensor = torch.cuda.FloatTensor([scalar]) # type: ignore
18
+ dist.all_reduce(float_tensor)
19
+ float_tensor /= dist.get_world_size()
20
+ scalar = float_tensor.item()
21
+ return scalar
22
+
23
+
24
+ def barrier_if_distributed() -> None:
25
+ """Raises a barrier if in a distributed context, otherwise does nothing."""
26
+ if dist.is_available() and dist.is_initialized():
27
+ dist.barrier()
28
+
29
+
30
+ def _wrap(fn, kwargs, error_queue):
31
+ # prctl(2) is a Linux specific system call.
32
+ # On other systems the following function call has no effect.
33
+ # This is set to ensure that non-daemonic child processes can
34
+ # terminate if their parent terminates before they do.
35
+ _prctl_pr_set_pdeathsig(signal.SIGINT)
36
+
37
+ try:
38
+ fn(**kwargs)
39
+ except KeyboardInterrupt:
40
+ pass # SIGINT; Killed by parent, do nothing
41
+ except EarlyStopping:
42
+ sys.exit(signal.SIGUSR1) # tape early stop exception
43
+ except Exception:
44
+ # Propagate exception to parent process, keeping original traceback
45
+ import traceback
46
+ error_queue.put(traceback.format_exc())
47
+ sys.exit(1)
48
+
49
+
50
+ class ProcessContext:
51
+ def __init__(self, processes, error_queues):
52
+ self.error_queues = error_queues
53
+ self.processes = processes
54
+ self.sentinels = {
55
+ process.sentinel: index
56
+ for index, process in enumerate(processes)
57
+ }
58
+
59
+ def pids(self):
60
+ return [int(process.pid) for process in self.processes]
61
+
62
+ def join(self, timeout=None):
63
+ r"""
64
+ Tries to join one or more processes in this process context.
65
+ If one of them exited with a non-zero exit status, this function
66
+ kills the remaining processes and raises an exception with the cause
67
+ of the first process exiting.
68
+
69
+ Returns ``True`` if all processes have been joined successfully,
70
+ ``False`` if there are more processes that need to be joined.
71
+
72
+ Arguments:
73
+ timeout (float): Wait this long before giving up on waiting.
74
+ """
75
+ # Ensure this function can be called even when we're done.
76
+ if len(self.sentinels) == 0:
77
+ return True
78
+
79
+ # Wait for any process to fail or all of them to succeed.
80
+ ready = mp.connection.wait(
81
+ self.sentinels.keys(),
82
+ timeout=timeout,
83
+ )
84
+ error_index = None
85
+ for sentinel in ready:
86
+ index = self.sentinels.pop(sentinel)
87
+ process = self.processes[index]
88
+ process.join()
89
+ if process.exitcode != 0:
90
+ error_index = index
91
+ break
92
+ # Return if there was no error.
93
+ if error_index is None:
94
+ # Return whether or not all processes have been joined.
95
+ return len(self.sentinels) == 0
96
+ # Assume failure. Terminate processes that are still alive.
97
+ for process in self.processes:
98
+ if process.is_alive():
99
+ process.terminate()
100
+ process.join()
101
+
102
+ # There won't be an error on the queue if the process crashed.
103
+ if self.error_queues[error_index].empty():
104
+ exitcode = self.processes[error_index].exitcode
105
+ if exitcode == signal.SIGUSR1:
106
+ return True
107
+ elif exitcode < 0:
108
+ name = signal.Signals(-exitcode).name
109
+ raise Exception(
110
+ "process %d terminated with signal %s" %
111
+ (error_index, name)
112
+ )
113
+ else:
114
+ raise Exception(
115
+ "process %d terminated with exit code %d" %
116
+ (error_index, exitcode)
117
+ )
118
+
119
+ original_trace = self.error_queues[error_index].get()
120
+ msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
121
+ msg += original_trace
122
+ raise Exception(msg)
123
+
124
+
125
+ def launch_process_group(func: typing.Callable,
126
+ args: argparse.Namespace,
127
+ num_processes: int,
128
+ num_nodes: int = 1,
129
+ node_rank: int = 0,
130
+ master_addr: str = "127.0.0.1",
131
+ master_port: int = 29500,
132
+ join: bool = True,
133
+ daemon: bool = False):
134
+ # world size in terms of number of processes
135
+ dist_world_size = num_processes * num_nodes
136
+
137
+ # set PyTorch distributed related environmental variables
138
+ current_env = os.environ.copy()
139
+ current_env["MASTER_ADDR"] = master_addr
140
+ current_env["MASTER_PORT"] = str(master_port)
141
+ current_env["WORLD_SIZE"] = str(dist_world_size)
142
+ if 'OMP_NUM_THREADS' not in os.environ and num_processes > 1:
143
+ current_env["OMP_NUM_THREADS"] = str(4)
144
+
145
+ error_queues = []
146
+ processes = []
147
+
148
+ for local_rank in range(num_processes):
149
+ # each process's rank
150
+ dist_rank = num_processes * node_rank + local_rank
151
+ current_env["RANK"] = str(dist_rank)
152
+ current_env["LOCAL_RANK"] = str(local_rank)
153
+ args.local_rank = local_rank
154
+
155
+ error_queue: mp.SimpleQueue[Exception] = mp.SimpleQueue()
156
+ kwargs = {'args': args, 'env': current_env}
157
+ process = mp.Process(
158
+ target=_wrap,
159
+ args=(func, kwargs, error_queue),
160
+ daemon=daemon)
161
+ process.start()
162
+ error_queues.append(error_queue)
163
+ processes.append(process)
164
+
165
+ process_context = ProcessContext(processes, error_queues)
166
+ if not join:
167
+ return process_context
168
+
169
+ while not process_context.join():
170
+ pass
tape/utils/setup_utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions to help setup the model, optimizer, distributed compute, etc.
2
+ """
3
+ import typing
4
+ import logging
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch.utils.data import DataLoader, RandomSampler, Dataset
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from ..optimization import AdamW
13
+
14
+ from ..registry import registry
15
+
16
+ from .utils import get_effective_batch_size
17
+ from ._sampler import BucketBatchSampler
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def setup_logging(local_rank: int,
23
+ save_path: typing.Optional[Path] = None,
24
+ log_level: typing.Union[str, int] = None) -> None:
25
+ if log_level is None:
26
+ level = logging.INFO
27
+ elif isinstance(log_level, str):
28
+ level = getattr(logging, log_level.upper())
29
+ elif isinstance(log_level, int):
30
+ level = log_level
31
+
32
+ if local_rank not in (-1, 0):
33
+ level = max(level, logging.WARN)
34
+
35
+ root_logger = logging.getLogger()
36
+ root_logger.setLevel(level)
37
+
38
+ formatter = logging.Formatter(
39
+ "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
40
+ datefmt="%y/%m/%d %H:%M:%S")
41
+
42
+ if not root_logger.hasHandlers():
43
+ console_handler = logging.StreamHandler(sys.stdout)
44
+ console_handler.setLevel(level)
45
+ console_handler.setFormatter(formatter)
46
+ root_logger.addHandler(console_handler)
47
+
48
+ if save_path is not None:
49
+ file_handler = logging.FileHandler(save_path / 'log')
50
+ file_handler.setLevel(level)
51
+ file_handler.setFormatter(formatter)
52
+ root_logger.addHandler(file_handler)
53
+
54
+
55
+ def setup_optimizer(model,
56
+ learning_rate: float):
57
+ """Create the AdamW optimizer for the given model with the specified learning rate. Based on
58
+ creation in the pytorch_transformers repository.
59
+
60
+ Args:
61
+ model (PreTrainedModel): The model for which to create an optimizer
62
+ learning_rate (float): Default learning rate to use when creating the optimizer
63
+
64
+ Returns:
65
+ optimizer (AdamW): An AdamW optimizer
66
+
67
+ """
68
+ param_optimizer = list(model.named_parameters())
69
+ no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
70
+ optimizer_grouped_parameters = [
71
+ {
72
+ "params": [
73
+ p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
74
+ ],
75
+ "weight_decay": 0.01,
76
+ },
77
+ {
78
+ "params": [
79
+ p for n, p in param_optimizer if any(nd in n for nd in no_decay)
80
+ ],
81
+ "weight_decay": 0.0,
82
+ },
83
+ ]
84
+
85
+ optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
86
+ return optimizer
87
+
88
+
89
+ def setup_dataset(task: str,
90
+ data_dir: typing.Union[str, Path],
91
+ split: str,
92
+ tokenizer: str) -> Dataset:
93
+ task_spec = registry.get_task_spec(task)
94
+ return task_spec.dataset(data_dir, split, tokenizer) # type: ignore
95
+
96
+
97
+ def setup_loader(dataset: Dataset,
98
+ batch_size: int,
99
+ local_rank: int,
100
+ n_gpu: int,
101
+ gradient_accumulation_steps: int,
102
+ num_workers: int) -> DataLoader:
103
+ sampler = DistributedSampler(dataset) if local_rank != -1 else RandomSampler(dataset)
104
+ batch_size = get_effective_batch_size(
105
+ batch_size, local_rank, n_gpu, gradient_accumulation_steps) * n_gpu
106
+ # WARNING: this will fail if the primary sequence is not the first thing the dataset returns
107
+ batch_sampler = BucketBatchSampler(
108
+ sampler, batch_size, False, lambda x: len(x[0]), dataset)
109
+
110
+ loader = DataLoader(
111
+ dataset,
112
+ num_workers=num_workers,
113
+ collate_fn=dataset.collate_fn, # type: ignore
114
+ batch_sampler=batch_sampler)
115
+
116
+ return loader
117
+
118
+
119
+ def setup_distributed(local_rank: int,
120
+ no_cuda: bool) -> typing.Tuple[torch.device, int, bool]:
121
+ if local_rank != -1 and not no_cuda:
122
+ torch.cuda.set_device(local_rank)
123
+ device: torch.device = torch.device("cuda", local_rank)
124
+ n_gpu = 1
125
+ dist.init_process_group(backend="nccl")
126
+ elif not torch.cuda.is_available() or no_cuda:
127
+ device = torch.device("cpu")
128
+ n_gpu = 1
129
+ else:
130
+ device = torch.device("cuda")
131
+ n_gpu = torch.cuda.device_count()
132
+
133
+ is_master = local_rank in (-1, 0)
134
+
135
+ return device, n_gpu, is_master
tape/utils/utils.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import random
3
+ from pathlib import Path
4
+ import logging
5
+ from time import strftime, gmtime
6
+ from datetime import datetime
7
+ import os
8
+ import argparse
9
+ import contextlib
10
+ from collections import defaultdict
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch.utils.data import Dataset
15
+ import torch.distributed as dist
16
+
17
+ logger = logging.getLogger(__name__)
18
+ FloatOrTensor = typing.Union[float, torch.Tensor]
19
+
20
+
21
+ def int_or_str(arg: str) -> typing.Union[int, str]:
22
+ try:
23
+ return int(arg)
24
+ except ValueError:
25
+ return arg
26
+
27
+
28
+ def check_is_file(file_path: str) -> str:
29
+ if file_path is None or os.path.isfile(file_path):
30
+ return file_path
31
+ else:
32
+ raise argparse.ArgumentTypeError(f"File path: {file_path} is not a valid file")
33
+
34
+
35
+ def check_is_dir(dir_path: str) -> str:
36
+ if dir_path is None or os.path.isdir(dir_path):
37
+ return dir_path
38
+ else:
39
+ raise argparse.ArgumentTypeError(f"Directory path: {dir_path} is not a valid directory")
40
+
41
+
42
+ def path_to_datetime(path: Path) -> datetime:
43
+ name = path.name
44
+ datetime_string = name.split('_')[0]
45
+ try:
46
+ year, month, day, hour, minute, second = datetime_string.split('-')
47
+ except ValueError:
48
+ try:
49
+ # Deprecated datetime strings
50
+ year, month, day, time_str = datetime_string.split('-')
51
+ hour, minute, second = time_str.split(':')
52
+ except ValueError:
53
+ return datetime(1, 1, 1)
54
+
55
+ pathdatetime = datetime(
56
+ int(year), int(month), int(day), int(hour), int(minute), int(second))
57
+ return pathdatetime
58
+
59
+
60
+ def get_expname(exp_name: typing.Optional[str],
61
+ task: typing.Optional[str] = None,
62
+ model_type: typing.Optional[str] = None) -> str:
63
+ if exp_name is None:
64
+ time_stamp = strftime("%y-%m-%d-%H-%M-%S", gmtime())
65
+ exp_name = f"{task}_{model_type}_{time_stamp}_{random.randint(0, int(1e6)):0>6d}"
66
+ return exp_name
67
+
68
+
69
+ def set_random_seeds(seed: int, n_gpu: int) -> None:
70
+ random.seed(seed)
71
+ np.random.seed(seed)
72
+ torch.manual_seed(seed)
73
+ if n_gpu > 0:
74
+ torch.cuda.manual_seed_all(seed) # type: ignore
75
+
76
+
77
+ def get_effective_num_gpus(local_rank: int, n_gpu: int) -> int:
78
+ if local_rank == -1:
79
+ num_gpus = n_gpu
80
+ else:
81
+ num_gpus = dist.get_world_size()
82
+ return num_gpus
83
+
84
+
85
+ def get_effective_batch_size(batch_size: int,
86
+ local_rank: int,
87
+ n_gpu: int,
88
+ gradient_accumulation_steps: int = 1) -> int:
89
+ eff_batch_size = float(batch_size)
90
+ eff_batch_size /= gradient_accumulation_steps
91
+ eff_batch_size /= get_effective_num_gpus(local_rank, n_gpu)
92
+ return int(eff_batch_size)
93
+
94
+
95
+ def get_num_train_optimization_steps(dataset: Dataset,
96
+ batch_size: int,
97
+ num_train_epochs: int) -> int:
98
+ return int(len(dataset) / batch_size * num_train_epochs)
99
+
100
+
101
+ class MetricsAccumulator:
102
+
103
+ def __init__(self, smoothing: float = 0.95):
104
+ self._loss_tmp = 0.
105
+ self._smoothloss: typing.Optional[float] = None
106
+ self._totalloss = 0.
107
+ self._metricstmp: typing.Dict[str, float] = defaultdict(lambda: 0.0)
108
+ self._smoothmetrics: typing.Dict[str, float] = {}
109
+ self._totalmetrics: typing.Dict[str, float] = defaultdict(lambda: 0.0)
110
+
111
+ self._nacc_steps = 0
112
+ self._nupdates = 0
113
+ self._smoothing = smoothing
114
+
115
+ def update(self,
116
+ loss: FloatOrTensor,
117
+ metrics: typing.Dict[str, FloatOrTensor],
118
+ step: bool = True) -> None:
119
+ if isinstance(loss, torch.Tensor):
120
+ loss = loss.item()
121
+
122
+ self._loss_tmp += loss
123
+ for name, value in metrics.items():
124
+ if isinstance(value, torch.Tensor):
125
+ value = value.item()
126
+ self._metricstmp[name] += value
127
+ self._nacc_steps += 1
128
+
129
+ if step:
130
+ self.step()
131
+
132
+ def step(self) -> typing.Dict[str, float]:
133
+ loss_tmp = self._loss_tmp / self._nacc_steps
134
+ metricstmp = {name: value / self._nacc_steps
135
+ for name, value in self._metricstmp.items()}
136
+
137
+ if self._smoothloss is None:
138
+ self._smoothloss = loss_tmp
139
+ else:
140
+ self._smoothloss *= self._smoothing
141
+ self._smoothloss += (1 - self._smoothing) * loss_tmp
142
+ self._totalloss += loss_tmp
143
+
144
+ for name, value in metricstmp.items():
145
+ if name in self._smoothmetrics:
146
+ currvalue = self._smoothmetrics[name]
147
+ newvalue = currvalue * self._smoothing + value * (1 - self._smoothing)
148
+ else:
149
+ newvalue = value
150
+
151
+ self._smoothmetrics[name] = newvalue
152
+ self._totalmetrics[name] += value
153
+
154
+ self._nupdates += 1
155
+
156
+ self._nacc_steps = 0
157
+ self._loss_tmp = 0
158
+ self._metricstmp = defaultdict(lambda: 0.0)
159
+
160
+ metricstmp['loss'] = loss_tmp
161
+ return metricstmp
162
+
163
+ def loss(self) -> float:
164
+ if self._smoothloss is None:
165
+ raise RuntimeError("Trying to get the loss without any updates")
166
+ return self._smoothloss
167
+
168
+ def metrics(self) -> typing.Dict[str, float]:
169
+ if self._nupdates == 0:
170
+ raise RuntimeError("Trying to get metrics without any updates")
171
+ return dict(self._smoothmetrics)
172
+
173
+ def final_loss(self) -> float:
174
+ return self._totalloss / self._nupdates
175
+
176
+ def final_metrics(self) -> typing.Dict[str, float]:
177
+ return {name: value / self._nupdates
178
+ for name, value in self._totalmetrics.items()}
179
+
180
+
181
+ class wrap_cuda_oom_error(contextlib.ContextDecorator):
182
+ """A context manager that wraps the Cuda OOM message so that you get some more helpful
183
+ context as to what you can/should change. Can also be used as a decorator.
184
+
185
+ Examples:
186
+ 1) As a context manager:
187
+
188
+ with wrap_cuda_oom_error(local_rank, batch_size, n_gpu, gradient_accumulation):
189
+ loss = model.forward(batch)
190
+ loss.backward()
191
+ optimizer.step()
192
+ optimizer.zero_grad
193
+
194
+ 2) As a decorator:
195
+
196
+ @wrap_cuda_oom_error(local_rank, batch_size, n_gpu, gradient_accumulation)
197
+ def run_train_epoch(args):
198
+ ...
199
+ <code to run training epoch>
200
+ ...
201
+ """
202
+
203
+ def __init__(self,
204
+ local_rank: int,
205
+ batch_size: int,
206
+ n_gpu: int = 1,
207
+ gradient_accumulation_steps: typing.Optional[int] = None):
208
+ self._local_rank = local_rank
209
+ self._batch_size = batch_size
210
+ self._n_gpu = n_gpu
211
+ self._gradient_accumulation_steps = gradient_accumulation_steps
212
+
213
+ def __enter__(self):
214
+ return self
215
+
216
+ def __exit__(self, exc_type, exc_value, traceback):
217
+ exc_args = exc_value.args if exc_value is not None else None
218
+ if exc_args and 'CUDA out of memory' in exc_args[0]:
219
+ eff_ngpu = get_effective_num_gpus(self._local_rank, self._n_gpu)
220
+ if self._gradient_accumulation_steps is not None:
221
+ eff_batch_size = get_effective_batch_size(
222
+ self._batch_size, self._local_rank, self._n_gpu,
223
+ self._gradient_accumulation_steps)
224
+ message = (f"CUDA out of memory. Reduce batch size or increase "
225
+ f"gradient_accumulation_steps to divide each batch over more "
226
+ f"forward passes.\n\n"
227
+ f"\tHyperparameters:\n"
228
+ f"\t\tbatch_size per backward-pass: {self._batch_size}\n"
229
+ f"\t\tgradient_accumulation_steps: "
230
+ f"{self._gradient_accumulation_steps}\n"
231
+ f"\t\tn_gpu: {eff_ngpu}\n"
232
+ f"\t\tbatch_size per (gpu * forward-pass): "
233
+ f"{eff_batch_size}")
234
+ else:
235
+ eff_batch_size = get_effective_batch_size(
236
+ self._batch_size, self._local_rank, self._n_gpu)
237
+ message = (f"CUDA out of memory. Reduce batch size to fit each "
238
+ f"iteration in memory.\n\n"
239
+ f"\tHyperparameters:\n"
240
+ f"\t\tbatch_size per forward-pass: {self._batch_size}\n"
241
+ f"\t\tn_gpu: {eff_ngpu}\n"
242
+ f"\t\tbatch_size per (gpu * forward-pass): "
243
+ f"{eff_batch_size}")
244
+ raise RuntimeError(message)
245
+ return False
246
+
247
+
248
+ def write_lmdb(filename: str, iterable: typing.Iterable, map_size: int = 2 ** 20):
249
+ """Utility for writing a dataset to an LMDB file.
250
+
251
+ Args:
252
+ filename (str): Output filename to write to
253
+ iterable (Iterable): An iterable dataset to write to. Entries must be pickleable.
254
+ map_size (int, optional): Maximum allowable size of database in bytes. Required by LMDB.
255
+ You will likely have to increase this. Default: 1MB.
256
+ """
257
+ import lmdb
258
+ import pickle as pkl
259
+ env = lmdb.open(filename, map_size=map_size)
260
+
261
+ with env.begin(write=True) as txn:
262
+ for i, entry in enumerate(iterable):
263
+ txn.put(str(i).encode(), pkl.dumps(entry))
264
+ txn.put(b'num_examples', pkl.dumps(i + 1))
265
+ env.close()
266
+
267
+
268
+ class IncrementalNPZ(object):
269
+ # Modified npz that allows incremental saving, from https://stackoverflow.com/questions/22712292/how-to-use-numpy-savez-in-a-loop-for-save-more-than-one-array # noqa: E501
270
+ def __init__(self, file):
271
+ import tempfile
272
+ import zipfile
273
+ import os
274
+
275
+ if isinstance(file, str):
276
+ if not file.endswith('.npz'):
277
+ file = file + '.npz'
278
+
279
+ compression = zipfile.ZIP_STORED
280
+
281
+ zipfile = self.zipfile_factory(file, mode="a", compression=compression)
282
+
283
+ # Stage arrays in a temporary file on disk, before writing to zip.
284
+ fd, tmpfile = tempfile.mkstemp(suffix='-numpy.npy')
285
+ os.close(fd)
286
+
287
+ self.tmpfile = tmpfile
288
+ self.zip = zipfile
289
+ self._i = 0
290
+
291
+ def zipfile_factory(self, *args, **kwargs):
292
+ import zipfile
293
+ import sys
294
+ if sys.version_info >= (2, 5):
295
+ kwargs['allowZip64'] = True
296
+ return zipfile.ZipFile(*args, **kwargs)
297
+
298
+ def savez(self, *args, **kwds):
299
+ import os
300
+ import numpy.lib.format as fmt
301
+
302
+ namedict = kwds
303
+ for val in args:
304
+ key = 'arr_%d' % self._i
305
+ if key in namedict.keys():
306
+ raise ValueError("Cannot use un-named variables and keyword %s" % key)
307
+ namedict[key] = val
308
+ self._i += 1
309
+
310
+ try:
311
+ for key, val in namedict.items():
312
+ fname = key + '.npy'
313
+ fid = open(self.tmpfile, 'wb')
314
+ with open(self.tmpfile, 'wb') as fid:
315
+ fmt.write_array(fid, np.asanyarray(val), allow_pickle=True)
316
+ self.zip.write(self.tmpfile, arcname=fname)
317
+ finally:
318
+ os.remove(self.tmpfile)
319
+
320
+ def close(self):
321
+ self.zip.close()
322
+
323
+ def __enter__(self):
324
+ return self
325
+
326
+ def __exit__(self, exc_type, exc_value, traceback):
327
+ self.close()
tape/visualization.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import os
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from pathlib import Path
6
+ import torch.nn as nn
7
+
8
+ from tensorboardX import SummaryWriter
9
+
10
+ try:
11
+ import wandb
12
+ WANDB_FOUND = True
13
+ except ImportError:
14
+ WANDB_FOUND = False
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class TAPEVisualizer(ABC):
20
+ """Base class for visualization in TAPE"""
21
+
22
+ @abstractmethod
23
+ def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
24
+ raise NotImplementedError
25
+
26
+ @abstractmethod
27
+ def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
28
+ raise NotImplementedError
29
+
30
+ @abstractmethod
31
+ def watch(self, model: nn.Module) -> None:
32
+ raise NotImplementedError
33
+
34
+ @abstractmethod
35
+ def log_metrics(self,
36
+ metrics_dict: typing.Dict[str, float],
37
+ split: str,
38
+ step: int):
39
+ raise NotImplementedError
40
+
41
+
42
+ class DummyVisualizer(TAPEVisualizer):
43
+ """Dummy class that doesn't do anything. Used for non-master branches."""
44
+
45
+ def __init__(self,
46
+ log_dir: typing.Union[str, Path] = '',
47
+ exp_name: str = '',
48
+ debug: bool = False):
49
+ pass
50
+
51
+ def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
52
+ pass
53
+
54
+ def watch(self, model: nn.Module) -> None:
55
+ pass
56
+
57
+ def log_metrics(self,
58
+ metrics_dict: typing.Dict[str, float],
59
+ split: str,
60
+ step: int):
61
+ pass
62
+
63
+
64
+ class TBVisualizer(TAPEVisualizer):
65
+
66
+ def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
67
+ log_dir = Path(log_dir) / exp_name
68
+ logger.info(f"tensorboard file at: {log_dir}")
69
+ self.logger = SummaryWriter(log_dir=str(log_dir))
70
+
71
+ def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
72
+ logger.warn("Cannot log config when using a TBVisualizer. "
73
+ "Configure wandb for this functionality")
74
+
75
+ def watch(self, model: nn.Module) -> None:
76
+ logger.warn("Cannot watch models when using a TBVisualizer. "
77
+ "Configure wandb for this functionality")
78
+
79
+ def log_metrics(self,
80
+ metrics_dict: typing.Dict[str, float],
81
+ split: str,
82
+ step: int):
83
+ for name, value in metrics_dict.items():
84
+ self.logger.add_scalar(split + "/" + name, value, step)
85
+
86
+
87
+ class WandBVisualizer(TAPEVisualizer):
88
+
89
+ def __init__(self, log_dir: typing.Union[str, Path], exp_name: str, debug: bool = False):
90
+ if not WANDB_FOUND:
91
+ raise ImportError("wandb module not available")
92
+ #if debug:
93
+ # os.environ['WANDB_MODE'] = 'dryrun'
94
+ #if 'WANDB_PROJECT' not in os.environ:
95
+ # # Want the user to set the WANDB_PROJECT.
96
+ # logger.warning("WANDB_PROJECT environment variable not found, "
97
+ # "not logging to app.wandb.ai")
98
+ # os.environ['WANDB_MODE'] = 'dryrun'
99
+ wandb.init(dir=log_dir, name=exp_name)
100
+
101
+ def log_config(self, config: typing.Dict[str, typing.Any]) -> None:
102
+ wandb.config.update(config)
103
+
104
+ def watch(self, model: nn.Module):
105
+ wandb.watch(model)
106
+
107
+ def log_metrics(self,
108
+ metrics_dict: typing.Dict[str, float],
109
+ split: str,
110
+ step: int):
111
+ wandb.log({f"{split.capitalize()} {name.capitalize()}": value
112
+ for name, value in metrics_dict.items()}, step=step)
113
+
114
+
115
+ def get(log_dir: typing.Union[str, Path],
116
+ exp_name: str,
117
+ local_rank: int,
118
+ debug: bool = False) -> TAPEVisualizer:
119
+ if local_rank not in (-1, 0):
120
+ return DummyVisualizer(log_dir, exp_name, debug)
121
+ elif WANDB_FOUND:
122
+ return WandBVisualizer(log_dir, exp_name, debug)
123
+ else:
124
+ return TBVisualizer(log_dir, exp_name, debug)