|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from torch.utils.data.dataset import Dataset |
|
|
from torch.utils.data.dataloader import DataLoader |
|
|
from torch.utils.data.sampler import WeightedRandomSampler |
|
|
from PIL import Image |
|
|
from sklearn.preprocessing import MinMaxScaler |
|
|
import pickle |
|
|
from .logger import BaseLogger |
|
|
from typing import List, Dict, Union |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
logger = BaseLogger.get_logger(__name__) |
|
|
|
|
|
|
|
|
class PrivateAugment(torch.nn.Module): |
|
|
""" |
|
|
Augmentation defined privately. |
|
|
Variety of augmentation can be written in this class if necessary. |
|
|
""" |
|
|
|
|
|
xray_augs_list = [ |
|
|
transforms.RandomAffine(degrees=(-3, 3), translate=(0.02, 0.02)), |
|
|
transforms.RandomAdjustSharpness(sharpness_factor=2), |
|
|
transforms.RandomAutocontrast() |
|
|
] |
|
|
|
|
|
|
|
|
class InputDataMixin: |
|
|
""" |
|
|
Class to normalizes input data. |
|
|
""" |
|
|
def _make_scaler(self) -> MinMaxScaler: |
|
|
""" |
|
|
Make scaler to normalize input data by min-max normalization with train data. |
|
|
|
|
|
Returns: |
|
|
MinMaxScaler: scaler |
|
|
""" |
|
|
scaler = MinMaxScaler() |
|
|
_df_train = self.df_source[self.df_source['split'] == 'train'] |
|
|
_ = scaler.fit(_df_train[self.input_list]) |
|
|
return scaler |
|
|
|
|
|
def save_scaler(self, save_path :str) -> None: |
|
|
""" |
|
|
Save scaler |
|
|
|
|
|
Args: |
|
|
save_path (str): path for saving scaler. |
|
|
""" |
|
|
|
|
|
with open(save_path, 'wb') as f: |
|
|
pickle.dump(self.scaler, f) |
|
|
|
|
|
def load_scaler(self, scaler_path :str) -> None: |
|
|
""" |
|
|
Load scaler. |
|
|
|
|
|
Args: |
|
|
scaler_path (str): path to scaler |
|
|
""" |
|
|
with open(scaler_path, 'rb') as f: |
|
|
scaler = pickle.load(f) |
|
|
return scaler |
|
|
|
|
|
def _normalize_inputs(self, df_inputs: pd.DataFrame) -> torch.FloatTensor: |
|
|
""" |
|
|
Normalize inputs. |
|
|
|
|
|
Args: |
|
|
df_inputs (pd.DataFrame): DataFrame of inputs |
|
|
|
|
|
Returns: |
|
|
torch.FloatTensor: normalized inputs |
|
|
|
|
|
Note: |
|
|
After iloc[[idx], index_input_list], pd.DataFrame is obtained. |
|
|
DataFrame fits the input type of self.scaler.transform. |
|
|
However, after normalizing, the shape of inputs_value is (1, N), where N is the number of input values. |
|
|
Since the shape (1, N) is not acceptable when forwarding, convert (1, N) -> (N,) is needed. |
|
|
""" |
|
|
inputs_value = self.scaler.transform(df_inputs).reshape(-1) |
|
|
inputs_value = np.array(inputs_value, dtype=np.float32) |
|
|
inputs_value = torch.from_numpy(inputs_value).clone() |
|
|
return inputs_value |
|
|
|
|
|
def _load_input_value_if_mlp(self, idx: int) -> Union[torch.FloatTensor, str]: |
|
|
""" |
|
|
Load input values after converting them into tensor if MLP is used. |
|
|
|
|
|
Args: |
|
|
idx (int): index |
|
|
|
|
|
Returns: |
|
|
Union[torch.Tensor[float], str]: tensor of input values, or empty string |
|
|
""" |
|
|
inputs_value = '' |
|
|
|
|
|
if self.params.mlp is None: |
|
|
return inputs_value |
|
|
|
|
|
index_input_list = [self.col_index_dict[input] for input in self.input_list] |
|
|
_df_inputs = self.df_split.iloc[[idx], index_input_list] |
|
|
inputs_value = self._normalize_inputs( _df_inputs) |
|
|
return inputs_value |
|
|
|
|
|
|
|
|
class ImageMixin: |
|
|
""" |
|
|
Class to normalize and transform image. |
|
|
""" |
|
|
def _make_augmentations(self) -> List: |
|
|
""" |
|
|
Define which augmentation is applied. |
|
|
|
|
|
When training, augmentation is needed for train data only. |
|
|
When test, no need of augmentation. |
|
|
""" |
|
|
_augmentation = [] |
|
|
if (self.params.isTrain) and (self.split == 'train'): |
|
|
if self.params.augmentation == 'xrayaug': |
|
|
_augmentation = PrivateAugment.xray_augs_list |
|
|
elif self.params.augmentation == 'trivialaugwide': |
|
|
_augmentation.append(transforms.TrivialAugmentWide()) |
|
|
elif self.params.augmentation == 'randaug': |
|
|
_augmentation.append(transforms.RandAugment()) |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
_augmentation = transforms.Compose(_augmentation) |
|
|
return _augmentation |
|
|
|
|
|
def _make_transforms(self) -> List: |
|
|
""" |
|
|
Make list of transforms. |
|
|
|
|
|
Returns: |
|
|
list of transforms: image normalization |
|
|
""" |
|
|
_transforms = [] |
|
|
_transforms.append(transforms.ToTensor()) |
|
|
|
|
|
if self.params.normalize_image == 'yes': |
|
|
|
|
|
if self.params.in_channel == 1: |
|
|
_transforms.append(transforms.Normalize(mean=(0.5, ), std=(0.5, ))) |
|
|
else: |
|
|
|
|
|
_transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) |
|
|
|
|
|
_transforms = transforms.Compose(_transforms) |
|
|
return _transforms |
|
|
|
|
|
def _open_image_in_channel(self, imgpath: str, in_channel: int) -> Image: |
|
|
""" |
|
|
Open image in channel. |
|
|
|
|
|
Args: |
|
|
imgpath (str): path to image |
|
|
in_channel (int): channel, or 1 or 3 |
|
|
|
|
|
Returns: |
|
|
Image: PIL image |
|
|
""" |
|
|
if in_channel == 1: |
|
|
image = Image.open(imgpath).convert('L') |
|
|
return image |
|
|
else: |
|
|
|
|
|
image = Image.open(imgpath).convert('RGB') |
|
|
return image |
|
|
|
|
|
def _load_image_if_cnn(self, idx: int) -> Union[torch.Tensor, str]: |
|
|
""" |
|
|
Load image and convert it to tensor if any of CNN or ViT is used. |
|
|
|
|
|
Args: |
|
|
idx (int): index |
|
|
|
|
|
Returns: |
|
|
Union[torch.Tensor[float], str]: tensor converted from image, or empty string |
|
|
""" |
|
|
image = '' |
|
|
|
|
|
if self.params.net is None: |
|
|
return image |
|
|
|
|
|
imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']] |
|
|
image = self._open_image_in_channel(imgpath, self.params.in_channel) |
|
|
image = self.augmentation(image) |
|
|
image = self.transform(image) |
|
|
return image |
|
|
|
|
|
|
|
|
class DeepSurvMixin: |
|
|
""" |
|
|
Class to handle required data for deepsurv. |
|
|
""" |
|
|
def _load_periods_if_deepsurv(self, idx: int) -> Union[torch.FloatTensor, str]: |
|
|
""" |
|
|
Return period if deepsurv. |
|
|
|
|
|
Args: |
|
|
idx (int): index |
|
|
|
|
|
Returns: |
|
|
Union[torch.FloatTensor, str]: period, or empty string |
|
|
""" |
|
|
periods = '' |
|
|
|
|
|
if self.params.task != 'deepsurv': |
|
|
return periods |
|
|
|
|
|
assert (self.params.task == 'deepsurv') and (len(self.label_list) == 1), 'Deepsurv cannot work in multi-label.' |
|
|
periods = self.df_split.iat[idx, self.col_index_dict[self.period_name]] |
|
|
periods = np.array(periods, dtype=np.float32) |
|
|
periods = torch.from_numpy(periods).clone() |
|
|
return periods |
|
|
|
|
|
|
|
|
class DataSetWidget(InputDataMixin, ImageMixin, DeepSurvMixin): |
|
|
""" |
|
|
Class for a widget to inherit multiple classes simultaneously. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class LoadDataSet(Dataset, DataSetWidget): |
|
|
""" |
|
|
Dataset for split. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
params, |
|
|
split: str |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
params (ParamSet): parameter for model |
|
|
split (str): split |
|
|
""" |
|
|
self.params = params |
|
|
self.df_source = self.params.df_source |
|
|
self.split = split |
|
|
|
|
|
self.input_list = self.params.input_list |
|
|
self.label_list = self.params.label_list |
|
|
|
|
|
if self.params.task == 'deepsurv': |
|
|
self.period_name = self.params.period_name |
|
|
|
|
|
self.df_split = self.df_source[self.df_source['split'] == self.split] |
|
|
self.col_index_dict = {col_name: self.df_split.columns.get_loc(col_name) for col_name in self.df_split.columns} |
|
|
|
|
|
|
|
|
if self.params.mlp is not None: |
|
|
assert (self.input_list != []), f"input list is empty." |
|
|
if params.isTrain: |
|
|
self.scaler = self._make_scaler() |
|
|
else: |
|
|
|
|
|
self.scaler = self.load_scaler(self.params.scaler_path) |
|
|
|
|
|
|
|
|
if self.params.net is not None: |
|
|
self.augmentation = self._make_augmentations() |
|
|
self.transform = self._make_transforms() |
|
|
|
|
|
def __len__(self) -> int: |
|
|
""" |
|
|
Return length of DataFrame. |
|
|
|
|
|
Returns: |
|
|
int: length of DataFrame |
|
|
""" |
|
|
return len(self.df_split) |
|
|
|
|
|
def _load_label(self, idx: int) -> Dict[str, Union[int, float]]: |
|
|
""" |
|
|
Return labels. |
|
|
If no column of label when csv of external dataset is used, |
|
|
empty dictionary is returned. |
|
|
|
|
|
Args: |
|
|
idx (int): index |
|
|
|
|
|
Returns: |
|
|
Dict[str, Union[int, float]]: dictionary of label name and its value |
|
|
""" |
|
|
|
|
|
label_list_in_split = list(self.df_split.columns[self.df_split.columns.str.startswith('label')]) |
|
|
label_dict = dict() |
|
|
if label_list_in_split != []: |
|
|
for label_name in self.label_list: |
|
|
label_dict[label_name] = self.df_split.iat[idx, self.col_index_dict[label_name]] |
|
|
else: |
|
|
|
|
|
pass |
|
|
return label_dict |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict: |
|
|
""" |
|
|
Return data row specified by index. |
|
|
|
|
|
Args: |
|
|
idx (int): index |
|
|
|
|
|
Returns: |
|
|
Dict: dictionary of data to be passed model |
|
|
""" |
|
|
uniqID = self.df_split.iat[idx, self.col_index_dict['uniqID']] |
|
|
group = self.df_split.iat[idx, self.col_index_dict['group']] |
|
|
imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']] |
|
|
split = self.df_split.iat[idx, self.col_index_dict['split']] |
|
|
inputs_value = self._load_input_value_if_mlp(idx) |
|
|
image = self._load_image_if_cnn(idx) |
|
|
label_dict = self._load_label(idx) |
|
|
periods = self._load_periods_if_deepsurv(idx) |
|
|
|
|
|
_data = { |
|
|
'uniqID': uniqID, |
|
|
'group': group, |
|
|
'imgpath': imgpath, |
|
|
'split': split, |
|
|
'inputs': inputs_value, |
|
|
'image': image, |
|
|
'labels': label_dict, |
|
|
'periods': periods |
|
|
} |
|
|
return _data |
|
|
|
|
|
|
|
|
def _make_sampler(split_data: LoadDataSet) -> WeightedRandomSampler: |
|
|
""" |
|
|
Make sampler. |
|
|
|
|
|
Args: |
|
|
split_data (LoadDataSet): dataset |
|
|
|
|
|
Returns: |
|
|
WeightedRandomSampler: sampler |
|
|
""" |
|
|
_target = [] |
|
|
for _, data in enumerate(split_data): |
|
|
_target.append(list(data['labels'].values())[0]) |
|
|
|
|
|
class_sample_count = np.array([len(np.where(_target == t)[0]) for t in np.unique(_target)]) |
|
|
weight = 1. / class_sample_count |
|
|
samples_weight = np.array([weight[t] for t in _target]) |
|
|
sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) |
|
|
return sampler |
|
|
|
|
|
|
|
|
def create_dataloader( |
|
|
params, |
|
|
split: str = None |
|
|
) -> DataLoader: |
|
|
""" |
|
|
Create data loader ofr split. |
|
|
|
|
|
Args: |
|
|
params (ParamSet): parameter for dataloader |
|
|
split (str): split. Defaults to None. |
|
|
|
|
|
Returns: |
|
|
DataLoader: data loader |
|
|
""" |
|
|
split_data = LoadDataSet(params, split) |
|
|
|
|
|
if params.isTrain: |
|
|
batch_size = params.batch_size |
|
|
shuffle = True |
|
|
else: |
|
|
batch_size = params.test_batch_size |
|
|
shuffle = False |
|
|
|
|
|
if params.sampler == 'yes': |
|
|
assert ((params.task == 'classification') or (params.task == 'deepsurv')), 'Cannot make sampler in regression.' |
|
|
assert (len(params.label_list) == 1), 'Cannot make sampler for multi-label.' |
|
|
shuffle = False |
|
|
sampler = _make_sampler(split_data) |
|
|
else: |
|
|
|
|
|
sampler = None |
|
|
|
|
|
split_loader = DataLoader( |
|
|
dataset=split_data, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
num_workers=0, |
|
|
sampler=sampler |
|
|
) |
|
|
return split_loader |
|
|
|