File size: 5,620 Bytes
99fde4e e203384 e7c76e5 2341544 e7c76e5 99fde4e e7c76e5 2341544 99fde4e e7c76e5 99fde4e e7c76e5 e755967 e7c76e5 99fde4e e755967 99fde4e e7c76e5 e755967 e7c76e5 e203384 e7c76e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import itertools
from abc import abstractmethod
from dataclasses import field
from typing import Dict, List, Optional
from .artifact import Artifact
from .generator_utils import ReusableGenerator
from .operator import InstanceOperatorWithGlobalAccess, MultiStreamOperator
from .stream import MultiStream
class Splitter(MultiStreamOperator):
pass
from .random_utils import random
from .split_utils import (
parse_random_mix_string,
parse_slices_string,
random_mix_streams,
rename_split,
slice_streams,
)
class RenameSplits(Splitter):
mapper: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
generators = rename_split(multi_stream, self.mapper)
return MultiStream(generators)
class SplitRandomMix(Splitter):
mix: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
generators = random_mix_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
class SeparateSplit(Splitter):
"""
Separates a split (e.g. train) into several splits (e.g. train1, train2)
sizes must indicate the size of every split except the last. If no size is give for the last split,
it includes all the examples not allocated to any split.
"""
from_split: str
to_split_names: List[str]
to_split_sizes: List[int]
def verify(self):
assert (
len(self.to_split_names) == len(self.to_split_sizes)
or len(self.to_split_names) == len(self.to_split_sizes) + 1
), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
return super().verify()
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {key: {key: [(None, None)]} for key in multi_stream.keys() if key != self.from_split}
so_far = 0
for name, size in itertools.zip_longest(self.to_split_names, self.to_split_sizes):
mapping[name] = {self.from_split: [(so_far, size)]}
if size:
so_far += size
generators = slice_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
class SliceSplit(Splitter):
slices: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
generators = slice_streams(multi_stream, mapping)
return MultiStream.from_generators(generators)
class Sampler(Artifact):
sample_size: int = None
def prepare(self):
super().prepare()
self.set_size(self.sample_size)
def set_size(self, size):
if isinstance(size, str):
assert size.isdigit(), f"sample_size must be a natural number, got {self.sample_size}"
size = int(size)
self.sample_size = size
@abstractmethod
def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
pass
class RandomSampler(Sampler):
def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
instances_pool = list(instances_pool)
return random.sample(instances_pool, self.sample_size)
class SpreadSplit(InstanceOperatorWithGlobalAccess):
source_stream: str = None
target_field: str = None
sampler: Sampler = None
def prepare(self):
self.accessible_streams = [self.source_stream]
self.cache_accessible_streams = True
self.local_cache = None
def verify(self):
assert self.source_stream is not None, "Source stream must be specified"
assert self.target_field is not None, "Target field must be specified"
assert self.sampler is not None, "Sampler must be specified"
return super().verify()
def process(self, instance: Dict[str, object], multi_stream: MultiStream) -> Dict[str, object]:
if self.local_cache is None:
self.local_cache = list(multi_stream[self.source_stream])
source_stream = self.local_cache
sampled_instances = self.sampler.sample(source_stream)
instance[self.target_field] = sampled_instances
return instance
if __name__ == "__main__":
# some tests
import random
random.seed(0)
splitter = SplitRandomMix(
mix={
"train": "train[90%]+validation[50%]",
"validation": "train[10%]+validation[50%]",
"test": "test",
}
)
def generator(name, size):
for i in range(size):
yield {"text": f"{name}_{i}"}
stream = MultiStream.from_generators(
{
"train": ReusableGenerator(generator, gen_kwargs={"name": "train", "size": 10}),
"validation": ReusableGenerator(generator, gen_kwargs={"name": "validation", "size": 10}),
"test": ReusableGenerator(generator, gen_kwargs={"name": "test", "size": 10}),
}
)
ds = splitter(stream)
for key, value in ds.items():
print(key)
for item in value:
print(item)
splitter = SliceSplit(
slices={
"train": "train[:2]+train[2:4]",
"validation": "train[4:6]",
"test": "train[6:]+test",
}
)
ds = splitter(stream)
for key, value in ds.items():
print(key)
for item in value:
print(item)
|