File size: 7,169 Bytes
d6d123e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# --------------------------------------------------------------------------------------------------------------------- #
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
# the FuXi architecture has been modified to reduce the overall model size
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
#
# Yingkai Sha
# ksha@ucar.edu
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
seed: 1000
data:
# upper-air variables
variables: ['U','V','T','Q']
save_loc: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
# surface variables
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
save_loc_surface: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
# dynamic forcing variables
dynamic_forcing_variables: ['tsi']
save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
# diagnostic variables
# diagnostic_variables: ['V500','U500','T500','Z500','Q500']
# save_loc_diagnostic: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
# static variables
static_variables: ['Z_GDS4_SFC','LSM']
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
# mean / std path
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
# train / validation split
train_years: [1979, 2018]
valid_years: [2018, 2019]
# data workflow
scaler_type: 'std_new'
# number of input states
# FuXi has 2 input states
history_len: 2
valid_history_len: 2
# number of forecast steps to compute loss
# 0 for single step training / validation
# larger than 0 for multi-step training / validation
forecast_len: 0
valid_forecast_len: 0
# one_shot: True --> compute loss on the last forecast step only
# one_shot: False --> compute loss on all forecast steps
one_shot: True
# 1 for hourly model
lead_time_periods: 6
# do not use skip_period
skip_periods: null
# compatible with the old 'std'
static_first: True
trainer:
type: standard # <---------- change to your type
mode: fsdp
cpu_offload: False
activation_checkpoint: True
load_weights: True
load_optimizer: True
load_scaler: True
load_sheduler: True
skip_validation: False
update_learning_rate: False
save_backup_weights: True
save_best_weights: True
learning_rate: 1.0e-03 # <-- change to your lr
weight_decay: 0
train_batch_size: 1
valid_batch_size: 1
batches_per_epoch: 0
valid_batches_per_epoch: 0
stopping_patience: 50
start_epoch: 0
num_epoch: 2
reload_epoch: True
epochs: &epochs 70
use_scheduler: True
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
# Automatic Mixed Precision: False
amp: False
# rescale loss as loss = loss / grad_accum_every
grad_accum_every: 1
# gradient clipping
grad_max_norm: 1.0
# number of workers
thread_workers: 4
valid_thread_workers: 0
model:
type: "fuxi"
frames: 2 # number of input states
image_height: 640 # number of latitude grids
image_width: 1280 # number of longitude grids
levels: 16 # number of upper-air variable levels
channels: 4 # upper-air variable channels
surface_channels: 7 # surface variable channels
input_only_channels: 3 # dynamic forcing, forcing, static channels
output_only_channels: 0 # diagnostic variable channels
# patchify layer
patch_height: 4 # number of latitude grids in each 3D patch
patch_width: 4 # number of longitude grids in each 3D patch
frame_patch_size: 2 # number of input states in each 3D patch
# hidden layers
dim: 1024 # dimension (default: 1536)
num_groups: 32 # number of groups (default: 32)
num_heads: 8 # number of heads (default: 8)
window_size: 7 # window size (default: 7)
depth: 16 # number of swin transformers (default: 48)
# map boundary padding
pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
# use spectral norm
use_spectral_norm: True
loss:
# the main training loss
training_loss: "mse"
# power loss (x), spectral_loss (x)
use_power_loss: False
use_spectral_loss: False
# use latitude weighting
use_latitude_weights: True
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
# turn-off variable weighting
use_variable_weights: False
# variable_weights:
# U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
# V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
# T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
# Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
# SP: 0.1
# t2m: 1.0
# V500: 0.1
# U500: 0.1
# T500: 0.1
# Z500: 0.1
# Q500: 0.1
predict:
forecasts:
type: "custom" # keep it as "custom"
start_year: 2020 # year of the first initialization (where rollout will start)
start_month: 1 # month of the first initialization
start_day: 1 # day of the first initialization
start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
duration: 30 # number of days to initialize, starting from the (year, mon, day) above
# duration should be divisible by the number of GPUs
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
days: 2 # forecast lead time as days (1 means 24-hour forecast)
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
# turn-off low-pass filter
use_laplace_filter: False
# deprecated
# save_format: "nc"
pbs: #derecho
conda: "/glade/work/ksha/miniconda3/envs/credit"
project: "NAML0001"
job_name: "fuxi_6h"
walltime: "12:00:00"
nodes: 8
ncpus: 64
ngpus: 4
mem: '480GB'
queue: 'main'
|