Elron commited on
Commit
778ad61
·
1 Parent(s): e7ab33e

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +207 -0
operators.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .stream import MultiStream, Stream
2
+ from .artifact import Artifact, fetch_artifact
3
+ from .operator import (
4
+ StreamInstanceOperator,
5
+ MultiStreamOperator,
6
+ SingleStreamOperator,
7
+ SingleStreamReducer,
8
+ StreamInitializerOperator,
9
+ Stream,
10
+ MultiStream,
11
+ )
12
+
13
+ from dataclasses import field
14
+ from typing import List, Union, Dict, Optional, Generator, Any, Iterable
15
+
16
+ from typing import Dict, Any
17
+
18
+
19
+ class FromIterables(StreamInitializerOperator):
20
+ def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
21
+ return MultiStream.from_iterables(iterables)
22
+
23
+
24
+ class MapInstanceValues(StreamInstanceOperator):
25
+ mappers: Dict[str, Dict[str, str]]
26
+ strict: bool = True
27
+
28
+ def verify(self):
29
+ # make sure the mappers are valid
30
+ for key, mapper in self.mappers.items():
31
+ assert isinstance(mapper, dict), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
32
+ for k, v in mapper.items():
33
+ assert isinstance(k, str), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
34
+
35
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
36
+ result = {}
37
+ for key, value in instance.items():
38
+ str_value = str(value)
39
+ if key in self.mappers:
40
+ mapper = self.mappers[key]
41
+ if self.strict:
42
+ value = mapper[str_value]
43
+ else:
44
+ if str_value in mapper:
45
+ value = mapper[str_value]
46
+ result[key] = value
47
+ return result
48
+
49
+
50
+ def flatten_dict(d: Dict[str, Any], parent_key: str = "", sep: str = "_") -> Dict[str, Any]:
51
+ items = []
52
+ for k, v in d.items():
53
+ new_key = parent_key + sep + k if parent_key else k
54
+ if isinstance(v, dict):
55
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
56
+ else:
57
+ items.append((new_key, v))
58
+ return dict(items)
59
+
60
+
61
+ class FlattenInstances(StreamInstanceOperator):
62
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
63
+ return flatten_dict(instance)
64
+
65
+
66
+ class AddFields(StreamInstanceOperator):
67
+ fields: Dict[str, object]
68
+
69
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
70
+ return {**instance, **self.fields}
71
+
72
+
73
+ class ArtifactFetcherMixin:
74
+ cache: Dict[str, Artifact] = {}
75
+
76
+ @classmethod
77
+ def get_artifact(cls, artifact_identifier: str) -> Artifact:
78
+ if artifact_identifier not in cls.cache:
79
+ artifact, artifactory = fetch_artifact(artifact_identifier)
80
+ cls.cache[artifact_identifier] = artifact
81
+ return cls.cache[artifact_identifier]
82
+
83
+
84
+ class ApplyValueOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
85
+ value_field: str
86
+ operators_field: str
87
+ default_operators: List[str] = None
88
+
89
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
90
+ operator_names = instance.get(self.operators_field)
91
+ if operator_names is None:
92
+ assert (
93
+ self.default_operators is not None
94
+ ), f"No operators found in {self.field} field and no default operators provided"
95
+ operator_names = self.default_operators
96
+
97
+ if isinstance(operator_names, str):
98
+ operator_names = [operator_names]
99
+
100
+ for name in operator_names:
101
+ operator = self.get_artifact(name)
102
+ instance = operator(instance, self.value_field)
103
+
104
+ return instance
105
+
106
+
107
+ class FilterByValues(SingleStreamOperator):
108
+ values: Dict[str, Any]
109
+
110
+ def process(self, stream: Stream, stream_name: str = None) -> Generator:
111
+ for instance in stream:
112
+ if all(instance[key] == value for key, value in self.values.items()):
113
+ yield instance
114
+
115
+
116
+ class Unique(SingleStreamReducer):
117
+ fields: List[str] = field(default_factory=list)
118
+
119
+ @staticmethod
120
+ def to_tuple(instance: dict, fields: List[str]) -> tuple:
121
+ result = []
122
+ for field in fields:
123
+ value = instance[field]
124
+ if isinstance(value, list):
125
+ value = tuple(value)
126
+ result.append(value)
127
+ return tuple(result)
128
+
129
+ def process(self, stream: Stream) -> Stream:
130
+ seen = set()
131
+ for instance in stream:
132
+ values = self.to_tuple(instance, self.fields)
133
+ if values not in seen:
134
+ seen.add(values)
135
+ return list(seen)
136
+
137
+
138
+ from .text_utils import nested_tuple_to_string
139
+
140
+
141
+ class SplitByValue(MultiStreamOperator):
142
+ fields: List[str] = field(default_factory=list)
143
+
144
+ def process(self, multi_stream: MultiStream) -> MultiStream:
145
+ uniques = Unique(fields=self.fields)(multi_stream)
146
+
147
+ result = {}
148
+
149
+ for stream_name, stream in multi_stream.items():
150
+ stream_unique_values = uniques[stream_name]
151
+ for unique_values in stream_unique_values:
152
+ filtering_values = {field: value for field, value in zip(self.fields, unique_values)}
153
+ filtered_streams = FilterByValues(values=filtering_values)._process_single_stream(stream)
154
+ filtered_stream_name = stream_name + "_" + nested_tuple_to_string(unique_values)
155
+ result[filtered_stream_name] = filtered_streams
156
+
157
+ return MultiStream(result)
158
+
159
+
160
+ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
161
+ field: str
162
+ reversed: bool = False
163
+
164
+ def process(self, stream: Stream, stream_name: str = None) -> Generator:
165
+ first_instance = stream.peak()
166
+
167
+ operators = first_instance.get(self.field, [])
168
+ if isinstance(operators, str):
169
+ operators = [operators]
170
+
171
+ if self.reversed:
172
+ operators = list(reversed(operators))
173
+
174
+ for operator_name in operators:
175
+ operator = self.get_artifact(operator_name)
176
+ assert isinstance(
177
+ operator, SingleStreamOperator
178
+ ), f"Operator {operator_name} must be a SingleStreamOperator"
179
+ stream = operator.process(stream)
180
+
181
+ yield from stream
182
+
183
+
184
+ class AddFieldNamePrefix(StreamInstanceOperator):
185
+ prefix_dict: Dict[str, str]
186
+
187
+ def prepare(self):
188
+ return super().prepare()
189
+
190
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
191
+ return {self.prefix_dict[stream_name] + key: value for key, value in instance.items()}
192
+
193
+
194
+ class MergeStreams(MultiStreamOperator):
195
+ new_stream_name: str = "all"
196
+ add_origin_stream_name: bool = True
197
+ origin_stream_name_field_name: str = "origin"
198
+
199
+ def merge(self, multi_stream):
200
+ for stream_name, stream in multi_stream.items():
201
+ for instance in stream:
202
+ if self.add_origin_stream_name:
203
+ instance[self.origin_stream_name_field_name] = stream_name
204
+ yield instance
205
+
206
+ def process(self, multi_stream: MultiStream) -> MultiStream:
207
+ return MultiStream({self.new_stream_name: Stream(self.merge, gen_kwargs={"multi_stream": multi_stream})})