Upload operators.py with huggingface_hub
Browse files- operators.py +54 -20
operators.py
CHANGED
@@ -41,6 +41,7 @@ from abc import abstractmethod
|
|
41 |
from collections import Counter
|
42 |
from copy import deepcopy
|
43 |
from dataclasses import field
|
|
|
44 |
from itertools import zip_longest
|
45 |
from random import Random
|
46 |
from typing import (
|
@@ -58,7 +59,7 @@ from typing import (
|
|
58 |
import requests
|
59 |
|
60 |
from .artifact import Artifact, fetch_artifact
|
61 |
-
from .dataclass import NonPositionalField
|
62 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
63 |
from .operator import (
|
64 |
MultiStream,
|
@@ -74,11 +75,14 @@ from .operator import (
|
|
74 |
StreamInstanceOperator,
|
75 |
)
|
76 |
from .random_utils import new_random_generator
|
|
|
77 |
from .stream import Stream
|
78 |
from .text_utils import nested_tuple_to_string
|
79 |
from .type_utils import isoftype
|
80 |
from .utils import flatten_dict
|
81 |
|
|
|
|
|
82 |
|
83 |
class FromIterables(StreamInitializerOperator):
|
84 |
"""Creates a MultiStream from a dict of named iterables.
|
@@ -484,7 +488,7 @@ class AddConstant(FieldOperator):
|
|
484 |
|
485 |
|
486 |
class Augmentor(StreamInstanceOperator):
|
487 |
-
"""A stream that augments the values of either the task input fields before rendering with the template, or the
|
488 |
|
489 |
Args:
|
490 |
augment_model_input: Whether to augment the input to the model.
|
@@ -564,9 +568,9 @@ class NullAugmentor(Augmentor):
|
|
564 |
|
565 |
|
566 |
class AugmentWhitespace(Augmentor):
|
567 |
-
"""Augments the inputs by
|
568 |
|
569 |
-
Currently each whitespace is replaced by a random choice of 1-3 whitespace
|
570 |
"""
|
571 |
|
572 |
def process_value(self, value: Any) -> Any:
|
@@ -1094,7 +1098,7 @@ class ArtifactFetcherMixin:
|
|
1094 |
return cls.cache[artifact_identifier]
|
1095 |
|
1096 |
|
1097 |
-
class ApplyOperatorsField(StreamInstanceOperator
|
1098 |
"""Applies value operators to each instance in a stream based on specified fields.
|
1099 |
|
1100 |
Args:
|
@@ -1215,30 +1219,60 @@ class FilterByCondition(SingleStreamOperator):
|
|
1215 |
return True
|
1216 |
|
1217 |
|
1218 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1219 |
"""Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
|
1220 |
|
1221 |
Raises an error if a field participating in the specified condition is missing from the instance
|
1222 |
|
1223 |
Args:
|
1224 |
-
|
|
|
1225 |
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
|
1226 |
|
1227 |
Examples:
|
1228 |
-
|
1229 |
-
|
1230 |
-
|
1231 |
-
|
1232 |
|
1233 |
"""
|
1234 |
|
1235 |
-
query: str
|
1236 |
error_on_filtered_all: bool = True
|
1237 |
|
1238 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1239 |
yielded = False
|
1240 |
for instance in stream:
|
1241 |
-
if
|
1242 |
yielded = True
|
1243 |
yield instance
|
1244 |
|
@@ -1248,33 +1282,33 @@ class FilterByQuery(SingleStreamOperator):
|
|
1248 |
)
|
1249 |
|
1250 |
|
1251 |
-
class
|
1252 |
-
"""Compute an expression
|
1253 |
|
1254 |
Raises an error if a field mentioned in the query is missing from the instance.
|
1255 |
|
1256 |
Args:
|
1257 |
-
|
1258 |
to_field (str): the field where the result is to be stored into
|
|
|
1259 |
|
1260 |
Examples:
|
1261 |
When instance {"a": 2, "b": 3} is process-ed by operator
|
1262 |
-
|
1263 |
the result is {"a": 2, "b": 3, "c": 5}
|
1264 |
|
1265 |
When instance {"a": "hello", "b": "world"} is process-ed by operator
|
1266 |
-
|
1267 |
the result is {"a": "hello", "b": "world", "c": "hello world"}
|
1268 |
|
1269 |
"""
|
1270 |
|
1271 |
-
query: str
|
1272 |
to_field: str
|
1273 |
|
1274 |
def process(
|
1275 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
1276 |
) -> Dict[str, Any]:
|
1277 |
-
instance[self.to_field] =
|
1278 |
return instance
|
1279 |
|
1280 |
|
|
|
41 |
from collections import Counter
|
42 |
from copy import deepcopy
|
43 |
from dataclasses import field
|
44 |
+
from importlib import import_module
|
45 |
from itertools import zip_longest
|
46 |
from random import Random
|
47 |
from typing import (
|
|
|
59 |
import requests
|
60 |
|
61 |
from .artifact import Artifact, fetch_artifact
|
62 |
+
from .dataclass import NonPositionalField, OptionalField
|
63 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
64 |
from .operator import (
|
65 |
MultiStream,
|
|
|
75 |
StreamInstanceOperator,
|
76 |
)
|
77 |
from .random_utils import new_random_generator
|
78 |
+
from .settings_utils import get_settings
|
79 |
from .stream import Stream
|
80 |
from .text_utils import nested_tuple_to_string
|
81 |
from .type_utils import isoftype
|
82 |
from .utils import flatten_dict
|
83 |
|
84 |
+
settings = get_settings()
|
85 |
+
|
86 |
|
87 |
class FromIterables(StreamInitializerOperator):
|
88 |
"""Creates a MultiStream from a dict of named iterables.
|
|
|
488 |
|
489 |
|
490 |
class Augmentor(StreamInstanceOperator):
|
491 |
+
"""A stream operator that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.
|
492 |
|
493 |
Args:
|
494 |
augment_model_input: Whether to augment the input to the model.
|
|
|
568 |
|
569 |
|
570 |
class AugmentWhitespace(Augmentor):
|
571 |
+
"""Augments the inputs by replacing existing whitespaces with other whitespaces.
|
572 |
|
573 |
+
Currently, each whitespace is replaced by a random choice of 1-3 whitespace characters (space, tab, newline).
|
574 |
"""
|
575 |
|
576 |
def process_value(self, value: Any) -> Any:
|
|
|
1098 |
return cls.cache[artifact_identifier]
|
1099 |
|
1100 |
|
1101 |
+
class ApplyOperatorsField(StreamInstanceOperator):
|
1102 |
"""Applies value operators to each instance in a stream based on specified fields.
|
1103 |
|
1104 |
Args:
|
|
|
1219 |
return True
|
1220 |
|
1221 |
|
1222 |
+
class ComputeExpressionMixin(Artifact):
|
1223 |
+
"""Computes an expression expressed over fields of an instance.
|
1224 |
+
|
1225 |
+
Args:
|
1226 |
+
expression (str): the expression, in terms of names of fields of an instance
|
1227 |
+
imports_list (List[str]): list of names of imports needed for the evaluation of the expression
|
1228 |
+
"""
|
1229 |
+
|
1230 |
+
expression: str
|
1231 |
+
imports_list: List[str] = OptionalField(default_factory=list)
|
1232 |
+
|
1233 |
+
def prepare(self):
|
1234 |
+
# can not do the imports here, because object does not pickle with imports
|
1235 |
+
self.globs = {}
|
1236 |
+
self.to_import = True
|
1237 |
+
|
1238 |
+
def compute_expression(self, instance: dict) -> Any:
|
1239 |
+
if self.to_import:
|
1240 |
+
for module_name in self.imports_list:
|
1241 |
+
self.globs[module_name] = import_module(module_name)
|
1242 |
+
self.to_import = False
|
1243 |
+
|
1244 |
+
if settings.allow_unverified_code:
|
1245 |
+
return eval(self.expression, self.globs, instance)
|
1246 |
+
|
1247 |
+
raise ValueError(
|
1248 |
+
f"Cannot run expression by {self} when unitxt.settings.allow_unverified_code=False either set it to True or set {settings.allow_unverified_code_key} environment variable."
|
1249 |
+
)
|
1250 |
+
|
1251 |
+
|
1252 |
+
class FilterByExpression(SingleStreamOperator, ComputeExpressionMixin):
|
1253 |
"""Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
|
1254 |
|
1255 |
Raises an error if a field participating in the specified condition is missing from the instance
|
1256 |
|
1257 |
Args:
|
1258 |
+
expression (str): a condition over fields of the instance, to be processed by python's eval()
|
1259 |
+
imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
|
1260 |
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
|
1261 |
|
1262 |
Examples:
|
1263 |
+
FilterByExpression(expression = "a > 4") will yield only instances where "a">4
|
1264 |
+
FilterByExpression(expression = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
|
1265 |
+
FilterByExpression(expression = "a in [4, 8]") will yield only instances where "a" is 4 or 8
|
1266 |
+
FilterByExpression(expression = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8
|
1267 |
|
1268 |
"""
|
1269 |
|
|
|
1270 |
error_on_filtered_all: bool = True
|
1271 |
|
1272 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1273 |
yielded = False
|
1274 |
for instance in stream:
|
1275 |
+
if self.compute_expression(instance):
|
1276 |
yielded = True
|
1277 |
yield instance
|
1278 |
|
|
|
1282 |
)
|
1283 |
|
1284 |
|
1285 |
+
class ExecuteExpression(StreamInstanceOperator, ComputeExpressionMixin):
|
1286 |
+
"""Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
|
1287 |
|
1288 |
Raises an error if a field mentioned in the query is missing from the instance.
|
1289 |
|
1290 |
Args:
|
1291 |
+
expression (str): an expression to be evaluated over the fields of the instance
|
1292 |
to_field (str): the field where the result is to be stored into
|
1293 |
+
imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
|
1294 |
|
1295 |
Examples:
|
1296 |
When instance {"a": 2, "b": 3} is process-ed by operator
|
1297 |
+
ExecuteExpression(expression="a+b", to_field = "c")
|
1298 |
the result is {"a": 2, "b": 3, "c": 5}
|
1299 |
|
1300 |
When instance {"a": "hello", "b": "world"} is process-ed by operator
|
1301 |
+
ExecuteExpression(expression = "a+' '+b", to_field = "c")
|
1302 |
the result is {"a": "hello", "b": "world", "c": "hello world"}
|
1303 |
|
1304 |
"""
|
1305 |
|
|
|
1306 |
to_field: str
|
1307 |
|
1308 |
def process(
|
1309 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
1310 |
) -> Dict[str, Any]:
|
1311 |
+
instance[self.to_field] = self.compute_expression(instance)
|
1312 |
return instance
|
1313 |
|
1314 |
|