| import json |
| import os |
| import numpy as np |
| from nowcasting.config import cfg |
| from nowcasting.helpers.visualization import save_hko_movie |
| from nowcasting.hko_iterator import HKOIterator |
| from nowcasting.hko_evaluation import HKOEvaluation |
|
|
| class HKOBenchmarkEnv(object): |
| """The Benchmark environment for the HKO7 Dataset |
| |
| There are two settings for the Benchmark, the "fixed" setting and the "online" setting. |
| In the "fixed" setting, pre-defined input sequences that have the same length will be |
| fed into the model for prediction. |
| This setting tests the model's ability to use the instant past to predict the future. |
| In the "online" setting, M frames will be given each time and the forecasting model |
| is required to predict the next K frames every stride steps. |
| If the begin_new_episode flag is turned on, a new episode has begun, which means that the current received images have no relationship with the previous images. |
| If the need_upload_prediction flag is turned on, the model is required to predict the |
| This setting tests both the model's ability to adapt in an online fashion and |
| the ability to capture the long-term dependency. |
| The input frame will be missing in some timestamps. |
| |
| To run the benchmark in the fixed setting: |
| |
| env = HKOBenchmarkEnv(...) |
| while not env.done: |
| # Get the observation |
| in_frame_dat, in_mask_dat, in_datetime_clips, out_datetime_clips, begin_new_episode = |
| env.get_observation(batch_size=batch_size) |
| # Running your algorithm to get the prediction |
| prediction = ... |
| # Upload prediction to the environment |
| env.upload_prediction(prediction) |
| |
| """ |
| def __init__(self, |
| pd_path, |
| save_dir="hko7_benchmark", |
| mode="fixed"): |
| assert mode == "fixed" or mode == "online" |
| self._pd_path = pd_path |
| self._save_dir = save_dir |
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir) |
| self._mode = mode |
| self._out_seq_len = cfg.HKO.BENCHMARK.OUT_LEN |
| self._stride = cfg.HKO.BENCHMARK.STRIDE |
| if mode == "fixed": |
| self._in_seq_len = cfg.HKO.BENCHMARK.IN_LEN |
| else: |
| self._in_seq_len = cfg.HKO.BENCHMARK.STRIDE |
| self._hko_iter = HKOIterator(pd_path=pd_path, |
| sample_mode="sequent", |
| seq_len=self._in_seq_len + self._out_seq_len, |
| stride=self._stride) |
| self._stat_dict = self._get_benchmark_stat() |
| self._begin_new_episode = True |
| self._received_pred_seq_num = 0 |
| self._need_upload_prediction = False |
| |
|
|
| self._save_seq_inds = set(np.arange(1, cfg.HKO.BENCHMARK.VISUALIZE_SEQ_NUM + 1) * \ |
| (self._stat_dict['pred_seq_num'] // |
| cfg.HKO.BENCHMARK.VISUALIZE_SEQ_NUM)) |
| self._all_eval = HKOEvaluation(seq_len=self._out_seq_len, use_central=False) |
|
|
| |
|
|
|
|
| |
| self._in_frame_dat = None |
| self._in_mask_dat = None |
| self._in_datetime_clips = None |
| self._out_frame_dat = None |
| self._out_mask_dat = None |
| self._out_datetime_clips = None |
|
|
| def reset(self): |
| self._hko_iter.reset() |
| self._all_eval.clear_all() |
| self._begin_new_episode = True |
| self._received_pred_seq_num = 0 |
| self._need_upload_prediction = False |
|
|
| @property |
| def _fingerprint(self): |
| pd_file_name = os.path.splitext(os.path.basename(self._pd_path))[0] |
| if self._mode == "fixed": |
| fingerprint = pd_file_name + "_in" + str(self._in_seq_len)\ |
| + "_out" + str(self._out_seq_len) + "_stride" + str(self._stride)\ |
| + "_" + self._mode |
| else: |
| fingerprint = pd_file_name + "_out" + str(self._out_seq_len)\ |
| + "_stride" + str(self._stride)\ |
| + "_" + self._mode |
| return fingerprint |
|
|
| @property |
| def _stat_filepath(self): |
| filename = self._fingerprint + ".json" |
| return os.path.join(cfg.HKO.BENCHMARK.STAT_PATH, filename) |
|
|
| def _get_benchmark_stat(self): |
| """Get the general statistics of the benchmark |
| |
| Returns |
| ------- |
| stat_dict : dict |
| 'pred_seq_num' --> Total number of predictions the model needs to make |
| """ |
| if os.path.exists(self._stat_filepath): |
| stat_dict = json.load(open(self._stat_filepath)) |
| else: |
| seq_num = 0 |
| episode_num = 0 |
| episode_start_datetime = [] |
| while not self._hko_iter.use_up: |
| if self._mode == "fixed": |
| datetime_clips, new_start =\ |
| self._hko_iter.sample(batch_size=1024, only_return_datetime=True) |
| if len(datetime_clips) == 0: |
| continue |
| seq_num += len(datetime_clips) |
| episode_num += len(datetime_clips) |
| elif self._mode == "online": |
| datetime_clips, new_start = \ |
| self._hko_iter.sample(batch_size=1, only_return_datetime=True) |
| if len(datetime_clips) == 0: |
| continue |
| episode_num += new_start |
| if new_start: |
| episode_start_datetime.append(datetime_clips[0][0].strftime('%Y%m%d%H%M')) |
| if self._stride != 1: |
| seq_num += 1 |
| else: |
| seq_num += 1 |
| print(self._fingerprint, seq_num, episode_num) |
| self._hko_iter.reset() |
| stat_dict = {'pred_seq_num': seq_num, |
| 'episode_num': episode_num, |
| 'episode_start_datetime': episode_start_datetime} |
| json.dump(stat_dict, open(self._stat_filepath, 'w'), indent=3) |
| return stat_dict |
|
|
| @property |
| def done(self): |
| return self._received_pred_seq_num >= self._stat_dict["pred_seq_num"] |
|
|
| def get_observation(self, batch_size=1): |
| """ |
| |
| Parameters |
| ---------- |
| batch_size : int |
| |
| |
| Returns |
| ------- |
| in_frame_dat : np.ndarray |
| Will be between 0 and 1 |
| in_datetime_clips : list |
| out_datetime_clips : list |
| begin_new_episode : bool |
| need_upload_prediction : bool |
| """ |
| if self._mode == "online": |
| assert batch_size == 1 |
| assert not self._need_upload_prediction |
| assert not self.done |
| assert not self._hko_iter.use_up |
| while True: |
| frame_dat, mask_dat, datetime_clips, new_start =\ |
| self._hko_iter.sample(batch_size=batch_size, only_return_datetime=False) |
| if len(datetime_clips) == 0: |
| continue |
| else: |
| break |
| frame_dat = frame_dat.astype(np.float32) / 255.0 |
| self._need_upload_prediction = True |
| if self._mode == "online": |
| self._begin_new_episode = new_start |
| if new_start and self._stride == 1: |
| self._need_upload_prediction = False |
| else: |
| self._begin_new_episode = True |
| self._in_datetime_clips = [ele[:self._in_seq_len] for ele in datetime_clips] |
| self._out_datetime_clips = [ele[self._in_seq_len:(self._in_seq_len + |
| self._out_seq_len)] |
| for ele in datetime_clips] |
| self._in_frame_dat = frame_dat[:self._in_seq_len, ...] |
| self._out_frame_dat = frame_dat[self._in_seq_len:(self._in_seq_len + self._out_seq_len), |
| ...] |
| self._in_mask_dat = mask_dat[:self._in_seq_len, ...] |
| self._out_mask_dat = mask_dat[self._in_seq_len:(self._in_seq_len + self._out_seq_len), ...] |
| return self._in_frame_dat,\ |
| self._in_datetime_clips,\ |
| self._out_datetime_clips,\ |
| self._begin_new_episode, \ |
| self._need_upload_prediction |
|
|
| def upload_prediction(self, prediction): |
| """ |
| |
| Parameters |
| ---------- |
| prediction : np.ndarray |
| |
| """ |
| assert self._need_upload_prediction, "Must call get_observation first!" \ |
| " Also, check the value of need_upload_predction" \ |
| " after calling" |
| self._need_upload_prediction = False |
| received_seq_inds = range(self._received_pred_seq_num, |
| self._received_pred_seq_num + prediction.shape[1]) |
| save_ind_set = set(received_seq_inds).intersection(self._save_seq_inds) |
| if len(save_ind_set) > 0: |
| assert len(save_ind_set) == 1 |
| ind = save_ind_set.pop() |
| ind -= self._received_pred_seq_num |
| if not os.path.exists(os.path.join(self._save_dir, self._fingerprint)): |
| os.makedirs(os.path.join(self._save_dir, self._fingerprint)) |
| print("Saving prediction videos to %s" % os.path.join(self._save_dir, |
| self._fingerprint)) |
| save_hko_movie(im_dat=self._in_frame_dat[:, ind, 0, ...], |
| mask_dat=self._in_mask_dat[:, ind, 0, :, :], |
| datetime_list=self._in_datetime_clips[ind], |
| save_path=os.path.join(self._save_dir, self._fingerprint, |
| "%s_in.mp4" % |
| self._in_datetime_clips[ind][0] |
| .strftime('%Y%m%d%H%M'))) |
| save_hko_movie(im_dat=self._out_frame_dat[:, ind, 0, ...], |
| mask_dat=self._out_mask_dat[:, ind, 0, :, :], |
| masked=False, |
| datetime_list=self._out_datetime_clips[ind], |
| save_path=os.path.join(self._save_dir, self._fingerprint, |
| "%s_out.mp4" % |
| self._in_datetime_clips[ind][0] |
| .strftime('%Y%m%d%H%M'))) |
| save_hko_movie(im_dat=prediction[:, ind, 0, ...], |
| mask_dat=self._out_mask_dat[:, ind, 0, :, :], |
| masked=False, |
| datetime_list=self._out_datetime_clips[ind], |
| save_path=os.path.join(self._save_dir, self._fingerprint, |
| "%s_pred.mp4" % |
| self._in_datetime_clips[ind][0] |
| .strftime('%Y%m%d%H%M'))) |
| self._received_pred_seq_num += prediction.shape[1] |
| if self._mode == "online": |
| if self._stride == 1: |
| assert not self._begin_new_episode |
| self._all_eval.update(gt=self._out_frame_dat, |
| pred=prediction, |
| mask=self._out_mask_dat, |
| start_datetimes=[ele[0] for ele in self._out_datetime_clips]) |
|
|
| def print_stat_readable(self): |
| self._all_eval.print_stat_readable(prefix="Received:%d " %self._received_pred_seq_num) |
|
|
| def save_eval(self): |
| assert self._received_pred_seq_num == self._stat_dict['pred_seq_num'],\ |
| "Must upload all the predictions to the testbed!" |
| print("Saving evaluation result to %s" %os.path.join(self._save_dir, self._fingerprint)) |
| self._all_eval.save(prefix=os.path.join(self._save_dir, self._fingerprint, "eval_all")) |
|
|
|
|
| if __name__ == '__main__': |
| env_fixed = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.RAINY_TEST, mode="fixed") |
| env_online = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.RAINY_TEST, mode="online") |
| print("Fixed Rainy Test SeqNum:", env_fixed._stat_dict['pred_seq_num']) |
| print("Online Rainy Test SeqNum:", env_online._stat_dict['pred_seq_num']) |
| env_fixed = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.RAINY_VALID, mode="fixed") |
| env_online = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.RAINY_VALID, mode="online") |
| print("Fixed Rainy Valid SeqNum:", env_fixed._stat_dict['pred_seq_num']) |
| print("Online Rainy Valid SeqNum:", env_online._stat_dict['pred_seq_num']) |
| env_fixed = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.ALL_15, mode="fixed") |
| env_online = HKOBenchmarkEnv(pd_path=cfg.HKO_PD.ALL_15, mode="online") |
| print("Fixed Rainy2015 ALL SeqNum:", env_fixed._stat_dict['pred_seq_num']) |
| print("Online Rainy2015 ALL SeqNum:", env_online._stat_dict['pred_seq_num']) |
|
|