DockFormer / train.py
bshor's picture
add code
bca3a49
raw
history blame
19.7 kB
import json
import sys
from typing import Optional
# This import must be on top to set the environment variables before importing other modules
import env_consts
import time
import os
from lightning.pytorch import seed_everything
import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.profilers import AdvancedProfiler
from dockformer.config import model_config
from dockformer.data.data_modules import OpenFoldDataModule, DockFormerDataModule
from dockformer.model.model import AlphaFold
from dockformer.utils import residue_constants
from dockformer.utils.exponential_moving_average import ExponentialMovingAverage
from dockformer.utils.loss import AlphaFoldLoss, lddt_ca
from dockformer.utils.lr_schedulers import AlphaFoldLRScheduler
from dockformer.utils.script_utils import get_latest_checkpoint
from dockformer.utils.superimposition import superimpose
from dockformer.utils.tensor_utils import tensor_tree_map
from dockformer.utils.validation_metrics import (
drmsd,
gdt_ts,
gdt_ha,
rmsd,
)
class ModelWrapper(pl.LightningModule):
def __init__(self, config):
super(ModelWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config)
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
self.aggregated_metrics = {}
self.log_agg_every_n_steps = 50 # match Trainer(log_every_n_steps=50)
def forward(self, batch):
return self.model(batch)
def _log(self, loss_breakdown, batch, outputs, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
# print("logging loss", loss_name, indiv_loss, flush=True)
self.log(
f"{phase}/{loss_name}",
indiv_loss,
on_step=train, on_epoch=(not train), logger=True, sync_dist=True
)
if train:
agg_name = f"{phase}/{loss_name}_agg"
if agg_name not in self.aggregated_metrics:
self.aggregated_metrics[agg_name] = []
self.aggregated_metrics[agg_name].append(float(indiv_loss))
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
on_step=False, on_epoch=True, logger=True, sync_dist=True
)
# print("logging validation metrics", flush=True)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for k, v in other_metrics.items():
# print("logging metric", k, v, flush=True)
if train:
agg_name = f"{phase}/{k}_agg"
if agg_name not in self.aggregated_metrics:
self.aggregated_metrics[agg_name] = []
self.aggregated_metrics[agg_name].append(float(torch.mean(v)))
self.log(
f"{phase}/{k}",
torch.mean(v),
on_step=False, on_epoch=True, logger=True, sync_dist=True
)
if train and any([len(v) >= self.log_agg_every_n_steps for v in self.aggregated_metrics.values()]):
for k, v in self.aggregated_metrics.items():
print("logging agg", k, len(v), sum(v) / len(v), flush=True)
self.log(k, sum(v) / len(v), on_step=True, on_epoch=False, logger=True, sync_dist=True)
self.aggregated_metrics[k] = []
def training_step(self, batch, batch_idx):
if self.ema.device != batch["aatype"].device:
self.ema.to(batch["aatype"].device)
# ground_truth = batch.pop('gt_features', None)
# Run the model
# print("running model", round(time.time() % 10000, 3), flush=True)
outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
# print("running loss", round(time.time() % 10000, 3), flush=True)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
# Log it
self._log(loss_breakdown, batch, outputs)
# print("loss done", round(time.time() % 10000, 3), flush=True)
return loss
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if self.cached_weights is None:
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
batch["use_clamped_fape"] = 0.
# Compute loss and other metrics
_, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
self._log(loss_breakdown, batch, outputs, train=False)
def on_validation_epoch_end(self):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def _compute_validation_metrics(self,
batch,
outputs,
superimposition_metrics=False
):
metrics = {}
all_gt_coords = batch["atom37_gt_positions"]
all_pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["atom37_atom_exists_in_gt"]
rough_protein_atom_mask = torch.repeat_interleave(batch["protein_mask"], 37, dim=-1).view(*all_atom_mask.shape)
protein_gt_coords = all_gt_coords * rough_protein_atom_mask[..., None]
protein_pred_coords = all_pred_coords * rough_protein_atom_mask[..., None]
protein_all_atom_mask = all_atom_mask * rough_protein_atom_mask
rough_ligand_atom_mask = torch.repeat_interleave(batch["ligand_mask"], 37, dim=-1).view(*all_atom_mask.shape)
ligand_gt_coords = all_gt_coords * rough_ligand_atom_mask[..., None]
ligand_pred_coords = all_pred_coords * rough_ligand_atom_mask[..., None]
ligand_all_atom_mask = all_atom_mask * rough_ligand_atom_mask
# This is super janky for superimposition. Fix later
protein_gt_coords_masked = protein_gt_coords * protein_all_atom_mask[..., None]
protein_pred_coords_masked = protein_pred_coords * protein_all_atom_mask[..., None]
ca_pos = residue_constants.atom_order["CA"]
protein_gt_coords_masked_ca = protein_gt_coords_masked[..., ca_pos, :]
protein_pred_coords_masked_ca = protein_pred_coords_masked[..., ca_pos, :]
protein_atom_mask_ca = protein_all_atom_mask[..., ca_pos]
ligand_gt_coords_single_atom = ligand_gt_coords[..., ca_pos, :]
ligand_pred_coords_single_atom = ligand_pred_coords[..., ca_pos, :]
ligand_gt_mask_single_atom = ligand_all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca(
protein_pred_coords,
protein_gt_coords,
protein_all_atom_mask,
eps=self.config.globals.eps,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd(
protein_pred_coords_masked_ca,
protein_gt_coords_masked_ca,
mask=protein_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
drmsd_intra_ligand_score = drmsd(
ligand_pred_coords_single_atom,
ligand_gt_coords_single_atom,
mask=ligand_gt_mask_single_atom,
)
metrics["drmsd_intra_ligand"] = drmsd_intra_ligand_score
# --- inter contacts
gt_contacts = batch["gt_inter_contacts"]
pred_contacts = torch.sigmoid(outputs["inter_contact_logits"].clone().detach()).squeeze(-1)
pred_contacts = (pred_contacts > 0.5).float()
pred_contacts = pred_contacts * batch["inter_pair_mask"]
# Calculate True Positives, False Positives, and False Negatives
tp = torch.sum((gt_contacts == 1) & (pred_contacts == 1))
fp = torch.sum((gt_contacts == 0) & (pred_contacts == 1))
fn = torch.sum((gt_contacts == 1) & (pred_contacts == 0))
# Calculate Recall and Precision
recall = tp / (tp + fn) if (tp + fn) > 0 else tp.float()
precision = tp / (tp + fp) if (tp + fp) > 0 else tp.float()
metrics["inter_contacts_recall"] = recall.clone().detach()
metrics["inter_contacts_precision"] = precision.clone().detach()
# print("inter_contacts recall", recall, "precision", precision, tp, fp, fn, torch.ones_like(gt_contacts).sum())
# --- Affinity
if True or batch["affinity_loss_factor"].sum() > 0.1:
# print("affinity loss factor", batch["affinity_loss_factor"].sum())
gt_affinity = batch["affinity"].squeeze(-1)
affinity_linspace = torch.linspace(0, 15, 32, device=batch["affinity"].device)
pred_affinity_1d = torch.sum(
torch.softmax(outputs["affinity_1d_logits"].clone().detach(), -1) * affinity_linspace, dim=-1)
pred_affinity_2d = torch.sum(
torch.softmax(outputs["affinity_2d_logits"].clone().detach(), -1) * affinity_linspace, dim=-1)
pred_affinity_cls = torch.sum(
torch.softmax(outputs["affinity_cls_logits"].clone().detach(), -1) * affinity_linspace, dim=-1)
aff_loss_factor = batch["affinity_loss_factor"].squeeze()
metrics["affinity_dist_1d"] = (torch.abs(gt_affinity - pred_affinity_1d) * aff_loss_factor).sum() / aff_loss_factor.sum()
metrics["affinity_dist_2d"] = (torch.abs(gt_affinity - pred_affinity_2d) * aff_loss_factor).sum() / aff_loss_factor.sum()
metrics["affinity_dist_cls"] = (torch.abs(gt_affinity - pred_affinity_cls) * aff_loss_factor).sum() / aff_loss_factor.sum()
metrics["affinity_dist_avg"] = (torch.abs(gt_affinity - (pred_affinity_cls + pred_affinity_1d + pred_affinity_2d) / 3) * aff_loss_factor).sum() / aff_loss_factor.sum()
# print("affinity metrics", gt_affinity, pred_affinity_2d, aff_loss_factor, metrics["affinity_dist_1d"],
# metrics["affinity_dist_2d"], metrics["affinity_dist_cls"], metrics["affinity_dist_avg"])
else:
# print("skipping affinity metrics")
pass
if superimposition_metrics:
superimposed_pred, alignment_rmsd, rots, transs = superimpose(
protein_gt_coords_masked_ca, protein_pred_coords_masked_ca, protein_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, protein_gt_coords_masked_ca, protein_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, protein_gt_coords_masked_ca, protein_atom_mask_ca
)
metrics["protein_alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
superimposed_ligand_coords = ligand_pred_coords_single_atom @ rots + transs[:, None, :]
ligand_alignment_rmsds = rmsd(ligand_gt_coords_single_atom, superimposed_ligand_coords,
mask=ligand_gt_mask_single_atom)
metrics["ligand_alignment_rmsd"] = ligand_alignment_rmsds.mean()
metrics["ligand_alignment_rmsd_under_2"] = torch.mean((ligand_alignment_rmsds < 2).float())
metrics["ligand_alignment_rmsd_under_5"] = torch.mean((ligand_alignment_rmsds < 5).float())
print("ligand rmsd:", ligand_alignment_rmsds)
return metrics
def configure_optimizers(self,
learning_rate: Optional[float] = None,
eps: float = 1e-5,
) -> torch.optim.Adam:
if learning_rate is None:
learning_rate = self.config.globals.max_lr
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
eps=eps
)
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
last_epoch=self.last_lr_step,
max_lr=self.config.globals.max_lr,
start_decay_after_n_steps=10000,
decay_every_n_steps=10000,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"name": "AlphaFoldLRScheduler",
}
}
def on_load_checkpoint(self, checkpoint):
ema = checkpoint["ema"]
self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def override_config(base_config, overriding_config):
for k, v in overriding_config.items():
if isinstance(v, dict):
base_config[k] = override_config(base_config[k], v)
else:
base_config[k] = v
return base_config
def train(override_config_path: str):
run_config = json.load(open(override_config_path, "r"))
seed = 42
seed_everything(seed, workers=True)
output_dir = run_config["train_output_dir"]
os.makedirs(output_dir, exist_ok=True)
print("Starting train", time.time())
config = model_config(
run_config.get("stage", "initial_training"),
train=True,
low_prec=True
)
config = override_config(config, run_config.get("override_conf", {}))
accumulate_grad_batches = run_config.get("accumulate_grad_batches", 1)
print("config loaded", time.time())
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_name = "cuda" if torch.cuda.is_available() else "cpu"
# device_name = "mps" if device_name == "cpu" and torch.backends.mps.is_available() else device_name
model_module = ModelWrapper(config)
print("model loaded", time.time())
# device_name = "cpu"
# for debugging memory:
# torch.cuda.memory._record_memory_history()
if "train_input_dir" in run_config:
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=seed,
train_data_dir=run_config["train_input_dir"],
val_data_dir=run_config["val_input_dir"],
train_epoch_len=run_config.get("train_epoch_len", 1000),
)
else:
data_module = DockFormerDataModule(
config=config.data,
batch_seed=seed,
train_data_file=run_config["train_input_file"],
val_data_file=run_config["val_input_file"],
)
print("data module loaded", time.time())
checkpoint_dir = os.path.join(output_dir, "checkpoint")
ckpt_path = run_config.get("ckpt_path", get_latest_checkpoint(checkpoint_dir))
if ckpt_path:
print(f"Resuming from checkpoint: {ckpt_path}")
sd = torch.load(ckpt_path)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
# Do we need this?
data_module.prepare_data()
data_module.setup("fit")
callbacks = []
mc = ModelCheckpoint(
dirpath=checkpoint_dir,
# every_n_epochs=1,
every_n_train_steps=250,
auto_insert_metric_name=False,
save_top_k=1,
save_on_train_epoch_end=True, # before validation
)
mc2 = ModelCheckpoint(
dirpath=checkpoint_dir, # Directory to save checkpoints
filename="step{step}_lig_rmsd{val/ligand_alignment_rmsd:.2f}", # Filename format for best
monitor="val/ligand_alignment_rmsd", # Metric to monitor
mode="min", # We want the lowest `ligand_rmsd`
save_top_k=1, # Save only the best model based on `ligand_rmsd`
every_n_epochs=1, # Save a checkpoint every epoch
auto_insert_metric_name=False
)
callbacks.append(mc)
callbacks.append(mc2)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
loggers = []
wandb_project_name = "EvoDocker3"
wandb_run_id_path = os.path.join(output_dir, "wandb_run_id.txt")
# Initialize WandbLogger and save run_id
local_rank = int(os.getenv('LOCAL_RANK', os.getenv("SLURM_PROCID", '0')))
global_rank = int(os.getenv('GLOBAL_RANK', os.getenv("SLURM_LOCALID", '0')))
print("ranks", os.getenv('LOCAL_RANK', 'd0'), os.getenv('local_rank', 'd0'), os.getenv('GLOBAL_RANK', 'd0'),
os.getenv('global_rank', 'd0'), os.getenv("SLURM_PROCID", 'd0'), os.getenv('SLURM_LOCALID', 'd0'), flush=True)
if local_rank == 0 and global_rank == 0 and not os.path.exists(wandb_run_id_path):
wandb_logger = WandbLogger(project=wandb_project_name, save_dir=output_dir)
with open(wandb_run_id_path, 'w') as f:
f.write(wandb_logger.experiment.id)
wandb_logger.experiment.config.update(run_config, allow_val_change=True)
else:
# Necessary for multi-node training https://github.com/rstrudel/segmenter/issues/22
while not os.path.exists(wandb_run_id_path):
print(f"Waiting for run_id file to be created ({local_rank})", flush=True)
time.sleep(1)
with open(wandb_run_id_path, 'r') as f:
run_id = f.read().strip()
wandb_logger = WandbLogger(project=wandb_project_name, save_dir=output_dir, resume='must', id=run_id)
loggers.append(wandb_logger)
strategy_params = {"strategy": "auto"}
if run_config.get("multi_node", False):
strategy_params["strategy"] = "ddp"
# strategy_params["strategy"] = "ddp_find_unused_parameters_true" # this causes issues with checkpointing...
strategy_params["num_nodes"] = run_config["multi_node"]["num_nodes"]
strategy_params["devices"] = run_config["multi_node"]["devices"]
trainer = pl.Trainer(
accelerator=device_name,
default_root_dir=output_dir,
**strategy_params,
reload_dataloaders_every_n_epochs=1,
accumulate_grad_batches=accumulate_grad_batches,
check_val_every_n_epoch=run_config.get("check_val_every_n_epoch", 10),
callbacks=callbacks,
logger=loggers,
# profiler=AdvancedProfiler(),
)
print("Starting fit", time.time())
trainer.fit(
model_module,
datamodule=data_module,
ckpt_path=ckpt_path,
)
# profiler_results = trainer.profiler.summary()
# print(profiler_results)
# torch.cuda.memory._dump_snapshot("my_train_snapshot.pickle")
# view on https://pytorch.org/memory_viz
if __name__ == "__main__":
if len(sys.argv) > 1:
train(sys.argv[1])
else:
train(os.path.join(os.path.dirname(__file__), "run_config.json"))