Spaces:
Paused
Paused
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import hydra | |
| import torch | |
| from omegaconf import DictConfig, OmegaConf | |
| from physicsnemo.distributed import DistributedManager | |
| from physicsnemo.launch.logging import LaunchLogger, PythonLogger | |
| from physicsnemo.sym.hydra import to_absolute_path | |
| from torch.nn.parallel import DistributedDataParallel | |
| from torch.optim import AdamW | |
| import time | |
| from dataloaders import Dedalus2DDataset, MHDDataloaderVecPot | |
| from losses import LossMHDVecPot_PhysicsNeMo | |
| from tfno import TFNO | |
| from utils.plot_utils import plot_predictions_mhd, plot_predictions_mhd_plotly | |
| dtype = torch.float | |
| torch.set_default_dtype(dtype) | |
| def main(cfg: DictConfig) -> None: | |
| DistributedManager.initialize() # Only call this once in the entire script! | |
| dist = DistributedManager() # call if required elsewhere | |
| cfg = OmegaConf.to_container(cfg, resolve=True) | |
| # initialize monitoring | |
| log = PythonLogger(name="mhd_pino") | |
| log.file_logging() | |
| # Load config file parameters | |
| model_params = cfg["model_params"] | |
| dataset_params = cfg["dataset_params"] | |
| test_loader_params = cfg["test_loader_params"] | |
| loss_params = cfg["loss_params"] | |
| optimizer_params = cfg["optimizer_params"] | |
| output_dir = cfg["output_dir"] | |
| test_params = cfg["test"] | |
| load_checkpoint = cfg.get("load_ckpt", False) | |
| output_dir = to_absolute_path(output_dir) | |
| os.makedirs(output_dir, exist_ok=True) | |
| data_dir = dataset_params["data_dir"] | |
| # Construct dataloaders | |
| dataset_test = Dedalus2DDataset( | |
| data_dir, | |
| output_names=dataset_params["output_names"], | |
| field_names=dataset_params["field_names"], | |
| num_train=dataset_params["num_train"], | |
| num_test=dataset_params["num_test"], | |
| num=dataset_params["num"], | |
| use_train=False, | |
| ) | |
| mhd_dataloader_test = MHDDataloaderVecPot( | |
| dataset_test, | |
| sub_x=dataset_params["sub_x"], | |
| sub_t=dataset_params["sub_t"], | |
| ind_x=dataset_params["ind_x"], | |
| ind_t=dataset_params["ind_t"], | |
| ) | |
| dataloader_test, sampler_test = mhd_dataloader_test.create_dataloader( | |
| batch_size=test_loader_params["batch_size"], | |
| shuffle=test_loader_params["shuffle"], | |
| num_workers=test_loader_params["num_workers"], | |
| pin_memory=test_loader_params["pin_memory"], | |
| distributed=dist.distributed, | |
| ) | |
| # define FNO model | |
| model = TFNO( | |
| in_channels=model_params["in_dim"], | |
| out_channels=model_params["out_dim"], | |
| decoder_layers=model_params["decoder_layers"], | |
| decoder_layer_size=model_params["fc_dim"], | |
| dimension=model_params["dimension"], | |
| latent_channels=model_params["layers"], | |
| num_fno_layers=model_params["num_fno_layers"], | |
| num_fno_modes=model_params["modes"], | |
| padding=[model_params["pad_z"], model_params["pad_y"], model_params["pad_x"]], | |
| rank=model_params["rank"], | |
| factorization=model_params["factorization"], | |
| fixed_rank_modes=model_params["fixed_rank_modes"], | |
| ).to(dist.device) | |
| # Set up DistributedDataParallel if using more than a single process. | |
| # The `distributed` property of DistributedManager can be used to | |
| # check this. | |
| if dist.distributed: | |
| ddps = torch.cuda.Stream() | |
| with torch.cuda.stream(ddps): | |
| model = DistributedDataParallel( | |
| model, | |
| device_ids=[dist.local_rank], # Set the device_id to be | |
| # the local rank of this process on | |
| # this node | |
| output_device=dist.device, | |
| broadcast_buffers=dist.broadcast_buffers, | |
| find_unused_parameters=dist.find_unused_parameters, | |
| ) | |
| torch.cuda.current_stream().wait_stream(ddps) | |
| # Construct optimizer and scheduler | |
| optimizer = AdamW( | |
| model.parameters(), | |
| betas=optimizer_params["betas"], | |
| lr=optimizer_params["lr"], | |
| weight_decay=0.1, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, | |
| milestones=optimizer_params["milestones"], | |
| gamma=optimizer_params["gamma"], | |
| ) | |
| # Construct Loss class | |
| mhd_loss = LossMHDVecPot_PhysicsNeMo(**loss_params) | |
| # Load model from checkpoint (if exists) | |
| if load_checkpoint: | |
| _ = load_checkpoint( | |
| test_params["ckpt_path"], model, optimizer, scheduler, device=dist.device | |
| ) | |
| # Eval Loop | |
| names = dataset_params["fields"] | |
| input_norm = torch.tensor(model_params["input_norm"]).to(dist.device) | |
| output_norm = torch.tensor(model_params["output_norm"]).to(dist.device) | |
| with LaunchLogger("test") as log: | |
| # Val loop | |
| model.eval() | |
| plot_count = 0 | |
| with torch.no_grad(): | |
| for i, (inputs, outputs) in enumerate(dataloader_test): | |
| inputs = inputs.type(dtype).to(dist.device) | |
| outputs = outputs.type(dtype).to(dist.device) | |
| start_time = time.time() | |
| # Compute Predictions | |
| pred = ( | |
| model((inputs / input_norm).permute(0, 4, 1, 2, 3)).permute( | |
| 0, 2, 3, 4, 1 | |
| ) | |
| * output_norm | |
| ) | |
| end_time = time.time() | |
| print(f"Inference Time: {end_time-start_time}") | |
| # Compute Loss | |
| loss, loss_dict = mhd_loss(pred, outputs, inputs, return_loss_dict=True) | |
| log.log_minibatch(loss_dict) | |
| # Get prediction plots | |
| for j, _ in enumerate(pred): | |
| # Make plots for each field | |
| for index, name in enumerate(names): | |
| # Generate figure | |
| _ = plot_predictions_mhd_plotly( | |
| pred[j].cpu(), | |
| outputs[j].cpu(), | |
| inputs[j].cpu(), | |
| index=index, | |
| name=name, | |
| ) | |
| plot_count += 1 | |
| # Get prediction plots and save images locally | |
| for j, _ in enumerate(pred): | |
| # Generate figure | |
| plot_predictions_mhd( | |
| pred[j].cpu(), | |
| outputs[j].cpu(), | |
| inputs[j].cpu(), | |
| names=names, | |
| save_path=os.path.join( | |
| output_dir, | |
| "MHD_eval_" + str(dist.rank), | |
| ), | |
| save_suffix=i, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |