medical imaging
ultrasound
dl_us_sos_inversion / loader.py
laughingrice's picture
Upload 11 files
6ce7d82
"""
Defines a Loader class to load data from a file or file wildcard
"""
import argparse
import h5py
import torch
import numpy as np
import glob
from typing import Tuple
class Loader:
"""
Data loader class
"""
def __init__(self, **kwargs):
parser = Loader.add_argparse_args()
for action in parser._actions:
if action.dest in kwargs:
action.default = kwargs[action.dest]
args = parser.parse_args([])
self.__dict__.update(vars(args))
if type(self.label_vars) is str:
self.label_vars = [self.label_vars]
@staticmethod
def add_argparse_args(parent_parser=None):
"""
Add argeparse argument for the data loader
"""
parser = argparse.ArgumentParser(
prog='Loader',
usage=Loader.__doc__,
parents=[parent_parser] if parent_parser is not None else [],
add_help=False)
parser.add_argument('--input_var', default='p_f5.0_o0', help='Variable name for the label data')
parser.add_argument('--label_vars', nargs='*', default='c0', help='Variable name(s) for the label data')
parser.add_argument('--inputs_crop', type=int, default=[0, 1, 32, 96, 42, 2090], nargs='*',
help='Crop input data on load [layer_min layer_max x_min x_max y_min y_max]')
parser.add_argument('--labels_crop', type=int, default=[322, 830, 60, 1076], nargs='*', help='Crop label data on load [x_min x_max y_min y_max]')
parser.add_argument('--labels_resize', type=float, default=256.0 / 1016.0, help='scaling factor for labels image')
parser.add_argument('--data_scale', type=float, default=1.0, help='Data scaling factor')
parser.add_argument('--data_gain', type=float, default=1.8, help='Data gain factor in dB/20 at farthest point in data.')
return parser
def load_data(self, test_file_pattern: str, train_file_pattern: str = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Loads training/testing data from file(s)
Arguments:
test_file_pattern {str} -- testing dataset(s) pattern
train_file_pattern {str} -- training dataset(s) pattern
Returns:
(test_inputs, test_labels, train_inputs, train_labels) -- None for values that are not loaded
"""
test_inputs, test_labels = self._load_data_files(test_file_pattern)
train_inputs, train_labels = self._load_data_files(train_file_pattern)
if train_file_pattern is not None and train_inputs is None:
raise ValueError('Failed to load train set')
if test_file_pattern is not None and test_inputs is None:
raise ValueError('Failed to load train set')
return test_inputs, test_labels, train_inputs, train_labels
def _load_data_files(self, file_pattern: str) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
""" Perform actual data loading
Args:
file_pattern: file name pattern
Returns:
inputs and labels tensors
"""
inputs, labels = None, None
if file_pattern is None:
return inputs, labels
files = glob.glob(file_pattern)
if len(files) == 0:
raise ValueError(f'{file_pattern=} comes up empty')
# Load first file to get output dimensions
with h5py.File(files[0], 'r') as f:
if self.input_var not in f:
raise ValueError(f'input data key not in file: {self.input_var=}')
shape = list(f[self.input_var].shape)
if self.inputs_crop is not None:
for i in range(len(self.inputs_crop) // 2):
shape[-i - 1] = self.inputs_crop[-i * 2 - 1] - self.inputs_crop[-i * 2 - 2]
shape[0] *= len(files)
inputs = np.empty(shape, np.single)
if len(self.label_vars):
if not all([v in f for v in self.label_vars]):
raise ValueError(f'labels data key(s) not in file: {self.label_vars=}')
shape = list(f[self.label_vars[0]].shape)
shape[1] *= len(self.label_vars)
if self.labels_crop is not None:
for i in range(len(self.labels_crop) // 2):
shape[-i - 1] = self.labels_crop[-i * 2 - 1] - self.labels_crop[-i * 2 - 2]
shape[-1] = int(shape[-1] * self.labels_resize)
shape[-2] = int(shape[-2] * self.labels_resize)
shape[0] *= len(files)
labels = np.empty(shape, np.single)
# Load data from files
pos = 0
for file in files:
with h5py.File(files[0], 'r') as f:
tmp_inputs = np.array(f[self.input_var])
if self.inputs_crop is not None:
slc = [slice(None)] * 4
for i in range(len(self.inputs_crop) // 2):
slc[-i - 1] = slice(self.inputs_crop[-i * 2 - 2], self.inputs_crop[-i * 2 - 1])
tmp_inputs = tmp_inputs[tuple(slc)]
inputs[pos:pos + tmp_inputs.shape[0], ...] = tmp_inputs
if len(self.label_vars):
tmp_labels = []
for v in self.label_vars:
tmp_labels.append(np.array(f[v]))
tmp_labels = np.concatenate(tmp_labels, axis=1)
if self.labels_crop is not None and self.labels_crop:
slc = [slice(None)] * 4
for i in range(len(self.labels_crop) // 2):
slc[-i - 1] = slice(self.labels_crop[-i * 2 - 2], self.labels_crop[-i * 2 - 1])
tmp_labels = tmp_labels[tuple(slc)]
if self.labels_resize != 1.0:
tmp_labels = torch.nn.Upsample(scale_factor=self.labels_resize, mode='nearest')(torch.from_numpy(tmp_labels)).numpy()
labels[pos:pos + tmp_labels.shape[0], ...] = tmp_labels
pos += tmp_inputs.shape[0]
inputs = inputs[:pos, ...]
if len(self.label_vars):
labels = labels[:pos, ...]
if self.data_scale != 1.0:
inputs *= self.data_scale
if self.data_gain != 0.0:
gain = 10.0 ** np.linspace(0, self.data_gain, inputs.shape[-1], np.single).reshape((1, 1, 1, -1))
inputs *= gain
# Required when inputs is non-continuous due to transpose
# TODO: Could probably use a check on strides and do a conditional copy.
inputs = torch.from_numpy(inputs.copy())
if len(self.label_vars):
labels = torch.from_numpy(labels)
return inputs, labels