MemDLM / src /guidance /main.py
Shrey Goel
adding code
d04a061
#!/usr/bin/env python3
import os
import wandb
import lightning.pytorch as pl
from omegaconf import OmegaConf
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from src.utils.model_utils import _print
from src.guidance.solubility_module import SolubilityClassifier
from src.guidance.dataloader import MembraneDataModule, get_datasets
config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml")
wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
# data
datasets = get_datasets(config)
data_module = MembraneDataModule(
config=config,
train_dataset=datasets['train'],
val_dataset=datasets['val'],
test_dataset=datasets['test'],
)
# wandb logging
#wandb.init(project=config.wandb.project, name=config.wandb.name)
wandb_logger = WandbLogger(**config.wandb)
# lightning checkpoints
lr_monitor = LearningRateMonitor(logging_interval="step")
checkpoint_callback = ModelCheckpoint(
monitor="val/loss",
save_top_k=1,
mode="min",
dirpath=config.checkpointing.save_dir,
filename="best_model",
)
# lightning trainer
trainer = pl.Trainer(
max_steps=config.training.max_steps,
accelerator="cuda",
devices=1, #config.training.devices if config.training.mode=='train' else [0],
#strategy=DDPStrategy(find_unused_parameters=True),
callbacks=[checkpoint_callback, lr_monitor],
logger=wandb_logger,
log_every_n_steps=config.training.log_every_n_steps
)
# Folder to save checkpoints
ckpt_dir = config.checkpointing.save_dir
os.makedirs(ckpt_dir, exist_ok=True)
# instantiate model
model = SolubilityClassifier(config)
# train or evalute the model
if config.training.mode == "train":
trainer.fit(model, datamodule=data_module)
elif config.training.mode == "test":
ckpt_path = os.path.join(ckpt_dir, "best_model.ckpt")
state_dict = model.get_state_dict(ckpt_path)
model.load_state_dict(state_dict)
trainer.test(model, datamodule=data_module, ckpt_path=ckpt_path)
else:
raise ValueError(f"{config.training.mode} is invalid. Must be 'train' or 'test'")
wandb.finish()