File size: 5,191 Bytes
99fde4e
e7c76e5
2341544
 
 
 
 
 
e7c76e5
 
 
 
 
 
99fde4e
e7c76e5
 
 
2341544
99fde4e
e7c76e5
 
 
 
99fde4e
 
 
 
 
 
 
 
e7c76e5
 
 
 
 
 
 
 
 
99fde4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7c76e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import itertools
from dataclasses import field
from typing import Dict, List, Optional

from .artifact import Artifact
from .generator_utils import ReusableGenerator
from .operator import InstanceOperatorWithGlobalAccess, MultiStreamOperator
from .stream import MultiStream


class Splitter(MultiStreamOperator):
    pass


from .random_utils import random
from .split_utils import (
    parse_random_mix_string,
    parse_slices_string,
    random_mix_streams,
    rename_split,
    slice_streams,
)


class RenameSplits(Splitter):
    mapper: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        generators = rename_split(multi_stream, self.mapper)
        return MultiStream(generators)


class SplitRandomMix(Splitter):
    mix: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
        generators = random_mix_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators, streaming=True)


class SeparateSplit(Splitter):
    """
    Separates a split (e.g. train) into several splits (e.g. train1, train2)
    sizes must indicate the size of every split except the last. If no size is give for the last split,
     it includes all the examples not allocated to any split.
    """

    from_split: str
    to_split_names: List[str]
    to_split_sizes: List[int]

    def verify(self):
        assert (
            len(self.to_split_names) == len(self.to_split_sizes)
            or len(self.to_split_names) == len(self.to_split_sizes) + 1
        ), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
        return super().verify()

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {key: {key: [(None, None)]} for key in multi_stream.keys() if key != self.from_split}
        so_far = 0
        for name, size in itertools.zip_longest(self.to_split_names, self.to_split_sizes):
            mapping[name] = {self.from_split: [(so_far, size)]}
            if size:
                so_far += size
        generators = slice_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators, streaming=True)


class SliceSplit(Splitter):
    slices: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
        generators = slice_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators, streaming=True)


class Sampler(Artifact):
    sample_size: int


class RandomSampler(Sampler):
    def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
        instances_pool = list(instances_pool)
        return random.sample(instances_pool, self.sample_size)


class SpreadSplit(InstanceOperatorWithGlobalAccess):
    source_stream: str = None
    target_field: str = None
    sampler: Sampler = None

    def prepare(self):
        self.accessible_streams = [self.source_stream]
        self.cache_accessible_streams = True
        self.local_cache = None

    def verify(self):
        assert self.source_stream is not None, "Source stream must be specified"
        assert self.target_field is not None, "Target field must be specified"
        assert self.sampler is not None, "Sampler must be specified"
        return super().verify()

    def process(self, instance: Dict[str, object], multi_stream: MultiStream) -> Dict[str, object]:
        if self.local_cache is None:
            self.local_cache = list(multi_stream[self.source_stream])

        source_stream = self.local_cache

        sampled_instances = self.sampler.sample(source_stream)
        instance[self.target_field] = sampled_instances
        return instance


if __name__ == "__main__":
    # some tests
    import random

    random.seed(0)
    splitter = SplitRandomMix(
        mix={
            "train": "train[90%]+validation[50%]",
            "validation": "train[10%]+validation[50%]",
            "test": "test",
        }
    )

    def generator(name, size):
        for i in range(size):
            yield {"text": f"{name}_{i}"}

    stream = MultiStream.from_generators(
        {
            "train": ReusableGenerator(generator, gen_kwargs={"name": "train", "size": 10}),
            "validation": ReusableGenerator(generator, gen_kwargs={"name": "validation", "size": 10}),
            "test": ReusableGenerator(generator, gen_kwargs={"name": "test", "size": 10}),
        }
    )

    ds = splitter(stream)
    for key, value in ds.items():
        print(key)
        for item in value:
            print(item)

    splitter = SliceSplit(
        slices={
            "train": "train[:2]+train[2:4]",
            "validation": "train[4:6]",
            "test": "train[6:]+test",
        }
    )

    ds = splitter(stream)
    for key, value in ds.items():
        print(key)
        for item in value:
            print(item)