|
|
|
|
|
|
| import torch
|
|
|
| import os
|
| from module import WeatherForecastModule
|
| from datamodule import WeatherForecastDataModule
|
| from pytorch_lightning.cli import LightningCLI
|
| from arch import Unet, R2Unet, AttUnet,AttR2Unet
|
| os.environ["WANDB_API_KEY"] = "f20f0d088ab8481e81a8623dcd59c22d4939fea1"
|
| os.environ["WANDB_ENTITY"] = "weatherforecast1024"
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| cli = LightningCLI(
|
| model_class=WeatherForecastModule,
|
| datamodule_class=WeatherForecastDataModule,
|
| seed_everything_default=42,
|
| run=False,
|
| )
|
| cli.model.set_path(cli.datamodule.hparams.dir_data)
|
| cli.model.set_size(cli.datamodule.hparams.rad_size,cli.datamodule.hparams.sat_size)
|
| cli.model.set_lat()
|
| cli.model.set_clim()
|
| cli.model.set_normalize()
|
| cli.model.set_denormalize()
|
| cli.trainer.fit(model = cli.model,datamodule = cli.datamodule)
|
| cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path="best")
|
|
|
|
|