Elron commited on
Commit
cbca7b8
1 Parent(s): 93a9b92

Upload fusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. fusion.py +124 -13
fusion.py CHANGED
@@ -1,18 +1,129 @@
1
- from typing import List, Optional
 
 
 
2
 
3
- from .loaders import Loader
4
- from .splitters import Splitter
5
- from .stream import MultiStream
6
- from .task import Tasker
7
 
8
- # class Fusion(StreamSource):
9
- # pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
11
 
12
- # class RecipeFusion(StreamSource):
13
- # recepies: List[Recipe]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # def __call__(self) -> MultiStream:
16
- # for recipe in self.recepies:
17
- # stream = recipe()
18
- # return stream
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Generator
2
+ from dataclasses import asdict
3
+ import random
4
+ from abc import abstractmethod
5
 
6
+ from .stream import MultiStream, Stream
7
+ from .operator import SourceOperator, StreamSource
8
+ from .card import TaskCard, ICLCard
9
+ from .common import CommonRecipe
10
 
11
+ class BaseFusion(SourceOperator):
12
+ """
13
+ BaseFusion operator that combines multiple streams into one.
14
+
15
+ Args:
16
+ include_splits: List of splits to include. If None, all splits are included.
17
+ """
18
+ include_splits: Optional[List[str]] = None
19
+
20
+ @abstractmethod
21
+ def fusion_generator(self, split) -> Generator:
22
+ pass
23
+
24
+ def splits(self) -> Generator:
25
+ splits = []
26
+ for origin in self.origins:
27
+ for s in origin().keys():
28
+ if s not in splits:
29
+ if self.include_splits is None or s in self.include_splits:
30
+ splits.append(s)
31
+ return splits
32
+
33
 
34
+ def process(self, ) -> MultiStream:
35
+ result = {}
36
+ for split in self.splits():
37
+ result[split] = Stream(self.fusion_generator, gen_kwargs={'split': split})
38
+ return MultiStream(result)
39
 
40
+ class FixedFusion(BaseFusion):
41
+ """
42
+ FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.
43
+
44
+ Args:
45
+ orgins: List of StreamSource objects.
46
+ examples_per_task: Number of examples per task. If None, all examples are returned.
47
+ splits: List of splits to include. If None, all splits are included.
48
+ """
49
+ examples_per_task: Optional[int] = None
50
+
51
+ def fusion_generator(self, split) -> Generator:
52
+ for origin in self.orgins:
53
+ iterator = iter(origin()[split])
54
+ if self.examples_per_task is not None:
55
+ for i in range(self.examples_per_task):
56
+ yield next(iterator)
57
+ else:
58
+ yield from iterator
59
+
60
 
61
+ class WeightedFusion(BaseFusion):
62
+ """
63
+ Fusion operator that combines multiple streams based
64
+
65
+ Args:
66
+ orgins: List of StreamSource objects.
67
+ weights: List of weights for each origin.
68
+ total_examples: Total number of examples to return. If None, all examples are returned.
69
+ """
70
+ origins: List[StreamSource] = None
71
+ weights: List[float] = None
72
+ total_examples: int = None
73
+
74
+ def verify(self):
75
+ super().verify()
76
+ assert self.origins is not None, "origins must be specified"
77
+ assert self.weights is not None, "weights must be specified"
78
+ assert len(self.origins) == len(self.weights), "origins and weights must have the same length"
79
+
80
+ def fusion_generator(self, split) -> Generator:
81
+ iterators = [iter(origin()[split]) for origin in self.origins]
82
+ total_examples = 0
83
+ while (self.total_examples is None or total_examples <= self.total_examples) \
84
+ and len(iterators) > 0:
85
+ iterator = random.choices(population=iterators, weights=self.weights)[0]
86
+ try:
87
+ yield next(iterator)
88
+ total_examples += 1
89
+ except StopIteration:
90
+ iterators.remove(iterator)
91
+
92
+ class TasksFusion(SourceOperator):
93
+ """
94
+ TasksFusion operator that combines multiple tasks into one.
95
+
96
+ Args:
97
+ tasks: List of TaskCard objects.
98
+ config: ICLCard object.
99
+ examples_per_task: Number of examples per task. If None, all examples are returned.
100
+ include_splits: List of splits to include. If None, all splits are included.
101
+ """
102
+ tasks: List[TaskCard]
103
+ config: ICLCard
104
+ examples_per_task: Optional[int] = None
105
+ include_splits: Optional[List[str]] = None
106
+
107
+ def prepare(self):
108
+ self.recipes = []
109
+ for task in self.tasks:
110
+ recipe = CommonRecipe(
111
+ card=task,
112
+ **asdict(self.config)
113
+ )
114
+
115
+ self.fusion = FixedFusion(
116
+ origins=self.recipes,
117
+ examples_per_task=self.examples_per_task,
118
+ include_splits=self.include_splits
119
+ )
120
+
121
+ def process(self) -> MultiStream:
122
+ return self.fusion()
123
+
124
+
125
+
126
+
127
+
128
+
129
+