Spaces:
Paused
Paused
| import os | |
| import hydra | |
| from omegaconf import OmegaConf | |
| import torch | |
| from omegaconf import DictConfig | |
| from physicsnemo.distributed import DistributedManager | |
| from physicsnemo.launch.logging import LaunchLogger, PythonLogger | |
| from physicsnemo.launch.utils import load_checkpoint, save_checkpoint | |
| from physicsnemo.sym.hydra import to_absolute_path | |
| from torch.nn.parallel import DistributedDataParallel | |
| from torch.optim import AdamW | |
| from dataloaders import Dedalus2DDataset, MHDDataloaderVecPot | |
| from losses import LossMHDVecPot_PhysicsNeMo | |
| from tfno import TFNO | |
| from utils.plot_utils import plot_predictions_mhd, plot_predictions_mhd_plotly | |
| dtype = torch.float | |
| torch.set_default_dtype(dtype) | |
| def main(cfg: DictConfig) -> None: | |
| DistributedManager.initialize() # Only call this once in the entire script! | |
| dist = DistributedManager() # call if required elsewhere | |
| cfg = OmegaConf.to_container(cfg, resolve=True) | |
| # initialize monitoring | |
| log = PythonLogger(name="mhd_pino") | |
| log.file_logging() | |
| log_params = cfg["log_params"] | |
| # Load config file parameters | |
| model_params = cfg["model_params"] | |
| dataset_params = cfg["dataset_params"] | |
| train_loader_params = cfg["train_loader_params"] | |
| val_loader_params = cfg["val_loader_params"] | |
| loss_params = cfg["loss_params"] | |
| optimizer_params = cfg["optimizer_params"] | |
| train_params = cfg["train_params"] | |
| load_ckpt = cfg["load_ckpt"] | |
| output_dir = cfg["output_dir"] | |
| output_dir = to_absolute_path(output_dir) | |
| os.makedirs(output_dir, exist_ok=True) | |
| data_dir = dataset_params["data_dir"] | |
| ckpt_path = train_params["ckpt_path"] | |
| # Construct dataloaders | |
| dataset_train = Dedalus2DDataset( | |
| dataset_params["data_dir"], | |
| output_names=dataset_params["output_names"], | |
| field_names=dataset_params["field_names"], | |
| num_train=dataset_params["num_train"], | |
| num_test=dataset_params["num_test"], | |
| num=dataset_params["num"], | |
| use_train=True, | |
| ) | |
| dataset_val = Dedalus2DDataset( | |
| data_dir, | |
| output_names=dataset_params["output_names"], | |
| field_names=dataset_params["field_names"], | |
| num_train=dataset_params["num_train"], | |
| num_test=dataset_params["num_test"], | |
| num=dataset_params["num"], | |
| use_train=False, | |
| ) | |
| mhd_dataloader_train = MHDDataloaderVecPot( | |
| dataset_train, | |
| sub_x=dataset_params["sub_x"], | |
| sub_t=dataset_params["sub_t"], | |
| ind_x=dataset_params["ind_x"], | |
| ind_t=dataset_params["ind_t"], | |
| ) | |
| mhd_dataloader_val = MHDDataloaderVecPot( | |
| dataset_val, | |
| sub_x=dataset_params["sub_x"], | |
| sub_t=dataset_params["sub_t"], | |
| ind_x=dataset_params["ind_x"], | |
| ind_t=dataset_params["ind_t"], | |
| ) | |
| dataloader_train, sampler_train = mhd_dataloader_train.create_dataloader( | |
| batch_size=train_loader_params["batch_size"], | |
| shuffle=train_loader_params["shuffle"], | |
| num_workers=train_loader_params["num_workers"], | |
| pin_memory=train_loader_params["pin_memory"], | |
| distributed=dist.distributed, | |
| ) | |
| dataloader_val, sampler_val = mhd_dataloader_val.create_dataloader( | |
| batch_size=val_loader_params["batch_size"], | |
| shuffle=val_loader_params["shuffle"], | |
| num_workers=val_loader_params["num_workers"], | |
| pin_memory=val_loader_params["pin_memory"], | |
| distributed=dist.distributed, | |
| ) | |
| # define FNO model | |
| model = TFNO( | |
| in_channels=model_params["in_dim"], | |
| out_channels=model_params["out_dim"], | |
| decoder_layers=model_params["decoder_layers"], | |
| decoder_layer_size=model_params["fc_dim"], | |
| dimension=model_params["dimension"], | |
| latent_channels=model_params["layers"], | |
| num_fno_layers=model_params["num_fno_layers"], | |
| num_fno_modes=model_params["modes"], | |
| padding=[model_params["pad_z"], model_params["pad_y"], model_params["pad_x"]], | |
| rank=model_params["rank"], | |
| factorization=model_params["factorization"], | |
| fixed_rank_modes=model_params["fixed_rank_modes"], | |
| decomposition_kwargs=model_params["decomposition_kwargs"], | |
| ).to(dist.device) | |
| # Set up DistributedDataParallel if using more than a single process. | |
| # The `distributed` property of DistributedManager can be used to | |
| # check this. | |
| if dist.distributed: | |
| ddps = torch.cuda.Stream() | |
| with torch.cuda.stream(ddps): | |
| model = DistributedDataParallel( | |
| model, | |
| device_ids=[dist.local_rank], # Set the device_id to be | |
| # the local rank of this process on | |
| # this node | |
| output_device=dist.device, | |
| broadcast_buffers=dist.broadcast_buffers, | |
| find_unused_parameters=dist.find_unused_parameters, | |
| ) | |
| torch.cuda.current_stream().wait_stream(ddps) | |
| # Construct optimizer and scheduler | |
| optimizer = AdamW( | |
| model.parameters(), | |
| betas=optimizer_params["betas"], | |
| lr=optimizer_params["lr"], | |
| weight_decay=0.1, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, | |
| milestones=optimizer_params["milestones"], | |
| gamma=optimizer_params["gamma"], | |
| ) | |
| # Construct Loss class | |
| mhd_loss = LossMHDVecPot_PhysicsNeMo(**loss_params) | |
| # Load model from checkpoint (if exists) | |
| loaded_epoch = 0 | |
| if load_ckpt: | |
| loaded_epoch = load_checkpoint( | |
| ckpt_path, model, optimizer, scheduler, device=dist.device | |
| ) | |
| # Training Loop | |
| epochs = train_params["epochs"] | |
| ckpt_freq = train_params["ckpt_freq"] | |
| names = dataset_params["fields"] | |
| input_norm = torch.tensor(model_params["input_norm"]).to(dist.device) | |
| output_norm = torch.tensor(model_params["output_norm"]).to(dist.device) | |
| for epoch in range(max(1, loaded_epoch + 1), epochs + 1): | |
| with LaunchLogger( | |
| "train", | |
| epoch=epoch, | |
| num_mini_batch=len(dataloader_train), | |
| epoch_alert_freq=1, | |
| ) as log: | |
| if dist.distributed: | |
| sampler_train.set_epoch(epoch) | |
| # Train Loop | |
| model.train() | |
| for i, (inputs, outputs) in enumerate(dataloader_train): | |
| inputs = inputs.type(torch.FloatTensor).to(dist.device) | |
| outputs = outputs.type(torch.FloatTensor).to(dist.device) | |
| # Zero Gradients | |
| optimizer.zero_grad() | |
| # Compute Predictions | |
| pred = ( | |
| model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute( | |
| 0, 2, 3, 4, 1 | |
| ) | |
| * output_norm | |
| ) | |
| # Compute Loss | |
| loss, loss_dict = mhd_loss(pred, outputs, inputs, return_loss_dict=True) | |
| # Compute Gradients for Back Propagation | |
| loss.backward() | |
| # Update Weights | |
| optimizer.step() | |
| log.log_minibatch(loss_dict) | |
| log.log_epoch({"Learning Rate": optimizer.param_groups[0]["lr"]}) | |
| scheduler.step() | |
| with LaunchLogger("valid", epoch=epoch) as log: | |
| # Val loop | |
| model.eval() | |
| plot_count = 0 | |
| with torch.no_grad(): | |
| for i, (inputs, outputs) in enumerate(dataloader_val): | |
| inputs = inputs.type(dtype).to(dist.device) | |
| outputs = outputs.type(dtype).to(dist.device) | |
| # Compute Predictions | |
| pred = ( | |
| model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute( | |
| 0, 2, 3, 4, 1 | |
| ) | |
| * output_norm | |
| ) | |
| # Compute Loss | |
| loss, loss_dict = mhd_loss( | |
| pred, outputs, inputs, return_loss_dict=True | |
| ) | |
| log.log_minibatch(loss_dict) | |
| # Get prediction plots to log | |
| # Do for number of batches specified in the config file | |
| if (i < log_params["log_num_plots"]) and ( | |
| epoch % log_params["log_plot_freq"] == 0 | |
| ): | |
| # Add all predictions in batch | |
| for j, _ in enumerate(pred): | |
| # Make plots for each field | |
| for index, name in enumerate(names): | |
| # Generate figure | |
| _ = plot_predictions_mhd_plotly( | |
| pred[j].cpu(), | |
| outputs[j].cpu(), | |
| inputs[j].cpu(), | |
| index=index, | |
| name=name, | |
| ) | |
| plot_count += 1 | |
| # Get prediction plots and save images locally | |
| if (i < 2) and (epoch % log_params["log_plot_freq"] == 0): | |
| # Add all predictions in batch | |
| for j, _ in enumerate(pred): | |
| # Generate figure | |
| plot_predictions_mhd( | |
| pred[j].cpu(), | |
| outputs[j].cpu(), | |
| inputs[j].cpu(), | |
| names=names, | |
| save_path=os.path.join( | |
| output_dir, | |
| "MHD_physicsnemo" + "_" + str(dist.rank), | |
| ), | |
| save_suffix=i, | |
| ) | |
| if epoch % ckpt_freq == 0 and dist.rank == 0: | |
| save_checkpoint(ckpt_path, model, optimizer, scheduler, epoch=epoch) | |
| if __name__ == "__main__": | |
| main() | |