Elron commited on
Commit
45dfa28
1 Parent(s): a180fb2

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +38 -13
operators.py CHANGED
@@ -20,6 +20,7 @@ Some operators are specielized in specific task such as:
20
 
21
  - :class:`loaders<unitxt.loaders>` for loading data.
22
  - :class:`splitters<unitxt.splitters>` for fixing data splits.
 
23
 
24
  Other specelized operators are used by unitxt internally:
25
 
@@ -32,7 +33,7 @@ General Operaotrs List:
32
  ------------------------
33
  """
34
  import collections
35
- import importlib
36
  import operator
37
  import os
38
  import uuid
@@ -41,7 +42,6 @@ from abc import abstractmethod
41
  from collections import Counter
42
  from copy import deepcopy
43
  from dataclasses import field
44
- from importlib import import_module
45
  from itertools import zip_longest
46
  from random import Random
47
  from typing import (
@@ -64,6 +64,7 @@ from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
64
  from .operator import (
65
  MultiStream,
66
  MultiStreamOperator,
 
67
  PagedStreamOperator,
68
  SequentialOperator,
69
  SideEffectOperator,
@@ -782,7 +783,7 @@ class Apply(StreamInstanceOperator):
782
  elif module_name in globals():
783
  obj = globals()[module_name]
784
  else:
785
- obj = importlib.import_module(module_name)
786
  for part in function_name.split("."):
787
  obj = getattr(obj, part)
788
  return obj
@@ -963,7 +964,16 @@ class CopyFields(FieldOperator):
963
  """
964
 
965
  def process_value(self, value: Any) -> Any:
966
- return value
 
 
 
 
 
 
 
 
 
967
 
968
 
969
  class AddID(StreamInstanceOperator):
@@ -1230,10 +1240,13 @@ class ComputeExpressionMixin(Artifact):
1230
  expression: str
1231
  imports_list: List[str] = OptionalField(default_factory=list)
1232
 
 
 
 
1233
  def prepare(self):
1234
  # can not do the imports here, because object does not pickle with imports
1235
  self.globals = {
1236
- module_name: import_module(module_name) for module_name in self.imports_list
1237
  }
1238
 
1239
  def compute_expression(self, instance: dict) -> Any:
@@ -1574,7 +1587,7 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
1574
  calc_confidence_intervals: bool
1575
 
1576
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1577
- from .metrics import Metric, MetricPipeline, MetricWithConfidenceInterval
1578
 
1579
  first_instance = stream.peek()
1580
 
@@ -1593,6 +1606,16 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
1593
  # by the first listed metric (as desired).
1594
  metric_names = list(reversed(metric_names))
1595
 
 
 
 
 
 
 
 
 
 
 
1596
  for metric_name in metric_names:
1597
  metric = self.get_artifact(metric_name)
1598
  assert isinstance(
@@ -1600,15 +1623,17 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
1600
  ), f"Operator {metric_name} must be a Metric"
1601
 
1602
  if not self.calc_confidence_intervals:
1603
- if isinstance(metric, MetricWithConfidenceInterval):
1604
- metric.disable_confidence_interval_calculation()
1605
- elif isinstance(metric, MetricPipeline) and isinstance(
1606
- metric.metric, MetricWithConfidenceInterval
1607
- ):
1608
- metric.metric.disable_confidence_interval_calculation()
1609
 
1610
- stream = metric(MultiStream({"tmp": stream}))["tmp"]
 
 
 
1611
 
 
 
 
 
1612
  yield from stream
1613
 
1614
 
 
20
 
21
  - :class:`loaders<unitxt.loaders>` for loading data.
22
  - :class:`splitters<unitxt.splitters>` for fixing data splits.
23
+ - :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
24
 
25
  Other specelized operators are used by unitxt internally:
26
 
 
33
  ------------------------
34
  """
35
  import collections
36
+ import copy
37
  import operator
38
  import os
39
  import uuid
 
42
  from collections import Counter
43
  from copy import deepcopy
44
  from dataclasses import field
 
45
  from itertools import zip_longest
46
  from random import Random
47
  from typing import (
 
64
  from .operator import (
65
  MultiStream,
66
  MultiStreamOperator,
67
+ PackageRequirementsMixin,
68
  PagedStreamOperator,
69
  SequentialOperator,
70
  SideEffectOperator,
 
783
  elif module_name in globals():
784
  obj = globals()[module_name]
785
  else:
786
+ obj = __import__(module_name)
787
  for part in function_name.split("."):
788
  obj = getattr(obj, part)
789
  return obj
 
964
  """
965
 
966
  def process_value(self, value: Any) -> Any:
967
+ return copy.deepcopy(value)
968
+
969
+
970
+ class GetItemByIndex(FieldOperator):
971
+ """Get from the item list by the index in the field."""
972
+
973
+ items_list: List[Any]
974
+
975
+ def process_value(self, value: Any) -> Any:
976
+ return self.items_list[value]
977
 
978
 
979
  class AddID(StreamInstanceOperator):
 
1240
  expression: str
1241
  imports_list: List[str] = OptionalField(default_factory=list)
1242
 
1243
+ def verify(self):
1244
+ PackageRequirementsMixin.check_missing_requirements(self, self.imports_list)
1245
+
1246
  def prepare(self):
1247
  # can not do the imports here, because object does not pickle with imports
1248
  self.globals = {
1249
+ module_name: __import__(module_name) for module_name in self.imports_list
1250
  }
1251
 
1252
  def compute_expression(self, instance: dict) -> Any:
 
1587
  calc_confidence_intervals: bool
1588
 
1589
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1590
+ from .metrics import Metric
1591
 
1592
  first_instance = stream.peek()
1593
 
 
1606
  # by the first listed metric (as desired).
1607
  metric_names = list(reversed(metric_names))
1608
 
1609
+ # Workaround: The metric/MetricPipeline modifies the stream itself, sometines making it incompatible
1610
+ # for further metrics' processing, instead of just modifying the score field.
1611
+ # Here we keep all the fields besides the score, and restore them after the metric finishes.
1612
+ first_instance = stream.peek()
1613
+ keys_to_restore = set(first_instance.keys()).difference({"score"})
1614
+ multi_stream = MultiStream({"tmp": stream})
1615
+ multi_stream = CopyFields(
1616
+ field_to_field={k: f"{k}_orig" for k in keys_to_restore}
1617
+ )(multi_stream)
1618
+
1619
  for metric_name in metric_names:
1620
  metric = self.get_artifact(metric_name)
1621
  assert isinstance(
 
1623
  ), f"Operator {metric_name} must be a Metric"
1624
 
1625
  if not self.calc_confidence_intervals:
1626
+ metric.disable_confidence_interval_calculation()
 
 
 
 
 
1627
 
1628
+ multi_stream = metric(multi_stream)
1629
+ multi_stream = CopyFields(
1630
+ field_to_field={f"{k}_orig": k for k in keys_to_restore}
1631
+ )(multi_stream)
1632
 
1633
+ multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
1634
+ multi_stream
1635
+ )
1636
+ stream = multi_stream["tmp"]
1637
  yield from stream
1638
 
1639