Spaces:
Running
Running
import collections | |
import random | |
from typing import Callable | |
from torchdata.datapipes.iter import IterDataPipe | |
def get_second_entry(sample): | |
return sample[1] | |
class UnderSamplerIterDataPipe(IterDataPipe): | |
"""Dataset wrapper for under-sampling. | |
Copied from: https://github.com/MaxHalford/pytorch-resample/blob/master/pytorch_resample/under.py # noqa | |
Modified to work with multiple labels. | |
MIT License | |
Copyright (c) 2020 Max Halford | |
This method is based on rejection sampling. | |
Parameters: | |
dataset | |
desired_dist: The desired class distribution. | |
The keys are the classes whilst the | |
values are the desired class percentages. | |
The values are normalised so that sum up | |
to 1. | |
label_getter: A function that takes a sample and returns its label. | |
seed: Random seed for reproducibility. | |
Attributes: | |
actual_dist: The counts of the observed sample labels. | |
rng: A random number generator instance. | |
References: | |
- https://www.wikiwand.com/en/Rejection_sampling | |
""" | |
def __init__( | |
self, | |
dataset: IterDataPipe, | |
desired_dist: dict, | |
label_getter: Callable = get_second_entry, | |
seed: int = None, | |
): | |
self.dataset = dataset | |
self.desired_dist = { | |
c: p / sum(desired_dist.values()) for c, p in desired_dist.items() | |
} | |
self.label_getter = label_getter | |
self.seed = seed | |
self.actual_dist = collections.Counter() | |
self.rng = random.Random(seed) | |
self._pivot = None | |
def __iter__(self): | |
for dp in self.dataset: | |
y = self.label_getter(dp) | |
self.actual_dist[y] += 1 | |
# To ease notation | |
f = self.desired_dist | |
g = self.actual_dist | |
# Check if the pivot needs to be changed | |
if y != self._pivot: | |
self._pivot = max(g.keys(), key=lambda y: f[y] / g[y]) | |
else: | |
yield dp | |
continue | |
# Determine the sampling ratio if the observed label | |
# is not the pivot | |
M = f[self._pivot] / g[self._pivot] | |
ratio = f[y] / (M * g[y]) | |
if ratio < 1 and self.rng.random() < ratio: | |
yield dp | |
def expected_size(cls, n, desired_dist, actual_dist): | |
M = max( | |
desired_dist.get(k) / actual_dist.get(k) | |
for k in set(desired_dist) | set(actual_dist) | |
) | |
return int(n / M) | |