|
import copy |
|
from abc import abstractmethod |
|
from typing import Generator, List, Optional |
|
|
|
from .dataclass import NonPositionalField |
|
from .operator import SourceOperator |
|
from .random_utils import new_random_generator |
|
from .stream import MultiStream, Stream |
|
|
|
|
|
class BaseFusion(SourceOperator): |
|
"""BaseFusion operator that combines multiple streams into one. |
|
|
|
Args: |
|
include_splits: List of splits to include. If None, all splits are included. |
|
""" |
|
|
|
origins: List[SourceOperator] |
|
include_splits: Optional[List[str]] = NonPositionalField(default=None) |
|
|
|
@abstractmethod |
|
def fusion_generator(self, split) -> Generator: |
|
pass |
|
|
|
def splits(self) -> Generator: |
|
splits = [] |
|
for origin in self.origins: |
|
for s in origin().keys(): |
|
if s not in splits: |
|
if self.include_splits is None or s in self.include_splits: |
|
splits.append(s) |
|
return splits |
|
|
|
def process( |
|
self, |
|
) -> MultiStream: |
|
result = {} |
|
for split in self.splits(): |
|
result[split] = Stream(self.fusion_generator, gen_kwargs={"split": split}) |
|
return MultiStream(result) |
|
|
|
|
|
class FixedFusion(BaseFusion): |
|
"""FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task. |
|
|
|
Args: |
|
origins: List of SourceOperator objects. |
|
examples_per_task: Number of examples per task. If None, all examples are returned. |
|
splits: List of splits to include. If None, all splits are included. |
|
""" |
|
|
|
max_instances_per_origin: Optional[int] = None |
|
|
|
def fusion_generator(self, split) -> Generator: |
|
for origin in self.origins: |
|
multi_stream = origin() |
|
if split not in multi_stream: |
|
continue |
|
iterator = iter(multi_stream[split]) |
|
if self.max_instances_per_origin is not None: |
|
for _ in range(self.max_instances_per_origin): |
|
try: |
|
yield next(iterator) |
|
except StopIteration: |
|
break |
|
else: |
|
yield from iterator |
|
|
|
|
|
class WeightedFusion(BaseFusion): |
|
"""Fusion operator that combines multiple streams based. |
|
|
|
Args: |
|
origins: List of SourceOperator objects. |
|
weights: List of weights for each origin. |
|
max_total_examples: Total number of examples to return. If None, all examples are returned. |
|
""" |
|
|
|
origins: List[SourceOperator] = None |
|
weights: List[float] = None |
|
max_total_examples: int = None |
|
|
|
def verify(self): |
|
super().verify() |
|
assert self.origins is not None, "origins must be specified" |
|
assert self.weights is not None, "weights must be specified" |
|
assert len(self.origins) == len( |
|
self.weights |
|
), "origins and weights must have the same length" |
|
|
|
def fusion_generator(self, split) -> Generator: |
|
weights = copy.deepcopy(self.weights) |
|
iterators = [iter(origin()[split]) for origin in self.origins] |
|
total_examples = 0 |
|
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split) |
|
while ( |
|
self.max_total_examples is None or total_examples <= self.max_total_examples |
|
) and len(iterators) > 0: |
|
iterator = random_generator.choices(population=iterators, weights=weights)[ |
|
0 |
|
] |
|
try: |
|
yield next(iterator) |
|
total_examples += 1 |
|
except StopIteration: |
|
index = iterators.index(iterator) |
|
iterators.pop(index) |
|
weights.pop(index) |
|
|