Elron commited on
Commit
78a0600
·
1 Parent(s): 26c0f39

Upload splitters.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. splitters.py +51 -0
splitters.py CHANGED
@@ -102,6 +102,57 @@ class RandomSampler(Sampler):
102
  return random.sample(instances_pool, self.sample_size)
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  class SpreadSplit(InstanceOperatorWithGlobalAccess):
106
  source_stream: str = None
107
  target_field: str = None
 
102
  return random.sample(instances_pool, self.sample_size)
103
 
104
 
105
+ class DiverseLabelsSampler(Sampler):
106
+ choices: str = "choices"
107
+
108
+ def prepare(self):
109
+ super().prepare()
110
+ self.labels = None
111
+
112
+ def examplar_repr(self, examplar):
113
+ assert (
114
+ "inputs" in examplar and self.choices in examplar["inputs"]
115
+ ), f"DiverseLabelsSampler assumes each examplar has {self.choices} field in it input"
116
+ examplar_outputs = next(iter(examplar["outputs"].values()))
117
+ return str([choice for choice in examplar["inputs"][self.choices] if choice in examplar_outputs])
118
+
119
+ def divide_by_repr(self, examplars_pool):
120
+ labels = dict()
121
+ for examplar in examplars_pool:
122
+ label_repr = self.examplar_repr(examplar)
123
+ if label_repr not in labels:
124
+ labels[label_repr] = []
125
+ labels[label_repr].append(examplar)
126
+ return labels
127
+
128
+ def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
129
+ if self.labels is None:
130
+ self.labels = self.divide_by_repr(instances_pool)
131
+ all_labels = list(self.labels.keys())
132
+ random.shuffle(all_labels)
133
+ from collections import Counter
134
+
135
+ total_allocated = 0
136
+ allocations = Counter()
137
+
138
+ while total_allocated < self.sample_size:
139
+ for label in all_labels:
140
+ if total_allocated < self.sample_size:
141
+ if len(self.labels[label]) - allocations[label] > 0:
142
+ allocations[label] += 1
143
+ total_allocated += 1
144
+ else:
145
+ break
146
+
147
+ result = []
148
+ for label, allocation in allocations.items():
149
+ sample = random.sample(self.labels[label], allocation)
150
+ result.extend(sample)
151
+
152
+ random.shuffle(result)
153
+ return result
154
+
155
+
156
  class SpreadSplit(InstanceOperatorWithGlobalAccess):
157
  source_stream: str = None
158
  target_field: str = None