File size: 7,066 Bytes
99fde4e
e203384
04d2454
2341544
 
04d2454
 
e7c76e5
 
 
2341544
99fde4e
e7c76e5
 
04d2454
 
 
 
 
e7c76e5
 
99fde4e
 
 
 
 
 
 
 
e7c76e5
 
 
 
 
 
e755967
e7c76e5
 
99fde4e
04d2454
 
99fde4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d2454
 
 
 
 
99fde4e
04d2454
 
 
99fde4e
 
 
 
e755967
99fde4e
 
e7c76e5
 
 
 
 
 
e755967
e7c76e5
 
 
e203384
 
 
 
 
 
 
 
04d2454
 
 
e203384
 
 
 
04d2454
 
 
e203384
e7c76e5
 
 
04d2454
 
 
e7c76e5
04d2454
e7c76e5
 
78a0600
 
 
 
 
 
 
 
04d2454
 
 
 
 
 
 
 
 
 
 
 
 
78a0600
04d2454
 
 
 
 
 
78a0600
 
04d2454
78a0600
 
 
 
 
 
 
04d2454
 
 
78a0600
 
 
04d2454
78a0600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d2454
78a0600
 
04d2454
78a0600
 
 
04d2454
e7c76e5
 
 
 
 
 
04d2454
e7c76e5
 
 
 
 
 
 
04d2454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import itertools
from abc import abstractmethod
from typing import Dict, List

from .artifact import Artifact
from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
from .random_utils import get_random
from .split_utils import (
    parse_random_mix_string,
    parse_slices_string,
    random_mix_streams,
    rename_split,
    slice_streams,
)
from .stream import MultiStream


class Splitter(MultiStreamOperator):
    pass


class RenameSplits(Splitter):
    mapper: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        generators = rename_split(multi_stream, self.mapper)
        return MultiStream(generators)


class SplitRandomMix(Splitter):
    mix: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
        generators = random_mix_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators)


class SeparateSplit(Splitter):
    """Separates a split (e.g. train) into several splits (e.g. train1, train2).

    sizes must indicate the size of every split except the last. If no size is give for the last split,
     it includes all the examples not allocated to any split.
    """

    from_split: str
    to_split_names: List[str]
    to_split_sizes: List[int]

    def verify(self):
        assert (
            len(self.to_split_names) == len(self.to_split_sizes)
            or len(self.to_split_names) == len(self.to_split_sizes) + 1
        ), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
        return super().verify()

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {
            key: {key: [(None, None)]}
            for key in multi_stream.keys()
            if key != self.from_split
        }
        so_far = 0
        for name, size in itertools.zip_longest(
            self.to_split_names, self.to_split_sizes
        ):
            mapping[name] = {self.from_split: [(so_far, size)]}
            if size:
                so_far += size
        generators = slice_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators)


class SliceSplit(Splitter):
    slices: Dict[str, str]

    def process(self, multi_stream: MultiStream) -> MultiStream:
        mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
        generators = slice_streams(multi_stream, mapping)
        return MultiStream.from_generators(generators)


class Sampler(Artifact):
    sample_size: int = None

    def prepare(self):
        super().prepare()
        self.set_size(self.sample_size)

    def set_size(self, size):
        if isinstance(size, str):
            assert (
                size.isdigit()
            ), f"sample_size must be a natural number, got {self.sample_size}"
            size = int(size)
        self.sample_size = size

    @abstractmethod
    def sample(
        self, instances_pool: List[Dict[str, object]]
    ) -> List[Dict[str, object]]:
        pass


class RandomSampler(Sampler):
    def sample(
        self, instances_pool: List[Dict[str, object]]
    ) -> List[Dict[str, object]]:
        instances_pool = list(instances_pool)
        return get_random().sample(instances_pool, self.sample_size)


class DiverseLabelsSampler(Sampler):
    choices: str = "choices"

    def prepare(self):
        super().prepare()
        self.labels = None

    def examplar_repr(self, examplar):
        if "inputs" not in examplar:
            raise ValueError(f"'inputs' field is missing from '{examplar}'.")
        inputs = examplar["inputs"]
        if self.choices not in inputs:
            raise ValueError(f"{self.choices} field is missing from '{inputs}'.")
        choices = inputs[self.choices]
        if not isinstance(choices, list):
            raise ValueError(
                f"Unexpected input choices value '{choices}'. Expected a list."
            )

        if "outputs" not in examplar:
            raise ValueError(f"'outputs' field is missing from '{examplar}'.")
        examplar_outputs = next(iter(examplar["outputs"].values()))
        if not isinstance(examplar_outputs, list):
            raise ValueError(
                f"Unexpected examplar_outputs value '{examplar_outputs}'. Expected a list."
            )

        return str([choice for choice in choices if choice in examplar_outputs])

    def divide_by_repr(self, examplars_pool):
        labels = {}
        for examplar in examplars_pool:
            label_repr = self.examplar_repr(examplar)
            if label_repr not in labels:
                labels[label_repr] = []
            labels[label_repr].append(examplar)
        return labels

    def sample(
        self, instances_pool: List[Dict[str, object]]
    ) -> List[Dict[str, object]]:
        if self.labels is None:
            self.labels = self.divide_by_repr(instances_pool)
        all_labels = list(self.labels.keys())
        get_random().shuffle(all_labels)
        from collections import Counter

        total_allocated = 0
        allocations = Counter()

        while total_allocated < self.sample_size:
            for label in all_labels:
                if total_allocated < self.sample_size:
                    if len(self.labels[label]) - allocations[label] > 0:
                        allocations[label] += 1
                        total_allocated += 1
                else:
                    break

        result = []
        for label, allocation in allocations.items():
            sample = get_random().sample(self.labels[label], allocation)
            result.extend(sample)

        get_random().shuffle(result)
        return result


class SpreadSplit(InstanceOperatorWithMultiStreamAccess):
    source_stream: str = None
    target_field: str = None
    sampler: Sampler = None

    def prepare(self):
        self.local_cache = None
        self.sampler.prepare()

    def verify(self):
        assert self.source_stream is not None, "Source stream must be specified"
        assert self.target_field is not None, "Target field must be specified"
        assert self.sampler is not None, "Sampler must be specified"
        return super().verify()

    def process(
        self, instance: Dict[str, object], multi_stream: MultiStream
    ) -> Dict[str, object]:
        try:
            if self.local_cache is None:
                self.local_cache = list(multi_stream[self.source_stream])

            source_stream = self.local_cache

            sampled_instances = self.sampler.sample(source_stream)
            instance[self.target_field] = sampled_instances
            return instance
        except Exception as e:
            raise Exception(
                f"Unable to fetch instances from '{self.source_stream}' to '{self.target_field}'"
            ) from e