""" Defines a Loader class to load data from a file or file wildcard """ import argparse import h5py import torch import numpy as np import glob from typing import Tuple class Loader: """ Data loader class """ def __init__(self, **kwargs): parser = Loader.add_argparse_args() for action in parser._actions: if action.dest in kwargs: action.default = kwargs[action.dest] args = parser.parse_args([]) self.__dict__.update(vars(args)) if type(self.label_vars) is str: self.label_vars = [self.label_vars] @staticmethod def add_argparse_args(parent_parser=None): """ Add argeparse argument for the data loader """ parser = argparse.ArgumentParser( prog='Loader', usage=Loader.__doc__, parents=[parent_parser] if parent_parser is not None else [], add_help=False) parser.add_argument('--input_var', default='p_f5.0_o0', help='Variable name for the label data') parser.add_argument('--label_vars', nargs='*', default='c0', help='Variable name(s) for the label data') parser.add_argument('--inputs_crop', type=int, default=[0, 1, 32, 96, 42, 2090], nargs='*', help='Crop input data on load [layer_min layer_max x_min x_max y_min y_max]') parser.add_argument('--labels_crop', type=int, default=[322, 830, 60, 1076], nargs='*', help='Crop label data on load [x_min x_max y_min y_max]') parser.add_argument('--labels_resize', type=float, default=256.0 / 1016.0, help='scaling factor for labels image') parser.add_argument('--data_scale', type=float, default=1.0, help='Data scaling factor') parser.add_argument('--data_gain', type=float, default=1.8, help='Data gain factor in dB/20 at farthest point in data.') return parser def load_data(self, test_file_pattern: str, train_file_pattern: str = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """Loads training/testing data from file(s) Arguments: test_file_pattern {str} -- testing dataset(s) pattern train_file_pattern {str} -- training dataset(s) pattern Returns: (test_inputs, test_labels, train_inputs, train_labels) -- None for values that are not loaded """ test_inputs, test_labels = self._load_data_files(test_file_pattern) train_inputs, train_labels = self._load_data_files(train_file_pattern) if train_file_pattern is not None and train_inputs is None: raise ValueError('Failed to load train set') if test_file_pattern is not None and test_inputs is None: raise ValueError('Failed to load train set') return test_inputs, test_labels, train_inputs, train_labels def _load_data_files(self, file_pattern: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """ Perform actual data loading Args: file_pattern: file name pattern Returns: inputs and labels tensors """ inputs, labels = None, None if file_pattern is None: return inputs, labels files = glob.glob(file_pattern) if len(files) == 0: raise ValueError(f'{file_pattern=} comes up empty') # Load first file to get output dimensions with h5py.File(files[0], 'r') as f: if self.input_var not in f: raise ValueError(f'input data key not in file: {self.input_var=}') shape = list(f[self.input_var].shape) if self.inputs_crop is not None: for i in range(len(self.inputs_crop) // 2): shape[-i - 1] = self.inputs_crop[-i * 2 - 1] - self.inputs_crop[-i * 2 - 2] shape[0] *= len(files) inputs = np.empty(shape, np.single) if len(self.label_vars): if not all([v in f for v in self.label_vars]): raise ValueError(f'labels data key(s) not in file: {self.label_vars=}') shape = list(f[self.label_vars[0]].shape) shape[1] *= len(self.label_vars) if self.labels_crop is not None: for i in range(len(self.labels_crop) // 2): shape[-i - 1] = self.labels_crop[-i * 2 - 1] - self.labels_crop[-i * 2 - 2] shape[-1] = int(shape[-1] * self.labels_resize) shape[-2] = int(shape[-2] * self.labels_resize) shape[0] *= len(files) labels = np.empty(shape, np.single) # Load data from files pos = 0 for file in files: with h5py.File(files[0], 'r') as f: tmp_inputs = np.array(f[self.input_var]) if self.inputs_crop is not None: slc = [slice(None)] * 4 for i in range(len(self.inputs_crop) // 2): slc[-i - 1] = slice(self.inputs_crop[-i * 2 - 2], self.inputs_crop[-i * 2 - 1]) tmp_inputs = tmp_inputs[tuple(slc)] inputs[pos:pos + tmp_inputs.shape[0], ...] = tmp_inputs if len(self.label_vars): tmp_labels = [] for v in self.label_vars: tmp_labels.append(np.array(f[v])) tmp_labels = np.concatenate(tmp_labels, axis=1) if self.labels_crop is not None and self.labels_crop: slc = [slice(None)] * 4 for i in range(len(self.labels_crop) // 2): slc[-i - 1] = slice(self.labels_crop[-i * 2 - 2], self.labels_crop[-i * 2 - 1]) tmp_labels = tmp_labels[tuple(slc)] if self.labels_resize != 1.0: tmp_labels = torch.nn.Upsample(scale_factor=self.labels_resize, mode='nearest')(torch.from_numpy(tmp_labels)).numpy() labels[pos:pos + tmp_labels.shape[0], ...] = tmp_labels pos += tmp_inputs.shape[0] inputs = inputs[:pos, ...] if len(self.label_vars): labels = labels[:pos, ...] if self.data_scale != 1.0: inputs *= self.data_scale if self.data_gain != 0.0: gain = 10.0 ** np.linspace(0, self.data_gain, inputs.shape[-1], np.single).reshape((1, 1, 1, -1)) inputs *= gain # Required when inputs is non-continuous due to transpose # TODO: Could probably use a check on strides and do a conditional copy. inputs = torch.from_numpy(inputs.copy()) if len(self.label_vars): labels = torch.from_numpy(labels) return inputs, labels