VachelHu commited on
Commit
b310f83
1 Parent(s): 301b050

Upload infer2img.py

Browse files
Files changed (1) hide show
  1. infer2img.py +196 -0
infer2img.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pandas as pd
4
+ import numpy as np
5
+ from functools import partial
6
+ from datetime import datetime, timedelta
7
+ from pathlib import Path
8
+ import pickle
9
+
10
+ import dask
11
+ import dask.array as da
12
+ import cartopy
13
+ import cartopy.crs as ccrs
14
+ import xarray as xr
15
+ import xarray.ufuncs as xu
16
+ import matplotlib.pyplot as plt
17
+
18
+ from model.afnonet import AFNONet # download the model code from https://github.com/HFAiLab/FourCastNet/blob/master/model/afnonet.py
19
+
20
+ DATANAMES = ['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature',
21
+ 'geopotential@1000', 'geopotential@50', 'geopotential@500', 'geopotential@850',
22
+ 'mean_sea_level_pressure', 'relative_humidity@500', 'relative_humidity@850',
23
+ 'surface_pressure', 'temperature@500', 'temperature@850', 'total_column_water_vapour',
24
+ 'u_component_of_wind@1000', 'u_component_of_wind@500', 'u_component_of_wind@850',
25
+ 'v_component_of_wind@1000', 'v_component_of_wind@500', 'v_component_of_wind@850',
26
+ 'total_precipitation']
27
+ DATAMAP = {
28
+ 'geopotential': 'z',
29
+ 'relative_humidity': 'r',
30
+ 'temperature': 't',
31
+ 'u_component_of_wind': 'u',
32
+ 'v_component_of_wind': 'v'
33
+ }
34
+
35
+
36
+ def load_model():
37
+ # input size
38
+ h, w = 720, 1440
39
+ x_c, y_c, p_c = 20, 20, 1
40
+
41
+ backbone_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=y_c, norm_layer=partial(nn.LayerNorm, eps=1e-6))
42
+ ckpt = torch.load('./backbone.pt', map_location="cpu")
43
+ backbone_model.load_state_dict(ckpt['model'])
44
+
45
+ precip_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=p_c, norm_layer=partial(nn.LayerNorm, eps=1e-6))
46
+ ckpt = torch.load('./precipitation.pt', map_location="cpu")
47
+ precip_model.load_state_dict(ckpt['model'])
48
+
49
+
50
+ def imcol(data, img_path, img_name, **kwargs):
51
+ fig = plt.figure(figsize=(20, 10))
52
+ ax = plt.axes(projection=ccrs.PlateCarree())
53
+
54
+ I = data.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, add_labels=False, rasterized=True, **kwargs)
55
+ ax.coastlines(resolution='110m')
56
+
57
+ dirname = f'{img_path.absolute()}/{img_name}.jpg'
58
+
59
+ plt.axis('off')
60
+ plt.savefig(dirname, bbox_inches='tight', pad_inches=0.)
61
+ plt.close(fig)
62
+
63
+
64
+ def plot(real_data, pred_data, save_path):
65
+ cmap_t = 'RdYlBu_r'
66
+
67
+ wind = xu.sqrt(real_data['u10'] ** 2 + real_data['v10'] ** 2)
68
+ wmin, wmax = wind.values.min(), wind.values.max()
69
+ wind = xu.sqrt(pred_data['u10'] ** 2 + pred_data['v10'] ** 2)
70
+ wmin, wmax = min(wind.values.min(), wmin), max(wind.values.max(), wmax)
71
+
72
+ pmin, pmax = real_data['tp'].values.min(), real_data['tp'].values.max()
73
+ pmin, pmax = min(pred_data['tp'].values.min(), pmin), max(pred_data['tp'].values.max(), pmax)
74
+
75
+ tmin, tmax = real_data['t2m'].values.min(), real_data['t2m'].values.max()
76
+ tmin, tmax = min(pred_data['t2m'].values.min(), tmin), max(pred_data['t2m'].values.max(), tmax)
77
+
78
+ for i in range(len(real_data.time)):
79
+ u = real_data['u10'].isel(time=i)
80
+ v = real_data['v10'].isel(time=i)
81
+ wind = xu.sqrt(u ** 2 + v ** 2)
82
+ precip = real_data['tp'].isel(time=i)
83
+ temp = real_data['t2m'].isel(time=i)
84
+
85
+ datetime = pd.to_datetime(str(wind['time'].values))
86
+ datetime = datetime.strftime('%Y-%m-%d %H:%M:%S')
87
+ print(f'plot {datetime}')
88
+
89
+ imcol(wind, save_path, img_name=f'wind_{datetime}_real', cmap=cmap_t, vmin=wmin, vmax=wmax),
90
+ imcol(precip, save_path, img_name=f'precipitation_{datetime}_real', cmap=cmap_t, vmin=pmin, vmax=pmax),
91
+ imcol(temp, save_path, img_name=f'temperature_{datetime}_real', cmap=cmap_t, vmin=tmin, vmax=tmax)
92
+
93
+ for i in range(len(pred_data.time)):
94
+ u = pred_data['u10'].isel(time=i)
95
+ v = pred_data['v10'].isel(time=i)
96
+ wind = xu.sqrt(u ** 2 + v ** 2)
97
+ precip = pred_data['tp'].isel(time=i)
98
+ temp = pred_data['t2m'].isel(time=i)
99
+
100
+ datetime = pd.to_datetime(str(wind['time'].values))
101
+ datetime = datetime.strftime('%Y-%m-%d %H:%M:%S')
102
+ print(f'plot {datetime}')
103
+
104
+ imcol(wind, save_path, img_name=f'wind_{datetime}_pred', cmap=cmap_t, vmin=wmin, vmax=wmax),
105
+ imcol(precip, save_path, img_name=f'precipitation_{datetime}_pred', cmap=cmap_t, vmin=pmin, vmax=pmax),
106
+ imcol(temp, save_path, img_name=f'temperature_{datetime}_pred', cmap=cmap_t, vmin=tmin, vmax=tmax)
107
+
108
+
109
+ def get_pred(sample, scaler, times=None, latitude=None, longitude=None):
110
+
111
+ backbone_model, precip_model = load_model()
112
+
113
+ sample = torch.from_numpy(sample[0])
114
+ sample = sample.float()
115
+
116
+ backbone_model.eval()
117
+ precip_model.eval()
118
+ pred = []
119
+ x = sample.unsqueeze(0).transpose(3, 2).transpose(2, 1)
120
+ for i in range(len(times)):
121
+ print(f"predict {times[i]}")
122
+
123
+ with torch.cuda.amp.autocast():
124
+ x = backbone_model(x)
125
+ tmp = x.transpose(1, 2).transpose(2, 3)
126
+ p = precip_model(x)
127
+
128
+ tmp = tmp.detach().numpy()[0, :, :, :3] * scaler['std'][:3] + scaler['mean'][:3]
129
+ p = p.detach().numpy()[0, 0, :, :, np.newaxis] * scaler['std'][-1] + scaler['mean'][-1]
130
+ tmp = np.concatenate([tmp, p], axis=-1)
131
+ pred.append(tmp)
132
+
133
+ pred = np.asarray(pred)
134
+ pred_data = xr.Dataset({
135
+ 'u10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 0], chunks=(7, 720, 1440))),
136
+ 'v10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 1], chunks=(7, 720, 1440))),
137
+ 't2m': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 2], chunks=(7, 720, 1440))),
138
+ 'tp': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 3], chunks=(7, 720, 1440))),
139
+ },
140
+ coords={'time': (['time'], times),
141
+ 'latitude': (['latitude'], latitude),
142
+ 'longitude': (['longitude'], longitude)
143
+ }
144
+ )
145
+
146
+ return pred_data
147
+
148
+
149
+ def get_data(start_time, end_time):
150
+ times = slice(start_time, end_time)
151
+
152
+ with open(f'./scaler.pkl', "rb") as f: # the mean and std of each atmospheric variables
153
+ scaler = pickle.load(f)
154
+
155
+ # load weather data
156
+ datas = []
157
+ for file in DATANAMES:
158
+ tmp = xr.open_mfdataset(f'./ERA5_rawdata/{file}/*.nc', combine='by_coords').sel(time=times)
159
+ if '@' in file:
160
+ k, v = file.split('@')
161
+ tmp = tmp.rename_vars({DATAMAP[k]: f'{DATAMAP[k]}@{v}'})
162
+ datas.append(tmp)
163
+ with dask.config.set(**{'array.slicing.split_large_chunks': False}):
164
+ raw_data = xr.merge(datas, compat="identical", join="inner")
165
+
166
+ data = []
167
+ for name in ['u10', 'v10', 't2m', 'z@1000', 'z@50', 'z@500', 'z@850', 'msl', 'r@500', 'r@850', 'sp', 't@500', 't@850', 'tcwv', 'u@1000', 'u@500', 'u@850', 'v@1000', 'v@500', 'v@850']:
168
+ raw = raw_data[name].values
169
+ data.append(raw)
170
+
171
+ data = np.stack(data, axis=-1)
172
+ data = (data - scaler['mean']) / scaler['std']
173
+ data = data[:, 1:, :, :] # 721*1440 -> 720*1440
174
+
175
+ return raw_data[['u10', 'v10', 't2m', 'tp']].sel(expver=1), data, scaler
176
+
177
+
178
+
179
+ if __name__ == '__main__':
180
+
181
+ start_time = datetime(2023, 1, 1, 0, 0)
182
+ end_time = datetime(2023, 1, 5, 18, 0)
183
+ num = int((end_time - start_time) / timedelta(hours=6))
184
+
185
+ print(f"start_time: {start_time}, end_time: {end_time}, pred_num: {num}")
186
+
187
+ real_data, sample, scaler = get_data(start_time)
188
+ print(sample.shape)
189
+
190
+ pred_times = [start_time + timedelta(hours=6) * i for i in range(1, num)]
191
+ pred = get_pred(sample, scaler=scaler, times=pred_times, latitude=real_data.latitude[1:], longitude=real_data.longitude)
192
+
193
+ save_path = Path(f"./output/")
194
+ save_path.mkdir(parents=True, exist_ok=True)
195
+
196
+ plot(real_data, pred, save_path)