File size: 3,798 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pandas as pd
import os
from pathlib import Path
from monai.data import Dataset, DataLoader
import numpy as np
from torch.utils.data import Subset
import pytorch_lightning as pl
from torchsampler import ImbalancedDatasetSampler
from torch.utils.data.dataloader import default_collate

from transforms import get_training_transforms, get_base_transforms

class VerseDataModule(pl.LightningDataModule):
  def __init__(self, hparams):
    super().__init__()
    self.save_hyperparameters(dict(hparams), logger=False)
    self.data_dir = Path(self.hparams.dataset_path)
    self.csv_path = self.data_dir / 'fxall_labels.csv'

    if "modelsgenesis" in hparams.transforms:
      self.image_dir = self.data_dir / 'raw'
    else:
      self.image_dir = self.data_dir / 'ct'

    if not os.path.exists(self.image_dir):
      # legacy support
      self.image_dir = self.data_dir
      self.csv_path = self.data_dir / 'slice_labels.csv'

    self.mask_dir = self.data_dir / 'seg'
    if hparams.mask != 'none' and not os.path.exists(self.mask_dir):
      raise RuntimeError("Configured to use masks, but 'seg' folder missing in dataset path")

    self.df = pd.read_csv(self.csv_path, index_col=0)

    # TODO temporary fix to check for non-existing files
    if "path" in self.df.columns:
      self.df = self.df[self.df["path"].apply(lambda p: os.path.exists(self.image_dir / p))]

    # FIXME slice_labels.csv provides the path in 'image', fxall_labels.csv in 'path'
    if "image" not in self.df.columns:
      self.df['image'] = self.df['path']

    self.transforms = {
      'training': get_training_transforms(hparams, self.image_dir, self.mask_dir),
      'validation': get_base_transforms(hparams, self.image_dir, self.mask_dir),
      'test': get_base_transforms(hparams, self.image_dir, self.mask_dir)
    }

    self.datasets = {}
    self.idxs = {}

  def setup(self, stage=None):
    # dropping samples without fracture grading
    # graded_idxs = ~self.df.fx.isna()
    # vertebrae_level_idxs = self.df.level_idx >= self.hparams.min_vertebrae_level
    # included_idxs = graded_idxs & vertebrae_level_idxs
    
    if stage == 'fit' or stage is None:
      phases = ['training', 'validation']
    else:
      phases = ['test']

    for split in phases:
      # get official verse partitions
      idxs = self.df[f'split_{self.hparams.fold}'] == split
      # idxs = np.where(included_idxs & idxs)[0]
      idxs = np.where(idxs)[0]
      self.idxs[split] = idxs
      self.datasets[split] = Dataset(
        self.df.iloc[idxs].to_dict('records'), 
        transform=self.transforms[split]
      )

  def get_label(self, data):
    train_df = self.df.iloc[self.idxs['training']]
    grading = train_df.fx_grading
    if self.hparams.task == 'detection':
        return train_df.fx
    elif self.hparams.task == 'grading':
        return grading
    elif self.hparams.task == 'simple_grading':
        if grading in [2,3]:
          return 1
        if grading>3:
          return grading-2
        else:
          return grading

  def train_dataloader(self):
    return DataLoader(
      self.datasets['training'], 
      batch_size=self.hparams.batch_size,
      sampler=ImbalancedDatasetSampler(
        num_samples=self.df.iloc[self.idxs['training']].fx.sum() * 2,
        dataset=self.datasets['training'],
        callback_get_label=self.get_label,
        ) if self.hparams.oversampling else None, 
      num_workers=2,
      shuffle=not self.hparams.oversampling
    )

  def val_dataloader(self):
    return DataLoader(
      self.datasets['validation'], 
      batch_size=self.hparams.batch_size, 
      num_workers=8,
    )

  def test_dataloader(self):
    return DataLoader(
      self.datasets['test'], 
      batch_size=self.hparams.batch_size, 
      num_workers=8,
    )