|
""" |
|
Defines a Loader class to load data from a file or file wildcard |
|
""" |
|
|
|
import argparse |
|
|
|
import h5py |
|
import torch |
|
import numpy as np |
|
import glob |
|
from typing import Tuple |
|
|
|
|
|
class Loader: |
|
""" |
|
Data loader class |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
parser = Loader.add_argparse_args() |
|
for action in parser._actions: |
|
if action.dest in kwargs: |
|
action.default = kwargs[action.dest] |
|
args = parser.parse_args([]) |
|
self.__dict__.update(vars(args)) |
|
|
|
if type(self.label_vars) is str: |
|
self.label_vars = [self.label_vars] |
|
|
|
@staticmethod |
|
def add_argparse_args(parent_parser=None): |
|
""" |
|
Add argeparse argument for the data loader |
|
""" |
|
parser = argparse.ArgumentParser( |
|
prog='Loader', |
|
usage=Loader.__doc__, |
|
parents=[parent_parser] if parent_parser is not None else [], |
|
add_help=False) |
|
|
|
parser.add_argument('--input_var', default='p_f5.0_o0', help='Variable name for the label data') |
|
parser.add_argument('--label_vars', nargs='*', default='c0', help='Variable name(s) for the label data') |
|
|
|
parser.add_argument('--inputs_crop', type=int, default=[0, 1, 32, 96, 42, 2090], nargs='*', |
|
help='Crop input data on load [layer_min layer_max x_min x_max y_min y_max]') |
|
parser.add_argument('--labels_crop', type=int, default=[322, 830, 60, 1076], nargs='*', help='Crop label data on load [x_min x_max y_min y_max]') |
|
parser.add_argument('--labels_resize', type=float, default=256.0 / 1016.0, help='scaling factor for labels image') |
|
|
|
parser.add_argument('--data_scale', type=float, default=1.0, help='Data scaling factor') |
|
parser.add_argument('--data_gain', type=float, default=1.8, help='Data gain factor in dB/20 at farthest point in data.') |
|
|
|
return parser |
|
|
|
def load_data(self, test_file_pattern: str, train_file_pattern: str = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]: |
|
"""Loads training/testing data from file(s) |
|
|
|
Arguments: |
|
test_file_pattern {str} -- testing dataset(s) pattern |
|
train_file_pattern {str} -- training dataset(s) pattern |
|
|
|
Returns: |
|
(test_inputs, test_labels, train_inputs, train_labels) -- None for values that are not loaded |
|
""" |
|
|
|
test_inputs, test_labels = self._load_data_files(test_file_pattern) |
|
train_inputs, train_labels = self._load_data_files(train_file_pattern) |
|
|
|
if train_file_pattern is not None and train_inputs is None: |
|
raise ValueError('Failed to load train set') |
|
|
|
if test_file_pattern is not None and test_inputs is None: |
|
raise ValueError('Failed to load train set') |
|
|
|
return test_inputs, test_labels, train_inputs, train_labels |
|
|
|
def _load_data_files(self, file_pattern: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]: |
|
""" Perform actual data loading |
|
|
|
Args: |
|
file_pattern: file name pattern |
|
|
|
Returns: |
|
inputs and labels tensors |
|
""" |
|
|
|
inputs, labels = None, None |
|
|
|
if file_pattern is None: |
|
return inputs, labels |
|
|
|
files = glob.glob(file_pattern) |
|
|
|
if len(files) == 0: |
|
raise ValueError(f'{file_pattern=} comes up empty') |
|
|
|
|
|
with h5py.File(files[0], 'r') as f: |
|
if self.input_var not in f: |
|
raise ValueError(f'input data key not in file: {self.input_var=}') |
|
|
|
shape = list(f[self.input_var].shape) |
|
if self.inputs_crop is not None: |
|
for i in range(len(self.inputs_crop) // 2): |
|
shape[-i - 1] = self.inputs_crop[-i * 2 - 1] - self.inputs_crop[-i * 2 - 2] |
|
|
|
shape[0] *= len(files) |
|
|
|
inputs = np.empty(shape, np.single) |
|
|
|
if len(self.label_vars): |
|
if not all([v in f for v in self.label_vars]): |
|
raise ValueError(f'labels data key(s) not in file: {self.label_vars=}') |
|
|
|
shape = list(f[self.label_vars[0]].shape) |
|
shape[1] *= len(self.label_vars) |
|
if self.labels_crop is not None: |
|
for i in range(len(self.labels_crop) // 2): |
|
shape[-i - 1] = self.labels_crop[-i * 2 - 1] - self.labels_crop[-i * 2 - 2] |
|
|
|
shape[-1] = int(shape[-1] * self.labels_resize) |
|
shape[-2] = int(shape[-2] * self.labels_resize) |
|
shape[0] *= len(files) |
|
|
|
labels = np.empty(shape, np.single) |
|
|
|
|
|
pos = 0 |
|
for file in files: |
|
with h5py.File(files[0], 'r') as f: |
|
tmp_inputs = np.array(f[self.input_var]) |
|
|
|
if self.inputs_crop is not None: |
|
slc = [slice(None)] * 4 |
|
for i in range(len(self.inputs_crop) // 2): |
|
slc[-i - 1] = slice(self.inputs_crop[-i * 2 - 2], self.inputs_crop[-i * 2 - 1]) |
|
tmp_inputs = tmp_inputs[tuple(slc)] |
|
|
|
inputs[pos:pos + tmp_inputs.shape[0], ...] = tmp_inputs |
|
|
|
if len(self.label_vars): |
|
tmp_labels = [] |
|
for v in self.label_vars: |
|
tmp_labels.append(np.array(f[v])) |
|
tmp_labels = np.concatenate(tmp_labels, axis=1) |
|
|
|
if self.labels_crop is not None and self.labels_crop: |
|
slc = [slice(None)] * 4 |
|
for i in range(len(self.labels_crop) // 2): |
|
slc[-i - 1] = slice(self.labels_crop[-i * 2 - 2], self.labels_crop[-i * 2 - 1]) |
|
tmp_labels = tmp_labels[tuple(slc)] |
|
|
|
if self.labels_resize != 1.0: |
|
tmp_labels = torch.nn.Upsample(scale_factor=self.labels_resize, mode='nearest')(torch.from_numpy(tmp_labels)).numpy() |
|
|
|
labels[pos:pos + tmp_labels.shape[0], ...] = tmp_labels |
|
|
|
pos += tmp_inputs.shape[0] |
|
|
|
inputs = inputs[:pos, ...] |
|
if len(self.label_vars): |
|
labels = labels[:pos, ...] |
|
|
|
if self.data_scale != 1.0: |
|
inputs *= self.data_scale |
|
|
|
if self.data_gain != 0.0: |
|
gain = 10.0 ** np.linspace(0, self.data_gain, inputs.shape[-1], np.single).reshape((1, 1, 1, -1)) |
|
inputs *= gain |
|
|
|
|
|
|
|
inputs = torch.from_numpy(inputs.copy()) |
|
if len(self.label_vars): |
|
labels = torch.from_numpy(labels) |
|
|
|
return inputs, labels |
|
|