Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
from typing import Any | |
import numpy as np | |
from pydantic import ConfigDict | |
from bytelatent.data.iterators.abstract_iterator import ( | |
PydanticIteratorState, | |
StatefulIterator, | |
) | |
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState | |
class SamplingIteratorState(PydanticIteratorState): | |
model_config = ConfigDict(extra="forbid") | |
rng_state: dict[str, Any] | |
source_to_weight: dict[str, float] | |
source_to_iterator_state: dict[str, SequenceIteratorState] | |
def build(self) -> "SamplingIterator": | |
return SamplingIterator( | |
rng_state=self.rng_state, | |
source_to_weight=self.source_to_weight, | |
source_to_iterator={ | |
source: state.build() | |
for source, state in self.source_to_iterator_state.items() | |
}, | |
) | |
class SamplingIterator(StatefulIterator): | |
def __init__( | |
self, | |
*, | |
rng_state: dict[str, Any], | |
source_to_weight: dict[str, float], | |
source_to_iterator: dict[str, StatefulIterator], | |
): | |
self.rng = np.random.default_rng() | |
self.rng.bit_generator.state = rng_state | |
self.source_to_weight = source_to_weight | |
self.source_to_iterator = source_to_iterator | |
def get_state(self) -> SamplingIteratorState: | |
return SamplingIteratorState( | |
rng_state=self.rng.bit_generator.state, | |
source_to_weight=self.source_to_weight, | |
source_to_iterator_state={ | |
source: iterator.get_state() | |
for source, iterator in self.source_to_iterator.items() | |
}, | |
) | |
def create_iter(self): | |
n_sources = len(self.source_to_weight) | |
possible_sources = [] | |
weights = [] | |
for source, w in self.source_to_weight.items(): | |
possible_sources.append(source) | |
weights.append(w) | |
source_to_python_iter = { | |
source: self.source_to_iterator[source].create_iter() | |
for source in possible_sources | |
} | |
while True: | |
norm_weights = np.array(weights) / np.array(weights).sum() | |
source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)] | |
yield next(source_to_python_iter[source_choice]) | |