| import gc |
| import os |
| import math |
| from re import L |
| import torch |
|
|
| import lightning as pl |
| import torch.nn.functional as F |
|
|
| from transformers import AutoModel |
|
|
| from src.madsbm.wt_peptide.control_field import PeptideControlField |
| from src.PeptiVerse.inference import PeptiVersePredictor |
| from src.utils.model_utils import CosineWarmup, _print, compute_grad_norms |
|
|
|
|
| class MadSBM(pl.LightningModule): |
| def __init__(self, config, guidance=None): |
| super().__init__() |
|
|
| self.config = config |
| self.model = PeptideControlField(config) |
| self.tokenizer = self.model.tokenizer |
| self.vocab_size = self.tokenizer.vocab_size |
| |
| self.mask_id = self.tokenizer.mask_token_id |
| self.pad_id = self.tokenizer.pad_token_id |
|
|
| self.embed_model = AutoModel.from_pretrained(config.model.esm_model) |
| self.embed_model.eval() |
| for param in self.embed_model.parameters(): |
| param.requires_grad = False |
|
|
| self.time_schedule = config.time_embed.time_schedule |
| self.anneal_frac = config.time_embed.anneal_frac |
| self.eps = float(config.time_embed.min_time) |
| self.t_max = 1.0 - self.eps |
| |
|
|
| |
| def forward(self, input_ids, attention_mask, t): |
| return self.model(xt=input_ids, attention_mask=attention_mask, t=t) |
|
|
| def step(self, batch): |
| x1 = batch['input_ids'] |
| attn_mask = batch['attention_mask'] |
| maskable = self.is_maskable(x1) |
|
|
| t = self.sample_t(x1) |
| xt = self.noise_seq(x1, t, maskable_mask=maskable) |
|
|
| outs = self.forward(xt, attn_mask, t) |
| if self.config.model.ablate: |
| logits = outs['dit'] |
| else: |
| logits = outs['madsbm'] |
| max_u_logit = outs['dit'].max().item() |
| max_esm_logit = outs['esm'].max().item() |
|
|
| loss_token = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| x1.view(-1), |
| reduction = 'none', |
| ignore_index=self.pad_id |
| ) |
| loss_token = loss_token.view(x1.size(0), x1.size(1)) |
|
|
| sample_loss = (loss_token * maskable.float()).sum(dim=1) / maskable.float().sum(dim=1).clamp(min=1.0) |
|
|
| loss = sample_loss.mean() |
| ppl = torch.exp(loss) |
| |
| return loss, ppl, max_u_logit, max_esm_logit |
|
|
| def noise_seq(self, x1, t, maskable_mask): |
| B, L = x1.shape |
| t = t.unsqueeze(1) |
| |
| |
| u = torch.rand((B, L), device=x1.device) |
| masked = (u < t) & maskable_mask |
|
|
| xt = x1.clone() |
| xt = xt.masked_fill(masked, self.mask_id) |
|
|
| return xt |
| |
| |
| def sample_t(self, x1): |
| ts = self.time_schedule |
| if ts == 'linear': |
| return self.sample_linear_t(x1) |
| elif ts == 'exponential': |
| return self.sample_exp_t(x1) |
| elif ts == 'uniform': |
| return self.sample_uni_t(x1) |
| else: |
| raise ValueError(f"Unrecognized time scheduler type: {ts}") |
|
|
| def sample_uni_t(self, x1): |
| B = x1.size(0) |
| T = self.config.time_embed.n_timesteps |
|
|
| discrete_ts = torch.randint(1, T+1, (B,), device=x1.device) |
| timesteps = discrete_ts.float() / float(T) |
| _print(f'timesteps: {timesteps}') |
| return timesteps.clamp(min=self.eps, max=self.t_max) |
|
|
|
|
| def sample_linear_t(self, x1): |
| B = x1.size(0) |
| eps = self.eps |
|
|
| |
| frac = float(self.global_step) / float(self.tot_steps) |
| t_max = 1.0 - eps |
|
|
| if frac < self.anneal_frac: |
| |
| prog = frac / max(1e-12, self.anneal_frac) |
| t_min = eps + prog * (t_max - eps) |
| t = t_min + (t_max - t_min) * torch.rand(B, device=x1.device) |
| else: |
| |
| t = eps + (t_max - eps) * torch.rand(B, device=x1.device) |
|
|
| return t.clamp(min=eps, max=t_max) |
|
|
|
|
| def sample_t_exponential(self, x1, t_min=1e-6, t_max=1.0-1e-6): |
| |
| """ |
| Exponentially anneal center of t from t_min to t_max over training. |
| |
| Implement if linear schedule isn't expressive enough |
| But for annealing over training steps, which can be a very large quantity, |
| exponential approximates linear schedule |
| """ |
| |
| k = self.config.training.exp_time_k |
| progress = self.trainer.step / self.tot_steps |
| frac = 1.0 - torch.exp(-k * torch.tensor(progress)) |
| center = t_min + frac * (t_max - t_min) |
|
|
| |
| t = torch.randn(x1.size(0)) * self.config.training.time_sigma + center |
| return t.clamp(min=t_min, max=t_max) |
|
|
|
|
|
|
| |
| def training_step(self, batch): |
| loss, ppl = self.step(batch) |
| self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True) |
| self.log("train/ppl", ppl, on_step=True, on_epoch=False, prog_bar=False) |
| return loss |
| |
| def validation_step(self, batch): |
| loss, ppl = self.step(batch) |
| self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| self.log("val/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| return loss |
|
|
| def test_step(self, batch): |
| loss, ppl, max_u, max_esm = self.step(batch) |
| self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| self.log("test/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| self.log("test/max_madsbm_logit", max_u, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| self.log("test/max_esm_logit", max_esm, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| return loss |
|
|
| def on_after_backward(self): |
| pre_norm = compute_grad_norms(self.parameters()) |
| self.log('train/grad_norm_PRE_clip', pre_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True) |
|
|
| |
| |
| |
| |
| def configure_optimizers(self): |
| optimizer = torch.optim.AdamW( |
| params = self.model.parameters(), |
| lr = self.config.optim.lr, |
| weight_decay = self.config.optim.weight_decay, |
| betas = (self.config.optim.beta1, self.config.optim.beta2) |
| ) |
|
|
| self.tot_steps = self.trainer.estimated_stepping_batches |
| warmup_steps = int(self.config.optim.warmup_epochs * self.tot_steps / self.config.training.n_epochs) |
|
|
| lr_scheduler = CosineWarmup( |
| optimizer = optimizer, |
| warmup_steps = warmup_steps, |
| total_steps = self.tot_steps |
| ) |
|
|
| return { |
| "optimizer": optimizer, |
| "lr_scheduler": { |
| "scheduler": lr_scheduler, |
| "interval": "step", |
| "frequency": 1 |
| } |
| } |
|
|
| def on_save_checkpoint(self, checkpoint: dict): |
| """ |
| Don't save the classifier model used for FBD calculation in the ckpt |
| """ |
| sd = checkpoint.get('state_dict', None) |
| if sd is None: |
| return |
| keys_to_remove = [k for k in sd.keys() if k.startswith("score_model.")] |
| for k in keys_to_remove: |
| sd.pop(k, None) |
| checkpoint['state_dict'] = sd |
|
|
|
|
| |
| def is_maskable(self, input_ids: torch.Tensor): |
| return ( |
| (input_ids != self.tokenizer.pad_token_id) |
| & (input_ids != self.tokenizer.cls_token_id) |
| & (input_ids != self.tokenizer.eos_token_id) |
| ) |
|
|
| def validate_config(self): |
| assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path" |
| assert self.config.model.hidden_dim % 2 == 0, 'odd value for embedding dim' |
| assert self.config.time_embed.time_dim % 2 == 0, 'odd value for time dim' |
| assert self.config.time_embed.fourier_dim % 2 == 0, 'odd value for fourier dim' |
|
|
| def get_state_dict(self, ckpt_path): |
| def remove_model_prefix(state_dict): |
| for k, v in state_dict.items(): |
| if "model." in k: |
| k.replace('model.', '') |
| return state_dict |
|
|
| checkpoint = torch.load(ckpt_path, map_location='cuda:3' if torch.cuda.is_available() else 'cpu') |
| state_dict = checkpoint.get("state_dict", checkpoint) |
|
|
| if any(k.startswith("model.") for k in state_dict.keys()): |
| state_dict = remove_model_prefix(state_dict) |
| |
| return state_dict |
|
|
| def cleanup(self): |
| torch.cuda.empty_cache() |
| gc.collect() |