File size: 3,557 Bytes
e7c76e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from .stream import MultiStream
from .operator import MultiStreamOperator, InstanceOperatorWithGlobalAccess
from .generator_utils import ReusableGenerator
from .artifact import Artifact
from typing import Optional, Dict, List
from dataclasses import field
class Splitter(MultiStreamOperator):
pass
import random
from .split_utils import (
parse_random_mix_string,
random_mix_streams,
parse_slices_string,
slice_streams,
)
class SplitRandomMix(Splitter):
mix: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
generators = random_mix_streams(multi_stream, mapping)
return MultiStream.from_generators(generators, streaming=True)
class SliceSplit(Splitter):
slices: Dict[str, str]
def process(self, multi_stream: MultiStream) -> MultiStream:
mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
generators = slice_streams(multi_stream, mapping)
return MultiStream.from_generators(generators, streaming=True)
class Sampler(Artifact):
sample_size: int
class RandomSampler(Sampler):
def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
instances_pool = list(instances_pool)
return random.sample(instances_pool, self.sample_size)
class SpreadSplit(InstanceOperatorWithGlobalAccess):
source_stream: str = None
target_field: str = None
sampler: Sampler = None
def prepare(self):
self.accessible_streams = [self.source_stream]
self.cache_accessible_streams = True
self.local_cache = None
def verify(self):
assert self.source_stream is not None, "Source stream must be specified"
assert self.target_field is not None, "Target field must be specified"
assert self.sampler is not None, "Sampler must be specified"
return super().verify()
def process(self, instance: Dict[str, object], multi_stream: MultiStream) -> Dict[str, object]:
if self.local_cache is None:
self.local_cache = list(multi_stream[self.source_stream])
source_stream = self.local_cache
sampled_instances = self.sampler.sample(source_stream)
instance[self.target_field] = sampled_instances
return instance
if __name__ == "__main__":
# some tests
import random
random.seed(0)
splitter = SplitRandomMix(
mix={
"train": "train[90%]+validation[50%]",
"validation": "train[10%]+validation[50%]",
"test": "test",
}
)
def generator(name, size):
for i in range(size):
yield {"text": f"{name}_{i}"}
stream = MultiStream.from_generators(
{
"train": ReusableGenerator(generator, gen_kwargs={"name": "train", "size": 10}),
"validation": ReusableGenerator(generator, gen_kwargs={"name": "validation", "size": 10}),
"test": ReusableGenerator(generator, gen_kwargs={"name": "test", "size": 10}),
}
)
ds = splitter(stream)
for key, value in ds.items():
print(key)
for item in value:
print(item)
splitter = SliceSplit(
slices={
"train": "train[:2]+train[2:4]",
"validation": "train[4:6]",
"test": "train[6:]+test",
}
)
ds = splitter(stream)
for key, value in ds.items():
print(key)
for item in value:
print(item)
|