| | import os |
| | import random |
| | import string |
| | import sys |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import hydra |
| | import omegaconf |
| | import pytorch_lightning as pl |
| | import torch |
| | import torch.multiprocessing |
| | from omegaconf import OmegaConf, listconfig |
| | from pytorch_lightning import LightningModule |
| | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
| | from pytorch_lightning.loggers import WandbLogger |
| | from pytorch_lightning.strategies import DDPStrategy |
| | from pytorch_lightning.utilities import rank_zero_only |
| |
|
| | from boltz.data.module.training import BoltzTrainingDataModule, DataConfig |
| |
|
| |
|
| | @dataclass |
| | class TrainConfig: |
| | """Train configuration. |
| | |
| | Attributes |
| | ---------- |
| | data : DataConfig |
| | The data configuration. |
| | model : ModelConfig |
| | The model configuration. |
| | output : str |
| | The output directory. |
| | trainer : Optional[dict] |
| | The trainer configuration. |
| | resume : Optional[str] |
| | The resume checkpoint. |
| | pretrained : Optional[str] |
| | The pretrained model. |
| | wandb : Optional[dict] |
| | The wandb configuration. |
| | disable_checkpoint : bool |
| | Disable checkpoint. |
| | matmul_precision : Optional[str] |
| | The matmul precision. |
| | find_unused_parameters : Optional[bool] |
| | Find unused parameters. |
| | save_top_k : Optional[int] |
| | Save top k checkpoints. |
| | validation_only : bool |
| | Run validation only. |
| | debug : bool |
| | Debug mode. |
| | strict_loading : bool |
| | Fail on mismatched checkpoint weights. |
| | load_confidence_from_trunk: Optional[bool] |
| | Load pre-trained confidence weights from trunk. |
| | |
| | """ |
| |
|
| | data: DataConfig |
| | model: LightningModule |
| | output: str |
| | trainer: Optional[dict] = None |
| | resume: Optional[str] = None |
| | pretrained: Optional[str] = None |
| | wandb: Optional[dict] = None |
| | disable_checkpoint: bool = False |
| | matmul_precision: Optional[str] = None |
| | find_unused_parameters: Optional[bool] = False |
| | save_top_k: Optional[int] = 1 |
| | validation_only: bool = False |
| | debug: bool = False |
| | strict_loading: bool = True |
| | load_confidence_from_trunk: Optional[bool] = False |
| |
|
| |
|
| | def train(raw_config: str, args: list[str]) -> None: |
| | """Run training. |
| | |
| | Parameters |
| | ---------- |
| | raw_config : str |
| | The input yaml configuration. |
| | args : list[str] |
| | Any command line overrides. |
| | |
| | """ |
| | |
| | raw_config = omegaconf.OmegaConf.load(raw_config) |
| |
|
| | |
| | args = omegaconf.OmegaConf.from_dotlist(args) |
| | raw_config = omegaconf.OmegaConf.merge(raw_config, args) |
| |
|
| | |
| | cfg = hydra.utils.instantiate(raw_config) |
| | cfg = TrainConfig(**cfg) |
| |
|
| | |
| | if cfg.matmul_precision is not None: |
| | torch.set_float32_matmul_precision(cfg.matmul_precision) |
| |
|
| | |
| | trainer = cfg.trainer |
| | if trainer is None: |
| | trainer = {} |
| |
|
| | |
| | devices = trainer.get("devices", 1) |
| |
|
| | wandb = cfg.wandb |
| | if cfg.debug: |
| | if isinstance(devices, int): |
| | devices = 1 |
| | elif isinstance(devices, (list, listconfig.ListConfig)): |
| | devices = [devices[0]] |
| | trainer["devices"] = devices |
| | cfg.data.num_workers = 0 |
| | if wandb: |
| | wandb = None |
| |
|
| | |
| | data_config = DataConfig(**cfg.data) |
| | data_module = BoltzTrainingDataModule(data_config) |
| | model_module = cfg.model |
| |
|
| | if cfg.pretrained and not cfg.resume: |
| | |
| | if cfg.load_confidence_from_trunk: |
| | checkpoint = torch.load(cfg.pretrained, map_location="cpu") |
| |
|
| | |
| | new_state_dict = {} |
| | for key, value in checkpoint["state_dict"].items(): |
| | if not key.startswith("structure_module") and not key.startswith( |
| | "distogram_module" |
| | ): |
| | new_key = "confidence_module." + key |
| | new_state_dict[new_key] = value |
| | new_state_dict.update(checkpoint["state_dict"]) |
| |
|
| | |
| | checkpoint["state_dict"] = new_state_dict |
| |
|
| | |
| | random_string = "".join( |
| | random.choices(string.ascii_lowercase + string.digits, k=10) |
| | ) |
| | file_path = os.path.dirname(cfg.pretrained) + "/" + random_string + ".ckpt" |
| | print( |
| | f"Saving modified checkpoint to {file_path} created by broadcasting trunk of {cfg.pretrained} to confidence module." |
| | ) |
| | torch.save(checkpoint, file_path) |
| | else: |
| | file_path = cfg.pretrained |
| |
|
| | print(f"Loading model from {file_path}") |
| | model_module = type(model_module).load_from_checkpoint( |
| | file_path, map_location="cpu", strict=False, **(model_module.hparams) |
| | ) |
| |
|
| | if cfg.load_confidence_from_trunk: |
| | os.remove(file_path) |
| |
|
| | |
| | callbacks = [] |
| | dirpath = cfg.output |
| | if not cfg.disable_checkpoint: |
| | mc = ModelCheckpoint( |
| | monitor="val/lddt", |
| | save_top_k=cfg.save_top_k, |
| | save_last=True, |
| | mode="max", |
| | every_n_epochs=1, |
| | ) |
| | callbacks = [mc] |
| |
|
| | |
| | loggers = [] |
| | if wandb: |
| | wdb_logger = WandbLogger( |
| | name=wandb["name"], |
| | group=wandb["name"], |
| | save_dir=cfg.output, |
| | project=wandb["project"], |
| | entity=wandb["entity"], |
| | log_model=False, |
| | ) |
| | loggers.append(wdb_logger) |
| | |
| |
|
| | @rank_zero_only |
| | def save_config_to_wandb() -> None: |
| | config_out = Path(wdb_logger.experiment.dir) / "run.yaml" |
| | with Path.open(config_out, "w") as f: |
| | OmegaConf.save(raw_config, f) |
| | wdb_logger.experiment.save(str(config_out)) |
| |
|
| | save_config_to_wandb() |
| |
|
| | |
| | strategy = "auto" |
| | if (isinstance(devices, int) and devices > 1) or ( |
| | isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1 |
| | ): |
| | strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters) |
| |
|
| | trainer = pl.Trainer( |
| | default_root_dir=str(dirpath), |
| | strategy=strategy, |
| | callbacks=callbacks, |
| | logger=loggers, |
| | enable_checkpointing=not cfg.disable_checkpoint, |
| | reload_dataloaders_every_n_epochs=1, |
| | **trainer, |
| | ) |
| |
|
| | if not cfg.strict_loading: |
| | model_module.strict_loading = False |
| |
|
| | if cfg.validation_only: |
| | trainer.validate( |
| | model_module, |
| | datamodule=data_module, |
| | ckpt_path=cfg.resume, |
| | ) |
| | else: |
| | trainer.fit( |
| | model_module, |
| | datamodule=data_module, |
| | ckpt_path=cfg.resume, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | arg1 = sys.argv[1] |
| | arg2 = sys.argv[2:] |
| | train(arg1, arg2) |
| |
|