Elron commited on
Commit
80500e3
·
verified ·
1 Parent(s): 5de16c3

Upload splitters.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. splitters.py +57 -12
splitters.py CHANGED
@@ -1,10 +1,11 @@
1
  import itertools
2
  from abc import abstractmethod
 
3
  from typing import Dict, List
4
 
5
  from .artifact import Artifact
6
  from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
7
- from .random_utils import get_random
8
  from .split_utils import (
9
  parse_random_mix_string,
10
  parse_slices_string,
@@ -82,6 +83,7 @@ class SliceSplit(Splitter):
82
 
83
  class Sampler(Artifact):
84
  sample_size: int = None
 
85
 
86
  def prepare(self):
87
  super().prepare()
@@ -95,6 +97,11 @@ class Sampler(Artifact):
95
  size = int(size)
96
  self.sample_size = size
97
 
 
 
 
 
 
98
  @abstractmethod
99
  def sample(
100
  self, instances_pool: List[Dict[str, object]]
@@ -107,22 +114,52 @@ class RandomSampler(Sampler):
107
  self, instances_pool: List[Dict[str, object]]
108
  ) -> List[Dict[str, object]]:
109
  instances_pool = list(instances_pool)
110
- return get_random().sample(instances_pool, self.sample_size)
111
 
112
 
113
  class DiverseLabelsSampler(Sampler):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  choices: str = "choices"
 
115
 
116
  def prepare(self):
117
  super().prepare()
118
- self.labels = None
119
 
120
  def examplar_repr(self, examplar):
121
  if "inputs" not in examplar:
122
  raise ValueError(f"'inputs' field is missing from '{examplar}'.")
123
  inputs = examplar["inputs"]
124
  if self.choices not in inputs:
125
- raise ValueError(f"{self.choices} field is missing from '{inputs}'.")
126
  choices = inputs[self.choices]
127
  if not isinstance(choices, list):
128
  raise ValueError(
@@ -131,7 +168,11 @@ class DiverseLabelsSampler(Sampler):
131
 
132
  if "outputs" not in examplar:
133
  raise ValueError(f"'outputs' field is missing from '{examplar}'.")
134
- examplar_outputs = next(iter(examplar["outputs"].values()))
 
 
 
 
135
  if not isinstance(examplar_outputs, list):
136
  raise ValueError(
137
  f"Unexpected examplar_outputs value '{examplar_outputs}'. Expected a list."
@@ -151,19 +192,23 @@ class DiverseLabelsSampler(Sampler):
151
  def sample(
152
  self, instances_pool: List[Dict[str, object]]
153
  ) -> List[Dict[str, object]]:
154
- if self.labels is None:
155
- self.labels = self.divide_by_repr(instances_pool)
156
- all_labels = list(self.labels.keys())
157
- get_random().shuffle(all_labels)
158
  from collections import Counter
159
 
 
 
 
 
160
  total_allocated = 0
161
  allocations = Counter()
162
 
163
  while total_allocated < self.sample_size:
164
  for label in all_labels:
165
  if total_allocated < self.sample_size:
166
- if len(self.labels[label]) - allocations[label] > 0:
167
  allocations[label] += 1
168
  total_allocated += 1
169
  else:
@@ -171,10 +216,10 @@ class DiverseLabelsSampler(Sampler):
171
 
172
  result = []
173
  for label, allocation in allocations.items():
174
- sample = get_random().sample(self.labels[label], allocation)
175
  result.extend(sample)
176
 
177
- get_random().shuffle(result)
178
  return result
179
 
180
 
 
1
  import itertools
2
  from abc import abstractmethod
3
+ from random import Random
4
  from typing import Dict, List
5
 
6
  from .artifact import Artifact
7
  from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
8
+ from .random_utils import new_random_generator
9
  from .split_utils import (
10
  parse_random_mix_string,
11
  parse_slices_string,
 
83
 
84
  class Sampler(Artifact):
85
  sample_size: int = None
86
+ random_generator: Random = new_random_generator(sub_seed="Sampler")
87
 
88
  def prepare(self):
89
  super().prepare()
 
97
  size = int(size)
98
  self.sample_size = size
99
 
100
+ def init_new_random_generator(self):
101
+ self.random_generator = new_random_generator(
102
+ sub_seed="init_new_random_generator"
103
+ )
104
+
105
  @abstractmethod
106
  def sample(
107
  self, instances_pool: List[Dict[str, object]]
 
114
  self, instances_pool: List[Dict[str, object]]
115
  ) -> List[Dict[str, object]]:
116
  instances_pool = list(instances_pool)
117
+ return self.random_generator.sample(instances_pool, self.sample_size)
118
 
119
 
120
  class DiverseLabelsSampler(Sampler):
121
+ """Selects a balanced sample of instances based on an output field.
122
+
123
+ (used for selecting demonstrations in-context learning)
124
+
125
+ The field must contain list of values e.g ['dog'], ['cat'], ['dog','cat','cow'].
126
+ The balancing is done such that each value or combination of values
127
+ appears as equals as possible in the samples.
128
+
129
+ The `choices` param is required and determines which values should be considered.
130
+
131
+ Example:
132
+ If choices is ['dog,'cat'] , then the following combinations will be considered.
133
+ ['']
134
+ ['cat']
135
+ ['dog']
136
+ ['dog','cat']
137
+
138
+ If the instance contains a value not in the 'choice' param, it is ignored. For example,
139
+ if choices is ['dog,'cat'] and the instance field is ['dog','cat','cow'], then 'cow' is ignored
140
+ then the instance is considered as ['dog','cat'].
141
+
142
+ Args:
143
+ sample_size - number of samples to extract
144
+ choices - name of input field that contains the list of values to balance on
145
+ labels - name of output field with labels that must be balanced
146
+
147
+
148
+ """
149
+
150
  choices: str = "choices"
151
+ labels: str = "labels"
152
 
153
  def prepare(self):
154
  super().prepare()
155
+ self.labels_cache = None
156
 
157
  def examplar_repr(self, examplar):
158
  if "inputs" not in examplar:
159
  raise ValueError(f"'inputs' field is missing from '{examplar}'.")
160
  inputs = examplar["inputs"]
161
  if self.choices not in inputs:
162
+ raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
163
  choices = inputs[self.choices]
164
  if not isinstance(choices, list):
165
  raise ValueError(
 
168
 
169
  if "outputs" not in examplar:
170
  raise ValueError(f"'outputs' field is missing from '{examplar}'.")
171
+ outputs = examplar["outputs"]
172
+ if self.labels not in outputs:
173
+ raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.")
174
+
175
+ examplar_outputs = examplar["outputs"][self.labels]
176
  if not isinstance(examplar_outputs, list):
177
  raise ValueError(
178
  f"Unexpected examplar_outputs value '{examplar_outputs}'. Expected a list."
 
192
  def sample(
193
  self, instances_pool: List[Dict[str, object]]
194
  ) -> List[Dict[str, object]]:
195
+ if self.labels_cache is None:
196
+ self.labels_cache = self.divide_by_repr(instances_pool)
197
+ all_labels = list(self.labels_cache.keys())
198
+ self.random_generator.shuffle(all_labels)
199
  from collections import Counter
200
 
201
+ if self.sample_size > len(instances_pool):
202
+ raise ValueError(
203
+ f"Request sample size {self.sample_size} is greater than number of instances {len(instances_pool)}"
204
+ )
205
  total_allocated = 0
206
  allocations = Counter()
207
 
208
  while total_allocated < self.sample_size:
209
  for label in all_labels:
210
  if total_allocated < self.sample_size:
211
+ if len(self.labels_cache[label]) - allocations[label] > 0:
212
  allocations[label] += 1
213
  total_allocated += 1
214
  else:
 
216
 
217
  result = []
218
  for label, allocation in allocations.items():
219
+ sample = self.random_generator.sample(self.labels_cache[label], allocation)
220
  result.extend(sample)
221
 
222
+ self.random_generator.shuffle(result)
223
  return result
224
 
225