|
|
|
|
|
import os |
|
import wandb |
|
import lightning.pytorch as pl |
|
|
|
from omegaconf import OmegaConf |
|
from lightning.pytorch.strategies import DDPStrategy |
|
from lightning.pytorch.loggers import WandbLogger |
|
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
|
from src.utils.model_utils import _print |
|
from src.guidance.solubility_module import SolubilityClassifier |
|
from src.guidance.dataloader import MembraneDataModule, get_datasets |
|
|
|
|
|
config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml") |
|
wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f') |
|
|
|
|
|
datasets = get_datasets(config) |
|
data_module = MembraneDataModule( |
|
config=config, |
|
train_dataset=datasets['train'], |
|
val_dataset=datasets['val'], |
|
test_dataset=datasets['test'], |
|
) |
|
|
|
|
|
|
|
wandb_logger = WandbLogger(**config.wandb) |
|
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval="step") |
|
checkpoint_callback = ModelCheckpoint( |
|
monitor="val/loss", |
|
save_top_k=1, |
|
mode="min", |
|
dirpath=config.checkpointing.save_dir, |
|
filename="best_model", |
|
) |
|
|
|
|
|
trainer = pl.Trainer( |
|
max_steps=config.training.max_steps, |
|
accelerator="cuda", |
|
devices=1, |
|
|
|
callbacks=[checkpoint_callback, lr_monitor], |
|
logger=wandb_logger, |
|
log_every_n_steps=config.training.log_every_n_steps |
|
) |
|
|
|
|
|
ckpt_dir = config.checkpointing.save_dir |
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
|
|
|
model = SolubilityClassifier(config) |
|
|
|
|
|
if config.training.mode == "train": |
|
trainer.fit(model, datamodule=data_module) |
|
|
|
elif config.training.mode == "test": |
|
ckpt_path = os.path.join(ckpt_dir, "best_model.ckpt") |
|
state_dict = model.get_state_dict(ckpt_path) |
|
model.load_state_dict(state_dict) |
|
trainer.test(model, datamodule=data_module, ckpt_path=ckpt_path) |
|
else: |
|
raise ValueError(f"{config.training.mode} is invalid. Must be 'train' or 'test'") |
|
|
|
wandb.finish() |
|
|