Niksa Praljak commited on
Commit
c865888
·
1 Parent(s): 03b411e

Add scripts for ProteoScribe Sampling

Browse files
Stage3_source/DSEma.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from deepspeed.runtime.zero import GatheredParameters
4
+ import deepspeed
5
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
6
+
7
+ def _z3_params_to_fetch(param_list):
8
+ return [
9
+ p for p in param_list
10
+ if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
11
+ ]
12
+
13
+
14
+ def moving_average(model, model_ema, beta=0.9999, device=None, zero_stage=3):
15
+ zero_stage_3 = (zero_stage == 3)
16
+ with torch.no_grad():
17
+ for param, param_ema in zip(model.parameters(),
18
+ model_ema.parameters()):
19
+ # TODO: use prefiltering for efficiency
20
+ params_to_fetch = _z3_params_to_fetch([param, param_ema
21
+ ]) if zero_stage_3 else []
22
+ should_gather_param = len(params_to_fetch) > 0
23
+ with deepspeed.zero.GatheredParameters(
24
+ params_to_fetch, enabled=should_gather_param):
25
+ data = param.data
26
+ if device is not None:
27
+ data = data.to(device)
28
+ #print('real model',data.shape, data)
29
+ #print('ema model',param_ema.shape, param_ema.data)
30
+ param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))
31
+ #print('after ema copy',param_ema.shape, param_ema.data)
32
+
33
+
34
+ def clone_zero_model(src_model, dst_model, zero_stage=0):
35
+ zero_stage_3 = (zero_stage == 3)
36
+ with torch.no_grad():
37
+ for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
38
+ # TODO: use prefiltering for efficiency
39
+ params_to_fetch = _z3_params_to_fetch([src_param, dst_param
40
+ ]) if zero_stage_3 else []
41
+ should_gather_param = len(params_to_fetch) > 0
42
+ with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
43
+ dst_param.data.copy_(src_param.data)
Stage3_source/PL_wrapper.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, optim
3
+ from torch.nn import functional as F
4
+ from torch.distributions import OneHotCategorical
5
+ from transformers.optimization import Adafactor
6
+
7
+ # PL functions
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import Trainer, seed_everything
10
+ from pytorch_lightning.callbacks import EarlyStopping
11
+
12
+ import functools
13
+ import math
14
+ #from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
15
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16
+ from torch.distributed.fsdp.wrap import (
17
+ size_based_auto_wrap_policy,
18
+ enable_wrap,
19
+ wrap
20
+ )
21
+
22
+ import deepspeed
23
+ from deepspeed.ops.adam import DeepSpeedCPUAdam
24
+
25
+ from sklearn.model_selection import train_test_split
26
+
27
+ from Stage3_source.DSEma import moving_average, clone_zero_model
28
+ import Stage3_source.transformer_training_helper as trainer_tools
29
+ import Stage3_source.helper_funcs as helper_tools
30
+ import Stage3_source.eval_metrics as eval_funcs
31
+ import Stage3_source.preprocess as prep
32
+
33
+ import copy
34
+
35
+ from torch.utils.data import DataLoader
36
+ import pandas as pd
37
+
38
+ from transformers import get_cosine_schedule_with_warmup
39
+
40
+ class PL_ProtARDM(pl.LightningModule):
41
+
42
+
43
+ def __init__(
44
+ self,
45
+ args: any,
46
+ model: nn.Module,
47
+ #ema_model: nn.Module,
48
+ ):
49
+
50
+ super().__init__()
51
+ #self.save_hyperparameters()
52
+
53
+ # arguments
54
+ self.script_args = args
55
+
56
+ # the whole model
57
+ self.model = model
58
+ #self.ema_model = ema_model
59
+
60
+ #clone_zero_model(self.model, self.ema_model, zero_stage=3)
61
+ ##self.ema_model = copy.deepcopy(self.model)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ t: torch.Tensor,
67
+ y_c: torch.Tensor,
68
+ ema=False,
69
+ ) -> torch.Tensor:
70
+
71
+ if ema:
72
+ logits = self.ema_model(x=x, t=t.view(-1,), y_c=y_c)
73
+ else:
74
+ logits = self.model(x=x, t=t.view(-1,), y_c=y_c)
75
+ return logits
76
+ #return F.softmax(logits, dim=1)
77
+
78
+
79
+ #def on_train_batch_end(self, *args, **kwargs):
80
+ # clone_zero_model(self.model, self.ema_model, zero_stage=3)
81
+ # #moving_average(self.model, self.ema_model, beta=0.0, zero_stage=3)
82
+
83
+
84
+ def configure_optimizers(self, ):
85
+
86
+ if self.script_args.choose_optim == 'AdamW':
87
+
88
+ if isinstance(self, FSDP):
89
+ print("Enter FSDP")
90
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay)
91
+
92
+ else:
93
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay)
94
+
95
+ elif self.script_args.choose_optim == 'AdaFactor':
96
+ optimizer = Adafactor(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay, relative_step=False)
97
+
98
+ elif self.script_args.choose_optim == 'Adam':
99
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.script_args.lr)
100
+
101
+ elif self.script_args.choose_optim == 'DeepSpeedCPUAdam':
102
+ optimizer = DeepSpeedCPUAdam(self.parameters(), lr=self.script_args.lr, weight_decay=self.script_args.weight_decay)
103
+
104
+ if self.script_args.scheduler_gamma is not None:
105
+ if isinstance(self.script_args.scheduler_gamma, str):
106
+ if 'coswarmup' == self.script_args.scheduler_gamma.lower():
107
+ print(f'Using cossine warmup scheduler with decay')
108
+ num_warmup_steps=self.script_args.traindata_len
109
+ num_training_steps=self.script_args.traindata_len*self.script_args.epochs
110
+ print(f'Num_warmup_steps={num_warmup_steps}')
111
+ print(f'Num_training_steps={num_training_steps}')
112
+
113
+ def _get_cosine_schedule_with_warmup_lr_lambda(
114
+ current_step: int, num_warmup_steps: int, num_training_steps: int, num_cycles: float
115
+ ):
116
+ if current_step < num_warmup_steps:
117
+ return float(current_step) / float(max(1, num_warmup_steps))
118
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
119
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
120
+
121
+ lr_lambda = functools.partial(
122
+ _get_cosine_schedule_with_warmup_lr_lambda,
123
+ num_warmup_steps=num_warmup_steps,
124
+ num_training_steps=num_training_steps,
125
+ num_cycles=0.5,
126
+ )
127
+ return {
128
+ "optimizer": optimizer,
129
+ "lr_scheduler": {
130
+ "scheduler": optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1),
131
+ "interval": "step",
132
+ },
133
+ }
134
+
135
+ #return {
136
+ # "optimizer": optimizer,
137
+ # "lr_scheduler": {
138
+ # "scheduler": get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps),
139
+ # "interval": "step",
140
+ # },
141
+ #}
142
+ else:
143
+ print(f'Using Exponential learning rate decay / epoch with factor: {self.script_args.scheduler_gamma}')
144
+ return {
145
+ "optimizer": optimizer,
146
+ "lr_scheduler": {
147
+ "scheduler": optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.script_args.scheduler_gamma, verbose=True),
148
+ "interval": "epoch",
149
+ },
150
+ }
151
+ else:
152
+ return optimizer
153
+
154
+ #else:
155
+ # print("Please make choose_option variable from these options: 'AdamW', 'AdaFactor', 'Adam', 'DeepSpeedCPUAdam'")
156
+
157
+ def common_step(
158
+ self,
159
+ realization: torch.Tensor,
160
+ realization_idx: any,
161
+ stage: str) -> dict:
162
+
163
+ if isinstance(realization, list):
164
+
165
+ # class labels
166
+ y_c = realization[1]#.long()
167
+
168
+ # input samples
169
+ realization = realization[0]
170
+ batch_size, seq_length = realization.size()
171
+
172
+ realization = realization.reshape(batch_size, 1, seq_length).long()
173
+
174
+ train_tuple = self.cond_elbo_objective(
175
+ realization=realization,
176
+ y_c=y_c,
177
+ realization_idx=realization_idx,
178
+ stage=stage,
179
+ ema=True if 'ema' in stage.lower() else False,
180
+ )
181
+
182
+ if len(train_tuple) == 1:
183
+ loss = train_tuple[0]
184
+ else:
185
+ loss = train_tuple[0]
186
+ metrics = train_tuple[1]
187
+
188
+ if realization_idx == 0:
189
+ gpu_memory_usage = helper_tools.print_gpu_initialization()
190
+ self.log(f"{stage}_gpu_memory_usage", gpu_memory_usage, sync_dist=True)
191
+
192
+ sync_dist = True if 'val' in stage else False
193
+ # track loss
194
+ self.log(f"{stage}_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist)
195
+ # track performance metrics
196
+ if len(train_tuple) > 1:
197
+ self.log(f"{stage}_prev_hard_acc", metrics[0], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist)
198
+ self.log(f"{stage}_prev_soft_acc", metrics[1], on_step=True, on_epoch=True, sync_dist=sync_dist)
199
+ self.log(f"{stage}_fut_hard_acc", metrics[2], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist)
200
+ self.log(f"{stage}_fut_soft_acc", metrics[3], on_step=True, on_epoch=True, sync_dist=sync_dist)
201
+ self.log(f"{stage}_current_hard_acc", metrics[4], prog_bar=True, on_step=True, on_epoch=True, sync_dist=sync_dist)
202
+ self.log(f"{stage}_current_soft_acc", metrics[5], on_step=True, on_epoch=True, sync_dist=sync_dist)
203
+ self.log(f"{stage}_current_ppl", metrics[6], on_step=True, on_epoch=True, sync_dist=sync_dist)
204
+ self.log(f"{stage}_prev_ppl", metrics[7], on_step=True, on_epoch=True, sync_dist=sync_dist)
205
+ self.log(f"{stage}_fut_ppl", metrics[8], on_step=True, on_epoch=True, sync_dist=sync_dist)
206
+ self.log(f"{stage}_pos_entropy", metrics[9], on_step=True, on_epoch=True, sync_dist=sync_dist)
207
+
208
+ torch.cuda.empty_cache()
209
+ return {'loss': loss}
210
+
211
+ def training_step(
212
+ self,
213
+ realization: torch.Tensor,
214
+ realization_idx: any):
215
+ return self.common_step(realization, realization_idx, stage='train')
216
+
217
+ def validation_step(
218
+ self,
219
+ realization: torch.Tensor,
220
+ realization_idx: any):
221
+ self.common_step(realization, realization_idx, stage='val')
222
+ #self.common_step(realization, realization_idx, stage='EMA_val')
223
+
224
+ def apply_OneHotCat(self, probs: torch.Tensor) -> any:
225
+ return OneHotCategorical(probs=probs.permute(0,2,1))
226
+ #return OneHotCategorical(probs=F.softmax(probs.permute(0,2,1), dim=-1))
227
+
228
+ def cond_elbo_objective(
229
+ self,
230
+ realization: torch.Tensor,
231
+ y_c: torch.Tensor,
232
+ realization_idx: any,
233
+ stage: str,
234
+ ema=False,
235
+ ):
236
+
237
+ bs, channel, seq_length = realization.size()
238
+
239
+ # get a batch of random sampling paths
240
+ sampled_random_path = trainer_tools.sample_random_path(bs, seq_length, device=self.script_args.device)
241
+ # sample a set of random smapling steps for each individual training sequences in the current batch
242
+ idx = trainer_tools.sample_random_index_for_sampling(bs, seq_length, device=self.script_args.device, option='random')
243
+ # we create a mask that masks the location were we've already sampled
244
+ random_path_mask = trainer_tools.create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length)
245
+ # create a mask that masks the location where we are currently sampling
246
+ current_path_mask = trainer_tools.create_sampling_location_mask(sampled_random_path, idx, bs, seq_length)
247
+ # future sampling locations (i.e. >t)
248
+ future_path_mask = trainer_tools.create_mask_at_future_path_index(sampled_random_path, idx, bs, seq_length)
249
+ # tokenize realization
250
+ real_tokens, bs, seq_length = trainer_tools.create_token_labels(self.script_args, realization)
251
+ #real_tokens = realization.clone().squeeze(1)
252
+ # mask realizations
253
+ real_token_masked = trainer_tools.mask_realizations(real_tokens, random_path_mask)
254
+ # conditional probs
255
+ #probs = self(x=real_token_masked, t=idx, y_c=y_c, ema=ema)
256
+ logits = self(x=real_token_masked, t=idx, y_c=y_c, ema=ema)
257
+
258
+ conditional_prob = OneHotCategorical(logits=logits.permute(0,2,1))
259
+ #conditional_prob = self.apply_OneHotCat(probs=probs)
260
+ # evaluate the value of the log prob for the given realization
261
+ log_prob = trainer_tools.log_prob_of_realization(self.script_args, conditional_prob, real_tokens)
262
+
263
+ # compute an average over all the unsampled
264
+ #log_prob_unsampled = trainer_tools.log_prob_of_unsampled_locations(log_prob.to(self.script_args.device), real_token_masked.to(self.script_args.device))
265
+ log_prob_unsampled = trainer_tools.log_prob_of_unsampled_locations(log_prob, real_token_masked)
266
+ #log_prob_unsampled = trainer_tools.log_prob_of_unsampled_locations(log_prob, real_token_masked, real_tokens)
267
+
268
+
269
+ # compute an average loss i.e. negative average log-likelihood over the batch elements
270
+ log_prob_weighted = trainer_tools.weight_log_prob(log_prob_unsampled, idx, seq_length)
271
+ # compute an average loss i.e. negative average log-likelihood over the batch elements
272
+ loss = trainer_tools.compute_average_loss_for_batch(log_prob_weighted)
273
+
274
+ #if 'val' in stage:
275
+ probs = F.softmax(logits, dim=1)
276
+ metrics = self.performance_step(
277
+ real_tokens=real_tokens.cpu(),
278
+ idx=idx.cpu(),
279
+ sampled_random_path=sampled_random_path.cpu().float(),
280
+ probs=probs.cpu().float(),
281
+ conditional_prob=conditional_prob)
282
+
283
+ return loss, metrics
284
+
285
+
286
+ # return loss,
287
+
288
+ @torch.no_grad()
289
+ def performance_step(
290
+ self,
291
+ real_tokens: torch.Tensor,
292
+ idx: torch.Tensor,
293
+ sampled_random_path: torch.Tensor,
294
+ probs: torch.Tensor,
295
+ conditional_prob: torch.Tensor
296
+ ) -> tuple:
297
+
298
+
299
+ # create numerical token sequence
300
+ sample_seq = torch.argmax(trainer_tools.sample_from_conditional(conditional_prob).cpu(), dim=1)
301
+
302
+ # eval prev positions in terms of time
303
+ prev_B_hard_acc, prev_B_soft_acc, fut_B_hard_acc, fut_B_soft_acc, current_B_hard_acc, current_B_soft_acc = eval_funcs.compute_acc_given_time_pos(
304
+ real_tokens=real_tokens,
305
+ sample_seq=sample_seq,
306
+ sample_path=sampled_random_path,
307
+ idx=idx
308
+ )
309
+
310
+ # compute ppl given time position
311
+ current_ppl, prev_ppl, fut_ppl = eval_funcs.compute_ppl_given_time_pos(
312
+ probs=probs,
313
+ sample_path=sampled_random_path,
314
+ idx=idx
315
+ )
316
+
317
+ # average positional entropy
318
+ pos_entropy = trainer_tools.compute_pos_entropy(probs=probs).mean().item()
319
+
320
+ metric_evals = (
321
+ prev_B_hard_acc,
322
+ prev_B_soft_acc,
323
+ fut_B_hard_acc,
324
+ fut_B_soft_acc,
325
+ current_B_hard_acc,
326
+ current_B_soft_acc,
327
+ current_ppl,
328
+ prev_ppl,
329
+ fut_ppl,
330
+ pos_entropy
331
+ )
332
+
333
+ return metric_evals
334
+
335
+
336
+
337
+ class PFamDataModule(pl.LightningDataModule):
338
+ def __init__(self, args):
339
+ super().__init__()
340
+ self.args = args
341
+
342
+ #df = pd.read_csv(args.data_root)
343
+ #data = torch.load(args.data_root)
344
+ data = self.load_data()
345
+
346
+ num_seq_list, text_emb_list = prep.prepare_protein_data(
347
+ args=args,
348
+ data_dict=data
349
+ )
350
+
351
+ print('Performing 80/20 random train/val split')
352
+ num_seq_list_train, num_seq_list_val, text_emb_train, text_emb_val = train_test_split(num_seq_list,
353
+ text_emb_list,
354
+ test_size=args.valid_size,
355
+ #stratify=class_label_list,
356
+ random_state=args.seed)
357
+ print(f'Number of training samples: {len(num_seq_list_train)}')
358
+ print(f'Number of validation samples: {len(num_seq_list_val)}')
359
+
360
+ self.train_dataset = prep.protein_dataset(
361
+ num_seq_list=num_seq_list_train,
362
+ text_emb=text_emb_train
363
+ )
364
+
365
+ self.val_dataset = prep.protein_dataset(
366
+ num_seq_list=num_seq_list_val,
367
+ text_emb=text_emb_val
368
+ )
369
+
370
+ def load_data(self):
371
+
372
+ try:
373
+
374
+ print(self.args.swissprot_data_root, self.args.pfam_data_root)
375
+
376
+ if self.args.swissprot_data_root != "None":
377
+ swissprot_data = torch.load(self.args.swissprot_data_root)
378
+ else:
379
+ swissprot_data=None
380
+
381
+ if self.args.pfam_data_root != "None":
382
+ pfam_data = torch.load(self.args.pfam_data_root)
383
+ else:
384
+ pfam_data=None
385
+
386
+ if (self.args.swissprot_data_root != "None") and (self.args.pfam_data_root != "None"):
387
+ return self.merge_and_append_values(dict1=swissprot_data, dict2=pfam_data)
388
+ elif self.args.swissprot_data_root == "None":
389
+ return pfam_data
390
+ elif self.args.pfam_data_root == "None":
391
+ return swissprot_data
392
+ else:
393
+ raise ValueError('Both SwissProt and Pfam datasets are unavailable.')
394
+
395
+ except FileNotFoundError as e:
396
+ raise FileNotFoundError(f"Data file not found: {e}")
397
+
398
+
399
+ def merge_and_append_values(self, dict1, dict2):
400
+
401
+ merged_dict = {}
402
+
403
+ # Combine all keys from both dictionaries
404
+ all_keys = set(dict1) | set(dict2)
405
+
406
+ for key in all_keys:
407
+ values = []
408
+ if key in dict1:
409
+ values.append(dict1[key])
410
+ if key in dict2:
411
+ values.append(dict2[key])
412
+
413
+ # Merge values for each key
414
+ # This merges lists or appends non-list values
415
+ merged_dict[key] = [item for sublist in values for item in (sublist if isinstance(sublist, list) else [sublist])]
416
+
417
+ return merged_dict
418
+
419
+ def train_dataloader(self):
420
+ return DataLoader(
421
+ self.train_dataset,
422
+ batch_size=self.args.batch_size,
423
+ num_workers=self.args.num_workers,
424
+ shuffle=True
425
+ )
426
+
427
+ def val_dataloader(self):
428
+ return DataLoader(
429
+ self.val_dataset,
430
+ batch_size=self.args.batch_size,
431
+ num_workers=self.args.num_workers,
432
+ shuffle=False
433
+ )
Stage3_source/__init__.py ADDED
File without changes
Stage3_source/animation_tools.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import imageio
4
+ import os
5
+
6
+ # convert the numerical labels tok characters
7
+ def convert_num_to_char(
8
+ tokens: list,
9
+ char_tokens: any
10
+ ) -> str:
11
+ return "".join([tokens[num] for num in char_tokens.tolist()])
12
+
13
+ # draw text onto white page
14
+ def draw_text(
15
+ image: any,
16
+ text: any,
17
+ font: any,
18
+ position: tuple=(0,0),
19
+ max_width: any=None,
20
+ fill: tuple=(0,0,0)
21
+ ) -> None:
22
+
23
+ draw = ImageDraw.Draw(image)
24
+ if max_width:
25
+ wrapped_text = textwrap.fill(text, width=max_width)
26
+ else:
27
+ wrapped_text = text
28
+ draw.multiline_text(position, wrapped_text, font=font, fill=fill)
29
+
30
+
31
+ # create gif animation
32
+ def generate_text_animation(
33
+ text_list: list,
34
+ text_animation_path: str,
35
+ output_temp_path: str='./outputs/temp_files'
36
+ ) -> None:
37
+
38
+ # create images with text
39
+ image_files = []
40
+ for index, text in enumerate(text_list):
41
+
42
+ img = Image.new('RGB', (600, 159), color=(255, 255, 255)) # Create a white image
43
+ font = ImageFont.load_default()
44
+ draw_text(img, text, font, position=(10, 10), max_width=80, fill=(0, 0, 0))
45
+
46
+ # Save image to a temporary file
47
+ os.makedirs(output_temp_path, exist_ok=True)
48
+ # temp_file = f'./outputs/temp_image_{index}.png'
49
+ temp_file = output_temp_path + f'/temp_image_{index}.png'
50
+ img.save(temp_file)
51
+ image_files.append(temp_file)
52
+
53
+ # Read saved images and create a GIF
54
+ images = [imageio.imread(file) for file in image_files]
55
+ imageio.mimsave(
56
+ text_animation_path,
57
+ images,
58
+ format='GIF',
59
+ duration=0.2,
60
+ )
61
+
62
+ # clean up temp image files
63
+ for file in image_files:
64
+ os.remove(file)
65
+ return
Stage3_source/cond_diff_transformer_layer.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from axial_positional_embedding import AxialPositionalEmbedding
6
+ from linear_attention_transformer import LinearAttentionTransformer
7
+
8
+ #Adapted from ehoogeboom github repo ...
9
+
10
+ class SinusoidalPosEmb(nn.Module):
11
+
12
+ """
13
+ Time embeddings
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ dim,
19
+ num_steps,
20
+ rescale_steps=4000
21
+ ):
22
+
23
+ super().__init__()
24
+
25
+ self.dim = dim
26
+ self.num_steps = float(num_steps)
27
+ self.rescale_steps = float(rescale_steps)
28
+
29
+
30
+ def forward(
31
+ self,
32
+ x
33
+ ):
34
+
35
+ x = x/self.num_steps * self.rescale_steps
36
+ device=x.device
37
+ half_dim = self.dim // 2
38
+ emb = math.log(10000) / (half_dim - 1)
39
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
40
+ emb = x[:,None] * emb[None,:]
41
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
42
+ return emb
43
+
44
+
45
+
46
+
47
+ class LinearAttentionTransformerEmbedding(nn.Module):
48
+
49
+ def __init__(
50
+ self,
51
+ args,
52
+ input_dim,
53
+ output_dim,
54
+ dim,
55
+ depth,
56
+ n_blocks,
57
+ max_seq_len,
58
+ num_timesteps,
59
+ heads=8,
60
+ dim_head=None,
61
+ causal=False,
62
+ reversible=False,
63
+ ff_chunks=1,
64
+ ff_glu=False,
65
+ ff_dropout=0.,
66
+ attn_layer_dropout=0.,
67
+ attn_dropout=0.,
68
+ blindspot_size=1,
69
+ n_local_attn_heads=0,
70
+ local_attn_window_size=128,
71
+ return_embeddings=False,
72
+ recieves_context=False,
73
+ pkm_layers=tuple(),
74
+ pkm_num_keys=128,
75
+ attend_axially=False,
76
+ linformer_settings=None,
77
+ context_linformer_settings=None
78
+ ):
79
+ assert (max_seq_len % local_attn_window_size) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster'
80
+ super().__init__()
81
+
82
+ self.max_seq_len = max_seq_len
83
+ self.depth = depth
84
+ self.emb_dim = dim
85
+ self.n_blocks = n_blocks
86
+
87
+
88
+ # token embeddings
89
+ self.x_emb_NN = nn.Embedding(input_dim, self.emb_dim)
90
+
91
+ # class label embedding
92
+ #self.class_emb_NN = nn.Embedding(args.num_y_class_labels, self.emb_dim)
93
+ self.y_mlp = nn.Sequential(
94
+ nn.Linear(args.text_emb_dim, self.emb_dim*4),
95
+ nn.Softplus(),
96
+ nn.Linear(self.emb_dim*4, self.emb_dim*n_blocks*depth)
97
+ )
98
+
99
+ # time embeddings
100
+ self.time_pos_emb = SinusoidalPosEmb(self.emb_dim, num_timesteps)
101
+ self.mlp = nn.Sequential(
102
+ nn.Linear(self.emb_dim, self.emb_dim*4),
103
+ nn.Softplus(),
104
+ nn.Linear(self.emb_dim*4, self.emb_dim*n_blocks*depth)
105
+ )
106
+
107
+ # token positional embeddings
108
+ self.axial_pos_emb = AxialPositionalEmbedding(
109
+ dim = self.emb_dim,
110
+ axial_shape=(
111
+ max_seq_len // local_attn_window_size,
112
+ local_attn_window_size)
113
+ )
114
+
115
+ self.transformer_blocks = torch.nn.ModuleList()
116
+
117
+ for ii in range(n_blocks):
118
+
119
+ self.transformer_blocks.append(torch.nn.ModuleList())
120
+
121
+ for jj in range(depth):
122
+
123
+ self.transformer_blocks[-1].append(
124
+ LinearAttentionTransformer(
125
+ self.emb_dim,
126
+ 1,
127
+ max_seq_len,
128
+ heads=heads,
129
+ dim_head=dim_head,
130
+ causal=causal,
131
+ ff_chunks=ff_chunks,
132
+ ff_glu=ff_glu,
133
+ ff_dropout=ff_dropout,
134
+ attn_layer_dropout=attn_layer_dropout,
135
+ reversible=reversible,
136
+ blindspot_size=blindspot_size,
137
+ n_local_attn_heads=n_local_attn_heads,
138
+ local_attn_window_size=local_attn_window_size,
139
+ attend_axially=attend_axially,
140
+ linformer_settings=linformer_settings,
141
+ context_linformer_settings=context_linformer_settings
142
+ )
143
+ )
144
+
145
+ self.norm = nn.LayerNorm(dim)
146
+ self.out = nn.Linear(self.emb_dim, output_dim) if not return_embeddings else nn.Identity()
147
+
148
+
149
+ def forward(self, x, t, y_c, **kwargs):
150
+
151
+ # time embeddings
152
+ t = self.time_pos_emb(t).type([p.dtype for p in self.mlp.parameters()][0])
153
+ t = self.mlp(t)
154
+ time_embed = t.reshape(x.size(0), 1, self.emb_dim, self.n_blocks, self.depth)
155
+ # token embeddings
156
+ x = self.x_emb_NN(x.long()) # final shape (batch_size, timelength, model_emb_dim)
157
+ # positional embeddings
158
+ x_pos = self.axial_pos_emb(x).type(x.type())
159
+ x_embed_axial = x + x_pos
160
+ h = torch.zeros_like(x_embed_axial)
161
+ # z_t embedding
162
+ #y_emb = self.class_emb_NN(y_c)
163
+ y_emb = self.y_mlp(y_c)
164
+ y_emb = y_emb.reshape(x.size(0), 1, self.emb_dim, self.n_blocks, self.depth)
165
+
166
+ for i, block in enumerate(self.transformer_blocks):
167
+
168
+ h = h+x_embed_axial
169
+ for j, transformer in enumerate(block):
170
+
171
+ h = transformer(h + time_embed[...,i,j] + y_emb[...,i,j])
172
+
173
+ h = self.norm(h)
174
+ output = self.out(h)
175
+
176
+ return output.permute(0,2,1)
177
+
178
+
179
+ def add_model_args(parser):
180
+
181
+ # Flow params
182
+ parser.add_argument('--num_steps', type=int, default=1)
183
+ parser.add_argument('--actnorm', type=eval, default=False)
184
+ parser.add_argument('--perm_channel', type=str, default='none', choices={'conv', 'shuffle', 'none'})
185
+ parser.add_argument('--perm_length', type=str, default='reverse', choices={'reverse', 'none'})
186
+ parser.add_argument('--input_dp_rate', type=float, default=0.0)
187
+
188
+ # Transformer params.
189
+ parser.add_argument('--transformer_dim', type=int, default=512)
190
+ parser.add_argument('--transformer_heads', type=int, default=16)
191
+ parser.add_argument('--transformer_depth', type=int, default=16)
192
+ parser.add_argument('--transformer_blocks', type=int, default=1)
193
+ parser.add_argument('--transformer_dropout', type=float, default=0.1)
194
+ parser.add_argument('--transformer_reversible', type=eval, default=False)
195
+ parser.add_argument('--transformer_local_heads', type=int, default=8)
196
+ parser.add_argument('--transformer_local_size', type=int, default=128)
197
+
198
+ def get_model(args, data_shape, num_classes):
199
+
200
+ data_shape = data_shape
201
+ num_classes = num_classes
202
+ input_dp_rate = args.input_dp_rate
203
+ transformer_dim = args.transformer_dim
204
+ transformer_heads = args.transformer_heads
205
+ transformer_depth = args.transformer_depth
206
+ transformer_blocks = args.transformer_blocks
207
+ transformer_local_heads = args.transformer_local_heads
208
+ transformer_local_size = args.transformer_local_size
209
+ transformer_reversible = args.transformer_reversible
210
+ diffusion_steps = args.diffusion_steps
211
+
212
+ C, _ = num_classes, data_shape[0]*data_shape[1]
213
+ L = args.diffusion_steps
214
+
215
+ print('Data shape index 0:', L)
216
+ current_shape = (L,)
217
+
218
+ class DiffTransformer(nn.Module):
219
+
220
+ def __init__(self,):
221
+
222
+ super(DiffTransformer, self).__init__()
223
+
224
+ self.transformer = LinearAttentionTransformerEmbedding(
225
+ args=args,
226
+ input_dim=num_classes,
227
+ output_dim=num_classes,
228
+ dim=transformer_dim,
229
+ heads=transformer_heads,
230
+ depth=transformer_depth,
231
+ n_blocks=transformer_blocks,
232
+ max_seq_len=L,
233
+ num_timesteps=diffusion_steps,
234
+ causal=False, # no autoregression
235
+ ff_dropout=0, # dropout for feedforward NN
236
+ attn_layer_dropout=input_dp_rate, # dropout right after self-att layer
237
+ attn_dropout=0, # dropout post-attention
238
+ n_local_attn_heads=transformer_local_heads,
239
+ # number of local attention heads for (QK)*V attention.
240
+ # this can be a tuple specifying the exact number of local
241
+ # attention heads at that depth
242
+ local_attn_window_size=transformer_local_size,
243
+ # receptive field of the local attention
244
+ reversible=transformer_reversible,
245
+ # use reversible nets, from reformer paper
246
+ )
247
+
248
+
249
+ def forward(self, x, t, y_c):
250
+ x = self.transformer(x,t,y_c)
251
+ return x
252
+
253
+
254
+ model = DiffTransformer()
255
+
256
+ return model
257
+
258
+
259
+
260
+
Stage3_source/diff_transformer_layer.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from axial_positional_embedding import AxialPositionalEmbedding
6
+ from linear_attention_transformer import LinearAttentionTransformer
7
+
8
+ #Adapted from ehoogeboom github repo ...
9
+
10
+ class SinusoidalPosEmb(nn.Module):
11
+
12
+ """
13
+ Time embeddings
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ dim,
19
+ num_steps,
20
+ rescale_steps=4000
21
+ ):
22
+
23
+ super().__init__()
24
+
25
+ self.dim = dim
26
+ self.num_steps = float(num_steps)
27
+ self.rescale_steps = float(rescale_steps)
28
+
29
+
30
+ def forward(
31
+ self,
32
+ x
33
+ ):
34
+
35
+ x = x/self.num_steps * self.rescale_steps
36
+ device=x.device
37
+ half_dim = self.dim // 2
38
+ emb = math.log(10000) / (half_dim - 1)
39
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
40
+ emb = x[:,None] * emb[None,:]
41
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
42
+ return emb
43
+
44
+
45
+
46
+
47
+ class LinearAttentionTransformerEmbedding(nn.Module):
48
+
49
+ def __init__(
50
+ self,
51
+ input_dim,
52
+ output_dim,
53
+ dim,
54
+ depth,
55
+ n_blocks,
56
+ max_seq_len,
57
+ num_timesteps,
58
+ heads=8,
59
+ dim_head=None,
60
+ causal=False,
61
+ reversible=False,
62
+ ff_chunks=1,
63
+ ff_glu=False,
64
+ ff_dropout=0.,
65
+ attn_layer_dropout=0.,
66
+ attn_dropout=0.,
67
+ blindspot_size=1,
68
+ n_local_attn_heads=0,
69
+ local_attn_window_size=128,
70
+ return_embeddings=False,
71
+ recieves_context=False,
72
+ pkm_layers=tuple(),
73
+ pkm_num_keys=128,
74
+ attend_axially=False,
75
+ linformer_settings=None,
76
+ context_linformer_settings=None
77
+ ):
78
+ assert (max_seq_len % local_attn_window_size) == 0, 'max sequence length must be divisible by the window size, to calculate number of kmeans cluster'
79
+ super().__init__()
80
+
81
+ self.max_seq_len = max_seq_len
82
+ self.depth = depth
83
+ self.emb_dim = dim
84
+ self.n_blocks = n_blocks
85
+
86
+ print('Input dimension', input_dim)
87
+ print('Output dimension', output_dim)
88
+
89
+ # token embeddings
90
+ self.x_emb_NN = nn.Embedding(input_dim, self.emb_dim)
91
+
92
+ # time embeddings
93
+ self.time_pos_emb = SinusoidalPosEmb(self.emb_dim, num_timesteps)
94
+ self.mlp = nn.Sequential(
95
+ nn.Linear(self.emb_dim, self.emb_dim*4),
96
+ nn.Softplus(),
97
+ nn.Linear(self.emb_dim*4, self.emb_dim*n_blocks*depth)
98
+ )
99
+
100
+ # token positional embeddings
101
+ self.axial_pos_emb = AxialPositionalEmbedding(
102
+ dim = self.emb_dim,
103
+ axial_shape=(
104
+ max_seq_len // local_attn_window_size,
105
+ local_attn_window_size)
106
+ )
107
+
108
+ self.pos_emb = nn.Embedding(1, self.emb_dim)
109
+
110
+ self.transformer_blocks = torch.nn.ModuleList()
111
+
112
+ for ii in range(n_blocks):
113
+
114
+ self.transformer_blocks.append(torch.nn.ModuleList())
115
+
116
+ for jj in range(depth):
117
+
118
+ self.transformer_blocks[-1].append(
119
+ LinearAttentionTransformer(
120
+ self.emb_dim,
121
+ 1,
122
+ max_seq_len,
123
+ heads=heads,
124
+ dim_head=dim_head,
125
+ causal=causal,
126
+ ff_chunks=ff_chunks,
127
+ ff_glu=ff_glu,
128
+ ff_dropout=ff_dropout,
129
+ attn_layer_dropout=attn_layer_dropout,
130
+ reversible=reversible,
131
+ blindspot_size=blindspot_size,
132
+ n_local_attn_heads=n_local_attn_heads,
133
+ local_attn_window_size=local_attn_window_size,
134
+ attend_axially=attend_axially,
135
+ linformer_settings=linformer_settings,
136
+ context_linformer_settings=context_linformer_settings
137
+ )
138
+ )
139
+
140
+ self.norm = nn.LayerNorm(dim)
141
+ self.out = nn.Linear(self.emb_dim, output_dim) if not return_embeddings else nn.Identity()
142
+ # self.out = nn.Conv1d(self.emb_dim, output_dim, kernel_size=1,stride=1)
143
+
144
+
145
+ def forward(self, x, t, **kwargs):
146
+
147
+
148
+ t = self.time_pos_emb(t)
149
+ t = self.mlp(t)
150
+
151
+ time_embed = t.reshape(x.size(0), 1, self.emb_dim, self.n_blocks, self.depth)
152
+ x = self.x_emb_NN(x.long()) # final shape (batch_size, timelength, model_emb_dim)
153
+ x_pos = self.axial_pos_emb(x).type(x.type())
154
+ # x_pos = self.pos_emb( self._create_pos_vec(x=x)).type(x.type())
155
+ x_embed_axial = x + x_pos
156
+ h = torch.zeros_like(x_embed_axial)
157
+
158
+ for i, block in enumerate(self.transformer_blocks):
159
+
160
+ h = h+x_embed_axial
161
+ for j, transformer in enumerate(block):
162
+
163
+ h = transformer(h+time_embed[...,i,j])
164
+
165
+ h = self.norm(h)
166
+ output = self.out(h)
167
+
168
+ return output.permute(0,2,1)
169
+
170
+ class Rezero(nn.Module):
171
+
172
+ def __init__(self):
173
+ super(Rezero, self).__init__()
174
+ self.alpha = torch.nn.Parameter(torch.zeros(size=(1,)))
175
+
176
+ def forward(self,x):
177
+ return self.alpha * x
178
+
179
+
180
+ def add_model_args(parser):
181
+
182
+ # Flow params
183
+ parser.add_argument('--num_steps', type=int, default=1)
184
+ parser.add_argument('--actnorm', type=eval, default=False)
185
+ parser.add_argument('--perm_channel', type=str, default='none', choices={'conv', 'shuffle', 'none'})
186
+ parser.add_argument('--perm_length', type=str, default='reverse', choices={'reverse', 'none'})
187
+
188
+ parser.add_argument('--input_dp_rate', type=float, default=0.0)
189
+
190
+ # Transformer params.
191
+ parser.add_argument('--transformer_dim', type=int, default=512)
192
+ parser.add_argument('--transformer_heads', type=int, default=16)
193
+ parser.add_argument('--transformer_depth', type=int, default=16)
194
+ parser.add_argument('--transformer_blocks', type=int, default=1)
195
+ parser.add_argument('--transformer_dropout', type=float, default=0.1)
196
+ parser.add_argument('--transformer_reversible', type=eval, default=False)
197
+ parser.add_argument('--transformer_local_heads', type=int, default=8)
198
+ parser.add_argument('--transformer_local_size', type=int, default=128)
199
+
200
+ def get_model(args, data_shape, num_classes):
201
+
202
+ data_shape = data_shape
203
+ num_classes = num_classes
204
+ input_dp_rate = args.input_dp_rate
205
+ transformer_dim = args.transformer_dim
206
+ transformer_heads = args.transformer_heads
207
+ transformer_depth = args.transformer_depth
208
+ transformer_blocks = args.transformer_blocks
209
+ transformer_local_heads = args.transformer_local_heads
210
+ transformer_local_size = args.transformer_local_size
211
+ transformer_reversible = args.transformer_reversible
212
+ diffusion_steps = args.diffusion_steps
213
+
214
+ C, L = num_classes, data_shape[0]*data_shape[1]
215
+
216
+ print('Data shape index 0:', L)
217
+ current_shape = (L,)
218
+
219
+ class DiffTransformer(nn.Module):
220
+
221
+ def __init__(self,):
222
+
223
+ super(DiffTransformer, self).__init__()
224
+
225
+ self.transformer = LinearAttentionTransformerEmbedding(
226
+ input_dim=num_classes,
227
+ output_dim=num_classes,
228
+ dim=transformer_dim,
229
+ heads=transformer_heads,
230
+ depth=transformer_depth,
231
+ n_blocks=transformer_blocks,
232
+ max_seq_len=L,
233
+ num_timesteps=diffusion_steps,
234
+ causal=False, # no autoregression
235
+ ff_dropout=0, # dropout for feedforward NN
236
+ attn_layer_dropout=input_dp_rate, # dropout right after self-att layer
237
+ attn_dropout=0, # dropout post-attention
238
+ n_local_attn_heads=transformer_local_heads,
239
+ # number of local attention heads for (QK)*V attention.
240
+ # this can be a tuple specifying the exact number of local
241
+ # attention heads at that depth
242
+ local_attn_window_size=transformer_local_size,
243
+ # receptive field of the local attention
244
+ reversible=transformer_reversible,
245
+ # use reversible nets, from reformer paper
246
+ )
247
+
248
+ self.rezero = Rezero()
249
+
250
+ def forward(self, x, t):
251
+ x = self.transformer(x,t)
252
+ # x = x.permute(0,2,1)
253
+ # x = self.rezero(x)
254
+ return x
255
+
256
+
257
+ model = DiffTransformer()
258
+
259
+ return model
260
+
261
+
262
+
263
+
Stage3_source/eval_metrics.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ description:
3
+ metrics to compute model performance
4
+ """
5
+
6
+ import Bio
7
+ from Bio.Align import substitution_matrices
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import torch
11
+ import re
12
+
13
+ import Stage3_source.animation_tools as ani_tools
14
+
15
+
16
+ ' compute Blosum62 soft accuracy '
17
+ class blosum_soft_accuracy:
18
+
19
+ def __init__(self, ):
20
+
21
+ self.blosum62 = substitution_matrices.load("BLOSUM62")
22
+ self.alphabet = self.blosum62.alphabet
23
+
24
+ def blosum_acc(
25
+ self,
26
+ aa1: str,
27
+ aa2: str
28
+ ) -> np.single:
29
+
30
+ row = self.blosum62.alphabet.index(aa1)
31
+ col = self.blosum62.alphabet.index(aa2)
32
+ substitution_scores = self.blosum62[row, :].values()
33
+
34
+ # Apply the softmax function to the substitution scores to get a prob dist.
35
+ probs = np.exp(substitution_scores)/np.sum(np.exp(substitution_scores))
36
+
37
+ # compute the soft acc. as the dot product of the prob dist. with a one-hot encoding
38
+ # of the amino acid ...
39
+ correct_aa = aa2
40
+ correct_index = self.alphabet.index(correct_aa)
41
+ one_hot = np.zeros_like(probs)
42
+ one_hot[correct_index] = 1
43
+
44
+ # normalize acc.
45
+ soft_acc = np.dot(probs, one_hot) / np.max(probs)
46
+
47
+ return soft_acc
48
+
49
+ def split_seq(self, seq: str) ->list:
50
+ # no_pads = seq.count("<PAD>")
51
+ # split_seq = ["<START>"] + list(seq.replace("<START>","").replace("<END>","").replace("<PAD>","")) + ["<END>"] + ["<PAD>"] * no_pads
52
+ split_seq = re.split(r'(-|<START>|<END>|<PAD>|(?<=\w)(?=\w))', seq)
53
+ #split_seq = re.findall(r'<START>|<END>|<PAD>|[A-Z]|-|\*', seq)
54
+
55
+ # remove empty strings and whitespace-only elements
56
+ split_seq = [char for char in split_seq if char and char.strip()]
57
+ return split_seq
58
+
59
+
60
+
61
+ def compute_soft_accuracy(
62
+ self,
63
+ seq1_list: list,
64
+ seq2_list: list
65
+ ) -> float:
66
+
67
+ # make sure batch size matches
68
+ if len(seq1_list) == len(seq2_list):
69
+ self.batch_size = len(seq1_list)
70
+
71
+ else:
72
+ print("Please make sequence batch size equivalent...")
73
+
74
+ # make sure sequence length matches
75
+ if len(seq1_list[0]) == len(seq2_list[0]):
76
+ self.L = len(seq1_list[0])
77
+
78
+ else:
79
+ #print("Please make sequence length match...")
80
+ pass
81
+
82
+ avg_soft_acc_per_batch = 0
83
+ # loop over the batch of sequence
84
+ for seq1, seq2 in zip(seq1_list, seq2_list):
85
+
86
+ # split sequence into individual tokens
87
+ seq1 = self.split_seq(seq1)
88
+ seq2 = self.split_seq(seq2)
89
+ # set number of positions
90
+ self.L = len(seq2)
91
+ self.L_h = 0
92
+ self.L_s = 0
93
+ avg_soft_acc_per_seq = 0
94
+ avg_hard_acc_per_seq = 0
95
+
96
+ # loop over the amino acid positions
97
+ for aa1, aa2 in zip(seq1, seq2):
98
+
99
+ if (aa1 not in ['-', '<START>', '<END>', '<PAD>']) and (aa2 not in ['-', '<START>', '<END>', '<PAD>']):
100
+ self.L_s += 1
101
+ soft_acc = self.blosum_acc(aa1=aa1, aa2=aa2)
102
+ avg_soft_acc_per_seq += soft_acc
103
+ else:
104
+ self.L_h += 1
105
+ acc = 1*(aa1==aa2)
106
+ avg_hard_acc_per_seq += acc
107
+
108
+ # compute accuracy for soft positions
109
+ try:
110
+ avg_soft_acc_per_seq *= 1/self.L_s
111
+ except ZeroDivisionError:
112
+ #print("L_s cannot be zero. Setting avg_soft_acc_per_seq to zero.")
113
+ avg_soft_acc_per_seq = 0
114
+
115
+ # compute accuracy for hard positions
116
+ try:
117
+ avg_hard_acc_per_seq *= 1/self.L_h
118
+ except ZeroDivisionError:
119
+ #print("L_h cannot be zero. Setting avg_hard_acc_per_seq to zero.")
120
+ avg_hard_acc_per_seq = 0
121
+
122
+
123
+ # compute the average accuracy between soft and hard
124
+ if self.L_s == 0:
125
+ avg_soft_acc_per_batch += avg_hard_acc_per_seq
126
+ elif self.L_h == 0:
127
+ avg_soft_acc_per_batch += avg_soft_acc_per_seq
128
+ else:
129
+ avg_soft_acc_per_batch += (avg_soft_acc_per_seq + avg_hard_acc_per_seq)/2
130
+
131
+ avg_soft_acc_per_batch *= 1/self.batch_size
132
+ return avg_soft_acc_per_batch
133
+
134
+
135
+ def compute_ppl(probs: torch.Tensor) -> float:
136
+
137
+ batch_size, sequence_length, class_labels = probs.shape
138
+
139
+ # flatten batch and sequence dimensions into a single dimension
140
+ flattened_probs = probs.reshape(batch_size * sequence_length, class_labels)
141
+
142
+ # calc. perplexity for each sequence independently
143
+ ppl = []
144
+ for i in range(batch_size * sequence_length):
145
+ sequence_probs = flattened_probs[i]
146
+ # compute ppl per seq
147
+ sequence_ppl = torch.exp(-torch.sum(
148
+ sequence_probs * torch.log(sequence_probs)
149
+ )
150
+ )
151
+ ppl.append(sequence_ppl.item())
152
+
153
+ ppl = torch.tensor(ppl).view(batch_size, sequence_length) # ppl per sequence in a given batch
154
+ avg_ppl = ppl.mean().item() # average ppl per batch
155
+
156
+ return avg_ppl
157
+
158
+ def batch_compute_ppl(probs_list: list) -> float:
159
+
160
+ batch_prob = sum([
161
+ compute_ppl(probs=probs.unsqueeze(0).permute(0,2,1)) for probs in probs_list
162
+ ]) / len(probs_list)
163
+
164
+ return batch_prob
165
+
166
+
167
+ def compute_hard_acc(
168
+ seq1: str,
169
+ seq2: str
170
+ ) -> float:
171
+
172
+
173
+ hard_acc = sum([aa1 == aa2 for (aa1 ,aa2) in zip(seq1, seq2) if aa2 != '<PAD>'])
174
+ valid_length = len([aa2 for aa2 in seq2 if aa2 != '<PAD>'])
175
+ if valid_length == 0:
176
+ return 1.0
177
+
178
+ hard_acc /= valid_length
179
+
180
+ return hard_acc
181
+
182
+ #def compute_hard_acc(
183
+ # seq1: str,
184
+ # seq2: str
185
+ # ) -> float:
186
+ #
187
+ # hard_acc = sum([aa1 == aa2 for (aa1 ,aa2) in zip(seq1, seq2)])
188
+ # hard_acc *= 1/len(seq2)
189
+ # return hard_acc
190
+
191
+ def batch_hard_acc(seq1_list: list, seq2_list: list) -> float:
192
+
193
+ hard_acc = sum([
194
+ compute_hard_acc(seq1=seq1, seq2=seq2) for (seq1,seq2) in zip(seq1_list, seq2_list)
195
+ ]) / len(seq2_list)
196
+
197
+ return hard_acc
198
+
199
+
200
+ def time_split_on_seq(
201
+ seq: torch.Tensor,
202
+ sample_seq_path: torch.Tensor,
203
+ idx: torch.Tensor
204
+ ) -> (
205
+ list,
206
+ list,
207
+ list
208
+ ):
209
+
210
+
211
+ if len(seq.shape) != 2:
212
+ batch_size, class_labels, _ = seq.shape
213
+
214
+ # collect list
215
+ current_seq, prev_seq, fut_seq = [], [], []
216
+
217
+ for ii in range(batch_size):
218
+ current_stack_probs, prev_stack_probs, fut_stack_probs = [], [], []
219
+
220
+ for jj in range(class_labels):
221
+
222
+ # current probs
223
+ current_stack_probs.append(
224
+ seq[ii,jj][
225
+ (sample_seq_path.cpu()[ii] == idx.cpu()[ii])
226
+ ]
227
+ )
228
+
229
+ # prev probs
230
+ prev_stack_probs.append(
231
+ seq[ii,jj][
232
+ (sample_seq_path.cpu()[ii] < idx.cpu()[ii])
233
+ ]
234
+ )
235
+
236
+ # future probs
237
+ fut_stack_probs.append(
238
+ seq[ii,jj][
239
+ (sample_seq_path.cpu()[ii] > idx.cpu()[ii])
240
+ ]
241
+ )
242
+
243
+ current_seq.append(torch.stack(current_stack_probs))
244
+ prev_seq.append(torch.stack(prev_stack_probs))
245
+ fut_seq.append(torch.stack(fut_stack_probs))
246
+
247
+ else:
248
+ # split the sequences based on time indices
249
+ current_seq = [seq[ii][sample_seq_path[ii] == idx[ii]] for ii in range(seq.shape[0])]
250
+ prev_seq = [seq[ii][sample_seq_path[ii] < idx[ii]] for ii in range(seq.shape[0])]
251
+ fut_seq = [seq[ii][sample_seq_path[ii] > idx[ii]] for ii in range(seq.shape[0])]
252
+
253
+ return (
254
+ current_seq,
255
+ prev_seq,
256
+ fut_seq
257
+ )
258
+
259
+ @torch.no_grad()
260
+ def compute_acc_given_time_pos(
261
+ real_tokens: torch.Tensor,
262
+ sample_seq: torch.Tensor,
263
+ sample_path: torch.Tensor,
264
+ idx: torch.Tensor
265
+ ) -> (
266
+ float,
267
+ float,
268
+ float,
269
+ float,
270
+ float,
271
+ float
272
+ ):
273
+
274
+ # tokenizer
275
+ tokens = ['-', '<START>', 'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','<END>','<PAD>']
276
+ #tokens = ['<START>', 'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','<END>','<PAD>']
277
+ tokens = tokens + ['X', 'U', 'Z', 'B', 'O']
278
+
279
+
280
+ # split real tokens based on time indices
281
+ current_real_tokens, prev_real_tokens, fut_real_tokens = time_split_on_seq(
282
+ seq=real_tokens.cpu(),
283
+ sample_seq_path=sample_path.cpu(),
284
+ idx=idx.cpu()
285
+ )
286
+
287
+ # split sampled tokens based on time indices
288
+ current_sample_tokens, prev_sample_tokens, fut_sample_tokens = time_split_on_seq(
289
+ seq=sample_seq.cpu(),
290
+ sample_seq_path=sample_path.cpu(),
291
+ idx=idx.cpu()
292
+ )
293
+
294
+ # convert real sequences to characters
295
+ current_real_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in current_real_tokens]
296
+ prev_real_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in prev_real_tokens]
297
+ fut_real_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in fut_real_tokens]
298
+
299
+ # convert sample sequences to characters
300
+ current_sample_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in current_sample_tokens]
301
+ prev_sample_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in prev_sample_tokens]
302
+ fut_sample_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in fut_sample_tokens]
303
+
304
+
305
+
306
+ # drop empty entries in list (happens if t=0 or t=256)
307
+ # prev string sequences
308
+ prev_sample_chars = [item for item in prev_sample_chars if item]
309
+ prev_real_chars = [item for item in prev_real_chars if item]
310
+ # fut string sequences
311
+ fut_real_chars = [item for item in fut_real_chars if item]
312
+ fut_sample_chars = [item for item in fut_sample_chars if item]
313
+
314
+ # class object to copmute blosum62 soft acc.
315
+ soft_acc_tool = blosum_soft_accuracy()
316
+
317
+ # split real sequence
318
+ prev_real_split_chars = [
319
+ soft_acc_tool.split_seq(sample) for sample in prev_real_chars
320
+ ]
321
+ fut_real_split_chars = [
322
+ soft_acc_tool.split_seq(sample) for sample in fut_real_chars
323
+ ]
324
+
325
+ # split sample sequence
326
+ prev_sample_split_chars = [
327
+ soft_acc_tool.split_seq(sample) for sample in prev_sample_chars
328
+ ]
329
+ fut_sample_split_chars = [
330
+ soft_acc_tool.split_seq(sample) for sample in fut_sample_chars
331
+ ]
332
+
333
+ # compute hard and soft accuracy
334
+ ' soft accuracy: '
335
+ # positions < t ( aa positions)
336
+ #prev_batch_soft_acc = soft_acc_tool.compute_soft_accuracy(
337
+ # seq1_list=prev_sample_chars,
338
+ # seq2_list=prev_real_chars
339
+ #)
340
+
341
+ # positions > t ( aa positions)
342
+ #fut_batch_soft_acc = soft_acc_tool.compute_soft_accuracy(
343
+ # seq1_list=fut_sample_chars,
344
+ # seq2_list=fut_real_chars
345
+ #)
346
+
347
+ # positions = t (aa positions)
348
+ #current_soft_acc = soft_acc_tool.compute_soft_accuracy(
349
+ #seq1_list=current_sample_chars,
350
+ #seq2_list=current_real_chars
351
+ #)
352
+
353
+ prev_batch_soft_acc, fut_batch_soft_acc, current_soft_acc = 0, 0, 0
354
+
355
+ ' hard accuracy: '
356
+ # positions < t ( aa positions)
357
+ prev_batch_hard_acc = batch_hard_acc(
358
+ seq1_list=prev_sample_split_chars,
359
+ seq2_list=prev_real_split_chars
360
+ )
361
+
362
+ # positions > t ( aa positions)
363
+ fut_batch_hard_acc = batch_hard_acc(
364
+ seq1_list=fut_sample_split_chars,
365
+ seq2_list=fut_real_split_chars
366
+ )
367
+
368
+ # positions = t (aa positions)
369
+ current_hard_acc = compute_hard_acc(
370
+ seq1=current_sample_chars,
371
+ seq2=current_real_chars
372
+ )
373
+
374
+ return (
375
+ prev_batch_hard_acc,
376
+ prev_batch_soft_acc,
377
+ fut_batch_hard_acc,
378
+ fut_batch_soft_acc,
379
+ current_hard_acc,
380
+ current_soft_acc
381
+ )
382
+
383
+
384
+ @torch.no_grad()
385
+ def compute_ppl_given_time_pos(
386
+ probs: torch.Tensor,
387
+ sample_path: torch.Tensor,
388
+ idx: torch.Tensor
389
+ ) -> (
390
+ float,
391
+ float,
392
+ float
393
+ ):
394
+
395
+ current_probs, prev_probs, fut_probs = time_split_on_seq(
396
+ probs.cpu(),
397
+ sample_seq_path=sample_path.cpu(),
398
+ idx=idx.cpu()
399
+ )
400
+
401
+ # ppl at the current time position (aa_i = t)
402
+ # current_ppl = compute_ppl(probs=torch.stack(current_probs).permute(0,2,1))
403
+ current_ppl = batch_compute_ppl(probs_list=current_probs)
404
+ # ppl at the prev and fut time positions (aa_i < t and aa_i > t)
405
+ prev_ppl = batch_compute_ppl(probs_list=prev_probs)
406
+ fut_ppl = batch_compute_ppl(probs_list=fut_probs)
407
+
408
+ return (
409
+ current_ppl,
410
+ prev_ppl,
411
+ fut_ppl
412
+ )
Stage3_source/helper_funcs.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pynvml import *
2
+
3
+
4
+ """
5
+ To track memory allocation, let's take advantage of the nvidia-ml-py3 package and GPU memory allocation from python.
6
+
7
+ ref: https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one
8
+ """
9
+
10
+
11
+ def print_gpu_initialization():
12
+ nvmlInit()
13
+ handle = nvmlDeviceGetHandleByIndex(0)
14
+ info = nvmlDeviceGetMemoryInfo(handle)
15
+ print(f"GPU memory occupied: {info.used//1024**2} MB.")
16
+ return info.used // 1024**2
17
+
18
+
19
+ def print_summary(result):
20
+ print(f"Time: {result.metrics['train_runtime']:.2f}")
21
+ print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
22
+ print_gpu_utilization()
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
Stage3_source/preprocess.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from torchvision.datasets import MNIST
7
+ from torchvision.transforms import Compose, ToTensor, Resize
8
+ import torchvision.transforms as T
9
+
10
+
11
+ #from numba import jit
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+
16
+ def get_mnist_dataset(args:any) -> DataLoader:
17
+
18
+
19
+ if args.dataset == 'normal':
20
+
21
+ print(args.download)
22
+ transform = Compose([ToTensor(), Resize(args.image_size), lambda x: x > 0.5])
23
+ train_dataset = MNIST(root=args.data_root, download=True, transform=transform, train=True)
24
+ train_dataloader = DataLoader(
25
+ train_dataset,
26
+ num_workers=args.workers,
27
+ batch_size=args.batch_size,
28
+ shuffle=True,
29
+ pin_memory=True,
30
+ drop_last=True
31
+ )
32
+
33
+ elif args.dataset == 'sequence':
34
+
35
+ transform = Compose([ToTensor(), Resize(args.image_size), lambda x: x > 0.5, T.Lambda(lambda x: torch.flatten(x).unsqueeze(0))])
36
+ train_dataset = MNIST(root=args.data_root, download=True, transform=transform, train=True)
37
+ train_dataloader = DataLoader(
38
+ train_dataset,
39
+ num_workers=args.workers,
40
+ batch_size=args.batch_size,
41
+ shuffle=True,
42
+ pin_memory=True,
43
+ drop_last=True
44
+ )
45
+
46
+ else:
47
+ print('Please picker either normal or sequence')
48
+ quit()
49
+
50
+ return train_dataloader
51
+
52
+
53
+
54
+
55
+ ' Protein preprocessing tools '
56
+
57
+ #@jit(nopython=True)
58
+ def pad_ends(
59
+ seqs: list,
60
+ max_seq_length: int
61
+ ) -> list:
62
+
63
+ padded_seqs = [] # add padded gaps at the end of each sequence
64
+ for seq in seqs:
65
+
66
+ seq_length = len(seq)
67
+ # number of padded tokens
68
+ pad_need = max_seq_length - seq_length
69
+ # add number of padded tokens to the end
70
+ seq += '-'*pad_need
71
+
72
+ padded_seqs.append(seq)
73
+
74
+ return padded_seqs
75
+
76
+
77
+ # create numerical represented sqeuences
78
+ def create_num_seqs(seq_list: list) -> list:
79
+
80
+ # tokenizer
81
+ #tokens = ['*', '<START>', 'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','<END>', '-']
82
+ tokens = [ '<START>', 'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','<END>', '-']
83
+ # needed to lose these to the token list
84
+ tokens = tokens + ['X', 'U', 'Z', 'B', 'O']
85
+ token2int = {x:ii for ii, x in enumerate(tokens)}
86
+
87
+ # empty list to hold num rep. seqs.
88
+ num_seq_list = []
89
+ for seq in seq_list:
90
+ num_seq_list.append([token2int[aa] for aa in seq])
91
+
92
+ return num_seq_list
93
+
94
+ # prepare the protein sequences
95
+ def prepare_protein_data(
96
+ args: any,
97
+ data_dict: dict
98
+ ) -> (
99
+ list,
100
+ list
101
+ ):
102
+
103
+ print([key for key in data_dict.keys()])
104
+
105
+ print('Prepare dataset')
106
+ # prepare sequences
107
+ seq_list = [seq.replace('-','') for seq in data_dict[args.sequence_keyname]]
108
+ seq_list = [['<START>'] + list(seq) + ['<END>'] for seq in seq_list]
109
+ seq_lens = [len(seq) for seq in seq_list]
110
+
111
+ # Determine the maximum sequence length based on context window size
112
+ max_seq_len = int(args.diffusion_steps)
113
+
114
+ # Get indices of sequences that meet the criteria
115
+ valid_indices = [i for i, seq in enumerate(seq_list) if len(seq) <= max_seq_len]
116
+
117
+ # Filter num_seq_list based on these indices
118
+ filter_seq_list = [seq_list[i] for i in valid_indices]
119
+
120
+ max_seq_len = int(args.image_size * args.image_size)
121
+ padded_seq_list = pad_ends(
122
+ seqs=filter_seq_list,
123
+ max_seq_length=max_seq_len
124
+ )
125
+ num_seq_list = create_num_seqs(padded_seq_list) # numerical representations
126
+
127
+ # prepare class labels
128
+ #class_label_list = df.label.values.tolist()
129
+ if args.facilitator in ['MSE', 'MMD']:
130
+ text_emb = data_dict['text_to_protein_embedding']
131
+ elif args.facilitator in ['Default']:
132
+ text_emb = data_dict['text_embedding']
133
+ else:
134
+ raise ValueError(f"Unexpected value for 'facilitator': {args.facilitator}")
135
+
136
+ text_emb = [text_emb[i] for i in valid_indices]
137
+ # prune sequence and texts out based on length
138
+
139
+ print('Finished preparing dataset')
140
+ #
141
+
142
+
143
+ return (
144
+ num_seq_list,
145
+ text_emb
146
+ )
147
+
148
+
149
+ class protein_dataset(Dataset):
150
+ """
151
+
152
+ Sequence dataloader
153
+
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ num_seq_list: list,
159
+ text_emb: torch.Tensor
160
+ ):
161
+
162
+ if not torch.is_tensor(num_seq_list):
163
+ self.num_seqs = torch.tensor(num_seq_list).float()
164
+
165
+ else:
166
+ pass
167
+
168
+ self.text_emb = text_emb
169
+
170
+ #if not torch.is_tensor(class_label_list):
171
+ # self.class_label = torch.tensor(class_label_list).float()
172
+
173
+ def __len__(self,):
174
+ """
175
+ number of samples total
176
+ """
177
+ return len(self.num_seqs)
178
+
179
+ def __getitem__(self, idx: any) -> (
180
+ torch.FloatTensor,
181
+ torch.FloatTensor
182
+ ):
183
+
184
+ """
185
+ extract adn return the data batch samples
186
+ """
187
+
188
+ # convert and return the data batch samples
189
+ if torch.is_tensor(idx):
190
+ idx = idx.tolist()
191
+
192
+ # sequences
193
+ num_seqs = self.num_seqs[idx]
194
+ # class labels
195
+ text_emb = self.text_emb[idx]
196
+
197
+ return (
198
+ num_seqs,
199
+ text_emb
200
+ )
Stage3_source/sampling_analysis.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+ import pandas as pd
5
+ import math
6
+ from tqdm import tqdm
7
+ import time
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader
13
+
14
+ import Stage3_source.preprocess as prep
15
+ import Stage3_source.cond_diff_transformer_layer as mod
16
+ import Stage3_source.transformer_training_helper as train_helper
17
+
18
+
19
+
20
+ # generate missing pixels with one shot
21
+ @torch.no_grad()
22
+ def cond_autocomplete_real_samples(
23
+ model: nn.Module,
24
+ args: any,
25
+ realization: torch.Tensor,
26
+ y_c: torch.Tensor,
27
+ idx: torch.Tensor
28
+ ) -> (
29
+ any,
30
+ torch.Tensor,
31
+ torch.Tensor,
32
+ torch.Tensor,
33
+ torch.Tensor
34
+ ):
35
+
36
+ model.eval()
37
+ bs, channel, seq_length = realization.size()
38
+ # get a batch of random sampling paths
39
+ sampled_random_path = train_helper.sample_random_path(bs, seq_length, device=args.device)
40
+ # create a mask that masks the locations where we've already sampled
41
+ random_path_mask = train_helper.create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length)
42
+ # tokenize realizations
43
+ real_tokens, bs, seq_length= train_helper.create_token_labels(args, realization)
44
+ #real_tokens = realization.clone().squeeze(1)
45
+
46
+ # mask realizations
47
+ real_token_masked = train_helper.mask_realizations(real_tokens, random_path_mask)
48
+ # conditional probability
49
+ conditional_prob, probs = train_helper.cond_predict_conditional_prob(model, real_token_masked, y_c, idx, args)
50
+ # evaluate the value of the log probability for the given realization:
51
+ log_prob = train_helper.log_prob_of_realization(args, conditional_prob, real_tokens)
52
+
53
+ return (
54
+ conditional_prob,
55
+ probs.cpu(),
56
+ real_token_masked.cpu(),
57
+ real_tokens.cpu(),
58
+ log_prob.cpu(),
59
+ sampled_random_path.cpu(),
60
+ random_path_mask.cpu()
61
+ )
62
+
63
+
64
+ # get the label for the corresponding sequence in the dataloader
65
+ def extract_samples_with_labels(
66
+ dataloader: DataLoader,
67
+ target_labels: int,
68
+ total_num: int,
69
+ pad_included: bool=False
70
+ ) -> dict:
71
+
72
+ extracted_sampled = {
73
+ 'sample': [],
74
+ 'label': []
75
+ }
76
+
77
+ for data, labels in dataloader:
78
+ for i, label in enumerate(labels):
79
+
80
+ if label.item() == target_labels:
81
+
82
+ if pad_included:
83
+ pass
84
+ else:
85
+ data[i] += 1 # account for the absorbing state (i.e. make room)
86
+
87
+ extracted_sampled['sample'].append(data[i]) # account for abosrbed state
88
+ extracted_sampled['label'].append(label)
89
+ if len(extracted_sampled['label']) == total_num:
90
+ return extracted_sampled
91
+
92
+ return extracted_sampled
93
+
94
+
95
+ # mask a given percentage of the sample
96
+ def corrupt_samples(
97
+ args: any,
98
+ realization: torch.Tensor,
99
+ perc: float
100
+ ) -> torch.Tensor:
101
+
102
+ bs, channels, seq_length = realization.size()
103
+
104
+ # number of samples to corrupt (i.e. idx)
105
+ idx = (args.diffusion_steps * torch.Tensor([perc])).to(int).to(args.device)
106
+ # get a batch of random sampling paths
107
+ sampled_random_path = train_helper.sample_random_path(bs, seq_length, device=args.device)
108
+ # we create a mask that masks the locations where we've already sampled
109
+ random_path_mask = train_helper.create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length)
110
+ # tokenize realizations
111
+ real_tokens, bs, seq_length= train_helper.create_token_labels(args, realization)
112
+ # mask realizations
113
+ real_token_masked = train_helper.mask_realizations(real_tokens, random_path_mask)
114
+
115
+ return (
116
+ real_token_masked,
117
+ sampled_random_path,
118
+ idx
119
+ )
120
+
121
+ # inpaint missing regions by predicting the next position
122
+ @torch.no_grad()
123
+ def predict_next_index(
124
+ model: nn.Module,
125
+ args: any,
126
+ mask_realization: torch.Tensor,
127
+ y_c: torch.Tensor,
128
+ idx: torch.Tensor
129
+ ) -> (
130
+ any,
131
+ torch.Tensor,
132
+ torch.Tensor,
133
+ torch.Tensor,
134
+ torch.Tensor,
135
+ torch.Tensor
136
+ ):
137
+
138
+ model.eval()
139
+ bs, channel, seq_length = mask_realization.size()
140
+
141
+ # conditional prob
142
+ conditional_prob, probs = train_helper.cond_predict_conditional_prob(model, mask_realization.squeeze(1), y_c, idx, args)
143
+
144
+ return (
145
+ conditional_prob,
146
+ probs.cpu(),
147
+ )
148
+
149
+
150
+
151
+
152
+ def generate_denoised_sampled(
153
+ args: any,
154
+ model: nn.Module,
155
+ extract_digit_samples: torch.Tensor,
156
+ extract_time: torch.Tensor,
157
+ extract_digit_label: torch.Tensor,
158
+ sampling_path: torch.Tensor
159
+ ) -> (
160
+ list,
161
+ list
162
+ ):
163
+
164
+ mask_realization_list, time_idx_list = [], []
165
+
166
+ # prepare data
167
+ temp_y_c = extract_digit_label.to(args.device)
168
+ temp_mask_realization = extract_digit_samples.unsqueeze(1).long().to(args.device)
169
+ temp_idx = torch.Tensor([extract_time]).to(args.device).squeeze(0)
170
+ temp_sampling_path = sampling_path.to(args.device)
171
+
172
+ for ii in tqdm(range(int(temp_idx.item()), args.diffusion_steps)):
173
+
174
+ # where we need to sample next
175
+ current_location = temp_sampling_path == temp_idx
176
+ print(current_location.shape)
177
+
178
+ # make position prediction
179
+ conditional_prob, prob = predict_next_index(
180
+ model=model,
181
+ args=args,
182
+ mask_realization=temp_mask_realization,
183
+ y_c=temp_y_c,
184
+ idx=temp_idx
185
+ )
186
+
187
+ # get the label for the next token position
188
+ next_temp_realization = torch.argmax(
189
+ conditional_prob.sample(), dim=-1
190
+ )
191
+
192
+ temp_mask_realization[0, current_location] = next_temp_realization[current_location]
193
+ mask_realization_list.append(temp_mask_realization.cpu().numpy())
194
+ time_idx_list.append(temp_idx.cpu().numpy())
195
+ temp_idx+=1
196
+
197
+
198
+ return (
199
+ mask_realization_list,
200
+ time_idx_list
201
+ )
202
+
203
+
204
+ def batch_generate_denoised_sampled(
205
+ args: any,
206
+ model: nn.Module,
207
+ extract_digit_samples: torch.Tensor,
208
+ extract_time: torch.Tensor,
209
+ extract_digit_label: torch.Tensor,
210
+ sampling_path: torch.Tensor
211
+ ) -> (list, list):
212
+
213
+ # Ensure batch dimension consistency across input tensors
214
+ assert extract_digit_samples.size(0) == extract_digit_label.size(0) == sampling_path.size(0) == extract_time.size(0), "Mismatched batch dimensions"
215
+
216
+ batch_size = extract_digit_samples.size(0)
217
+ mask_realization_list, time_idx_list = [], []
218
+ print('batch_size:', batch_size)
219
+
220
+ # Prepare data
221
+ temp_y_c = extract_digit_label.to(args.device)
222
+ temp_mask_realization = extract_digit_samples.unsqueeze(1).long().to(args.device)
223
+ temp_idx = extract_time.unsqueeze(-1).to(args.device) # Adding an extra dimension for batch processing
224
+ temp_sampling_path = sampling_path.to(args.device)
225
+ print(f"Starting temp_idx: {temp_idx[0].item()}")
226
+
227
+ start_time_index = temp_idx[0].item() # assume all temp_idx is the same values
228
+ max_diffusion_step = args.diffusion_steps # max number of timesteps
229
+
230
+
231
+ for ii in tqdm(range(start_time_index, max_diffusion_step), initial=start_time_index, total=max_diffusion_step):
232
+
233
+ # Check if any temp_idx has reached or exceeded diffusion_steps
234
+ if torch.any(temp_idx >= args.diffusion_steps):
235
+ break
236
+
237
+ # Broadcast ii to match the batch size
238
+ current_ii = torch.full((batch_size,), ii, dtype=torch.long, device=args.device)
239
+
240
+ # Make position prediction
241
+ conditional_prob, prob = predict_next_index(
242
+ model=model,
243
+ args=args,
244
+ mask_realization=temp_mask_realization,
245
+ y_c=temp_y_c,
246
+ idx=temp_idx
247
+ )
248
+
249
+
250
+ # Get the label for the next token position
251
+ next_temp_realization = torch.argmax(conditional_prob.sample(), dim=-1)
252
+
253
+ # Update temp_mask_realization for each item in the batch
254
+ current_location = temp_sampling_path == temp_idx # Adding an extra dimension for comparison
255
+ current_location = torch.argmax(current_location.detach().cpu()*1, dim=-1)
256
+ temp_mask_realization[:, 0, current_location] = next_temp_realization[:,current_location]
257
+
258
+ # Append results for each item in the batch
259
+ mask_realization_list.append(temp_mask_realization.cpu().numpy())
260
+ time_idx_list.append(temp_idx.cpu().numpy())
261
+
262
+ # Increment temp_idx for the next iteration
263
+ temp_idx += 1
264
+
265
+ return mask_realization_list, time_idx_list
266
+
267
+
268
+
269
+ # convert sequence with numerical variables into character letters
270
+ def convert_num_to_chars(
271
+ tokenizer: any,
272
+ num_seq: list
273
+ ) -> list:
274
+
275
+ char_seq = [tokenizer[num] for num in num_seq]
276
+ return "".join(char_seq)
Stage3_source/transformer_sampling_helper.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from pathlib import Path
3
+ import numpy as np
4
+ from tqdm.auto import tqdm
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ from torch.distributions import OneHotCategorical
9
+ from torch.distributions import Categorical
10
+
11
+
12
+
Stage3_source/transformer_training_helper.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from pathlib import Path
3
+ import numpy as np
4
+ from tqdm.auto import tqdm
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ from torch.distributions import OneHotCategorical
9
+ from torch.distributions import Categorical
10
+
11
+ import Stage3_source.eval_metrics as eval_funcs
12
+
13
+ # functions adapted for token-based transformers instead of Unet images (hat tip to author: LukasMosser)
14
+
15
+ ' sample random paths '
16
+ def sample_random_path(
17
+ batch_size: int,
18
+ seq_length: int,
19
+ device: str='device'
20
+ ) -> torch.Tensor:
21
+
22
+ # create a batch of random sampling paths
23
+ random_paths = torch.stack(
24
+ [torch.randperm(seq_length, device=device) for _ in range(batch_size)],
25
+ axis=0
26
+ )
27
+ # sequential paths
28
+ #random_paths = torch.stack(
29
+ # [torch.arange(seq_length, device=device) for _ in range(batch_size)],
30
+ # axis=0
31
+ #)
32
+ return random_paths
33
+
34
+ ' create masks to indicate positions that we have sampled already '
35
+ def create_mask_at_random_path_index(
36
+ sample_random_path: torch.Tensor,
37
+ idx: any,
38
+ batch_size: int,
39
+ seq_length: int
40
+ ) -> torch.Tensor:
41
+
42
+ # create a mask that has 1s everywhere we've sampled and 0's everywhere else
43
+ mask = (sample_random_path < idx)
44
+ return mask
45
+
46
+ ' create a (batched) mask of where we are now sampling '
47
+ def create_sampling_location_mask(
48
+ sampled_random_path: torch.Tensor,
49
+ idx: any,
50
+ batch_size: int,
51
+ seq_length: int
52
+ ) -> torch.Tensor:
53
+
54
+ # create a binary mask that has 1 at the current location for us to sample
55
+ sampling_location_mask = (sampled_random_path == idx).long()
56
+ return sampling_location_mask
57
+
58
+ ' create masks to indicate positions beyond the current sampling position '
59
+ def create_mask_at_future_path_index(
60
+ sampled_random_path: torch.Tensor,
61
+ idx: any,
62
+ batch_size: int,
63
+ seq_length: int
64
+ ) -> torch.Tensor:
65
+
66
+ # create a mask that has 1s everywhere were are not going to be sampling and
67
+ # 0's everywhere we previously and currently sampled
68
+ sampling_future_mask = (sampled_random_path > idx).long()
69
+ return sampling_future_mask
70
+
71
+ ' sampling from the probability distribution '
72
+ def sample_from_conditional(conditional_prob: any) -> torch.Tensor:
73
+ # sample from the categorical dist.
74
+ return conditional_prob.sample().permute(0,2,1)
75
+
76
+ ' compute entropy of the model predicted probability distribution '
77
+ def compute_entropy(conditional_prob: any) -> torch.Tensor:
78
+ # we can directly compute the entropy of the categorical distribution
79
+ return conditional_prob.entropy()
80
+
81
+ ' sampling the time trajectory '
82
+ class exp_weight_time_sample:
83
+
84
+ def __init__(self, timesteps: int, decay_rate: float):
85
+
86
+ self.timesteps = timesteps
87
+ self.decay_rate = decay_rate
88
+ # compute the weight based on the exp function
89
+ self.weights = torch.tensor(
90
+ [torch.exp(-torch.tensor([i])*decay_rate) for i in range(self.timesteps)]
91
+ )
92
+
93
+ # normalize weights
94
+ self.weights /= self.weights.sum()
95
+
96
+ def sample(self, batch_size: int) -> torch.Tensor:
97
+ # generate random samples
98
+ samples = torch.multinomial(self.weights, batch_size, replacement=True)
99
+ return samples
100
+
101
+ def sample_random_index_for_sampling(
102
+ batch_size: int,
103
+ seq_length: int,
104
+ device: str='cuda',
105
+ option: str='random'
106
+ ) -> any:
107
+
108
+ if option == 'random':
109
+ # sample a random index where we want to sample next
110
+ idx = torch.randint(
111
+ low=0,
112
+ high=seq_length+1,
113
+ size=(batch_size,1),
114
+ device=device,
115
+ requires_grad=False
116
+ )
117
+
118
+ elif option == 'weighted':
119
+ time_sampler = exp_weight_time_sampler(timesteps=seq_length+1, decay_rate=0.005)
120
+ # sample a weighted random index where we want to sample next
121
+ idx = time_sampler.sample(batch_size=batch_size).unsqueeze(1).to(device)
122
+
123
+ return idx
124
+
125
+ #' log probs from realization '
126
+ def log_prob_of_realization(
127
+ args: any,
128
+ conditional_prob: any,
129
+ real_tokens: torch.Tensor
130
+ ) -> torch.Tensor:
131
+ # compute the log-prob of a given realization
132
+ #log_prob = conditional_prob._categorical.log_prob(real_tokens.to(args.device))
133
+ log_prob = conditional_prob._categorical.log_prob(real_tokens)
134
+ # log_prob = conditional_prob.log_prob(real_tokens.to(args.device))
135
+ return log_prob
136
+
137
+
138
+ #' get the log probabilities of the unsampled locations '
139
+ #def log_prob_of_unsampled_locations(
140
+ # log_prob: torch.Tensor,
141
+ # token_mask: torch.Tensor,
142
+ # real_tokens: torch.Tensor
143
+ # ) -> torch.Tensor:
144
+ #
145
+ # # unsampled token positions (i.e. absorbing states)
146
+ # unsampled_mask = (token_mask == 0) * 1
147
+ # # non-padded tokens
148
+ # non_padded_mask = (real_tokens != 23) * 1
149
+ # # final mask is absorbing states that do not belong to padded tokens
150
+ # final_unsampled_mask = unsampled_mask & non_padded_mask
151
+ # # compute the total log prob of the unsampled locations, taking sum over log-probs
152
+ # log_prob_unsampled = ( final_unsampled_mask * log_prob)
153
+ # # sum log probs at absorbing positions
154
+ # summed_log_prob_unsampled = log_prob_unsampled.sum(1)
155
+ #
156
+ # return summed_log_prob_unsampled
157
+
158
+
159
+ ' get the log probabilities of the unsampled locations '
160
+ def log_prob_of_unsampled_locations(
161
+ log_prob: torch.Tensor,
162
+ token_mask: torch.Tensor
163
+ ) -> torch.Tensor:
164
+
165
+ # copmute the total log prob of the unsampled locations, taking sum over log-probs
166
+ log_prob_unsampled = ((token_mask == 0)*1 * log_prob)
167
+
168
+ return log_prob_unsampled.sum(1)
169
+
170
+ ' weight the unsampeld log probs '
171
+ def weight_log_prob(
172
+ log_prob_unsampled: torch.Tensor,
173
+ idx: any,
174
+ seq_length
175
+ ) -> torch.Tensor:
176
+ # compute the average log-prob over the unsampled locations
177
+ log_prob_weighted = 1/(seq_length - idx.squeeze(1) + 1) * log_prob_unsampled
178
+ return log_prob_weighted
179
+
180
+ ' get mean log prob over the batch '
181
+ def compute_average_loss_for_batch(log_prob_weighted: torch.Tensor) -> torch.Tensor:
182
+ # copute a (negative) average over the batch elements to copmute an unbiased estimator of the loss
183
+ loss = -log_prob_weighted.mean()
184
+ return loss
185
+
186
+ ' create the numerical tokenized input data for transformer '
187
+ def create_token_labels(args, realization) -> (
188
+ torch.Tensor,
189
+ int,
190
+ int
191
+ ):
192
+
193
+ bs, channel, seq_length = realization.size()
194
+ temp_real = realization.reshape(bs, channel, seq_length)*1
195
+
196
+ if args.task == 'MNIST':
197
+ real_tokens = (temp_real == 1)*2 + (temp_real == 0)*1 # numerical tokeni labels for mnist
198
+
199
+ elif args.task == 'proteins':
200
+ real_tokens = temp_real + 1
201
+ # background --> label 1
202
+ # foreground --> label 2
203
+ # mask (absorbing state) --> label 0
204
+ return (
205
+ real_tokens.squeeze(1),
206
+ bs,
207
+ seq_length
208
+ )
209
+
210
+ ' mask the positions for predictions/denoising '
211
+ def mask_realizations(
212
+ real_tokens: torch.Tensor,
213
+ random_path_mask: torch.Tensor
214
+ ) -> torch.Tensor:
215
+
216
+ out_real_tokens = real_tokens.clone()
217
+ # batch size
218
+ bs = random_path_mask.shape[0]
219
+ # convert random path to boolean
220
+ bool_rand_path_mask = random_path_mask.to(dtype=torch.bool)
221
+ # positional masks
222
+ # mask the future sample positions
223
+ future_mask_positions = ((~bool_rand_path_mask)*1).squeeze(1)
224
+
225
+ for ii in range(bs):
226
+
227
+ mask_positions = future_mask_positions[ii].nonzero().tolist()
228
+ # insert mask tokens
229
+ out_real_tokens[ii, mask_positions] = 0
230
+
231
+ return out_real_tokens
232
+
233
+
234
+ ' model prediction '
235
+ def predict_conditional_prob(
236
+ model: nn.Module,
237
+ real_token_masked: torch.Tensor,
238
+ idx: any,
239
+ args: any
240
+ ) -> (
241
+ any,
242
+ torch.Tensor
243
+ ):
244
+ #logits = model(x=real_token_masked.to(args.device), t=idx.view(-1,))
245
+ logits = model(x=real_token_masked, t=idx.view(-1,))
246
+ probs = F.softmax(
247
+ logits,
248
+ dim=1
249
+ )
250
+
251
+ conditional_prob = OneHotCategorical(probs=probs.permute(0,2,1))
252
+
253
+ return (
254
+ conditional_prob,
255
+ probs
256
+ )
257
+
258
+
259
+ """
260
+ Here, we compute the previous position tokens, current token position, and future token positions, where
261
+ past, current, and future are defined by the time trajectory.
262
+ """
263
+
264
+ ' sample from model '
265
+ @torch.no_grad()
266
+ def sample_from_conditional(conditional_prob: any) -> torch.Tensor:
267
+ # draw a sample from the categorical dist.
268
+ cond_prob_sample = conditional_prob.sample().permute(0,2,1)
269
+ return cond_prob_sample
270
+
271
+ ' compute the accuracy at the current sampling location '
272
+ @torch.no_grad()
273
+ def sample_recover(
274
+ real_tokens: torch.Tensor,
275
+ cond_prob_sample: torch.Tensor,
276
+ current_path_mask: torch.Tensor
277
+ ) -> float:
278
+
279
+ # remove from gpu
280
+ real_tokens.cpu()
281
+ cond_prob_sample.cpu()
282
+ current_path_mask.cpu()
283
+
284
+ # current sampling index
285
+ current_tensor_pos = torch.argmax((current_path_mask == 1)*1, dim=-1)
286
+
287
+ # model predictions match the ground truth label at current sampling index
288
+ match_preds = [(
289
+ real_tokens[seq_idx, ii] == torch.argmax(cond_prob_sample, dim=1)[seq_idx, ii]
290
+ ).item()*1 for seq_idx, ii in enumerate(current_tensor_pos.cpu().numpy())
291
+ ]
292
+
293
+ return sum(match_preds)/len(match_preds)
294
+
295
+
296
+ ' compute the accuracy of previous conditionally sampled locations '
297
+ @torch.no_grad()
298
+ def compute_prev_token_acc(
299
+ cond_real_tokens: torch.Tensor,
300
+ cond_prob_sample: torch.Tensor,
301
+ path_mask: torch.Tensor
302
+ ) -> np.ndarray:
303
+
304
+ # remove from gpu
305
+ cond_real_tokens.cpu()
306
+ cond_prob_sample.cpu()
307
+ path_mask.cpu()
308
+
309
+ # class labels of the sampled model prediction
310
+ cond_sample_tokens = torch.argmax(cond_prob_sample, dim=1)
311
+ matches = []
312
+ for ii , sample_pos in enumerate(path_mask):
313
+
314
+ temp_real_tokens = cond_real_tokens[ii, sample_pos.nonzero()].squeeze(1)
315
+ temp_sample_tokens = cond_sample_tokens[ii, sample_pos.nonzero()].squeeze(1)
316
+ matches.append(
317
+ (temp_real_tokens == temp_sample_tokens).tolist()
318
+ )
319
+
320
+ acc = []
321
+ for match in matches:
322
+
323
+ try:
324
+ acc.append(sum(match*1)/len(match))
325
+
326
+ except ZeroDivisionError:
327
+ acc.append(0)
328
+
329
+ return np.mean(acc)
330
+
331
+
332
+
333
+ ' compute the accuracy of previous conditionally sampled locations '
334
+ @torch.no_grad()
335
+ def compute_future_token_acc(
336
+ cond_real_tokens: torch.Tensor,
337
+ cond_prob_sample: torch.Tensor,
338
+ path_mask: torch.Tensor
339
+ ) -> np.ndarray:
340
+
341
+ # remove from gpu
342
+ cond_real_tokens.cpu()
343
+ cond_prob_sample.cpu()
344
+ path_mask.cpu()
345
+
346
+ # class labels of the sampled model prediction
347
+ cond_sample_tokens = torch.argmax(cond_prob_sample, dim=1)
348
+ matches = []
349
+ for ii, sample_pos in enumerate(path_mask):
350
+
351
+ temp_real_tokens = cond_real_tokens[ii, sample_pos.nonzero()].squeeze(1)
352
+ temp_sample_tokens = cond_sample_tokens[ii, sample_pos.nonzero()].squeeze(1)
353
+ matches.append(
354
+ (temp_real_tokens == temp_sample_tokens).tolist()
355
+ )
356
+
357
+ acc = []
358
+ for match in matches:
359
+ try:
360
+ acc.append(sum(match*1)/len(match))
361
+ except ZeroDivisionError:
362
+ acc.append(0)
363
+ return np.mean(acc)
364
+
365
+ @torch.no_grad()
366
+ def compute_pos_entropy(probs: torch.Tensor) -> torch.Tensor:
367
+
368
+ # average positional entropy
369
+ pos_entropy = torch.mean(torch.mean(-probs * torch.log(probs), dim = 1), dim = 0)
370
+ return pos_entropy
371
+
372
+
373
+ def elbo_objective(
374
+ model: nn.Module,
375
+ realization: torch.Tensor,
376
+ args: any
377
+ ) -> (
378
+ torch.Tensor,
379
+ float,
380
+ float,
381
+ float,
382
+ torch.Tensor
383
+ ):
384
+
385
+ bs, channel, seq_length = realization.size()
386
+
387
+ # get a batch of random sampling paths
388
+ sampled_random_path = sample_random_path(bs, seq_length, device=args.device)
389
+ # sample a set of random sampling steps for each individual training image in the current batch
390
+ idx = sample_random_index_for_sampling(bs, seq_length, device=args.device, option='random')
391
+ # we create a mask that masks the locations wher we've already sampled
392
+ random_path_mask = create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length)
393
+ # create a mask that masks the locations where are currently sampling
394
+ current_path_mask = create_sampling_location_mask(sampled_random_path, idx, bs, seq_length)
395
+ # future samplign locations (i.e. >t)
396
+ future_path_mask = create_mask_at_future_path_index(sampled_random_path, idx, bs, seq_length)
397
+ # tokenize realizations
398
+ real_tokens, bs, seq_length = create_token_labels(args, realization)
399
+ # mask realizations
400
+ real_token_masked = mask_realizations(real_tokens, random_path_mask)
401
+ # conditional probs
402
+ conditional_prob, probs = predict_conditional_prob(model, real_token_masked, idx, args)
403
+ # evaluate the value of the log prob for the given realization
404
+ log_prob = log_prob_of_realization(args, conditional_prob, real_tokens)
405
+ # compute an average over all the unsampled locations for each image in the batch
406
+ #log_prob_unsampled = log_prob_of_unsampled_locations(log_prob.to(args.device), real_token_masked.to(args.device))
407
+ log_prob_unsampled = log_prob_of_unsampled_locations(log_prob, real_token_masked)
408
+ # compute an average over all the unsampled locations for each image in the batch
409
+ log_prob_weighted = weight_log_prob(log_prob_unsampled, idx, seq_length)
410
+ # compute an average loss i.e. negative average log likelihood over teh batch elements
411
+ loss = compute_average_loss_for_batch(log_prob_weighted)
412
+
413
+
414
+ # compute metrics
415
+ cond_prob_sample = sample_from_conditional(conditional_prob)
416
+ acc = sample_recover(real_tokens, cond_prob_sample, current_path_mask)
417
+ prev_acc = compute_prev_token_acc(real_tokens, cond_prob_sample, random_path_mask)
418
+ future_acc = compute_future_token_acc(real_tokens, cond_prob_sample, future_path_mask)
419
+ # average positional entropy
420
+ pos_entropy = compute_pos_entropy(probs=probs)
421
+
422
+ return (
423
+ loss,
424
+ acc,
425
+ prev_acc,
426
+ future_acc,
427
+ pos_entropy
428
+ )
429
+
430
+
431
+ ' model prediction with class conditional '
432
+ def cond_predict_conditional_prob(
433
+ model: nn.Module,
434
+ real_token_masked: torch.Tensor,
435
+ y_c: torch.Tensor,
436
+ idx: any,
437
+ args: any
438
+ ) -> (
439
+ any,
440
+ torch.Tensor
441
+ ):
442
+ #logits = model(x=real_token_masked.to(args.device), t=idx.view(-1,), y_c=y_c)
443
+ logits = model(x=real_token_masked, t=idx.view(-1,), y_c=y_c)
444
+ probs = F.softmax(
445
+ logits,
446
+ dim=1
447
+ )
448
+
449
+ conditional_prob = OneHotCategorical(probs=probs.permute(0,2,1))
450
+ # conditional_prob = Categorical(probs=probs.permute(0,2,1))
451
+
452
+ return (
453
+ conditional_prob,
454
+ probs
455
+ )
456
+
457
+
458
+ def cond_elbo_objective(
459
+ model: nn.Module,
460
+ realization: torch.Tensor,
461
+ y_c: torch.Tensor,
462
+ args: any,
463
+ iteration: int
464
+ ) -> (
465
+ torch.Tensor,
466
+ tuple
467
+ ):
468
+
469
+ bs, channel, seq_length = realization.size()
470
+
471
+ # get a batch of random sampling paths
472
+ sampled_random_path = sample_random_path(bs, seq_length, device=args.device)
473
+ # sample a set of random sampling steps for each individual training samples in the current batch
474
+ idx = sample_random_index_for_sampling(bs, seq_length, device=args.device, option='random')
475
+ # we create a mask that masks the locations wher we've already sampled
476
+ random_path_mask = create_mask_at_random_path_index(sampled_random_path, idx, bs, seq_length)
477
+ # create a mask that masks the locations where are currently sampling
478
+ current_path_mask = create_sampling_location_mask(sampled_random_path, idx, bs, seq_length)
479
+ # future samplign locations (i.e. >t)
480
+ future_path_mask = create_mask_at_future_path_index(sampled_random_path, idx, bs, seq_length)
481
+ # tokenize realizations
482
+ real_tokens, bs, seq_length = create_token_labels(args,realization)
483
+ #real_tokens = realizations.clone().squeeze(1)
484
+ # mask realizations
485
+ real_token_masked = mask_realizations(real_tokens, random_path_mask)
486
+ # conditional probs
487
+ conditional_prob, probs = cond_predict_conditional_prob(model, real_token_masked, y_c, idx, args)
488
+ # evaluate the value of the log prob for the given realization
489
+ log_prob = log_prob_of_realization(args, conditional_prob, real_tokens)
490
+ # compute an average over all the unsampled locations for each image in the batch
491
+ #log_prob_unsampled = log_prob_of_unsampled_locations(log_prob.to(args.device), real_token_masked.to(args.device))
492
+ log_prob_unsampled = log_prob_of_unsampled_locations(log_prob, real_token_masked)
493
+ #log_prob_unsampled = log_prob_of_unsampled_locations(log_prob, real_token_masked, real_tokens)
494
+
495
+ # compute an average over all the unsampled locations for each image in the batch
496
+ log_prob_weighted = weight_log_prob(log_prob_unsampled, idx, seq_length)
497
+ # compute an average loss i.e. negative average log likelihood over teh batch elements
498
+ loss = compute_average_loss_for_batch(log_prob_weighted)
499
+
500
+ # compute metrics
501
+ if iteration % args.enter_eval == 0:
502
+
503
+
504
+ with torch.no_grad():
505
+
506
+ # compute accuracy given time position
507
+ sample_seq = torch.argmax(sample_from_conditional(conditional_prob), dim=1) # create numerical token sequences
508
+
509
+ # convert to cpu
510
+ real_tokens = real_tokens.cpu()
511
+ sample_seq = sample_seq.cpu()
512
+ idx = idx.cpu()
513
+ sampled_random_path = sampled_random_path.cpu()
514
+ probs = probs.cpu()
515
+
516
+
517
+ prev_B_hard_acc, prev_B_soft_acc, fut_B_hard_acc, fut_B_soft_acc, current_B_hard_acc, current_B_soft_acc = eval_funcs.compute_acc_given_time_pos(
518
+ real_tokens=real_tokens,
519
+ sample_seq=sample_seq,
520
+ sample_path=sampled_random_path,
521
+ idx=idx
522
+ )
523
+
524
+ # copmute ppl given time position
525
+ current_ppl, prev_ppl, fut_ppl = eval_funcs.compute_ppl_given_time_pos(
526
+ probs=probs,
527
+ sample_path=sampled_random_path,
528
+ idx=idx
529
+ )
530
+
531
+ # average positional entropy
532
+ pos_entropy = compute_pos_entropy(probs=probs).mean().item()
533
+
534
+
535
+ metric_evals = (
536
+ prev_B_hard_acc,
537
+ prev_B_soft_acc,
538
+ fut_B_hard_acc,
539
+ fut_B_soft_acc,
540
+ current_B_hard_acc,
541
+ current_B_soft_acc,
542
+ current_ppl,
543
+ prev_ppl,
544
+ fut_ppl,
545
+ pos_entropy
546
+ )
547
+
548
+ else:
549
+ metric_evals = (None)
550
+
551
+ return (
552
+ loss,
553
+ metric_evals
554
+ )
555
+
556
+
557
+
run_ProteoScribe_sample.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import json
3
+ import pandas as pd
4
+ import argparse
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import pytorch_lightning as pl
10
+ import Stage3_source.PL_wrapper as Stage3_PL_mod
11
+ import Stage3_source.cond_diff_transformer_layer as Stage3_mod
12
+ import Stage3_source.sampling_analysis as Stage3_sample_tools
13
+ import Stage3_source.animation_tools as Stage3_ani_tools
14
+
15
+
16
+ # Step 1: Load JSON configuration
17
+ def load_json_config(json_path):
18
+ """
19
+ Load JSON configuration file.
20
+ """
21
+ with open(json_path, "r") as f:
22
+ config = json.load(f)
23
+ # print("Loaded JSON config:", config)
24
+ return config
25
+
26
+ # Step 2: Convert JSON dictionary to Namespace
27
+ def convert_to_namespace(config_dict):
28
+ """
29
+ Recursively convert a dictionary to an argparse Namespace.
30
+ """
31
+ for key, value in config_dict.items():
32
+ if isinstance(value, dict): # Recursively handle nested dictionaries
33
+ config_dict[key] = convert_to_namespace(value)
34
+ return Namespace(**config_dict)
35
+
36
+
37
+ # Step 3: load model with pretrained weights
38
+ def prepare_model(args, config_args) ->nn.Module:
39
+ """
40
+ Prepare the model and PyTorch Lightning Trainer using a flat args object.
41
+ """
42
+
43
+ # Initialize the model graph
44
+ model = Stage3_mod.get_model(
45
+ args=config_args,
46
+ data_shape=(config_args.image_size, config_args.image_size),
47
+ num_classes=config_args.num_classes
48
+ )
49
+
50
+ # Load state_dict into the model with map_location="cpu"
51
+ model.load_state_dict(torch.load(args.model_path, map_location=config_args.device))
52
+ model.eval()
53
+
54
+ print(f"Stage 3 model loaded from: {args.model_path} (loaded on {config_args.device})")
55
+ return model
56
+
57
+
58
+
59
+ # Step 4: Sample sequences from the model
60
+ @torch.no_grad()
61
+ def batch_stage3_generate_sequences(
62
+ args: any,
63
+ model: nn.Module,
64
+ z_t: torch.Tensor
65
+ ) -> pd.Series:
66
+ """
67
+ Generates protein sequences in batches using a denoising model.
68
+
69
+ Args:
70
+ args (any): Configuration object containing model and sampling parameters.
71
+ model (nn.Module): The pre-trained model used for denoising and generation.
72
+ z_t (torch.Tensor): Input tensor representing initial samples for sequence generation.
73
+
74
+ Returns:
75
+ pd.Series: A dictionary containing generated sequences for each replica.
76
+ """
77
+
78
+ # Handle z_t if passed as a list of tensors
79
+ if isinstance(z_t, list) and all(isinstance(item, torch.Tensor) for item in z_t):
80
+ print(f"z_t is a list of tensors with {len(z_t)} tensors.")
81
+ z_t = torch.stack(z_t)
82
+
83
+ # Move model and inputs to the target device (CPU or CUDA)
84
+ model.to(args.device)
85
+ z_t = z_t.to(args.device)
86
+
87
+ # Amino acid tokenization including special characters
88
+ tokens = [
89
+ '-', '<START>', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M',
90
+ 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '<END>', '<PAD>',
91
+ 'X', 'U', 'Z', 'B', 'O' # Special characters
92
+ ]
93
+
94
+ # Initialize a dictionary to store generated sequences for each replica
95
+ design_sequence_dict = {f'replica_{ii}': [] for ii in range(args.num_replicas)}
96
+
97
+ # Loop over input samples (each z_t) and generate sequences
98
+ for idx_sample, z_text_sample in enumerate(z_t):
99
+
100
+ # Process in batches to optimize memory and speed
101
+ for batch_start in range(0, args.num_replicas, args.batch_size_sample):
102
+ current_batch_size = min(args.batch_size_sample, args.num_replicas - batch_start)
103
+
104
+ # Prepare batched input for current batch
105
+ batched_z_text_sample = z_text_sample.unsqueeze(0).repeat(current_batch_size, 1)
106
+
107
+ # Generate random permutations for each sample in the batch
108
+ batch_perms = torch.stack([torch.randperm(args.diffusion_steps) for _ in range(current_batch_size)])
109
+
110
+ # Generate denoised samples using the model
111
+ mask_realization_list, _ = Stage3_sample_tools.batch_generate_denoised_sampled(
112
+ args=args,
113
+ model=model,
114
+ extract_digit_samples=torch.zeros(current_batch_size, args.diffusion_steps),
115
+ extract_time=torch.zeros(current_batch_size).long(),
116
+ extract_digit_label=batched_z_text_sample,
117
+ sampling_path=batch_perms
118
+ )
119
+
120
+ # Convert generated numeric sequences to amino acid sequences
121
+ for i, mask_realization in enumerate(mask_realization_list[-1]):
122
+ design_sequence = Stage3_ani_tools.convert_num_to_char(tokens, mask_realization[0])
123
+ clean_sequence = design_sequence.replace('<START>', '').replace('<END>', '').replace('<PAD>', '')
124
+ design_sequence_dict[f'replica_{batch_start + i}'].append(clean_sequence)
125
+
126
+ return design_sequence_dict
127
+
128
+
129
+
130
+ # Step 5: Argument Parser Function
131
+ def parse_arguments():
132
+
133
+ parser = argparse.ArgumentParser(description="BioM3 Inference Script (Stage 1)")
134
+ parser.add_argument('--json_path', type=str, required=True,
135
+ help="Path to the JSON configuration file (stage1_config.json)")
136
+ parser.add_argument('--model_path', type=str, required=True,
137
+ help="Path to the pre-trained model weights (pytorch_model.bin)")
138
+ parser.add_argument('--input_path', type=str, required=True,
139
+ help="Path to save input embeddings")
140
+ parser.add_argument('--output_path', type=str, required=True,
141
+ help="Path to save output embeddings")
142
+ return parser.parse_args()
143
+
144
+
145
+ if __name__ == '__main__':
146
+
147
+ # Parse arguments
148
+ config_args_parser = parse_arguments()
149
+
150
+ # Load and convert JSON config
151
+ config_dict = load_json_config(config_args_parser.json_path)
152
+ config_args = convert_to_namespace(config_dict)
153
+
154
+ # load test dataset
155
+ embedding_dataset = torch.load(config_args_parser.input_path)
156
+
157
+ # load model
158
+ model = prepare_model(args=config_args_parser, config_args=config_args)
159
+
160
+ # sample sequences
161
+ design_sequence_dict = batch_stage3_generate_sequences(
162
+ args=config_args,
163
+ model=model,
164
+ z_t=embedding_dataset['z_c']
165
+ )
166
+
167
+ print(f'{design_sequence_dict=}')
stage3_config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "cuda",
3
+ "output_hist_folder": "None",
4
+ "version_name": "None",
5
+ "output_folder": "./",
6
+ "save_hist_path": "None",
7
+ "tb_logger_path": "None",
8
+ "tb_logger_folder": "None",
9
+ "model_option": "transformer",
10
+ "model_path_checkpoint": "/project/ranganathanr/niksapraljak/HF_repo/HF_BioM3_project/V20240805_final_phase8/last-v2.ckpt",
11
+ "stage3_model_path": "/project/ranganathanr/niksapraljak/HF_repo/HF_BioM3_project/V20240805_final_phase8/last-v2.ckpt",
12
+ "stage2_data_path": "None",
13
+ "stage3_output_data_path": "None",
14
+ "data_root": "None",
15
+ "num_replicas": 5,
16
+ "batch_size_sample": 32,
17
+ "diffusion_steps": 1024,
18
+ "seed": 42,
19
+ "batch_size": 16,
20
+ "warmup_steps": 500,
21
+ "image_size": 32,
22
+ "learning_rate": 1e-4,
23
+ "weight_decay": 1e-6,
24
+ "ema_inv_gamma": 1.0,
25
+ "ema_power": 0.75,
26
+ "ema_max_value": 0.95,
27
+ "precision": "fp16",
28
+ "num_classes": 29,
29
+ "num_y_class_labels": 6,
30
+ "task": "proteins",
31
+ "enter_eval": 1000,
32
+ "choose_optim": "DeepSpeedCPUAdam",
33
+ "epochs": 1000,
34
+ "acc_grad_batches": 1,
35
+ "gpu_devices": 1,
36
+ "scheduler_gamma": "coswarmup",
37
+ "text_emb_dim": 512,
38
+ "facilitator": "MMD",
39
+ "context_window_size": 1024,
40
+ "sequence_keyname": "sequence",
41
+ "valid_size": 0.1,
42
+ "num_workers": 12,
43
+ "transformer_dim": 512,
44
+ "transformer_heads": 16,
45
+ "transformer_depth": 16,
46
+ "model_checkpoint": "/project/ranganathanr/niksapraljak/HF_repo/HF_BioM3_project/V20240805_final_phase8/last-v2.ckpt",
47
+ "data_path": "None",
48
+ "output_dict_path": "None",
49
+
50
+ "num_steps": 1,
51
+ "actnorm": false,
52
+ "perm_channel": "none",
53
+ "perm_length": "reverse",
54
+ "input_dp_rate": 0.0,
55
+
56
+ "transformer_blocks": 1,
57
+ "transformer_dropout": 0.1,
58
+ "transformer_reversible": false,
59
+ "transformer_local_heads": 8,
60
+ "transformer_local_size": 128
61
+ }
62
+