File size: 4,425 Bytes
c118196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import lightning as L
import numpy as np
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.transforms import v2 as T

from src.dataset import DRDataset


class DRDataModule(L.LightningDataModule):
    def __init__(
        self,
        train_csv_path,
        val_csv_path,
        image_size: int = 224,
        batch_size: int = 8,
        num_workers: int = 4,
        use_class_weighting: bool = False,
        use_weighted_sampler: bool = False,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Ensure mutual exclusivity between use_class_weighting and use_weighted_sampler
        if use_class_weighting and use_weighted_sampler:
            raise ValueError(
                "use_class_weighting and use_weighted_sampler cannot both be True"
            )

        self.train_csv_path = train_csv_path
        self.val_csv_path = val_csv_path
        self.use_class_weighting = use_class_weighting
        self.use_weighted_sampler = use_weighted_sampler

        # Define the transformations
        self.train_transform = T.Compose(
            [
                T.Resize((image_size, image_size), antialias=True),
                T.RandomAffine(degrees=10, translate=(0.01, 0.01), scale=(0.99, 1.01)),
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.01),
                T.RandomHorizontalFlip(p=0.5),
                T.ToDtype(torch.float32, scale=True),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

        self.val_transform = T.Compose(
            [
                T.Resize((image_size, image_size), antialias=True),
                T.ToDtype(torch.float32, scale=True),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def setup(self, stage=None):
        """Set up datasets for training and validation."""
        # Initialize datasets with specified transformations
        self.train_dataset = DRDataset(
            self.train_csv_path, transform=self.train_transform
        )
        self.val_dataset = DRDataset(self.val_csv_path, transform=self.val_transform)

        # Compute number of classes and class weights
        labels = self.train_dataset.labels.numpy()
        self.num_classes = len(np.unique(labels))
        self.class_weights = (
            self._compute_class_weights(labels) if self.use_class_weighting else None
        )

    def train_dataloader(self):
        """Returns a DataLoader for training data."""
        if self.use_weighted_sampler:
            sampler = self._get_weighted_sampler(self.train_dataset.labels.numpy())
            shuffle = False  # Sampler will handle shuffling
        else:
            sampler = None
            shuffle = True

        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            shuffle=shuffle,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
        )

    def _compute_class_weights(self, labels):
        class_weights = compute_class_weight(
            class_weight="balanced", classes=np.unique(labels), y=labels
        )
        return torch.tensor(class_weights, dtype=torch.float32)

    def _get_weighted_sampler(self, labels: np.ndarray) -> WeightedRandomSampler:
        """Returns a WeightedRandomSampler based on class weights.

        The weights tensor should contain a weight for each sample, not the class weights.
        Have a look at this post for an example: https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
        https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/
        """

        class_sample_count = np.array(
            [len(np.where(labels == label)[0]) for label in np.unique(labels)]
        )
        weight = 1.0 / class_sample_count
        samples_weight = np.array([weight[label] for label in labels])
        samples_weight = torch.from_numpy(samples_weight)

        return WeightedRandomSampler(
            weights=samples_weight, num_samples=len(labels), replacement=True
        )