|
|
''' |
|
|
dutils.py |
|
|
A utility library for customized data loading functions |
|
|
''' |
|
|
import os |
|
|
import gzip |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
import os |
|
|
import cv2 |
|
|
from typing import List, Union, Dict, Sequence |
|
|
import numpy as np |
|
|
import numpy.random as nprand |
|
|
import datetime |
|
|
import pandas as pd |
|
|
import h5py |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.nn.functional import avg_pool2d |
|
|
import random |
|
|
from torchvision import transforms as T |
|
|
from torchvision import datasets |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from PIL import Image |
|
|
|
|
|
SEVIR_ROOT_DIR = "data/SEVIR" |
|
|
METEO_FILE_DIR = "data/meteonet" |
|
|
|
|
|
def resize(seq, size): |
|
|
|
|
|
seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) |
|
|
seq = seq.clamp(0,1) |
|
|
return seq.unsqueeze(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pixel_to_dBZ_nonlinear(img): |
|
|
''' |
|
|
[0, 255] OR [0, 1] pixel => [0, 80] dBZ |
|
|
''' |
|
|
if img.mean() > 1.0: |
|
|
img = img / 255.0 |
|
|
ashift = 31.0 |
|
|
afact = 4.0 |
|
|
atan_dBZ_min = -1.482 |
|
|
atan_dBZ_max = 1.412 |
|
|
tan_pix = np.tan(img * (atan_dBZ_max - atan_dBZ_min) + atan_dBZ_min) |
|
|
return tan_pix * afact + ashift |
|
|
|
|
|
def dbZ_to_pixel_nonlinear(dbZ): |
|
|
''' |
|
|
[0, 80] dBZ => [0, 255] OR [0, 1] pixel |
|
|
''' |
|
|
ashift = 31.0 |
|
|
afact = 4.0 |
|
|
atan_dBZ_min = -1.482 |
|
|
atan_dBZ_max = 1.412 |
|
|
dbZ_adjusted = (dbZ - ashift) / afact |
|
|
return (np.arctan(dbZ_adjusted) - atan_dBZ_min) / (atan_dBZ_max - atan_dBZ_min) |
|
|
|
|
|
def dbZ_to_pixel(dbZ): |
|
|
''' |
|
|
[0, 80] dbZ => [0, 1] pixel |
|
|
''' |
|
|
return np.floor((dbZ + 10) * 255 / 70 + 0.5) / 255.0 |
|
|
|
|
|
def pixel_to_dBZ(pixel): |
|
|
''' |
|
|
[0, 255] (or [0, 1]) pixel => [0, 80] dBZ |
|
|
''' |
|
|
if pixel.mean() > 1.0: |
|
|
pixel = pixel / 255.0 |
|
|
return (70 * pixel) - 10 |
|
|
|
|
|
def nonlinear_to_linear(im): |
|
|
return dbZ_to_pixel(pixel_to_dBZ_nonlinear(im)) |
|
|
|
|
|
def nonlinear_to_linear_batched(seq, datetime): |
|
|
seq_linear = np.zeros_like(seq) |
|
|
for i, (seq_b, dt_b) in enumerate(zip(seq, datetime)): |
|
|
if dt_b[0].year >= 2016: |
|
|
seq_linear[i] = nonlinear_to_linear(seq_b) |
|
|
else: |
|
|
seq_linear[i] = seq_b |
|
|
seq_linear = np.clip(seq_linear, 0.0, 1.0) |
|
|
return seq_linear |
|
|
|
|
|
def linear_to_nonlinear(im): |
|
|
return dbZ_to_pixel_nonlinear(pixel_to_dBZ(im)) |
|
|
|
|
|
def linear_to_nonlinear_batched(seq, datetime): |
|
|
seq_nonlinear = np.zeros_like(seq) |
|
|
for i, (seq_b, dt_b) in enumerate(zip(seq, datetime)): |
|
|
if dt_b[0].year < 2016: |
|
|
seq_nonlinear[i] = linear_to_nonlinear(seq_b) |
|
|
else: |
|
|
seq_nonlinear[i] = seq_b |
|
|
seq_nonlinear = np.clip(seq_nonlinear, 0.0, 1.0) |
|
|
return seq_nonlinear |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEVIR_DATA_TYPES = ['vis', 'ir069', 'ir107', 'vil', 'lght'] |
|
|
SEVIR_RAW_DTYPES = {'vis': np.int16, |
|
|
'ir069': np.int16, |
|
|
'ir107': np.int16, |
|
|
'vil': np.uint8, |
|
|
'lght': np.int16} |
|
|
LIGHTING_FRAME_TIMES = np.arange(- 120.0, 125.0, 5) * 60 |
|
|
SEVIR_DATA_SHAPE = {'lght': (48, 48), } |
|
|
PREPROCESS_SCALE_SEVIR = {'vis': 1, |
|
|
'ir069': 1 / 1174.68, |
|
|
'ir107': 1 / 2562.43, |
|
|
'vil': 1 / 47.54, |
|
|
'lght': 1 / 0.60517} |
|
|
PREPROCESS_OFFSET_SEVIR = {'vis': 0, |
|
|
'ir069': 3683.58, |
|
|
'ir107': 1552.80, |
|
|
'vil': - 33.44, |
|
|
'lght': - 0.02990} |
|
|
PREPROCESS_SCALE_01 = {'vis': 1, |
|
|
'ir069': 1, |
|
|
'ir107': 1, |
|
|
'vil': 1 / 255, |
|
|
'lght': 1} |
|
|
PREPROCESS_OFFSET_01 = {'vis': 0, |
|
|
'ir069': 0, |
|
|
'ir107': 0, |
|
|
'vil': 0, |
|
|
'lght': 0} |
|
|
|
|
|
|
|
|
SEVIR_CATALOG = os.path.join(SEVIR_ROOT_DIR, "CATALOG.csv") |
|
|
SEVIR_DATA_DIR = os.path.join(SEVIR_ROOT_DIR, "data") |
|
|
SEVIR_RAW_SEQ_LEN = 49 |
|
|
|
|
|
SEVIR_TRAIN_VAL_SPLIT_DATE = datetime.datetime(2019, 1, 1) |
|
|
SEVIR_TRAIN_TEST_SPLIT_DATE = datetime.datetime(2019, 6, 1) |
|
|
|
|
|
def change_layout_np(data, |
|
|
in_layout='NHWT', out_layout='NHWT', |
|
|
ret_contiguous=False): |
|
|
|
|
|
if in_layout == 'NHWT': |
|
|
pass |
|
|
elif in_layout == 'NTHW': |
|
|
data = np.transpose(data, |
|
|
axes=(0, 2, 3, 1)) |
|
|
elif in_layout == 'NWHT': |
|
|
data = np.transpose(data, |
|
|
axes=(0, 2, 1, 3)) |
|
|
elif in_layout == 'NTCHW': |
|
|
data = data[:, :, 0, :, :] |
|
|
data = np.transpose(data, |
|
|
axes=(0, 2, 3, 1)) |
|
|
elif in_layout == 'NTHWC': |
|
|
data = data[:, :, :, :, 0] |
|
|
data = np.transpose(data, |
|
|
axes=(0, 2, 3, 1)) |
|
|
elif in_layout == 'NTWHC': |
|
|
data = data[:, :, :, :, 0] |
|
|
data = np.transpose(data, |
|
|
axes=(0, 3, 2, 1)) |
|
|
elif in_layout == 'TNHW': |
|
|
data = np.transpose(data, |
|
|
axes=(1, 2, 3, 0)) |
|
|
elif in_layout == 'TNCHW': |
|
|
data = data[:, :, 0, :, :] |
|
|
data = np.transpose(data, |
|
|
axes=(1, 2, 3, 0)) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if out_layout == 'NHWT': |
|
|
pass |
|
|
elif out_layout == 'NTHW': |
|
|
data = np.transpose(data, |
|
|
axes=(0, 3, 1, 2)) |
|
|
elif out_layout == 'NWHT': |
|
|
data = np.transpose(data, |
|
|
axes=(0, 2, 1, 3)) |
|
|
elif out_layout == 'NTCHW': |
|
|
data = np.transpose(data, |
|
|
axes=(0, 3, 1, 2)) |
|
|
data = np.expand_dims(data, axis=2) |
|
|
elif out_layout == 'NTHWC': |
|
|
data = np.transpose(data, |
|
|
axes=(0, 3, 1, 2)) |
|
|
data = np.expand_dims(data, axis=-1) |
|
|
elif out_layout == 'NTWHC': |
|
|
data = np.transpose(data, |
|
|
axes=(0, 3, 2, 1)) |
|
|
data = np.expand_dims(data, axis=-1) |
|
|
elif out_layout == 'TNHW': |
|
|
data = np.transpose(data, |
|
|
axes=(3, 0, 1, 2)) |
|
|
elif out_layout == 'TNCHW': |
|
|
data = np.transpose(data, |
|
|
axes=(3, 0, 1, 2)) |
|
|
data = np.expand_dims(data, axis=2) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
if ret_contiguous: |
|
|
data = data.ascontiguousarray() |
|
|
return data |
|
|
|
|
|
def change_layout_torch(data, |
|
|
in_layout='NHWT', out_layout='NHWT', |
|
|
ret_contiguous=False): |
|
|
|
|
|
if in_layout == 'NHWT': |
|
|
pass |
|
|
elif in_layout == 'NTHW': |
|
|
data = data.permute(0, 2, 3, 1) |
|
|
elif in_layout == 'NTCHW': |
|
|
data = data[:, :, 0, :, :] |
|
|
data = data.permute(0, 2, 3, 1) |
|
|
elif in_layout == 'NTHWC': |
|
|
data = data[:, :, :, :, 0] |
|
|
data = data.permute(0, 2, 3, 1) |
|
|
elif in_layout == 'TNHW': |
|
|
data = data.permute(1, 2, 3, 0) |
|
|
elif in_layout == 'TNCHW': |
|
|
data = data[:, :, 0, :, :] |
|
|
data = data.permute(1, 2, 3, 0) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if out_layout == 'NHWT': |
|
|
pass |
|
|
elif out_layout == 'NTHW': |
|
|
data = data.permute(0, 3, 1, 2) |
|
|
elif out_layout == 'NTCHW': |
|
|
data = data.permute(0, 3, 1, 2) |
|
|
data = torch.unsqueeze(data, dim=2) |
|
|
elif out_layout == 'NTHWC': |
|
|
data = data.permute(0, 3, 1, 2) |
|
|
data = torch.unsqueeze(data, dim=-1) |
|
|
elif out_layout == 'TNHW': |
|
|
data = data.permute(3, 0, 1, 2) |
|
|
elif out_layout == 'TNCHW': |
|
|
data = data.permute(3, 0, 1, 2) |
|
|
data = torch.unsqueeze(data, dim=2) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
if ret_contiguous: |
|
|
data = data.contiguous() |
|
|
return data |
|
|
|
|
|
class SEVIRDataLoader: |
|
|
r""" |
|
|
DataLoader that loads SEVIR sequences, and spilts each event |
|
|
into segments according to specified sequence length. |
|
|
|
|
|
Event Frames: |
|
|
[-----------------------raw_seq_len----------------------] |
|
|
[-----seq_len-----] |
|
|
<--stride-->[-----seq_len-----] |
|
|
<--stride-->[-----seq_len-----] |
|
|
... |
|
|
""" |
|
|
def __init__(self, |
|
|
data_types: Sequence[str] = None, |
|
|
seq_len: int = 49, |
|
|
raw_seq_len: int = 49, |
|
|
sample_mode: str = 'sequent', |
|
|
stride: int = 12, |
|
|
batch_size: int = 1, |
|
|
layout: str = 'NHWT', |
|
|
num_shard: int = 1, |
|
|
rank: int = 0, |
|
|
split_mode: str = "uneven", |
|
|
sevir_catalog: Union[str, pd.DataFrame] = None, |
|
|
sevir_data_dir: str = None, |
|
|
start_date: datetime.datetime = None, |
|
|
end_date: datetime.datetime = None, |
|
|
datetime_filter=None, |
|
|
catalog_filter='default', |
|
|
shuffle: bool = False, |
|
|
shuffle_seed: int = 1, |
|
|
output_type=np.float32, |
|
|
preprocess: bool = True, |
|
|
rescale_method: str = '01', |
|
|
downsample_dict: Dict[str, Sequence[int]] = None, |
|
|
verbose: bool = False): |
|
|
r""" |
|
|
Parameters |
|
|
---------- |
|
|
data_types |
|
|
A subset of SEVIR_DATA_TYPES. |
|
|
seq_len |
|
|
The length of the data sequences. Should be smaller than the max length raw_seq_len. |
|
|
raw_seq_len |
|
|
The length of the raw data sequences. |
|
|
sample_mode |
|
|
'random' or 'sequent' |
|
|
stride |
|
|
Useful when sample_mode == 'sequent' |
|
|
stride must not be smaller than out_len to prevent data leakage in testing. |
|
|
batch_size |
|
|
Number of sequences in one batch. |
|
|
layout |
|
|
str: consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W' |
|
|
The layout of sampled data. Raw data layout is 'NHWT'. |
|
|
valid layout: 'NHWT', 'NTHW', 'NTCHW', 'TNHW', 'TNCHW'. |
|
|
num_shard |
|
|
Split the whole dataset into num_shard parts for distributed training. |
|
|
rank |
|
|
Rank of the current process within num_shard. |
|
|
split_mode: str |
|
|
if 'ceil', all `num_shard` dataloaders have the same length = ceil(total_len / num_shard). |
|
|
Different dataloaders may have some duplicated data batches, if the total size of datasets is not divided by num_shard. |
|
|
if 'floor', all `num_shard` dataloaders have the same length = floor(total_len / num_shard). |
|
|
The last several data batches may be wasted, if the total size of datasets is not divided by num_shard. |
|
|
if 'uneven', the last datasets has larger length when the total length is not divided by num_shard. |
|
|
The uneven split leads to synchronization error in dist.all_reduce() or dist.barrier(). |
|
|
See related issue: https://github.com/pytorch/pytorch/issues/33148 |
|
|
Notice: this also affects the behavior of `self.use_up`. |
|
|
sevir_catalog |
|
|
Name of SEVIR catalog CSV file. |
|
|
sevir_data_dir |
|
|
Directory path to SEVIR data. |
|
|
start_date |
|
|
Start time of SEVIR samples to generate. |
|
|
end_date |
|
|
End time of SEVIR samples to generate. |
|
|
datetime_filter |
|
|
function |
|
|
Mask function applied to time_utc column of catalog (return true to keep the row). |
|
|
Pass function of the form lambda t : COND(t) |
|
|
Example: lambda t: np.logical_and(t.dt.hour>=13,t.dt.hour<=21) # Generate only day-time events |
|
|
catalog_filter |
|
|
function or None or 'default' |
|
|
Mask function applied to entire catalog dataframe (return true to keep row). |
|
|
Pass function of the form lambda catalog: COND(catalog) |
|
|
Example: lambda c: [s[0]=='S' for s in c.id] # Generate only the 'S' events |
|
|
shuffle |
|
|
bool, If True, data samples are shuffled before each epoch. |
|
|
shuffle_seed |
|
|
int, Seed to use for shuffling. |
|
|
output_type |
|
|
np.dtype, dtype of generated tensors |
|
|
preprocess |
|
|
bool, If True, self.preprocess_data_dict(data_dict) is called before each sample generated |
|
|
downsample_dict: |
|
|
dict, downsample_dict.keys() == data_types. downsample_dict[key] is a Sequence of (t_factor, h_factor, w_factor), |
|
|
representing the downsampling factors of all dimensions. |
|
|
verbose |
|
|
bool, verbose when opening raw data files |
|
|
""" |
|
|
super(SEVIRDataLoader, self).__init__() |
|
|
if sevir_catalog is None: |
|
|
sevir_catalog = SEVIR_CATALOG |
|
|
if sevir_data_dir is None: |
|
|
sevir_data_dir = SEVIR_DATA_DIR |
|
|
if data_types is None: |
|
|
data_types = SEVIR_DATA_TYPES |
|
|
else: |
|
|
assert set(data_types).issubset(SEVIR_DATA_TYPES) |
|
|
|
|
|
|
|
|
self._dtypes = SEVIR_RAW_DTYPES |
|
|
self.lght_frame_times = LIGHTING_FRAME_TIMES |
|
|
self.data_shape = SEVIR_DATA_SHAPE |
|
|
|
|
|
self.raw_seq_len = raw_seq_len |
|
|
assert seq_len <= self.raw_seq_len, f'seq_len must not be larger than raw_seq_len = {raw_seq_len}, got {seq_len}.' |
|
|
self.seq_len = seq_len |
|
|
assert sample_mode in ['random', 'sequent'], f'Invalid sample_mode = {sample_mode}, must be \'random\' or \'sequent\'.' |
|
|
self.sample_mode = sample_mode |
|
|
self.stride = stride |
|
|
self.batch_size = batch_size |
|
|
valid_layout = ('NHWT', 'NTHW', 'NTCHW', 'NTHWC', 'TNHW', 'TNCHW') |
|
|
if layout not in valid_layout: |
|
|
raise ValueError(f'Invalid layout = {layout}! Must be one of {valid_layout}.') |
|
|
self.layout = layout |
|
|
self.num_shard = num_shard |
|
|
self.rank = rank |
|
|
valid_split_mode = ('ceil', 'floor', 'uneven') |
|
|
if split_mode not in valid_split_mode: |
|
|
raise ValueError(f'Invalid split_mode: {split_mode}! Must be one of {valid_split_mode}.') |
|
|
self.split_mode = split_mode |
|
|
self._samples = None |
|
|
self._hdf_files = {} |
|
|
self.data_types = data_types |
|
|
if isinstance(sevir_catalog, str): |
|
|
self.catalog = pd.read_csv(sevir_catalog, parse_dates=['time_utc'], low_memory=False) |
|
|
else: |
|
|
self.catalog = sevir_catalog |
|
|
self.sevir_data_dir = sevir_data_dir |
|
|
self.datetime_filter = datetime_filter |
|
|
self.catalog_filter = catalog_filter |
|
|
self.start_date = start_date |
|
|
self.end_date = end_date |
|
|
self.shuffle = shuffle |
|
|
self.shuffle_seed = int(shuffle_seed) |
|
|
self.output_type = output_type |
|
|
self.preprocess = preprocess |
|
|
self.downsample_dict = downsample_dict |
|
|
self.rescale_method = rescale_method |
|
|
self.verbose = verbose |
|
|
|
|
|
if self.start_date is not None: |
|
|
self.catalog = self.catalog[self.catalog.time_utc > self.start_date] |
|
|
if self.end_date is not None: |
|
|
self.catalog = self.catalog[self.catalog.time_utc <= self.end_date] |
|
|
if self.datetime_filter: |
|
|
self.catalog = self.catalog[self.datetime_filter(self.catalog.time_utc)] |
|
|
|
|
|
if self.catalog_filter is not None: |
|
|
if self.catalog_filter == 'default': |
|
|
self.catalog_filter = lambda c: c.pct_missing == 0 |
|
|
self.catalog = self.catalog[self.catalog_filter(self.catalog)] |
|
|
|
|
|
self._compute_samples() |
|
|
self._open_files(verbose=self.verbose) |
|
|
self.reset() |
|
|
|
|
|
def _compute_samples(self): |
|
|
""" |
|
|
Computes the list of samples in catalog to be used. This sets self._samples |
|
|
""" |
|
|
|
|
|
imgt = self.data_types |
|
|
imgts = set(imgt) |
|
|
filtcat = self.catalog[ np.logical_or.reduce([self.catalog.img_type==i for i in imgt]) ] |
|
|
|
|
|
filtcat = filtcat.groupby('id').filter(lambda x: imgts.issubset(set(x['img_type']))) |
|
|
|
|
|
|
|
|
filtcat = filtcat.groupby('id').filter(lambda x: x.shape[0]==len(imgt)) |
|
|
self._samples = filtcat.groupby('id').apply(lambda df: self._df_to_series(df,imgt) ) |
|
|
if self.shuffle: |
|
|
self.shuffle_samples() |
|
|
|
|
|
def shuffle_samples(self): |
|
|
self._samples = self._samples.sample(frac=1, random_state=self.shuffle_seed) |
|
|
|
|
|
def _df_to_series(self, df, imgt): |
|
|
d = {} |
|
|
df = df.set_index('img_type') |
|
|
for i in imgt: |
|
|
s = df.loc[i] |
|
|
idx = s.file_index if i != 'lght' else s.id |
|
|
d.update({f'{i}_filename': [s.file_name], |
|
|
f'{i}_index': [idx]}) |
|
|
|
|
|
return pd.DataFrame(d) |
|
|
|
|
|
def _open_files(self, verbose=True): |
|
|
""" |
|
|
Opens HDF files |
|
|
""" |
|
|
imgt = self.data_types |
|
|
hdf_filenames = [] |
|
|
for t in imgt: |
|
|
hdf_filenames += list(np.unique( self._samples[f'{t}_filename'].values )) |
|
|
self._hdf_files = {} |
|
|
for f in hdf_filenames: |
|
|
if verbose: |
|
|
print('Opening HDF5 file for reading', f) |
|
|
self._hdf_files[f] = h5py.File(self.sevir_data_dir + '/' + f, 'r') |
|
|
|
|
|
def close(self): |
|
|
""" |
|
|
Closes all open file handles |
|
|
""" |
|
|
for f in self._hdf_files: |
|
|
self._hdf_files[f].close() |
|
|
self._hdf_files = {} |
|
|
|
|
|
@property |
|
|
def num_seq_per_event(self): |
|
|
return 1 + (self.raw_seq_len - self.seq_len) // self.stride |
|
|
|
|
|
@property |
|
|
def total_num_seq(self): |
|
|
""" |
|
|
The total number of sequences within each shard. |
|
|
Notice that it is not the product of `self.num_seq_per_event` and `self.total_num_event`. |
|
|
""" |
|
|
return int(self.num_seq_per_event * self.num_event) |
|
|
|
|
|
@property |
|
|
def total_num_event(self): |
|
|
""" |
|
|
The total number of events in the whole dataset, before split into different shards. |
|
|
""" |
|
|
return int(self._samples.shape[0]) |
|
|
|
|
|
@property |
|
|
def start_event_idx(self): |
|
|
""" |
|
|
The event idx used in certain rank should satisfy event_idx >= start_event_idx |
|
|
""" |
|
|
return self.total_num_event // self.num_shard * self.rank |
|
|
|
|
|
@property |
|
|
def end_event_idx(self): |
|
|
""" |
|
|
The event idx used in certain rank should satisfy event_idx < end_event_idx |
|
|
|
|
|
""" |
|
|
if self.split_mode == 'ceil': |
|
|
_last_start_event_idx = self.total_num_event // self.num_shard * (self.num_shard - 1) |
|
|
_num_event = self.total_num_event - _last_start_event_idx |
|
|
return self.start_event_idx + _num_event |
|
|
elif self.split_mode == 'floor': |
|
|
return self.total_num_event // self.num_shard * (self.rank + 1) |
|
|
else: |
|
|
if self.rank == self.num_shard - 1: |
|
|
return self.total_num_event |
|
|
else: |
|
|
return self.total_num_event // self.num_shard * (self.rank + 1) |
|
|
|
|
|
@property |
|
|
def num_event(self): |
|
|
""" |
|
|
The number of events split into each rank |
|
|
""" |
|
|
return self.end_event_idx - self.start_event_idx |
|
|
|
|
|
def _read_data(self, row, data): |
|
|
""" |
|
|
Iteratively read data into data dict. Finally data[imgt] gets shape (batch_size, height, width, raw_seq_len). |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
row |
|
|
A series with fields IMGTYPE_filename, IMGTYPE_index, IMGTYPE_time_index. |
|
|
data |
|
|
Dict, data[imgt] is a data tensor with shape = (tmp_batch_size, height, width, raw_seq_len). |
|
|
|
|
|
Returns |
|
|
------- |
|
|
data |
|
|
Updated data. Updated shape = (tmp_batch_size + 1, height, width, raw_seq_len). |
|
|
""" |
|
|
imgtyps = np.unique([x.split('_')[0] for x in list(row.keys())]) |
|
|
for t in imgtyps: |
|
|
fname = row[f'{t}_filename'] |
|
|
idx = row[f'{t}_index'] |
|
|
t_slice = slice(0, None) |
|
|
|
|
|
if t == 'lght': |
|
|
lght_data = self._hdf_files[fname][idx][:] |
|
|
data_i = self._lght_to_grid(lght_data, t_slice) |
|
|
else: |
|
|
data_i = self._hdf_files[fname][t][idx:idx + 1, :, :, t_slice] |
|
|
data[t] = np.concatenate((data[t], data_i), axis=0) if (t in data) else data_i |
|
|
|
|
|
return data |
|
|
|
|
|
def _lght_to_grid(self, data, t_slice=slice(0, None)): |
|
|
""" |
|
|
Converts Nx5 lightning data matrix into a 2D grid of pixel counts |
|
|
""" |
|
|
|
|
|
out_size = (*self.data_shape['lght'], len(self.lght_frame_times)) if t_slice.stop is None else (*self.data_shape['lght'], 1) |
|
|
if data.shape[0] == 0: |
|
|
return np.zeros((1,) + out_size, dtype=np.float32) |
|
|
|
|
|
|
|
|
x, y = data[:, 3], data[:, 4] |
|
|
m = np.logical_and.reduce([x >= 0, x < out_size[0], y >= 0, y < out_size[1]]) |
|
|
data = data[m, :] |
|
|
if data.shape[0] == 0: |
|
|
return np.zeros((1,) + out_size, dtype=np.float32) |
|
|
|
|
|
|
|
|
t = data[:, 0] |
|
|
if t_slice.stop is not None: |
|
|
if t_slice.stop > 0: |
|
|
if t_slice.stop < len(self.lght_frame_times): |
|
|
tm = np.logical_and(t >= self.lght_frame_times[t_slice.stop - 1], |
|
|
t < self.lght_frame_times[t_slice.stop]) |
|
|
else: |
|
|
tm = t >= self.lght_frame_times[-1] |
|
|
else: |
|
|
tm = np.logical_and(t >= self.lght_frame_times[0], t < self.lght_frame_times[1]) |
|
|
|
|
|
|
|
|
data = data[tm, :] |
|
|
z = np.zeros(data.shape[0], dtype=np.int64) |
|
|
else: |
|
|
z = np.digitize(t, self.lght_frame_times) - 1 |
|
|
z[z == -1] = 0 |
|
|
|
|
|
x = data[:, 3].astype(np.int64) |
|
|
y = data[:, 4].astype(np.int64) |
|
|
|
|
|
k = np.ravel_multi_index(np.array([y, x, z]), out_size) |
|
|
n = np.bincount(k, minlength=np.prod(out_size)) |
|
|
return np.reshape(n, out_size).astype(np.int16)[np.newaxis, :] |
|
|
|
|
|
def _old_save_downsampled_dataset(self, save_dir, downsample_dict, verbose=True): |
|
|
""" |
|
|
This method does not save .h5 dataset correctly. There are some batches missed due to unknown error. |
|
|
E.g., the first converted .h5 file `SEVIR_VIL_RANDOMEVENTS_2017_0501_0831.h5` only has batch_dim = 1414, |
|
|
while it should be 1440 in the original .h5 file. |
|
|
""" |
|
|
import os |
|
|
from skimage.measure import block_reduce |
|
|
assert not os.path.exists(save_dir), f"save_dir {save_dir} already exists!" |
|
|
os.makedirs(save_dir) |
|
|
sample_counter = 0 |
|
|
for index, row in self._samples.iterrows(): |
|
|
if verbose: |
|
|
print(f"Downsampling {sample_counter}-th data item.", end='\r') |
|
|
for data_type in self.data_types: |
|
|
fname = row[f'{data_type}_filename'] |
|
|
idx = row[f'{data_type}_index'] |
|
|
t_slice = slice(0, None) |
|
|
if data_type == 'lght': |
|
|
lght_data = self._hdf_files[fname][idx][:] |
|
|
data_i = self._lght_to_grid(lght_data, t_slice) |
|
|
else: |
|
|
data_i = self._hdf_files[fname][data_type][idx:idx + 1, :, :, t_slice] |
|
|
|
|
|
t_slice = [slice(None, None), ] * 4 |
|
|
t_slice[-1] = slice(None, None, downsample_dict[data_type][0]) |
|
|
data_i = data_i[tuple(t_slice)] |
|
|
|
|
|
data_i = block_reduce(data_i, |
|
|
block_size=(1, *downsample_dict[data_type][1:], 1), |
|
|
func=np.max) |
|
|
|
|
|
new_file_path = os.path.join(save_dir, fname) |
|
|
if not os.path.exists(new_file_path): |
|
|
if not os.path.exists(os.path.dirname(new_file_path)): |
|
|
os.makedirs(os.path.dirname(new_file_path)) |
|
|
|
|
|
with h5py.File(new_file_path, 'w') as hf: |
|
|
hf.create_dataset( |
|
|
data_type, data=data_i, |
|
|
maxshape=(None, *data_i.shape[1:])) |
|
|
else: |
|
|
|
|
|
with h5py.File(new_file_path, 'a') as hf: |
|
|
hf[data_type].resize((hf[data_type].shape[0] + data_i.shape[0]), axis=0) |
|
|
hf[data_type][-data_i.shape[0]:] = data_i |
|
|
|
|
|
sample_counter += 1 |
|
|
|
|
|
def save_downsampled_dataset(self, save_dir, downsample_dict, verbose=True): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
save_dir |
|
|
downsample_dict: Dict[Sequence[int]] |
|
|
Notice that this is different from `self.downsample_dict`, which is used during runtime. |
|
|
""" |
|
|
import os |
|
|
from skimage.measure import block_reduce |
|
|
from ...utils.utils import path_splitall |
|
|
assert not os.path.exists(save_dir), f"save_dir {save_dir} already exists!" |
|
|
os.makedirs(save_dir) |
|
|
for fname, hdf_file in self._hdf_files.items(): |
|
|
if verbose: |
|
|
print(f"Downsampling data in {fname}.") |
|
|
data_type = path_splitall(fname)[0] |
|
|
if data_type == 'lght': |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
data_i = self._hdf_files[fname][data_type] |
|
|
|
|
|
t_slice = [slice(None, None), ] * 4 |
|
|
t_slice[-1] = slice(None, None, downsample_dict[data_type][0]) |
|
|
data_i = data_i[tuple(t_slice)] |
|
|
|
|
|
data_i = block_reduce(data_i, |
|
|
block_size=(1, *downsample_dict[data_type][1:], 1), |
|
|
func=np.max) |
|
|
|
|
|
new_file_path = os.path.join(save_dir, fname) |
|
|
if not os.path.exists(os.path.dirname(new_file_path)): |
|
|
os.makedirs(os.path.dirname(new_file_path)) |
|
|
|
|
|
with h5py.File(new_file_path, 'w') as hf: |
|
|
hf.create_dataset( |
|
|
data_type, data=data_i, |
|
|
maxshape=(None, *data_i.shape[1:])) |
|
|
|
|
|
@property |
|
|
def sample_count(self): |
|
|
""" |
|
|
Record how many times self.__next__() is called. |
|
|
""" |
|
|
return self._sample_count |
|
|
|
|
|
def inc_sample_count(self): |
|
|
self._sample_count += 1 |
|
|
|
|
|
@property |
|
|
def curr_event_idx(self): |
|
|
return self._curr_event_idx |
|
|
|
|
|
@property |
|
|
def curr_seq_idx(self): |
|
|
""" |
|
|
Used only when self.sample_mode == 'sequent' |
|
|
""" |
|
|
return self._curr_seq_idx |
|
|
|
|
|
def set_curr_event_idx(self, val): |
|
|
self._curr_event_idx = val |
|
|
|
|
|
def set_curr_seq_idx(self, val): |
|
|
""" |
|
|
Used only when self.sample_mode == 'sequent' |
|
|
""" |
|
|
self._curr_seq_idx = val |
|
|
|
|
|
def reset(self, shuffle: bool = None): |
|
|
self.set_curr_event_idx(val=self.start_event_idx) |
|
|
self.set_curr_seq_idx(0) |
|
|
self._sample_count = 0 |
|
|
if shuffle is None: |
|
|
shuffle = self.shuffle |
|
|
if shuffle: |
|
|
self.shuffle_samples() |
|
|
|
|
|
def __len__(self): |
|
|
""" |
|
|
Used only when self.sample_mode == 'sequent' |
|
|
""" |
|
|
return self.total_num_seq // self.batch_size |
|
|
|
|
|
@property |
|
|
def use_up(self): |
|
|
""" |
|
|
Check if dataset is used up in 'sequent' mode. |
|
|
""" |
|
|
if self.sample_mode == 'random': |
|
|
return False |
|
|
else: |
|
|
|
|
|
curr_event_remain_seq = self.num_seq_per_event - self.curr_seq_idx |
|
|
all_remain_seq = curr_event_remain_seq + ( |
|
|
self.end_event_idx - self.curr_event_idx - 1) * self.num_seq_per_event |
|
|
if self.split_mode == "floor": |
|
|
|
|
|
return all_remain_seq < self.batch_size |
|
|
else: |
|
|
return all_remain_seq <= 0 |
|
|
|
|
|
def _load_event_batch(self, event_idx, event_batch_size): |
|
|
""" |
|
|
Loads a selected batch of events (not batch of sequences) into memory. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
idx |
|
|
event_batch_size |
|
|
event_batch[i] = all_type_i_available_events[idx:idx + event_batch_size] |
|
|
Returns |
|
|
------- |
|
|
event_batch |
|
|
list of event batches. |
|
|
event_batch[i] is the event batch of the i-th data type. |
|
|
Each event_batch[i] is a np.ndarray with shape = (event_batch_size, height, width, raw_seq_len) |
|
|
""" |
|
|
event_idx_slice_end = event_idx + event_batch_size |
|
|
pad_size = 0 |
|
|
if event_idx_slice_end > self.end_event_idx: |
|
|
pad_size = event_idx_slice_end - self.end_event_idx |
|
|
event_idx_slice_end = self.end_event_idx |
|
|
pd_batch = self._samples.iloc[event_idx:event_idx_slice_end] |
|
|
data = {} |
|
|
for index, row in pd_batch.iterrows(): |
|
|
data = self._read_data(row, data) |
|
|
if pad_size > 0: |
|
|
event_batch = [] |
|
|
for t in self.data_types: |
|
|
pad_shape = [pad_size, ] + list(data[t].shape[1:]) |
|
|
data_pad = np.concatenate((data[t].astype(self.output_type), |
|
|
np.zeros(pad_shape, dtype=self.output_type)), |
|
|
axis=0) |
|
|
event_batch.append(data_pad) |
|
|
else: |
|
|
event_batch = [data[t].astype(self.output_type) for t in self.data_types] |
|
|
return event_batch |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
if self.sample_mode == 'random': |
|
|
self.inc_sample_count() |
|
|
ret_dict = self._random_sample() |
|
|
else: |
|
|
if self.use_up: |
|
|
raise StopIteration |
|
|
else: |
|
|
self.inc_sample_count() |
|
|
ret_dict = self._sequent_sample() |
|
|
ret_dict = self.data_dict_to_tensor(data_dict=ret_dict, |
|
|
data_types=self.data_types) |
|
|
if self.preprocess: |
|
|
ret_dict = self.preprocess_data_dict(data_dict=ret_dict, |
|
|
data_types=self.data_types, |
|
|
layout=self.layout, |
|
|
rescale=self.rescale_method) |
|
|
if self.downsample_dict is not None: |
|
|
ret_dict = self.downsample_data_dict(data_dict=ret_dict, |
|
|
data_types=self.data_types, |
|
|
factors_dict=self.downsample_dict, |
|
|
layout=self.layout) |
|
|
return ret_dict |
|
|
|
|
|
def __getitem__(self, index): |
|
|
data_dict = self._idx_sample(index=index) |
|
|
return data_dict |
|
|
|
|
|
@staticmethod |
|
|
def preprocess_data_dict(data_dict, data_types=None, layout='NHWT', rescale='01'): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
data_dict: Dict[str, Union[np.ndarray, torch.Tensor]] |
|
|
data_types: Sequence[str] |
|
|
The data types that we want to rescale. This mainly excludes "mask" from preprocessing. |
|
|
layout: str |
|
|
consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W' |
|
|
rescale: str |
|
|
'sevir': use the offsets and scale factors in original implementation. |
|
|
'01': scale all values to range 0 to 1, currently only supports 'vil' |
|
|
Returns |
|
|
------- |
|
|
data_dict: Dict[str, Union[np.ndarray, torch.Tensor]] |
|
|
preprocessed data |
|
|
""" |
|
|
if rescale == 'sevir': |
|
|
scale_dict = PREPROCESS_SCALE_SEVIR |
|
|
offset_dict = PREPROCESS_OFFSET_SEVIR |
|
|
elif rescale == '01': |
|
|
scale_dict = PREPROCESS_SCALE_01 |
|
|
offset_dict = PREPROCESS_OFFSET_01 |
|
|
else: |
|
|
raise ValueError(f'Invalid rescale option: {rescale}.') |
|
|
if data_types is None: |
|
|
data_types = data_dict.keys() |
|
|
for key, data in data_dict.items(): |
|
|
if key in data_types: |
|
|
if isinstance(data, np.ndarray): |
|
|
data = scale_dict[key] * ( |
|
|
data.astype(np.float32) + |
|
|
offset_dict[key]) |
|
|
data = change_layout_np(data=data, |
|
|
in_layout='NHWT', |
|
|
out_layout=layout) |
|
|
elif isinstance(data, torch.Tensor): |
|
|
data = scale_dict[key] * ( |
|
|
data.float() + |
|
|
offset_dict[key]) |
|
|
data = change_layout_torch(data=data, |
|
|
in_layout='NHWT', |
|
|
out_layout=layout) |
|
|
data_dict[key] = data |
|
|
return data_dict |
|
|
|
|
|
@staticmethod |
|
|
def process_data_dict_back(data_dict, data_types=None, rescale='01'): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
data_dict |
|
|
each data_dict[key] is a torch.Tensor. |
|
|
rescale |
|
|
str: |
|
|
'sevir': data are scaled using the offsets and scale factors in original implementation. |
|
|
'01': data are all scaled to range 0 to 1, currently only supports 'vil' |
|
|
Returns |
|
|
------- |
|
|
data_dict |
|
|
each data_dict[key] is the data processed back in torch.Tensor. |
|
|
""" |
|
|
if rescale == 'sevir': |
|
|
scale_dict = PREPROCESS_SCALE_SEVIR |
|
|
offset_dict = PREPROCESS_OFFSET_SEVIR |
|
|
elif rescale == '01': |
|
|
scale_dict = PREPROCESS_SCALE_01 |
|
|
offset_dict = PREPROCESS_OFFSET_01 |
|
|
else: |
|
|
raise ValueError(f'Invalid rescale option: {rescale}.') |
|
|
if data_types is None: |
|
|
data_types = data_dict.keys() |
|
|
for key in data_types: |
|
|
data = data_dict[key] |
|
|
data = data.float() / scale_dict[key] - offset_dict[key] |
|
|
data_dict[key] = data |
|
|
return data_dict |
|
|
|
|
|
@staticmethod |
|
|
def data_dict_to_tensor(data_dict, data_types=None): |
|
|
""" |
|
|
Convert each element in data_dict to torch.Tensor (copy without grad). |
|
|
""" |
|
|
ret_dict = {} |
|
|
if data_types is None: |
|
|
data_types = data_dict.keys() |
|
|
for key, data in data_dict.items(): |
|
|
if key in data_types: |
|
|
if isinstance(data, torch.Tensor): |
|
|
ret_dict[key] = data.detach().clone() |
|
|
elif isinstance(data, np.ndarray): |
|
|
ret_dict[key] = torch.from_numpy(data) |
|
|
else: |
|
|
raise ValueError(f"Invalid data type: {type(data)}. Should be torch.Tensor or np.ndarray") |
|
|
else: |
|
|
ret_dict[key] = data |
|
|
return ret_dict |
|
|
|
|
|
@staticmethod |
|
|
def downsample_data_dict(data_dict, data_types=None, factors_dict=None, layout='NHWT'): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
data_dict: Dict[str, Union[np.array, torch.Tensor]] |
|
|
factors_dict: Optional[Dict[str, Sequence[int]]] |
|
|
each element `factors` is a Sequence of int, representing (t_factor, h_factor, w_factor) |
|
|
|
|
|
Returns |
|
|
------- |
|
|
downsampled_data_dict: Dict[str, torch.Tensor] |
|
|
Modify on a deep copy of data_dict instead of directly modifying the original data_dict |
|
|
""" |
|
|
if factors_dict is None: |
|
|
factors_dict = {} |
|
|
if data_types is None: |
|
|
data_types = data_dict.keys() |
|
|
downsampled_data_dict = SEVIRDataLoader.data_dict_to_tensor( |
|
|
data_dict=data_dict, |
|
|
data_types=data_types) |
|
|
for key, data in data_dict.items(): |
|
|
factors = factors_dict.get(key, None) |
|
|
if factors is not None: |
|
|
downsampled_data_dict[key] = change_layout_torch( |
|
|
data=downsampled_data_dict[key], |
|
|
in_layout=layout, |
|
|
out_layout='NTHW') |
|
|
|
|
|
t_slice = [slice(None, None), ] * 4 |
|
|
t_slice[1] = slice(None, None, factors[0]) |
|
|
downsampled_data_dict[key] = downsampled_data_dict[key][tuple(t_slice)] |
|
|
|
|
|
downsampled_data_dict[key] = avg_pool2d( |
|
|
input=downsampled_data_dict[key], |
|
|
kernel_size=(factors[1], factors[2])) |
|
|
|
|
|
downsampled_data_dict[key] = change_layout_torch( |
|
|
data=downsampled_data_dict[key], |
|
|
in_layout='NTHW', |
|
|
out_layout=layout) |
|
|
|
|
|
return downsampled_data_dict |
|
|
|
|
|
def _random_sample(self): |
|
|
""" |
|
|
Returns |
|
|
------- |
|
|
ret_dict |
|
|
dict. ret_dict.keys() == self.data_types. |
|
|
If self.preprocess == False: |
|
|
ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
|
|
""" |
|
|
num_sampled = 0 |
|
|
event_idx_list = nprand.randint(low=self.start_event_idx, |
|
|
high=self.end_event_idx, |
|
|
size=self.batch_size) |
|
|
seq_idx_list = nprand.randint(low=0, |
|
|
high=self.num_seq_per_event, |
|
|
size=self.batch_size) |
|
|
seq_slice_list = [slice(seq_idx * self.stride, |
|
|
seq_idx * self.stride + self.seq_len) |
|
|
for seq_idx in seq_idx_list] |
|
|
ret_dict = {} |
|
|
while num_sampled < self.batch_size: |
|
|
event = self._load_event_batch(event_idx=event_idx_list[num_sampled], |
|
|
event_batch_size=1) |
|
|
for imgt_idx, imgt in enumerate(self.data_types): |
|
|
sampled_seq = event[imgt_idx][[0, ], :, :, seq_slice_list[num_sampled]] |
|
|
if imgt in ret_dict: |
|
|
ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
|
|
axis=0) |
|
|
else: |
|
|
ret_dict.update({imgt: sampled_seq}) |
|
|
return ret_dict |
|
|
|
|
|
def _sequent_sample(self): |
|
|
""" |
|
|
Returns |
|
|
------- |
|
|
ret_dict: Dict |
|
|
`ret_dict.keys()` contains `self.data_types`. |
|
|
`ret_dict["mask"]` is a list of bool, indicating if the data entry is real or padded. |
|
|
If self.preprocess == False: |
|
|
ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
|
|
""" |
|
|
assert not self.use_up, 'Data loader used up! Reset it to reuse.' |
|
|
event_idx = self.curr_event_idx |
|
|
seq_idx = self.curr_seq_idx |
|
|
num_sampled = 0 |
|
|
sampled_idx_list = [] |
|
|
while num_sampled < self.batch_size: |
|
|
sampled_idx_list.append({'event_idx': event_idx, |
|
|
'seq_idx': seq_idx}) |
|
|
seq_idx += 1 |
|
|
if seq_idx >= self.num_seq_per_event: |
|
|
event_idx += 1 |
|
|
seq_idx = 0 |
|
|
num_sampled += 1 |
|
|
|
|
|
start_event_idx = sampled_idx_list[0]['event_idx'] |
|
|
event_batch_size = sampled_idx_list[-1]['event_idx'] - start_event_idx + 1 |
|
|
|
|
|
event_batch = self._load_event_batch(event_idx=start_event_idx, |
|
|
event_batch_size=event_batch_size) |
|
|
ret_dict = {"mask": []} |
|
|
all_no_pad_flag = True |
|
|
for sampled_idx in sampled_idx_list: |
|
|
batch_slice = [sampled_idx['event_idx'] - start_event_idx, ] |
|
|
seq_slice = slice(sampled_idx['seq_idx'] * self.stride, |
|
|
sampled_idx['seq_idx'] * self.stride + self.seq_len) |
|
|
for imgt_idx, imgt in enumerate(self.data_types): |
|
|
sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] |
|
|
if imgt in ret_dict: |
|
|
ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
|
|
axis=0) |
|
|
else: |
|
|
ret_dict.update({imgt: sampled_seq}) |
|
|
|
|
|
no_pad_flag = sampled_idx['event_idx'] < self.end_event_idx |
|
|
if not no_pad_flag: |
|
|
all_no_pad_flag = False |
|
|
ret_dict["mask"].append(no_pad_flag) |
|
|
if all_no_pad_flag: |
|
|
|
|
|
ret_dict["mask"] = None |
|
|
|
|
|
self.set_curr_event_idx(event_idx) |
|
|
self.set_curr_seq_idx(seq_idx) |
|
|
return ret_dict |
|
|
|
|
|
def _idx_sample(self, index): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
index |
|
|
The index of the batch to sample. |
|
|
Returns |
|
|
------- |
|
|
ret_dict |
|
|
dict. ret_dict.keys() == self.data_types. |
|
|
If self.preprocess == False: |
|
|
ret_dict[imgt].shape == (batch_size, height, width, seq_len) |
|
|
""" |
|
|
event_idx = (index * self.batch_size) // self.num_seq_per_event |
|
|
seq_idx = (index * self.batch_size) % self.num_seq_per_event |
|
|
num_sampled = 0 |
|
|
sampled_idx_list = [] |
|
|
while num_sampled < self.batch_size: |
|
|
sampled_idx_list.append({'event_idx': event_idx, |
|
|
'seq_idx': seq_idx}) |
|
|
seq_idx += 1 |
|
|
if seq_idx >= self.num_seq_per_event: |
|
|
event_idx += 1 |
|
|
seq_idx = 0 |
|
|
num_sampled += 1 |
|
|
|
|
|
start_event_idx = sampled_idx_list[0]['event_idx'] |
|
|
event_batch_size = sampled_idx_list[-1]['event_idx'] - start_event_idx + 1 |
|
|
|
|
|
event_batch = self._load_event_batch(event_idx=start_event_idx, |
|
|
event_batch_size=event_batch_size) |
|
|
ret_dict = {} |
|
|
for sampled_idx in sampled_idx_list: |
|
|
batch_slice = [sampled_idx['event_idx'] - start_event_idx, ] |
|
|
seq_slice = slice(sampled_idx['seq_idx'] * self.stride, |
|
|
sampled_idx['seq_idx'] * self.stride + self.seq_len) |
|
|
for imgt_idx, imgt in enumerate(self.data_types): |
|
|
sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] |
|
|
if imgt in ret_dict: |
|
|
ret_dict[imgt] = np.concatenate((ret_dict[imgt], sampled_seq), |
|
|
axis=0) |
|
|
else: |
|
|
ret_dict.update({imgt: sampled_seq}) |
|
|
|
|
|
ret_dict = self.data_dict_to_tensor(data_dict=ret_dict, |
|
|
data_types=self.data_types) |
|
|
if self.preprocess: |
|
|
ret_dict = self.preprocess_data_dict(data_dict=ret_dict, |
|
|
data_types=self.data_types, |
|
|
layout=self.layout, |
|
|
rescale=self.rescale_method) |
|
|
|
|
|
if self.downsample_dict is not None: |
|
|
ret_dict = self.downsample_data_dict(data_dict=ret_dict, |
|
|
data_types=self.data_types, |
|
|
factors_dict=self.downsample_dict, |
|
|
layout=self.layout) |
|
|
return ret_dict |
|
|
|
|
|
|
|
|
class SEVIRDataIterator(): |
|
|
''' |
|
|
A wrapper s.t. it implements the function sample(). |
|
|
Every arguments in this class will be redirected to the inner SEVIRDataLoader object. |
|
|
If you expect a pythonic iterator, use SEVIRDataLoader instead. |
|
|
''' |
|
|
def __init__(self, **kwargs): |
|
|
self.loader = SEVIRDataLoader(**kwargs) |
|
|
self.sample_mode = kwargs['sample_mode'] if 'sample_mode' in kwargs else 'random' |
|
|
|
|
|
def reset(self): |
|
|
self.loader.reset() |
|
|
|
|
|
def sample(self, batch_size=None): |
|
|
''' |
|
|
The input param batch_size here is not used |
|
|
''' |
|
|
out = next(self.loader, None) |
|
|
if out is None and self.sample_mode == 'random': |
|
|
self.loader.reset() |
|
|
out = next(self.loader, None) |
|
|
return out |
|
|
|
|
|
def __len__(self): |
|
|
""" |
|
|
Used only when self.sample_mode == 'sequent' |
|
|
""" |
|
|
return len(self.loader) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Meteo(Dataset): |
|
|
def __init__(self, data_path, img_size, type='train', trans=None, in_len=-1): |
|
|
super().__init__() |
|
|
|
|
|
self.pixel_scale = 70.0 |
|
|
|
|
|
self.data_path = data_path |
|
|
self.img_size = img_size |
|
|
self.in_len = in_len |
|
|
|
|
|
assert type in ['train', 'test', 'val'] |
|
|
self.type = type if type!='val' else 'test' |
|
|
with h5py.File(data_path,'r') as f: |
|
|
self.all_len = int(f[f'{self.type}_len'][()]) |
|
|
if trans is not None: |
|
|
self.transform = trans |
|
|
else: |
|
|
self.transform = T.Compose([ |
|
|
T.Resize((img_size, img_size)), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return self.all_len |
|
|
|
|
|
def sample(self): |
|
|
index = np.random.randint(0, self.all_len) |
|
|
return self.__getitem__(index) |
|
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
with h5py.File(self.data_path,'r') as f: |
|
|
imgs = f[self.type][str(index)][()] |
|
|
|
|
|
frames = torch.from_numpy(imgs).float().squeeze() |
|
|
frames = frames / self.pixel_scale |
|
|
frames = self.transform(frames).unsqueeze(1) |
|
|
|
|
|
|
|
|
return frames[:self.in_len], frames[self.in_len:] |
|
|
|
|
|
|
|
|
def load_meteonet(batch_size, val_batch_size, in_len, train=False, num_workers=0, img_size=128): |
|
|
meteo_filepath = os.path.join(METEO_FILE_DIR, "meteo.h5") |
|
|
if train: |
|
|
train_set = Meteo(meteo_filepath, img_size, 'train', in_len=in_len) |
|
|
valid_set = Meteo(meteo_filepath, img_size, 'val', in_len=in_len) |
|
|
dataloader_train = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) |
|
|
dataloader_valid = torch.utils.data.DataLoader(valid_set, batch_size=val_batch_size, shuffle=False, drop_last=True, num_workers=num_workers) |
|
|
return dataloader_train, dataloader_valid |
|
|
else: |
|
|
test_set = Meteo(meteo_filepath, img_size, 'test', in_len=in_len) |
|
|
dataloader_test = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) |
|
|
return None, dataloader_test |