Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- FusionModel.egg-info/PKG-INFO +8 -0
- FusionModel.egg-info/SOURCES.txt +11 -0
- FusionModel.egg-info/dependency_links.txt +1 -0
- FusionModel.egg-info/top_level.txt +2 -0
- checkpoint/Unet/checkpoints/epoch_003.ckpt +3 -0
- checkpoint/Unet/checkpoints/last.ckpt +3 -0
- checkpoint/Unet/csv_logs/version_0/hparams.yaml +24 -0
- checkpoint/Unet/csv_logs/version_0/metrics.csv +0 -0
- checkpoint/Unet/wandb_logs/config.yaml +157 -0
- checkpoint/Unet/wandb_logs/wandb/debug-internal.log +7 -0
- checkpoint/Unet/wandb_logs/wandb/debug.log +22 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/files/output.log +0 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/files/wandb-summary.json +1 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-core.log +13 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-internal.log +17 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug.log +15 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/run-m5tg7yyl.wandb +0 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/output.log +161 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/requirements.txt +77 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/wandb-metadata.json +85 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-core.log +7 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-internal.log +7 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug.log +22 -0
- checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/run-0nx0l2dh.wandb +3 -0
- configs/AttR2Unet.yaml +86 -0
- configs/AttUnet.yaml +86 -0
- configs/Nothing.yaml +86 -0
- configs/R2Unet.yaml +95 -0
- configs/Unet.yaml +104 -0
- pyproject.toml +21 -0
- src/__pycache__/arch.cpython-310.pyc +0 -0
- src/__pycache__/arch.cpython-312.pyc +0 -0
- src/__pycache__/arch.cpython-38.pyc +0 -0
- src/__pycache__/datamodule.cpython-310.pyc +0 -0
- src/__pycache__/datamodule.cpython-312.pyc +0 -0
- src/__pycache__/lr_scheduler.cpython-310.pyc +0 -0
- src/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
- src/__pycache__/metric.cpython-310.pyc +0 -0
- src/__pycache__/metric.cpython-312.pyc +0 -0
- src/__pycache__/module.cpython-310.pyc +0 -0
- src/__pycache__/module.cpython-312.pyc +0 -0
- src/__pycache__/module.cpython-38.pyc +0 -0
- src/__pycache__/train.cpython-38.pyc +0 -0
- src/arch.py +473 -0
- src/datamodule.py +341 -0
- src/lr_scheduler.py +94 -0
- src/metric.py +44 -0
- src/module.py +168 -0
- src/rad_clim.py +23 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/run-0nx0l2dh.wandb filter=lfs diff=lfs merge=lfs -text
|
FusionModel.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: FusionModel
|
| 3 |
+
Version: 0.3.1
|
| 4 |
+
Author-email: Khanh Vinh Bui <khanhvinhbui0512@gmail.com>, Hong Trang Le <lhtrang@hcmut.edu.vn>
|
| 5 |
+
Classifier: Programming Language :: Python :: 3
|
| 6 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 7 |
+
Requires-Python: >=3.10
|
| 8 |
+
Description-Content-Type: text/markdown
|
FusionModel.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pyproject.toml
|
| 2 |
+
FusionModel.egg-info/PKG-INFO
|
| 3 |
+
FusionModel.egg-info/SOURCES.txt
|
| 4 |
+
FusionModel.egg-info/dependency_links.txt
|
| 5 |
+
FusionModel.egg-info/top_level.txt
|
| 6 |
+
src/arch.py
|
| 7 |
+
src/datamodule.py
|
| 8 |
+
src/lr_scheduler.py
|
| 9 |
+
src/metric.py
|
| 10 |
+
src/module.py
|
| 11 |
+
src/train.py
|
FusionModel.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
FusionModel.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
configs
|
| 2 |
+
src
|
checkpoint/Unet/checkpoints/epoch_003.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f005cf7c67d6259fdc39a5ccb425db8367dc96622457009fcb82a9df5123487
|
| 3 |
+
size 521087
|
checkpoint/Unet/checkpoints/last.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f005cf7c67d6259fdc39a5ccb425db8367dc96622457009fcb82a9df5123487
|
| 3 |
+
size 521087
|
checkpoint/Unet/csv_logs/version_0/hparams.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_instantiator: pytorch_lightning.cli.instantiate_module
|
| 2 |
+
ablation: 'no'
|
| 3 |
+
batch_size: 1
|
| 4 |
+
beta_1: 0.9
|
| 5 |
+
beta_2: 0.99
|
| 6 |
+
dir_data: /data/weather2025/NhaBe/
|
| 7 |
+
eta_min: 1.0e-08
|
| 8 |
+
hours_predicted: 3
|
| 9 |
+
lr: 0.0005
|
| 10 |
+
max_epochs: 50
|
| 11 |
+
num_workers: 4
|
| 12 |
+
pin_memory: false
|
| 13 |
+
pretrained_path: ''
|
| 14 |
+
rad_inp_vars: precipitation
|
| 15 |
+
rad_out_vars: precipitation
|
| 16 |
+
rad_size: 400
|
| 17 |
+
sat_inp_vars: total_precipitation
|
| 18 |
+
sat_out_vars: total_precipitation
|
| 19 |
+
sat_size: 25
|
| 20 |
+
time_points_rad: 1
|
| 21 |
+
time_points_sat: 1
|
| 22 |
+
warmup_epochs: 10
|
| 23 |
+
warmup_start_lr: 1.0e-08
|
| 24 |
+
weight_decay: 1.0e-05
|
checkpoint/Unet/csv_logs/version_0/metrics.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
checkpoint/Unet/wandb_logs/config.yaml
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytorch_lightning==2.5.1.post0
|
| 2 |
+
seed_everything: 42
|
| 3 |
+
trainer:
|
| 4 |
+
accelerator: cuda
|
| 5 |
+
strategy: auto
|
| 6 |
+
devices:
|
| 7 |
+
- 6
|
| 8 |
+
num_nodes: 1
|
| 9 |
+
precision: 16-mixed
|
| 10 |
+
logger:
|
| 11 |
+
- class_path: pytorch_lightning.loggers.WandbLogger
|
| 12 |
+
init_args:
|
| 13 |
+
name: UnetNhaBe
|
| 14 |
+
save_dir: checkpoint/Unet/wandb_logs
|
| 15 |
+
version: null
|
| 16 |
+
offline: false
|
| 17 |
+
dir: null
|
| 18 |
+
id: null
|
| 19 |
+
anonymous: null
|
| 20 |
+
project: NhaBe
|
| 21 |
+
log_model: false
|
| 22 |
+
experiment: null
|
| 23 |
+
prefix: ''
|
| 24 |
+
checkpoint_name: null
|
| 25 |
+
entity: null
|
| 26 |
+
notes: null
|
| 27 |
+
tags: null
|
| 28 |
+
config: null
|
| 29 |
+
config_exclude_keys: null
|
| 30 |
+
config_include_keys: null
|
| 31 |
+
allow_val_change: null
|
| 32 |
+
group: null
|
| 33 |
+
job_type: null
|
| 34 |
+
mode: null
|
| 35 |
+
force: null
|
| 36 |
+
reinit: null
|
| 37 |
+
resume: null
|
| 38 |
+
resume_from: null
|
| 39 |
+
fork_from: null
|
| 40 |
+
save_code: null
|
| 41 |
+
tensorboard: null
|
| 42 |
+
sync_tensorboard: null
|
| 43 |
+
monitor_gym: null
|
| 44 |
+
settings: null
|
| 45 |
+
- class_path: pytorch_lightning.loggers.CSVLogger
|
| 46 |
+
init_args:
|
| 47 |
+
save_dir: checkpoint/Unet/csv_logs
|
| 48 |
+
name: null
|
| 49 |
+
version: null
|
| 50 |
+
prefix: ''
|
| 51 |
+
flush_logs_every_n_steps: 100
|
| 52 |
+
callbacks:
|
| 53 |
+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
| 54 |
+
init_args:
|
| 55 |
+
logging_interval: step
|
| 56 |
+
log_momentum: false
|
| 57 |
+
log_weight_decay: false
|
| 58 |
+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
| 59 |
+
init_args:
|
| 60 |
+
dirpath: checkpoint/Unet/checkpoints
|
| 61 |
+
filename: epoch_{epoch:03d}
|
| 62 |
+
monitor: val/mse
|
| 63 |
+
verbose: false
|
| 64 |
+
save_last: true
|
| 65 |
+
save_top_k: 1
|
| 66 |
+
save_weights_only: false
|
| 67 |
+
mode: min
|
| 68 |
+
auto_insert_metric_name: false
|
| 69 |
+
every_n_train_steps: null
|
| 70 |
+
train_time_interval: null
|
| 71 |
+
every_n_epochs: null
|
| 72 |
+
save_on_train_epoch_end: null
|
| 73 |
+
enable_version_counter: true
|
| 74 |
+
- class_path: pytorch_lightning.callbacks.EarlyStopping
|
| 75 |
+
init_args:
|
| 76 |
+
monitor: val/mse
|
| 77 |
+
min_delta: 0.0
|
| 78 |
+
patience: 10
|
| 79 |
+
verbose: false
|
| 80 |
+
mode: min
|
| 81 |
+
strict: true
|
| 82 |
+
check_finite: true
|
| 83 |
+
stopping_threshold: null
|
| 84 |
+
divergence_threshold: null
|
| 85 |
+
check_on_train_epoch_end: null
|
| 86 |
+
log_rank_zero_only: false
|
| 87 |
+
- class_path: pytorch_lightning.callbacks.RichModelSummary
|
| 88 |
+
init_args:
|
| 89 |
+
max_depth: -1
|
| 90 |
+
fast_dev_run: false
|
| 91 |
+
max_epochs: 100
|
| 92 |
+
min_epochs: 1
|
| 93 |
+
max_steps: -1
|
| 94 |
+
min_steps: null
|
| 95 |
+
max_time: null
|
| 96 |
+
limit_train_batches: null
|
| 97 |
+
limit_val_batches: null
|
| 98 |
+
limit_test_batches: null
|
| 99 |
+
limit_predict_batches: null
|
| 100 |
+
overfit_batches: 0.0
|
| 101 |
+
val_check_interval: null
|
| 102 |
+
check_val_every_n_epoch: 1
|
| 103 |
+
num_sanity_val_steps: null
|
| 104 |
+
log_every_n_steps: null
|
| 105 |
+
enable_checkpointing: true
|
| 106 |
+
enable_progress_bar: true
|
| 107 |
+
enable_model_summary: null
|
| 108 |
+
accumulate_grad_batches: 1
|
| 109 |
+
gradient_clip_val: null
|
| 110 |
+
gradient_clip_algorithm: null
|
| 111 |
+
deterministic: null
|
| 112 |
+
benchmark: null
|
| 113 |
+
inference_mode: true
|
| 114 |
+
use_distributed_sampler: true
|
| 115 |
+
profiler: null
|
| 116 |
+
detect_anomaly: false
|
| 117 |
+
barebones: false
|
| 118 |
+
plugins: null
|
| 119 |
+
sync_batchnorm: true
|
| 120 |
+
reload_dataloaders_every_n_epochs: 0
|
| 121 |
+
default_root_dir: checkpoint/Unet
|
| 122 |
+
model_registry: null
|
| 123 |
+
model:
|
| 124 |
+
net:
|
| 125 |
+
class_path: arch.Network
|
| 126 |
+
init_args:
|
| 127 |
+
model_type: Unet
|
| 128 |
+
rad_channel: 1
|
| 129 |
+
sat_channel: 1
|
| 130 |
+
rad_size: 400
|
| 131 |
+
sat_size: 25
|
| 132 |
+
pretrained_path: ''
|
| 133 |
+
lr: 0.0005
|
| 134 |
+
beta_1: 0.9
|
| 135 |
+
beta_2: 0.99
|
| 136 |
+
weight_decay: 1.0e-05
|
| 137 |
+
warmup_epochs: 10
|
| 138 |
+
max_epochs: 50
|
| 139 |
+
warmup_start_lr: 1.0e-08
|
| 140 |
+
eta_min: 1.0e-08
|
| 141 |
+
data:
|
| 142 |
+
dir_data: /data/weather2025/NhaBe/
|
| 143 |
+
batch_size: 1
|
| 144 |
+
hours_predicted: 3
|
| 145 |
+
num_workers: 4
|
| 146 |
+
pin_memory: false
|
| 147 |
+
time_points_rad: 1
|
| 148 |
+
time_points_sat: 1
|
| 149 |
+
sat_inp_vars: total_precipitation
|
| 150 |
+
sat_out_vars: total_precipitation
|
| 151 |
+
sat_size: 25
|
| 152 |
+
rad_inp_vars: precipitation
|
| 153 |
+
rad_out_vars: precipitation
|
| 154 |
+
rad_size: 400
|
| 155 |
+
ablation: 'no'
|
| 156 |
+
optimizer: null
|
| 157 |
+
lr_scheduler: null
|
checkpoint/Unet/wandb_logs/wandb/debug-internal.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-06-17T09:05:28.179242652Z","level":"INFO","msg":"stream: starting","core version":"0.20.1","symlink path":"checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-core.log"}
|
| 2 |
+
{"time":"2025-06-17T09:05:29.423278937Z","level":"INFO","msg":"stream: created new stream","id":"0nx0l2dh"}
|
| 3 |
+
{"time":"2025-06-17T09:05:29.423321777Z","level":"INFO","msg":"stream: started","id":"0nx0l2dh"}
|
| 4 |
+
{"time":"2025-06-17T09:05:29.423393558Z","level":"INFO","msg":"sender: started","stream_id":"0nx0l2dh"}
|
| 5 |
+
{"time":"2025-06-17T09:05:29.423393088Z","level":"INFO","msg":"writer: Do: started","stream_id":"0nx0l2dh"}
|
| 6 |
+
{"time":"2025-06-17T09:05:29.423465179Z","level":"INFO","msg":"handler: started","stream_id":"0nx0l2dh"}
|
| 7 |
+
{"time":"2025-06-17T09:05:30.100696875Z","level":"INFO","msg":"Starting system monitor"}
|
checkpoint/Unet/wandb_logs/wandb/debug.log
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Current SDK version is 0.20.1
|
| 2 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Configure stats pid to 1311468
|
| 3 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/.config/wandb/settings
|
| 4 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/weather_forecast/Unet/wandb/settings
|
| 5 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 6 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:setup_run_log_directory():703] Logging user logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug.log
|
| 7 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-internal.log
|
| 8 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():831] calling init triggers
|
| 9 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():836] wandb.init called with sweep_config: {}
|
| 10 |
+
config: {'_wandb': {}}
|
| 11 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():872] starting backend
|
| 12 |
+
2025-06-17 09:05:28,169 INFO MainThread:1311468 [wandb_init.py:init():875] sending inform_init request
|
| 13 |
+
2025-06-17 09:05:28,174 INFO MainThread:1311468 [wandb_init.py:init():883] backend started and connected
|
| 14 |
+
2025-06-17 09:05:28,175 INFO MainThread:1311468 [wandb_init.py:init():956] updated telemetry
|
| 15 |
+
2025-06-17 09:05:28,175 INFO MainThread:1311468 [wandb_init.py:init():980] communicating run to backend with 90.0 second timeout
|
| 16 |
+
2025-06-17 09:05:30,098 INFO MainThread:1311468 [wandb_init.py:init():1032] starting run threads in backend
|
| 17 |
+
2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_console_start():2453] atexit reg
|
| 18 |
+
2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2301] redirect: wrap_raw
|
| 19 |
+
2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2370] Wrapping output streams.
|
| 20 |
+
2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2393] Redirects installed.
|
| 21 |
+
2025-06-17 09:05:30,177 INFO MainThread:1311468 [wandb_init.py:init():1078] run started, returning control to user process
|
| 22 |
+
2025-06-17 09:05:31,151 INFO MainThread:1311468 [wandb_run.py:_config_callback():1358] config_cb None None {'pretrained_path': '', 'lr': 0.0005, 'beta_1': 0.9, 'beta_2': 0.99, 'weight_decay': 1e-05, 'warmup_epochs': 10, 'max_epochs': 50, 'warmup_start_lr': 1e-08, 'eta_min': 1e-08, '_instantiator': 'pytorch_lightning.cli.instantiate_module', 'dir_data': '/data/weather2025/NhaBe/', 'batch_size': 1, 'hours_predicted': 3, 'num_workers': 4, 'pin_memory': False, 'time_points_rad': 1, 'time_points_sat': 1, 'sat_inp_vars': 'total_precipitation', 'sat_out_vars': 'total_precipitation', 'sat_size': 25, 'rad_inp_vars': 'precipitation', 'rad_out_vars': 'precipitation', 'rad_size': 400, 'ablation': 'no'}
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/files/output.log
ADDED
|
File without changes
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/files/wandb-summary.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"_wandb":{"runtime":0}}
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-core.log
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-06-17T08:57:44.288260722Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpbpuchipv/port-1289333.txt","pid":1289333,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
|
| 2 |
+
{"time":"2025-06-17T08:57:44.289762517Z","level":"INFO","msg":"Will exit if parent process dies.","ppid":1289333}
|
| 3 |
+
{"time":"2025-06-17T08:57:44.289701246Z","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":43767,"Zone":""}}
|
| 4 |
+
{"time":"2025-06-17T08:57:44.468360629Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:33480"}
|
| 5 |
+
{"time":"2025-06-17T08:57:44.478126346Z","level":"INFO","msg":"handleInformInit: received","streamId":"m5tg7yyl","id":"127.0.0.1:33480"}
|
| 6 |
+
{"time":"2025-06-17T08:57:45.013693012Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"m5tg7yyl","id":"127.0.0.1:33480"}
|
| 7 |
+
{"time":"2025-06-17T08:57:46.227115814Z","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"127.0.0.1:33480"}
|
| 8 |
+
{"time":"2025-06-17T08:57:46.227331796Z","level":"INFO","msg":"server is shutting down"}
|
| 9 |
+
{"time":"2025-06-17T08:57:46.227313783Z","level":"INFO","msg":"connection: closing","id":"127.0.0.1:33480"}
|
| 10 |
+
{"time":"2025-06-17T08:57:46.227453186Z","level":"INFO","msg":"connection: closed successfully","id":"127.0.0.1:33480"}
|
| 11 |
+
{"time":"2025-06-17T08:57:46.48785785Z","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"127.0.0.1:33480"}
|
| 12 |
+
{"time":"2025-06-17T08:57:46.487909579Z","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"127.0.0.1:33480"}
|
| 13 |
+
{"time":"2025-06-17T08:57:46.487925552Z","level":"INFO","msg":"server is closed"}
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-06-17T08:57:44.478779812Z","level":"INFO","msg":"stream: starting","core version":"0.20.1","symlink path":"checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-core.log"}
|
| 2 |
+
{"time":"2025-06-17T08:57:45.013625369Z","level":"INFO","msg":"stream: created new stream","id":"m5tg7yyl"}
|
| 3 |
+
{"time":"2025-06-17T08:57:45.013682966Z","level":"INFO","msg":"stream: started","id":"m5tg7yyl"}
|
| 4 |
+
{"time":"2025-06-17T08:57:45.013709365Z","level":"INFO","msg":"handler: started","stream_id":"m5tg7yyl"}
|
| 5 |
+
{"time":"2025-06-17T08:57:45.013744481Z","level":"INFO","msg":"sender: started","stream_id":"m5tg7yyl"}
|
| 6 |
+
{"time":"2025-06-17T08:57:45.013733645Z","level":"INFO","msg":"writer: Do: started","stream_id":"m5tg7yyl"}
|
| 7 |
+
{"time":"2025-06-17T08:57:45.224225022Z","level":"ERROR","msg":"HTTP error","status":403,"method":"POST","url":"https://api.wandb.ai/graphql"}
|
| 8 |
+
{"time":"2025-06-17T08:57:45.22437671Z","level":"ERROR","msg":"runupserter: failed to init run","error":"returned error 403: {\"data\":{\"upsertBucket\":null},\"errors\":[{\"message\":\"permission denied\",\"path\":[\"upsertBucket\"],\"extensions\":{\"code\":\"PERMISSION_ERROR\"}}]}"}
|
| 9 |
+
{"time":"2025-06-17T08:57:46.227328327Z","level":"INFO","msg":"stream: closing","id":"m5tg7yyl"}
|
| 10 |
+
{"time":"2025-06-17T08:57:46.227825345Z","level":"ERROR","msg":"sender: uploadConfigFile: stream: no run"}
|
| 11 |
+
{"time":"2025-06-17T08:57:46.486865753Z","level":"ERROR","msg":"HTTP error","status":404,"method":"POST","url":"https://api.wandb.ai/graphql"}
|
| 12 |
+
{"time":"2025-06-17T08:57:46.486986554Z","level":"ERROR","msg":"runfiles: CreateRunFiles returned error: returned error 404: {\"data\":{\"createRunFiles\":null},\"errors\":[{\"message\":\"project vinh-bui0512-hcmut/NhaBe not found during createRunFiles\",\"path\":[\"createRunFiles\"]}]}"}
|
| 13 |
+
{"time":"2025-06-17T08:57:46.487641258Z","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
|
| 14 |
+
{"time":"2025-06-17T08:57:46.487699694Z","level":"INFO","msg":"handler: closed","stream_id":"m5tg7yyl"}
|
| 15 |
+
{"time":"2025-06-17T08:57:46.487714658Z","level":"INFO","msg":"writer: Close: closed","stream_id":"m5tg7yyl"}
|
| 16 |
+
{"time":"2025-06-17T08:57:46.487745625Z","level":"INFO","msg":"sender: closed","stream_id":"m5tg7yyl"}
|
| 17 |
+
{"time":"2025-06-17T08:57:46.487775923Z","level":"INFO","msg":"stream: closed","id":"m5tg7yyl"}
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug.log
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Current SDK version is 0.20.1
|
| 2 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Configure stats pid to 1289333
|
| 3 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/.config/wandb/settings
|
| 4 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/weather_forecast/Unet/wandb/settings
|
| 5 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 6 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:setup_run_log_directory():703] Logging user logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug.log
|
| 7 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-internal.log
|
| 8 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:init():831] calling init triggers
|
| 9 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:init():836] wandb.init called with sweep_config: {}
|
| 10 |
+
config: {'_wandb': {}}
|
| 11 |
+
2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:init():872] starting backend
|
| 12 |
+
2025-06-17 08:57:44,468 INFO MainThread:1289333 [wandb_init.py:init():875] sending inform_init request
|
| 13 |
+
2025-06-17 08:57:44,473 INFO MainThread:1289333 [wandb_init.py:init():883] backend started and connected
|
| 14 |
+
2025-06-17 08:57:44,475 INFO MainThread:1289333 [wandb_init.py:init():956] updated telemetry
|
| 15 |
+
2025-06-17 08:57:44,476 INFO MainThread:1289333 [wandb_init.py:init():980] communicating run to backend with 90.0 second timeout
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/run-m5tg7yyl.wandb
ADDED
|
Binary file (366 Bytes). View file
|
|
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/output.log
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Number of train samples: 31462
|
| 2 |
+
Number of test samples: 8077
|
| 3 |
+
Number of val samples: 1398
|
| 4 |
+
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6]
|
| 5 |
+
┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓
|
| 6 |
+
┃[1;35m [0m[1;35m [0m[1;35m [0m┃[1;35m [0m[1;35mName [0m[1;35m [0m┃[1;35m [0m[1;35mType [0m[1;35m [0m┃[1;35m [0m[1;35mParams[0m[1;35m [0m┃[1;35m [0m[1;35mMode [0m[1;35m [0m┃
|
| 7 |
+
┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩
|
| 8 |
+
│[2m [0m[2m0 [0m[2m [0m│ net │ Network │ 30.0 K │ train │
|
| 9 |
+
│[2m [0m[2m1 [0m[2m [0m│ net.net │ Unet │ 30.0 K │ train │
|
| 10 |
+
│[2m [0m[2m2 [0m[2m [0m│ net.net.encoder_blocks │ ModuleList │ 4.8 K │ train │
|
| 11 |
+
│[2m [0m[2m3 [0m[2m [0m│ net.net.encoder_blocks.0 │ ConvBlock │ 66 │ train │
|
| 12 |
+
│[2m [0m[2m4 [0m[2m [0m│ net.net.encoder_blocks.0.conv │ Sequential │ 66 │ train │
|
| 13 |
+
│[2m [0m[2m5 [0m[2m [0m│ net.net.encoder_blocks.0.conv.0 │ Conv2d │ 20 │ train │
|
| 14 |
+
│[2m [0m[2m6 [0m[2m [0m│ net.net.encoder_blocks.0.conv.1 │ BatchNorm2d │ 4 │ train │
|
| 15 |
+
│[2m [0m[2m7 [0m[2m [0m│ net.net.encoder_blocks.0.conv.2 │ ReLU │ 0 │ train │
|
| 16 |
+
│[2m [0m[2m8 [0m[2m [0m│ net.net.encoder_blocks.0.conv.3 │ Conv2d │ 38 │ train │
|
| 17 |
+
│[2m [0m[2m9 [0m[2m [0m│ net.net.encoder_blocks.0.conv.4 │ BatchNorm2d │ 4 │ train │
|
| 18 |
+
│[2m [0m[2m10 [0m[2m [0m│ net.net.encoder_blocks.0.conv.5 │ ReLU │ 0 │ train │
|
| 19 |
+
│[2m [0m[2m11 [0m[2m [0m│ net.net.encoder_blocks.1 │ ConvBlock │ 240 │ train │
|
| 20 |
+
│[2m [0m[2m12 [0m[2m [0m│ net.net.encoder_blocks.1.conv │ Sequential │ 240 │ train │
|
| 21 |
+
│[2m [0m[2m13 [0m[2m [0m│ net.net.encoder_blocks.1.conv.0 │ Conv2d │ 76 │ train │
|
| 22 |
+
│[2m [0m[2m14 [0m[2m [0m│ net.net.encoder_blocks.1.conv.1 │ BatchNorm2d │ 8 │ train │
|
| 23 |
+
│[2m [0m[2m15 [0m[2m [0m│ net.net.encoder_blocks.1.conv.2 │ ReLU │ 0 │ train │
|
| 24 |
+
│[2m [0m[2m16 [0m[2m [0m│ net.net.encoder_blocks.1.conv.3 │ Conv2d │ 148 │ train │
|
| 25 |
+
│[2m [0m[2m17 [0m[2m [0m│ net.net.encoder_blocks.1.conv.4 │ BatchNorm2d │ 8 │ train │
|
| 26 |
+
│[2m [0m[2m18 [0m[2m [0m│ net.net.encoder_blocks.1.conv.5 │ ReLU │ 0 │ train │
|
| 27 |
+
│[2m [0m[2m19 [0m[2m [0m│ net.net.encoder_blocks.2 │ ConvBlock │ 912 │ train │
|
| 28 |
+
│[2m [0m[2m20 [0m[2m [0m│ net.net.encoder_blocks.2.conv │ Sequential │ 912 │ train │
|
| 29 |
+
│[2m [0m[2m21 [0m[2m [0m│ net.net.encoder_blocks.2.conv.0 │ Conv2d │ 296 │ train │
|
| 30 |
+
│[2m [0m[2m22 [0m[2m [0m│ net.net.encoder_blocks.2.conv.1 │ BatchNorm2d │ 16 │ train │
|
| 31 |
+
│[2m [0m[2m23 [0m[2m [0m│ net.net.encoder_blocks.2.conv.2 │ ReLU │ 0 │ train │
|
| 32 |
+
│[2m [0m[2m24 [0m[2m [0m│ net.net.encoder_blocks.2.conv.3 │ Conv2d │ 584 │ train │
|
| 33 |
+
│[2m [0m[2m25 [0m[2m [0m│ net.net.encoder_blocks.2.conv.4 │ BatchNorm2d │ 16 │ train │
|
| 34 |
+
│[2m [0m[2m26 [0m[2m [0m│ net.net.encoder_blocks.2.conv.5 │ ReLU │ 0 │ train │
|
| 35 |
+
│[2m [0m[2m27 [0m[2m [0m│ net.net.encoder_blocks.3 │ ConvBlock │ 3.6 K │ train │
|
| 36 |
+
│[2m [0m[2m28 [0m[2m [0m│ net.net.encoder_blocks.3.conv │ Sequential │ 3.6 K │ train │
|
| 37 |
+
│[2m [0m[2m29 [0m[2m [0m│ net.net.encoder_blocks.3.conv.0 │ Conv2d │ 1.2 K │ train │
|
| 38 |
+
│[2m [0m[2m30 [0m[2m [0m│ net.net.encoder_blocks.3.conv.1 │ BatchNorm2d │ 32 │ train │
|
| 39 |
+
│[2m [0m[2m31 [0m[2m [0m│ net.net.encoder_blocks.3.conv.2 │ ReLU │ 0 │ train │
|
| 40 |
+
│[2m [0m[2m32 [0m[2m [0m│ net.net.encoder_blocks.3.conv.3 │ Conv2d │ 2.3 K │ train │
|
| 41 |
+
│[2m [0m[2m33 [0m[2m [0m│ net.net.encoder_blocks.3.conv.4 │ BatchNorm2d │ 32 │ train │
|
| 42 |
+
│[2m [0m[2m34 [0m[2m [0m│ net.net.encoder_blocks.3.conv.5 │ ReLU │ 0 │ train │
|
| 43 |
+
│[2m [0m[2m35 [0m[2m [0m│ net.net.pools │ ModuleList │ 0 │ train │
|
| 44 |
+
│[2m [0m[2m36 [0m[2m [0m│ net.net.pools.0 │ MaxPool2d │ 0 │ train │
|
| 45 |
+
│[2m [0m[2m37 [0m[2m [0m│ net.net.pools.1 │ MaxPool2d │ 0 │ train │
|
| 46 |
+
│[2m [0m[2m38 [0m[2m [0m│ net.net.pools.2 │ MaxPool2d │ 0 │ train │
|
| 47 |
+
│[2m [0m[2m39 [0m[2m [0m│ net.net.pools.3 │ MaxPool2d │ 0 │ train │
|
| 48 |
+
│[2m [0m[2m40 [0m[2m [0m│ net.net.mid_conv_1 │ single_conv │ 2.4 K │ train │
|
| 49 |
+
│[2m [0m[2m41 [0m[2m [0m│ net.net.mid_conv_1.conv │ Sequential │ 2.4 K │ train │
|
| 50 |
+
│[2m [0m[2m42 [0m[2m [0m│ net.net.mid_conv_1.conv.0 │ Conv2d │ 2.3 K │ train │
|
| 51 |
+
│[2m [0m[2m43 [0m[2m [0m│ net.net.mid_conv_1.conv.1 │ BatchNorm2d │ 32 │ train │
|
| 52 |
+
│[2m [0m[2m44 [0m[2m [0m│ net.net.mid_conv_1.conv.2 │ ReLU │ 0 │ train │
|
| 53 |
+
│[2m [0m[2m45 [0m[2m [0m│ net.net.mid_conv_2 │ single_conv │ 192 │ train │
|
| 54 |
+
│[2m [0m[2m46 [0m[2m [0m│ net.net.mid_conv_2.conv │ Sequential │ 192 │ train │
|
| 55 |
+
│[2m [0m[2m47 [0m[2m [0m│ net.net.mid_conv_2.conv.0 │ Conv2d │ 160 │ train │
|
| 56 |
+
│[2m [0m[2m48 [0m[2m [0m│ net.net.mid_conv_2.conv.1 │ BatchNorm2d │ 32 │ train │
|
| 57 |
+
│[2m [0m[2m49 [0m[2m [0m│ net.net.mid_conv_2.conv.2 │ ReLU │ 0 │ train │
|
| 58 |
+
│[2m [0m[2m50 [0m[2m [0m│ net.net.mid_merge │ ConvBlock │ 7.0 K │ train │
|
| 59 |
+
│[2m [0m[2m51 [0m[2m [0m│ net.net.mid_merge.conv │ Sequential │ 7.0 K │ train │
|
| 60 |
+
│[2m [0m[2m52 [0m[2m [0m│ net.net.mid_merge.conv.0 │ Conv2d │ 4.6 K │ train │
|
| 61 |
+
│[2m [0m[2m53 [0m[2m [0m│ net.net.mid_merge.conv.1 │ BatchNorm2d │ 32 │ train │
|
| 62 |
+
│[2m [0m[2m54 [0m[2m [0m│ net.net.mid_merge.conv.2 │ ReLU │ 0 │ train │
|
| 63 |
+
│[2m [0m[2m55 [0m[2m [0m│ net.net.mid_merge.conv.3 │ Conv2d │ 2.3 K │ train │
|
| 64 |
+
│[2m [0m[2m56 [0m[2m [0m│ net.net.mid_merge.conv.4 │ BatchNorm2d │ 32 │ train │
|
| 65 |
+
│[2m [0m[2m57 [0m[2m [0m│ net.net.mid_merge.conv.5 │ ReLU │ 0 │ train │
|
| 66 |
+
│[2m [0m[2m58 [0m[2m [0m│ net.net.up_convs │ ModuleList │ 6.2 K │ train │
|
| 67 |
+
│[2m [0m[2m59 [0m[2m [0m│ net.net.up_convs.0 │ UpConv │ 4.7 K │ train │
|
| 68 |
+
│[2m [0m[2m60 [0m[2m [0m│ net.net.up_convs.0.up │ Sequential │ 4.7 K │ train │
|
| 69 |
+
│[2m [0m[2m61 [0m[2m [0m│ net.net.up_convs.0.up.0 │ Upsample │ 0 │ train │
|
| 70 |
+
│[2m [0m[2m62 [0m[2m [0m│ net.net.up_convs.0.up.1 │ Conv2d │ 4.6 K │ train │
|
| 71 |
+
│[2m [0m[2m63 [0m[2m [0m│ net.net.up_convs.0.up.2 │ BatchNorm2d │ 32 │ train │
|
| 72 |
+
│[2m [0m[2m64 [0m[2m [0m│ net.net.up_convs.0.up.3 │ ReLU │ 0 │ train │
|
| 73 |
+
│[2m [0m[2m65 [0m[2m [0m│ net.net.up_convs.1 │ UpConv │ 1.2 K │ train │
|
| 74 |
+
│[2m [0m[2m66 [0m[2m [0m│ net.net.up_convs.1.up │ Sequential │ 1.2 K │ train │
|
| 75 |
+
│[2m [0m[2m67 [0m[2m [0m│ net.net.up_convs.1.up.0 │ Upsample │ 0 │ train │
|
| 76 |
+
│[2m [0m[2m68 [0m[2m [0m│ net.net.up_convs.1.up.1 │ Conv2d │ 1.2 K │ train │
|
| 77 |
+
│[2m [0m[2m69 [0m[2m [0m│ net.net.up_convs.1.up.2 │ BatchNorm2d │ 16 │ train │
|
| 78 |
+
│[2m [0m[2m70 [0m[2m [0m│ net.net.up_convs.1.up.3 │ ReLU │ 0 │ train │
|
| 79 |
+
│[2m [0m[2m71 [0m[2m [0m│ net.net.up_convs.2 │ UpConv │ 300 │ train │
|
| 80 |
+
│[2m [0m[2m72 [0m[2m [0m│ net.net.up_convs.2.up │ Sequential │ 300 │ train │
|
| 81 |
+
│[2m [0m[2m73 [0m[2m [0m│ net.net.up_convs.2.up.0 │ Upsample │ 0 │ train │
|
| 82 |
+
│[2m [0m[2m74 [0m[2m [0m│ net.net.up_convs.2.up.1 │ Conv2d │ 292 │ train │
|
| 83 |
+
│[2m [0m[2m75 [0m[2m [0m│ net.net.up_convs.2.up.2 │ BatchNorm2d │ 8 │ train │
|
| 84 |
+
│[2m [0m[2m76 [0m[2m [0m│ net.net.up_convs.2.up.3 │ ReLU │ 0 │ train │
|
| 85 |
+
│[2m [0m[2m77 [0m[2m [0m│ net.net.up_convs.3 │ UpConv │ 78 │ train │
|
| 86 |
+
│[2m [0m[2m78 [0m[2m [0m│ net.net.up_convs.3.up │ Sequential │ 78 │ train │
|
| 87 |
+
│[2m [0m[2m79 [0m[2m [0m│ net.net.up_convs.3.up.0 │ Upsample │ 0 │ train │
|
| 88 |
+
│[2m [0m[2m80 [0m[2m [0m│ net.net.up_convs.3.up.1 │ Conv2d │ 74 │ train │
|
| 89 |
+
│[2m [0m[2m81 [0m[2m [0m│ net.net.up_convs.3.up.2 │ BatchNorm2d │ 4 │ train │
|
| 90 |
+
│[2m [0m[2m82 [0m[2m [0m│ net.net.up_convs.3.up.3 │ ReLU │ 0 │ train │
|
| 91 |
+
│[2m [0m[2m83 [0m[2m [0m│ net.net.decoder_blocks │ ModuleList │ 9.4 K │ train │
|
| 92 |
+
│[2m [0m[2m84 [0m[2m [0m│ net.net.decoder_blocks.0 │ ConvBlock │ 7.0 K │ train │
|
| 93 |
+
│[2m [0m[2m85 [0m[2m [0m│ net.net.decoder_blocks.0.conv │ Sequential │ 7.0 K │ train │
|
| 94 |
+
│[2m [0m[2m86 [0m[2m [0m│ net.net.decoder_blocks.0.conv.0 │ Conv2d │ 4.6 K │ train │
|
| 95 |
+
│[2m [0m[2m87 [0m[2m [0m│ net.net.decoder_blocks.0.conv.1 │ BatchNorm2d │ 32 │ train │
|
| 96 |
+
│[2m [0m[2m88 [0m[2m [0m│ net.net.decoder_blocks.0.conv.2 │ ReLU │ 0 │ train │
|
| 97 |
+
│[2m [0m[2m89 [0m[2m [0m│ net.net.decoder_blocks.0.conv.3 │ Conv2d │ 2.3 K │ train │
|
| 98 |
+
│[2m [0m[2m90 [0m[2m [0m│ net.net.decoder_blocks.0.conv.4 │ BatchNorm2d │ 32 │ train │
|
| 99 |
+
│[2m [0m[2m91 [0m[2m [0m│ net.net.decoder_blocks.0.conv.5 │ ReLU │ 0 │ train │
|
| 100 |
+
│[2m [0m[2m92 [0m[2m [0m│ net.net.decoder_blocks.1 │ ConvBlock │ 1.8 K │ train │
|
| 101 |
+
│[2m [0m[2m93 [0m[2m [0m│ net.net.decoder_blocks.1.conv │ Sequential │ 1.8 K │ train │
|
| 102 |
+
│[2m [0m[2m94 [0m[2m [0m│ net.net.decoder_blocks.1.conv.0 │ Conv2d │ 1.2 K │ train │
|
| 103 |
+
│[2m [0m[2m95 [0m[2m [0m│ net.net.decoder_blocks.1.conv.1 │ BatchNorm2d │ 16 │ train │
|
| 104 |
+
│[2m [0m[2m96 [0m[2m [0m│ net.net.decoder_blocks.1.conv.2 │ ReLU │ 0 │ train │
|
| 105 |
+
│[2m [0m[2m97 [0m[2m [0m│ net.net.decoder_blocks.1.conv.3 │ Conv2d │ 584 │ train │
|
| 106 |
+
│[2m [0m[2m98 [0m[2m [0m│ net.net.decoder_blocks.1.conv.4 │ BatchNorm2d │ 16 │ train │
|
| 107 |
+
│[2m [0m[2m99 [0m[2m [0m│ net.net.decoder_blocks.1.conv.5 │ ReLU │ 0 │ train │
|
| 108 |
+
│[2m [0m[2m100[0m[2m [0m│ net.net.decoder_blocks.2 │ ConvBlock │ 456 │ train │
|
| 109 |
+
│[2m [0m[2m101[0m[2m [0m│ net.net.decoder_blocks.2.conv │ Sequential │ 456 │ train │
|
| 110 |
+
│[2m [0m[2m102[0m[2m [0m│ net.net.decoder_blocks.2.conv.0 │ Conv2d │ 292 │ train │
|
| 111 |
+
│[2m [0m[2m103[0m[2m [0m│ net.net.decoder_blocks.2.conv.1 │ BatchNorm2d │ 8 │ train │
|
| 112 |
+
│[2m [0m[2m104[0m[2m [0m│ net.net.decoder_blocks.2.conv.2 │ ReLU │ 0 │ train │
|
| 113 |
+
│[2m [0m[2m105[0m[2m [0m│ net.net.decoder_blocks.2.conv.3 │ Conv2d │ 148 │ train │
|
| 114 |
+
│[2m [0m[2m106[0m[2m [0m│ net.net.decoder_blocks.2.conv.4 │ BatchNorm2d │ 8 │ train │
|
| 115 |
+
│[2m [0m[2m107[0m[2m [0m│ net.net.decoder_blocks.2.conv.5 │ ReLU │ 0 │ train │
|
| 116 |
+
│[2m [0m[2m108[0m[2m [0m│ net.net.decoder_blocks.3 │ ConvBlock │ 120 │ train │
|
| 117 |
+
│[2m [0m[2m109[0m[2m [0m│ net.net.decoder_blocks.3.conv │ Sequential │ 120 │ train │
|
| 118 |
+
│[2m [0m[2m110[0m[2m [0m│ net.net.decoder_blocks.3.conv.0 │ Conv2d │ 74 │ train │
|
| 119 |
+
│[2m [0m[2m111[0m[2m [0m│ net.net.decoder_blocks.3.conv.1 │ BatchNorm2d │ 4 │ train │
|
| 120 |
+
│[2m [0m[2m112[0m[2m [0m│ net.net.decoder_blocks.3.conv.2 │ ReLU │ 0 │ train │
|
| 121 |
+
│[2m [0m[2m113[0m[2m [0m│ net.net.decoder_blocks.3.conv.3 │ Conv2d │ 38 │ train │
|
| 122 |
+
│[2m [0m[2m114[0m[2m [0m│ net.net.decoder_blocks.3.conv.4 │ BatchNorm2d │ 4 │ train │
|
| 123 |
+
│[2m [0m[2m115[0m[2m [0m│ net.net.decoder_blocks.3.conv.5 │ ReLU │ 0 │ train │
|
| 124 |
+
│[2m [0m[2m116[0m[2m [0m│ net.net.final_decoder │ ConvBlock │ 120 │ train │
|
| 125 |
+
│[2m [0m[2m117[0m[2m [0m│ net.net.final_decoder.conv │ Sequential │ 120 │ train │
|
| 126 |
+
│[2m [0m[2m118[0m[2m [0m│ net.net.final_decoder.conv.0 │ Conv2d │ 74 │ train │
|
| 127 |
+
│[2m [0m[2m119[0m[2m [0m│ net.net.final_decoder.conv.1 │ BatchNorm2d │ 4 │ train │
|
| 128 |
+
│[2m [0m[2m120[0m[2m [0m│ net.net.final_decoder.conv.2 │ ReLU │ 0 │ train │
|
| 129 |
+
│[2m [0m[2m121[0m[2m [0m│ net.net.final_decoder.conv.3 │ Conv2d │ 38 │ train │
|
| 130 |
+
│[2m [0m[2m122[0m[2m [0m│ net.net.final_decoder.conv.4 │ BatchNorm2d │ 4 │ train │
|
| 131 |
+
│[2m [0m[2m123[0m[2m [0m│ net.net.final_decoder.conv.5 │ ReLU │ 0 │ train │
|
| 132 |
+
│[2m [0m[2m124[0m[2m [0m│ net.net.out_conv_R │ Conv2d │ 3 │ train │
|
| 133 |
+
│[2m [0m[2m125[0m[2m [0m│ net.net.out_conv_S │ Conv2d │ 17 │ train │
|
| 134 |
+
│[2m [0m[2m126[0m[2m [0m│ rad_denormalization │ Normalize │ 0 │ train │
|
| 135 |
+
│[2m [0m[2m127[0m[2m [0m│ sat_denormalization │ Normalize │ 0 │ train │
|
| 136 |
+
└─────┴─────────────────────────────────┴─────────────┴────────┴───────┘
|
| 137 |
+
[1mTrainable params[0m: 30.0 K
|
| 138 |
+
[1mNon-trainable params[0m: 0
|
| 139 |
+
[1mTotal params[0m: 30.0 K
|
| 140 |
+
[1mTotal estimated model params size (MB)[0m: 0
|
| 141 |
+
[1mModules in train mode[0m: 128
|
| 142 |
+
[1mModules in eval mode[0m: 0
|
| 143 |
+
Epoch 4: 17%|▏| 5205/31462 [02:33<12:54, 33.89it/s, v_num=dh_0, train/rad=0.120, train/sat=2.380, train/mse=2.500, val/rad=1.970, val/sat=1.140, val/mse
|
| 144 |
+
/home/radaric/.conda/envs/unet/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:182: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
| 145 |
+
warnings.warn(
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
Detected KeyboardInterrupt, attempting graceful shutdown ...
|
| 149 |
+
Exception ignored in atexit callback: <function _start_and_connect_service.<locals>.teardown_atexit at 0x7fafa15b5360>
|
| 150 |
+
Traceback (most recent call last):
|
| 151 |
+
File "/home/radaric/.conda/envs/unet/lib/python3.10/site-packages/wandb/sdk/lib/service_connection.py", line 90, in teardown_atexit
|
| 152 |
+
conn.teardown(hooks.exit_code)
|
| 153 |
+
File "/home/radaric/.conda/envs/unet/lib/python3.10/site-packages/wandb/sdk/lib/service_connection.py", line 218, in teardown
|
| 154 |
+
self._router.join()
|
| 155 |
+
File "/home/radaric/.conda/envs/unet/lib/python3.10/site-packages/wandb/sdk/interface/router.py", line 75, in join
|
| 156 |
+
self._thread.join()
|
| 157 |
+
File "/home/radaric/.conda/envs/unet/lib/python3.10/threading.py", line 1096, in join
|
| 158 |
+
self._wait_for_tstate_lock()
|
| 159 |
+
File "/home/radaric/.conda/envs/unet/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
|
| 160 |
+
if lock.acquire(block, timeout):
|
| 161 |
+
KeyboardInterrupt:
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/requirements.txt
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
urllib3==2.4.0
|
| 2 |
+
requests==2.32.4
|
| 3 |
+
typing-inspection==0.4.1
|
| 4 |
+
Jinja2==3.1.6
|
| 5 |
+
MarkupSafe==3.0.2
|
| 6 |
+
setuptools==78.1.1
|
| 7 |
+
frozenlist==1.7.0
|
| 8 |
+
aiosignal==1.3.2
|
| 9 |
+
docstring_parser==0.16
|
| 10 |
+
aiohappyeyeballs==2.6.1
|
| 11 |
+
ClimaX==0.3.1
|
| 12 |
+
platformdirs==4.3.8
|
| 13 |
+
async-timeout==5.0.1
|
| 14 |
+
nvidia-cusolver-cu12==11.7.1.2
|
| 15 |
+
protobuf==6.31.1
|
| 16 |
+
charset-normalizer==3.4.2
|
| 17 |
+
attrs==25.3.0
|
| 18 |
+
pip==25.1
|
| 19 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 20 |
+
importlib_resources==6.5.2
|
| 21 |
+
nvidia-nvjitlink-cu12==12.6.85
|
| 22 |
+
numpy==2.2.6
|
| 23 |
+
typeshed_client==2.7.0
|
| 24 |
+
jsonargparse==4.40.0
|
| 25 |
+
nvidia-cusparselt-cu12==0.6.3
|
| 26 |
+
GitPython==3.1.44
|
| 27 |
+
nvidia-cusparse-cu12==12.5.4.2
|
| 28 |
+
mpmath==1.3.0
|
| 29 |
+
pytorch-lightning==2.5.1.post0
|
| 30 |
+
torchvision==0.22.1
|
| 31 |
+
PyYAML==6.0.2
|
| 32 |
+
nvidia-cudnn-cu12==9.5.1.17
|
| 33 |
+
markdown-it-py==3.0.0
|
| 34 |
+
typing_extensions==4.14.0
|
| 35 |
+
smmap==5.0.2
|
| 36 |
+
pydantic_core==2.33.2
|
| 37 |
+
torchsummary==1.5.1
|
| 38 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 39 |
+
FusionModel==0.3.1
|
| 40 |
+
mdurl==0.1.2
|
| 41 |
+
sentry-sdk==2.30.0
|
| 42 |
+
nvidia-curand-cu12==10.3.7.77
|
| 43 |
+
idna==3.10
|
| 44 |
+
triton==3.3.1
|
| 45 |
+
multidict==6.4.4
|
| 46 |
+
Pygments==2.19.1
|
| 47 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
| 48 |
+
tqdm==4.67.1
|
| 49 |
+
psutil==7.0.0
|
| 50 |
+
gitdb==4.0.12
|
| 51 |
+
fsspec==2025.5.1
|
| 52 |
+
pydantic==2.11.6
|
| 53 |
+
sympy==1.14.0
|
| 54 |
+
torchaudio==2.7.1
|
| 55 |
+
nvidia-nccl-cu12==2.26.2
|
| 56 |
+
propcache==0.3.2
|
| 57 |
+
wandb==0.20.1
|
| 58 |
+
filelock==3.18.0
|
| 59 |
+
packaging==25.0
|
| 60 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 61 |
+
networkx==3.4.2
|
| 62 |
+
aiohttp==3.12.12
|
| 63 |
+
nvidia-cufft-cu12==11.3.0.4
|
| 64 |
+
nvidia-nvtx-cu12==12.6.77
|
| 65 |
+
wheel==0.45.1
|
| 66 |
+
yarl==1.20.1
|
| 67 |
+
certifi==2025.4.26
|
| 68 |
+
click==8.2.1
|
| 69 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 70 |
+
rich==14.0.0
|
| 71 |
+
pillow==11.2.1
|
| 72 |
+
setproctitle==1.3.6
|
| 73 |
+
torchmetrics==1.7.3
|
| 74 |
+
lightning-utilities==0.14.3
|
| 75 |
+
torch==2.7.1
|
| 76 |
+
annotated-types==0.7.0
|
| 77 |
+
ClimaX==0.3.1
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/wandb-metadata.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"os": "Linux-5.4.0-208-generic-x86_64-with-glibc2.31",
|
| 3 |
+
"python": "CPython 3.10.18",
|
| 4 |
+
"startedAt": "2025-06-17T09:05:28.174321Z",
|
| 5 |
+
"args": [
|
| 6 |
+
"--config",
|
| 7 |
+
"configs/Unet.yaml"
|
| 8 |
+
],
|
| 9 |
+
"program": "/home/radaric/weather_forecast/Unet/src/train.py",
|
| 10 |
+
"codePath": "src/train.py",
|
| 11 |
+
"email": "weatherforecast1024hcmut@gmail.com",
|
| 12 |
+
"root": "checkpoint/Unet/wandb_logs",
|
| 13 |
+
"host": "u116613",
|
| 14 |
+
"executable": "/home/radaric/.conda/envs/unet/bin/python",
|
| 15 |
+
"codePathLocal": "src/train.py",
|
| 16 |
+
"cpu_count": 48,
|
| 17 |
+
"cpu_count_logical": 96,
|
| 18 |
+
"gpu": "NVIDIA RTX A6000",
|
| 19 |
+
"gpu_count": 7,
|
| 20 |
+
"disk": {
|
| 21 |
+
"/": {
|
| 22 |
+
"total": "1877998821376",
|
| 23 |
+
"used": "1470173900800"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"memory": {
|
| 27 |
+
"total": "540953096192"
|
| 28 |
+
},
|
| 29 |
+
"cpu": {
|
| 30 |
+
"count": 48,
|
| 31 |
+
"countLogical": 96
|
| 32 |
+
},
|
| 33 |
+
"gpu_nvidia": [
|
| 34 |
+
{
|
| 35 |
+
"name": "NVIDIA RTX A6000",
|
| 36 |
+
"memoryTotal": "51527024640",
|
| 37 |
+
"cudaCores": 10752,
|
| 38 |
+
"architecture": "Ampere",
|
| 39 |
+
"uuid": "GPU-fb5a2de4-c79a-f2d0-a864-a6271ad28ae6"
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"name": "NVIDIA RTX A6000",
|
| 43 |
+
"memoryTotal": "51527024640",
|
| 44 |
+
"cudaCores": 10752,
|
| 45 |
+
"architecture": "Ampere",
|
| 46 |
+
"uuid": "GPU-1a8c199b-93ca-3fec-6459-a5515bf1b12b"
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"name": "NVIDIA RTX A6000",
|
| 50 |
+
"memoryTotal": "51527024640",
|
| 51 |
+
"cudaCores": 10752,
|
| 52 |
+
"architecture": "Ampere",
|
| 53 |
+
"uuid": "GPU-4d0c0cac-f72d-9dc7-9ac0-60cf8803134b"
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"name": "NVIDIA RTX A6000",
|
| 57 |
+
"memoryTotal": "51527024640",
|
| 58 |
+
"cudaCores": 10752,
|
| 59 |
+
"architecture": "Ampere",
|
| 60 |
+
"uuid": "GPU-2887d599-b7bf-d31f-4425-84fa60413306"
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"name": "NVIDIA RTX A6000",
|
| 64 |
+
"memoryTotal": "51527024640",
|
| 65 |
+
"cudaCores": 10752,
|
| 66 |
+
"architecture": "Ampere",
|
| 67 |
+
"uuid": "GPU-86e7c8f1-cde6-4163-dc15-52cef50545bd"
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"name": "NVIDIA RTX A6000",
|
| 71 |
+
"memoryTotal": "51527024640",
|
| 72 |
+
"cudaCores": 10752,
|
| 73 |
+
"architecture": "Ampere",
|
| 74 |
+
"uuid": "GPU-460d754a-f551-6943-c142-b5b8f2f86236"
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"name": "NVIDIA RTX A6000",
|
| 78 |
+
"memoryTotal": "51527024640",
|
| 79 |
+
"cudaCores": 10752,
|
| 80 |
+
"architecture": "Ampere",
|
| 81 |
+
"uuid": "GPU-553ca63b-335c-4c11-94eb-29c777adb307"
|
| 82 |
+
}
|
| 83 |
+
],
|
| 84 |
+
"cudaVersion": "12.3"
|
| 85 |
+
}
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-core.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-06-17T09:05:27.98855733Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpng05fvru/port-1311468.txt","pid":1311468,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
|
| 2 |
+
{"time":"2025-06-17T09:05:27.99003933Z","level":"INFO","msg":"Will exit if parent process dies.","ppid":1311468}
|
| 3 |
+
{"time":"2025-06-17T09:05:27.99004801Z","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":46731,"Zone":""}}
|
| 4 |
+
{"time":"2025-06-17T09:05:28.169034214Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:38220"}
|
| 5 |
+
{"time":"2025-06-17T09:05:28.178996979Z","level":"INFO","msg":"handleInformInit: received","streamId":"0nx0l2dh","id":"127.0.0.1:38220"}
|
| 6 |
+
{"time":"2025-06-17T09:05:29.423327647Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"0nx0l2dh","id":"127.0.0.1:38220"}
|
| 7 |
+
{"time":"2025-06-17T10:10:24.148447187Z","level":"INFO","msg":"Parent process exited, terminating service process."}
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-internal.log
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"time":"2025-06-17T09:05:28.179242652Z","level":"INFO","msg":"stream: starting","core version":"0.20.1","symlink path":"checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-core.log"}
|
| 2 |
+
{"time":"2025-06-17T09:05:29.423278937Z","level":"INFO","msg":"stream: created new stream","id":"0nx0l2dh"}
|
| 3 |
+
{"time":"2025-06-17T09:05:29.423321777Z","level":"INFO","msg":"stream: started","id":"0nx0l2dh"}
|
| 4 |
+
{"time":"2025-06-17T09:05:29.423393558Z","level":"INFO","msg":"sender: started","stream_id":"0nx0l2dh"}
|
| 5 |
+
{"time":"2025-06-17T09:05:29.423393088Z","level":"INFO","msg":"writer: Do: started","stream_id":"0nx0l2dh"}
|
| 6 |
+
{"time":"2025-06-17T09:05:29.423465179Z","level":"INFO","msg":"handler: started","stream_id":"0nx0l2dh"}
|
| 7 |
+
{"time":"2025-06-17T09:05:30.100696875Z","level":"INFO","msg":"Starting system monitor"}
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug.log
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Current SDK version is 0.20.1
|
| 2 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Configure stats pid to 1311468
|
| 3 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/.config/wandb/settings
|
| 4 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/weather_forecast/Unet/wandb/settings
|
| 5 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from environment variables
|
| 6 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:setup_run_log_directory():703] Logging user logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug.log
|
| 7 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-internal.log
|
| 8 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():831] calling init triggers
|
| 9 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():836] wandb.init called with sweep_config: {}
|
| 10 |
+
config: {'_wandb': {}}
|
| 11 |
+
2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():872] starting backend
|
| 12 |
+
2025-06-17 09:05:28,169 INFO MainThread:1311468 [wandb_init.py:init():875] sending inform_init request
|
| 13 |
+
2025-06-17 09:05:28,174 INFO MainThread:1311468 [wandb_init.py:init():883] backend started and connected
|
| 14 |
+
2025-06-17 09:05:28,175 INFO MainThread:1311468 [wandb_init.py:init():956] updated telemetry
|
| 15 |
+
2025-06-17 09:05:28,175 INFO MainThread:1311468 [wandb_init.py:init():980] communicating run to backend with 90.0 second timeout
|
| 16 |
+
2025-06-17 09:05:30,098 INFO MainThread:1311468 [wandb_init.py:init():1032] starting run threads in backend
|
| 17 |
+
2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_console_start():2453] atexit reg
|
| 18 |
+
2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2301] redirect: wrap_raw
|
| 19 |
+
2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2370] Wrapping output streams.
|
| 20 |
+
2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2393] Redirects installed.
|
| 21 |
+
2025-06-17 09:05:30,177 INFO MainThread:1311468 [wandb_init.py:init():1078] run started, returning control to user process
|
| 22 |
+
2025-06-17 09:05:31,151 INFO MainThread:1311468 [wandb_run.py:_config_callback():1358] config_cb None None {'pretrained_path': '', 'lr': 0.0005, 'beta_1': 0.9, 'beta_2': 0.99, 'weight_decay': 1e-05, 'warmup_epochs': 10, 'max_epochs': 50, 'warmup_start_lr': 1e-08, 'eta_min': 1e-08, '_instantiator': 'pytorch_lightning.cli.instantiate_module', 'dir_data': '/data/weather2025/NhaBe/', 'batch_size': 1, 'hours_predicted': 3, 'num_workers': 4, 'pin_memory': False, 'time_points_rad': 1, 'time_points_sat': 1, 'sat_inp_vars': 'total_precipitation', 'sat_out_vars': 'total_precipitation', 'sat_size': 25, 'rad_inp_vars': 'precipitation', 'rad_out_vars': 'precipitation', 'rad_size': 400, 'ablation': 'no'}
|
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/run-0nx0l2dh.wandb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72554d0fd15b4685f86d20808dfa02fa74043afb99ece44f6b8184bd0a6f9bfc
|
| 3 |
+
size 66093056
|
configs/AttR2Unet.yaml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 42
|
| 2 |
+
|
| 3 |
+
# ---------------------------- TRAINER -------------------------------------------
|
| 4 |
+
trainer:
|
| 5 |
+
default_root_dir: "checkpoint/AttR2Unet"
|
| 6 |
+
precision: "16-mixed"
|
| 7 |
+
min_epochs: 1
|
| 8 |
+
max_epochs: 100
|
| 9 |
+
accelerator: cuda
|
| 10 |
+
# limit_train_batches: 10
|
| 11 |
+
devices: [6]
|
| 12 |
+
# strategy: ddp
|
| 13 |
+
num_nodes: 1
|
| 14 |
+
enable_progress_bar: true
|
| 15 |
+
sync_batchnorm: True
|
| 16 |
+
enable_checkpointing: True
|
| 17 |
+
# debugging
|
| 18 |
+
fast_dev_run: false
|
| 19 |
+
logger:
|
| 20 |
+
class_path: pytorch_lightning.loggers.CSVLogger
|
| 21 |
+
init_args:
|
| 22 |
+
save_dir: "checkpoint/AttR2Unet/logs"
|
| 23 |
+
name: null
|
| 24 |
+
version: null
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
| 28 |
+
init_args:
|
| 29 |
+
logging_interval: "step"
|
| 30 |
+
|
| 31 |
+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
| 32 |
+
init_args:
|
| 33 |
+
dirpath: "checkpoint/AttR2Unet/checkpoints"
|
| 34 |
+
monitor: "val/mse" # name of the logged metric which determines when model is improving
|
| 35 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 36 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 37 |
+
save_last: True # additionally always save model from last epoch
|
| 38 |
+
verbose: False
|
| 39 |
+
filename: "epoch_{epoch:03d}"
|
| 40 |
+
auto_insert_metric_name: False
|
| 41 |
+
|
| 42 |
+
- class_path: pytorch_lightning.callbacks.EarlyStopping
|
| 43 |
+
init_args:
|
| 44 |
+
monitor: "val/mse" # name of the logged metric which determines when model is improving
|
| 45 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 46 |
+
patience: 10 # how many validation epochs of not improving until training stops
|
| 47 |
+
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
|
| 48 |
+
|
| 49 |
+
- class_path: pytorch_lightning.callbacks.RichModelSummary
|
| 50 |
+
init_args:
|
| 51 |
+
max_depth: -1
|
| 52 |
+
|
| 53 |
+
- class_path: pytorch_lightning.callbacks.RichProgressBar
|
| 54 |
+
|
| 55 |
+
# ---------------------------- MODEL -------------------------------------------
|
| 56 |
+
model:
|
| 57 |
+
pretrained_path: ""
|
| 58 |
+
beta_1: 0.9
|
| 59 |
+
beta_2: 0.99
|
| 60 |
+
lr: 5e-4
|
| 61 |
+
weight_decay: 1e-5
|
| 62 |
+
warmup_epochs: 10
|
| 63 |
+
max_epochs: 50
|
| 64 |
+
warmup_start_lr: 1e-8
|
| 65 |
+
eta_min: 1e-8
|
| 66 |
+
net:
|
| 67 |
+
model_type: "AttR2Unet"
|
| 68 |
+
num_channel: 1
|
| 69 |
+
|
| 70 |
+
# ---------------------------- DATA -------------------------------------------
|
| 71 |
+
data:
|
| 72 |
+
dir_data: "/data/data_WF/ablation/ablation_time"
|
| 73 |
+
ablation: "time"
|
| 74 |
+
sat_size: 20
|
| 75 |
+
rad_size: 640
|
| 76 |
+
time_points_rad: 1
|
| 77 |
+
time_points_sat: 1
|
| 78 |
+
sat_inp_vars: ["total_precipitation"]
|
| 79 |
+
sat_out_vars: "total_precipitation"
|
| 80 |
+
rad_inp_vars: ["precipitation"]
|
| 81 |
+
rad_out_vars: "precipitation"
|
| 82 |
+
hours_predicted: 3
|
| 83 |
+
batch_size: 32
|
| 84 |
+
num_workers: 4
|
| 85 |
+
pin_memory: False
|
| 86 |
+
|
configs/AttUnet.yaml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 42
|
| 2 |
+
|
| 3 |
+
# ---------------------------- TRAINER -------------------------------------------
|
| 4 |
+
trainer:
|
| 5 |
+
default_root_dir: "checkpoint/AttUnet"
|
| 6 |
+
precision: "16-mixed"
|
| 7 |
+
min_epochs: 1
|
| 8 |
+
max_epochs: 100
|
| 9 |
+
accelerator: cuda
|
| 10 |
+
# limit_train_batches: 10
|
| 11 |
+
devices: [5]
|
| 12 |
+
# strategy: ddp
|
| 13 |
+
num_nodes: 1
|
| 14 |
+
enable_progress_bar: true
|
| 15 |
+
sync_batchnorm: True
|
| 16 |
+
enable_checkpointing: True
|
| 17 |
+
# debugging
|
| 18 |
+
fast_dev_run: false
|
| 19 |
+
logger:
|
| 20 |
+
class_path: pytorch_lightning.loggers.CSVLogger
|
| 21 |
+
init_args:
|
| 22 |
+
save_dir: "checkpoint/AttUnet/logs"
|
| 23 |
+
name: null
|
| 24 |
+
version: null
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
| 28 |
+
init_args:
|
| 29 |
+
logging_interval: "step"
|
| 30 |
+
|
| 31 |
+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
| 32 |
+
init_args:
|
| 33 |
+
dirpath: "checkpoint/AttUnet/checkpoints"
|
| 34 |
+
monitor: "val/mse" # name of the logged metric which determines when model is improving
|
| 35 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 36 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 37 |
+
save_last: True # additionally always save model from last epoch
|
| 38 |
+
verbose: False
|
| 39 |
+
filename: "epoch_{epoch:03d}"
|
| 40 |
+
auto_insert_metric_name: False
|
| 41 |
+
|
| 42 |
+
- class_path: pytorch_lightning.callbacks.EarlyStopping
|
| 43 |
+
init_args:
|
| 44 |
+
monitor: "val/mse" # name of the logged metric which determines when model is improving
|
| 45 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 46 |
+
patience: 10 # how many validation epochs of not improving until training stops
|
| 47 |
+
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
|
| 48 |
+
|
| 49 |
+
- class_path: pytorch_lightning.callbacks.RichModelSummary
|
| 50 |
+
init_args:
|
| 51 |
+
max_depth: -1
|
| 52 |
+
|
| 53 |
+
- class_path: pytorch_lightning.callbacks.RichProgressBar
|
| 54 |
+
|
| 55 |
+
# ---------------------------- MODEL -------------------------------------------
|
| 56 |
+
model:
|
| 57 |
+
pretrained_path: ""
|
| 58 |
+
beta_1: 0.9
|
| 59 |
+
beta_2: 0.99
|
| 60 |
+
lr: 5e-4
|
| 61 |
+
weight_decay: 1e-5
|
| 62 |
+
warmup_epochs: 10
|
| 63 |
+
max_epochs: 50
|
| 64 |
+
warmup_start_lr: 1e-8
|
| 65 |
+
eta_min: 1e-8
|
| 66 |
+
net:
|
| 67 |
+
model_type: "AttUnet"
|
| 68 |
+
num_channel: 1
|
| 69 |
+
|
| 70 |
+
# ---------------------------- DATA -------------------------------------------
|
| 71 |
+
data:
|
| 72 |
+
dir_data: "/data/data_WF/ablation/ablation_time"
|
| 73 |
+
ablation: "time"
|
| 74 |
+
sat_size: 20
|
| 75 |
+
rad_size: 640
|
| 76 |
+
time_points_rad: 1
|
| 77 |
+
time_points_sat: 1
|
| 78 |
+
sat_inp_vars: ["total_precipitation"]
|
| 79 |
+
sat_out_vars: "total_precipitation"
|
| 80 |
+
rad_inp_vars: ["precipitation"]
|
| 81 |
+
rad_out_vars: "precipitation"
|
| 82 |
+
hours_predicted: 3
|
| 83 |
+
batch_size: 32
|
| 84 |
+
num_workers: 4
|
| 85 |
+
pin_memory: False
|
| 86 |
+
|
configs/Nothing.yaml
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 42
|
| 2 |
+
|
| 3 |
+
# ---------------------------- TRAINER -------------------------------------------
|
| 4 |
+
trainer:
|
| 5 |
+
default_root_dir: "checkpoint/Nothing"
|
| 6 |
+
precision: "16-mixed"
|
| 7 |
+
min_epochs: 1
|
| 8 |
+
max_epochs: 100
|
| 9 |
+
accelerator: cuda
|
| 10 |
+
# limit_train_batches: 10
|
| 11 |
+
devices: [4]
|
| 12 |
+
# strategy: ddp
|
| 13 |
+
num_nodes: 1
|
| 14 |
+
enable_progress_bar: true
|
| 15 |
+
sync_batchnorm: True
|
| 16 |
+
enable_checkpointing: True
|
| 17 |
+
# debugging
|
| 18 |
+
fast_dev_run: false
|
| 19 |
+
logger:
|
| 20 |
+
class_path: pytorch_lightning.loggers.CSVLogger
|
| 21 |
+
init_args:
|
| 22 |
+
save_dir: "checkpoint/Nothing/logs"
|
| 23 |
+
name: null
|
| 24 |
+
version: null
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
| 28 |
+
init_args:
|
| 29 |
+
logging_interval: "step"
|
| 30 |
+
|
| 31 |
+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
| 32 |
+
init_args:
|
| 33 |
+
dirpath: "checkpoint/Nothing/checkpoints"
|
| 34 |
+
monitor: "val/sat" # name of the logged metric which determines when model is improving
|
| 35 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 36 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 37 |
+
save_last: True # additionally always save model from last epoch
|
| 38 |
+
verbose: False
|
| 39 |
+
filename: "epoch_{epoch:03d}"
|
| 40 |
+
auto_insert_metric_name: False
|
| 41 |
+
|
| 42 |
+
- class_path: pytorch_lightning.callbacks.EarlyStopping
|
| 43 |
+
init_args:
|
| 44 |
+
monitor: "val/sat" # name of the logged metric which determines when model is improving
|
| 45 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 46 |
+
patience: 10 # how many validation epochs of not improving until training stops
|
| 47 |
+
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
|
| 48 |
+
|
| 49 |
+
- class_path: pytorch_lightning.callbacks.RichModelSummary
|
| 50 |
+
init_args:
|
| 51 |
+
max_depth: -1
|
| 52 |
+
|
| 53 |
+
- class_path: pytorch_lightning.callbacks.RichProgressBar
|
| 54 |
+
|
| 55 |
+
# ---------------------------- MODEL -------------------------------------------
|
| 56 |
+
model:
|
| 57 |
+
pretrained_path: ""
|
| 58 |
+
beta_1: 0.9
|
| 59 |
+
beta_2: 0.99
|
| 60 |
+
lr: 5e-4
|
| 61 |
+
weight_decay: 1e-5
|
| 62 |
+
warmup_epochs: 10
|
| 63 |
+
max_epochs: 50
|
| 64 |
+
warmup_start_lr: 1e-8
|
| 65 |
+
eta_min: 1e-8
|
| 66 |
+
net:
|
| 67 |
+
model_type: "Nothing"
|
| 68 |
+
num_channel: 1
|
| 69 |
+
|
| 70 |
+
# ---------------------------- DATA -------------------------------------------
|
| 71 |
+
data:
|
| 72 |
+
dir_data: "/data/data_WF/ablation/ablation_time"
|
| 73 |
+
ablation: "time"
|
| 74 |
+
sat_size: 20
|
| 75 |
+
rad_size: 640
|
| 76 |
+
time_points_rad: 1
|
| 77 |
+
time_points_sat: 1
|
| 78 |
+
sat_inp_vars: ["total_precipitation"]
|
| 79 |
+
sat_out_vars: "total_precipitation"
|
| 80 |
+
rad_inp_vars: ["precipitation"]
|
| 81 |
+
rad_out_vars: "precipitation"
|
| 82 |
+
hours_predicted: 3
|
| 83 |
+
batch_size: 8
|
| 84 |
+
num_workers: 4
|
| 85 |
+
pin_memory: False
|
| 86 |
+
|
configs/R2Unet.yaml
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 42
|
| 2 |
+
|
| 3 |
+
# ---------------------------- TRAINER -------------------------------------------
|
| 4 |
+
trainer:
|
| 5 |
+
default_root_dir: "checkpoint/R2Unet"
|
| 6 |
+
precision: "16-mixed"
|
| 7 |
+
min_epochs: 1
|
| 8 |
+
max_epochs: 100
|
| 9 |
+
accelerator: cuda
|
| 10 |
+
# limit_train_batches: 10
|
| 11 |
+
devices: [4]
|
| 12 |
+
# strategy: ddp
|
| 13 |
+
num_nodes: 1
|
| 14 |
+
enable_progress_bar: true
|
| 15 |
+
sync_batchnorm: True
|
| 16 |
+
enable_checkpointing: True
|
| 17 |
+
# debugging
|
| 18 |
+
fast_dev_run: false
|
| 19 |
+
logger:
|
| 20 |
+
class_path: pytorch_lightning.loggers.CSVLogger
|
| 21 |
+
init_args:
|
| 22 |
+
save_dir: "checkpoint/R2Unet/logs"
|
| 23 |
+
name: null
|
| 24 |
+
version: null
|
| 25 |
+
|
| 26 |
+
callbacks:
|
| 27 |
+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
| 28 |
+
init_args:
|
| 29 |
+
logging_interval: "step"
|
| 30 |
+
|
| 31 |
+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
| 32 |
+
init_args:
|
| 33 |
+
dirpath: "checkpoint/R2Unet/checkpoints"
|
| 34 |
+
monitor: "val/mse" # name of the logged metric which determines when model is improving
|
| 35 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 36 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 37 |
+
save_last: True # additionally always save model from last epoch
|
| 38 |
+
verbose: False
|
| 39 |
+
filename: "epoch_{epoch:03d}"
|
| 40 |
+
auto_insert_metric_name: False
|
| 41 |
+
|
| 42 |
+
- class_path: pytorch_lightning.callbacks.EarlyStopping
|
| 43 |
+
init_args:
|
| 44 |
+
monitor: "val/mse" # name of the logged metric which determines when model is improving
|
| 45 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 46 |
+
patience: 10 # how many validation epochs of not improving until training stops
|
| 47 |
+
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
|
| 48 |
+
|
| 49 |
+
- class_path: pytorch_lightning.callbacks.RichModelSummary
|
| 50 |
+
init_args:
|
| 51 |
+
max_depth: -1
|
| 52 |
+
|
| 53 |
+
- class_path: pytorch_lightning.callbacks.RichProgressBar
|
| 54 |
+
init_args:
|
| 55 |
+
theme:
|
| 56 |
+
description: "white"
|
| 57 |
+
progress_bar: "#6206E0"
|
| 58 |
+
progress_bar_finished: "green"
|
| 59 |
+
progress_bar_pulse: "cyan"
|
| 60 |
+
batch_progress: "white"
|
| 61 |
+
time: "grey42"
|
| 62 |
+
processing_speed: "grey70"
|
| 63 |
+
metrics: "white"
|
| 64 |
+
# ---------------------------- MODEL -------------------------------------------
|
| 65 |
+
model:
|
| 66 |
+
pretrained_path: ""
|
| 67 |
+
beta_1: 0.9
|
| 68 |
+
beta_2: 0.99
|
| 69 |
+
lr: 5e-4
|
| 70 |
+
weight_decay: 1e-5
|
| 71 |
+
warmup_epochs: 10
|
| 72 |
+
max_epochs: 50
|
| 73 |
+
warmup_start_lr: 1e-8
|
| 74 |
+
eta_min: 1e-8
|
| 75 |
+
net:
|
| 76 |
+
model_type: "R2Unet"
|
| 77 |
+
num_channel: 1
|
| 78 |
+
|
| 79 |
+
# ---------------------------- DATA -------------------------------------------
|
| 80 |
+
data:
|
| 81 |
+
dir_data: "/data/data_WF/ablation/ablation_time"
|
| 82 |
+
ablation: "time"
|
| 83 |
+
sat_size: 20
|
| 84 |
+
rad_size: 640
|
| 85 |
+
time_points_rad: 1
|
| 86 |
+
time_points_sat: 1
|
| 87 |
+
sat_inp_vars: ["total_precipitation"]
|
| 88 |
+
sat_out_vars: "total_precipitation"
|
| 89 |
+
rad_inp_vars: ["precipitation"]
|
| 90 |
+
rad_out_vars: "precipitation"
|
| 91 |
+
hours_predicted: 3
|
| 92 |
+
batch_size: 32
|
| 93 |
+
num_workers: 4
|
| 94 |
+
pin_memory: False
|
| 95 |
+
|
configs/Unet.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 42
|
| 2 |
+
|
| 3 |
+
# ---------------------------- TRAINER -------------------------------------------
|
| 4 |
+
trainer:
|
| 5 |
+
default_root_dir: "checkpoint/Unet"
|
| 6 |
+
precision: "16-mixed"
|
| 7 |
+
min_epochs: 1
|
| 8 |
+
max_epochs: 100
|
| 9 |
+
accelerator: cuda
|
| 10 |
+
# limit_train_batches: 10
|
| 11 |
+
devices: [5]
|
| 12 |
+
# strategy: ddp
|
| 13 |
+
num_nodes: 1
|
| 14 |
+
enable_progress_bar: true
|
| 15 |
+
sync_batchnorm: True
|
| 16 |
+
enable_checkpointing: True
|
| 17 |
+
# debugging
|
| 18 |
+
fast_dev_run: false
|
| 19 |
+
logger:
|
| 20 |
+
- class_path: pytorch_lightning.loggers.WandbLogger
|
| 21 |
+
init_args:
|
| 22 |
+
project: "NhaBe"
|
| 23 |
+
name: "UnetNhaBe"
|
| 24 |
+
save_dir: "checkpoint/Unet/wandb_logs"
|
| 25 |
+
log_model: False
|
| 26 |
+
- class_path: pytorch_lightning.loggers.CSVLogger
|
| 27 |
+
init_args:
|
| 28 |
+
save_dir: "checkpoint/Unet/csv_logs"
|
| 29 |
+
name: null
|
| 30 |
+
version: null
|
| 31 |
+
|
| 32 |
+
callbacks:
|
| 33 |
+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
| 34 |
+
init_args:
|
| 35 |
+
logging_interval: "step"
|
| 36 |
+
|
| 37 |
+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
| 38 |
+
init_args:
|
| 39 |
+
dirpath: "checkpoint/Unet/checkpoints"
|
| 40 |
+
monitor: "val/mse" # name of the logged metric which determines when model is improving
|
| 41 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 42 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 43 |
+
save_last: True # additionally always save model from last epoch
|
| 44 |
+
verbose: False
|
| 45 |
+
filename: "epoch_{epoch:03d}"
|
| 46 |
+
auto_insert_metric_name: False
|
| 47 |
+
|
| 48 |
+
- class_path: pytorch_lightning.callbacks.EarlyStopping
|
| 49 |
+
init_args:
|
| 50 |
+
monitor: "val/mse" # name of the logged metric which determines when model is improving
|
| 51 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 52 |
+
patience: 10 # how many validation epochs of not improving until training stops
|
| 53 |
+
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
|
| 54 |
+
|
| 55 |
+
- class_path: pytorch_lightning.callbacks.RichModelSummary
|
| 56 |
+
init_args:
|
| 57 |
+
max_depth: -1
|
| 58 |
+
|
| 59 |
+
# - class_path: pytorch_lightning.callbacks.RichProgressBar
|
| 60 |
+
# init_args:
|
| 61 |
+
# theme:
|
| 62 |
+
# description: "white"
|
| 63 |
+
# progress_bar: "#6206E0"
|
| 64 |
+
# progress_bar_finished: "green"
|
| 65 |
+
# progress_bar_pulse: "cyan"
|
| 66 |
+
# batch_progress: "white"
|
| 67 |
+
# time: "grey42"
|
| 68 |
+
# processing_speed: "grey70"
|
| 69 |
+
# metrics: "white"
|
| 70 |
+
# ---------------------------- MODEL -------------------------------------------
|
| 71 |
+
model:
|
| 72 |
+
pretrained_path: ""
|
| 73 |
+
beta_1: 0.9
|
| 74 |
+
beta_2: 0.99
|
| 75 |
+
lr: 5e-4
|
| 76 |
+
weight_decay: 1e-5
|
| 77 |
+
warmup_epochs: 10
|
| 78 |
+
max_epochs: 50
|
| 79 |
+
warmup_start_lr: 1e-8
|
| 80 |
+
eta_min: 1e-8
|
| 81 |
+
net:
|
| 82 |
+
model_type: "Unet"
|
| 83 |
+
rad_channel: 1
|
| 84 |
+
sat_channel: 1
|
| 85 |
+
rad_size: 400
|
| 86 |
+
sat_size: 25
|
| 87 |
+
|
| 88 |
+
# ---------------------------- DATA -------------------------------------------
|
| 89 |
+
data:
|
| 90 |
+
dir_data: "/data/weather2025/NhaBe/"
|
| 91 |
+
ablation: "no"
|
| 92 |
+
rad_size: 400
|
| 93 |
+
sat_size: 25
|
| 94 |
+
time_points_rad: 1
|
| 95 |
+
time_points_sat: 1
|
| 96 |
+
sat_inp_vars: "total_precipitation"
|
| 97 |
+
sat_out_vars: "total_precipitation"
|
| 98 |
+
rad_inp_vars: "precipitation"
|
| 99 |
+
rad_out_vars: "precipitation"
|
| 100 |
+
hours_predicted: 3
|
| 101 |
+
batch_size: 1
|
| 102 |
+
num_workers: 8
|
| 103 |
+
pin_memory: False
|
| 104 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools", "setuptools-scm"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "FusionModel"
|
| 7 |
+
version = "0.3.1"
|
| 8 |
+
authors =[
|
| 9 |
+
{name="Khanh Vinh Bui", email="khanhvinhbui0512@gmail.com"},
|
| 10 |
+
{name="Hong Trang Le", email="lhtrang@hcmut.edu.vn"}
|
| 11 |
+
]
|
| 12 |
+
description = ""
|
| 13 |
+
readme = "README.md"
|
| 14 |
+
requires-python = ">=3.10"
|
| 15 |
+
classifiers = [
|
| 16 |
+
"Programming Language :: Python :: 3",
|
| 17 |
+
"License :: OSI Approved :: MIT License",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[tool.setuptools.packages.find]
|
| 21 |
+
where = ["."]
|
src/__pycache__/arch.cpython-310.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
src/__pycache__/arch.cpython-312.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
src/__pycache__/arch.cpython-38.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
src/__pycache__/datamodule.cpython-310.pyc
ADDED
|
Binary file (9.34 kB). View file
|
|
|
src/__pycache__/datamodule.cpython-312.pyc
ADDED
|
Binary file (23.8 kB). View file
|
|
|
src/__pycache__/lr_scheduler.cpython-310.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
src/__pycache__/lr_scheduler.cpython-312.pyc
ADDED
|
Binary file (5.55 kB). View file
|
|
|
src/__pycache__/metric.cpython-310.pyc
ADDED
|
Binary file (1.74 kB). View file
|
|
|
src/__pycache__/metric.cpython-312.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
src/__pycache__/module.cpython-310.pyc
ADDED
|
Binary file (6.35 kB). View file
|
|
|
src/__pycache__/module.cpython-312.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
src/__pycache__/module.cpython-38.pyc
ADDED
|
Binary file (6.22 kB). View file
|
|
|
src/__pycache__/train.cpython-38.pyc
ADDED
|
Binary file (983 Bytes). View file
|
|
|
src/arch.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import glob
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision
|
| 7 |
+
# For everything
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.nn import CrossEntropyLoss, Linear, MSELoss
|
| 12 |
+
from torch.nn import ConvTranspose2d, Conv2d, MaxPool2d, BatchNorm2d
|
| 13 |
+
# For our model
|
| 14 |
+
import torchvision.models as models
|
| 15 |
+
from torchvision import datasets, transforms
|
| 16 |
+
from torchvision.io import read_image
|
| 17 |
+
from torch.utils.data import DataLoader, Dataset
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
from torch.autograd import Variable
|
| 20 |
+
from torchsummary import summary
|
| 21 |
+
class Nothing(nn.Module):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super(Nothing,self).__init__()
|
| 24 |
+
def forward(self, radar,satellite):
|
| 25 |
+
return radar, satellite
|
| 26 |
+
|
| 27 |
+
class ConvBlock(nn.Module):
|
| 28 |
+
def __init__(self, in_channels, out_channels):
|
| 29 |
+
super(ConvBlock, self).__init__()
|
| 30 |
+
# number of input channels is a number of filters in the previous layer
|
| 31 |
+
# number of output channels is a number of filters in the current layer
|
| 32 |
+
# "same" convolutions
|
| 33 |
+
self.conv = nn.Sequential(
|
| 34 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
|
| 35 |
+
nn.BatchNorm2d(out_channels),
|
| 36 |
+
nn.ReLU(inplace=True),
|
| 37 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
|
| 38 |
+
nn.BatchNorm2d(out_channels),
|
| 39 |
+
nn.ReLU(inplace=True)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
x = self.conv(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
class UpConv(nn.Module):
|
| 47 |
+
def __init__(self, in_channels, out_channels):
|
| 48 |
+
super(UpConv, self).__init__()
|
| 49 |
+
self.up = nn.Sequential(
|
| 50 |
+
nn.Upsample(scale_factor=2),
|
| 51 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
|
| 52 |
+
nn.BatchNorm2d(out_channels),
|
| 53 |
+
nn.ReLU(inplace=True)
|
| 54 |
+
)
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
x = self.up(x)
|
| 57 |
+
return x
|
| 58 |
+
class AttentionBlock(nn.Module):
|
| 59 |
+
"""Attention block with learnable parameters"""
|
| 60 |
+
def __init__(self, F_g, F_l, n_coefficients):
|
| 61 |
+
"""
|
| 62 |
+
:param F_g: number of feature maps (channels) in previous layer
|
| 63 |
+
:param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection
|
| 64 |
+
:param n_coefficients: number of learnable multi-dimensional attention coefficients
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
super(AttentionBlock, self).__init__()
|
| 68 |
+
|
| 69 |
+
self.W_gate = nn.Sequential(
|
| 70 |
+
nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
|
| 71 |
+
nn.BatchNorm2d(n_coefficients)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.W_x = nn.Sequential(
|
| 75 |
+
nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
|
| 76 |
+
nn.BatchNorm2d(n_coefficients)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.psi = nn.Sequential(
|
| 80 |
+
nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
| 81 |
+
nn.BatchNorm2d(1),
|
| 82 |
+
nn.Sigmoid()
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.relu = nn.ReLU(inplace=True)
|
| 86 |
+
|
| 87 |
+
def forward(self, gate, skip_connection):
|
| 88 |
+
"""
|
| 89 |
+
:param gate: gating signal from previous layer
|
| 90 |
+
:param skip_connection: activation from corresponding encoder layer
|
| 91 |
+
:return: output activations
|
| 92 |
+
"""
|
| 93 |
+
g1 = self.W_gate(gate)
|
| 94 |
+
x1 = self.W_x(skip_connection)
|
| 95 |
+
psi = self.relu(g1 + x1)
|
| 96 |
+
psi = self.psi(psi)
|
| 97 |
+
out = skip_connection * psi
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
class Recurrent_block(nn.Module):
|
| 101 |
+
def __init__(self,ch_out,t=2):
|
| 102 |
+
super(Recurrent_block,self).__init__()
|
| 103 |
+
self.t = t
|
| 104 |
+
self.ch_out = ch_out
|
| 105 |
+
self.conv = nn.Sequential(
|
| 106 |
+
nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding='same',bias=True),
|
| 107 |
+
nn.BatchNorm2d(ch_out),
|
| 108 |
+
nn.ReLU(inplace=True)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self,x):
|
| 112 |
+
for i in range(self.t):
|
| 113 |
+
|
| 114 |
+
if i==0:
|
| 115 |
+
x1 = self.conv(x)
|
| 116 |
+
|
| 117 |
+
x1 = self.conv(x+x1)
|
| 118 |
+
return x1
|
| 119 |
+
|
| 120 |
+
class RRCNN_block(nn.Module):
|
| 121 |
+
def __init__(self,ch_in,ch_out,t=2):
|
| 122 |
+
super(RRCNN_block,self).__init__()
|
| 123 |
+
self.RCNN = nn.Sequential(
|
| 124 |
+
Recurrent_block(ch_out,t=t),
|
| 125 |
+
Recurrent_block(ch_out,t=t)
|
| 126 |
+
)
|
| 127 |
+
self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding='same')
|
| 128 |
+
|
| 129 |
+
def forward(self,x):
|
| 130 |
+
x = self.Conv_1x1(x)
|
| 131 |
+
x1 = self.RCNN(x)
|
| 132 |
+
return x+x1
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class single_conv(nn.Module):
|
| 136 |
+
def __init__(self,ch_in,ch_out):
|
| 137 |
+
super(single_conv,self).__init__()
|
| 138 |
+
self.conv = nn.Sequential(
|
| 139 |
+
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding='same',bias=True),
|
| 140 |
+
nn.BatchNorm2d(ch_out),
|
| 141 |
+
nn.ReLU(inplace=True)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def forward(self,x):
|
| 145 |
+
x = self.conv(x)
|
| 146 |
+
return x
|
| 147 |
+
class Unet(nn.Module):
|
| 148 |
+
def __init__(self, rad_channel=1,sat_channel=1, rad_size=640, sat_size=20):
|
| 149 |
+
super(Unet, self).__init__()
|
| 150 |
+
assert rad_size % sat_size == 0, "rad_size must be divisible by sat_size"
|
| 151 |
+
ratio = rad_size // sat_size
|
| 152 |
+
assert (ratio & (ratio - 1)) == 0, "rad_size/sat_size must be a power of 2"
|
| 153 |
+
self.n_pool = int(math.log2(ratio))
|
| 154 |
+
# Encoder
|
| 155 |
+
self.encoder_blocks = nn.ModuleList()
|
| 156 |
+
self.pools = nn.ModuleList()
|
| 157 |
+
for i in range(self.n_pool):
|
| 158 |
+
in_c = rad_channel * (2**(i))
|
| 159 |
+
out_c = rad_channel * (2**(i+1))
|
| 160 |
+
self.encoder_blocks.append(ConvBlock(in_c, out_c))
|
| 161 |
+
if i < self.n_pool:
|
| 162 |
+
self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2))
|
| 163 |
+
# Bottleneck
|
| 164 |
+
self.mid_conv_1 = single_conv(out_c, out_c)
|
| 165 |
+
self.mid_conv_2 = single_conv(sat_channel, out_c)
|
| 166 |
+
self.mid_merge = ConvBlock(2*out_c, out_c)
|
| 167 |
+
# Decoder
|
| 168 |
+
self.up_convs = nn.ModuleList()
|
| 169 |
+
self.decoder_blocks = nn.ModuleList()
|
| 170 |
+
for i in reversed(range(self.n_pool)):
|
| 171 |
+
up_in = rad_channel * (2**(i+2))
|
| 172 |
+
up_out = rad_channel * (2**(i+1))
|
| 173 |
+
self.up_convs.append(UpConv(up_in, up_out))
|
| 174 |
+
self.decoder_blocks.append(ConvBlock(up_in, up_out))
|
| 175 |
+
self.final_decoder = ConvBlock(4*rad_channel, 2*rad_channel)
|
| 176 |
+
self.out_conv_R = nn.Conv2d(2*rad_channel, rad_channel, kernel_size=1, padding='same')
|
| 177 |
+
self.out_conv_S = nn.Conv2d(out_c, sat_channel, kernel_size=1, padding='same')
|
| 178 |
+
def forward(self, radar, satellite):
|
| 179 |
+
# Encoding
|
| 180 |
+
enc_feats = []
|
| 181 |
+
x = radar
|
| 182 |
+
for i, block in enumerate(self.encoder_blocks):
|
| 183 |
+
x = block(x)
|
| 184 |
+
enc_feats.append(x)
|
| 185 |
+
if i < self.n_pool:
|
| 186 |
+
x = self.pools[i](x)
|
| 187 |
+
# Bottleneck
|
| 188 |
+
x = F.relu(self.mid_conv_1(x))
|
| 189 |
+
y = F.relu(self.mid_conv_2(satellite))
|
| 190 |
+
x = torch.cat((x, y), dim=1)
|
| 191 |
+
|
| 192 |
+
mid_out = self.mid_merge(x)
|
| 193 |
+
pred_sat = self.out_conv_S(mid_out)
|
| 194 |
+
# Decoding
|
| 195 |
+
x = x # input to decoder is original x before mid_merge
|
| 196 |
+
for i in range(self.n_pool):
|
| 197 |
+
x = self.up_convs[i](x)
|
| 198 |
+
x = torch.cat((enc_feats[self.n_pool - 1 - i], x), dim=1)
|
| 199 |
+
x = self.decoder_blocks[i](x)
|
| 200 |
+
x = torch.cat((enc_feats[0], x), dim=1)
|
| 201 |
+
x = self.final_decoder(x)
|
| 202 |
+
pred_rad = self.out_conv_R(x)
|
| 203 |
+
return pred_rad, pred_sat
|
| 204 |
+
# class Unet(nn.Module):
|
| 205 |
+
# def __init__(self,num_channel=1,rad_size=640,sat_size=20):
|
| 206 |
+
# super(Unet, self).__init__()
|
| 207 |
+
# self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 208 |
+
# self.Conv1 = ConvBlock(1, 2*num_channel)
|
| 209 |
+
# self.Conv2 = ConvBlock(2*num_channel, 4*num_channel)
|
| 210 |
+
# self.Conv3 = ConvBlock(4*num_channel, 8*num_channel)
|
| 211 |
+
# self.Conv4 = ConvBlock(8*num_channel, 16*num_channel)
|
| 212 |
+
# self.Conv5 = ConvBlock(16*num_channel, 32*num_channel)
|
| 213 |
+
# self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
|
| 214 |
+
# self.mid_conv_2 = single_conv(2, 32*num_channel)
|
| 215 |
+
# self.MidConv = ConvBlock(64*num_channel, 32*num_channel)
|
| 216 |
+
# self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
|
| 217 |
+
# self.Up5 = UpConv(64*num_channel, 32*num_channel)
|
| 218 |
+
# self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel)
|
| 219 |
+
# self.Up4 = UpConv(32*num_channel, 16*num_channel)
|
| 220 |
+
# self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel)
|
| 221 |
+
# self.Up3 = UpConv(16*num_channel, 8*num_channel)
|
| 222 |
+
# self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel)
|
| 223 |
+
# self.Up2 = UpConv(8*num_channel, 4*num_channel)
|
| 224 |
+
# self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel)
|
| 225 |
+
# self.Up1 = UpConv(4*num_channel, 2*num_channel)
|
| 226 |
+
# self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel)
|
| 227 |
+
# self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
|
| 228 |
+
# def forward(self, radar,satellite):
|
| 229 |
+
# e1 = self.Conv1(radar)
|
| 230 |
+
# e2 = self.MaxPool(e1)
|
| 231 |
+
# e2 = self.Conv2(e2)
|
| 232 |
+
# e3 = self.MaxPool(e2)
|
| 233 |
+
# e3 = self.Conv3(e3)
|
| 234 |
+
# e4 = self.MaxPool(e3)
|
| 235 |
+
# e4 = self.Conv4(e4)
|
| 236 |
+
# e5 = self.MaxPool(e4)
|
| 237 |
+
# e5 = self.Conv5(e5)
|
| 238 |
+
# e6 = self.MaxPool(e5)
|
| 239 |
+
# X = F.relu(self.mid_conv_1(e6))
|
| 240 |
+
# Y = F.relu(self.mid_conv_2(satellite))
|
| 241 |
+
# X = torch.cat((X,Y),1)
|
| 242 |
+
# Y = self.MidConv(X)
|
| 243 |
+
# pred_satellite = self.out_conv_S(Y)
|
| 244 |
+
# d5 = self.Up5(X)
|
| 245 |
+
# d5 = torch.cat((e5, d5), dim=1)
|
| 246 |
+
# d5 = self.UpConv5(d5)
|
| 247 |
+
# d4 = self.Up4(d5)
|
| 248 |
+
# d4 = torch.cat((e4, d4), dim=1)
|
| 249 |
+
# d4 = self.UpConv4(d4)
|
| 250 |
+
# d3 = self.Up3(d4)
|
| 251 |
+
# d3 = torch.cat((e3, d3), dim=1)
|
| 252 |
+
# d3 = self.UpConv3(d3)
|
| 253 |
+
# d2 = self.Up2(d3)
|
| 254 |
+
# d2 = torch.cat((e2, d2), dim=1)
|
| 255 |
+
# d2 = self.UpConv2(d2)
|
| 256 |
+
# d1 = self.Up1(d2)
|
| 257 |
+
# d0 = torch.cat((e1, d1), dim=1)
|
| 258 |
+
# d0 = self.UpConv1(d0)
|
| 259 |
+
# pred_radar = self.out_conv_R(d0)
|
| 260 |
+
# return pred_radar, pred_satellite
|
| 261 |
+
|
| 262 |
+
class R2Unet(nn.Module):
|
| 263 |
+
def __init__(self,num_channel=1,t=2):
|
| 264 |
+
super(R2Unet, self).__init__()
|
| 265 |
+
self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 266 |
+
self.RRCNN1 = RRCNN_block(5,2*num_channel,t=t)
|
| 267 |
+
self.RRCNN2 = RRCNN_block(2*num_channel,4*num_channel,t=t)
|
| 268 |
+
self.RRCNN3 = RRCNN_block(4*num_channel,8*num_channel,t=t)
|
| 269 |
+
self.RRCNN4 = RRCNN_block(8*num_channel,16*num_channel,t=t)
|
| 270 |
+
self.RRCNN5 = RRCNN_block(16*num_channel,32*num_channel,t=t)
|
| 271 |
+
self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
|
| 272 |
+
self.mid_conv_2 = single_conv(2, 32*num_channel)
|
| 273 |
+
self.MidConv = RRCNN_block(64*num_channel, 32*num_channel)
|
| 274 |
+
self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
|
| 275 |
+
self.Up5 = UpConv(64*num_channel, 32*num_channel)
|
| 276 |
+
self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel)
|
| 277 |
+
self.Up4 = UpConv(32*num_channel, 16*num_channel)
|
| 278 |
+
self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel)
|
| 279 |
+
self.Up3 = UpConv(16*num_channel, 8*num_channel)
|
| 280 |
+
self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel)
|
| 281 |
+
self.Up2 = UpConv(8*num_channel, 4*num_channel)
|
| 282 |
+
self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel)
|
| 283 |
+
self.Up1 = UpConv(4*num_channel, 2*num_channel)
|
| 284 |
+
self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel)
|
| 285 |
+
self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
|
| 286 |
+
def forward(self, radar,satellite):
|
| 287 |
+
e1 = self.RRCNN1(radar)
|
| 288 |
+
e2 = self.MaxPool(e1)
|
| 289 |
+
e2 = self.RRCNN2(e2)
|
| 290 |
+
e3 = self.MaxPool(e2)
|
| 291 |
+
e3 = self.RRCNN3(e3)
|
| 292 |
+
e4 = self.MaxPool(e3)
|
| 293 |
+
e4 = self.RRCNN4(e4)
|
| 294 |
+
e5 = self.MaxPool(e4)
|
| 295 |
+
e5 = self.RRCNN5(e5)
|
| 296 |
+
e6 = self.MaxPool(e5)
|
| 297 |
+
X = F.relu(self.mid_conv_1(e6))
|
| 298 |
+
Y = F.relu(self.mid_conv_2(satellite))
|
| 299 |
+
X = torch.cat((X,Y),1)
|
| 300 |
+
Y = self.MidConv(X)
|
| 301 |
+
pred_satellite = self.out_conv_S(Y)
|
| 302 |
+
d5 = self.Up5(X)
|
| 303 |
+
d5 = torch.cat((e5, d5), dim=1)
|
| 304 |
+
d5 = self.UpRRCNN5(d5)
|
| 305 |
+
d4 = self.Up4(d5)
|
| 306 |
+
d4 = torch.cat((e4, d4), dim=1)
|
| 307 |
+
d4 = self.UpRRCNN4(d4)
|
| 308 |
+
d3 = self.Up3(d4)
|
| 309 |
+
d3 = torch.cat((e3, d3), dim=1)
|
| 310 |
+
d3 = self.UpRRCNN3(d3)
|
| 311 |
+
d2 = self.Up2(d3)
|
| 312 |
+
d2 = torch.cat((e2, d2), dim=1)
|
| 313 |
+
d2 = self.UpRRCNN2(d2)
|
| 314 |
+
d1 = self.Up1(d2)
|
| 315 |
+
d0 = torch.cat((e1, d1), dim=1)
|
| 316 |
+
d0 = self.UpRRCNN1(d0)
|
| 317 |
+
pred_radar = self.out_conv_R(d0)
|
| 318 |
+
return pred_radar, pred_satellite
|
| 319 |
+
class AttUnet(nn.Module):
|
| 320 |
+
def __init__(self,num_channel=1):
|
| 321 |
+
super(AttUnet, self).__init__()
|
| 322 |
+
self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 323 |
+
self.Conv1 = ConvBlock(5, 2*num_channel)
|
| 324 |
+
self.Conv2 = ConvBlock(2*num_channel, 4*num_channel)
|
| 325 |
+
self.Conv3 = ConvBlock(4*num_channel, 8*num_channel)
|
| 326 |
+
self.Conv4 = ConvBlock(8*num_channel, 16*num_channel)
|
| 327 |
+
self.Conv5 = ConvBlock(16*num_channel, 32*num_channel)
|
| 328 |
+
self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
|
| 329 |
+
self.mid_conv_2 = single_conv(2, 32*num_channel)
|
| 330 |
+
self.MidConv = ConvBlock(64*num_channel, 32*num_channel)
|
| 331 |
+
self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
|
| 332 |
+
self.Up5 = UpConv(64*num_channel, 32*num_channel)
|
| 333 |
+
self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel)
|
| 334 |
+
self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel)
|
| 335 |
+
self.Up4 = UpConv(32*num_channel, 16*num_channel)
|
| 336 |
+
self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel)
|
| 337 |
+
self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel)
|
| 338 |
+
self.Up3 = UpConv(16*num_channel, 8*num_channel)
|
| 339 |
+
self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel)
|
| 340 |
+
self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel)
|
| 341 |
+
self.Up2 = UpConv(8*num_channel, 4*num_channel)
|
| 342 |
+
self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel)
|
| 343 |
+
self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel)
|
| 344 |
+
self.Up1 = UpConv(4*num_channel, 2*num_channel)
|
| 345 |
+
self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel)
|
| 346 |
+
self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel)
|
| 347 |
+
self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
|
| 348 |
+
def forward(self, radar,satellite):
|
| 349 |
+
e1 = self.Conv1(radar)
|
| 350 |
+
e2 = self.MaxPool(e1)
|
| 351 |
+
e2 = self.Conv2(e2)
|
| 352 |
+
e3 = self.MaxPool(e2)
|
| 353 |
+
e3 = self.Conv3(e3)
|
| 354 |
+
e4 = self.MaxPool(e3)
|
| 355 |
+
e4 = self.Conv4(e4)
|
| 356 |
+
e5 = self.MaxPool(e4)
|
| 357 |
+
e5 = self.Conv5(e5)
|
| 358 |
+
e6 = self.MaxPool(e5)
|
| 359 |
+
X = F.relu(self.mid_conv_1(e6))
|
| 360 |
+
Y = F.relu(self.mid_conv_2(satellite))
|
| 361 |
+
X = torch.cat((X,Y),1)
|
| 362 |
+
Y = self.MidConv(X)
|
| 363 |
+
pred_satellite = self.out_conv_S(Y)
|
| 364 |
+
d5 = self.Up5(X)
|
| 365 |
+
s4 = self.Att5(gate=d5, skip_connection=e5)
|
| 366 |
+
d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
|
| 367 |
+
d5 = self.UpConv5(d5)
|
| 368 |
+
d4 = self.Up4(d5)
|
| 369 |
+
s3 = self.Att4(gate=d4, skip_connection=e4)
|
| 370 |
+
d4 = torch.cat((s3, d4), dim=1)
|
| 371 |
+
d4 = self.UpConv4(d4)
|
| 372 |
+
d3 = self.Up3(d4)
|
| 373 |
+
s2 = self.Att3(gate=d3, skip_connection=e3)
|
| 374 |
+
d3 = torch.cat((s2, d3), dim=1)
|
| 375 |
+
d3 = self.UpConv3(d3)
|
| 376 |
+
d2 = self.Up2(d3)
|
| 377 |
+
s1 = self.Att2(gate=d2, skip_connection=e2)
|
| 378 |
+
d2 = torch.cat((s1, d2), dim=1)
|
| 379 |
+
d2 = self.UpConv2(d2)
|
| 380 |
+
d1 = self.Up1(d2)
|
| 381 |
+
s0 = self.Att1(gate=d1, skip_connection=e1)
|
| 382 |
+
d0 = torch.cat((s0, d1), dim=1)
|
| 383 |
+
d0 = self.UpConv1(d0)
|
| 384 |
+
pred_radar = self.out_conv_R(d0)
|
| 385 |
+
return pred_radar, pred_satellite
|
| 386 |
+
class AttR2Unet(nn.Module):
|
| 387 |
+
def __init__(self,num_channel=1,t=2):
|
| 388 |
+
super(AttR2Unet, self).__init__()
|
| 389 |
+
self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 390 |
+
self.RRCNN1 = RRCNN_block(5, 2*num_channel)
|
| 391 |
+
self.RRCNN2 = RRCNN_block(2*num_channel, 4*num_channel)
|
| 392 |
+
self.RRCNN3 = RRCNN_block(4*num_channel, 8*num_channel)
|
| 393 |
+
self.RRCNN4 = RRCNN_block(8*num_channel, 16*num_channel)
|
| 394 |
+
self.RRCNN5 = RRCNN_block(16*num_channel, 32*num_channel)
|
| 395 |
+
self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
|
| 396 |
+
self.mid_conv_2 = single_conv(2, 32*num_channel)
|
| 397 |
+
self.MidConv = RRCNN_block(64*num_channel, 32*num_channel)
|
| 398 |
+
self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
|
| 399 |
+
self.Up5 = UpConv(64*num_channel, 32*num_channel)
|
| 400 |
+
self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel)
|
| 401 |
+
self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel)
|
| 402 |
+
self.Up4 = UpConv(32*num_channel, 16*num_channel)
|
| 403 |
+
self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel)
|
| 404 |
+
self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel)
|
| 405 |
+
self.Up3 = UpConv(16*num_channel, 8*num_channel)
|
| 406 |
+
self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel)
|
| 407 |
+
self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel)
|
| 408 |
+
self.Up2 = UpConv(8*num_channel, 4*num_channel)
|
| 409 |
+
self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel)
|
| 410 |
+
self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel)
|
| 411 |
+
self.Up1 = UpConv(4*num_channel, 2*num_channel)
|
| 412 |
+
self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel)
|
| 413 |
+
self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel)
|
| 414 |
+
self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
|
| 415 |
+
def forward(self, radar,satellite):
|
| 416 |
+
e1 = self.RRCNN1(radar)
|
| 417 |
+
e2 = self.MaxPool(e1)
|
| 418 |
+
e2 = self.RRCNN2(e2)
|
| 419 |
+
e3 = self.MaxPool(e2)
|
| 420 |
+
e3 = self.RRCNN3(e3)
|
| 421 |
+
e4 = self.MaxPool(e3)
|
| 422 |
+
e4 = self.RRCNN4(e4)
|
| 423 |
+
e5 = self.MaxPool(e4)
|
| 424 |
+
e5 = self.RRCNN5(e5)
|
| 425 |
+
e6 = self.MaxPool(e5)
|
| 426 |
+
X = F.relu(self.mid_conv_1(e6))
|
| 427 |
+
Y = F.relu(self.mid_conv_2(satellite))
|
| 428 |
+
X = torch.cat((X,Y),1)
|
| 429 |
+
Y = self.MidConv(X)
|
| 430 |
+
pred_satellite = self.out_conv_S(Y)
|
| 431 |
+
d5 = self.Up5(X)
|
| 432 |
+
s4 = self.Att5(gate=d5, skip_connection=e5)
|
| 433 |
+
d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
|
| 434 |
+
d5 = self.UpRRCNN5(d5)
|
| 435 |
+
d4 = self.Up4(d5)
|
| 436 |
+
s3 = self.Att4(gate=d4, skip_connection=e4)
|
| 437 |
+
d4 = torch.cat((s3, d4), dim=1)
|
| 438 |
+
d4 = self.UpRRCNN4(d4)
|
| 439 |
+
d3 = self.Up3(d4)
|
| 440 |
+
s2 = self.Att3(gate=d3, skip_connection=e3)
|
| 441 |
+
d3 = torch.cat((s2, d3), dim=1)
|
| 442 |
+
d3 = self.UpRRCNN3(d3)
|
| 443 |
+
d2 = self.Up2(d3)
|
| 444 |
+
s1 = self.Att2(gate=d2, skip_connection=e2)
|
| 445 |
+
d2 = torch.cat((s1, d2), dim=1)
|
| 446 |
+
d2 = self.UpRRCNN2(d2)
|
| 447 |
+
d1 = self.Up1(d2)
|
| 448 |
+
s0 = self.Att1(gate=d1, skip_connection=e1)
|
| 449 |
+
d0 = torch.cat((s0, d1), dim=1)
|
| 450 |
+
d0 = self.UpRRCNN1(d0)
|
| 451 |
+
pred_radar = self.out_conv_R(d0)
|
| 452 |
+
return pred_radar, pred_satellite
|
| 453 |
+
class Network(nn.Module):
|
| 454 |
+
def __init__(self,model_type:str,rad_channel:int, sat_channel:int,rad_size:int,sat_size:int):
|
| 455 |
+
super(Network,self).__init__()
|
| 456 |
+
print(model_type)
|
| 457 |
+
if(model_type == "Nothing"):
|
| 458 |
+
self.net = Nothing()
|
| 459 |
+
elif(model_type == "Unet"):
|
| 460 |
+
self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
|
| 461 |
+
elif(model_type == "Unet"):
|
| 462 |
+
self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
|
| 463 |
+
elif(model_type == "R2Unet"):
|
| 464 |
+
self.net = R2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
|
| 465 |
+
elif(model_type == "AttUnet"):
|
| 466 |
+
self.net = AttUnet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
|
| 467 |
+
elif(model_type == "AttR2Unet"):
|
| 468 |
+
self.net = AttR2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
|
| 469 |
+
else:
|
| 470 |
+
raise ValueError("model_type is wrong")
|
| 471 |
+
def forward(self, radar,satellite):
|
| 472 |
+
pred_radar, pred_satellite = self.net.forward(radar,satellite)
|
| 473 |
+
return pred_radar, pred_satellite
|
src/datamodule.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
| 3 |
+
import numpy as np
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from pytorch_lightning import LightningDataModule, LightningModule
|
| 7 |
+
from pytorch_lightning.cli import LightningCLI
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
import pytorch_lightning as L
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from typing import Tuple, Dict, List
|
| 13 |
+
|
| 14 |
+
# import optim
|
| 15 |
+
|
| 16 |
+
class DataReader(Dataset):
|
| 17 |
+
def __init__(
|
| 18 |
+
self, dir_data : str,
|
| 19 |
+
type_data : str,
|
| 20 |
+
rad_attribute : str ,
|
| 21 |
+
sat_attribute : str,
|
| 22 |
+
hours_predicted : int,
|
| 23 |
+
rad_predicted : str ,
|
| 24 |
+
sat_predicted : str ,
|
| 25 |
+
time_points_rad : int,
|
| 26 |
+
time_points_sat : int,
|
| 27 |
+
rad_size:int,
|
| 28 |
+
sat_size:int,
|
| 29 |
+
ablation = str,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.base_dir=dir_data
|
| 33 |
+
self.type_data = type_data
|
| 34 |
+
if self.type_data == "train":
|
| 35 |
+
self.dir_data=os.path.join(dir_data, "train")
|
| 36 |
+
elif self.type_data =="test":
|
| 37 |
+
self.dir_data=os.path.join(dir_data, 'test')
|
| 38 |
+
elif self.type_data =="val":
|
| 39 |
+
self.dir_data=os.path.join(dir_data, 'val')
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError("Type must be train, test or val")
|
| 42 |
+
self.sat_size = sat_size
|
| 43 |
+
self.rad_size = rad_size
|
| 44 |
+
self.hours_predicted = hours_predicted
|
| 45 |
+
self.rad_attribute = rad_attribute
|
| 46 |
+
self.sat_attribute = sat_attribute
|
| 47 |
+
self.rad_predicted = rad_predicted
|
| 48 |
+
self.sat_predicted = sat_predicted
|
| 49 |
+
self.time_points_rad = time_points_rad
|
| 50 |
+
self.time_points_sat = time_points_sat
|
| 51 |
+
self.transform_rad = None
|
| 52 |
+
self.transform_sat = None
|
| 53 |
+
self.ablation = ablation
|
| 54 |
+
# Create path for img
|
| 55 |
+
self.rad_mean = np.load(os.path.join(self.base_dir,'rad_mean.npz'))[self.rad_attribute]
|
| 56 |
+
self.rad_std = np.load(os.path.join(self.base_dir,'rad_std.npz'))[self.rad_attribute]
|
| 57 |
+
self.sat_mean = np.load(os.path.join(self.base_dir,'sat_mean.npz'))[self.sat_attribute]
|
| 58 |
+
self.sat_std = np.load(os.path.join(self.base_dir,'sat_std.npz'))[self.sat_attribute]
|
| 59 |
+
#Create transform
|
| 60 |
+
self.create_transform()
|
| 61 |
+
#Get list img
|
| 62 |
+
if(self.ablation == "no"):
|
| 63 |
+
self.list_img_dir = self.gen_list_img_no(self.dir_data)
|
| 64 |
+
elif(self.ablation == "rad"):
|
| 65 |
+
self.list_img_dir = self.gen_list_img_rad(self.dir_data)
|
| 66 |
+
elif(self.ablation == "sat"):
|
| 67 |
+
self.list_img_dir = self.gen_list_img_sat(self.dir_data)
|
| 68 |
+
elif(self.ablation == "full"):
|
| 69 |
+
self.list_img_dir = self.gen_list_img_full(self.dir_data)
|
| 70 |
+
elif(self.ablation == "time"):
|
| 71 |
+
self.list_img_dir = self.gen_list_img_time(self.dir_data)
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError("Ablation must be no,rad,sat,full")
|
| 74 |
+
print(f"Number of {self.type_data } samples:",len(self.list_img_dir))
|
| 75 |
+
def __len__(self):
|
| 76 |
+
return len(self.list_img_dir)
|
| 77 |
+
def __getitem__(self, idx):
|
| 78 |
+
if(self.transform_rad):
|
| 79 |
+
inp_rad = self.transform_rad(np.load(self.list_img_dir[idx][0])[self.rad_attribute])
|
| 80 |
+
out_rad = self.transform_rad(np.load(self.list_img_dir[idx][2])[self.rad_predicted])
|
| 81 |
+
if(self.transform_sat):
|
| 82 |
+
inp_sat = self.transform_sat(np.load(self.list_img_dir[idx][1])[self.sat_attribute])
|
| 83 |
+
out_sat = self.transform_sat(np.load(self.list_img_dir[idx][3])[self.sat_predicted][0])
|
| 84 |
+
return inp_rad,inp_sat.float(),out_rad, out_sat.float()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def create_transform(self):
|
| 88 |
+
self.transform_rad = transforms.Compose([
|
| 89 |
+
transforms.ToTensor(),
|
| 90 |
+
transforms.Normalize(self.rad_mean,self.rad_std)
|
| 91 |
+
])
|
| 92 |
+
self.transform_sat = transforms.Compose([
|
| 93 |
+
transforms.ToTensor(),
|
| 94 |
+
transforms.Normalize(self.sat_mean[0],self.sat_std[0]),
|
| 95 |
+
])
|
| 96 |
+
# print("SAT_MEAN", self.sat_mean, self.sat_std)
|
| 97 |
+
|
| 98 |
+
def gen_list_img_no(self,path):
|
| 99 |
+
pred_rad_dir =os.path.join(path,"pred_rad")
|
| 100 |
+
pred_sat_dir = os.path.join(path,"pred_sat")
|
| 101 |
+
GT_rad_dir = os.path.join(path ,"rad")
|
| 102 |
+
GT_sat_dir = os.path.join(path,"sat")
|
| 103 |
+
list_dir = []
|
| 104 |
+
# print()
|
| 105 |
+
# print(len(os.listdir(pred_rad_dir)))
|
| 106 |
+
for name in os.listdir(pred_rad_dir):
|
| 107 |
+
temp = []
|
| 108 |
+
if(not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| 109 |
+
continue
|
| 110 |
+
temp.append(os.path.join(pred_rad_dir,name))
|
| 111 |
+
pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:])
|
| 112 |
+
GT_rad_path = os.path.join(GT_rad_dir, name)
|
| 113 |
+
GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
|
| 114 |
+
if(os.path.isfile(pred_sat_path)):
|
| 115 |
+
temp.append(pred_sat_path)
|
| 116 |
+
if(os.path.isfile(GT_rad_path)):
|
| 117 |
+
temp.append(GT_rad_path)
|
| 118 |
+
if(os.path.isfile(GT_sat_path)):
|
| 119 |
+
temp.append(GT_sat_path)
|
| 120 |
+
if(len(temp) == 4):
|
| 121 |
+
list_dir.append(temp)
|
| 122 |
+
return list_dir
|
| 123 |
+
def gen_list_img_rad(self,path):
|
| 124 |
+
pred_rad_dir = os.path.join(path,"rad")
|
| 125 |
+
pred_sat_dir = os.path.join(path,"pred_sat")
|
| 126 |
+
GT_rad_dir = os.path.join(path ,"rad")
|
| 127 |
+
GT_sat_dir = os.path.join(path,"sat")
|
| 128 |
+
list_dir = []
|
| 129 |
+
for name in os.listdir(pred_rad_dir):
|
| 130 |
+
temp = []
|
| 131 |
+
if( not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| 132 |
+
continue
|
| 133 |
+
temp_date = self.get_date_time(name)
|
| 134 |
+
temp.append(os.path.join(pred_rad_dir,name))
|
| 135 |
+
pred_sat_path = os.path.join(pred_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
|
| 136 |
+
GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz')
|
| 137 |
+
GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
|
| 138 |
+
if(os.path.isfile(pred_sat_path)):
|
| 139 |
+
temp.append(pred_sat_path)
|
| 140 |
+
if(os.path.isfile(GT_rad_path)):
|
| 141 |
+
temp.append(GT_rad_path)
|
| 142 |
+
if(os.path.isfile(GT_sat_path)):
|
| 143 |
+
temp.append(GT_sat_path)
|
| 144 |
+
if(len(temp) == 4):
|
| 145 |
+
list_dir.append(temp)
|
| 146 |
+
return list_dir
|
| 147 |
+
def gen_list_img_sat(self,path):
|
| 148 |
+
pred_rad_dir = os.path.join(path,"pred_rad")
|
| 149 |
+
pred_sat_dir = os.path.join(path,"sat")
|
| 150 |
+
GT_rad_dir = os.path.join(path ,"rad")
|
| 151 |
+
GT_sat_dir = os.path.join(path,"sat")
|
| 152 |
+
list_dir = []
|
| 153 |
+
for name in os.listdir(pred_rad_dir):
|
| 154 |
+
temp = []
|
| 155 |
+
if( not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| 156 |
+
continue
|
| 157 |
+
temp_date = self.get_date_time(name)
|
| 158 |
+
temp.append(os.path.join(pred_rad_dir,name))
|
| 159 |
+
pred_sat_path = os.path.join(pred_sat_dir, (temp_date-timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
|
| 160 |
+
GT_rad_path = os.path.join(GT_rad_dir, name)
|
| 161 |
+
GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
|
| 162 |
+
if(os.path.isfile(pred_sat_path)):
|
| 163 |
+
temp.append(pred_sat_path)
|
| 164 |
+
if(os.path.isfile(GT_rad_path)):
|
| 165 |
+
temp.append(GT_rad_path)
|
| 166 |
+
if(os.path.isfile(GT_sat_path)):
|
| 167 |
+
temp.append(GT_sat_path)
|
| 168 |
+
if(len(temp) == 4):
|
| 169 |
+
list_dir.append(temp)
|
| 170 |
+
return list_dir
|
| 171 |
+
def gen_list_img_full(self,path):
|
| 172 |
+
pred_rad_dir = os.path.join(path,"rad")
|
| 173 |
+
pred_sat_dir = os.path.join(path,"sat")
|
| 174 |
+
GT_rad_dir = os.path.join(path ,"rad")
|
| 175 |
+
GT_sat_dir = os.path.join(path,"sat")
|
| 176 |
+
list_dir = []
|
| 177 |
+
for name in os.listdir(pred_rad_dir):
|
| 178 |
+
temp = []
|
| 179 |
+
if(not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| 180 |
+
continue
|
| 181 |
+
temp_date = self.get_date_time(name)
|
| 182 |
+
temp.append(os.path.join(pred_rad_dir,name))
|
| 183 |
+
pred_sat_path = os.path.join(pred_sat_dir,temp_date.strftime('%Y%m%d%H')+'.npz')
|
| 184 |
+
GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz')
|
| 185 |
+
GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
|
| 186 |
+
if(os.path.isfile(pred_sat_path)):
|
| 187 |
+
temp.append(pred_sat_path)
|
| 188 |
+
if(os.path.isfile(GT_rad_path)):
|
| 189 |
+
temp.append(GT_rad_path)
|
| 190 |
+
if(os.path.isfile(GT_sat_path)):
|
| 191 |
+
temp.append(GT_sat_path)
|
| 192 |
+
if(len(temp) == 4):
|
| 193 |
+
list_dir.append(temp)
|
| 194 |
+
return list_dir
|
| 195 |
+
def gen_list_img_time(self,path):
|
| 196 |
+
pred_rad_dir =os.path.join(path,"pred_rad")
|
| 197 |
+
pred_sat_dir = os.path.join(path,"pred_sat")
|
| 198 |
+
GT_rad_dir = os.path.join(path ,"rad")
|
| 199 |
+
GT_sat_dir = os.path.join(path,"sat")
|
| 200 |
+
list_dir = []
|
| 201 |
+
for name in os.listdir(pred_rad_dir):
|
| 202 |
+
temp = [[],[],[],[]]
|
| 203 |
+
temp_date = self.get_date_time(name)
|
| 204 |
+
if(not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| 205 |
+
continue
|
| 206 |
+
for i in range(4):
|
| 207 |
+
temp_path = os.path.join(GT_rad_dir, (temp_date+timedelta(minutes=-210+i*10)).strftime('%Y%m%d%H%M') + '.npz')
|
| 208 |
+
if(os.path.isfile(temp_path)): temp[0].append(temp_path)
|
| 209 |
+
for i in range(1):
|
| 210 |
+
temp_path = os.path.join(GT_sat_dir, (temp_date+timedelta(minutes=-180+i*10)).strftime('%Y%m%d%H') + '.npz')
|
| 211 |
+
if(os.path.isfile(temp_path)): temp[1].append(temp_path)
|
| 212 |
+
temp[0].append(os.path.join(pred_rad_dir,name))
|
| 213 |
+
pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:])
|
| 214 |
+
GT_rad_path = os.path.join(GT_rad_dir, name)
|
| 215 |
+
GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
|
| 216 |
+
if(os.path.isfile(pred_sat_path)):
|
| 217 |
+
temp[1].append(pred_sat_path)
|
| 218 |
+
if(os.path.isfile(GT_rad_path)):
|
| 219 |
+
temp[2].append(GT_rad_path)
|
| 220 |
+
if(os.path.isfile(GT_sat_path)):
|
| 221 |
+
temp[3].append(GT_sat_path)
|
| 222 |
+
if(len(temp[0]) == 5 and len(temp[1]) == 2 and len(temp[2]) == 1 and len(temp[3]) == 1):
|
| 223 |
+
list_dir.append(temp)
|
| 224 |
+
return list_dir
|
| 225 |
+
def get_date_time(self,name):
|
| 226 |
+
year=int(name[0:4])
|
| 227 |
+
month=int(name[4:6])
|
| 228 |
+
day=int(name[6:8])
|
| 229 |
+
hour=int(name[8:10])
|
| 230 |
+
minute = int(name[10:12])
|
| 231 |
+
return datetime(year,month,day,hour,minute)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class WeatherForecastDataModule(LightningDataModule):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
dir_data: str,
|
| 240 |
+
batch_size:int ,
|
| 241 |
+
hours_predicted :int,
|
| 242 |
+
num_workers:int ,
|
| 243 |
+
pin_memory: bool ,
|
| 244 |
+
time_points_rad : int,
|
| 245 |
+
time_points_sat : int,
|
| 246 |
+
sat_inp_vars: str,
|
| 247 |
+
sat_out_vars : str,
|
| 248 |
+
sat_size: int,
|
| 249 |
+
rad_inp_vars : str,
|
| 250 |
+
rad_out_vars : str,
|
| 251 |
+
rad_size: int,
|
| 252 |
+
ablation: str,
|
| 253 |
+
):
|
| 254 |
+
|
| 255 |
+
super().__init__()
|
| 256 |
+
# this line allows to access init params with 'self.hparams' attribute
|
| 257 |
+
self.save_hyperparameters(logger=True)
|
| 258 |
+
self.data_train = None
|
| 259 |
+
self.data_test = None
|
| 260 |
+
self.data_val = None
|
| 261 |
+
self.rad_mean = np.load(os.path.join(self.hparams.dir_data,'rad_mean.npz'))[self.hparams.rad_inp_vars]
|
| 262 |
+
self.rad_std = np.load(os.path.join(self.hparams.dir_data,'rad_std.npz'))[self.hparams.rad_inp_vars]
|
| 263 |
+
self.sat_mean = np.load(os.path.join(self.hparams.dir_data,'sat_mean.npz'))[self.hparams.sat_inp_vars]
|
| 264 |
+
self.sat_std = np.load(os.path.join(self.hparams.dir_data,'sat_std.npz'))[self.hparams.sat_inp_vars]
|
| 265 |
+
def prepare_data(self):
|
| 266 |
+
pass
|
| 267 |
+
|
| 268 |
+
def setup(self, stage):
|
| 269 |
+
# print(self.hparams.dir_data)
|
| 270 |
+
self.data_train = DataReader(
|
| 271 |
+
dir_data=self.hparams.dir_data,
|
| 272 |
+
type_data= "train",
|
| 273 |
+
rad_attribute = self.hparams.rad_inp_vars,
|
| 274 |
+
sat_attribute = self.hparams.sat_inp_vars,
|
| 275 |
+
hours_predicted = self.hparams.hours_predicted,
|
| 276 |
+
rad_predicted = self.hparams.rad_out_vars,
|
| 277 |
+
sat_predicted = self.hparams.sat_out_vars,
|
| 278 |
+
time_points_rad = self.hparams.time_points_rad,
|
| 279 |
+
time_points_sat = self.hparams.time_points_sat,
|
| 280 |
+
sat_size = self.hparams.sat_size,
|
| 281 |
+
rad_size = self.hparams.rad_size,
|
| 282 |
+
ablation = self.hparams.ablation
|
| 283 |
+
)
|
| 284 |
+
self.data_test = DataReader(
|
| 285 |
+
dir_data=self.hparams.dir_data,
|
| 286 |
+
type_data ="test",
|
| 287 |
+
rad_attribute = self.hparams.rad_inp_vars,
|
| 288 |
+
sat_attribute = self.hparams.sat_inp_vars,
|
| 289 |
+
hours_predicted = self.hparams.hours_predicted,
|
| 290 |
+
rad_predicted = self.hparams.rad_out_vars,
|
| 291 |
+
sat_predicted = self.hparams.sat_out_vars,
|
| 292 |
+
time_points_rad = self.hparams.time_points_rad,
|
| 293 |
+
time_points_sat = self.hparams.time_points_sat,
|
| 294 |
+
sat_size = self.hparams.sat_size,
|
| 295 |
+
rad_size = self.hparams.rad_size,
|
| 296 |
+
ablation = self.hparams.ablation
|
| 297 |
+
)
|
| 298 |
+
self.data_val = DataReader(
|
| 299 |
+
dir_data=self.hparams.dir_data,
|
| 300 |
+
type_data = "val",
|
| 301 |
+
rad_attribute = self.hparams.rad_inp_vars,
|
| 302 |
+
sat_attribute = self.hparams.sat_inp_vars,
|
| 303 |
+
hours_predicted = self.hparams.hours_predicted,
|
| 304 |
+
rad_predicted = self.hparams.rad_out_vars,
|
| 305 |
+
sat_predicted = self.hparams.sat_out_vars,
|
| 306 |
+
time_points_rad = self.hparams.time_points_rad,
|
| 307 |
+
time_points_sat = self.hparams.time_points_sat,
|
| 308 |
+
sat_size = self.hparams.sat_size,
|
| 309 |
+
rad_size = self.hparams.rad_size,
|
| 310 |
+
ablation = self.hparams.ablation
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def train_dataloader(self):
|
| 314 |
+
return DataLoader(
|
| 315 |
+
self.data_train,
|
| 316 |
+
batch_size=self.hparams.batch_size,
|
| 317 |
+
num_workers=self.hparams.num_workers,
|
| 318 |
+
drop_last=False,
|
| 319 |
+
pin_memory=self.hparams.pin_memory,
|
| 320 |
+
shuffle=True,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def val_dataloader(self):
|
| 324 |
+
return DataLoader(
|
| 325 |
+
self.data_val,
|
| 326 |
+
batch_size=self.hparams.batch_size,
|
| 327 |
+
num_workers=self.hparams.num_workers,
|
| 328 |
+
drop_last=False,
|
| 329 |
+
pin_memory=self.hparams.pin_memory,
|
| 330 |
+
shuffle=False,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
def test_dataloader(self):
|
| 334 |
+
return DataLoader(
|
| 335 |
+
self.data_test,
|
| 336 |
+
batch_size=self.hparams.batch_size,
|
| 337 |
+
num_workers=self.hparams.num_workers,
|
| 338 |
+
drop_last=False,
|
| 339 |
+
pin_memory=self.hparams.pin_memory,
|
| 340 |
+
shuffle=False,
|
| 341 |
+
)
|
src/lr_scheduler.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
from torch.optim import Optimizer
|
| 9 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LinearWarmupCosineAnnealingLR(_LRScheduler):
|
| 13 |
+
"""Sets the learning rate of each parameter group to follow a linear warmup schedule between
|
| 14 |
+
warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and
|
| 15 |
+
eta_min."""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
optimizer: Optimizer,
|
| 20 |
+
warmup_epochs: int,
|
| 21 |
+
max_epochs: int,
|
| 22 |
+
warmup_start_lr: float = 0.0,
|
| 23 |
+
eta_min: float = 0.0,
|
| 24 |
+
last_epoch: int = -1,
|
| 25 |
+
) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Args:
|
| 28 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 29 |
+
warmup_epochs (int): Maximum number of iterations for linear warmup
|
| 30 |
+
max_epochs (int): Maximum number of iterations
|
| 31 |
+
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
|
| 32 |
+
eta_min (float): Minimum learning rate. Default: 0.
|
| 33 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
| 34 |
+
"""
|
| 35 |
+
self.warmup_epochs = warmup_epochs
|
| 36 |
+
self.max_epochs = max_epochs
|
| 37 |
+
self.warmup_start_lr = warmup_start_lr
|
| 38 |
+
self.eta_min = eta_min
|
| 39 |
+
|
| 40 |
+
super().__init__(optimizer, last_epoch)
|
| 41 |
+
|
| 42 |
+
def get_lr(self) -> List[float]:
|
| 43 |
+
"""Compute learning rate using chainable form of the scheduler."""
|
| 44 |
+
if not self._get_lr_called_within_step:
|
| 45 |
+
warnings.warn(
|
| 46 |
+
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
|
| 47 |
+
UserWarning,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if self.last_epoch == self.warmup_epochs:
|
| 51 |
+
return self.base_lrs
|
| 52 |
+
if self.last_epoch == 0:
|
| 53 |
+
return [self.warmup_start_lr] * len(self.base_lrs)
|
| 54 |
+
if self.last_epoch < self.warmup_epochs:
|
| 55 |
+
return [
|
| 56 |
+
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
|
| 57 |
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
| 58 |
+
]
|
| 59 |
+
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
|
| 60 |
+
return [
|
| 61 |
+
group["lr"]
|
| 62 |
+
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
|
| 63 |
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
return [
|
| 67 |
+
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
|
| 68 |
+
/ (
|
| 69 |
+
1
|
| 70 |
+
+ math.cos(
|
| 71 |
+
math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
|
| 72 |
+
)
|
| 73 |
+
)
|
| 74 |
+
* (group["lr"] - self.eta_min)
|
| 75 |
+
+ self.eta_min
|
| 76 |
+
for group in self.optimizer.param_groups
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
def _get_closed_form_lr(self) -> List[float]:
|
| 80 |
+
"""Called when epoch is passed as a param to the `step` function of the scheduler."""
|
| 81 |
+
if self.last_epoch < self.warmup_epochs:
|
| 82 |
+
return [
|
| 83 |
+
self.warmup_start_lr
|
| 84 |
+
+ self.last_epoch * (base_lr - self.warmup_start_lr) / max(1, self.warmup_epochs - 1)
|
| 85 |
+
for base_lr in self.base_lrs
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
return [
|
| 89 |
+
self.eta_min
|
| 90 |
+
+ 0.5
|
| 91 |
+
* (base_lr - self.eta_min)
|
| 92 |
+
* (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
|
| 93 |
+
for base_lr in self.base_lrs
|
| 94 |
+
]
|
src/metric.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
def MSE(pred,GT,lat,clim):
|
| 4 |
+
return torch.mean((pred-GT)**2)
|
| 5 |
+
def RMSE(pred,GT,lat,clim):
|
| 6 |
+
return torch.sqrt(torch.mean((pred-GT)**2))
|
| 7 |
+
def MAE(pred,GT,lat,clim):
|
| 8 |
+
return torch.mean(torch.abs(pred-GT))
|
| 9 |
+
def WMSE(pred, y, lat,clim):
|
| 10 |
+
if(lat is None):return 0
|
| 11 |
+
error = (pred - y) ** 2 # [N, C, H, W]
|
| 12 |
+
# lattitude weights
|
| 13 |
+
w_lat = np.cos(np.deg2rad(lat))
|
| 14 |
+
w_lat = w_lat / w_lat.mean()
|
| 15 |
+
w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=error.dtype, device=error.device) # (1, H, 1)
|
| 16 |
+
loss = (error * w_lat).mean()
|
| 17 |
+
return loss
|
| 18 |
+
def WRMSE(pred,GT,lat,clim):
|
| 19 |
+
if(lat is None):return 0
|
| 20 |
+
error = (pred - GT) ** 2 # [B, V, H, W]
|
| 21 |
+
# lattitude weights
|
| 22 |
+
w_lat = np.cos(np.deg2rad(lat))
|
| 23 |
+
w_lat = w_lat / w_lat.mean() # (H, )
|
| 24 |
+
w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=error.dtype, device=error.device)
|
| 25 |
+
loss = torch.mean(
|
| 26 |
+
torch.sqrt(torch.mean(error* w_lat, dim=(-2, -1)))
|
| 27 |
+
)
|
| 28 |
+
return loss
|
| 29 |
+
def ACC(pred,GT,lat,clim):
|
| 30 |
+
if(lat is None):return 0
|
| 31 |
+
w_lat = np.cos(np.deg2rad(lat))
|
| 32 |
+
w_lat = w_lat / w_lat.mean() # (H, )
|
| 33 |
+
w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=pred.dtype, device=pred.device) # [1, H, 1]
|
| 34 |
+
|
| 35 |
+
# = torch.mean(y, dim=(0, 1), keepdim=True)
|
| 36 |
+
clim = clim.to(device=GT.device).unsqueeze(0)
|
| 37 |
+
pred = pred - clim
|
| 38 |
+
GT = GT - clim
|
| 39 |
+
pred_prime = pred - torch.mean(pred)
|
| 40 |
+
GT_prime = GT - torch.mean(GT)
|
| 41 |
+
loss = torch.sum(w_lat * pred_prime * GT_prime) / torch.sqrt(
|
| 42 |
+
torch.sum(w_lat * pred_prime**2) * torch.sum(w_lat * GT_prime**2)
|
| 43 |
+
)
|
| 44 |
+
return loss
|
src/module.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py
|
| 5 |
+
from typing import Any
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from pytorch_lightning import LightningModule
|
| 10 |
+
from torchvision.transforms import transforms
|
| 11 |
+
from lr_scheduler import LinearWarmupCosineAnnealingLR
|
| 12 |
+
from arch import Network
|
| 13 |
+
from metric import (
|
| 14 |
+
MSE,RMSE,MAE,ACC,WMSE,WRMSE
|
| 15 |
+
)
|
| 16 |
+
class WeatherForecastModule(LightningModule):
|
| 17 |
+
"""Lightning module for global forecasting with the ClimaX model.
|
| 18 |
+
Args:
|
| 19 |
+
net: Deeplearning model.
|
| 20 |
+
pretrained_path (str, optional): Path to pre-trained checkpoint.
|
| 21 |
+
lr (float, optional): Learning rate.
|
| 22 |
+
beta_1 (float, optional): Beta 1 for AdamW.
|
| 23 |
+
beta_2 (float, optional): Beta 2 for AdamW.
|
| 24 |
+
weight_decay (float, optional): Weight decay for AdamW.
|
| 25 |
+
warmup_epochs (int, optional): Number of warmup epochs.
|
| 26 |
+
max_epochs (int, optional): Number of total epochs.
|
| 27 |
+
warmup_start_lr (float, optional): Starting learning rate for warmup.
|
| 28 |
+
eta_min (float, optional): Minimum learning rate.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
net: Network,
|
| 33 |
+
pretrained_path: str = "",
|
| 34 |
+
lr: float = 5e-4,
|
| 35 |
+
beta_1: float = 0.9,
|
| 36 |
+
beta_2: float = 0.99,
|
| 37 |
+
weight_decay: float = 1e-5,
|
| 38 |
+
warmup_epochs: int = 10000,
|
| 39 |
+
max_epochs: int = 200000,
|
| 40 |
+
warmup_start_lr: float = 1e-8,
|
| 41 |
+
eta_min: float = 1e-8,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.save_hyperparameters(logger=True, ignore=["net"])
|
| 45 |
+
self.net = net
|
| 46 |
+
if len(pretrained_path) > 0:
|
| 47 |
+
self.load_pretrained_weights(pretrained_path)
|
| 48 |
+
|
| 49 |
+
def load_pretrained_weights(self, pretrained_path):
|
| 50 |
+
self.net.load_state_dict(torch.load(pretrained_path))
|
| 51 |
+
def set_path(self,path):
|
| 52 |
+
self.path = path
|
| 53 |
+
def set_size(self,rad_size,sat_size):
|
| 54 |
+
self.rad_size = rad_size
|
| 55 |
+
self.sat_size = sat_size
|
| 56 |
+
|
| 57 |
+
def set_lat(self):
|
| 58 |
+
lat = np.load(os.path.join(self.path,'sat_lat.npy'))
|
| 59 |
+
self.sat_lat = lat[lat.shape[-1]//2-self.sat_size//2:lat.shape[-1]//2+self.sat_size//2]
|
| 60 |
+
# self.sat_lat = np.load(os.path.join(self.path,'sat_lat.npy'))
|
| 61 |
+
# self.sat_clim = torch.from_numpy(np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation'])
|
| 62 |
+
def set_clim(self):
|
| 63 |
+
##########
|
| 64 |
+
rad_clim = np.load(os.path.join(self.path,'rad_clim.npz'))['precipitation']
|
| 65 |
+
sat_clim = np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation']
|
| 66 |
+
self.rad_clim = torch.from_numpy(rad_clim)
|
| 67 |
+
self.sat_clim = torch.from_numpy(sat_clim)
|
| 68 |
+
|
| 69 |
+
def set_normalize(self):
|
| 70 |
+
self.rad_mean = np.load(os.path.join(self.path,'rad_mean.npz'))['precipitation']
|
| 71 |
+
self.rad_std = np.load(os.path.join(self.path,'rad_std.npz'))['precipitation']
|
| 72 |
+
self.sat_mean = np.load(os.path.join(self.path,'sat_mean.npz'))['total_precipitation']
|
| 73 |
+
self.sat_std = np.load(os.path.join(self.path,'sat_std.npz'))['total_precipitation']
|
| 74 |
+
def set_denormalize(self):
|
| 75 |
+
self.rad_denormalization = transforms.Normalize(-self.rad_mean/self.rad_std,1/self.rad_std)
|
| 76 |
+
self.sat_denormalization = transforms.Normalize(-self.sat_mean/self.sat_std,1/self.sat_std)
|
| 77 |
+
def training_step(self, batch: Any, batch_idx: int):
|
| 78 |
+
inp_rad, inp_sat, out_rad, out_sat = batch
|
| 79 |
+
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
|
| 80 |
+
loss = torch.nn.MSELoss()
|
| 81 |
+
loss_rad = loss(pred_rad,out_rad)
|
| 82 |
+
loss_sat = loss(pred_sat,out_sat)
|
| 83 |
+
loss_tot = loss_rad + loss_sat
|
| 84 |
+
self.log("train/rad", loss_rad, prog_bar=True, logger = True)
|
| 85 |
+
self.log("train/sat", loss_sat, prog_bar=True, logger = True)
|
| 86 |
+
self.log("train/mse", loss_tot, prog_bar=True, logger = True)
|
| 87 |
+
return loss_tot
|
| 88 |
+
|
| 89 |
+
def validation_step(self, batch: Any, batch_idx: int):
|
| 90 |
+
inp_rad, inp_sat, out_rad, out_sat = batch
|
| 91 |
+
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
|
| 92 |
+
loss = torch.nn.MSELoss()
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
loss_rad = loss(pred_rad,out_rad)
|
| 95 |
+
loss_sat = loss(pred_sat,out_sat)
|
| 96 |
+
loss_tot = loss_rad + loss_sat
|
| 97 |
+
self.log("val/rad", loss_rad, prog_bar=True, logger = True)
|
| 98 |
+
self.log("val/sat", loss_sat, prog_bar=True, logger = True)
|
| 99 |
+
self.log("val/mse", loss_tot, prog_bar=True, logger = True)
|
| 100 |
+
return loss_tot
|
| 101 |
+
def test_step(self, batch: Any, batch_idx: int):
|
| 102 |
+
inp_rad, inp_sat, out_rad, out_sat = batch
|
| 103 |
+
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
|
| 104 |
+
loss = torch.nn.MSELoss()
|
| 105 |
+
self.rad_denormalization(out_rad)
|
| 106 |
+
rad_metric = [MSE,RMSE,ACC,MAE]
|
| 107 |
+
sat_metric = [MSE,WMSE,RMSE,WRMSE,ACC,MAE]
|
| 108 |
+
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
loss_rad = loss(self.rad_denormalization(pred_rad),self.rad_denormalization(out_rad))
|
| 111 |
+
loss_sat = loss(self.sat_denormalization(pred_sat),self.sat_denormalization(out_sat))
|
| 112 |
+
loss_tot = loss_rad + loss_sat
|
| 113 |
+
self.log(f"test/rad", loss_rad, prog_bar=True, logger = True)
|
| 114 |
+
self.log("test/sat", loss_sat, prog_bar=True, logger = True)
|
| 115 |
+
self.log("test/mse", loss_tot, prog_bar=True, logger = True)
|
| 116 |
+
for met in rad_metric:
|
| 117 |
+
loss_rad = met(
|
| 118 |
+
self.rad_denormalization(pred_rad),
|
| 119 |
+
self.rad_denormalization(out_rad),
|
| 120 |
+
np.ones(self.rad_size),
|
| 121 |
+
self.rad_clim
|
| 122 |
+
)
|
| 123 |
+
self.log(f"test/rad_{met.__name__}", loss_rad, prog_bar=True, logger = True)
|
| 124 |
+
for met in sat_metric:
|
| 125 |
+
loss_sat = met(
|
| 126 |
+
self.sat_denormalization(pred_sat),
|
| 127 |
+
self.sat_denormalization(out_sat),
|
| 128 |
+
self.sat_lat,
|
| 129 |
+
self.sat_clim,
|
| 130 |
+
)
|
| 131 |
+
self.log(f"test/sat_{met.__name__}", loss_sat, prog_bar=True, logger = True)
|
| 132 |
+
return loss_tot
|
| 133 |
+
def configure_optimizers(self):
|
| 134 |
+
decay = []
|
| 135 |
+
no_decay = []
|
| 136 |
+
for name, m in self.named_parameters():
|
| 137 |
+
if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name:
|
| 138 |
+
no_decay.append(m)
|
| 139 |
+
else:
|
| 140 |
+
decay.append(m)
|
| 141 |
+
optimizer = torch.optim.AdamW(
|
| 142 |
+
[
|
| 143 |
+
{
|
| 144 |
+
"params": decay,
|
| 145 |
+
"lr": self.hparams.lr,
|
| 146 |
+
"betas": (self.hparams.beta_1, self.hparams.beta_2),
|
| 147 |
+
"weight_decay": self.hparams.weight_decay,
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"params": no_decay,
|
| 151 |
+
"lr": self.hparams.lr,
|
| 152 |
+
"betas": (self.hparams.beta_1, self.hparams.beta_2),
|
| 153 |
+
"weight_decay": 0,
|
| 154 |
+
},
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
lr_scheduler = LinearWarmupCosineAnnealingLR(
|
| 159 |
+
optimizer,
|
| 160 |
+
self.hparams.warmup_epochs,
|
| 161 |
+
self.hparams.max_epochs,
|
| 162 |
+
self.hparams.warmup_start_lr,
|
| 163 |
+
self.hparams.eta_min,
|
| 164 |
+
)
|
| 165 |
+
scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
|
| 166 |
+
|
| 167 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
| 168 |
+
|
src/rad_clim.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
path_load = '/data/weather2025/NhaBe/train/rad'
|
| 7 |
+
path_save = '/data/weather2025/NhaBe'
|
| 8 |
+
num = 0
|
| 9 |
+
rad_clim = {}
|
| 10 |
+
for name in os.listdir(path_load):
|
| 11 |
+
file = np.load(os.path.join(path_load,name))
|
| 12 |
+
for field in file.keys():
|
| 13 |
+
if(num == 0):
|
| 14 |
+
rad_clim[field] = file[field]
|
| 15 |
+
else:
|
| 16 |
+
rad_clim[field] = rad_clim[field] + file[field]
|
| 17 |
+
num += 1
|
| 18 |
+
print(num,end='\r')
|
| 19 |
+
for field in rad_clim.keys():
|
| 20 |
+
rad_clim[field] = rad_clim[field]/num
|
| 21 |
+
rad_clim[field] = np.expand_dims(rad_clim[field],axis =0)
|
| 22 |
+
print(rad_clim[field].shape)
|
| 23 |
+
np.savez(os.path.join(path_save,'rad_clim.npz'),**rad_clim)
|