File size: 3,484 Bytes
300a7be
cbca7b8
5c545d2
5852323
300a7be
5c545d2
17a636b
5c545d2
 
5852323
cbca7b8
17a636b
5c545d2
cbca7b8
 
 
5c545d2
300a7be
 
5c545d2
cbca7b8
 
 
5c545d2
cbca7b8
 
 
 
 
 
 
 
5852323
5c545d2
 
 
cbca7b8
 
5c545d2
cbca7b8
5852323
5c545d2
cbca7b8
17a636b
5c545d2
cbca7b8
 
 
 
 
5c545d2
300a7be
5c545d2
cbca7b8
300a7be
cbca7b8
300a7be
 
 
 
 
 
cbca7b8
 
5c545d2
5852323
cbca7b8
17a636b
5c545d2
cbca7b8
 
 
300a7be
cbca7b8
5c545d2
cbca7b8
 
300a7be
5c545d2
cbca7b8
 
 
 
17a636b
 
 
5c545d2
cbca7b8
300a7be
cbca7b8
 
17a636b
 
 
 
cbca7b8
 
 
 
300a7be
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import copy
from abc import abstractmethod
from typing import Generator, List, Optional

from .dataclass import NonPositionalField
from .operator import SourceOperator, StreamSource
from .random_utils import get_random
from .stream import MultiStream, Stream


class BaseFusion(SourceOperator):
    """BaseFusion operator that combines multiple streams into one.

    Args:
        include_splits: List of splits to include. If None, all splits are included.
    """

    origins: List[StreamSource]
    include_splits: Optional[List[str]] = NonPositionalField(default=None)

    @abstractmethod
    def fusion_generator(self, split) -> Generator:
        pass

    def splits(self) -> Generator:
        splits = []
        for origin in self.origins:
            for s in origin().keys():
                if s not in splits:
                    if self.include_splits is None or s in self.include_splits:
                        splits.append(s)
        return splits

    def process(
        self,
    ) -> MultiStream:
        result = {}
        for split in self.splits():
            result[split] = Stream(self.fusion_generator, gen_kwargs={"split": split})
        return MultiStream(result)


class FixedFusion(BaseFusion):
    """FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.

    Args:
        orgins: List of StreamSource objects.
        examples_per_task: Number of examples per task. If None, all examples are returned.
        splits: List of splits to include. If None, all splits are included.
    """

    max_instances_per_origin: Optional[int] = None

    def fusion_generator(self, split) -> Generator:
        for origin in self.origins:
            iterator = iter(origin()[split])
            if self.max_instances_per_origin is not None:
                for _ in range(self.max_instances_per_origin):
                    try:
                        yield next(iterator)
                    except StopIteration:
                        break
            else:
                yield from iterator


class WeightedFusion(BaseFusion):
    """Fusion operator that combines multiple streams based.

    Args:
        orgins: List of StreamSource objects.
        weights: List of weights for each origin.
        max_total_examples: Total number of examples to return. If None, all examples are returned.
    """

    origins: List[StreamSource] = None
    weights: List[float] = None
    max_total_examples: int = None

    def verify(self):
        super().verify()
        assert self.origins is not None, "origins must be specified"
        assert self.weights is not None, "weights must be specified"
        assert len(self.origins) == len(
            self.weights
        ), "origins and weights must have the same length"

    def fusion_generator(self, split) -> Generator:
        weights = copy.deepcopy(self.weights)
        iterators = [iter(origin()[split]) for origin in self.origins]
        total_examples = 0
        while (
            self.max_total_examples is None or total_examples <= self.max_total_examples
        ) and len(iterators) > 0:
            iterator = get_random().choices(population=iterators, weights=weights)[0]
            try:
                yield next(iterator)
                total_examples += 1
            except StopIteration:
                index = iterators.index(iterator)
                iterators.pop(index)
                weights.pop(index)