Elron commited on
Commit
300a7be
1 Parent(s): 59bff37

Upload fusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. fusion.py +20 -40
fusion.py CHANGED
@@ -1,9 +1,11 @@
 
1
  from abc import abstractmethod
2
  from dataclasses import asdict
3
  from typing import Generator, List, Optional
4
 
5
  from .card import ICLCard, TaskCard
6
  from .common import CommonRecipe
 
7
  from .operator import SourceOperator, StreamSource
8
  from .random_utils import random
9
  from .stream import MultiStream, Stream
@@ -17,7 +19,8 @@ class BaseFusion(SourceOperator):
17
  include_splits: List of splits to include. If None, all splits are included.
18
  """
19
 
20
- include_splits: Optional[List[str]] = None
 
21
 
22
  @abstractmethod
23
  def fusion_generator(self, split) -> Generator:
@@ -51,14 +54,17 @@ class FixedFusion(BaseFusion):
51
  splits: List of splits to include. If None, all splits are included.
52
  """
53
 
54
- examples_per_task: Optional[int] = None
55
 
56
  def fusion_generator(self, split) -> Generator:
57
- for origin in self.orgins:
58
  iterator = iter(origin()[split])
59
- if self.examples_per_task is not None:
60
- for i in range(self.examples_per_task):
61
- yield next(iterator)
 
 
 
62
  else:
63
  yield from iterator
64
 
@@ -70,12 +76,12 @@ class WeightedFusion(BaseFusion):
70
  Args:
71
  orgins: List of StreamSource objects.
72
  weights: List of weights for each origin.
73
- total_examples: Total number of examples to return. If None, all examples are returned.
74
  """
75
 
76
  origins: List[StreamSource] = None
77
  weights: List[float] = None
78
- total_examples: int = None
79
 
80
  def verify(self):
81
  super().verify()
@@ -84,41 +90,15 @@ class WeightedFusion(BaseFusion):
84
  assert len(self.origins) == len(self.weights), "origins and weights must have the same length"
85
 
86
  def fusion_generator(self, split) -> Generator:
 
87
  iterators = [iter(origin()[split]) for origin in self.origins]
88
  total_examples = 0
89
- while (self.total_examples is None or total_examples <= self.total_examples) and len(iterators) > 0:
90
- iterator = random.choices(population=iterators, weights=self.weights)[0]
91
  try:
92
  yield next(iterator)
93
  total_examples += 1
94
  except StopIteration:
95
- iterators.remove(iterator)
96
-
97
-
98
- class TasksFusion(SourceOperator):
99
- """
100
- TasksFusion operator that combines multiple tasks into one.
101
-
102
- Args:
103
- tasks: List of TaskCard objects.
104
- config: ICLCard object.
105
- examples_per_task: Number of examples per task. If None, all examples are returned.
106
- include_splits: List of splits to include. If None, all splits are included.
107
- """
108
-
109
- tasks: List[TaskCard]
110
- config: ICLCard
111
- examples_per_task: Optional[int] = None
112
- include_splits: Optional[List[str]] = None
113
-
114
- def prepare(self):
115
- self.recipes = []
116
- for task in self.tasks:
117
- recipe = CommonRecipe(card=task, **asdict(self.config))
118
-
119
- self.fusion = FixedFusion(
120
- origins=self.recipes, examples_per_task=self.examples_per_task, include_splits=self.include_splits
121
- )
122
-
123
- def process(self) -> MultiStream:
124
- return self.fusion()
 
1
+ import copy
2
  from abc import abstractmethod
3
  from dataclasses import asdict
4
  from typing import Generator, List, Optional
5
 
6
  from .card import ICLCard, TaskCard
7
  from .common import CommonRecipe
8
+ from .dataclass import NonPositionalField
9
  from .operator import SourceOperator, StreamSource
10
  from .random_utils import random
11
  from .stream import MultiStream, Stream
 
19
  include_splits: List of splits to include. If None, all splits are included.
20
  """
21
 
22
+ origins: List[StreamSource]
23
+ include_splits: Optional[List[str]] = NonPositionalField(default=None)
24
 
25
  @abstractmethod
26
  def fusion_generator(self, split) -> Generator:
 
54
  splits: List of splits to include. If None, all splits are included.
55
  """
56
 
57
+ max_instances_per_origin: Optional[int] = None
58
 
59
  def fusion_generator(self, split) -> Generator:
60
+ for origin in self.origins:
61
  iterator = iter(origin()[split])
62
+ if self.max_instances_per_origin is not None:
63
+ for _ in range(self.max_instances_per_origin):
64
+ try:
65
+ yield next(iterator)
66
+ except StopIteration:
67
+ break
68
  else:
69
  yield from iterator
70
 
 
76
  Args:
77
  orgins: List of StreamSource objects.
78
  weights: List of weights for each origin.
79
+ max_total_examples: Total number of examples to return. If None, all examples are returned.
80
  """
81
 
82
  origins: List[StreamSource] = None
83
  weights: List[float] = None
84
+ max_total_examples: int = None
85
 
86
  def verify(self):
87
  super().verify()
 
90
  assert len(self.origins) == len(self.weights), "origins and weights must have the same length"
91
 
92
  def fusion_generator(self, split) -> Generator:
93
+ weights = copy.deepcopy(self.weights)
94
  iterators = [iter(origin()[split]) for origin in self.origins]
95
  total_examples = 0
96
+ while (self.max_total_examples is None or total_examples <= self.max_total_examples) and len(iterators) > 0:
97
+ iterator = random.choices(population=iterators, weights=weights)[0]
98
  try:
99
  yield next(iterator)
100
  total_examples += 1
101
  except StopIteration:
102
+ index = iterators.index(iterator)
103
+ iterators.pop(index)
104
+ weights.pop(index)