import torch import torch.nn as nn import pandas as pd import numpy as np from functools import partial from datetime import datetime, timedelta from pathlib import Path import pickle import dask import dask.array as da import cartopy import as ccrs import xarray as xr import xarray.ufuncs as xu import matplotlib.pyplot as plt from model.afnonet import AFNONet # download the model code from DATANAMES = ['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature', 'geopotential@1000', 'geopotential@50', 'geopotential@500', 'geopotential@850', 'mean_sea_level_pressure', 'relative_humidity@500', 'relative_humidity@850', 'surface_pressure', 'temperature@500', 'temperature@850', 'total_column_water_vapour', 'u_component_of_wind@1000', 'u_component_of_wind@500', 'u_component_of_wind@850', 'v_component_of_wind@1000', 'v_component_of_wind@500', 'v_component_of_wind@850', 'total_precipitation'] DATAMAP = { 'geopotential': 'z', 'relative_humidity': 'r', 'temperature': 't', 'u_component_of_wind': 'u', 'v_component_of_wind': 'v' } def load_model(): # input size h, w = 720, 1440 x_c, y_c, p_c = 20, 20, 1 backbone_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=y_c, norm_layer=partial(nn.LayerNorm, eps=1e-6)) ckpt = torch.load('./', map_location="cpu") backbone_model.load_state_dict(ckpt['model']) precip_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=p_c, norm_layer=partial(nn.LayerNorm, eps=1e-6)) ckpt = torch.load('./', map_location="cpu") precip_model.load_state_dict(ckpt['model']) def imcol(data, img_path, img_name, **kwargs): fig = plt.figure(figsize=(20, 10)) ax = plt.axes(projection=ccrs.PlateCarree()) I = data.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, add_labels=False, rasterized=True, **kwargs) ax.coastlines(resolution='110m') dirname = f'{img_path.absolute()}/{img_name}.jpg' plt.axis('off') plt.savefig(dirname, bbox_inches='tight', pad_inches=0.) plt.close(fig) def plot(real_data, pred_data, save_path): cmap_t = 'RdYlBu_r' wind = xu.sqrt(real_data['u10'] ** 2 + real_data['v10'] ** 2) wmin, wmax = wind.values.min(), wind.values.max() wind = xu.sqrt(pred_data['u10'] ** 2 + pred_data['v10'] ** 2) wmin, wmax = min(wind.values.min(), wmin), max(wind.values.max(), wmax) pmin, pmax = real_data['tp'].values.min(), real_data['tp'].values.max() pmin, pmax = min(pred_data['tp'].values.min(), pmin), max(pred_data['tp'].values.max(), pmax) tmin, tmax = real_data['t2m'].values.min(), real_data['t2m'].values.max() tmin, tmax = min(pred_data['t2m'].values.min(), tmin), max(pred_data['t2m'].values.max(), tmax) for i in range(len(real_data.time)): u = real_data['u10'].isel(time=i) v = real_data['v10'].isel(time=i) wind = xu.sqrt(u ** 2 + v ** 2) precip = real_data['tp'].isel(time=i) temp = real_data['t2m'].isel(time=i) datetime = pd.to_datetime(str(wind['time'].values)) datetime = datetime.strftime('%Y-%m-%d %H:%M:%S') print(f'plot {datetime}') imcol(wind, save_path, img_name=f'wind_{datetime}_real', cmap=cmap_t, vmin=wmin, vmax=wmax), imcol(precip, save_path, img_name=f'precipitation_{datetime}_real', cmap=cmap_t, vmin=pmin, vmax=pmax), imcol(temp, save_path, img_name=f'temperature_{datetime}_real', cmap=cmap_t, vmin=tmin, vmax=tmax) for i in range(len(pred_data.time)): u = pred_data['u10'].isel(time=i) v = pred_data['v10'].isel(time=i) wind = xu.sqrt(u ** 2 + v ** 2) precip = pred_data['tp'].isel(time=i) temp = pred_data['t2m'].isel(time=i) datetime = pd.to_datetime(str(wind['time'].values)) datetime = datetime.strftime('%Y-%m-%d %H:%M:%S') print(f'plot {datetime}') imcol(wind, save_path, img_name=f'wind_{datetime}_pred', cmap=cmap_t, vmin=wmin, vmax=wmax), imcol(precip, save_path, img_name=f'precipitation_{datetime}_pred', cmap=cmap_t, vmin=pmin, vmax=pmax), imcol(temp, save_path, img_name=f'temperature_{datetime}_pred', cmap=cmap_t, vmin=tmin, vmax=tmax) def get_pred(sample, scaler, times=None, latitude=None, longitude=None): backbone_model, precip_model = load_model() sample = torch.from_numpy(sample[0]) sample = sample.float() backbone_model.eval() precip_model.eval() pred = [] x = sample.unsqueeze(0).transpose(3, 2).transpose(2, 1) for i in range(len(times)): print(f"predict {times[i]}") with torch.cuda.amp.autocast(): x = backbone_model(x) tmp = x.transpose(1, 2).transpose(2, 3) p = precip_model(x) tmp = tmp.detach().numpy()[0, :, :, :3] * scaler['std'][:3] + scaler['mean'][:3] p = p.detach().numpy()[0, 0, :, :, np.newaxis] * scaler['std'][-1] + scaler['mean'][-1] tmp = np.concatenate([tmp, p], axis=-1) pred.append(tmp) pred = np.asarray(pred) pred_data = xr.Dataset({ 'u10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 0], chunks=(7, 720, 1440))), 'v10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 1], chunks=(7, 720, 1440))), 't2m': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 2], chunks=(7, 720, 1440))), 'tp': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 3], chunks=(7, 720, 1440))), }, coords={'time': (['time'], times), 'latitude': (['latitude'], latitude), 'longitude': (['longitude'], longitude) } ) return pred_data def get_data(start_time, end_time): times = slice(start_time, end_time) with open(f'./scaler.pkl', "rb") as f: # the mean and std of each atmospheric variables scaler = pickle.load(f) # load weather data datas = [] for file in DATANAMES: tmp = xr.open_mfdataset(f'./ERA5_rawdata/{file}/*.nc', combine='by_coords').sel(time=times) if '@' in file: k, v = file.split('@') tmp = tmp.rename_vars({DATAMAP[k]: f'{DATAMAP[k]}@{v}'}) datas.append(tmp) with dask.config.set(**{'array.slicing.split_large_chunks': False}): raw_data = xr.merge(datas, compat="identical", join="inner") data = [] for name in ['u10', 'v10', 't2m', 'z@1000', 'z@50', 'z@500', 'z@850', 'msl', 'r@500', 'r@850', 'sp', 't@500', 't@850', 'tcwv', 'u@1000', 'u@500', 'u@850', 'v@1000', 'v@500', 'v@850']: raw = raw_data[name].values data.append(raw) data = np.stack(data, axis=-1) data = (data - scaler['mean']) / scaler['std'] data = data[:, 1:, :, :] # 721*1440 -> 720*1440 return raw_data[['u10', 'v10', 't2m', 'tp']].sel(expver=1), data, scaler if __name__ == '__main__': start_time = datetime(2023, 1, 1, 0, 0) end_time = datetime(2023, 1, 5, 18, 0) num = int((end_time - start_time) / timedelta(hours=6)) print(f"start_time: {start_time}, end_time: {end_time}, pred_num: {num}") real_data, sample, scaler = get_data(start_time) print(sample.shape) pred_times = [start_time + timedelta(hours=6) * i for i in range(1, num)] pred = get_pred(sample, scaler=scaler, times=pred_times, latitude=real_data.latitude[1:], longitude=real_data.longitude) save_path = Path(f"./output/") save_path.mkdir(parents=True, exist_ok=True) plot(real_data, pred, save_path)