|
import logging |
|
import random |
|
from typing import * |
|
|
|
from allennlp.data.samplers.batch_sampler import BatchSampler |
|
from allennlp.data.samplers.max_tokens_batch_sampler import MaxTokensBatchSampler |
|
from torch.utils import data |
|
|
|
logger = logging.getLogger('mix_sampler') |
|
|
|
|
|
@BatchSampler.register('mix_sampler') |
|
class MixSampler(MaxTokensBatchSampler): |
|
def __init__( |
|
self, |
|
max_tokens: int, |
|
sorting_keys: List[str] = None, |
|
padding_noise: float = 0.1, |
|
sampling_ratios: Optional[Dict[str, float]] = None, |
|
): |
|
super().__init__(max_tokens, sorting_keys, padding_noise) |
|
|
|
self.sampling_ratios = sampling_ratios or dict() |
|
|
|
def __iter__(self): |
|
indices, lengths = self._argsort_by_padding(self.data_source) |
|
|
|
original_num = len(indices) |
|
instance_types = [ |
|
ins.fields['meta'].metadata.get('type', 'default') if 'meta' in ins.fields else 'default' |
|
for ins in self.data_source |
|
] |
|
instance_thresholds = [ |
|
self.sampling_ratios[ins_type] if ins_type in self.sampling_ratios else 1.0 for ins_type in instance_types |
|
] |
|
for idx, threshold in enumerate(instance_thresholds): |
|
if random.random() > threshold: |
|
|
|
list_idx = indices.index(idx) |
|
del indices[list_idx], lengths[list_idx] |
|
if original_num != len(indices): |
|
logger.info(f'#instances reduced from {original_num} to {len(indices)}.') |
|
|
|
max_lengths = [max(length) for length in lengths] |
|
group_iterator = self._lazy_groups_of_max_size(indices, max_lengths) |
|
|
|
batches = [list(group) for group in group_iterator] |
|
random.shuffle(batches) |
|
for batch in batches: |
|
yield batch |
|
|