Paul Engstler
Initial commit
92f0e98
raw
history blame
No virus
3.8 kB
import pandas as pd
import os
from pathlib import Path
from monai.data import Dataset, DataLoader
import numpy as np
from torch.utils.data import Subset
import pytorch_lightning as pl
from torchsampler import ImbalancedDatasetSampler
from torch.utils.data.dataloader import default_collate
from transforms import get_training_transforms, get_base_transforms
class VerseDataModule(pl.LightningDataModule):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(dict(hparams), logger=False)
self.data_dir = Path(self.hparams.dataset_path)
self.csv_path = self.data_dir / 'fxall_labels.csv'
if "modelsgenesis" in hparams.transforms:
self.image_dir = self.data_dir / 'raw'
else:
self.image_dir = self.data_dir / 'ct'
if not os.path.exists(self.image_dir):
# legacy support
self.image_dir = self.data_dir
self.csv_path = self.data_dir / 'slice_labels.csv'
self.mask_dir = self.data_dir / 'seg'
if hparams.mask != 'none' and not os.path.exists(self.mask_dir):
raise RuntimeError("Configured to use masks, but 'seg' folder missing in dataset path")
self.df = pd.read_csv(self.csv_path, index_col=0)
# TODO temporary fix to check for non-existing files
if "path" in self.df.columns:
self.df = self.df[self.df["path"].apply(lambda p: os.path.exists(self.image_dir / p))]
# FIXME slice_labels.csv provides the path in 'image', fxall_labels.csv in 'path'
if "image" not in self.df.columns:
self.df['image'] = self.df['path']
self.transforms = {
'training': get_training_transforms(hparams, self.image_dir, self.mask_dir),
'validation': get_base_transforms(hparams, self.image_dir, self.mask_dir),
'test': get_base_transforms(hparams, self.image_dir, self.mask_dir)
}
self.datasets = {}
self.idxs = {}
def setup(self, stage=None):
# dropping samples without fracture grading
# graded_idxs = ~self.df.fx.isna()
# vertebrae_level_idxs = self.df.level_idx >= self.hparams.min_vertebrae_level
# included_idxs = graded_idxs & vertebrae_level_idxs
if stage == 'fit' or stage is None:
phases = ['training', 'validation']
else:
phases = ['test']
for split in phases:
# get official verse partitions
idxs = self.df[f'split_{self.hparams.fold}'] == split
# idxs = np.where(included_idxs & idxs)[0]
idxs = np.where(idxs)[0]
self.idxs[split] = idxs
self.datasets[split] = Dataset(
self.df.iloc[idxs].to_dict('records'),
transform=self.transforms[split]
)
def get_label(self, data):
train_df = self.df.iloc[self.idxs['training']]
grading = train_df.fx_grading
if self.hparams.task == 'detection':
return train_df.fx
elif self.hparams.task == 'grading':
return grading
elif self.hparams.task == 'simple_grading':
if grading in [2,3]:
return 1
if grading>3:
return grading-2
else:
return grading
def train_dataloader(self):
return DataLoader(
self.datasets['training'],
batch_size=self.hparams.batch_size,
sampler=ImbalancedDatasetSampler(
num_samples=self.df.iloc[self.idxs['training']].fx.sum() * 2,
dataset=self.datasets['training'],
callback_get_label=self.get_label,
) if self.hparams.oversampling else None,
num_workers=2,
shuffle=not self.hparams.oversampling
)
def val_dataloader(self):
return DataLoader(
self.datasets['validation'],
batch_size=self.hparams.batch_size,
num_workers=8,
)
def test_dataloader(self):
return DataLoader(
self.datasets['test'],
batch_size=self.hparams.batch_size,
num_workers=8,
)