import torch from omegaconf import OmegaConf from swim.utils import instantiate_from_config from torchinfo import summary from swim.modules.dataset import SwimDataModule from lightning import Trainer from lightning.pytorch.loggers import WandbLogger torch.set_float32_matmul_precision("medium") config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml") model = instantiate_from_config(config.model) model.learning_rate = config.model.base_learning_rate datamodule = SwimDataModule( root_dir="/cm/shared/ninhnq3/datasets/swim_data", batch_size=2, img_size=512 ) logger = WandbLogger(project="swim", name="autoencoder_kl") trainer = Trainer(max_epochs=10, devices=[0], logger=logger, log_every_n_steps=10) trainer.fit(model, datamodule)