File size: 8,882 Bytes
99fde4e e203384 80500e3 04d2454 2341544 04d2454 80500e3 e7c76e5 2341544 99fde4e e7c76e5 04d2454 e7c76e5 99fde4e e7c76e5 e755967 e7c76e5 99fde4e 04d2454 99fde4e 04d2454 99fde4e 04d2454 99fde4e e755967 99fde4e e7c76e5 e755967 e7c76e5 e203384 80500e3 e203384 04d2454 e203384 80500e3 e203384 04d2454 e203384 e7c76e5 04d2454 e7c76e5 80500e3 e7c76e5 78a0600 80500e3 78a0600 80500e3 78a0600 80500e3 78a0600 04d2454 80500e3 04d2454 80500e3 04d2454 78a0600 04d2454 78a0600 04d2454 80500e3 78a0600 80500e3 78a0600 80500e3 78a0600 80500e3 78a0600 80500e3 78a0600 04d2454 e7c76e5 04d2454 e7c76e5 04d2454 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
import itertools
from abc import abstractmethod
from random import Random
from typing import Dict, List
from .artifact import Artifact
from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
from .random_utils import new_random_generator
from .split_utils import (
parse_random_mix_string,
parse_slices_string,
random_mix_streams,
rename_split,
slice_streams,
)
from .stream import MultiStream
class Splitter(MultiStreamOperator):
pass
class RenameSplits(Splitter):
mapper: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
generators = rename_split(multi_stream, self.mapper)
return MultiStream(generators)
class SplitRandomMix(Splitter):
mix: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
generators = random_mix_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
class SeparateSplit(Splitter):
"""Separates a split (e.g. train) into several splits (e.g. train1, train2).
sizes must indicate the size of every split except the last. If no size is give for the last split,
it includes all the examples not allocated to any split.
"""
from_split: str
to_split_names: List[str]
to_split_sizes: List[int]
def verify(self):
assert (
len(self.to_split_names) == len(self.to_split_sizes)
or len(self.to_split_names) == len(self.to_split_sizes) + 1
), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
return super().verify()
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {
key: {key: [(None, None)]}
for key in multi_stream.keys()
if key != self.from_split
}
so_far = 0
for name, size in itertools.zip_longest(
self.to_split_names, self.to_split_sizes
):
mapping[name] = {self.from_split: [(so_far, size)]}
if size:
so_far += size
generators = slice_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
class SliceSplit(Splitter):
slices: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
generators = slice_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
class Sampler(Artifact):
sample_size: int = None
random_generator: Random = new_random_generator(sub_seed="Sampler")
def prepare(self):
super().prepare()
self.set_size(self.sample_size)
def set_size(self, size):
if isinstance(size, str):
assert (
size.isdigit()
), f"sample_size must be a natural number, got {self.sample_size}"
size = int(size)
self.sample_size = size
def init_new_random_generator(self):
self.random_generator = new_random_generator(
sub_seed="init_new_random_generator"
)
@abstractmethod
def sample(
self, instances_pool: List[Dict[str, object]]
) -> List[Dict[str, object]]:
pass
class RandomSampler(Sampler):
def sample(
self, instances_pool: List[Dict[str, object]]
) -> List[Dict[str, object]]:
instances_pool = list(instances_pool)
return self.random_generator.sample(instances_pool, self.sample_size)
class DiverseLabelsSampler(Sampler):
"""Selects a balanced sample of instances based on an output field.
(used for selecting demonstrations in-context learning)
The field must contain list of values e.g ['dog'], ['cat'], ['dog','cat','cow'].
The balancing is done such that each value or combination of values
appears as equals as possible in the samples.
The `choices` param is required and determines which values should be considered.
Example:
If choices is ['dog,'cat'] , then the following combinations will be considered.
['']
['cat']
['dog']
['dog','cat']
If the instance contains a value not in the 'choice' param, it is ignored. For example,
if choices is ['dog,'cat'] and the instance field is ['dog','cat','cow'], then 'cow' is ignored
then the instance is considered as ['dog','cat'].
Args:
sample_size - number of samples to extract
choices - name of input field that contains the list of values to balance on
labels - name of output field with labels that must be balanced
"""
choices: str = "choices"
labels: str = "labels"
def prepare(self):
super().prepare()
self.labels_cache = None
def examplar_repr(self, examplar):
if "inputs" not in examplar:
raise ValueError(f"'inputs' field is missing from '{examplar}'.")
inputs = examplar["inputs"]
if self.choices not in inputs:
raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
choices = inputs[self.choices]
if not isinstance(choices, list):
raise ValueError(
f"Unexpected input choices value '{choices}'. Expected a list."
)
if "outputs" not in examplar:
raise ValueError(f"'outputs' field is missing from '{examplar}'.")
outputs = examplar["outputs"]
if self.labels not in outputs:
raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.")
examplar_outputs = examplar["outputs"][self.labels]
if not isinstance(examplar_outputs, list):
raise ValueError(
f"Unexpected examplar_outputs value '{examplar_outputs}'. Expected a list."
)
return str([choice for choice in choices if choice in examplar_outputs])
def divide_by_repr(self, examplars_pool):
labels = {}
for examplar in examplars_pool:
label_repr = self.examplar_repr(examplar)
if label_repr not in labels:
labels[label_repr] = []
labels[label_repr].append(examplar)
return labels
def sample(
self, instances_pool: List[Dict[str, object]]
) -> List[Dict[str, object]]:
if self.labels_cache is None:
self.labels_cache = self.divide_by_repr(instances_pool)
all_labels = list(self.labels_cache.keys())
self.random_generator.shuffle(all_labels)
from collections import Counter
if self.sample_size > len(instances_pool):
raise ValueError(
f"Request sample size {self.sample_size} is greater than number of instances {len(instances_pool)}"
)
total_allocated = 0
allocations = Counter()
while total_allocated < self.sample_size:
for label in all_labels:
if total_allocated < self.sample_size:
if len(self.labels_cache[label]) - allocations[label] > 0:
allocations[label] += 1
total_allocated += 1
else:
break
result = []
for label, allocation in allocations.items():
sample = self.random_generator.sample(self.labels_cache[label], allocation)
result.extend(sample)
self.random_generator.shuffle(result)
return result
class SpreadSplit(InstanceOperatorWithMultiStreamAccess):
source_stream: str = None
target_field: str = None
sampler: Sampler = None
def prepare(self):
self.local_cache = None
self.sampler.prepare()
def verify(self):
assert self.source_stream is not None, "Source stream must be specified"
assert self.target_field is not None, "Target field must be specified"
assert self.sampler is not None, "Sampler must be specified"
return super().verify()
def process(
self, instance: Dict[str, object], multi_stream: MultiStream
) -> Dict[str, object]:
try:
if self.local_cache is None:
self.local_cache = list(multi_stream[self.source_stream])
source_stream = self.local_cache
sampled_instances = self.sampler.sample(source_stream)
instance[self.target_field] = sampled_instances
return instance
except Exception as e:
raise Exception(
f"Unable to fetch instances from '{self.source_stream}' to '{self.target_field}'"
) from e
|