nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import os
import random
import string
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import hydra
import omegaconf
import pytorch_lightning as pl
import torch
import torch.multiprocessing
from omegaconf import OmegaConf, listconfig
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities import rank_zero_only
from boltz.data.module.training import BoltzTrainingDataModule, DataConfig
@dataclass
class TrainConfig:
"""Train configuration.
Attributes
----------
data : DataConfig
The data configuration.
model : ModelConfig
The model configuration.
output : str
The output directory.
trainer : Optional[dict]
The trainer configuration.
resume : Optional[str]
The resume checkpoint.
pretrained : Optional[str]
The pretrained model.
wandb : Optional[dict]
The wandb configuration.
disable_checkpoint : bool
Disable checkpoint.
matmul_precision : Optional[str]
The matmul precision.
find_unused_parameters : Optional[bool]
Find unused parameters.
save_top_k : Optional[int]
Save top k checkpoints.
validation_only : bool
Run validation only.
debug : bool
Debug mode.
strict_loading : bool
Fail on mismatched checkpoint weights.
load_confidence_from_trunk: Optional[bool]
Load pre-trained confidence weights from trunk.
"""
data: DataConfig
model: LightningModule
output: str
trainer: Optional[dict] = None
resume: Optional[str] = None
pretrained: Optional[str] = None
wandb: Optional[dict] = None
disable_checkpoint: bool = False
matmul_precision: Optional[str] = None
find_unused_parameters: Optional[bool] = False
save_top_k: Optional[int] = 1
validation_only: bool = False
debug: bool = False
strict_loading: bool = True
load_confidence_from_trunk: Optional[bool] = False
def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915
"""Run training.
Parameters
----------
raw_config : str
The input yaml configuration.
args : list[str]
Any command line overrides.
"""
# Load the configuration
raw_config = omegaconf.OmegaConf.load(raw_config)
# Apply input arguments
args = omegaconf.OmegaConf.from_dotlist(args)
raw_config = omegaconf.OmegaConf.merge(raw_config, args)
# Instantiate the task
cfg = hydra.utils.instantiate(raw_config)
cfg = TrainConfig(**cfg)
# Set matmul precision
if cfg.matmul_precision is not None:
torch.set_float32_matmul_precision(cfg.matmul_precision)
# Create trainer dict
trainer = cfg.trainer
if trainer is None:
trainer = {}
# Flip some arguments in debug mode
devices = trainer.get("devices", 1)
wandb = cfg.wandb
if cfg.debug:
if isinstance(devices, int):
devices = 1
elif isinstance(devices, (list, listconfig.ListConfig)):
devices = [devices[0]]
trainer["devices"] = devices
cfg.data.num_workers = 0
if wandb:
wandb = None
# Create objects
data_config = DataConfig(**cfg.data)
data_module = BoltzTrainingDataModule(data_config)
model_module = cfg.model
if cfg.pretrained and not cfg.resume:
# Load the pretrained weights into the confidence module
if cfg.load_confidence_from_trunk:
checkpoint = torch.load(cfg.pretrained, map_location="cpu")
# Modify parameter names in the state_dict
new_state_dict = {}
for key, value in checkpoint["state_dict"].items():
if not key.startswith("structure_module") and not key.startswith(
"distogram_module"
):
new_key = "confidence_module." + key
new_state_dict[new_key] = value
new_state_dict.update(checkpoint["state_dict"])
# Update the checkpoint with the new state_dict
checkpoint["state_dict"] = new_state_dict
# Save the modified checkpoint
random_string = "".join(
random.choices(string.ascii_lowercase + string.digits, k=10)
)
file_path = os.path.dirname(cfg.pretrained) + "/" + random_string + ".ckpt"
print(
f"Saving modified checkpoint to {file_path} created by broadcasting trunk of {cfg.pretrained} to confidence module."
)
torch.save(checkpoint, file_path)
else:
file_path = cfg.pretrained
print(f"Loading model from {file_path}")
model_module = type(model_module).load_from_checkpoint(
file_path, map_location="cpu", strict=False, **(model_module.hparams)
)
if cfg.load_confidence_from_trunk:
os.remove(file_path)
# Create checkpoint callback
callbacks = []
dirpath = cfg.output
if not cfg.disable_checkpoint:
mc = ModelCheckpoint(
monitor="val/lddt",
save_top_k=cfg.save_top_k,
save_last=True,
mode="max",
every_n_epochs=1,
)
callbacks = [mc]
# Create wandb logger
loggers = []
if wandb:
wdb_logger = WandbLogger(
name=wandb["name"],
group=wandb["name"],
save_dir=cfg.output,
project=wandb["project"],
entity=wandb["entity"],
log_model=False,
)
loggers.append(wdb_logger)
# Save the config to wandb
@rank_zero_only
def save_config_to_wandb() -> None:
config_out = Path(wdb_logger.experiment.dir) / "run.yaml"
with Path.open(config_out, "w") as f:
OmegaConf.save(raw_config, f)
wdb_logger.experiment.save(str(config_out))
save_config_to_wandb()
# Set up trainer
strategy = "auto"
if (isinstance(devices, int) and devices > 1) or (
isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1
):
strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters)
trainer = pl.Trainer(
default_root_dir=str(dirpath),
strategy=strategy,
callbacks=callbacks,
logger=loggers,
enable_checkpointing=not cfg.disable_checkpoint,
reload_dataloaders_every_n_epochs=1,
**trainer,
)
if not cfg.strict_loading:
model_module.strict_loading = False
if cfg.validation_only:
trainer.validate(
model_module,
datamodule=data_module,
ckpt_path=cfg.resume,
)
else:
trainer.fit(
model_module,
datamodule=data_module,
ckpt_path=cfg.resume,
)
if __name__ == "__main__":
arg1 = sys.argv[1]
arg2 = sys.argv[2:]
train(arg1, arg2)