OpenCastKit / infer2img.py
VachelHu's picture
Upload infer2img.py
b310f83
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from functools import partial
from datetime import datetime, timedelta
from pathlib import Path
import pickle
import dask
import dask.array as da
import cartopy
import cartopy.crs as ccrs
import xarray as xr
import xarray.ufuncs as xu
import matplotlib.pyplot as plt
from model.afnonet import AFNONet # download the model code from https://github.com/HFAiLab/FourCastNet/blob/master/model/afnonet.py
DATANAMES = ['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature',
'geopotential@1000', 'geopotential@50', 'geopotential@500', 'geopotential@850',
'mean_sea_level_pressure', 'relative_humidity@500', 'relative_humidity@850',
'surface_pressure', 'temperature@500', 'temperature@850', 'total_column_water_vapour',
'u_component_of_wind@1000', 'u_component_of_wind@500', 'u_component_of_wind@850',
'v_component_of_wind@1000', 'v_component_of_wind@500', 'v_component_of_wind@850',
'total_precipitation']
DATAMAP = {
'geopotential': 'z',
'relative_humidity': 'r',
'temperature': 't',
'u_component_of_wind': 'u',
'v_component_of_wind': 'v'
}
def load_model():
# input size
h, w = 720, 1440
x_c, y_c, p_c = 20, 20, 1
backbone_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=y_c, norm_layer=partial(nn.LayerNorm, eps=1e-6))
ckpt = torch.load('./backbone.pt', map_location="cpu")
backbone_model.load_state_dict(ckpt['model'])
precip_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=p_c, norm_layer=partial(nn.LayerNorm, eps=1e-6))
ckpt = torch.load('./precipitation.pt', map_location="cpu")
precip_model.load_state_dict(ckpt['model'])
def imcol(data, img_path, img_name, **kwargs):
fig = plt.figure(figsize=(20, 10))
ax = plt.axes(projection=ccrs.PlateCarree())
I = data.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, add_labels=False, rasterized=True, **kwargs)
ax.coastlines(resolution='110m')
dirname = f'{img_path.absolute()}/{img_name}.jpg'
plt.axis('off')
plt.savefig(dirname, bbox_inches='tight', pad_inches=0.)
plt.close(fig)
def plot(real_data, pred_data, save_path):
cmap_t = 'RdYlBu_r'
wind = xu.sqrt(real_data['u10'] ** 2 + real_data['v10'] ** 2)
wmin, wmax = wind.values.min(), wind.values.max()
wind = xu.sqrt(pred_data['u10'] ** 2 + pred_data['v10'] ** 2)
wmin, wmax = min(wind.values.min(), wmin), max(wind.values.max(), wmax)
pmin, pmax = real_data['tp'].values.min(), real_data['tp'].values.max()
pmin, pmax = min(pred_data['tp'].values.min(), pmin), max(pred_data['tp'].values.max(), pmax)
tmin, tmax = real_data['t2m'].values.min(), real_data['t2m'].values.max()
tmin, tmax = min(pred_data['t2m'].values.min(), tmin), max(pred_data['t2m'].values.max(), tmax)
for i in range(len(real_data.time)):
u = real_data['u10'].isel(time=i)
v = real_data['v10'].isel(time=i)
wind = xu.sqrt(u ** 2 + v ** 2)
precip = real_data['tp'].isel(time=i)
temp = real_data['t2m'].isel(time=i)
datetime = pd.to_datetime(str(wind['time'].values))
datetime = datetime.strftime('%Y-%m-%d %H:%M:%S')
print(f'plot {datetime}')
imcol(wind, save_path, img_name=f'wind_{datetime}_real', cmap=cmap_t, vmin=wmin, vmax=wmax),
imcol(precip, save_path, img_name=f'precipitation_{datetime}_real', cmap=cmap_t, vmin=pmin, vmax=pmax),
imcol(temp, save_path, img_name=f'temperature_{datetime}_real', cmap=cmap_t, vmin=tmin, vmax=tmax)
for i in range(len(pred_data.time)):
u = pred_data['u10'].isel(time=i)
v = pred_data['v10'].isel(time=i)
wind = xu.sqrt(u ** 2 + v ** 2)
precip = pred_data['tp'].isel(time=i)
temp = pred_data['t2m'].isel(time=i)
datetime = pd.to_datetime(str(wind['time'].values))
datetime = datetime.strftime('%Y-%m-%d %H:%M:%S')
print(f'plot {datetime}')
imcol(wind, save_path, img_name=f'wind_{datetime}_pred', cmap=cmap_t, vmin=wmin, vmax=wmax),
imcol(precip, save_path, img_name=f'precipitation_{datetime}_pred', cmap=cmap_t, vmin=pmin, vmax=pmax),
imcol(temp, save_path, img_name=f'temperature_{datetime}_pred', cmap=cmap_t, vmin=tmin, vmax=tmax)
def get_pred(sample, scaler, times=None, latitude=None, longitude=None):
backbone_model, precip_model = load_model()
sample = torch.from_numpy(sample[0])
sample = sample.float()
backbone_model.eval()
precip_model.eval()
pred = []
x = sample.unsqueeze(0).transpose(3, 2).transpose(2, 1)
for i in range(len(times)):
print(f"predict {times[i]}")
with torch.cuda.amp.autocast():
x = backbone_model(x)
tmp = x.transpose(1, 2).transpose(2, 3)
p = precip_model(x)
tmp = tmp.detach().numpy()[0, :, :, :3] * scaler['std'][:3] + scaler['mean'][:3]
p = p.detach().numpy()[0, 0, :, :, np.newaxis] * scaler['std'][-1] + scaler['mean'][-1]
tmp = np.concatenate([tmp, p], axis=-1)
pred.append(tmp)
pred = np.asarray(pred)
pred_data = xr.Dataset({
'u10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 0], chunks=(7, 720, 1440))),
'v10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 1], chunks=(7, 720, 1440))),
't2m': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 2], chunks=(7, 720, 1440))),
'tp': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 3], chunks=(7, 720, 1440))),
},
coords={'time': (['time'], times),
'latitude': (['latitude'], latitude),
'longitude': (['longitude'], longitude)
}
)
return pred_data
def get_data(start_time, end_time):
times = slice(start_time, end_time)
with open(f'./scaler.pkl', "rb") as f: # the mean and std of each atmospheric variables
scaler = pickle.load(f)
# load weather data
datas = []
for file in DATANAMES:
tmp = xr.open_mfdataset(f'./ERA5_rawdata/{file}/*.nc', combine='by_coords').sel(time=times)
if '@' in file:
k, v = file.split('@')
tmp = tmp.rename_vars({DATAMAP[k]: f'{DATAMAP[k]}@{v}'})
datas.append(tmp)
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
raw_data = xr.merge(datas, compat="identical", join="inner")
data = []
for name in ['u10', 'v10', 't2m', 'z@1000', 'z@50', 'z@500', 'z@850', 'msl', 'r@500', 'r@850', 'sp', 't@500', 't@850', 'tcwv', 'u@1000', 'u@500', 'u@850', 'v@1000', 'v@500', 'v@850']:
raw = raw_data[name].values
data.append(raw)
data = np.stack(data, axis=-1)
data = (data - scaler['mean']) / scaler['std']
data = data[:, 1:, :, :] # 721*1440 -> 720*1440
return raw_data[['u10', 'v10', 't2m', 'tp']].sel(expver=1), data, scaler
if __name__ == '__main__':
start_time = datetime(2023, 1, 1, 0, 0)
end_time = datetime(2023, 1, 5, 18, 0)
num = int((end_time - start_time) / timedelta(hours=6))
print(f"start_time: {start_time}, end_time: {end_time}, pred_num: {num}")
real_data, sample, scaler = get_data(start_time)
print(sample.shape)
pred_times = [start_time + timedelta(hours=6) * i for i in range(1, num)]
pred = get_pred(sample, scaler=scaler, times=pred_times, latitude=real_data.latitude[1:], longitude=real_data.longitude)
save_path = Path(f"./output/")
save_path.mkdir(parents=True, exist_ok=True)
plot(real_data, pred, save_path)