| | import torch |
| | import numpy as np |
| | class BalanceSampler(torch.utils.data.sampler.Sampler): |
| | def __init__ (self, data): |
| | self.data = data |
| |
|
| | self.labels = torch.stack([self.data[entry_idx][2] for entry_idx in range(len(self.data))]) |
| | self.sums = self.labels.sum(dim=0) |
| | self.avg = int(torch.mean(self.sums).item()) |
| |
|
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __iter__(self): |
| | training = [] |
| | minority_classes = torch.where(self.sums < self.avg)[0] |
| | majority_classes = torch.where(self.sums >= self.avg)[0] |
| |
|
| | for class_idx in minority_classes: |
| | class_indices = torch.where(self.labels[:, class_idx] == 1)[0] |
| | oversampled_indices = np.random.choice(class_indices, size=self.avg, replace=True) |
| | training.extend(oversampled_indices.tolist()) |
| |
|
| | |
| | for class_idx in majority_classes: |
| | class_indices = torch.where(self.labels[:, class_idx] == 1)[0] |
| | undersampled_indices = np.random.choice(class_indices, size=self.avg, replace=False) |
| | training.extend(undersampled_indices.tolist()) |
| | training=np.random.choice(training, size=6300, replace=False) |
| |
|
| |
|
| | return iter(training) |
| |
|
| | def __getitem__(self, index): |
| | return self.data[index] |
| |
|