Upload infer2img.py
Browse files- infer2img.py +196 -0
infer2img.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from functools import partial
|
6 |
+
from datetime import datetime, timedelta
|
7 |
+
from pathlib import Path
|
8 |
+
import pickle
|
9 |
+
|
10 |
+
import dask
|
11 |
+
import dask.array as da
|
12 |
+
import cartopy
|
13 |
+
import cartopy.crs as ccrs
|
14 |
+
import xarray as xr
|
15 |
+
import xarray.ufuncs as xu
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
from model.afnonet import AFNONet # download the model code from https://github.com/HFAiLab/FourCastNet/blob/master/model/afnonet.py
|
19 |
+
|
20 |
+
DATANAMES = ['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature',
|
21 |
+
'geopotential@1000', 'geopotential@50', 'geopotential@500', 'geopotential@850',
|
22 |
+
'mean_sea_level_pressure', 'relative_humidity@500', 'relative_humidity@850',
|
23 |
+
'surface_pressure', 'temperature@500', 'temperature@850', 'total_column_water_vapour',
|
24 |
+
'u_component_of_wind@1000', 'u_component_of_wind@500', 'u_component_of_wind@850',
|
25 |
+
'v_component_of_wind@1000', 'v_component_of_wind@500', 'v_component_of_wind@850',
|
26 |
+
'total_precipitation']
|
27 |
+
DATAMAP = {
|
28 |
+
'geopotential': 'z',
|
29 |
+
'relative_humidity': 'r',
|
30 |
+
'temperature': 't',
|
31 |
+
'u_component_of_wind': 'u',
|
32 |
+
'v_component_of_wind': 'v'
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def load_model():
|
37 |
+
# input size
|
38 |
+
h, w = 720, 1440
|
39 |
+
x_c, y_c, p_c = 20, 20, 1
|
40 |
+
|
41 |
+
backbone_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=y_c, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
42 |
+
ckpt = torch.load('./backbone.pt', map_location="cpu")
|
43 |
+
backbone_model.load_state_dict(ckpt['model'])
|
44 |
+
|
45 |
+
precip_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=p_c, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
46 |
+
ckpt = torch.load('./precipitation.pt', map_location="cpu")
|
47 |
+
precip_model.load_state_dict(ckpt['model'])
|
48 |
+
|
49 |
+
|
50 |
+
def imcol(data, img_path, img_name, **kwargs):
|
51 |
+
fig = plt.figure(figsize=(20, 10))
|
52 |
+
ax = plt.axes(projection=ccrs.PlateCarree())
|
53 |
+
|
54 |
+
I = data.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, add_labels=False, rasterized=True, **kwargs)
|
55 |
+
ax.coastlines(resolution='110m')
|
56 |
+
|
57 |
+
dirname = f'{img_path.absolute()}/{img_name}.jpg'
|
58 |
+
|
59 |
+
plt.axis('off')
|
60 |
+
plt.savefig(dirname, bbox_inches='tight', pad_inches=0.)
|
61 |
+
plt.close(fig)
|
62 |
+
|
63 |
+
|
64 |
+
def plot(real_data, pred_data, save_path):
|
65 |
+
cmap_t = 'RdYlBu_r'
|
66 |
+
|
67 |
+
wind = xu.sqrt(real_data['u10'] ** 2 + real_data['v10'] ** 2)
|
68 |
+
wmin, wmax = wind.values.min(), wind.values.max()
|
69 |
+
wind = xu.sqrt(pred_data['u10'] ** 2 + pred_data['v10'] ** 2)
|
70 |
+
wmin, wmax = min(wind.values.min(), wmin), max(wind.values.max(), wmax)
|
71 |
+
|
72 |
+
pmin, pmax = real_data['tp'].values.min(), real_data['tp'].values.max()
|
73 |
+
pmin, pmax = min(pred_data['tp'].values.min(), pmin), max(pred_data['tp'].values.max(), pmax)
|
74 |
+
|
75 |
+
tmin, tmax = real_data['t2m'].values.min(), real_data['t2m'].values.max()
|
76 |
+
tmin, tmax = min(pred_data['t2m'].values.min(), tmin), max(pred_data['t2m'].values.max(), tmax)
|
77 |
+
|
78 |
+
for i in range(len(real_data.time)):
|
79 |
+
u = real_data['u10'].isel(time=i)
|
80 |
+
v = real_data['v10'].isel(time=i)
|
81 |
+
wind = xu.sqrt(u ** 2 + v ** 2)
|
82 |
+
precip = real_data['tp'].isel(time=i)
|
83 |
+
temp = real_data['t2m'].isel(time=i)
|
84 |
+
|
85 |
+
datetime = pd.to_datetime(str(wind['time'].values))
|
86 |
+
datetime = datetime.strftime('%Y-%m-%d %H:%M:%S')
|
87 |
+
print(f'plot {datetime}')
|
88 |
+
|
89 |
+
imcol(wind, save_path, img_name=f'wind_{datetime}_real', cmap=cmap_t, vmin=wmin, vmax=wmax),
|
90 |
+
imcol(precip, save_path, img_name=f'precipitation_{datetime}_real', cmap=cmap_t, vmin=pmin, vmax=pmax),
|
91 |
+
imcol(temp, save_path, img_name=f'temperature_{datetime}_real', cmap=cmap_t, vmin=tmin, vmax=tmax)
|
92 |
+
|
93 |
+
for i in range(len(pred_data.time)):
|
94 |
+
u = pred_data['u10'].isel(time=i)
|
95 |
+
v = pred_data['v10'].isel(time=i)
|
96 |
+
wind = xu.sqrt(u ** 2 + v ** 2)
|
97 |
+
precip = pred_data['tp'].isel(time=i)
|
98 |
+
temp = pred_data['t2m'].isel(time=i)
|
99 |
+
|
100 |
+
datetime = pd.to_datetime(str(wind['time'].values))
|
101 |
+
datetime = datetime.strftime('%Y-%m-%d %H:%M:%S')
|
102 |
+
print(f'plot {datetime}')
|
103 |
+
|
104 |
+
imcol(wind, save_path, img_name=f'wind_{datetime}_pred', cmap=cmap_t, vmin=wmin, vmax=wmax),
|
105 |
+
imcol(precip, save_path, img_name=f'precipitation_{datetime}_pred', cmap=cmap_t, vmin=pmin, vmax=pmax),
|
106 |
+
imcol(temp, save_path, img_name=f'temperature_{datetime}_pred', cmap=cmap_t, vmin=tmin, vmax=tmax)
|
107 |
+
|
108 |
+
|
109 |
+
def get_pred(sample, scaler, times=None, latitude=None, longitude=None):
|
110 |
+
|
111 |
+
backbone_model, precip_model = load_model()
|
112 |
+
|
113 |
+
sample = torch.from_numpy(sample[0])
|
114 |
+
sample = sample.float()
|
115 |
+
|
116 |
+
backbone_model.eval()
|
117 |
+
precip_model.eval()
|
118 |
+
pred = []
|
119 |
+
x = sample.unsqueeze(0).transpose(3, 2).transpose(2, 1)
|
120 |
+
for i in range(len(times)):
|
121 |
+
print(f"predict {times[i]}")
|
122 |
+
|
123 |
+
with torch.cuda.amp.autocast():
|
124 |
+
x = backbone_model(x)
|
125 |
+
tmp = x.transpose(1, 2).transpose(2, 3)
|
126 |
+
p = precip_model(x)
|
127 |
+
|
128 |
+
tmp = tmp.detach().numpy()[0, :, :, :3] * scaler['std'][:3] + scaler['mean'][:3]
|
129 |
+
p = p.detach().numpy()[0, 0, :, :, np.newaxis] * scaler['std'][-1] + scaler['mean'][-1]
|
130 |
+
tmp = np.concatenate([tmp, p], axis=-1)
|
131 |
+
pred.append(tmp)
|
132 |
+
|
133 |
+
pred = np.asarray(pred)
|
134 |
+
pred_data = xr.Dataset({
|
135 |
+
'u10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 0], chunks=(7, 720, 1440))),
|
136 |
+
'v10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 1], chunks=(7, 720, 1440))),
|
137 |
+
't2m': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 2], chunks=(7, 720, 1440))),
|
138 |
+
'tp': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 3], chunks=(7, 720, 1440))),
|
139 |
+
},
|
140 |
+
coords={'time': (['time'], times),
|
141 |
+
'latitude': (['latitude'], latitude),
|
142 |
+
'longitude': (['longitude'], longitude)
|
143 |
+
}
|
144 |
+
)
|
145 |
+
|
146 |
+
return pred_data
|
147 |
+
|
148 |
+
|
149 |
+
def get_data(start_time, end_time):
|
150 |
+
times = slice(start_time, end_time)
|
151 |
+
|
152 |
+
with open(f'./scaler.pkl', "rb") as f: # the mean and std of each atmospheric variables
|
153 |
+
scaler = pickle.load(f)
|
154 |
+
|
155 |
+
# load weather data
|
156 |
+
datas = []
|
157 |
+
for file in DATANAMES:
|
158 |
+
tmp = xr.open_mfdataset(f'./ERA5_rawdata/{file}/*.nc', combine='by_coords').sel(time=times)
|
159 |
+
if '@' in file:
|
160 |
+
k, v = file.split('@')
|
161 |
+
tmp = tmp.rename_vars({DATAMAP[k]: f'{DATAMAP[k]}@{v}'})
|
162 |
+
datas.append(tmp)
|
163 |
+
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
|
164 |
+
raw_data = xr.merge(datas, compat="identical", join="inner")
|
165 |
+
|
166 |
+
data = []
|
167 |
+
for name in ['u10', 'v10', 't2m', 'z@1000', 'z@50', 'z@500', 'z@850', 'msl', 'r@500', 'r@850', 'sp', 't@500', 't@850', 'tcwv', 'u@1000', 'u@500', 'u@850', 'v@1000', 'v@500', 'v@850']:
|
168 |
+
raw = raw_data[name].values
|
169 |
+
data.append(raw)
|
170 |
+
|
171 |
+
data = np.stack(data, axis=-1)
|
172 |
+
data = (data - scaler['mean']) / scaler['std']
|
173 |
+
data = data[:, 1:, :, :] # 721*1440 -> 720*1440
|
174 |
+
|
175 |
+
return raw_data[['u10', 'v10', 't2m', 'tp']].sel(expver=1), data, scaler
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == '__main__':
|
180 |
+
|
181 |
+
start_time = datetime(2023, 1, 1, 0, 0)
|
182 |
+
end_time = datetime(2023, 1, 5, 18, 0)
|
183 |
+
num = int((end_time - start_time) / timedelta(hours=6))
|
184 |
+
|
185 |
+
print(f"start_time: {start_time}, end_time: {end_time}, pred_num: {num}")
|
186 |
+
|
187 |
+
real_data, sample, scaler = get_data(start_time)
|
188 |
+
print(sample.shape)
|
189 |
+
|
190 |
+
pred_times = [start_time + timedelta(hours=6) * i for i in range(1, num)]
|
191 |
+
pred = get_pred(sample, scaler=scaler, times=pred_times, latitude=real_data.latitude[1:], longitude=real_data.longitude)
|
192 |
+
|
193 |
+
save_path = Path(f"./output/")
|
194 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
195 |
+
|
196 |
+
plot(real_data, pred, save_path)
|