File size: 7,169 Bytes
d6d123e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# --------------------------------------------------------------------------------------------------------------------- #
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu) 
# the FuXi architecture has been modified to reduce the overall model size
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
#
# Yingkai Sha
# ksha@ucar.edu
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
seed: 1000

data:
    # upper-air variables
    variables: ['U','V','T','Q']
    save_loc: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
    
    # surface variables
    surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
    save_loc_surface: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
    
    # dynamic forcing variables
    dynamic_forcing_variables: ['tsi']
    save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
    
    # diagnostic variables
    # diagnostic_variables: ['V500','U500','T500','Z500','Q500'] 
    # save_loc_diagnostic: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
    
    # static variables
    static_variables: ['Z_GDS4_SFC','LSM'] 
    save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
    
    # mean / std path
    mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
    std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
    
    # train / validation split
    train_years: [1979, 2018]
    valid_years: [2018, 2019]
    
    # data workflow
    scaler_type: 'std_new'  
    
    # number of input states
    # FuXi has 2 input states
    history_len: 2
    valid_history_len: 2
    
    # number of forecast steps to compute loss
    # 0 for single step training / validation
    # larger than 0 for multi-step training / validation
    forecast_len: 0
    valid_forecast_len: 0
    
    # one_shot: True --> compute loss on the last forecast step only
    # one_shot: False --> compute loss on all forecast steps
    one_shot: True
    
    # 1 for hourly model
    lead_time_periods: 6
    
    # do not use skip_period
    skip_periods: null
    
    # compatible with the old 'std'
    static_first: True
    
trainer:
    type: standard # <---------- change to your type
    
    mode: fsdp
    cpu_offload: False
    activation_checkpoint: True
    
    load_weights: True
    load_optimizer: True
    load_scaler: True
    load_sheduler: True

    skip_validation: False
    update_learning_rate: False
    
    save_backup_weights: True
    save_best_weights: True
    
    learning_rate: 1.0e-03 # <-- change to your lr
    weight_decay: 0
    
    train_batch_size: 1
    valid_batch_size: 1
    
    batches_per_epoch: 0
    valid_batches_per_epoch: 0
    stopping_patience: 50
    
    start_epoch: 0
    num_epoch: 2
    reload_epoch: True
    epochs: &epochs 70
    
    use_scheduler: True
    scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs,  'last_epoch': -1}

    # Automatic Mixed Precision: False
    amp: False
    
    # rescale loss as loss = loss / grad_accum_every
    grad_accum_every: 1 
    # gradient clipping
    grad_max_norm: 1.0
    
    # number of workers
    thread_workers: 4
    valid_thread_workers: 0
    
model:
    type: "fuxi"
    
    frames: 2               # number of input states
    image_height: 640       # number of latitude grids
    image_width: 1280       # number of longitude grids
    levels: 16              # number of upper-air variable levels
    channels: 4             # upper-air variable channels
    surface_channels: 7     # surface variable channels
    input_only_channels: 3  # dynamic forcing, forcing, static channels
    output_only_channels: 0 # diagnostic variable channels
    
    # patchify layer
    patch_height: 4         # number of latitude grids in each 3D patch
    patch_width: 4          # number of longitude grids in each 3D patch
    frame_patch_size: 2     # number of input states in each 3D patch
    
    # hidden layers
    dim: 1024               # dimension (default: 1536)
    num_groups: 32          # number of groups (default: 32)
    num_heads: 8            # number of heads (default: 8)
    window_size: 7          # window size (default: 7)
    depth: 16               # number of swin transformers (default: 48)
    
    # map boundary padding
    pad_lon: 80             # number of grids to pad on 0 and 360 deg lon
    pad_lat: 80             # number of grids to pad on -90 and 90 deg lat
    
    # use spectral norm
    use_spectral_norm: True
    
loss: 
    # the main training loss
    training_loss: "mse"
    
    # power loss (x), spectral_loss (x)
    use_power_loss: False 
    use_spectral_loss: False
    
    # use latitude weighting
    use_latitude_weights: True
    latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
    
    # turn-off variable weighting
    use_variable_weights: False
    # variable_weights:
    #     U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
    #     V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
    #     T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
    #     Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
    #     SP: 0.1
    #     t2m: 1.0
    #     V500: 0.1
    #     U500: 0.1
    #     T500: 0.1
    #     Z500: 0.1
    #     Q500: 0.1
    
predict:
    forecasts:
        type: "custom"       # keep it as "custom"
        start_year: 2020     # year of the first initialization (where rollout will start)
        start_month: 1       # month of the first initialization
        start_day: 1         # day of the first initialization
        start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
        duration: 30         # number of days to initialize, starting from the (year, mon, day) above
                             # duration should be divisible by the number of GPUs 
                             # (e.g., duration: 384 for 365-day rollout using 32 GPUs)
        days: 2              # forecast lead time as days (1 means 24-hour forecast)
        
    save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
    save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
    
    # turn-off low-pass filter
    use_laplace_filter: False
    
    # deprecated
    # save_format: "nc"
    
pbs: #derecho
    conda: "/glade/work/ksha/miniconda3/envs/credit"
    project: "NAML0001"
    job_name: "fuxi_6h"
    walltime: "12:00:00"
    nodes: 8
    ncpus: 64
    ngpus: 4
    mem: '480GB'
    queue: 'main'