File size: 7,075 Bytes
b7c39fe
 
 
778ad61
 
b7c39fe
778ad61
 
 
 
b7c39fe
 
778ad61
b7c39fe
778ad61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from dataclasses import field
from typing import Any, Dict, Generator, Iterable, List, Optional, Union

from .artifact import Artifact, fetch_artifact
from .operator import (
    MultiStream,
    MultiStreamOperator,
    SingleStreamOperator,
    SingleStreamReducer,
    Stream,
    StreamInitializerOperator,
    StreamInstanceOperator,
)
from .stream import MultiStream, Stream


class FromIterables(StreamInitializerOperator):
    def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
        return MultiStream.from_iterables(iterables)


class MapInstanceValues(StreamInstanceOperator):
    mappers: Dict[str, Dict[str, str]]
    strict: bool = True

    def verify(self):
        # make sure the mappers are valid
        for key, mapper in self.mappers.items():
            assert isinstance(mapper, dict), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
            for k, v in mapper.items():
                assert isinstance(k, str), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        result = {}
        for key, value in instance.items():
            str_value = str(value)
            if key in self.mappers:
                mapper = self.mappers[key]
                if self.strict:
                    value = mapper[str_value]
                else:
                    if str_value in mapper:
                        value = mapper[str_value]
            result[key] = value
        return result


def flatten_dict(d: Dict[str, Any], parent_key: str = "", sep: str = "_") -> Dict[str, Any]:
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


class FlattenInstances(StreamInstanceOperator):
    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        return flatten_dict(instance)


class AddFields(StreamInstanceOperator):
    fields: Dict[str, object]

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        return {**instance, **self.fields}


class ArtifactFetcherMixin:
    cache: Dict[str, Artifact] = {}

    @classmethod
    def get_artifact(cls, artifact_identifier: str) -> Artifact:
        if artifact_identifier not in cls.cache:
            artifact, artifactory = fetch_artifact(artifact_identifier)
            cls.cache[artifact_identifier] = artifact
        return cls.cache[artifact_identifier]


class ApplyValueOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
    value_field: str
    operators_field: str
    default_operators: List[str] = None

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        operator_names = instance.get(self.operators_field)
        if operator_names is None:
            assert (
                self.default_operators is not None
            ), f"No operators found in {self.field} field and no default operators provided"
            operator_names = self.default_operators

        if isinstance(operator_names, str):
            operator_names = [operator_names]

        for name in operator_names:
            operator = self.get_artifact(name)
            instance = operator(instance, self.value_field)

        return instance


class FilterByValues(SingleStreamOperator):
    values: Dict[str, Any]

    def process(self, stream: Stream, stream_name: str = None) -> Generator:
        for instance in stream:
            if all(instance[key] == value for key, value in self.values.items()):
                yield instance


class Unique(SingleStreamReducer):
    fields: List[str] = field(default_factory=list)

    @staticmethod
    def to_tuple(instance: dict, fields: List[str]) -> tuple:
        result = []
        for field in fields:
            value = instance[field]
            if isinstance(value, list):
                value = tuple(value)
            result.append(value)
        return tuple(result)

    def process(self, stream: Stream) -> Stream:
        seen = set()
        for instance in stream:
            values = self.to_tuple(instance, self.fields)
            if values not in seen:
                seen.add(values)
        return list(seen)


from .text_utils import nested_tuple_to_string


class SplitByValue(MultiStreamOperator):
    fields: List[str] = field(default_factory=list)

    def process(self, multi_stream: MultiStream) -> MultiStream:
        uniques = Unique(fields=self.fields)(multi_stream)

        result = {}

        for stream_name, stream in multi_stream.items():
            stream_unique_values = uniques[stream_name]
            for unique_values in stream_unique_values:
                filtering_values = {field: value for field, value in zip(self.fields, unique_values)}
                filtered_streams = FilterByValues(values=filtering_values)._process_single_stream(stream)
                filtered_stream_name = stream_name + "_" + nested_tuple_to_string(unique_values)
                result[filtered_stream_name] = filtered_streams

        return MultiStream(result)


class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
    field: str
    reversed: bool = False

    def process(self, stream: Stream, stream_name: str = None) -> Generator:
        first_instance = stream.peak()

        operators = first_instance.get(self.field, [])
        if isinstance(operators, str):
            operators = [operators]

        if self.reversed:
            operators = list(reversed(operators))

        for operator_name in operators:
            operator = self.get_artifact(operator_name)
            assert isinstance(
                operator, SingleStreamOperator
            ), f"Operator {operator_name} must be a SingleStreamOperator"
            stream = operator.process(stream)

        yield from stream


class AddFieldNamePrefix(StreamInstanceOperator):
    prefix_dict: Dict[str, str]

    def prepare(self):
        return super().prepare()

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        return {self.prefix_dict[stream_name] + key: value for key, value in instance.items()}


class MergeStreams(MultiStreamOperator):
    new_stream_name: str = "all"
    add_origin_stream_name: bool = True
    origin_stream_name_field_name: str = "origin"

    def merge(self, multi_stream):
        for stream_name, stream in multi_stream.items():
            for instance in stream:
                if self.add_origin_stream_name:
                    instance[self.origin_stream_name_field_name] = stream_name
                yield instance

    def process(self, multi_stream: MultiStream) -> MultiStream:
        return MultiStream({self.new_stream_name: Stream(self.merge, gen_kwargs={"multi_stream": multi_stream})})