|
from .stream import MultiStream |
|
from .operator import MultiStreamOperator, InstanceOperatorWithGlobalAccess |
|
from .generator_utils import ReusableGenerator |
|
from .artifact import Artifact |
|
|
|
|
|
from typing import Optional, Dict, List |
|
from dataclasses import field |
|
|
|
|
|
class Splitter(MultiStreamOperator): |
|
pass |
|
|
|
|
|
import random |
|
|
|
from .split_utils import ( |
|
parse_random_mix_string, |
|
random_mix_streams, |
|
parse_slices_string, |
|
slice_streams, |
|
) |
|
|
|
|
|
class SplitRandomMix(Splitter): |
|
mix: Dict[str, str] |
|
|
|
def process(self, multi_stream: MultiStream) -> MultiStream: |
|
mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()} |
|
generators = random_mix_streams(multi_stream, mapping) |
|
return MultiStream.from_generators(generators, streaming=True) |
|
|
|
|
|
class SliceSplit(Splitter): |
|
slices: Dict[str, str] |
|
|
|
def process(self, multi_stream: MultiStream) -> MultiStream: |
|
mapping = {k: parse_slices_string(v) for k, v in self.slices.items()} |
|
generators = slice_streams(multi_stream, mapping) |
|
return MultiStream.from_generators(generators, streaming=True) |
|
|
|
|
|
class Sampler(Artifact): |
|
sample_size: int |
|
|
|
|
|
class RandomSampler(Sampler): |
|
def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]: |
|
instances_pool = list(instances_pool) |
|
return random.sample(instances_pool, self.sample_size) |
|
|
|
|
|
class SpreadSplit(InstanceOperatorWithGlobalAccess): |
|
source_stream: str = None |
|
target_field: str = None |
|
sampler: Sampler = None |
|
|
|
def prepare(self): |
|
self.accessible_streams = [self.source_stream] |
|
self.cache_accessible_streams = True |
|
self.local_cache = None |
|
|
|
def verify(self): |
|
assert self.source_stream is not None, "Source stream must be specified" |
|
assert self.target_field is not None, "Target field must be specified" |
|
assert self.sampler is not None, "Sampler must be specified" |
|
return super().verify() |
|
|
|
def process(self, instance: Dict[str, object], multi_stream: MultiStream) -> Dict[str, object]: |
|
if self.local_cache is None: |
|
self.local_cache = list(multi_stream[self.source_stream]) |
|
|
|
source_stream = self.local_cache |
|
|
|
sampled_instances = self.sampler.sample(source_stream) |
|
instance[self.target_field] = sampled_instances |
|
return instance |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import random |
|
|
|
random.seed(0) |
|
splitter = SplitRandomMix( |
|
mix={ |
|
"train": "train[90%]+validation[50%]", |
|
"validation": "train[10%]+validation[50%]", |
|
"test": "test", |
|
} |
|
) |
|
|
|
def generator(name, size): |
|
for i in range(size): |
|
yield {"text": f"{name}_{i}"} |
|
|
|
stream = MultiStream.from_generators( |
|
{ |
|
"train": ReusableGenerator(generator, gen_kwargs={"name": "train", "size": 10}), |
|
"validation": ReusableGenerator(generator, gen_kwargs={"name": "validation", "size": 10}), |
|
"test": ReusableGenerator(generator, gen_kwargs={"name": "test", "size": 10}), |
|
} |
|
) |
|
|
|
ds = splitter(stream) |
|
for key, value in ds.items(): |
|
print(key) |
|
for item in value: |
|
print(item) |
|
|
|
splitter = SliceSplit( |
|
slices={ |
|
"train": "train[:2]+train[2:4]", |
|
"validation": "train[4:6]", |
|
"test": "train[6:]+test", |
|
} |
|
) |
|
|
|
ds = splitter(stream) |
|
for key, value in ds.items(): |
|
print(key) |
|
for item in value: |
|
print(item) |
|
|