File size: 3,716 Bytes
300a7be
cbca7b8
5c545d2
5852323
300a7be
2d210f5
cc653d8
5c545d2
 
5852323
cbca7b8
17a636b
5c545d2
cbca7b8
 
 
5c545d2
2d210f5
300a7be
5c545d2
cbca7b8
 
 
5c545d2
cbca7b8
 
 
 
 
 
 
 
5852323
5c545d2
 
 
cbca7b8
 
5c545d2
cbca7b8
5852323
5c545d2
cbca7b8
17a636b
5c545d2
cbca7b8
f747a71
cbca7b8
 
 
5c545d2
300a7be
5c545d2
cbca7b8
300a7be
ef1f482
 
 
 
300a7be
 
 
 
 
 
cbca7b8
 
5c545d2
5852323
cbca7b8
17a636b
5c545d2
cbca7b8
f747a71
cbca7b8
300a7be
cbca7b8
5c545d2
2d210f5
cbca7b8
300a7be
5c545d2
cbca7b8
 
 
 
17a636b
 
 
5c545d2
cbca7b8
300a7be
cbca7b8
 
cc653d8
17a636b
 
 
cc653d8
 
 
cbca7b8
 
 
 
300a7be
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import copy
from abc import abstractmethod
from typing import Generator, List, Optional

from .dataclass import NonPositionalField
from .operator import SourceOperator
from .random_utils import new_random_generator
from .stream import MultiStream, Stream


class BaseFusion(SourceOperator):
    """BaseFusion operator that combines multiple streams into one.

    Args:
        include_splits: List of splits to include. If None, all splits are included.
    """

    origins: List[SourceOperator]
    include_splits: Optional[List[str]] = NonPositionalField(default=None)

    @abstractmethod
    def fusion_generator(self, split) -> Generator:
        pass

    def splits(self) -> Generator:
        splits = []
        for origin in self.origins:
            for s in origin().keys():
                if s not in splits:
                    if self.include_splits is None or s in self.include_splits:
                        splits.append(s)
        return splits

    def process(
        self,
    ) -> MultiStream:
        result = {}
        for split in self.splits():
            result[split] = Stream(self.fusion_generator, gen_kwargs={"split": split})
        return MultiStream(result)


class FixedFusion(BaseFusion):
    """FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.

    Args:
        origins: List of SourceOperator objects.
        examples_per_task: Number of examples per task. If None, all examples are returned.
        splits: List of splits to include. If None, all splits are included.
    """

    max_instances_per_origin: Optional[int] = None

    def fusion_generator(self, split) -> Generator:
        for origin in self.origins:
            multi_stream = origin()
            if split not in multi_stream:
                continue
            iterator = iter(multi_stream[split])
            if self.max_instances_per_origin is not None:
                for _ in range(self.max_instances_per_origin):
                    try:
                        yield next(iterator)
                    except StopIteration:
                        break
            else:
                yield from iterator


class WeightedFusion(BaseFusion):
    """Fusion operator that combines multiple streams based.

    Args:
        origins: List of SourceOperator objects.
        weights: List of weights for each origin.
        max_total_examples: Total number of examples to return. If None, all examples are returned.
    """

    origins: List[SourceOperator] = None
    weights: List[float] = None
    max_total_examples: int = None

    def verify(self):
        super().verify()
        assert self.origins is not None, "origins must be specified"
        assert self.weights is not None, "weights must be specified"
        assert len(self.origins) == len(
            self.weights
        ), "origins and weights must have the same length"

    def fusion_generator(self, split) -> Generator:
        weights = copy.deepcopy(self.weights)
        iterators = [iter(origin()[split]) for origin in self.origins]
        total_examples = 0
        random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
        while (
            self.max_total_examples is None or total_examples <= self.max_total_examples
        ) and len(iterators) > 0:
            iterator = random_generator.choices(population=iterators, weights=weights)[
                0
            ]
            try:
                yield next(iterator)
                total_examples += 1
            except StopIteration:
                index = iterators.index(iterator)
                iterators.pop(index)
                weights.pop(index)