weatherforecast1024 commited on
Commit
f3b050a
·
verified ·
1 Parent(s): ee6b828

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. FusionModel.egg-info/PKG-INFO +8 -0
  3. FusionModel.egg-info/SOURCES.txt +11 -0
  4. FusionModel.egg-info/dependency_links.txt +1 -0
  5. FusionModel.egg-info/top_level.txt +2 -0
  6. checkpoint/Unet/checkpoints/epoch_003.ckpt +3 -0
  7. checkpoint/Unet/checkpoints/last.ckpt +3 -0
  8. checkpoint/Unet/csv_logs/version_0/hparams.yaml +24 -0
  9. checkpoint/Unet/csv_logs/version_0/metrics.csv +0 -0
  10. checkpoint/Unet/wandb_logs/config.yaml +157 -0
  11. checkpoint/Unet/wandb_logs/wandb/debug-internal.log +7 -0
  12. checkpoint/Unet/wandb_logs/wandb/debug.log +22 -0
  13. checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/files/output.log +0 -0
  14. checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/files/wandb-summary.json +1 -0
  15. checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-core.log +13 -0
  16. checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-internal.log +17 -0
  17. checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug.log +15 -0
  18. checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/run-m5tg7yyl.wandb +0 -0
  19. checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/output.log +161 -0
  20. checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/requirements.txt +77 -0
  21. checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/wandb-metadata.json +85 -0
  22. checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-core.log +7 -0
  23. checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-internal.log +7 -0
  24. checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug.log +22 -0
  25. checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/run-0nx0l2dh.wandb +3 -0
  26. configs/AttR2Unet.yaml +86 -0
  27. configs/AttUnet.yaml +86 -0
  28. configs/Nothing.yaml +86 -0
  29. configs/R2Unet.yaml +95 -0
  30. configs/Unet.yaml +104 -0
  31. pyproject.toml +21 -0
  32. src/__pycache__/arch.cpython-310.pyc +0 -0
  33. src/__pycache__/arch.cpython-312.pyc +0 -0
  34. src/__pycache__/arch.cpython-38.pyc +0 -0
  35. src/__pycache__/datamodule.cpython-310.pyc +0 -0
  36. src/__pycache__/datamodule.cpython-312.pyc +0 -0
  37. src/__pycache__/lr_scheduler.cpython-310.pyc +0 -0
  38. src/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
  39. src/__pycache__/metric.cpython-310.pyc +0 -0
  40. src/__pycache__/metric.cpython-312.pyc +0 -0
  41. src/__pycache__/module.cpython-310.pyc +0 -0
  42. src/__pycache__/module.cpython-312.pyc +0 -0
  43. src/__pycache__/module.cpython-38.pyc +0 -0
  44. src/__pycache__/train.cpython-38.pyc +0 -0
  45. src/arch.py +473 -0
  46. src/datamodule.py +341 -0
  47. src/lr_scheduler.py +94 -0
  48. src/metric.py +44 -0
  49. src/module.py +168 -0
  50. src/rad_clim.py +23 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/run-0nx0l2dh.wandb filter=lfs diff=lfs merge=lfs -text
FusionModel.egg-info/PKG-INFO ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: FusionModel
3
+ Version: 0.3.1
4
+ Author-email: Khanh Vinh Bui <khanhvinhbui0512@gmail.com>, Hong Trang Le <lhtrang@hcmut.edu.vn>
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: License :: OSI Approved :: MIT License
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
FusionModel.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pyproject.toml
2
+ FusionModel.egg-info/PKG-INFO
3
+ FusionModel.egg-info/SOURCES.txt
4
+ FusionModel.egg-info/dependency_links.txt
5
+ FusionModel.egg-info/top_level.txt
6
+ src/arch.py
7
+ src/datamodule.py
8
+ src/lr_scheduler.py
9
+ src/metric.py
10
+ src/module.py
11
+ src/train.py
FusionModel.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
FusionModel.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ configs
2
+ src
checkpoint/Unet/checkpoints/epoch_003.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f005cf7c67d6259fdc39a5ccb425db8367dc96622457009fcb82a9df5123487
3
+ size 521087
checkpoint/Unet/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f005cf7c67d6259fdc39a5ccb425db8367dc96622457009fcb82a9df5123487
3
+ size 521087
checkpoint/Unet/csv_logs/version_0/hparams.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _instantiator: pytorch_lightning.cli.instantiate_module
2
+ ablation: 'no'
3
+ batch_size: 1
4
+ beta_1: 0.9
5
+ beta_2: 0.99
6
+ dir_data: /data/weather2025/NhaBe/
7
+ eta_min: 1.0e-08
8
+ hours_predicted: 3
9
+ lr: 0.0005
10
+ max_epochs: 50
11
+ num_workers: 4
12
+ pin_memory: false
13
+ pretrained_path: ''
14
+ rad_inp_vars: precipitation
15
+ rad_out_vars: precipitation
16
+ rad_size: 400
17
+ sat_inp_vars: total_precipitation
18
+ sat_out_vars: total_precipitation
19
+ sat_size: 25
20
+ time_points_rad: 1
21
+ time_points_sat: 1
22
+ warmup_epochs: 10
23
+ warmup_start_lr: 1.0e-08
24
+ weight_decay: 1.0e-05
checkpoint/Unet/csv_logs/version_0/metrics.csv ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint/Unet/wandb_logs/config.yaml ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_lightning==2.5.1.post0
2
+ seed_everything: 42
3
+ trainer:
4
+ accelerator: cuda
5
+ strategy: auto
6
+ devices:
7
+ - 6
8
+ num_nodes: 1
9
+ precision: 16-mixed
10
+ logger:
11
+ - class_path: pytorch_lightning.loggers.WandbLogger
12
+ init_args:
13
+ name: UnetNhaBe
14
+ save_dir: checkpoint/Unet/wandb_logs
15
+ version: null
16
+ offline: false
17
+ dir: null
18
+ id: null
19
+ anonymous: null
20
+ project: NhaBe
21
+ log_model: false
22
+ experiment: null
23
+ prefix: ''
24
+ checkpoint_name: null
25
+ entity: null
26
+ notes: null
27
+ tags: null
28
+ config: null
29
+ config_exclude_keys: null
30
+ config_include_keys: null
31
+ allow_val_change: null
32
+ group: null
33
+ job_type: null
34
+ mode: null
35
+ force: null
36
+ reinit: null
37
+ resume: null
38
+ resume_from: null
39
+ fork_from: null
40
+ save_code: null
41
+ tensorboard: null
42
+ sync_tensorboard: null
43
+ monitor_gym: null
44
+ settings: null
45
+ - class_path: pytorch_lightning.loggers.CSVLogger
46
+ init_args:
47
+ save_dir: checkpoint/Unet/csv_logs
48
+ name: null
49
+ version: null
50
+ prefix: ''
51
+ flush_logs_every_n_steps: 100
52
+ callbacks:
53
+ - class_path: pytorch_lightning.callbacks.LearningRateMonitor
54
+ init_args:
55
+ logging_interval: step
56
+ log_momentum: false
57
+ log_weight_decay: false
58
+ - class_path: pytorch_lightning.callbacks.ModelCheckpoint
59
+ init_args:
60
+ dirpath: checkpoint/Unet/checkpoints
61
+ filename: epoch_{epoch:03d}
62
+ monitor: val/mse
63
+ verbose: false
64
+ save_last: true
65
+ save_top_k: 1
66
+ save_weights_only: false
67
+ mode: min
68
+ auto_insert_metric_name: false
69
+ every_n_train_steps: null
70
+ train_time_interval: null
71
+ every_n_epochs: null
72
+ save_on_train_epoch_end: null
73
+ enable_version_counter: true
74
+ - class_path: pytorch_lightning.callbacks.EarlyStopping
75
+ init_args:
76
+ monitor: val/mse
77
+ min_delta: 0.0
78
+ patience: 10
79
+ verbose: false
80
+ mode: min
81
+ strict: true
82
+ check_finite: true
83
+ stopping_threshold: null
84
+ divergence_threshold: null
85
+ check_on_train_epoch_end: null
86
+ log_rank_zero_only: false
87
+ - class_path: pytorch_lightning.callbacks.RichModelSummary
88
+ init_args:
89
+ max_depth: -1
90
+ fast_dev_run: false
91
+ max_epochs: 100
92
+ min_epochs: 1
93
+ max_steps: -1
94
+ min_steps: null
95
+ max_time: null
96
+ limit_train_batches: null
97
+ limit_val_batches: null
98
+ limit_test_batches: null
99
+ limit_predict_batches: null
100
+ overfit_batches: 0.0
101
+ val_check_interval: null
102
+ check_val_every_n_epoch: 1
103
+ num_sanity_val_steps: null
104
+ log_every_n_steps: null
105
+ enable_checkpointing: true
106
+ enable_progress_bar: true
107
+ enable_model_summary: null
108
+ accumulate_grad_batches: 1
109
+ gradient_clip_val: null
110
+ gradient_clip_algorithm: null
111
+ deterministic: null
112
+ benchmark: null
113
+ inference_mode: true
114
+ use_distributed_sampler: true
115
+ profiler: null
116
+ detect_anomaly: false
117
+ barebones: false
118
+ plugins: null
119
+ sync_batchnorm: true
120
+ reload_dataloaders_every_n_epochs: 0
121
+ default_root_dir: checkpoint/Unet
122
+ model_registry: null
123
+ model:
124
+ net:
125
+ class_path: arch.Network
126
+ init_args:
127
+ model_type: Unet
128
+ rad_channel: 1
129
+ sat_channel: 1
130
+ rad_size: 400
131
+ sat_size: 25
132
+ pretrained_path: ''
133
+ lr: 0.0005
134
+ beta_1: 0.9
135
+ beta_2: 0.99
136
+ weight_decay: 1.0e-05
137
+ warmup_epochs: 10
138
+ max_epochs: 50
139
+ warmup_start_lr: 1.0e-08
140
+ eta_min: 1.0e-08
141
+ data:
142
+ dir_data: /data/weather2025/NhaBe/
143
+ batch_size: 1
144
+ hours_predicted: 3
145
+ num_workers: 4
146
+ pin_memory: false
147
+ time_points_rad: 1
148
+ time_points_sat: 1
149
+ sat_inp_vars: total_precipitation
150
+ sat_out_vars: total_precipitation
151
+ sat_size: 25
152
+ rad_inp_vars: precipitation
153
+ rad_out_vars: precipitation
154
+ rad_size: 400
155
+ ablation: 'no'
156
+ optimizer: null
157
+ lr_scheduler: null
checkpoint/Unet/wandb_logs/wandb/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2025-06-17T09:05:28.179242652Z","level":"INFO","msg":"stream: starting","core version":"0.20.1","symlink path":"checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-core.log"}
2
+ {"time":"2025-06-17T09:05:29.423278937Z","level":"INFO","msg":"stream: created new stream","id":"0nx0l2dh"}
3
+ {"time":"2025-06-17T09:05:29.423321777Z","level":"INFO","msg":"stream: started","id":"0nx0l2dh"}
4
+ {"time":"2025-06-17T09:05:29.423393558Z","level":"INFO","msg":"sender: started","stream_id":"0nx0l2dh"}
5
+ {"time":"2025-06-17T09:05:29.423393088Z","level":"INFO","msg":"writer: Do: started","stream_id":"0nx0l2dh"}
6
+ {"time":"2025-06-17T09:05:29.423465179Z","level":"INFO","msg":"handler: started","stream_id":"0nx0l2dh"}
7
+ {"time":"2025-06-17T09:05:30.100696875Z","level":"INFO","msg":"Starting system monitor"}
checkpoint/Unet/wandb_logs/wandb/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Current SDK version is 0.20.1
2
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Configure stats pid to 1311468
3
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/.config/wandb/settings
4
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/weather_forecast/Unet/wandb/settings
5
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from environment variables
6
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:setup_run_log_directory():703] Logging user logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug.log
7
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-internal.log
8
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():831] calling init triggers
9
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():836] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():872] starting backend
12
+ 2025-06-17 09:05:28,169 INFO MainThread:1311468 [wandb_init.py:init():875] sending inform_init request
13
+ 2025-06-17 09:05:28,174 INFO MainThread:1311468 [wandb_init.py:init():883] backend started and connected
14
+ 2025-06-17 09:05:28,175 INFO MainThread:1311468 [wandb_init.py:init():956] updated telemetry
15
+ 2025-06-17 09:05:28,175 INFO MainThread:1311468 [wandb_init.py:init():980] communicating run to backend with 90.0 second timeout
16
+ 2025-06-17 09:05:30,098 INFO MainThread:1311468 [wandb_init.py:init():1032] starting run threads in backend
17
+ 2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_console_start():2453] atexit reg
18
+ 2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2301] redirect: wrap_raw
19
+ 2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2370] Wrapping output streams.
20
+ 2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2393] Redirects installed.
21
+ 2025-06-17 09:05:30,177 INFO MainThread:1311468 [wandb_init.py:init():1078] run started, returning control to user process
22
+ 2025-06-17 09:05:31,151 INFO MainThread:1311468 [wandb_run.py:_config_callback():1358] config_cb None None {'pretrained_path': '', 'lr': 0.0005, 'beta_1': 0.9, 'beta_2': 0.99, 'weight_decay': 1e-05, 'warmup_epochs': 10, 'max_epochs': 50, 'warmup_start_lr': 1e-08, 'eta_min': 1e-08, '_instantiator': 'pytorch_lightning.cli.instantiate_module', 'dir_data': '/data/weather2025/NhaBe/', 'batch_size': 1, 'hours_predicted': 3, 'num_workers': 4, 'pin_memory': False, 'time_points_rad': 1, 'time_points_sat': 1, 'sat_inp_vars': 'total_precipitation', 'sat_out_vars': 'total_precipitation', 'sat_size': 25, 'rad_inp_vars': 'precipitation', 'rad_out_vars': 'precipitation', 'rad_size': 400, 'ablation': 'no'}
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/files/output.log ADDED
File without changes
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb":{"runtime":0}}
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-core.log ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-06-17T08:57:44.288260722Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpbpuchipv/port-1289333.txt","pid":1289333,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
2
+ {"time":"2025-06-17T08:57:44.289762517Z","level":"INFO","msg":"Will exit if parent process dies.","ppid":1289333}
3
+ {"time":"2025-06-17T08:57:44.289701246Z","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":43767,"Zone":""}}
4
+ {"time":"2025-06-17T08:57:44.468360629Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:33480"}
5
+ {"time":"2025-06-17T08:57:44.478126346Z","level":"INFO","msg":"handleInformInit: received","streamId":"m5tg7yyl","id":"127.0.0.1:33480"}
6
+ {"time":"2025-06-17T08:57:45.013693012Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"m5tg7yyl","id":"127.0.0.1:33480"}
7
+ {"time":"2025-06-17T08:57:46.227115814Z","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"127.0.0.1:33480"}
8
+ {"time":"2025-06-17T08:57:46.227331796Z","level":"INFO","msg":"server is shutting down"}
9
+ {"time":"2025-06-17T08:57:46.227313783Z","level":"INFO","msg":"connection: closing","id":"127.0.0.1:33480"}
10
+ {"time":"2025-06-17T08:57:46.227453186Z","level":"INFO","msg":"connection: closed successfully","id":"127.0.0.1:33480"}
11
+ {"time":"2025-06-17T08:57:46.48785785Z","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"127.0.0.1:33480"}
12
+ {"time":"2025-06-17T08:57:46.487909579Z","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"127.0.0.1:33480"}
13
+ {"time":"2025-06-17T08:57:46.487925552Z","level":"INFO","msg":"server is closed"}
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-internal.log ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2025-06-17T08:57:44.478779812Z","level":"INFO","msg":"stream: starting","core version":"0.20.1","symlink path":"checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-core.log"}
2
+ {"time":"2025-06-17T08:57:45.013625369Z","level":"INFO","msg":"stream: created new stream","id":"m5tg7yyl"}
3
+ {"time":"2025-06-17T08:57:45.013682966Z","level":"INFO","msg":"stream: started","id":"m5tg7yyl"}
4
+ {"time":"2025-06-17T08:57:45.013709365Z","level":"INFO","msg":"handler: started","stream_id":"m5tg7yyl"}
5
+ {"time":"2025-06-17T08:57:45.013744481Z","level":"INFO","msg":"sender: started","stream_id":"m5tg7yyl"}
6
+ {"time":"2025-06-17T08:57:45.013733645Z","level":"INFO","msg":"writer: Do: started","stream_id":"m5tg7yyl"}
7
+ {"time":"2025-06-17T08:57:45.224225022Z","level":"ERROR","msg":"HTTP error","status":403,"method":"POST","url":"https://api.wandb.ai/graphql"}
8
+ {"time":"2025-06-17T08:57:45.22437671Z","level":"ERROR","msg":"runupserter: failed to init run","error":"returned error 403: {\"data\":{\"upsertBucket\":null},\"errors\":[{\"message\":\"permission denied\",\"path\":[\"upsertBucket\"],\"extensions\":{\"code\":\"PERMISSION_ERROR\"}}]}"}
9
+ {"time":"2025-06-17T08:57:46.227328327Z","level":"INFO","msg":"stream: closing","id":"m5tg7yyl"}
10
+ {"time":"2025-06-17T08:57:46.227825345Z","level":"ERROR","msg":"sender: uploadConfigFile: stream: no run"}
11
+ {"time":"2025-06-17T08:57:46.486865753Z","level":"ERROR","msg":"HTTP error","status":404,"method":"POST","url":"https://api.wandb.ai/graphql"}
12
+ {"time":"2025-06-17T08:57:46.486986554Z","level":"ERROR","msg":"runfiles: CreateRunFiles returned error: returned error 404: {\"data\":{\"createRunFiles\":null},\"errors\":[{\"message\":\"project vinh-bui0512-hcmut/NhaBe not found during createRunFiles\",\"path\":[\"createRunFiles\"]}]}"}
13
+ {"time":"2025-06-17T08:57:46.487641258Z","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
14
+ {"time":"2025-06-17T08:57:46.487699694Z","level":"INFO","msg":"handler: closed","stream_id":"m5tg7yyl"}
15
+ {"time":"2025-06-17T08:57:46.487714658Z","level":"INFO","msg":"writer: Close: closed","stream_id":"m5tg7yyl"}
16
+ {"time":"2025-06-17T08:57:46.487745625Z","level":"INFO","msg":"sender: closed","stream_id":"m5tg7yyl"}
17
+ {"time":"2025-06-17T08:57:46.487775923Z","level":"INFO","msg":"stream: closed","id":"m5tg7yyl"}
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug.log ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Current SDK version is 0.20.1
2
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Configure stats pid to 1289333
3
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/.config/wandb/settings
4
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/weather_forecast/Unet/wandb/settings
5
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_setup.py:_flush():81] Loading settings from environment variables
6
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:setup_run_log_directory():703] Logging user logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug.log
7
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/logs/debug-internal.log
8
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:init():831] calling init triggers
9
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:init():836] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-06-17 08:57:44,256 INFO MainThread:1289333 [wandb_init.py:init():872] starting backend
12
+ 2025-06-17 08:57:44,468 INFO MainThread:1289333 [wandb_init.py:init():875] sending inform_init request
13
+ 2025-06-17 08:57:44,473 INFO MainThread:1289333 [wandb_init.py:init():883] backend started and connected
14
+ 2025-06-17 08:57:44,475 INFO MainThread:1289333 [wandb_init.py:init():956] updated telemetry
15
+ 2025-06-17 08:57:44,476 INFO MainThread:1289333 [wandb_init.py:init():980] communicating run to backend with 90.0 second timeout
checkpoint/Unet/wandb_logs/wandb/run-20250617_085744-m5tg7yyl/run-m5tg7yyl.wandb ADDED
Binary file (366 Bytes). View file
 
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/output.log ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Number of train samples: 31462
2
+ Number of test samples: 8077
3
+ Number of val samples: 1398
4
+ LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6]
5
+ ┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓
6
+ ┃   ┃ Name  ┃ Type  ┃ Params ┃ Mode  ┃
7
+ ┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩
8
+ │ 0  │ net │ Network │ 30.0 K │ train │
9
+ │ 1  │ net.net │ Unet │ 30.0 K │ train │
10
+ │ 2  │ net.net.encoder_blocks │ ModuleList │ 4.8 K │ train │
11
+ │ 3  │ net.net.encoder_blocks.0 │ ConvBlock │ 66 │ train │
12
+ │ 4  │ net.net.encoder_blocks.0.conv │ Sequential │ 66 │ train │
13
+ │ 5  │ net.net.encoder_blocks.0.conv.0 │ Conv2d │ 20 │ train │
14
+ │ 6  │ net.net.encoder_blocks.0.conv.1 │ BatchNorm2d │ 4 │ train │
15
+ │ 7  │ net.net.encoder_blocks.0.conv.2 │ ReLU │ 0 │ train │
16
+ │ 8  │ net.net.encoder_blocks.0.conv.3 │ Conv2d │ 38 │ train │
17
+ │ 9  │ net.net.encoder_blocks.0.conv.4 │ BatchNorm2d │ 4 │ train │
18
+ │ 10  │ net.net.encoder_blocks.0.conv.5 │ ReLU │ 0 │ train │
19
+ │ 11  │ net.net.encoder_blocks.1 │ ConvBlock │ 240 │ train │
20
+ │ 12  │ net.net.encoder_blocks.1.conv │ Sequential │ 240 │ train │
21
+ │ 13  │ net.net.encoder_blocks.1.conv.0 │ Conv2d │ 76 │ train │
22
+ │ 14  │ net.net.encoder_blocks.1.conv.1 │ BatchNorm2d │ 8 │ train │
23
+ │ 15  │ net.net.encoder_blocks.1.conv.2 │ ReLU │ 0 │ train │
24
+ │ 16  │ net.net.encoder_blocks.1.conv.3 │ Conv2d │ 148 │ train │
25
+ │ 17  │ net.net.encoder_blocks.1.conv.4 │ BatchNorm2d │ 8 │ train │
26
+ │ 18  │ net.net.encoder_blocks.1.conv.5 │ ReLU │ 0 │ train │
27
+ │ 19  │ net.net.encoder_blocks.2 │ ConvBlock │ 912 │ train │
28
+ │ 20  │ net.net.encoder_blocks.2.conv │ Sequential │ 912 │ train │
29
+ │ 21  │ net.net.encoder_blocks.2.conv.0 │ Conv2d │ 296 │ train │
30
+ │ 22  │ net.net.encoder_blocks.2.conv.1 │ BatchNorm2d │ 16 │ train │
31
+ │ 23  │ net.net.encoder_blocks.2.conv.2 │ ReLU │ 0 │ train │
32
+ │ 24  │ net.net.encoder_blocks.2.conv.3 │ Conv2d │ 584 │ train │
33
+ │ 25  │ net.net.encoder_blocks.2.conv.4 │ BatchNorm2d │ 16 │ train │
34
+ │ 26  │ net.net.encoder_blocks.2.conv.5 │ ReLU │ 0 │ train │
35
+ │ 27  │ net.net.encoder_blocks.3 │ ConvBlock │ 3.6 K │ train │
36
+ │ 28  │ net.net.encoder_blocks.3.conv │ Sequential │ 3.6 K │ train │
37
+ │ 29  │ net.net.encoder_blocks.3.conv.0 │ Conv2d │ 1.2 K │ train │
38
+ │ 30  │ net.net.encoder_blocks.3.conv.1 │ BatchNorm2d │ 32 │ train │
39
+ │ 31  │ net.net.encoder_blocks.3.conv.2 │ ReLU │ 0 │ train │
40
+ │ 32  │ net.net.encoder_blocks.3.conv.3 │ Conv2d │ 2.3 K │ train │
41
+ │ 33  │ net.net.encoder_blocks.3.conv.4 │ BatchNorm2d │ 32 │ train │
42
+ │ 34  │ net.net.encoder_blocks.3.conv.5 │ ReLU │ 0 │ train │
43
+ │ 35  │ net.net.pools │ ModuleList │ 0 │ train │
44
+ │ 36  │ net.net.pools.0 │ MaxPool2d │ 0 │ train │
45
+ │ 37  │ net.net.pools.1 │ MaxPool2d │ 0 │ train │
46
+ │ 38  │ net.net.pools.2 │ MaxPool2d │ 0 │ train │
47
+ │ 39  │ net.net.pools.3 │ MaxPool2d │ 0 │ train │
48
+ │ 40  │ net.net.mid_conv_1 │ single_conv │ 2.4 K │ train │
49
+ │ 41  │ net.net.mid_conv_1.conv │ Sequential │ 2.4 K │ train │
50
+ │ 42  │ net.net.mid_conv_1.conv.0 │ Conv2d │ 2.3 K │ train │
51
+ │ 43  │ net.net.mid_conv_1.conv.1 │ BatchNorm2d │ 32 │ train │
52
+ │ 44  │ net.net.mid_conv_1.conv.2 │ ReLU │ 0 │ train │
53
+ │ 45  │ net.net.mid_conv_2 │ single_conv │ 192 │ train │
54
+ │ 46  │ net.net.mid_conv_2.conv │ Sequential │ 192 │ train │
55
+ │ 47  │ net.net.mid_conv_2.conv.0 │ Conv2d │ 160 │ train │
56
+ │ 48  │ net.net.mid_conv_2.conv.1 │ BatchNorm2d │ 32 │ train │
57
+ │ 49  │ net.net.mid_conv_2.conv.2 │ ReLU │ 0 │ train │
58
+ │ 50  │ net.net.mid_merge │ ConvBlock │ 7.0 K │ train │
59
+ │ 51  │ net.net.mid_merge.conv │ Sequential │ 7.0 K │ train │
60
+ │ 52  │ net.net.mid_merge.conv.0 │ Conv2d │ 4.6 K │ train │
61
+ │ 53  │ net.net.mid_merge.conv.1 │ BatchNorm2d │ 32 │ train │
62
+ │ 54  │ net.net.mid_merge.conv.2 │ ReLU │ 0 │ train │
63
+ │ 55  │ net.net.mid_merge.conv.3 │ Conv2d │ 2.3 K │ train │
64
+ │ 56  │ net.net.mid_merge.conv.4 │ BatchNorm2d │ 32 │ train │
65
+ │ 57  │ net.net.mid_merge.conv.5 │ ReLU │ 0 │ train │
66
+ │ 58  │ net.net.up_convs │ ModuleList │ 6.2 K │ train │
67
+ │ 59  │ net.net.up_convs.0 │ UpConv │ 4.7 K │ train │
68
+ │ 60  │ net.net.up_convs.0.up │ Sequential │ 4.7 K │ train │
69
+ │ 61  │ net.net.up_convs.0.up.0 │ Upsample │ 0 │ train │
70
+ │ 62  │ net.net.up_convs.0.up.1 │ Conv2d │ 4.6 K │ train │
71
+ │ 63  │ net.net.up_convs.0.up.2 │ BatchNorm2d │ 32 │ train │
72
+ │ 64  │ net.net.up_convs.0.up.3 │ ReLU │ 0 │ train │
73
+ │ 65  │ net.net.up_convs.1 │ UpConv │ 1.2 K │ train │
74
+ │ 66  │ net.net.up_convs.1.up │ Sequential │ 1.2 K │ train │
75
+ │ 67  │ net.net.up_convs.1.up.0 │ Upsample │ 0 │ train │
76
+ │ 68  │ net.net.up_convs.1.up.1 │ Conv2d │ 1.2 K │ train │
77
+ │ 69  │ net.net.up_convs.1.up.2 │ BatchNorm2d │ 16 │ train │
78
+ │ 70  │ net.net.up_convs.1.up.3 │ ReLU │ 0 │ train │
79
+ │ 71  │ net.net.up_convs.2 │ UpConv │ 300 │ train │
80
+ │ 72  │ net.net.up_convs.2.up │ Sequential │ 300 │ train │
81
+ │ 73  │ net.net.up_convs.2.up.0 │ Upsample │ 0 │ train │
82
+ │ 74  │ net.net.up_convs.2.up.1 │ Conv2d │ 292 │ train │
83
+ │ 75  │ net.net.up_convs.2.up.2 │ BatchNorm2d │ 8 │ train │
84
+ │ 76  │ net.net.up_convs.2.up.3 │ ReLU │ 0 │ train │
85
+ │ 77  │ net.net.up_convs.3 │ UpConv │ 78 │ train │
86
+ │ 78  │ net.net.up_convs.3.up │ Sequential │ 78 │ train │
87
+ │ 79  │ net.net.up_convs.3.up.0 │ Upsample │ 0 │ train │
88
+ │ 80  │ net.net.up_convs.3.up.1 │ Conv2d │ 74 │ train │
89
+ │ 81  │ net.net.up_convs.3.up.2 │ BatchNorm2d │ 4 │ train │
90
+ │ 82  │ net.net.up_convs.3.up.3 │ ReLU │ 0 │ train │
91
+ │ 83  │ net.net.decoder_blocks │ ModuleList │ 9.4 K │ train │
92
+ │ 84  │ net.net.decoder_blocks.0 │ ConvBlock │ 7.0 K │ train │
93
+ │ 85  │ net.net.decoder_blocks.0.conv │ Sequential │ 7.0 K │ train │
94
+ │ 86  │ net.net.decoder_blocks.0.conv.0 │ Conv2d │ 4.6 K │ train │
95
+ │ 87  │ net.net.decoder_blocks.0.conv.1 │ BatchNorm2d │ 32 │ train │
96
+ │ 88  │ net.net.decoder_blocks.0.conv.2 │ ReLU │ 0 │ train │
97
+ │ 89  │ net.net.decoder_blocks.0.conv.3 │ Conv2d │ 2.3 K │ train │
98
+ │ 90  │ net.net.decoder_blocks.0.conv.4 │ BatchNorm2d │ 32 │ train │
99
+ │ 91  │ net.net.decoder_blocks.0.conv.5 │ ReLU │ 0 │ train │
100
+ │ 92  │ net.net.decoder_blocks.1 │ ConvBlock │ 1.8 K │ train │
101
+ │ 93  │ net.net.decoder_blocks.1.conv │ Sequential │ 1.8 K │ train │
102
+ │ 94  │ net.net.decoder_blocks.1.conv.0 │ Conv2d │ 1.2 K │ train │
103
+ │ 95  │ net.net.decoder_blocks.1.conv.1 │ BatchNorm2d │ 16 │ train │
104
+ │ 96  │ net.net.decoder_blocks.1.conv.2 │ ReLU │ 0 │ train │
105
+ │ 97  │ net.net.decoder_blocks.1.conv.3 │ Conv2d │ 584 │ train │
106
+ │ 98  │ net.net.decoder_blocks.1.conv.4 │ BatchNorm2d │ 16 │ train │
107
+ │ 99  │ net.net.decoder_blocks.1.conv.5 │ ReLU │ 0 │ train │
108
+ │ 100 │ net.net.decoder_blocks.2 │ ConvBlock │ 456 │ train │
109
+ │ 101 │ net.net.decoder_blocks.2.conv │ Sequential │ 456 │ train │
110
+ │ 102 │ net.net.decoder_blocks.2.conv.0 │ Conv2d │ 292 │ train │
111
+ │ 103 │ net.net.decoder_blocks.2.conv.1 │ BatchNorm2d │ 8 │ train │
112
+ │ 104 │ net.net.decoder_blocks.2.conv.2 │ ReLU │ 0 │ train │
113
+ │ 105 │ net.net.decoder_blocks.2.conv.3 │ Conv2d │ 148 │ train │
114
+ │ 106 │ net.net.decoder_blocks.2.conv.4 │ BatchNorm2d │ 8 │ train │
115
+ │ 107 │ net.net.decoder_blocks.2.conv.5 │ ReLU │ 0 │ train │
116
+ │ 108 │ net.net.decoder_blocks.3 │ ConvBlock │ 120 │ train │
117
+ │ 109 │ net.net.decoder_blocks.3.conv │ Sequential │ 120 │ train │
118
+ │ 110 │ net.net.decoder_blocks.3.conv.0 │ Conv2d │ 74 │ train │
119
+ │ 111 │ net.net.decoder_blocks.3.conv.1 │ BatchNorm2d │ 4 │ train │
120
+ │ 112 │ net.net.decoder_blocks.3.conv.2 │ ReLU │ 0 │ train │
121
+ │ 113 │ net.net.decoder_blocks.3.conv.3 │ Conv2d │ 38 │ train │
122
+ │ 114 │ net.net.decoder_blocks.3.conv.4 │ BatchNorm2d │ 4 │ train │
123
+ │ 115 │ net.net.decoder_blocks.3.conv.5 │ ReLU │ 0 │ train │
124
+ │ 116 │ net.net.final_decoder │ ConvBlock │ 120 │ train │
125
+ │ 117 │ net.net.final_decoder.conv │ Sequential │ 120 │ train │
126
+ │ 118 │ net.net.final_decoder.conv.0 │ Conv2d │ 74 │ train │
127
+ │ 119 │ net.net.final_decoder.conv.1 │ BatchNorm2d │ 4 │ train │
128
+ │ 120 │ net.net.final_decoder.conv.2 │ ReLU │ 0 │ train │
129
+ │ 121 │ net.net.final_decoder.conv.3 │ Conv2d │ 38 │ train │
130
+ │ 122 │ net.net.final_decoder.conv.4 │ BatchNorm2d │ 4 │ train │
131
+ │ 123 │ net.net.final_decoder.conv.5 │ ReLU │ 0 │ train │
132
+ │ 124 │ net.net.out_conv_R │ Conv2d │ 3 │ train │
133
+ │ 125 │ net.net.out_conv_S │ Conv2d │ 17 │ train │
134
+ │ 126 │ rad_denormalization │ Normalize │ 0 │ train │
135
+ │ 127 │ sat_denormalization │ Normalize │ 0 │ train │
136
+ └─────┴─────────────────────────────────┴─────────────┴────────┴───────┘
137
+ Trainable params: 30.0 K
138
+ Non-trainable params: 0
139
+ Total params: 30.0 K
140
+ Total estimated model params size (MB): 0
141
+ Modules in train mode: 128
142
+ Modules in eval mode: 0
143
+ Epoch 4: 17%|▏| 5205/31462 [02:33<12:54, 33.89it/s, v_num=dh_0, train/rad=0.120, train/sat=2.380, train/mse=2.500, val/rad=1.970, val/sat=1.140, val/mse
144
+ /home/radaric/.conda/envs/unet/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:182: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
145
+ warnings.warn(
146
+
147
+
148
+ Detected KeyboardInterrupt, attempting graceful shutdown ...
149
+ Exception ignored in atexit callback: <function _start_and_connect_service.<locals>.teardown_atexit at 0x7fafa15b5360>
150
+ Traceback (most recent call last):
151
+ File "/home/radaric/.conda/envs/unet/lib/python3.10/site-packages/wandb/sdk/lib/service_connection.py", line 90, in teardown_atexit
152
+ conn.teardown(hooks.exit_code)
153
+ File "/home/radaric/.conda/envs/unet/lib/python3.10/site-packages/wandb/sdk/lib/service_connection.py", line 218, in teardown
154
+ self._router.join()
155
+ File "/home/radaric/.conda/envs/unet/lib/python3.10/site-packages/wandb/sdk/interface/router.py", line 75, in join
156
+ self._thread.join()
157
+ File "/home/radaric/.conda/envs/unet/lib/python3.10/threading.py", line 1096, in join
158
+ self._wait_for_tstate_lock()
159
+ File "/home/radaric/.conda/envs/unet/lib/python3.10/threading.py", line 1116, in _wait_for_tstate_lock
160
+ if lock.acquire(block, timeout):
161
+ KeyboardInterrupt:
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/requirements.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ urllib3==2.4.0
2
+ requests==2.32.4
3
+ typing-inspection==0.4.1
4
+ Jinja2==3.1.6
5
+ MarkupSafe==3.0.2
6
+ setuptools==78.1.1
7
+ frozenlist==1.7.0
8
+ aiosignal==1.3.2
9
+ docstring_parser==0.16
10
+ aiohappyeyeballs==2.6.1
11
+ ClimaX==0.3.1
12
+ platformdirs==4.3.8
13
+ async-timeout==5.0.1
14
+ nvidia-cusolver-cu12==11.7.1.2
15
+ protobuf==6.31.1
16
+ charset-normalizer==3.4.2
17
+ attrs==25.3.0
18
+ pip==25.1
19
+ nvidia-cufile-cu12==1.11.1.6
20
+ importlib_resources==6.5.2
21
+ nvidia-nvjitlink-cu12==12.6.85
22
+ numpy==2.2.6
23
+ typeshed_client==2.7.0
24
+ jsonargparse==4.40.0
25
+ nvidia-cusparselt-cu12==0.6.3
26
+ GitPython==3.1.44
27
+ nvidia-cusparse-cu12==12.5.4.2
28
+ mpmath==1.3.0
29
+ pytorch-lightning==2.5.1.post0
30
+ torchvision==0.22.1
31
+ PyYAML==6.0.2
32
+ nvidia-cudnn-cu12==9.5.1.17
33
+ markdown-it-py==3.0.0
34
+ typing_extensions==4.14.0
35
+ smmap==5.0.2
36
+ pydantic_core==2.33.2
37
+ torchsummary==1.5.1
38
+ nvidia-cublas-cu12==12.6.4.1
39
+ FusionModel==0.3.1
40
+ mdurl==0.1.2
41
+ sentry-sdk==2.30.0
42
+ nvidia-curand-cu12==10.3.7.77
43
+ idna==3.10
44
+ triton==3.3.1
45
+ multidict==6.4.4
46
+ Pygments==2.19.1
47
+ nvidia-cuda-cupti-cu12==12.6.80
48
+ tqdm==4.67.1
49
+ psutil==7.0.0
50
+ gitdb==4.0.12
51
+ fsspec==2025.5.1
52
+ pydantic==2.11.6
53
+ sympy==1.14.0
54
+ torchaudio==2.7.1
55
+ nvidia-nccl-cu12==2.26.2
56
+ propcache==0.3.2
57
+ wandb==0.20.1
58
+ filelock==3.18.0
59
+ packaging==25.0
60
+ nvidia-cuda-nvrtc-cu12==12.6.77
61
+ networkx==3.4.2
62
+ aiohttp==3.12.12
63
+ nvidia-cufft-cu12==11.3.0.4
64
+ nvidia-nvtx-cu12==12.6.77
65
+ wheel==0.45.1
66
+ yarl==1.20.1
67
+ certifi==2025.4.26
68
+ click==8.2.1
69
+ nvidia-cuda-runtime-cu12==12.6.77
70
+ rich==14.0.0
71
+ pillow==11.2.1
72
+ setproctitle==1.3.6
73
+ torchmetrics==1.7.3
74
+ lightning-utilities==0.14.3
75
+ torch==2.7.1
76
+ annotated-types==0.7.0
77
+ ClimaX==0.3.1
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/files/wandb-metadata.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-208-generic-x86_64-with-glibc2.31",
3
+ "python": "CPython 3.10.18",
4
+ "startedAt": "2025-06-17T09:05:28.174321Z",
5
+ "args": [
6
+ "--config",
7
+ "configs/Unet.yaml"
8
+ ],
9
+ "program": "/home/radaric/weather_forecast/Unet/src/train.py",
10
+ "codePath": "src/train.py",
11
+ "email": "weatherforecast1024hcmut@gmail.com",
12
+ "root": "checkpoint/Unet/wandb_logs",
13
+ "host": "u116613",
14
+ "executable": "/home/radaric/.conda/envs/unet/bin/python",
15
+ "codePathLocal": "src/train.py",
16
+ "cpu_count": 48,
17
+ "cpu_count_logical": 96,
18
+ "gpu": "NVIDIA RTX A6000",
19
+ "gpu_count": 7,
20
+ "disk": {
21
+ "/": {
22
+ "total": "1877998821376",
23
+ "used": "1470173900800"
24
+ }
25
+ },
26
+ "memory": {
27
+ "total": "540953096192"
28
+ },
29
+ "cpu": {
30
+ "count": 48,
31
+ "countLogical": 96
32
+ },
33
+ "gpu_nvidia": [
34
+ {
35
+ "name": "NVIDIA RTX A6000",
36
+ "memoryTotal": "51527024640",
37
+ "cudaCores": 10752,
38
+ "architecture": "Ampere",
39
+ "uuid": "GPU-fb5a2de4-c79a-f2d0-a864-a6271ad28ae6"
40
+ },
41
+ {
42
+ "name": "NVIDIA RTX A6000",
43
+ "memoryTotal": "51527024640",
44
+ "cudaCores": 10752,
45
+ "architecture": "Ampere",
46
+ "uuid": "GPU-1a8c199b-93ca-3fec-6459-a5515bf1b12b"
47
+ },
48
+ {
49
+ "name": "NVIDIA RTX A6000",
50
+ "memoryTotal": "51527024640",
51
+ "cudaCores": 10752,
52
+ "architecture": "Ampere",
53
+ "uuid": "GPU-4d0c0cac-f72d-9dc7-9ac0-60cf8803134b"
54
+ },
55
+ {
56
+ "name": "NVIDIA RTX A6000",
57
+ "memoryTotal": "51527024640",
58
+ "cudaCores": 10752,
59
+ "architecture": "Ampere",
60
+ "uuid": "GPU-2887d599-b7bf-d31f-4425-84fa60413306"
61
+ },
62
+ {
63
+ "name": "NVIDIA RTX A6000",
64
+ "memoryTotal": "51527024640",
65
+ "cudaCores": 10752,
66
+ "architecture": "Ampere",
67
+ "uuid": "GPU-86e7c8f1-cde6-4163-dc15-52cef50545bd"
68
+ },
69
+ {
70
+ "name": "NVIDIA RTX A6000",
71
+ "memoryTotal": "51527024640",
72
+ "cudaCores": 10752,
73
+ "architecture": "Ampere",
74
+ "uuid": "GPU-460d754a-f551-6943-c142-b5b8f2f86236"
75
+ },
76
+ {
77
+ "name": "NVIDIA RTX A6000",
78
+ "memoryTotal": "51527024640",
79
+ "cudaCores": 10752,
80
+ "architecture": "Ampere",
81
+ "uuid": "GPU-553ca63b-335c-4c11-94eb-29c777adb307"
82
+ }
83
+ ],
84
+ "cudaVersion": "12.3"
85
+ }
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-core.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2025-06-17T09:05:27.98855733Z","level":"INFO","msg":"main: starting server","port-filename":"/tmp/tmpng05fvru/port-1311468.txt","pid":1311468,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
2
+ {"time":"2025-06-17T09:05:27.99003933Z","level":"INFO","msg":"Will exit if parent process dies.","ppid":1311468}
3
+ {"time":"2025-06-17T09:05:27.99004801Z","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":46731,"Zone":""}}
4
+ {"time":"2025-06-17T09:05:28.169034214Z","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:38220"}
5
+ {"time":"2025-06-17T09:05:28.178996979Z","level":"INFO","msg":"handleInformInit: received","streamId":"0nx0l2dh","id":"127.0.0.1:38220"}
6
+ {"time":"2025-06-17T09:05:29.423327647Z","level":"INFO","msg":"handleInformInit: stream started","streamId":"0nx0l2dh","id":"127.0.0.1:38220"}
7
+ {"time":"2025-06-17T10:10:24.148447187Z","level":"INFO","msg":"Parent process exited, terminating service process."}
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2025-06-17T09:05:28.179242652Z","level":"INFO","msg":"stream: starting","core version":"0.20.1","symlink path":"checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-core.log"}
2
+ {"time":"2025-06-17T09:05:29.423278937Z","level":"INFO","msg":"stream: created new stream","id":"0nx0l2dh"}
3
+ {"time":"2025-06-17T09:05:29.423321777Z","level":"INFO","msg":"stream: started","id":"0nx0l2dh"}
4
+ {"time":"2025-06-17T09:05:29.423393558Z","level":"INFO","msg":"sender: started","stream_id":"0nx0l2dh"}
5
+ {"time":"2025-06-17T09:05:29.423393088Z","level":"INFO","msg":"writer: Do: started","stream_id":"0nx0l2dh"}
6
+ {"time":"2025-06-17T09:05:29.423465179Z","level":"INFO","msg":"handler: started","stream_id":"0nx0l2dh"}
7
+ {"time":"2025-06-17T09:05:30.100696875Z","level":"INFO","msg":"Starting system monitor"}
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug.log ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Current SDK version is 0.20.1
2
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Configure stats pid to 1311468
3
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/.config/wandb/settings
4
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from /home/radaric/weather_forecast/Unet/wandb/settings
5
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_setup.py:_flush():81] Loading settings from environment variables
6
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:setup_run_log_directory():703] Logging user logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug.log
7
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/logs/debug-internal.log
8
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():831] calling init triggers
9
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():836] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {}}
11
+ 2025-06-17 09:05:27,961 INFO MainThread:1311468 [wandb_init.py:init():872] starting backend
12
+ 2025-06-17 09:05:28,169 INFO MainThread:1311468 [wandb_init.py:init():875] sending inform_init request
13
+ 2025-06-17 09:05:28,174 INFO MainThread:1311468 [wandb_init.py:init():883] backend started and connected
14
+ 2025-06-17 09:05:28,175 INFO MainThread:1311468 [wandb_init.py:init():956] updated telemetry
15
+ 2025-06-17 09:05:28,175 INFO MainThread:1311468 [wandb_init.py:init():980] communicating run to backend with 90.0 second timeout
16
+ 2025-06-17 09:05:30,098 INFO MainThread:1311468 [wandb_init.py:init():1032] starting run threads in backend
17
+ 2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_console_start():2453] atexit reg
18
+ 2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2301] redirect: wrap_raw
19
+ 2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2370] Wrapping output streams.
20
+ 2025-06-17 09:05:30,175 INFO MainThread:1311468 [wandb_run.py:_redirect():2393] Redirects installed.
21
+ 2025-06-17 09:05:30,177 INFO MainThread:1311468 [wandb_init.py:init():1078] run started, returning control to user process
22
+ 2025-06-17 09:05:31,151 INFO MainThread:1311468 [wandb_run.py:_config_callback():1358] config_cb None None {'pretrained_path': '', 'lr': 0.0005, 'beta_1': 0.9, 'beta_2': 0.99, 'weight_decay': 1e-05, 'warmup_epochs': 10, 'max_epochs': 50, 'warmup_start_lr': 1e-08, 'eta_min': 1e-08, '_instantiator': 'pytorch_lightning.cli.instantiate_module', 'dir_data': '/data/weather2025/NhaBe/', 'batch_size': 1, 'hours_predicted': 3, 'num_workers': 4, 'pin_memory': False, 'time_points_rad': 1, 'time_points_sat': 1, 'sat_inp_vars': 'total_precipitation', 'sat_out_vars': 'total_precipitation', 'sat_size': 25, 'rad_inp_vars': 'precipitation', 'rad_out_vars': 'precipitation', 'rad_size': 400, 'ablation': 'no'}
checkpoint/Unet/wandb_logs/wandb/run-20250617_090527-0nx0l2dh/run-0nx0l2dh.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72554d0fd15b4685f86d20808dfa02fa74043afb99ece44f6b8184bd0a6f9bfc
3
+ size 66093056
configs/AttR2Unet.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 42
2
+
3
+ # ---------------------------- TRAINER -------------------------------------------
4
+ trainer:
5
+ default_root_dir: "checkpoint/AttR2Unet"
6
+ precision: "16-mixed"
7
+ min_epochs: 1
8
+ max_epochs: 100
9
+ accelerator: cuda
10
+ # limit_train_batches: 10
11
+ devices: [6]
12
+ # strategy: ddp
13
+ num_nodes: 1
14
+ enable_progress_bar: true
15
+ sync_batchnorm: True
16
+ enable_checkpointing: True
17
+ # debugging
18
+ fast_dev_run: false
19
+ logger:
20
+ class_path: pytorch_lightning.loggers.CSVLogger
21
+ init_args:
22
+ save_dir: "checkpoint/AttR2Unet/logs"
23
+ name: null
24
+ version: null
25
+
26
+ callbacks:
27
+ - class_path: pytorch_lightning.callbacks.LearningRateMonitor
28
+ init_args:
29
+ logging_interval: "step"
30
+
31
+ - class_path: pytorch_lightning.callbacks.ModelCheckpoint
32
+ init_args:
33
+ dirpath: "checkpoint/AttR2Unet/checkpoints"
34
+ monitor: "val/mse" # name of the logged metric which determines when model is improving
35
+ mode: "min" # "max" means higher metric value is better, can be also "min"
36
+ save_top_k: 1 # save k best models (determined by above metric)
37
+ save_last: True # additionally always save model from last epoch
38
+ verbose: False
39
+ filename: "epoch_{epoch:03d}"
40
+ auto_insert_metric_name: False
41
+
42
+ - class_path: pytorch_lightning.callbacks.EarlyStopping
43
+ init_args:
44
+ monitor: "val/mse" # name of the logged metric which determines when model is improving
45
+ mode: "min" # "max" means higher metric value is better, can be also "min"
46
+ patience: 10 # how many validation epochs of not improving until training stops
47
+ min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
48
+
49
+ - class_path: pytorch_lightning.callbacks.RichModelSummary
50
+ init_args:
51
+ max_depth: -1
52
+
53
+ - class_path: pytorch_lightning.callbacks.RichProgressBar
54
+
55
+ # ---------------------------- MODEL -------------------------------------------
56
+ model:
57
+ pretrained_path: ""
58
+ beta_1: 0.9
59
+ beta_2: 0.99
60
+ lr: 5e-4
61
+ weight_decay: 1e-5
62
+ warmup_epochs: 10
63
+ max_epochs: 50
64
+ warmup_start_lr: 1e-8
65
+ eta_min: 1e-8
66
+ net:
67
+ model_type: "AttR2Unet"
68
+ num_channel: 1
69
+
70
+ # ---------------------------- DATA -------------------------------------------
71
+ data:
72
+ dir_data: "/data/data_WF/ablation/ablation_time"
73
+ ablation: "time"
74
+ sat_size: 20
75
+ rad_size: 640
76
+ time_points_rad: 1
77
+ time_points_sat: 1
78
+ sat_inp_vars: ["total_precipitation"]
79
+ sat_out_vars: "total_precipitation"
80
+ rad_inp_vars: ["precipitation"]
81
+ rad_out_vars: "precipitation"
82
+ hours_predicted: 3
83
+ batch_size: 32
84
+ num_workers: 4
85
+ pin_memory: False
86
+
configs/AttUnet.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 42
2
+
3
+ # ---------------------------- TRAINER -------------------------------------------
4
+ trainer:
5
+ default_root_dir: "checkpoint/AttUnet"
6
+ precision: "16-mixed"
7
+ min_epochs: 1
8
+ max_epochs: 100
9
+ accelerator: cuda
10
+ # limit_train_batches: 10
11
+ devices: [5]
12
+ # strategy: ddp
13
+ num_nodes: 1
14
+ enable_progress_bar: true
15
+ sync_batchnorm: True
16
+ enable_checkpointing: True
17
+ # debugging
18
+ fast_dev_run: false
19
+ logger:
20
+ class_path: pytorch_lightning.loggers.CSVLogger
21
+ init_args:
22
+ save_dir: "checkpoint/AttUnet/logs"
23
+ name: null
24
+ version: null
25
+
26
+ callbacks:
27
+ - class_path: pytorch_lightning.callbacks.LearningRateMonitor
28
+ init_args:
29
+ logging_interval: "step"
30
+
31
+ - class_path: pytorch_lightning.callbacks.ModelCheckpoint
32
+ init_args:
33
+ dirpath: "checkpoint/AttUnet/checkpoints"
34
+ monitor: "val/mse" # name of the logged metric which determines when model is improving
35
+ mode: "min" # "max" means higher metric value is better, can be also "min"
36
+ save_top_k: 1 # save k best models (determined by above metric)
37
+ save_last: True # additionally always save model from last epoch
38
+ verbose: False
39
+ filename: "epoch_{epoch:03d}"
40
+ auto_insert_metric_name: False
41
+
42
+ - class_path: pytorch_lightning.callbacks.EarlyStopping
43
+ init_args:
44
+ monitor: "val/mse" # name of the logged metric which determines when model is improving
45
+ mode: "min" # "max" means higher metric value is better, can be also "min"
46
+ patience: 10 # how many validation epochs of not improving until training stops
47
+ min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
48
+
49
+ - class_path: pytorch_lightning.callbacks.RichModelSummary
50
+ init_args:
51
+ max_depth: -1
52
+
53
+ - class_path: pytorch_lightning.callbacks.RichProgressBar
54
+
55
+ # ---------------------------- MODEL -------------------------------------------
56
+ model:
57
+ pretrained_path: ""
58
+ beta_1: 0.9
59
+ beta_2: 0.99
60
+ lr: 5e-4
61
+ weight_decay: 1e-5
62
+ warmup_epochs: 10
63
+ max_epochs: 50
64
+ warmup_start_lr: 1e-8
65
+ eta_min: 1e-8
66
+ net:
67
+ model_type: "AttUnet"
68
+ num_channel: 1
69
+
70
+ # ---------------------------- DATA -------------------------------------------
71
+ data:
72
+ dir_data: "/data/data_WF/ablation/ablation_time"
73
+ ablation: "time"
74
+ sat_size: 20
75
+ rad_size: 640
76
+ time_points_rad: 1
77
+ time_points_sat: 1
78
+ sat_inp_vars: ["total_precipitation"]
79
+ sat_out_vars: "total_precipitation"
80
+ rad_inp_vars: ["precipitation"]
81
+ rad_out_vars: "precipitation"
82
+ hours_predicted: 3
83
+ batch_size: 32
84
+ num_workers: 4
85
+ pin_memory: False
86
+
configs/Nothing.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 42
2
+
3
+ # ---------------------------- TRAINER -------------------------------------------
4
+ trainer:
5
+ default_root_dir: "checkpoint/Nothing"
6
+ precision: "16-mixed"
7
+ min_epochs: 1
8
+ max_epochs: 100
9
+ accelerator: cuda
10
+ # limit_train_batches: 10
11
+ devices: [4]
12
+ # strategy: ddp
13
+ num_nodes: 1
14
+ enable_progress_bar: true
15
+ sync_batchnorm: True
16
+ enable_checkpointing: True
17
+ # debugging
18
+ fast_dev_run: false
19
+ logger:
20
+ class_path: pytorch_lightning.loggers.CSVLogger
21
+ init_args:
22
+ save_dir: "checkpoint/Nothing/logs"
23
+ name: null
24
+ version: null
25
+
26
+ callbacks:
27
+ - class_path: pytorch_lightning.callbacks.LearningRateMonitor
28
+ init_args:
29
+ logging_interval: "step"
30
+
31
+ - class_path: pytorch_lightning.callbacks.ModelCheckpoint
32
+ init_args:
33
+ dirpath: "checkpoint/Nothing/checkpoints"
34
+ monitor: "val/sat" # name of the logged metric which determines when model is improving
35
+ mode: "min" # "max" means higher metric value is better, can be also "min"
36
+ save_top_k: 1 # save k best models (determined by above metric)
37
+ save_last: True # additionally always save model from last epoch
38
+ verbose: False
39
+ filename: "epoch_{epoch:03d}"
40
+ auto_insert_metric_name: False
41
+
42
+ - class_path: pytorch_lightning.callbacks.EarlyStopping
43
+ init_args:
44
+ monitor: "val/sat" # name of the logged metric which determines when model is improving
45
+ mode: "min" # "max" means higher metric value is better, can be also "min"
46
+ patience: 10 # how many validation epochs of not improving until training stops
47
+ min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
48
+
49
+ - class_path: pytorch_lightning.callbacks.RichModelSummary
50
+ init_args:
51
+ max_depth: -1
52
+
53
+ - class_path: pytorch_lightning.callbacks.RichProgressBar
54
+
55
+ # ---------------------------- MODEL -------------------------------------------
56
+ model:
57
+ pretrained_path: ""
58
+ beta_1: 0.9
59
+ beta_2: 0.99
60
+ lr: 5e-4
61
+ weight_decay: 1e-5
62
+ warmup_epochs: 10
63
+ max_epochs: 50
64
+ warmup_start_lr: 1e-8
65
+ eta_min: 1e-8
66
+ net:
67
+ model_type: "Nothing"
68
+ num_channel: 1
69
+
70
+ # ---------------------------- DATA -------------------------------------------
71
+ data:
72
+ dir_data: "/data/data_WF/ablation/ablation_time"
73
+ ablation: "time"
74
+ sat_size: 20
75
+ rad_size: 640
76
+ time_points_rad: 1
77
+ time_points_sat: 1
78
+ sat_inp_vars: ["total_precipitation"]
79
+ sat_out_vars: "total_precipitation"
80
+ rad_inp_vars: ["precipitation"]
81
+ rad_out_vars: "precipitation"
82
+ hours_predicted: 3
83
+ batch_size: 8
84
+ num_workers: 4
85
+ pin_memory: False
86
+
configs/R2Unet.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 42
2
+
3
+ # ---------------------------- TRAINER -------------------------------------------
4
+ trainer:
5
+ default_root_dir: "checkpoint/R2Unet"
6
+ precision: "16-mixed"
7
+ min_epochs: 1
8
+ max_epochs: 100
9
+ accelerator: cuda
10
+ # limit_train_batches: 10
11
+ devices: [4]
12
+ # strategy: ddp
13
+ num_nodes: 1
14
+ enable_progress_bar: true
15
+ sync_batchnorm: True
16
+ enable_checkpointing: True
17
+ # debugging
18
+ fast_dev_run: false
19
+ logger:
20
+ class_path: pytorch_lightning.loggers.CSVLogger
21
+ init_args:
22
+ save_dir: "checkpoint/R2Unet/logs"
23
+ name: null
24
+ version: null
25
+
26
+ callbacks:
27
+ - class_path: pytorch_lightning.callbacks.LearningRateMonitor
28
+ init_args:
29
+ logging_interval: "step"
30
+
31
+ - class_path: pytorch_lightning.callbacks.ModelCheckpoint
32
+ init_args:
33
+ dirpath: "checkpoint/R2Unet/checkpoints"
34
+ monitor: "val/mse" # name of the logged metric which determines when model is improving
35
+ mode: "min" # "max" means higher metric value is better, can be also "min"
36
+ save_top_k: 1 # save k best models (determined by above metric)
37
+ save_last: True # additionally always save model from last epoch
38
+ verbose: False
39
+ filename: "epoch_{epoch:03d}"
40
+ auto_insert_metric_name: False
41
+
42
+ - class_path: pytorch_lightning.callbacks.EarlyStopping
43
+ init_args:
44
+ monitor: "val/mse" # name of the logged metric which determines when model is improving
45
+ mode: "min" # "max" means higher metric value is better, can be also "min"
46
+ patience: 10 # how many validation epochs of not improving until training stops
47
+ min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
48
+
49
+ - class_path: pytorch_lightning.callbacks.RichModelSummary
50
+ init_args:
51
+ max_depth: -1
52
+
53
+ - class_path: pytorch_lightning.callbacks.RichProgressBar
54
+ init_args:
55
+ theme:
56
+ description: "white"
57
+ progress_bar: "#6206E0"
58
+ progress_bar_finished: "green"
59
+ progress_bar_pulse: "cyan"
60
+ batch_progress: "white"
61
+ time: "grey42"
62
+ processing_speed: "grey70"
63
+ metrics: "white"
64
+ # ---------------------------- MODEL -------------------------------------------
65
+ model:
66
+ pretrained_path: ""
67
+ beta_1: 0.9
68
+ beta_2: 0.99
69
+ lr: 5e-4
70
+ weight_decay: 1e-5
71
+ warmup_epochs: 10
72
+ max_epochs: 50
73
+ warmup_start_lr: 1e-8
74
+ eta_min: 1e-8
75
+ net:
76
+ model_type: "R2Unet"
77
+ num_channel: 1
78
+
79
+ # ---------------------------- DATA -------------------------------------------
80
+ data:
81
+ dir_data: "/data/data_WF/ablation/ablation_time"
82
+ ablation: "time"
83
+ sat_size: 20
84
+ rad_size: 640
85
+ time_points_rad: 1
86
+ time_points_sat: 1
87
+ sat_inp_vars: ["total_precipitation"]
88
+ sat_out_vars: "total_precipitation"
89
+ rad_inp_vars: ["precipitation"]
90
+ rad_out_vars: "precipitation"
91
+ hours_predicted: 3
92
+ batch_size: 32
93
+ num_workers: 4
94
+ pin_memory: False
95
+
configs/Unet.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 42
2
+
3
+ # ---------------------------- TRAINER -------------------------------------------
4
+ trainer:
5
+ default_root_dir: "checkpoint/Unet"
6
+ precision: "16-mixed"
7
+ min_epochs: 1
8
+ max_epochs: 100
9
+ accelerator: cuda
10
+ # limit_train_batches: 10
11
+ devices: [5]
12
+ # strategy: ddp
13
+ num_nodes: 1
14
+ enable_progress_bar: true
15
+ sync_batchnorm: True
16
+ enable_checkpointing: True
17
+ # debugging
18
+ fast_dev_run: false
19
+ logger:
20
+ - class_path: pytorch_lightning.loggers.WandbLogger
21
+ init_args:
22
+ project: "NhaBe"
23
+ name: "UnetNhaBe"
24
+ save_dir: "checkpoint/Unet/wandb_logs"
25
+ log_model: False
26
+ - class_path: pytorch_lightning.loggers.CSVLogger
27
+ init_args:
28
+ save_dir: "checkpoint/Unet/csv_logs"
29
+ name: null
30
+ version: null
31
+
32
+ callbacks:
33
+ - class_path: pytorch_lightning.callbacks.LearningRateMonitor
34
+ init_args:
35
+ logging_interval: "step"
36
+
37
+ - class_path: pytorch_lightning.callbacks.ModelCheckpoint
38
+ init_args:
39
+ dirpath: "checkpoint/Unet/checkpoints"
40
+ monitor: "val/mse" # name of the logged metric which determines when model is improving
41
+ mode: "min" # "max" means higher metric value is better, can be also "min"
42
+ save_top_k: 1 # save k best models (determined by above metric)
43
+ save_last: True # additionally always save model from last epoch
44
+ verbose: False
45
+ filename: "epoch_{epoch:03d}"
46
+ auto_insert_metric_name: False
47
+
48
+ - class_path: pytorch_lightning.callbacks.EarlyStopping
49
+ init_args:
50
+ monitor: "val/mse" # name of the logged metric which determines when model is improving
51
+ mode: "min" # "max" means higher metric value is better, can be also "min"
52
+ patience: 10 # how many validation epochs of not improving until training stops
53
+ min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement
54
+
55
+ - class_path: pytorch_lightning.callbacks.RichModelSummary
56
+ init_args:
57
+ max_depth: -1
58
+
59
+ # - class_path: pytorch_lightning.callbacks.RichProgressBar
60
+ # init_args:
61
+ # theme:
62
+ # description: "white"
63
+ # progress_bar: "#6206E0"
64
+ # progress_bar_finished: "green"
65
+ # progress_bar_pulse: "cyan"
66
+ # batch_progress: "white"
67
+ # time: "grey42"
68
+ # processing_speed: "grey70"
69
+ # metrics: "white"
70
+ # ---------------------------- MODEL -------------------------------------------
71
+ model:
72
+ pretrained_path: ""
73
+ beta_1: 0.9
74
+ beta_2: 0.99
75
+ lr: 5e-4
76
+ weight_decay: 1e-5
77
+ warmup_epochs: 10
78
+ max_epochs: 50
79
+ warmup_start_lr: 1e-8
80
+ eta_min: 1e-8
81
+ net:
82
+ model_type: "Unet"
83
+ rad_channel: 1
84
+ sat_channel: 1
85
+ rad_size: 400
86
+ sat_size: 25
87
+
88
+ # ---------------------------- DATA -------------------------------------------
89
+ data:
90
+ dir_data: "/data/weather2025/NhaBe/"
91
+ ablation: "no"
92
+ rad_size: 400
93
+ sat_size: 25
94
+ time_points_rad: 1
95
+ time_points_sat: 1
96
+ sat_inp_vars: "total_precipitation"
97
+ sat_out_vars: "total_precipitation"
98
+ rad_inp_vars: "precipitation"
99
+ rad_out_vars: "precipitation"
100
+ hours_predicted: 3
101
+ batch_size: 1
102
+ num_workers: 8
103
+ pin_memory: False
104
+
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "setuptools-scm"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "FusionModel"
7
+ version = "0.3.1"
8
+ authors =[
9
+ {name="Khanh Vinh Bui", email="khanhvinhbui0512@gmail.com"},
10
+ {name="Hong Trang Le", email="lhtrang@hcmut.edu.vn"}
11
+ ]
12
+ description = ""
13
+ readme = "README.md"
14
+ requires-python = ">=3.10"
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "License :: OSI Approved :: MIT License",
18
+ ]
19
+
20
+ [tool.setuptools.packages.find]
21
+ where = ["."]
src/__pycache__/arch.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
src/__pycache__/arch.cpython-312.pyc ADDED
Binary file (27.2 kB). View file
 
src/__pycache__/arch.cpython-38.pyc ADDED
Binary file (12.8 kB). View file
 
src/__pycache__/datamodule.cpython-310.pyc ADDED
Binary file (9.34 kB). View file
 
src/__pycache__/datamodule.cpython-312.pyc ADDED
Binary file (23.8 kB). View file
 
src/__pycache__/lr_scheduler.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
src/__pycache__/lr_scheduler.cpython-312.pyc ADDED
Binary file (5.55 kB). View file
 
src/__pycache__/metric.cpython-310.pyc ADDED
Binary file (1.74 kB). View file
 
src/__pycache__/metric.cpython-312.pyc ADDED
Binary file (3.67 kB). View file
 
src/__pycache__/module.cpython-310.pyc ADDED
Binary file (6.35 kB). View file
 
src/__pycache__/module.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
src/__pycache__/module.cpython-38.pyc ADDED
Binary file (6.22 kB). View file
 
src/__pycache__/train.cpython-38.pyc ADDED
Binary file (983 Bytes). View file
 
src/arch.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import glob
4
+ import math
5
+ import torch
6
+ import torchvision
7
+ # For everything
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn import CrossEntropyLoss, Linear, MSELoss
12
+ from torch.nn import ConvTranspose2d, Conv2d, MaxPool2d, BatchNorm2d
13
+ # For our model
14
+ import torchvision.models as models
15
+ from torchvision import datasets, transforms
16
+ from torchvision.io import read_image
17
+ from torch.utils.data import DataLoader, Dataset
18
+ import torch.optim as optim
19
+ from torch.autograd import Variable
20
+ from torchsummary import summary
21
+ class Nothing(nn.Module):
22
+ def __init__(self):
23
+ super(Nothing,self).__init__()
24
+ def forward(self, radar,satellite):
25
+ return radar, satellite
26
+
27
+ class ConvBlock(nn.Module):
28
+ def __init__(self, in_channels, out_channels):
29
+ super(ConvBlock, self).__init__()
30
+ # number of input channels is a number of filters in the previous layer
31
+ # number of output channels is a number of filters in the current layer
32
+ # "same" convolutions
33
+ self.conv = nn.Sequential(
34
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
35
+ nn.BatchNorm2d(out_channels),
36
+ nn.ReLU(inplace=True),
37
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
38
+ nn.BatchNorm2d(out_channels),
39
+ nn.ReLU(inplace=True)
40
+ )
41
+
42
+ def forward(self, x):
43
+ x = self.conv(x)
44
+ return x
45
+
46
+ class UpConv(nn.Module):
47
+ def __init__(self, in_channels, out_channels):
48
+ super(UpConv, self).__init__()
49
+ self.up = nn.Sequential(
50
+ nn.Upsample(scale_factor=2),
51
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True),
52
+ nn.BatchNorm2d(out_channels),
53
+ nn.ReLU(inplace=True)
54
+ )
55
+ def forward(self, x):
56
+ x = self.up(x)
57
+ return x
58
+ class AttentionBlock(nn.Module):
59
+ """Attention block with learnable parameters"""
60
+ def __init__(self, F_g, F_l, n_coefficients):
61
+ """
62
+ :param F_g: number of feature maps (channels) in previous layer
63
+ :param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection
64
+ :param n_coefficients: number of learnable multi-dimensional attention coefficients
65
+ """
66
+
67
+ super(AttentionBlock, self).__init__()
68
+
69
+ self.W_gate = nn.Sequential(
70
+ nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
71
+ nn.BatchNorm2d(n_coefficients)
72
+ )
73
+
74
+ self.W_x = nn.Sequential(
75
+ nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True),
76
+ nn.BatchNorm2d(n_coefficients)
77
+ )
78
+
79
+ self.psi = nn.Sequential(
80
+ nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True),
81
+ nn.BatchNorm2d(1),
82
+ nn.Sigmoid()
83
+ )
84
+
85
+ self.relu = nn.ReLU(inplace=True)
86
+
87
+ def forward(self, gate, skip_connection):
88
+ """
89
+ :param gate: gating signal from previous layer
90
+ :param skip_connection: activation from corresponding encoder layer
91
+ :return: output activations
92
+ """
93
+ g1 = self.W_gate(gate)
94
+ x1 = self.W_x(skip_connection)
95
+ psi = self.relu(g1 + x1)
96
+ psi = self.psi(psi)
97
+ out = skip_connection * psi
98
+ return out
99
+
100
+ class Recurrent_block(nn.Module):
101
+ def __init__(self,ch_out,t=2):
102
+ super(Recurrent_block,self).__init__()
103
+ self.t = t
104
+ self.ch_out = ch_out
105
+ self.conv = nn.Sequential(
106
+ nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding='same',bias=True),
107
+ nn.BatchNorm2d(ch_out),
108
+ nn.ReLU(inplace=True)
109
+ )
110
+
111
+ def forward(self,x):
112
+ for i in range(self.t):
113
+
114
+ if i==0:
115
+ x1 = self.conv(x)
116
+
117
+ x1 = self.conv(x+x1)
118
+ return x1
119
+
120
+ class RRCNN_block(nn.Module):
121
+ def __init__(self,ch_in,ch_out,t=2):
122
+ super(RRCNN_block,self).__init__()
123
+ self.RCNN = nn.Sequential(
124
+ Recurrent_block(ch_out,t=t),
125
+ Recurrent_block(ch_out,t=t)
126
+ )
127
+ self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding='same')
128
+
129
+ def forward(self,x):
130
+ x = self.Conv_1x1(x)
131
+ x1 = self.RCNN(x)
132
+ return x+x1
133
+
134
+
135
+ class single_conv(nn.Module):
136
+ def __init__(self,ch_in,ch_out):
137
+ super(single_conv,self).__init__()
138
+ self.conv = nn.Sequential(
139
+ nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding='same',bias=True),
140
+ nn.BatchNorm2d(ch_out),
141
+ nn.ReLU(inplace=True)
142
+ )
143
+
144
+ def forward(self,x):
145
+ x = self.conv(x)
146
+ return x
147
+ class Unet(nn.Module):
148
+ def __init__(self, rad_channel=1,sat_channel=1, rad_size=640, sat_size=20):
149
+ super(Unet, self).__init__()
150
+ assert rad_size % sat_size == 0, "rad_size must be divisible by sat_size"
151
+ ratio = rad_size // sat_size
152
+ assert (ratio & (ratio - 1)) == 0, "rad_size/sat_size must be a power of 2"
153
+ self.n_pool = int(math.log2(ratio))
154
+ # Encoder
155
+ self.encoder_blocks = nn.ModuleList()
156
+ self.pools = nn.ModuleList()
157
+ for i in range(self.n_pool):
158
+ in_c = rad_channel * (2**(i))
159
+ out_c = rad_channel * (2**(i+1))
160
+ self.encoder_blocks.append(ConvBlock(in_c, out_c))
161
+ if i < self.n_pool:
162
+ self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2))
163
+ # Bottleneck
164
+ self.mid_conv_1 = single_conv(out_c, out_c)
165
+ self.mid_conv_2 = single_conv(sat_channel, out_c)
166
+ self.mid_merge = ConvBlock(2*out_c, out_c)
167
+ # Decoder
168
+ self.up_convs = nn.ModuleList()
169
+ self.decoder_blocks = nn.ModuleList()
170
+ for i in reversed(range(self.n_pool)):
171
+ up_in = rad_channel * (2**(i+2))
172
+ up_out = rad_channel * (2**(i+1))
173
+ self.up_convs.append(UpConv(up_in, up_out))
174
+ self.decoder_blocks.append(ConvBlock(up_in, up_out))
175
+ self.final_decoder = ConvBlock(4*rad_channel, 2*rad_channel)
176
+ self.out_conv_R = nn.Conv2d(2*rad_channel, rad_channel, kernel_size=1, padding='same')
177
+ self.out_conv_S = nn.Conv2d(out_c, sat_channel, kernel_size=1, padding='same')
178
+ def forward(self, radar, satellite):
179
+ # Encoding
180
+ enc_feats = []
181
+ x = radar
182
+ for i, block in enumerate(self.encoder_blocks):
183
+ x = block(x)
184
+ enc_feats.append(x)
185
+ if i < self.n_pool:
186
+ x = self.pools[i](x)
187
+ # Bottleneck
188
+ x = F.relu(self.mid_conv_1(x))
189
+ y = F.relu(self.mid_conv_2(satellite))
190
+ x = torch.cat((x, y), dim=1)
191
+
192
+ mid_out = self.mid_merge(x)
193
+ pred_sat = self.out_conv_S(mid_out)
194
+ # Decoding
195
+ x = x # input to decoder is original x before mid_merge
196
+ for i in range(self.n_pool):
197
+ x = self.up_convs[i](x)
198
+ x = torch.cat((enc_feats[self.n_pool - 1 - i], x), dim=1)
199
+ x = self.decoder_blocks[i](x)
200
+ x = torch.cat((enc_feats[0], x), dim=1)
201
+ x = self.final_decoder(x)
202
+ pred_rad = self.out_conv_R(x)
203
+ return pred_rad, pred_sat
204
+ # class Unet(nn.Module):
205
+ # def __init__(self,num_channel=1,rad_size=640,sat_size=20):
206
+ # super(Unet, self).__init__()
207
+ # self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
208
+ # self.Conv1 = ConvBlock(1, 2*num_channel)
209
+ # self.Conv2 = ConvBlock(2*num_channel, 4*num_channel)
210
+ # self.Conv3 = ConvBlock(4*num_channel, 8*num_channel)
211
+ # self.Conv4 = ConvBlock(8*num_channel, 16*num_channel)
212
+ # self.Conv5 = ConvBlock(16*num_channel, 32*num_channel)
213
+ # self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
214
+ # self.mid_conv_2 = single_conv(2, 32*num_channel)
215
+ # self.MidConv = ConvBlock(64*num_channel, 32*num_channel)
216
+ # self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
217
+ # self.Up5 = UpConv(64*num_channel, 32*num_channel)
218
+ # self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel)
219
+ # self.Up4 = UpConv(32*num_channel, 16*num_channel)
220
+ # self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel)
221
+ # self.Up3 = UpConv(16*num_channel, 8*num_channel)
222
+ # self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel)
223
+ # self.Up2 = UpConv(8*num_channel, 4*num_channel)
224
+ # self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel)
225
+ # self.Up1 = UpConv(4*num_channel, 2*num_channel)
226
+ # self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel)
227
+ # self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
228
+ # def forward(self, radar,satellite):
229
+ # e1 = self.Conv1(radar)
230
+ # e2 = self.MaxPool(e1)
231
+ # e2 = self.Conv2(e2)
232
+ # e3 = self.MaxPool(e2)
233
+ # e3 = self.Conv3(e3)
234
+ # e4 = self.MaxPool(e3)
235
+ # e4 = self.Conv4(e4)
236
+ # e5 = self.MaxPool(e4)
237
+ # e5 = self.Conv5(e5)
238
+ # e6 = self.MaxPool(e5)
239
+ # X = F.relu(self.mid_conv_1(e6))
240
+ # Y = F.relu(self.mid_conv_2(satellite))
241
+ # X = torch.cat((X,Y),1)
242
+ # Y = self.MidConv(X)
243
+ # pred_satellite = self.out_conv_S(Y)
244
+ # d5 = self.Up5(X)
245
+ # d5 = torch.cat((e5, d5), dim=1)
246
+ # d5 = self.UpConv5(d5)
247
+ # d4 = self.Up4(d5)
248
+ # d4 = torch.cat((e4, d4), dim=1)
249
+ # d4 = self.UpConv4(d4)
250
+ # d3 = self.Up3(d4)
251
+ # d3 = torch.cat((e3, d3), dim=1)
252
+ # d3 = self.UpConv3(d3)
253
+ # d2 = self.Up2(d3)
254
+ # d2 = torch.cat((e2, d2), dim=1)
255
+ # d2 = self.UpConv2(d2)
256
+ # d1 = self.Up1(d2)
257
+ # d0 = torch.cat((e1, d1), dim=1)
258
+ # d0 = self.UpConv1(d0)
259
+ # pred_radar = self.out_conv_R(d0)
260
+ # return pred_radar, pred_satellite
261
+
262
+ class R2Unet(nn.Module):
263
+ def __init__(self,num_channel=1,t=2):
264
+ super(R2Unet, self).__init__()
265
+ self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
266
+ self.RRCNN1 = RRCNN_block(5,2*num_channel,t=t)
267
+ self.RRCNN2 = RRCNN_block(2*num_channel,4*num_channel,t=t)
268
+ self.RRCNN3 = RRCNN_block(4*num_channel,8*num_channel,t=t)
269
+ self.RRCNN4 = RRCNN_block(8*num_channel,16*num_channel,t=t)
270
+ self.RRCNN5 = RRCNN_block(16*num_channel,32*num_channel,t=t)
271
+ self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
272
+ self.mid_conv_2 = single_conv(2, 32*num_channel)
273
+ self.MidConv = RRCNN_block(64*num_channel, 32*num_channel)
274
+ self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
275
+ self.Up5 = UpConv(64*num_channel, 32*num_channel)
276
+ self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel)
277
+ self.Up4 = UpConv(32*num_channel, 16*num_channel)
278
+ self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel)
279
+ self.Up3 = UpConv(16*num_channel, 8*num_channel)
280
+ self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel)
281
+ self.Up2 = UpConv(8*num_channel, 4*num_channel)
282
+ self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel)
283
+ self.Up1 = UpConv(4*num_channel, 2*num_channel)
284
+ self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel)
285
+ self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
286
+ def forward(self, radar,satellite):
287
+ e1 = self.RRCNN1(radar)
288
+ e2 = self.MaxPool(e1)
289
+ e2 = self.RRCNN2(e2)
290
+ e3 = self.MaxPool(e2)
291
+ e3 = self.RRCNN3(e3)
292
+ e4 = self.MaxPool(e3)
293
+ e4 = self.RRCNN4(e4)
294
+ e5 = self.MaxPool(e4)
295
+ e5 = self.RRCNN5(e5)
296
+ e6 = self.MaxPool(e5)
297
+ X = F.relu(self.mid_conv_1(e6))
298
+ Y = F.relu(self.mid_conv_2(satellite))
299
+ X = torch.cat((X,Y),1)
300
+ Y = self.MidConv(X)
301
+ pred_satellite = self.out_conv_S(Y)
302
+ d5 = self.Up5(X)
303
+ d5 = torch.cat((e5, d5), dim=1)
304
+ d5 = self.UpRRCNN5(d5)
305
+ d4 = self.Up4(d5)
306
+ d4 = torch.cat((e4, d4), dim=1)
307
+ d4 = self.UpRRCNN4(d4)
308
+ d3 = self.Up3(d4)
309
+ d3 = torch.cat((e3, d3), dim=1)
310
+ d3 = self.UpRRCNN3(d3)
311
+ d2 = self.Up2(d3)
312
+ d2 = torch.cat((e2, d2), dim=1)
313
+ d2 = self.UpRRCNN2(d2)
314
+ d1 = self.Up1(d2)
315
+ d0 = torch.cat((e1, d1), dim=1)
316
+ d0 = self.UpRRCNN1(d0)
317
+ pred_radar = self.out_conv_R(d0)
318
+ return pred_radar, pred_satellite
319
+ class AttUnet(nn.Module):
320
+ def __init__(self,num_channel=1):
321
+ super(AttUnet, self).__init__()
322
+ self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
323
+ self.Conv1 = ConvBlock(5, 2*num_channel)
324
+ self.Conv2 = ConvBlock(2*num_channel, 4*num_channel)
325
+ self.Conv3 = ConvBlock(4*num_channel, 8*num_channel)
326
+ self.Conv4 = ConvBlock(8*num_channel, 16*num_channel)
327
+ self.Conv5 = ConvBlock(16*num_channel, 32*num_channel)
328
+ self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
329
+ self.mid_conv_2 = single_conv(2, 32*num_channel)
330
+ self.MidConv = ConvBlock(64*num_channel, 32*num_channel)
331
+ self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
332
+ self.Up5 = UpConv(64*num_channel, 32*num_channel)
333
+ self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel)
334
+ self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel)
335
+ self.Up4 = UpConv(32*num_channel, 16*num_channel)
336
+ self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel)
337
+ self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel)
338
+ self.Up3 = UpConv(16*num_channel, 8*num_channel)
339
+ self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel)
340
+ self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel)
341
+ self.Up2 = UpConv(8*num_channel, 4*num_channel)
342
+ self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel)
343
+ self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel)
344
+ self.Up1 = UpConv(4*num_channel, 2*num_channel)
345
+ self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel)
346
+ self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel)
347
+ self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
348
+ def forward(self, radar,satellite):
349
+ e1 = self.Conv1(radar)
350
+ e2 = self.MaxPool(e1)
351
+ e2 = self.Conv2(e2)
352
+ e3 = self.MaxPool(e2)
353
+ e3 = self.Conv3(e3)
354
+ e4 = self.MaxPool(e3)
355
+ e4 = self.Conv4(e4)
356
+ e5 = self.MaxPool(e4)
357
+ e5 = self.Conv5(e5)
358
+ e6 = self.MaxPool(e5)
359
+ X = F.relu(self.mid_conv_1(e6))
360
+ Y = F.relu(self.mid_conv_2(satellite))
361
+ X = torch.cat((X,Y),1)
362
+ Y = self.MidConv(X)
363
+ pred_satellite = self.out_conv_S(Y)
364
+ d5 = self.Up5(X)
365
+ s4 = self.Att5(gate=d5, skip_connection=e5)
366
+ d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
367
+ d5 = self.UpConv5(d5)
368
+ d4 = self.Up4(d5)
369
+ s3 = self.Att4(gate=d4, skip_connection=e4)
370
+ d4 = torch.cat((s3, d4), dim=1)
371
+ d4 = self.UpConv4(d4)
372
+ d3 = self.Up3(d4)
373
+ s2 = self.Att3(gate=d3, skip_connection=e3)
374
+ d3 = torch.cat((s2, d3), dim=1)
375
+ d3 = self.UpConv3(d3)
376
+ d2 = self.Up2(d3)
377
+ s1 = self.Att2(gate=d2, skip_connection=e2)
378
+ d2 = torch.cat((s1, d2), dim=1)
379
+ d2 = self.UpConv2(d2)
380
+ d1 = self.Up1(d2)
381
+ s0 = self.Att1(gate=d1, skip_connection=e1)
382
+ d0 = torch.cat((s0, d1), dim=1)
383
+ d0 = self.UpConv1(d0)
384
+ pred_radar = self.out_conv_R(d0)
385
+ return pred_radar, pred_satellite
386
+ class AttR2Unet(nn.Module):
387
+ def __init__(self,num_channel=1,t=2):
388
+ super(AttR2Unet, self).__init__()
389
+ self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2)
390
+ self.RRCNN1 = RRCNN_block(5, 2*num_channel)
391
+ self.RRCNN2 = RRCNN_block(2*num_channel, 4*num_channel)
392
+ self.RRCNN3 = RRCNN_block(4*num_channel, 8*num_channel)
393
+ self.RRCNN4 = RRCNN_block(8*num_channel, 16*num_channel)
394
+ self.RRCNN5 = RRCNN_block(16*num_channel, 32*num_channel)
395
+ self.mid_conv_1 = single_conv(32*num_channel,32*num_channel)
396
+ self.mid_conv_2 = single_conv(2, 32*num_channel)
397
+ self.MidConv = RRCNN_block(64*num_channel, 32*num_channel)
398
+ self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same')
399
+ self.Up5 = UpConv(64*num_channel, 32*num_channel)
400
+ self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel)
401
+ self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel)
402
+ self.Up4 = UpConv(32*num_channel, 16*num_channel)
403
+ self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel)
404
+ self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel)
405
+ self.Up3 = UpConv(16*num_channel, 8*num_channel)
406
+ self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel)
407
+ self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel)
408
+ self.Up2 = UpConv(8*num_channel, 4*num_channel)
409
+ self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel)
410
+ self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel)
411
+ self.Up1 = UpConv(4*num_channel, 2*num_channel)
412
+ self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel)
413
+ self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel)
414
+ self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same')
415
+ def forward(self, radar,satellite):
416
+ e1 = self.RRCNN1(radar)
417
+ e2 = self.MaxPool(e1)
418
+ e2 = self.RRCNN2(e2)
419
+ e3 = self.MaxPool(e2)
420
+ e3 = self.RRCNN3(e3)
421
+ e4 = self.MaxPool(e3)
422
+ e4 = self.RRCNN4(e4)
423
+ e5 = self.MaxPool(e4)
424
+ e5 = self.RRCNN5(e5)
425
+ e6 = self.MaxPool(e5)
426
+ X = F.relu(self.mid_conv_1(e6))
427
+ Y = F.relu(self.mid_conv_2(satellite))
428
+ X = torch.cat((X,Y),1)
429
+ Y = self.MidConv(X)
430
+ pred_satellite = self.out_conv_S(Y)
431
+ d5 = self.Up5(X)
432
+ s4 = self.Att5(gate=d5, skip_connection=e5)
433
+ d5 = torch.cat((s4, d5), dim=1) # concatenate attention-weighted skip connection with previous layer output
434
+ d5 = self.UpRRCNN5(d5)
435
+ d4 = self.Up4(d5)
436
+ s3 = self.Att4(gate=d4, skip_connection=e4)
437
+ d4 = torch.cat((s3, d4), dim=1)
438
+ d4 = self.UpRRCNN4(d4)
439
+ d3 = self.Up3(d4)
440
+ s2 = self.Att3(gate=d3, skip_connection=e3)
441
+ d3 = torch.cat((s2, d3), dim=1)
442
+ d3 = self.UpRRCNN3(d3)
443
+ d2 = self.Up2(d3)
444
+ s1 = self.Att2(gate=d2, skip_connection=e2)
445
+ d2 = torch.cat((s1, d2), dim=1)
446
+ d2 = self.UpRRCNN2(d2)
447
+ d1 = self.Up1(d2)
448
+ s0 = self.Att1(gate=d1, skip_connection=e1)
449
+ d0 = torch.cat((s0, d1), dim=1)
450
+ d0 = self.UpRRCNN1(d0)
451
+ pred_radar = self.out_conv_R(d0)
452
+ return pred_radar, pred_satellite
453
+ class Network(nn.Module):
454
+ def __init__(self,model_type:str,rad_channel:int, sat_channel:int,rad_size:int,sat_size:int):
455
+ super(Network,self).__init__()
456
+ print(model_type)
457
+ if(model_type == "Nothing"):
458
+ self.net = Nothing()
459
+ elif(model_type == "Unet"):
460
+ self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
461
+ elif(model_type == "Unet"):
462
+ self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
463
+ elif(model_type == "R2Unet"):
464
+ self.net = R2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
465
+ elif(model_type == "AttUnet"):
466
+ self.net = AttUnet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
467
+ elif(model_type == "AttR2Unet"):
468
+ self.net = AttR2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size)
469
+ else:
470
+ raise ValueError("model_type is wrong")
471
+ def forward(self, radar,satellite):
472
+ pred_radar, pred_satellite = self.net.forward(radar,satellite)
473
+ return pred_radar, pred_satellite
src/datamodule.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import DataLoader, Dataset, random_split
3
+ import numpy as np
4
+ from datetime import datetime, timedelta
5
+ from torchvision import transforms
6
+ from pytorch_lightning import LightningDataModule, LightningModule
7
+ from pytorch_lightning.cli import LightningCLI
8
+ from torch.utils.data import DataLoader
9
+ import pytorch_lightning as L
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Tuple, Dict, List
13
+
14
+ # import optim
15
+
16
+ class DataReader(Dataset):
17
+ def __init__(
18
+ self, dir_data : str,
19
+ type_data : str,
20
+ rad_attribute : str ,
21
+ sat_attribute : str,
22
+ hours_predicted : int,
23
+ rad_predicted : str ,
24
+ sat_predicted : str ,
25
+ time_points_rad : int,
26
+ time_points_sat : int,
27
+ rad_size:int,
28
+ sat_size:int,
29
+ ablation = str,
30
+ ):
31
+ super().__init__()
32
+ self.base_dir=dir_data
33
+ self.type_data = type_data
34
+ if self.type_data == "train":
35
+ self.dir_data=os.path.join(dir_data, "train")
36
+ elif self.type_data =="test":
37
+ self.dir_data=os.path.join(dir_data, 'test')
38
+ elif self.type_data =="val":
39
+ self.dir_data=os.path.join(dir_data, 'val')
40
+ else:
41
+ raise ValueError("Type must be train, test or val")
42
+ self.sat_size = sat_size
43
+ self.rad_size = rad_size
44
+ self.hours_predicted = hours_predicted
45
+ self.rad_attribute = rad_attribute
46
+ self.sat_attribute = sat_attribute
47
+ self.rad_predicted = rad_predicted
48
+ self.sat_predicted = sat_predicted
49
+ self.time_points_rad = time_points_rad
50
+ self.time_points_sat = time_points_sat
51
+ self.transform_rad = None
52
+ self.transform_sat = None
53
+ self.ablation = ablation
54
+ # Create path for img
55
+ self.rad_mean = np.load(os.path.join(self.base_dir,'rad_mean.npz'))[self.rad_attribute]
56
+ self.rad_std = np.load(os.path.join(self.base_dir,'rad_std.npz'))[self.rad_attribute]
57
+ self.sat_mean = np.load(os.path.join(self.base_dir,'sat_mean.npz'))[self.sat_attribute]
58
+ self.sat_std = np.load(os.path.join(self.base_dir,'sat_std.npz'))[self.sat_attribute]
59
+ #Create transform
60
+ self.create_transform()
61
+ #Get list img
62
+ if(self.ablation == "no"):
63
+ self.list_img_dir = self.gen_list_img_no(self.dir_data)
64
+ elif(self.ablation == "rad"):
65
+ self.list_img_dir = self.gen_list_img_rad(self.dir_data)
66
+ elif(self.ablation == "sat"):
67
+ self.list_img_dir = self.gen_list_img_sat(self.dir_data)
68
+ elif(self.ablation == "full"):
69
+ self.list_img_dir = self.gen_list_img_full(self.dir_data)
70
+ elif(self.ablation == "time"):
71
+ self.list_img_dir = self.gen_list_img_time(self.dir_data)
72
+ else:
73
+ raise ValueError("Ablation must be no,rad,sat,full")
74
+ print(f"Number of {self.type_data } samples:",len(self.list_img_dir))
75
+ def __len__(self):
76
+ return len(self.list_img_dir)
77
+ def __getitem__(self, idx):
78
+ if(self.transform_rad):
79
+ inp_rad = self.transform_rad(np.load(self.list_img_dir[idx][0])[self.rad_attribute])
80
+ out_rad = self.transform_rad(np.load(self.list_img_dir[idx][2])[self.rad_predicted])
81
+ if(self.transform_sat):
82
+ inp_sat = self.transform_sat(np.load(self.list_img_dir[idx][1])[self.sat_attribute])
83
+ out_sat = self.transform_sat(np.load(self.list_img_dir[idx][3])[self.sat_predicted][0])
84
+ return inp_rad,inp_sat.float(),out_rad, out_sat.float()
85
+
86
+
87
+ def create_transform(self):
88
+ self.transform_rad = transforms.Compose([
89
+ transforms.ToTensor(),
90
+ transforms.Normalize(self.rad_mean,self.rad_std)
91
+ ])
92
+ self.transform_sat = transforms.Compose([
93
+ transforms.ToTensor(),
94
+ transforms.Normalize(self.sat_mean[0],self.sat_std[0]),
95
+ ])
96
+ # print("SAT_MEAN", self.sat_mean, self.sat_std)
97
+
98
+ def gen_list_img_no(self,path):
99
+ pred_rad_dir =os.path.join(path,"pred_rad")
100
+ pred_sat_dir = os.path.join(path,"pred_sat")
101
+ GT_rad_dir = os.path.join(path ,"rad")
102
+ GT_sat_dir = os.path.join(path,"sat")
103
+ list_dir = []
104
+ # print()
105
+ # print(len(os.listdir(pred_rad_dir)))
106
+ for name in os.listdir(pred_rad_dir):
107
+ temp = []
108
+ if(not name.endswith("00.npz") and not name.endswith("03.npz")):
109
+ continue
110
+ temp.append(os.path.join(pred_rad_dir,name))
111
+ pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:])
112
+ GT_rad_path = os.path.join(GT_rad_dir, name)
113
+ GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
114
+ if(os.path.isfile(pred_sat_path)):
115
+ temp.append(pred_sat_path)
116
+ if(os.path.isfile(GT_rad_path)):
117
+ temp.append(GT_rad_path)
118
+ if(os.path.isfile(GT_sat_path)):
119
+ temp.append(GT_sat_path)
120
+ if(len(temp) == 4):
121
+ list_dir.append(temp)
122
+ return list_dir
123
+ def gen_list_img_rad(self,path):
124
+ pred_rad_dir = os.path.join(path,"rad")
125
+ pred_sat_dir = os.path.join(path,"pred_sat")
126
+ GT_rad_dir = os.path.join(path ,"rad")
127
+ GT_sat_dir = os.path.join(path,"sat")
128
+ list_dir = []
129
+ for name in os.listdir(pred_rad_dir):
130
+ temp = []
131
+ if( not name.endswith("00.npz") and not name.endswith("03.npz")):
132
+ continue
133
+ temp_date = self.get_date_time(name)
134
+ temp.append(os.path.join(pred_rad_dir,name))
135
+ pred_sat_path = os.path.join(pred_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
136
+ GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz')
137
+ GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
138
+ if(os.path.isfile(pred_sat_path)):
139
+ temp.append(pred_sat_path)
140
+ if(os.path.isfile(GT_rad_path)):
141
+ temp.append(GT_rad_path)
142
+ if(os.path.isfile(GT_sat_path)):
143
+ temp.append(GT_sat_path)
144
+ if(len(temp) == 4):
145
+ list_dir.append(temp)
146
+ return list_dir
147
+ def gen_list_img_sat(self,path):
148
+ pred_rad_dir = os.path.join(path,"pred_rad")
149
+ pred_sat_dir = os.path.join(path,"sat")
150
+ GT_rad_dir = os.path.join(path ,"rad")
151
+ GT_sat_dir = os.path.join(path,"sat")
152
+ list_dir = []
153
+ for name in os.listdir(pred_rad_dir):
154
+ temp = []
155
+ if( not name.endswith("00.npz") and not name.endswith("03.npz")):
156
+ continue
157
+ temp_date = self.get_date_time(name)
158
+ temp.append(os.path.join(pred_rad_dir,name))
159
+ pred_sat_path = os.path.join(pred_sat_dir, (temp_date-timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
160
+ GT_rad_path = os.path.join(GT_rad_dir, name)
161
+ GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
162
+ if(os.path.isfile(pred_sat_path)):
163
+ temp.append(pred_sat_path)
164
+ if(os.path.isfile(GT_rad_path)):
165
+ temp.append(GT_rad_path)
166
+ if(os.path.isfile(GT_sat_path)):
167
+ temp.append(GT_sat_path)
168
+ if(len(temp) == 4):
169
+ list_dir.append(temp)
170
+ return list_dir
171
+ def gen_list_img_full(self,path):
172
+ pred_rad_dir = os.path.join(path,"rad")
173
+ pred_sat_dir = os.path.join(path,"sat")
174
+ GT_rad_dir = os.path.join(path ,"rad")
175
+ GT_sat_dir = os.path.join(path,"sat")
176
+ list_dir = []
177
+ for name in os.listdir(pred_rad_dir):
178
+ temp = []
179
+ if(not name.endswith("00.npz") and not name.endswith("03.npz")):
180
+ continue
181
+ temp_date = self.get_date_time(name)
182
+ temp.append(os.path.join(pred_rad_dir,name))
183
+ pred_sat_path = os.path.join(pred_sat_dir,temp_date.strftime('%Y%m%d%H')+'.npz')
184
+ GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz')
185
+ GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
186
+ if(os.path.isfile(pred_sat_path)):
187
+ temp.append(pred_sat_path)
188
+ if(os.path.isfile(GT_rad_path)):
189
+ temp.append(GT_rad_path)
190
+ if(os.path.isfile(GT_sat_path)):
191
+ temp.append(GT_sat_path)
192
+ if(len(temp) == 4):
193
+ list_dir.append(temp)
194
+ return list_dir
195
+ def gen_list_img_time(self,path):
196
+ pred_rad_dir =os.path.join(path,"pred_rad")
197
+ pred_sat_dir = os.path.join(path,"pred_sat")
198
+ GT_rad_dir = os.path.join(path ,"rad")
199
+ GT_sat_dir = os.path.join(path,"sat")
200
+ list_dir = []
201
+ for name in os.listdir(pred_rad_dir):
202
+ temp = [[],[],[],[]]
203
+ temp_date = self.get_date_time(name)
204
+ if(not name.endswith("00.npz") and not name.endswith("03.npz")):
205
+ continue
206
+ for i in range(4):
207
+ temp_path = os.path.join(GT_rad_dir, (temp_date+timedelta(minutes=-210+i*10)).strftime('%Y%m%d%H%M') + '.npz')
208
+ if(os.path.isfile(temp_path)): temp[0].append(temp_path)
209
+ for i in range(1):
210
+ temp_path = os.path.join(GT_sat_dir, (temp_date+timedelta(minutes=-180+i*10)).strftime('%Y%m%d%H') + '.npz')
211
+ if(os.path.isfile(temp_path)): temp[1].append(temp_path)
212
+ temp[0].append(os.path.join(pred_rad_dir,name))
213
+ pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:])
214
+ GT_rad_path = os.path.join(GT_rad_dir, name)
215
+ GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
216
+ if(os.path.isfile(pred_sat_path)):
217
+ temp[1].append(pred_sat_path)
218
+ if(os.path.isfile(GT_rad_path)):
219
+ temp[2].append(GT_rad_path)
220
+ if(os.path.isfile(GT_sat_path)):
221
+ temp[3].append(GT_sat_path)
222
+ if(len(temp[0]) == 5 and len(temp[1]) == 2 and len(temp[2]) == 1 and len(temp[3]) == 1):
223
+ list_dir.append(temp)
224
+ return list_dir
225
+ def get_date_time(self,name):
226
+ year=int(name[0:4])
227
+ month=int(name[4:6])
228
+ day=int(name[6:8])
229
+ hour=int(name[8:10])
230
+ minute = int(name[10:12])
231
+ return datetime(year,month,day,hour,minute)
232
+
233
+
234
+
235
+
236
+ class WeatherForecastDataModule(LightningDataModule):
237
+ def __init__(
238
+ self,
239
+ dir_data: str,
240
+ batch_size:int ,
241
+ hours_predicted :int,
242
+ num_workers:int ,
243
+ pin_memory: bool ,
244
+ time_points_rad : int,
245
+ time_points_sat : int,
246
+ sat_inp_vars: str,
247
+ sat_out_vars : str,
248
+ sat_size: int,
249
+ rad_inp_vars : str,
250
+ rad_out_vars : str,
251
+ rad_size: int,
252
+ ablation: str,
253
+ ):
254
+
255
+ super().__init__()
256
+ # this line allows to access init params with 'self.hparams' attribute
257
+ self.save_hyperparameters(logger=True)
258
+ self.data_train = None
259
+ self.data_test = None
260
+ self.data_val = None
261
+ self.rad_mean = np.load(os.path.join(self.hparams.dir_data,'rad_mean.npz'))[self.hparams.rad_inp_vars]
262
+ self.rad_std = np.load(os.path.join(self.hparams.dir_data,'rad_std.npz'))[self.hparams.rad_inp_vars]
263
+ self.sat_mean = np.load(os.path.join(self.hparams.dir_data,'sat_mean.npz'))[self.hparams.sat_inp_vars]
264
+ self.sat_std = np.load(os.path.join(self.hparams.dir_data,'sat_std.npz'))[self.hparams.sat_inp_vars]
265
+ def prepare_data(self):
266
+ pass
267
+
268
+ def setup(self, stage):
269
+ # print(self.hparams.dir_data)
270
+ self.data_train = DataReader(
271
+ dir_data=self.hparams.dir_data,
272
+ type_data= "train",
273
+ rad_attribute = self.hparams.rad_inp_vars,
274
+ sat_attribute = self.hparams.sat_inp_vars,
275
+ hours_predicted = self.hparams.hours_predicted,
276
+ rad_predicted = self.hparams.rad_out_vars,
277
+ sat_predicted = self.hparams.sat_out_vars,
278
+ time_points_rad = self.hparams.time_points_rad,
279
+ time_points_sat = self.hparams.time_points_sat,
280
+ sat_size = self.hparams.sat_size,
281
+ rad_size = self.hparams.rad_size,
282
+ ablation = self.hparams.ablation
283
+ )
284
+ self.data_test = DataReader(
285
+ dir_data=self.hparams.dir_data,
286
+ type_data ="test",
287
+ rad_attribute = self.hparams.rad_inp_vars,
288
+ sat_attribute = self.hparams.sat_inp_vars,
289
+ hours_predicted = self.hparams.hours_predicted,
290
+ rad_predicted = self.hparams.rad_out_vars,
291
+ sat_predicted = self.hparams.sat_out_vars,
292
+ time_points_rad = self.hparams.time_points_rad,
293
+ time_points_sat = self.hparams.time_points_sat,
294
+ sat_size = self.hparams.sat_size,
295
+ rad_size = self.hparams.rad_size,
296
+ ablation = self.hparams.ablation
297
+ )
298
+ self.data_val = DataReader(
299
+ dir_data=self.hparams.dir_data,
300
+ type_data = "val",
301
+ rad_attribute = self.hparams.rad_inp_vars,
302
+ sat_attribute = self.hparams.sat_inp_vars,
303
+ hours_predicted = self.hparams.hours_predicted,
304
+ rad_predicted = self.hparams.rad_out_vars,
305
+ sat_predicted = self.hparams.sat_out_vars,
306
+ time_points_rad = self.hparams.time_points_rad,
307
+ time_points_sat = self.hparams.time_points_sat,
308
+ sat_size = self.hparams.sat_size,
309
+ rad_size = self.hparams.rad_size,
310
+ ablation = self.hparams.ablation
311
+ )
312
+
313
+ def train_dataloader(self):
314
+ return DataLoader(
315
+ self.data_train,
316
+ batch_size=self.hparams.batch_size,
317
+ num_workers=self.hparams.num_workers,
318
+ drop_last=False,
319
+ pin_memory=self.hparams.pin_memory,
320
+ shuffle=True,
321
+ )
322
+
323
+ def val_dataloader(self):
324
+ return DataLoader(
325
+ self.data_val,
326
+ batch_size=self.hparams.batch_size,
327
+ num_workers=self.hparams.num_workers,
328
+ drop_last=False,
329
+ pin_memory=self.hparams.pin_memory,
330
+ shuffle=False,
331
+ )
332
+
333
+ def test_dataloader(self):
334
+ return DataLoader(
335
+ self.data_test,
336
+ batch_size=self.hparams.batch_size,
337
+ num_workers=self.hparams.num_workers,
338
+ drop_last=False,
339
+ pin_memory=self.hparams.pin_memory,
340
+ shuffle=False,
341
+ )
src/lr_scheduler.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ import warnings
6
+ from typing import List
7
+
8
+ from torch.optim import Optimizer
9
+ from torch.optim.lr_scheduler import _LRScheduler
10
+
11
+
12
+ class LinearWarmupCosineAnnealingLR(_LRScheduler):
13
+ """Sets the learning rate of each parameter group to follow a linear warmup schedule between
14
+ warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and
15
+ eta_min."""
16
+
17
+ def __init__(
18
+ self,
19
+ optimizer: Optimizer,
20
+ warmup_epochs: int,
21
+ max_epochs: int,
22
+ warmup_start_lr: float = 0.0,
23
+ eta_min: float = 0.0,
24
+ last_epoch: int = -1,
25
+ ) -> None:
26
+ """
27
+ Args:
28
+ optimizer (Optimizer): Wrapped optimizer.
29
+ warmup_epochs (int): Maximum number of iterations for linear warmup
30
+ max_epochs (int): Maximum number of iterations
31
+ warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
32
+ eta_min (float): Minimum learning rate. Default: 0.
33
+ last_epoch (int): The index of last epoch. Default: -1.
34
+ """
35
+ self.warmup_epochs = warmup_epochs
36
+ self.max_epochs = max_epochs
37
+ self.warmup_start_lr = warmup_start_lr
38
+ self.eta_min = eta_min
39
+
40
+ super().__init__(optimizer, last_epoch)
41
+
42
+ def get_lr(self) -> List[float]:
43
+ """Compute learning rate using chainable form of the scheduler."""
44
+ if not self._get_lr_called_within_step:
45
+ warnings.warn(
46
+ "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
47
+ UserWarning,
48
+ )
49
+
50
+ if self.last_epoch == self.warmup_epochs:
51
+ return self.base_lrs
52
+ if self.last_epoch == 0:
53
+ return [self.warmup_start_lr] * len(self.base_lrs)
54
+ if self.last_epoch < self.warmup_epochs:
55
+ return [
56
+ group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
57
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
58
+ ]
59
+ if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
60
+ return [
61
+ group["lr"]
62
+ + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
63
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
64
+ ]
65
+
66
+ return [
67
+ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
68
+ / (
69
+ 1
70
+ + math.cos(
71
+ math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
72
+ )
73
+ )
74
+ * (group["lr"] - self.eta_min)
75
+ + self.eta_min
76
+ for group in self.optimizer.param_groups
77
+ ]
78
+
79
+ def _get_closed_form_lr(self) -> List[float]:
80
+ """Called when epoch is passed as a param to the `step` function of the scheduler."""
81
+ if self.last_epoch < self.warmup_epochs:
82
+ return [
83
+ self.warmup_start_lr
84
+ + self.last_epoch * (base_lr - self.warmup_start_lr) / max(1, self.warmup_epochs - 1)
85
+ for base_lr in self.base_lrs
86
+ ]
87
+
88
+ return [
89
+ self.eta_min
90
+ + 0.5
91
+ * (base_lr - self.eta_min)
92
+ * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
93
+ for base_lr in self.base_lrs
94
+ ]
src/metric.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ def MSE(pred,GT,lat,clim):
4
+ return torch.mean((pred-GT)**2)
5
+ def RMSE(pred,GT,lat,clim):
6
+ return torch.sqrt(torch.mean((pred-GT)**2))
7
+ def MAE(pred,GT,lat,clim):
8
+ return torch.mean(torch.abs(pred-GT))
9
+ def WMSE(pred, y, lat,clim):
10
+ if(lat is None):return 0
11
+ error = (pred - y) ** 2 # [N, C, H, W]
12
+ # lattitude weights
13
+ w_lat = np.cos(np.deg2rad(lat))
14
+ w_lat = w_lat / w_lat.mean()
15
+ w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=error.dtype, device=error.device) # (1, H, 1)
16
+ loss = (error * w_lat).mean()
17
+ return loss
18
+ def WRMSE(pred,GT,lat,clim):
19
+ if(lat is None):return 0
20
+ error = (pred - GT) ** 2 # [B, V, H, W]
21
+ # lattitude weights
22
+ w_lat = np.cos(np.deg2rad(lat))
23
+ w_lat = w_lat / w_lat.mean() # (H, )
24
+ w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=error.dtype, device=error.device)
25
+ loss = torch.mean(
26
+ torch.sqrt(torch.mean(error* w_lat, dim=(-2, -1)))
27
+ )
28
+ return loss
29
+ def ACC(pred,GT,lat,clim):
30
+ if(lat is None):return 0
31
+ w_lat = np.cos(np.deg2rad(lat))
32
+ w_lat = w_lat / w_lat.mean() # (H, )
33
+ w_lat = torch.from_numpy(w_lat).unsqueeze(0).unsqueeze(-1).to(dtype=pred.dtype, device=pred.device) # [1, H, 1]
34
+
35
+ # = torch.mean(y, dim=(0, 1), keepdim=True)
36
+ clim = clim.to(device=GT.device).unsqueeze(0)
37
+ pred = pred - clim
38
+ GT = GT - clim
39
+ pred_prime = pred - torch.mean(pred)
40
+ GT_prime = GT - torch.mean(GT)
41
+ loss = torch.sum(w_lat * pred_prime * GT_prime) / torch.sqrt(
42
+ torch.sum(w_lat * pred_prime**2) * torch.sum(w_lat * GT_prime**2)
43
+ )
44
+ return loss
src/module.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT license.
3
+
4
+ # credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py
5
+ from typing import Any
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ from pytorch_lightning import LightningModule
10
+ from torchvision.transforms import transforms
11
+ from lr_scheduler import LinearWarmupCosineAnnealingLR
12
+ from arch import Network
13
+ from metric import (
14
+ MSE,RMSE,MAE,ACC,WMSE,WRMSE
15
+ )
16
+ class WeatherForecastModule(LightningModule):
17
+ """Lightning module for global forecasting with the ClimaX model.
18
+ Args:
19
+ net: Deeplearning model.
20
+ pretrained_path (str, optional): Path to pre-trained checkpoint.
21
+ lr (float, optional): Learning rate.
22
+ beta_1 (float, optional): Beta 1 for AdamW.
23
+ beta_2 (float, optional): Beta 2 for AdamW.
24
+ weight_decay (float, optional): Weight decay for AdamW.
25
+ warmup_epochs (int, optional): Number of warmup epochs.
26
+ max_epochs (int, optional): Number of total epochs.
27
+ warmup_start_lr (float, optional): Starting learning rate for warmup.
28
+ eta_min (float, optional): Minimum learning rate.
29
+ """
30
+ def __init__(
31
+ self,
32
+ net: Network,
33
+ pretrained_path: str = "",
34
+ lr: float = 5e-4,
35
+ beta_1: float = 0.9,
36
+ beta_2: float = 0.99,
37
+ weight_decay: float = 1e-5,
38
+ warmup_epochs: int = 10000,
39
+ max_epochs: int = 200000,
40
+ warmup_start_lr: float = 1e-8,
41
+ eta_min: float = 1e-8,
42
+ ):
43
+ super().__init__()
44
+ self.save_hyperparameters(logger=True, ignore=["net"])
45
+ self.net = net
46
+ if len(pretrained_path) > 0:
47
+ self.load_pretrained_weights(pretrained_path)
48
+
49
+ def load_pretrained_weights(self, pretrained_path):
50
+ self.net.load_state_dict(torch.load(pretrained_path))
51
+ def set_path(self,path):
52
+ self.path = path
53
+ def set_size(self,rad_size,sat_size):
54
+ self.rad_size = rad_size
55
+ self.sat_size = sat_size
56
+
57
+ def set_lat(self):
58
+ lat = np.load(os.path.join(self.path,'sat_lat.npy'))
59
+ self.sat_lat = lat[lat.shape[-1]//2-self.sat_size//2:lat.shape[-1]//2+self.sat_size//2]
60
+ # self.sat_lat = np.load(os.path.join(self.path,'sat_lat.npy'))
61
+ # self.sat_clim = torch.from_numpy(np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation'])
62
+ def set_clim(self):
63
+ ##########
64
+ rad_clim = np.load(os.path.join(self.path,'rad_clim.npz'))['precipitation']
65
+ sat_clim = np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation']
66
+ self.rad_clim = torch.from_numpy(rad_clim)
67
+ self.sat_clim = torch.from_numpy(sat_clim)
68
+
69
+ def set_normalize(self):
70
+ self.rad_mean = np.load(os.path.join(self.path,'rad_mean.npz'))['precipitation']
71
+ self.rad_std = np.load(os.path.join(self.path,'rad_std.npz'))['precipitation']
72
+ self.sat_mean = np.load(os.path.join(self.path,'sat_mean.npz'))['total_precipitation']
73
+ self.sat_std = np.load(os.path.join(self.path,'sat_std.npz'))['total_precipitation']
74
+ def set_denormalize(self):
75
+ self.rad_denormalization = transforms.Normalize(-self.rad_mean/self.rad_std,1/self.rad_std)
76
+ self.sat_denormalization = transforms.Normalize(-self.sat_mean/self.sat_std,1/self.sat_std)
77
+ def training_step(self, batch: Any, batch_idx: int):
78
+ inp_rad, inp_sat, out_rad, out_sat = batch
79
+ pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
80
+ loss = torch.nn.MSELoss()
81
+ loss_rad = loss(pred_rad,out_rad)
82
+ loss_sat = loss(pred_sat,out_sat)
83
+ loss_tot = loss_rad + loss_sat
84
+ self.log("train/rad", loss_rad, prog_bar=True, logger = True)
85
+ self.log("train/sat", loss_sat, prog_bar=True, logger = True)
86
+ self.log("train/mse", loss_tot, prog_bar=True, logger = True)
87
+ return loss_tot
88
+
89
+ def validation_step(self, batch: Any, batch_idx: int):
90
+ inp_rad, inp_sat, out_rad, out_sat = batch
91
+ pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
92
+ loss = torch.nn.MSELoss()
93
+ with torch.no_grad():
94
+ loss_rad = loss(pred_rad,out_rad)
95
+ loss_sat = loss(pred_sat,out_sat)
96
+ loss_tot = loss_rad + loss_sat
97
+ self.log("val/rad", loss_rad, prog_bar=True, logger = True)
98
+ self.log("val/sat", loss_sat, prog_bar=True, logger = True)
99
+ self.log("val/mse", loss_tot, prog_bar=True, logger = True)
100
+ return loss_tot
101
+ def test_step(self, batch: Any, batch_idx: int):
102
+ inp_rad, inp_sat, out_rad, out_sat = batch
103
+ pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
104
+ loss = torch.nn.MSELoss()
105
+ self.rad_denormalization(out_rad)
106
+ rad_metric = [MSE,RMSE,ACC,MAE]
107
+ sat_metric = [MSE,WMSE,RMSE,WRMSE,ACC,MAE]
108
+
109
+ with torch.no_grad():
110
+ loss_rad = loss(self.rad_denormalization(pred_rad),self.rad_denormalization(out_rad))
111
+ loss_sat = loss(self.sat_denormalization(pred_sat),self.sat_denormalization(out_sat))
112
+ loss_tot = loss_rad + loss_sat
113
+ self.log(f"test/rad", loss_rad, prog_bar=True, logger = True)
114
+ self.log("test/sat", loss_sat, prog_bar=True, logger = True)
115
+ self.log("test/mse", loss_tot, prog_bar=True, logger = True)
116
+ for met in rad_metric:
117
+ loss_rad = met(
118
+ self.rad_denormalization(pred_rad),
119
+ self.rad_denormalization(out_rad),
120
+ np.ones(self.rad_size),
121
+ self.rad_clim
122
+ )
123
+ self.log(f"test/rad_{met.__name__}", loss_rad, prog_bar=True, logger = True)
124
+ for met in sat_metric:
125
+ loss_sat = met(
126
+ self.sat_denormalization(pred_sat),
127
+ self.sat_denormalization(out_sat),
128
+ self.sat_lat,
129
+ self.sat_clim,
130
+ )
131
+ self.log(f"test/sat_{met.__name__}", loss_sat, prog_bar=True, logger = True)
132
+ return loss_tot
133
+ def configure_optimizers(self):
134
+ decay = []
135
+ no_decay = []
136
+ for name, m in self.named_parameters():
137
+ if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name:
138
+ no_decay.append(m)
139
+ else:
140
+ decay.append(m)
141
+ optimizer = torch.optim.AdamW(
142
+ [
143
+ {
144
+ "params": decay,
145
+ "lr": self.hparams.lr,
146
+ "betas": (self.hparams.beta_1, self.hparams.beta_2),
147
+ "weight_decay": self.hparams.weight_decay,
148
+ },
149
+ {
150
+ "params": no_decay,
151
+ "lr": self.hparams.lr,
152
+ "betas": (self.hparams.beta_1, self.hparams.beta_2),
153
+ "weight_decay": 0,
154
+ },
155
+ ]
156
+ )
157
+
158
+ lr_scheduler = LinearWarmupCosineAnnealingLR(
159
+ optimizer,
160
+ self.hparams.warmup_epochs,
161
+ self.hparams.max_epochs,
162
+ self.hparams.warmup_start_lr,
163
+ self.hparams.eta_min,
164
+ )
165
+ scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
166
+
167
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
168
+
src/rad_clim.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import math
4
+ import numpy as np
5
+ from copy import deepcopy
6
+ path_load = '/data/weather2025/NhaBe/train/rad'
7
+ path_save = '/data/weather2025/NhaBe'
8
+ num = 0
9
+ rad_clim = {}
10
+ for name in os.listdir(path_load):
11
+ file = np.load(os.path.join(path_load,name))
12
+ for field in file.keys():
13
+ if(num == 0):
14
+ rad_clim[field] = file[field]
15
+ else:
16
+ rad_clim[field] = rad_clim[field] + file[field]
17
+ num += 1
18
+ print(num,end='\r')
19
+ for field in rad_clim.keys():
20
+ rad_clim[field] = rad_clim[field]/num
21
+ rad_clim[field] = np.expand_dims(rad_clim[field],axis =0)
22
+ print(rad_clim[field].shape)
23
+ np.savez(os.path.join(path_save,'rad_clim.npz'),**rad_clim)