| import torch |
| import numpy as np |
| class BalanceSampler(torch.utils.data.sampler.Sampler): |
| def __init__ (self, data): |
| self.data = data |
|
|
| self.labels = torch.stack([self.data[entry_idx][2] for entry_idx in range(len(self.data))]) |
| self.sums = self.labels.sum(dim=0) |
| self.avg = int(torch.mean(self.sums).item()) |
|
|
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __iter__(self): |
| training = [] |
| minority_classes = torch.where(self.sums < self.avg)[0] |
| majority_classes = torch.where(self.sums >= self.avg)[0] |
|
|
| for class_idx in minority_classes: |
| class_indices = torch.where(self.labels[:, class_idx] == 1)[0] |
| oversampled_indices = np.random.choice(class_indices, size=self.avg, replace=True) |
| training.extend(oversampled_indices.tolist()) |
|
|
| |
| for class_idx in majority_classes: |
| class_indices = torch.where(self.labels[:, class_idx] == 1)[0] |
| undersampled_indices = np.random.choice(class_indices, size=self.avg, replace=False) |
| training.extend(undersampled_indices.tolist()) |
| training=np.random.choice(training, size=6300, replace=False) |
|
|
|
|
| return iter(training) |
|
|
| def __getitem__(self, index): |
| return self.data[index] |