Elron commited on
Commit
d08fbc6
1 Parent(s): 9d5b4c0

Upload folder using huggingface_hub

Browse files
Files changed (30) hide show
  1. api.py +73 -36
  2. artifact.py +4 -4
  3. benchmark.py +58 -0
  4. blocks.py +2 -1
  5. card.py +2 -2
  6. catalog.py +41 -0
  7. collections_operators.py +9 -7
  8. dataclass.py +0 -23
  9. dataset.py +2 -0
  10. dict_utils.py +71 -19
  11. formats.py +5 -5
  12. fusion.py +42 -46
  13. image_operators.py +26 -0
  14. inference.py +84 -5
  15. llm_as_judge.py +35 -52
  16. loaders.py +21 -21
  17. metric.py +2 -0
  18. metric_utils.py +184 -114
  19. metrics.py +171 -6
  20. operator.py +32 -27
  21. operators.py +116 -76
  22. processors.py +40 -2
  23. schema.py +77 -17
  24. settings_utils.py +14 -0
  25. standard.py +57 -32
  26. stream.py +8 -2
  27. struct_data_operators.py +37 -0
  28. task.py +30 -18
  29. templates.py +55 -23
  30. version.py +1 -1
api.py CHANGED
@@ -6,8 +6,9 @@ from datasets import DatasetDict
6
  from .artifact import fetch_artifact
7
  from .dataset_utils import get_dataset_artifact
8
  from .logging_utils import get_logger
9
- from .metric_utils import _compute, _post_process
10
  from .operator import SourceOperator
 
11
  from .standard import StandardRecipe
12
 
13
  logger = get_logger()
@@ -22,21 +23,60 @@ def load(source: Union[SourceOperator, str]) -> DatasetDict:
22
  return source().to_dataset()
23
 
24
 
25
- def _load_dataset_from_query(dataset_query: str) -> DatasetDict:
26
  dataset_query = dataset_query.replace("sys_prompt", "instruction")
27
- dataset_stream = get_dataset_artifact(dataset_query)
28
- return dataset_stream().to_dataset()
 
 
 
29
 
30
 
31
- def _load_dataset_from_dict(dataset_params: Dict[str, Any]) -> DatasetDict:
32
  recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
33
  for param in dataset_params.keys():
34
  assert param in recipe_attributes, (
35
  f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
36
  f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
37
  )
38
- recipe = StandardRecipe(**dataset_params)
39
- return recipe().to_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
@@ -47,7 +87,7 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
47
  Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
48
 
49
  Args:
50
- dataset_query (str, optional): A string query which specifies dataset to load from local catalog.
51
  For example:
52
  "card=cards.wnli,template=templates.classification.multi_class.relation.default".
53
  **kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
@@ -65,26 +105,9 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
65
  loader_limit = 10
66
  dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
67
  """
68
- if dataset_query and kwargs:
69
- raise ValueError(
70
- "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
71
- "If you want to load dataset from a card in local catalog, use query only. "
72
- "Otherwise, use key-worded arguments only to specify properties of dataset."
73
- )
74
-
75
- if dataset_query:
76
- if not isinstance(dataset_query, str):
77
- raise ValueError(
78
- f"If specified, 'dataset_query' must be a string, however, "
79
- f"'{dataset_query}' was provided instead, which is of type "
80
- f"'{type(dataset_query)}'."
81
- )
82
- return _load_dataset_from_query(dataset_query)
83
-
84
- if kwargs:
85
- return _load_dataset_from_dict(kwargs)
86
 
87
- raise ValueError("Either 'dataset_query' or key-worded arguments must be provided.")
88
 
89
 
90
  def evaluate(predictions, data) -> List[Dict[str, Any]]:
@@ -92,26 +115,40 @@ def evaluate(predictions, data) -> List[Dict[str, Any]]:
92
 
93
 
94
  def post_process(predictions, data) -> List[Dict[str, Any]]:
95
- return _post_process(predictions=predictions, references=data)
96
 
97
 
98
  @lru_cache
99
- def _get_produce_with_cache(recipe_query):
100
- return get_dataset_artifact(recipe_query).produce
101
 
102
 
103
- def produce(instance_or_instances, recipe_query):
104
  is_list = isinstance(instance_or_instances, list)
105
  if not is_list:
106
  instance_or_instances = [instance_or_instances]
107
- result = _get_produce_with_cache(recipe_query)(instance_or_instances)
108
  if not is_list:
109
  result = result[0]
110
  return result
111
 
112
 
113
- def infer(instance_or_instances, recipe, engine):
114
- dataset = produce(instance_or_instances, recipe)
 
 
 
 
 
 
115
  engine, _ = fetch_artifact(engine)
116
- predictions = engine.infer(dataset)
117
- return post_process(predictions, dataset)
 
 
 
 
 
 
 
 
 
6
  from .artifact import fetch_artifact
7
  from .dataset_utils import get_dataset_artifact
8
  from .logging_utils import get_logger
9
+ from .metric_utils import _compute, _inference_post_process
10
  from .operator import SourceOperator
11
+ from .schema import UNITXT_DATASET_SCHEMA
12
  from .standard import StandardRecipe
13
 
14
  logger = get_logger()
 
23
  return source().to_dataset()
24
 
25
 
26
+ def _get_recipe_from_query(dataset_query: str) -> StandardRecipe:
27
  dataset_query = dataset_query.replace("sys_prompt", "instruction")
28
+ try:
29
+ dataset_stream, _ = fetch_artifact(dataset_query)
30
+ except:
31
+ dataset_stream = get_dataset_artifact(dataset_query)
32
+ return dataset_stream
33
 
34
 
35
+ def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> StandardRecipe:
36
  recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
37
  for param in dataset_params.keys():
38
  assert param in recipe_attributes, (
39
  f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
40
  f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
41
  )
42
+ return StandardRecipe(**dataset_params)
43
+
44
+
45
+ def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
46
+ if dataset_query and dataset_args:
47
+ raise ValueError(
48
+ "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
49
+ "If you want to load dataset from a card in local catalog, use query only. "
50
+ "Otherwise, use key-worded arguments only to specify properties of dataset."
51
+ )
52
+
53
+ if dataset_query:
54
+ if not isinstance(dataset_query, str):
55
+ raise ValueError(
56
+ f"If specified, 'dataset_query' must be a string, however, "
57
+ f"'{dataset_query}' was provided instead, which is of type "
58
+ f"'{type(dataset_query)}'."
59
+ )
60
+
61
+ if not dataset_query and not dataset_args:
62
+ raise ValueError(
63
+ "Either 'dataset_query' or key-worded arguments must be provided."
64
+ )
65
+
66
+
67
+ def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> StandardRecipe:
68
+ if isinstance(dataset_query, StandardRecipe):
69
+ return dataset_query
70
+
71
+ _verify_dataset_args(dataset_query, kwargs)
72
+
73
+ if dataset_query:
74
+ recipe = _get_recipe_from_query(dataset_query)
75
+
76
+ if kwargs:
77
+ recipe = _get_recipe_from_dict(kwargs)
78
+
79
+ return recipe
80
 
81
 
82
  def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
 
87
  Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
88
 
89
  Args:
90
+ dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
91
  For example:
92
  "card=cards.wnli,template=templates.classification.multi_class.relation.default".
93
  **kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
 
105
  loader_limit = 10
106
  dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
107
  """
108
+ recipe = load_recipe(dataset_query, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
111
 
112
 
113
  def evaluate(predictions, data) -> List[Dict[str, Any]]:
 
115
 
116
 
117
  def post_process(predictions, data) -> List[Dict[str, Any]]:
118
+ return _inference_post_process(predictions=predictions, references=data)
119
 
120
 
121
  @lru_cache
122
+ def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
123
+ return load_recipe(dataset_query, **kwargs).produce
124
 
125
 
126
+ def produce(instance_or_instances, dataset_query: Optional[str] = None, **kwargs):
127
  is_list = isinstance(instance_or_instances, list)
128
  if not is_list:
129
  instance_or_instances = [instance_or_instances]
130
+ result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
131
  if not is_list:
132
  result = result[0]
133
  return result
134
 
135
 
136
+ def infer(
137
+ instance_or_instances,
138
+ engine,
139
+ dataset_query: Optional[str] = None,
140
+ return_data=False,
141
+ **kwargs,
142
+ ):
143
+ dataset = produce(instance_or_instances, dataset_query, **kwargs)
144
  engine, _ = fetch_artifact(engine)
145
+ raw_predictions = engine.infer(dataset)
146
+ predictions = post_process(raw_predictions, dataset)
147
+ if return_data:
148
+ for prediction, raw_prediction, instance in zip(
149
+ predictions, raw_predictions, dataset
150
+ ):
151
+ instance["prediction"] = prediction
152
+ instance["raw_prediction"] = raw_prediction
153
+ return dataset
154
+ return predictions
artifact.py CHANGED
@@ -439,10 +439,10 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[Artifactory, None]]:
439
  """Loads an artifict from one of possible representations.
440
 
441
  (1) If artifact representation is already an Artifact object, return it.
442
- (2) If artifact representation is a string location of a local file, load the Artifact from local file.
443
- (3) If artifact representation is a string name iin the catalog, load the Artifact from the catalog.
444
- (4) If artifact representation is a json string, create dictionary representation from the string and build an Artifact object from it.
445
- (5) Otherwise, check the artifact representation is a dictionary and build an Artifact object from it.
446
  """
447
  if isinstance(artifact_rep, Artifact):
448
  return artifact_rep, None
 
439
  """Loads an artifict from one of possible representations.
440
 
441
  (1) If artifact representation is already an Artifact object, return it.
442
+ (2) If artifact representation is a string location of a local file, load the Artifact from the local file.
443
+ (3) If artifact representation is a string name in the catalog, load the Artifact from the catalog.
444
+ (4) If artifact representation is a json string, create a dictionary representation from the string and build an Artifact object from it.
445
+ (5) Otherwise, check that the artifact representation is a dictionary and build an Artifact object from it.
446
  """
447
  if isinstance(artifact_rep, Artifact):
448
  return artifact_rep, None
benchmark.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Union
2
+
3
+ from .dataclass import NonPositionalField
4
+ from .formats import Format
5
+ from .fusion import FixedFusion, WeightedFusion
6
+ from .operator import SourceOperator
7
+ from .standard import StandardRecipe
8
+ from .stream import MultiStream
9
+ from .system_prompts import SystemPrompt
10
+
11
+
12
+ class BaseBenchmark(SourceOperator):
13
+ format: Format = NonPositionalField(default=None)
14
+ num_demos: int = NonPositionalField(default=None)
15
+ system_prompt: SystemPrompt = NonPositionalField(default=None)
16
+ loader_limit: int = NonPositionalField(default=None)
17
+
18
+
19
+ class Benchmark(BaseBenchmark):
20
+ subsets: Dict[str, Union[StandardRecipe, BaseBenchmark]]
21
+
22
+ max_total_samples: int = None
23
+ max_samples_per_subset: int = None
24
+
25
+ def verify(self):
26
+ if (
27
+ self.max_total_samples is not None
28
+ and self.max_samples_per_subset is not None
29
+ ):
30
+ raise ValueError("Set either max_total_samples or max_samples_per_subset")
31
+
32
+ def prepare(self):
33
+ if self.format is not None or self.num_demos is not None:
34
+ for subset in self.subsets.values():
35
+ if self.num_demos is not None:
36
+ subset.num_demos = self.num_demos
37
+ if self.format is not None:
38
+ subset.format = self.format
39
+ if self.system_prompt is not None:
40
+ subset.system_prompt = self.system_prompt
41
+ if self.loader_limit is not None:
42
+ subset.loader_limit = self.loader_limit
43
+ subset.prepare()
44
+
45
+ def process(
46
+ self,
47
+ ) -> MultiStream:
48
+ if self.max_total_samples is None:
49
+ operator = FixedFusion(
50
+ subsets=self.subsets,
51
+ max_instances_per_subset=self.max_samples_per_subset,
52
+ )
53
+ else:
54
+ operator = WeightedFusion(
55
+ subsets=self.subsets, max_total_samples=self.max_total_samples
56
+ )
57
+
58
+ return operator()
blocks.py CHANGED
@@ -13,7 +13,7 @@ from .operators import (
13
  Copy,
14
  DivideAllFieldsBy,
15
  MapInstanceValues,
16
- RenameFields,
17
  Set,
18
  )
19
  from .processors import ToString, ToStringStripped
@@ -21,6 +21,7 @@ from .recipe import SequentialRecipe
21
  from .splitters import RandomSampler, Sample, SliceSplit, SplitRandomMix
22
  from .stream import MultiStream
23
  from .struct_data_operators import (
 
24
  ListToKeyValPairs,
25
  MapHTMLTableToJSON,
26
  SerializeKeyValPairs,
 
13
  Copy,
14
  DivideAllFieldsBy,
15
  MapInstanceValues,
16
+ Rename,
17
  Set,
18
  )
19
  from .processors import ToString, ToStringStripped
 
21
  from .splitters import RandomSampler, Sample, SliceSplit, SplitRandomMix
22
  from .stream import MultiStream
23
  from .struct_data_operators import (
24
+ ConstructTableFromRowsCols,
25
  ListToKeyValPairs,
26
  MapHTMLTableToJSON,
27
  SerializeKeyValPairs,
card.py CHANGED
@@ -10,12 +10,12 @@ from .task import Task
10
 
11
 
12
  class TaskCard(Artifact):
13
- """TaskCard delineates the phases in transforming the source dataset into a model-input, and specifies the metrics for evaluation of model-output.
14
 
15
  Attributes:
16
  loader: specifies the source address and the loading operator that can access that source and transform it into a unitxt multistream.
17
 
18
- preprocess_steps: list of unitxt operators to process the data source into a model-input.
19
 
20
  task: specifies the fields (of the already (pre)processed instance) making the inputs, the fields making the outputs, and the metrics to be used for evaluating the model output.
21
 
 
10
 
11
 
12
  class TaskCard(Artifact):
13
+ """TaskCard delineates the phases in transforming the source dataset into model input, and specifies the metrics for evaluation of model output.
14
 
15
  Attributes:
16
  loader: specifies the source address and the loading operator that can access that source and transform it into a unitxt multistream.
17
 
18
+ preprocess_steps: list of unitxt operators to process the data source into model input.
19
 
20
  task: specifies the fields (of the already (pre)processed instance) making the inputs, the fields making the outputs, and the metrics to be used for evaluating the model output.
21
 
catalog.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from collections import Counter
3
  from functools import lru_cache
@@ -195,6 +196,46 @@ def summary():
195
  return result
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def ls(to_file=None):
199
  done = set()
200
  result = []
 
1
+ import json
2
  import os
3
  from collections import Counter
4
  from functools import lru_cache
 
196
  return result
197
 
198
 
199
+ def _get_tags_from_file(file_path):
200
+ result = Counter()
201
+
202
+ with open(file_path) as f:
203
+ data = json.load(f)
204
+ if "__tags__" in data and isinstance(data["__tags__"], dict):
205
+ tags = data["__tags__"]
206
+ for key, value in tags.items():
207
+ if isinstance(value, list):
208
+ for item in value:
209
+ result[f"{key}:{item}"] += 1
210
+ else:
211
+ result[f"{key}:{value}"] += 1
212
+
213
+ return result
214
+
215
+
216
+ def count_tags():
217
+ result = Counter()
218
+ done = set()
219
+
220
+ for local_catalog_path in get_local_catalogs_paths():
221
+ if local_catalog_path not in done:
222
+ for root, _, files in os.walk(local_catalog_path):
223
+ for file in files:
224
+ if file.endswith(".json"):
225
+ file_path = os.path.join(root, file)
226
+ try:
227
+ result += _get_tags_from_file(file_path)
228
+ except json.JSONDecodeError:
229
+ logger.info(f"Error decoding JSON in file: {file_path}")
230
+ except OSError:
231
+ logger.info(f"Error reading file: {file_path}")
232
+
233
+ done.add(local_catalog_path)
234
+
235
+ print_dict(result)
236
+ return result
237
+
238
+
239
  def ls(to_file=None):
240
  done = set()
241
  result = []
collections_operators.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any, Generator, List, Optional
2
 
 
3
  from .operators import FieldOperator, StreamOperator
4
  from .stream import Stream
5
  from .utils import deepcopy
@@ -66,20 +67,21 @@ class DuplicateByList(StreamOperator):
66
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
67
  to_field = self.field if self.to_field is None else self.to_field
68
  for instance in stream:
69
- elements = instance[self.field]
70
  for element in elements:
71
  if self.use_deep_copy:
72
  instance_copy = deepcopy(instance)
73
- instance_copy[to_field] = element
74
  else:
75
- instance_copy = {
76
- **instance,
77
- self.field: elements,
78
- to_field: element,
79
- }
80
  yield instance_copy
81
 
82
 
 
 
 
 
83
  class DuplicateBySubLists(StreamOperator):
84
  field: str
85
  to_field: Optional[str] = None
 
1
  from typing import Any, Generator, List, Optional
2
 
3
+ from .dict_utils import dict_get, dict_set
4
  from .operators import FieldOperator, StreamOperator
5
  from .stream import Stream
6
  from .utils import deepcopy
 
67
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
68
  to_field = self.field if self.to_field is None else self.to_field
69
  for instance in stream:
70
+ elements = dict_get(instance, self.field)
71
  for element in elements:
72
  if self.use_deep_copy:
73
  instance_copy = deepcopy(instance)
74
+
75
  else:
76
+ instance_copy = {**instance}
77
+ dict_set(instance_copy, to_field, element)
 
 
 
78
  yield instance_copy
79
 
80
 
81
+ class Explode(DuplicateByList):
82
+ pass
83
+
84
+
85
  class DuplicateBySubLists(StreamOperator):
86
  field: str
87
  to_field: Optional[str] = None
dataclass.py CHANGED
@@ -2,7 +2,6 @@ import copy
2
  import dataclasses
3
  import functools
4
  import inspect
5
- import warnings
6
  from abc import ABCMeta
7
  from inspect import Parameter, Signature
8
  from typing import Any, Dict, List, Optional, final
@@ -38,7 +37,6 @@ class Field:
38
  final: bool = False
39
  abstract: bool = False
40
  required: bool = False
41
- deprecated: bool = False
42
  internal: bool = False
43
  origin_cls: type = None
44
  metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
@@ -55,12 +53,6 @@ class FinalField(Field):
55
  self.final = True
56
 
57
 
58
- @dataclasses.dataclass
59
- class DeprecatedField(Field):
60
- def __post_init__(self):
61
- self.deprecated = True
62
-
63
-
64
  @dataclasses.dataclass
65
  class RequiredField(Field):
66
  def __post_init__(self):
@@ -251,10 +243,6 @@ def required_fields(cls):
251
  return [field for field in fields(cls) if field.required]
252
 
253
 
254
- def deprecated_fields(cls):
255
- return [field for field in fields(cls) if field.deprecated]
256
-
257
-
258
  def abstract_fields(cls):
259
  return [field for field in fields(cls) if field.abstract]
260
 
@@ -267,10 +255,6 @@ def is_final_field(field):
267
  return field.final
268
 
269
 
270
- def is_deprecated_field(field):
271
- return field.deprecated
272
-
273
-
274
  def get_field_default(field):
275
  if field.default_factory is not None:
276
  return field.default_factory()
@@ -424,7 +408,6 @@ class Dataclass(metaclass=DataclassMeta):
424
  """Initialize fields based on kwargs.
425
 
426
  Checks for abstract fields when an instance is created.
427
- Warn when a deprecated is used
428
  """
429
  super().__init__()
430
  _init_fields = [field for field in fields(self) if field.init]
@@ -433,12 +416,6 @@ class Dataclass(metaclass=DataclassMeta):
433
  field.name for field in _init_fields if field.also_positional
434
  ]
435
 
436
- _init_deprecated_fields = [field for field in _init_fields if field.deprecated]
437
- for dep_field in _init_deprecated_fields:
438
- warnings.warn(
439
- dep_field.metadata["deprecation_msg"], DeprecationWarning, stacklevel=2
440
- )
441
-
442
  for name in _init_positional_fields_names[: len(argv)]:
443
  if name in kwargs:
444
  raise TypeError(
 
2
  import dataclasses
3
  import functools
4
  import inspect
 
5
  from abc import ABCMeta
6
  from inspect import Parameter, Signature
7
  from typing import Any, Dict, List, Optional, final
 
37
  final: bool = False
38
  abstract: bool = False
39
  required: bool = False
 
40
  internal: bool = False
41
  origin_cls: type = None
42
  metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
 
53
  self.final = True
54
 
55
 
 
 
 
 
 
 
56
  @dataclasses.dataclass
57
  class RequiredField(Field):
58
  def __post_init__(self):
 
243
  return [field for field in fields(cls) if field.required]
244
 
245
 
 
 
 
 
246
  def abstract_fields(cls):
247
  return [field for field in fields(cls) if field.abstract]
248
 
 
255
  return field.final
256
 
257
 
 
 
 
 
258
  def get_field_default(field):
259
  if field.default_factory is not None:
260
  return field.default_factory()
 
408
  """Initialize fields based on kwargs.
409
 
410
  Checks for abstract fields when an instance is created.
 
411
  """
412
  super().__init__()
413
  _init_fields = [field for field in fields(self) if field.init]
 
416
  field.name for field in _init_fields if field.also_positional
417
  ]
418
 
 
 
 
 
 
 
419
  for name in _init_positional_fields_names[: len(argv)]:
420
  if name in kwargs:
421
  raise TypeError(
dataset.py CHANGED
@@ -4,6 +4,7 @@ import datasets
4
 
5
  from .api import __file__ as _
6
  from .artifact import __file__ as _
 
7
  from .blocks import __file__ as _
8
  from .card import __file__ as _
9
  from .catalog import __file__ as _
@@ -23,6 +24,7 @@ from .fusion import __file__ as _
23
  from .generator_utils import __file__ as _
24
  from .hf_utils import __file__ as _
25
  from .hf_utils import verify_versions_compatibility
 
26
  from .inference import __file__ as _
27
  from .instructions import __file__ as _
28
  from .llm_as_judge import __file__ as _
 
4
 
5
  from .api import __file__ as _
6
  from .artifact import __file__ as _
7
+ from .benchmark import __file__ as _
8
  from .blocks import __file__ as _
9
  from .card import __file__ as _
10
  from .catalog import __file__ as _
 
24
  from .generator_utils import __file__ as _
25
  from .hf_utils import __file__ as _
26
  from .hf_utils import verify_versions_compatibility
27
+ from .image_operators import __file__ as _
28
  from .inference import __file__ as _
29
  from .instructions import __file__ as _
30
  from .llm_as_judge import __file__ as _
dict_utils.py CHANGED
@@ -4,8 +4,23 @@ from typing import Any, List, Tuple
4
  from .text_utils import construct_dict_str
5
 
6
  indx = re.compile(r"^(\d+)$")
 
 
 
 
 
 
7
  name = re.compile(r"^[\w. -]+$")
8
 
 
 
 
 
 
 
 
 
 
9
  # formal definition of qpath syntax by which a query is specified:
10
  # qpath -> A (/A)*
11
  # A -> name | * | non-neg-int
@@ -51,7 +66,9 @@ name = re.compile(r"^[\w. -]+$")
51
 
52
 
53
  # validate and normalizes into components
54
- def validate_query_and_break_to_components(query: str) -> List[str]:
 
 
55
  if not isinstance(query, str) or len(query) == 0:
56
  raise ValueError(
57
  f"invalid query: either not a string or an empty string: {query}"
@@ -69,9 +86,9 @@ def validate_query_and_break_to_components(query: str) -> List[str]:
69
  components = [component.strip() for component in components]
70
  for component in components:
71
  if not (
72
- bool(name.match(component))
73
- or component == "*"
74
- or bool(indx.match(component))
75
  ):
76
  raise ValueError(
77
  f"Component {component} in input query is none of: valid field-name, non-neg-int, or '*'"
@@ -79,10 +96,14 @@ def validate_query_and_break_to_components(query: str) -> List[str]:
79
  return components
80
 
81
 
82
- def is_subpath(subpath, fullpath):
83
  # Split the paths into individual components
84
- subpath_components = validate_query_and_break_to_components(subpath)
85
- fullpath_components = validate_query_and_break_to_components(fullpath)
 
 
 
 
86
 
87
  # Check if the full path starts with the subpath
88
  return fullpath_components[: len(subpath_components)] == subpath_components
@@ -100,16 +121,17 @@ def delete_values(
100
  query: List[str],
101
  index_into_query: int,
102
  remove_empty_ancestors=False,
 
103
  ) -> Tuple[bool, Any]:
104
  component = query[index_into_query]
105
  if index_into_query == -1:
106
- if component == "*":
107
  # delete all members of the list or dict
108
  current_element = [] if isinstance(current_element, list) else {}
109
  return (True, current_element)
110
  # component is a either a dictionary key or an index into a list,
111
  # pop the respective element from current_element
112
- if indx.match(component):
113
  component = int(component)
114
  try:
115
  current_element.pop(component)
@@ -141,6 +163,7 @@ def delete_values(
141
  query=query,
142
  index_into_query=index_into_query + 1,
143
  remove_empty_ancestors=remove_empty_ancestors,
 
144
  )
145
  if not success:
146
  continue
@@ -155,7 +178,7 @@ def delete_values(
155
  return (any_success, current_element)
156
 
157
  # current component is index into a list or a key into a dictionary
158
- if indx.match(component):
159
  component = int(component)
160
  try:
161
  success, new_val = delete_values(
@@ -163,6 +186,7 @@ def delete_values(
163
  query=query,
164
  index_into_query=index_into_query + 1,
165
  remove_empty_ancestors=remove_empty_ancestors,
 
166
  )
167
  if not success:
168
  return (False, None)
@@ -176,7 +200,11 @@ def delete_values(
176
 
177
 
178
  def dict_delete(
179
- dic: dict, query: str, not_exist_ok: bool = False, remove_empty_ancestors=False
 
 
 
 
180
  ):
181
  # We remove from dic the value from each and every element lead to by a path matching the query.
182
  # If remove_empty_ancestors=True, and the removal of any such value leaves its containing element (list or dict)
@@ -197,7 +225,9 @@ def dict_delete(
197
  dic.pop(query.strip())
198
  return
199
 
200
- qpath = validate_query_and_break_to_components(query)
 
 
201
 
202
  try:
203
  success, new_val = delete_values(
@@ -205,6 +235,7 @@ def dict_delete(
205
  query=qpath,
206
  index_into_query=(-1) * len(qpath),
207
  remove_empty_ancestors=remove_empty_ancestors,
 
208
  )
209
 
210
  if success:
@@ -225,7 +256,10 @@ def dict_delete(
225
  # if query includes * then return a list of values reached by all paths that match the query
226
  # flake8: noqa: C901
227
  def get_values(
228
- current_element: Any, query: List[str], index_into_query: int
 
 
 
229
  ) -> Tuple[bool, Any]:
230
  # going down from current_element through query[index_into_query].
231
  if index_into_query == 0:
@@ -244,7 +278,12 @@ def get_values(
244
  sub_elements = current_element
245
  for sub_element in sub_elements:
246
  try:
247
- success, val = get_values(sub_element, query, index_into_query + 1)
 
 
 
 
 
248
  if success:
249
  to_ret.append(val)
250
  except:
@@ -253,11 +292,14 @@ def get_values(
253
  return (len(to_ret) > 0 or index_into_query == -1, to_ret)
254
  # when * is the last component, it refers to 'all the contents' of an empty list being current_element.
255
  # next_component is indx or name, current_element must be a list or a dict
256
- if indx.match(component):
257
  component = int(component)
258
  try:
259
  success, new_val = get_values(
260
- current_element[component], query, index_into_query + 1
 
 
 
261
  )
262
  if success:
263
  return (True, new_val)
@@ -274,6 +316,7 @@ def set_values(
274
  index_into_query: int,
275
  fixed_parameters: dict,
276
  set_multiple: bool = False,
 
277
  ) -> Tuple[bool, Any]:
278
  if index_into_query == 0:
279
  return (True, value) # matched query all along!
@@ -321,6 +364,7 @@ def set_values(
321
  index_into_query=index_into_query + 1,
322
  set_multiple=False, # now used, not allowed again,
323
  fixed_parameters=fixed_parameters,
 
324
  )
325
  if not success:
326
  continue
@@ -335,7 +379,7 @@ def set_values(
335
  )
336
 
337
  # component is an index into a list or a key into a dictionary
338
- if indx.match(component):
339
  if current_element is None or not isinstance(current_element, list):
340
  if not fixed_parameters["generate_if_not_exists"]:
341
  return (False, None)
@@ -368,6 +412,7 @@ def set_values(
368
  index_into_query=index_into_query + 1,
369
  fixed_parameters=fixed_parameters,
370
  set_multiple=set_multiple,
 
371
  )
372
  if success:
373
  current_element[component] = new_val
@@ -383,6 +428,7 @@ def dict_get(
383
  query: str,
384
  not_exist_ok: bool = False,
385
  default: Any = None,
 
386
  ):
387
  if len(query.strip()) == 0:
388
  return dic
@@ -393,7 +439,9 @@ def dict_get(
393
  if isinstance(dic, dict) and query.strip() in dic:
394
  return dic[query.strip()]
395
 
396
- components = validate_query_and_break_to_components(query)
 
 
397
  if len(components) > 1:
398
  try:
399
  success, values = get_values(dic, components, -1 * len(components))
@@ -474,6 +522,7 @@ def dict_set(
474
  value: Any,
475
  not_exist_ok=True,
476
  set_multiple=False,
 
477
  ): # sets dic to its new value
478
  if dic is None or not isinstance(dic, (list, dict)):
479
  raise ValueError(
@@ -510,7 +559,9 @@ def dict_set(
510
  f"set_multiple=True, but value, {value}, can not be broken up, as either it is not a list or it is an empty list"
511
  )
512
 
513
- components = validate_query_and_break_to_components(query)
 
 
514
  fixed_parameters = {
515
  "query": components,
516
  "generate_if_not_exists": not_exist_ok,
@@ -522,6 +573,7 @@ def dict_set(
522
  index_into_query=(-1) * len(components),
523
  fixed_parameters=fixed_parameters,
524
  set_multiple=set_multiple,
 
525
  )
526
  if not success and not not_exist_ok:
527
  raise ValueError(f"No path in dic {dic} matches query {query}.")
 
4
  from .text_utils import construct_dict_str
5
 
6
  indx = re.compile(r"^(\d+)$")
7
+
8
+
9
+ def is_index(string):
10
+ return bool(indx.match(string))
11
+
12
+
13
  name = re.compile(r"^[\w. -]+$")
14
 
15
+
16
+ def is_name(string):
17
+ return bool(name.match(string))
18
+
19
+
20
+ def is_wildcard(string):
21
+ return string == "*"
22
+
23
+
24
  # formal definition of qpath syntax by which a query is specified:
25
  # qpath -> A (/A)*
26
  # A -> name | * | non-neg-int
 
66
 
67
 
68
  # validate and normalizes into components
69
+ def validate_query_and_break_to_components(
70
+ query: str, allow_int_index=True
71
+ ) -> List[str]:
72
  if not isinstance(query, str) or len(query) == 0:
73
  raise ValueError(
74
  f"invalid query: either not a string or an empty string: {query}"
 
86
  components = [component.strip() for component in components]
87
  for component in components:
88
  if not (
89
+ is_name(component)
90
+ or is_wildcard(component)
91
+ or (is_index(component) and allow_int_index)
92
  ):
93
  raise ValueError(
94
  f"Component {component} in input query is none of: valid field-name, non-neg-int, or '*'"
 
96
  return components
97
 
98
 
99
+ def is_subpath(subpath, fullpath, allow_int_index=True):
100
  # Split the paths into individual components
101
+ subpath_components = validate_query_and_break_to_components(
102
+ subpath, allow_int_index=allow_int_index
103
+ )
104
+ fullpath_components = validate_query_and_break_to_components(
105
+ fullpath, allow_int_index=allow_int_index
106
+ )
107
 
108
  # Check if the full path starts with the subpath
109
  return fullpath_components[: len(subpath_components)] == subpath_components
 
121
  query: List[str],
122
  index_into_query: int,
123
  remove_empty_ancestors=False,
124
+ allow_int_index=True,
125
  ) -> Tuple[bool, Any]:
126
  component = query[index_into_query]
127
  if index_into_query == -1:
128
+ if is_wildcard(component):
129
  # delete all members of the list or dict
130
  current_element = [] if isinstance(current_element, list) else {}
131
  return (True, current_element)
132
  # component is a either a dictionary key or an index into a list,
133
  # pop the respective element from current_element
134
+ if is_index(component) and allow_int_index:
135
  component = int(component)
136
  try:
137
  current_element.pop(component)
 
163
  query=query,
164
  index_into_query=index_into_query + 1,
165
  remove_empty_ancestors=remove_empty_ancestors,
166
+ allow_int_index=allow_int_index,
167
  )
168
  if not success:
169
  continue
 
178
  return (any_success, current_element)
179
 
180
  # current component is index into a list or a key into a dictionary
181
+ if is_index(component) and allow_int_index:
182
  component = int(component)
183
  try:
184
  success, new_val = delete_values(
 
186
  query=query,
187
  index_into_query=index_into_query + 1,
188
  remove_empty_ancestors=remove_empty_ancestors,
189
+ allow_int_index=allow_int_index,
190
  )
191
  if not success:
192
  return (False, None)
 
200
 
201
 
202
  def dict_delete(
203
+ dic: dict,
204
+ query: str,
205
+ not_exist_ok: bool = False,
206
+ remove_empty_ancestors=False,
207
+ allow_int_index=True,
208
  ):
209
  # We remove from dic the value from each and every element lead to by a path matching the query.
210
  # If remove_empty_ancestors=True, and the removal of any such value leaves its containing element (list or dict)
 
225
  dic.pop(query.strip())
226
  return
227
 
228
+ qpath = validate_query_and_break_to_components(
229
+ query, allow_int_index=allow_int_index
230
+ )
231
 
232
  try:
233
  success, new_val = delete_values(
 
235
  query=qpath,
236
  index_into_query=(-1) * len(qpath),
237
  remove_empty_ancestors=remove_empty_ancestors,
238
+ allow_int_index=allow_int_index,
239
  )
240
 
241
  if success:
 
256
  # if query includes * then return a list of values reached by all paths that match the query
257
  # flake8: noqa: C901
258
  def get_values(
259
+ current_element: Any,
260
+ query: List[str],
261
+ index_into_query: int,
262
+ allow_int_index=True,
263
  ) -> Tuple[bool, Any]:
264
  # going down from current_element through query[index_into_query].
265
  if index_into_query == 0:
 
278
  sub_elements = current_element
279
  for sub_element in sub_elements:
280
  try:
281
+ success, val = get_values(
282
+ sub_element,
283
+ query,
284
+ index_into_query + 1,
285
+ allow_int_index=allow_int_index,
286
+ )
287
  if success:
288
  to_ret.append(val)
289
  except:
 
292
  return (len(to_ret) > 0 or index_into_query == -1, to_ret)
293
  # when * is the last component, it refers to 'all the contents' of an empty list being current_element.
294
  # next_component is indx or name, current_element must be a list or a dict
295
+ if is_index(component) and allow_int_index:
296
  component = int(component)
297
  try:
298
  success, new_val = get_values(
299
+ current_element[component],
300
+ query,
301
+ index_into_query + 1,
302
+ allow_int_index=allow_int_index,
303
  )
304
  if success:
305
  return (True, new_val)
 
316
  index_into_query: int,
317
  fixed_parameters: dict,
318
  set_multiple: bool = False,
319
+ allow_int_index=True,
320
  ) -> Tuple[bool, Any]:
321
  if index_into_query == 0:
322
  return (True, value) # matched query all along!
 
364
  index_into_query=index_into_query + 1,
365
  set_multiple=False, # now used, not allowed again,
366
  fixed_parameters=fixed_parameters,
367
+ allow_int_index=allow_int_index,
368
  )
369
  if not success:
370
  continue
 
379
  )
380
 
381
  # component is an index into a list or a key into a dictionary
382
+ if is_index(component) and allow_int_index:
383
  if current_element is None or not isinstance(current_element, list):
384
  if not fixed_parameters["generate_if_not_exists"]:
385
  return (False, None)
 
412
  index_into_query=index_into_query + 1,
413
  fixed_parameters=fixed_parameters,
414
  set_multiple=set_multiple,
415
+ allow_int_index=allow_int_index,
416
  )
417
  if success:
418
  current_element[component] = new_val
 
428
  query: str,
429
  not_exist_ok: bool = False,
430
  default: Any = None,
431
+ allow_int_index=True,
432
  ):
433
  if len(query.strip()) == 0:
434
  return dic
 
439
  if isinstance(dic, dict) and query.strip() in dic:
440
  return dic[query.strip()]
441
 
442
+ components = validate_query_and_break_to_components(
443
+ query, allow_int_index=allow_int_index
444
+ )
445
  if len(components) > 1:
446
  try:
447
  success, values = get_values(dic, components, -1 * len(components))
 
522
  value: Any,
523
  not_exist_ok=True,
524
  set_multiple=False,
525
+ allow_int_index=True,
526
  ): # sets dic to its new value
527
  if dic is None or not isinstance(dic, (list, dict)):
528
  raise ValueError(
 
559
  f"set_multiple=True, but value, {value}, can not be broken up, as either it is not a list or it is an empty list"
560
  )
561
 
562
+ components = validate_query_and_break_to_components(
563
+ query, allow_int_index=allow_int_index
564
+ )
565
  fixed_parameters = {
566
  "query": components,
567
  "generate_if_not_exists": not_exist_ok,
 
573
  index_into_query=(-1) * len(components),
574
  fixed_parameters=fixed_parameters,
575
  set_multiple=set_multiple,
576
+ allow_int_index=allow_int_index,
577
  )
578
  if not success and not not_exist_ok:
579
  raise ValueError(f"No path in dic {dic} matches query {query}.")
formats.py CHANGED
@@ -79,13 +79,13 @@ class SystemFormat(BaseFormat):
79
  Important: formats can use '\N' notations that means new-line if no new-line before and no empty string before.
80
 
81
  SystemFormat expects the input instance to contain:
82
- 1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task independent opening text.
83
  2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
84
  from the source dataset), in the context of the underlying task.
85
  3. A field named "instruction" that contains a (non-None) string.
86
  4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
87
  and "target", representing a single demo.
88
- 5. A field named "target_prefx" that contains a string to prefix the target in both each demo, and to end the whole generated prompt
89
 
90
  SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
91
  field "source" of the instance. Formatting is driven by two args: 'demo_format' and 'model_input_format'.
@@ -200,16 +200,16 @@ class SystemFormat(BaseFormat):
200
 
201
 
202
  class HFSystemFormat(BaseFormat):
203
- r"""Formats the complete input for the model using the Hugginface chat template of a given model.
204
 
205
  HFSystemFormat expects the input instance to contain:
206
- 1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task independent opening text.
207
  2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
208
  from the source dataset), in the context of the underlying task.
209
  3. A field named "instruction" that contains a (non-None) string.
210
  4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
211
  and "target", representing a single demo.
212
- 5. A field named "target_prefx" that contains a string to prefix the target in both each demo, and to end the whole generated prompt
213
 
214
  SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
215
  field "source" of the instance.
 
79
  Important: formats can use '\N' notations that means new-line if no new-line before and no empty string before.
80
 
81
  SystemFormat expects the input instance to contain:
82
+ 1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
83
  2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
84
  from the source dataset), in the context of the underlying task.
85
  3. A field named "instruction" that contains a (non-None) string.
86
  4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
87
  and "target", representing a single demo.
88
+ 5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt
89
 
90
  SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
91
  field "source" of the instance. Formatting is driven by two args: 'demo_format' and 'model_input_format'.
 
200
 
201
 
202
  class HFSystemFormat(BaseFormat):
203
+ r"""Formats the complete input for the model using the HuggingFace chat template of a given model.
204
 
205
  HFSystemFormat expects the input instance to contain:
206
+ 1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
207
  2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
208
  from the source dataset), in the context of the underlying task.
209
  3. A field named "instruction" that contains a (non-None) string.
210
  4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
211
  and "target", representing a single demo.
212
+ 5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt.
213
 
214
  SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
215
  field "source" of the instance.
fusion.py CHANGED
@@ -12,13 +12,13 @@ class BaseFusion(SourceOperator):
12
  """BaseFusion operator that combines multiple multistreams into one.
13
 
14
  Args:
15
- origins: a dict of named SourceOperator objects (each to yield a MultiStream) or a list thereof,
16
  each is specified along with its input, so can generate a MultiStream
17
  include_splits: List of splits to include from each input MultiStream.
18
  If None, all splits are included.
19
  """
20
 
21
- origins: Union[List[SourceOperator], Dict[str, SourceOperator]]
22
  include_splits: Optional[List[str]] = NonPositionalField(default=None)
23
 
24
  @abstractmethod
@@ -26,18 +26,18 @@ class BaseFusion(SourceOperator):
26
  pass
27
 
28
  def prepare(self):
29
- assert isoftype(self.origins, Dict[str, SourceOperator]) or isoftype(
30
- self.origins, List[SourceOperator]
31
  )
32
- self.named_origins = (
33
- {i: self.origins[i]() for i in range(len(self.origins))}
34
- if isinstance(self.origins, list)
35
- else {name: origin() for name, origin in self.origins.items()}
36
  )
37
 
38
  def splits(self) -> List[str]:
39
  splits = []
40
- for _, origin in self.named_origins.items():
41
  for s in origin.keys():
42
  if s not in splits:
43
  if self.include_splits is None or s in self.include_splits:
@@ -59,69 +59,69 @@ class FixedFusion(BaseFusion):
59
  """FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
60
 
61
  Args:
62
- origins: Dict of named SourceOperator objects (each to yield a MultiStream), or a list thereof
63
  splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
64
- max_instances_per_origin_split: Number of instances to take from each input split of each input multistream.
65
  If None, all instances of each split (that is specified in include_splits) are included in the result.
66
 
67
  """
68
 
69
- max_instances_per_origin_split: Optional[int] = None
70
 
71
  def prepare(self):
72
  super().prepare()
73
 
74
  # flake8: noqa: C901
75
  def fusion_generator(self, split) -> Generator:
76
- for origin_name, origin in self.named_origins.items():
77
  if split not in origin:
78
  continue
79
  emitted_from_this_split = 0
80
- for instance in origin[split]:
81
- if (
82
- self.max_instances_per_origin_split is not None
83
- and emitted_from_this_split >= self.max_instances_per_origin_split
84
- ):
85
- break
86
- if isinstance(origin_name, str):
87
- # named origins, not anonymous, record in instance
88
- if "group" in instance:
89
- instance["group"] = origin_name + "/" + instance["group"]
90
- else:
91
- instance["group"] = origin_name
92
- emitted_from_this_split += 1
93
- yield instance
 
94
 
95
 
96
  class WeightedFusion(BaseFusion):
97
  """Fusion operator that combines multiple MultiStream-s.
98
 
99
  Args:
100
- origins: Dict of named MultiStream objects, or a list thereof
101
  weights: Dict of named weights for each origin, or a list thereof
102
  max_total_examples: Total number of instances to return per returned split.
103
  If None, all instances are returned
104
  """
105
 
106
- origins: Union[Dict[str, SourceOperator], List[SourceOperator]] = None
107
  weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
108
- max_total_examples: int = None
109
- ignore_origin_groups: List[str] = ["unitxt"]
110
 
111
  def verify(self):
112
  super().verify()
113
- assert self.origins is not None, "origins must be specified"
114
  assert self.weights is not None, "weights must be specified"
115
- assert len(self.origins) == len(
116
  self.weights
117
- ), "origins and weights must have the same length"
118
- assert isoftype(self.origins, Dict[str, SourceOperator]) or isoftype(
119
- self.origins, List[SourceOperator]
120
  )
121
  assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
122
  self.weights, List[Union[int, float]]
123
  )
124
- assert isinstance(self.origins, dict) == isinstance(self.weights, dict)
125
 
126
  def prepare(self):
127
  super().prepare()
@@ -134,12 +134,12 @@ class WeightedFusion(BaseFusion):
134
  def fusion_generator(self, split) -> Generator:
135
  iterators = {
136
  named_origin: iter(origin[split])
137
- for named_origin, origin in self.named_origins.items()
138
  }
139
  total_examples = 0
140
  random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
141
  while (
142
- self.max_total_examples is None or total_examples < self.max_total_examples
143
  ) and len(iterators) > 0:
144
  population = list(iterators.keys())
145
  origin_name = random_generator.choices(
@@ -150,13 +150,9 @@ class WeightedFusion(BaseFusion):
150
  try:
151
  instance = next(iterator)
152
  if isinstance(origin_name, str):
153
- if (
154
- "group" in instance
155
- and instance["group"] not in self.ignore_origin_groups
156
- ):
157
- instance["group"] = origin_name + "/" + instance["group"]
158
- else:
159
- instance["group"] = origin_name
160
  total_examples += 1
161
  yield instance
162
 
 
12
  """BaseFusion operator that combines multiple multistreams into one.
13
 
14
  Args:
15
+ subsets: a dict of named SourceOperator objects (each to yield a MultiStream) or a list thereof,
16
  each is specified along with its input, so can generate a MultiStream
17
  include_splits: List of splits to include from each input MultiStream.
18
  If None, all splits are included.
19
  """
20
 
21
+ subsets: Union[List[SourceOperator], Dict[str, SourceOperator]]
22
  include_splits: Optional[List[str]] = NonPositionalField(default=None)
23
 
24
  @abstractmethod
 
26
  pass
27
 
28
  def prepare(self):
29
+ assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
30
+ self.subsets, List[SourceOperator]
31
  )
32
+ self.named_subsets = (
33
+ {i: self.subsets[i]() for i in range(len(self.subsets))}
34
+ if isinstance(self.subsets, list)
35
+ else {name: origin() for name, origin in self.subsets.items()}
36
  )
37
 
38
  def splits(self) -> List[str]:
39
  splits = []
40
+ for _, origin in self.named_subsets.items():
41
  for s in origin.keys():
42
  if s not in splits:
43
  if self.include_splits is None or s in self.include_splits:
 
59
  """FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
60
 
61
  Args:
62
+ subsets: Dict of named SourceOperator objects (each to yield a MultiStream), or a list thereof
63
  splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
64
+ max_instances_per_subset: Number of instances to take from each input split of each input multistream.
65
  If None, all instances of each split (that is specified in include_splits) are included in the result.
66
 
67
  """
68
 
69
+ max_instances_per_subset: Optional[int] = None
70
 
71
  def prepare(self):
72
  super().prepare()
73
 
74
  # flake8: noqa: C901
75
  def fusion_generator(self, split) -> Generator:
76
+ for origin_name, origin in self.named_subsets.items():
77
  if split not in origin:
78
  continue
79
  emitted_from_this_split = 0
80
+ try:
81
+ for instance in origin[split]:
82
+ if (
83
+ self.max_instances_per_subset is not None
84
+ and emitted_from_this_split >= self.max_instances_per_subset
85
+ ):
86
+ break
87
+ if isinstance(origin_name, str):
88
+ if "subset" not in instance:
89
+ instance["subset"] = []
90
+ instance["subset"].insert(0, origin_name)
91
+ emitted_from_this_split += 1
92
+ yield instance
93
+ except Exception as e:
94
+ raise RuntimeError(f"Exception in subset: {origin_name}") from e
95
 
96
 
97
  class WeightedFusion(BaseFusion):
98
  """Fusion operator that combines multiple MultiStream-s.
99
 
100
  Args:
101
+ subsets: Dict of named MultiStream objects, or a list thereof
102
  weights: Dict of named weights for each origin, or a list thereof
103
  max_total_examples: Total number of instances to return per returned split.
104
  If None, all instances are returned
105
  """
106
 
107
+ subsets: Union[Dict[str, SourceOperator], List[SourceOperator]] = None
108
  weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
109
+ max_total_samples: int = None
 
110
 
111
  def verify(self):
112
  super().verify()
113
+ assert self.subsets is not None, "subsets must be specified"
114
  assert self.weights is not None, "weights must be specified"
115
+ assert len(self.subsets) == len(
116
  self.weights
117
+ ), "subsets and weights must have the same length"
118
+ assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
119
+ self.subsets, List[SourceOperator]
120
  )
121
  assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
122
  self.weights, List[Union[int, float]]
123
  )
124
+ assert isinstance(self.subsets, dict) == isinstance(self.weights, dict)
125
 
126
  def prepare(self):
127
  super().prepare()
 
134
  def fusion_generator(self, split) -> Generator:
135
  iterators = {
136
  named_origin: iter(origin[split])
137
+ for named_origin, origin in self.named_subsets.items()
138
  }
139
  total_examples = 0
140
  random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
141
  while (
142
+ self.max_total_samples is None or total_examples < self.max_total_samples
143
  ) and len(iterators) > 0:
144
  population = list(iterators.keys())
145
  origin_name = random_generator.choices(
 
150
  try:
151
  instance = next(iterator)
152
  if isinstance(origin_name, str):
153
+ if "subset" not in instance:
154
+ instance["subset"] = []
155
+ instance["subset"].insert(0, origin_name)
 
 
 
 
156
  total_examples += 1
157
  yield instance
158
 
image_operators.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Any, Dict
3
+
4
+ from .dict_utils import dict_get
5
+ from .operators import InstanceFieldOperator
6
+
7
+
8
+ def extract_images(text, instance):
9
+ regex = r'<img\s+src=["\'](.*?)["\']'
10
+ image_sources = re.findall(regex, text)
11
+ images = []
12
+ for image_source in image_sources:
13
+ image = dict_get(instance, image_source)
14
+ images.append(image)
15
+ return images
16
+
17
+
18
+ class ImageToText(InstanceFieldOperator):
19
+ def process_instance_value(self, value: Any, instance: Dict[str, Any]):
20
+ if "media" not in instance:
21
+ instance["media"] = {}
22
+ if "images" not in instance["media"]:
23
+ instance["media"]["images"] = []
24
+ idx = len(instance["media"]["images"])
25
+ instance["media"]["images"].append(value)
26
+ return f'<img src="media/images/{idx}">'
inference.py CHANGED
@@ -1,12 +1,14 @@
1
  import abc
2
  import os
 
3
  from typing import Any, Dict, List, Literal, Optional, Union
4
 
5
  from tqdm import tqdm
6
 
7
  from .artifact import Artifact
8
- from .dataclass import InternalField
9
  from .deprecation_utils import deprecation
 
10
  from .logging_utils import get_logger
11
  from .operator import PackageRequirementsMixin
12
 
@@ -61,11 +63,20 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
61
  return self._infer_log_probs(dataset)
62
 
63
 
64
- class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
 
 
 
 
 
 
 
 
 
 
65
  model_name: str
66
  max_new_tokens: int
67
  use_fp16: bool = True
68
- lazy_load: bool = False
69
 
70
  _requirements_list = {
71
  "transformers": "Install huggingface package using 'pip install --upgrade transformers"
@@ -115,11 +126,11 @@ class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
115
  if not self.lazy_load:
116
  self._prepare_pipeline()
117
 
118
- def is_pipeline_initialized(self):
119
  return hasattr(self, "model") and self.model is not None
120
 
121
  def _infer(self, dataset):
122
- if not self.is_pipeline_initialized():
123
  self._prepare_pipeline()
124
 
125
  outputs = []
@@ -497,3 +508,71 @@ class WMLInferenceEngine(
497
  prompt=dataset["source"],
498
  params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
499
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import abc
2
  import os
3
+ import re
4
  from typing import Any, Dict, List, Literal, Optional, Union
5
 
6
  from tqdm import tqdm
7
 
8
  from .artifact import Artifact
9
+ from .dataclass import InternalField, NonPositionalField
10
  from .deprecation_utils import deprecation
11
+ from .image_operators import extract_images
12
  from .logging_utils import get_logger
13
  from .operator import PackageRequirementsMixin
14
 
 
63
  return self._infer_log_probs(dataset)
64
 
65
 
66
+ class LazyLoadMixin(Artifact):
67
+ lazy_load: bool = NonPositionalField(default=False)
68
+
69
+ @abc.abstractmethod
70
+ def _is_loaded(self):
71
+ pass
72
+
73
+
74
+ class HFPipelineBasedInferenceEngine(
75
+ InferenceEngine, PackageRequirementsMixin, LazyLoadMixin
76
+ ):
77
  model_name: str
78
  max_new_tokens: int
79
  use_fp16: bool = True
 
80
 
81
  _requirements_list = {
82
  "transformers": "Install huggingface package using 'pip install --upgrade transformers"
 
126
  if not self.lazy_load:
127
  self._prepare_pipeline()
128
 
129
+ def _is_loaded(self):
130
  return hasattr(self, "model") and self.model is not None
131
 
132
  def _infer(self, dataset):
133
+ if not self._is_loaded():
134
  self._prepare_pipeline()
135
 
136
  outputs = []
 
508
  prompt=dataset["source"],
509
  params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
510
  )
511
+
512
+
513
+ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
514
+ model_name: str
515
+ max_new_tokens: int
516
+ lazy_load = True
517
+
518
+ _requirements_list = {
519
+ "transformers": "Install huggingface package using 'pip install --upgrade transformers",
520
+ "torch": "Install torch, go on PyTorch website for mode details.",
521
+ "accelerate": "pip install accelerate",
522
+ }
523
+
524
+ def _prepare_engine(self):
525
+ import torch
526
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
527
+
528
+ self.device = torch.device(
529
+ "mps"
530
+ if torch.backends.mps.is_available()
531
+ else 0
532
+ if torch.cuda.is_available()
533
+ else "cpu"
534
+ )
535
+
536
+ self.model = LlavaForConditionalGeneration.from_pretrained(
537
+ self.model_name,
538
+ torch_dtype=torch.float16,
539
+ low_cpu_mem_usage=True,
540
+ ).to(self.device)
541
+
542
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
543
+
544
+ def prepare(self):
545
+ if not self.lazy_load:
546
+ self._prepare_engine()
547
+
548
+ def _is_loaded(self):
549
+ return hasattr(self, "model") and self.model is not None
550
+
551
+ def _infer(self, dataset):
552
+ if not self._is_loaded():
553
+ self._prepare_engine()
554
+
555
+ import torch
556
+
557
+ results = []
558
+ for instance in dataset:
559
+ text = instance["source"]
560
+ images = extract_images(text, instance)
561
+ # Regular expression to match all <img src="..."> tags
562
+ regex = r'<img\s+src=["\'](.*?)["\']\s*/?>'
563
+ model_input = re.sub(regex, "<image>", text)
564
+ if len(images) == 1:
565
+ images = images[0]
566
+ inputs = self.processor(
567
+ images=images, text=model_input, return_tensors="pt"
568
+ ).to(self.device, torch.float16)
569
+ input_len = len(inputs["input_ids"][0])
570
+ output = self.model.generate(
571
+ **inputs, max_new_tokens=self.max_new_tokens, do_sample=False
572
+ )
573
+ result = self.processor.decode(
574
+ output[0][input_len:], skip_special_tokens=True
575
+ )
576
+ results.append(result)
577
+
578
+ return results
llm_as_judge.py CHANGED
@@ -1,28 +1,32 @@
1
  from typing import Any, Dict, List, Literal, Optional
2
 
3
- from .api import evaluate, produce
4
- from .artifact import Artifact, fetch_artifact, settings
5
- from .formats import Format
 
6
  from .inference import InferenceEngine, OpenAiInferenceEngine
7
  from .metrics import BulkInstanceMetric
8
  from .operator import SequentialOperator
9
- from .system_prompts import SystemPrompt
 
10
  from .templates import Template
11
 
 
 
12
 
13
  class LLMAsJudge(BulkInstanceMetric):
14
- """LLM as judge based metric class for evaluating correctness.
15
 
16
  Attributes:
17
  main_score (str): The main score label used for evaluation.
18
- task (Literal["rating.single_turn"]): The type of task the llm-as-judge runs. This defines the output and input
19
- format of the jude model.
20
  template (Template): The template used when generating inputs for the judge llm.
21
  format (Format): The format used when generating inputs for judge llm.
22
  system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
23
  strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
24
  inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
25
- inference_model (InferenceEngine): the module that creates the inference of the judge llm.
26
  reduction_map (dict): A dictionary specifying the reduction method for the metric.
27
  batch_size (int): The size of the bulk.
28
  """
@@ -34,8 +38,8 @@ class LLMAsJudge(BulkInstanceMetric):
34
  "pairwise_comparative_rating.single_turn",
35
  ]
36
  template: Template
37
- format: Format = None
38
- system_prompt: SystemPrompt = None
39
  strip_system_prompt_and_format_from_inputs: bool = True
40
  inference_model: InferenceEngine
41
  reduction_map: Optional[Dict[str, List[str]]] = None
@@ -72,7 +76,6 @@ class LLMAsJudge(BulkInstanceMetric):
72
  {
73
  "question": input_instance,
74
  "answer": prediction,
75
- "rating": 5.0, # This is a dummy value that is not used in practice
76
  }
77
  for input_instance, prediction, reference in zip(
78
  input_instances, predictions, references
@@ -84,7 +87,6 @@ class LLMAsJudge(BulkInstanceMetric):
84
  "question": input_instance,
85
  "answer": prediction,
86
  "reference_answer": reference[0],
87
- "rating": 5.0, # This is a dummy value that is not used in practice
88
  }
89
  for input_instance, prediction, reference in zip(
90
  input_instances, predictions, references
@@ -98,7 +100,6 @@ class LLMAsJudge(BulkInstanceMetric):
98
  "answer_b": reference[0],
99
  "model_a": "input_model",
100
  "model_b": "baseline_model",
101
- "answer_a_preference": 0, # This is a dummy value that is not used in practice,
102
  }
103
  for input_instance, prediction, reference in zip(
104
  input_instances, predictions, references
@@ -110,15 +111,6 @@ class LLMAsJudge(BulkInstanceMetric):
110
  )
111
  return instances
112
 
113
- @staticmethod
114
- def _add_metadata_to_judge_instances(
115
- instances: List[List[Any]], task_data: List[Dict]
116
- ):
117
- for instance, data in zip(instances, task_data):
118
- instance["data_classification_policy"] = data["metadata"][
119
- "data_classification_policy"
120
- ]
121
-
122
  def prepare(self):
123
  super().prepare()
124
  if self.task == "pairwise_comparative_rating.single_turn":
@@ -176,47 +168,38 @@ class LLMAsJudge(BulkInstanceMetric):
176
  instances = self._get_instance_for_judge_model(
177
  input_instances, predictions, references
178
  )
179
- self._add_metadata_to_judge_instances(instances, task_data)
180
-
181
- card = f"cards.dynamic_cards_for_llm_judges.{self.task}"
182
- recipe_args = {
183
- "card": card,
184
- "template": self.template,
185
- "demos_pool_size": 0,
186
- "num_demos": 0,
187
- "__type__": settings.default_recipe,
188
- }
189
- if self.system_prompt:
190
- recipe_args["system_prompt"] = self.system_prompt
191
- if self.format:
192
- recipe_args["format"] = self.format
193
- recipe = Artifact.from_dict(recipe_args)
194
- dataset = produce(instances, recipe)
195
- verdicts = self.inference_model.infer(dataset)
196
- meta_scores = evaluate(predictions=verdicts, data=dataset)
197
-
198
- res_list = []
199
- for instance, verdict in zip(meta_scores, verdicts):
200
  if self.task == "pairwise_comparative_rating.single_turn":
201
  is_model_b_the_baseline = (
202
  instance["task_data"]["model_b"] == "baseline_model"
203
  )
204
  if is_model_b_the_baseline:
205
- model_a_preference_score = instance["processed_prediction"]
206
  else:
207
- model_a_preference_score = instance["processed_prediction"] * -1
208
 
209
- res = {
210
  self.main_score: model_a_preference_score,
211
- "judge_raw_output": verdict,
212
  "judge_raw_input": instance["source"],
213
  }
214
  else:
215
- res = {
216
- self.main_score: instance["processed_prediction"],
217
- "judge_raw_output": verdict,
218
  "judge_raw_input": instance["source"],
219
  }
220
- res_list.append(res)
221
 
222
- return res_list
 
1
  from typing import Any, Dict, List, Literal, Optional
2
 
3
+ from .api import infer
4
+ from .artifact import fetch_artifact
5
+ from .dataclass import Field
6
+ from .formats import Format, SystemFormat
7
  from .inference import InferenceEngine, OpenAiInferenceEngine
8
  from .metrics import BulkInstanceMetric
9
  from .operator import SequentialOperator
10
+ from .settings_utils import get_settings
11
+ from .system_prompts import EmptySystemPrompt, SystemPrompt
12
  from .templates import Template
13
 
14
+ settings = get_settings()
15
+
16
 
17
  class LLMAsJudge(BulkInstanceMetric):
18
+ """LLM-as-judge-based metric class for evaluating correctness.
19
 
20
  Attributes:
21
  main_score (str): The main score label used for evaluation.
22
+ task (Literal["rating.single_turn"]): The type of task the llm as judge runs. This defines the output and input
23
+ format of the judge model.
24
  template (Template): The template used when generating inputs for the judge llm.
25
  format (Format): The format used when generating inputs for judge llm.
26
  system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
27
  strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
28
  inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
29
+ inference_model (InferenceEngine): The module that creates the inference of the judge llm.
30
  reduction_map (dict): A dictionary specifying the reduction method for the metric.
31
  batch_size (int): The size of the bulk.
32
  """
 
38
  "pairwise_comparative_rating.single_turn",
39
  ]
40
  template: Template
41
+ system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
42
+ format: Format = Field(default_factory=SystemFormat)
43
  strip_system_prompt_and_format_from_inputs: bool = True
44
  inference_model: InferenceEngine
45
  reduction_map: Optional[Dict[str, List[str]]] = None
 
76
  {
77
  "question": input_instance,
78
  "answer": prediction,
 
79
  }
80
  for input_instance, prediction, reference in zip(
81
  input_instances, predictions, references
 
87
  "question": input_instance,
88
  "answer": prediction,
89
  "reference_answer": reference[0],
 
90
  }
91
  for input_instance, prediction, reference in zip(
92
  input_instances, predictions, references
 
100
  "answer_b": reference[0],
101
  "model_a": "input_model",
102
  "model_b": "baseline_model",
 
103
  }
104
  for input_instance, prediction, reference in zip(
105
  input_instances, predictions, references
 
111
  )
112
  return instances
113
 
 
 
 
 
 
 
 
 
 
114
  def prepare(self):
115
  super().prepare()
116
  if self.task == "pairwise_comparative_rating.single_turn":
 
168
  instances = self._get_instance_for_judge_model(
169
  input_instances, predictions, references
170
  )
171
+ outputs = infer(
172
+ instances,
173
+ engine=self.inference_model,
174
+ task=f"tasks.response_assessment.{self.task}",
175
+ template=self.template,
176
+ system_prompt=self.system_prompt,
177
+ format=self.format,
178
+ return_data=True,
179
+ )
180
+
181
+ results = []
182
+ for instance in outputs:
 
 
 
 
 
 
 
 
 
183
  if self.task == "pairwise_comparative_rating.single_turn":
184
  is_model_b_the_baseline = (
185
  instance["task_data"]["model_b"] == "baseline_model"
186
  )
187
  if is_model_b_the_baseline:
188
+ model_a_preference_score = instance["prediction"]
189
  else:
190
+ model_a_preference_score = instance["prediction"] * -1
191
 
192
+ result = {
193
  self.main_score: model_a_preference_score,
194
+ "judge_raw_output": instance["raw_prediction"],
195
  "judge_raw_input": instance["source"],
196
  }
197
  else:
198
+ result = {
199
+ self.main_score: instance["prediction"],
200
+ "judge_raw_output": instance["raw_prediction"],
201
  "judge_raw_input": instance["source"],
202
  }
203
+ results.append(result)
204
 
205
+ return results
loaders.py CHANGED
@@ -7,23 +7,23 @@ Unitxt is all about readily preparing of any given data source for feeding into
7
  post-processing the model's output, preparing it for any given evaluator.
8
 
9
  Through that journey, the data advances in the form of Unitxt Multistream, undergoing a sequential application
10
- of various off the shelf operators (i.e, picked from Unitxt catalog), or operators easily implemented by inheriting.
11
- The journey starts by a Unitxt Loeader bearing a Multistream from the given datasource.
12
  A loader, therefore, is the first item on any Unitxt Recipe.
13
 
14
  Unitxt catalog contains several loaders for the most popular datasource formats.
15
- All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource, is
16
- straight forward.
17
 
18
  Available Loaders Overview:
19
- - :ref:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from Huggingface datasets.
20
  - :ref:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
21
  - :ref:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
22
  - :ref:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
23
  - :ref:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
24
  - :ref:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
25
  - :ref:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
26
- - :ref:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from Huggingface Spaces.
27
 
28
 
29
 
@@ -64,7 +64,7 @@ class Loader(SourceOperator):
64
 
65
  A loader is the first component in the Unitxt Recipe,
66
  responsible for loading data from various sources and preparing it as a MultiStream for processing.
67
- The loader_limit an optional parameter used to control the maximum number of instances to load from the data source. It is applied for each split separately.
68
  It is usually provided to the loader via the recipe (see standard.py)
69
  The loader can use this value to limit the amount of data downloaded from the source
70
  to reduce loading time. However, this may not always be possible, so the
@@ -140,13 +140,13 @@ class Loader(SourceOperator):
140
 
141
 
142
  class LoadHF(Loader):
143
- """Loads datasets from the Huggingface Hub.
144
 
145
  It supports loading with or without streaming,
146
- and can filter datasets upon loading.
147
 
148
  Args:
149
- path: The path or identifier of the dataset on the Huggingface Hub.
150
  name: An optional dataset name.
151
  data_dir: Optional directory to store downloaded data.
152
  split: Optional specification of which split to load.
@@ -652,7 +652,7 @@ class MultipleSourceLoader(Loader):
652
  sources: A list of loaders that will be combined to form a unified dataset.
653
 
654
  Examples:
655
- 1) Loading the train split from Huggingface hub and the test set from a local file:
656
 
657
  .. code-block:: python
658
 
@@ -678,12 +678,12 @@ class MultipleSourceLoader(Loader):
678
 
679
  def load_data(self):
680
  return FixedFusion(
681
- origins=self.sources, max_instances_per_origin_split=self.get_limit()
682
  ).process()
683
 
684
 
685
  class LoadFromDictionary(Loader):
686
- """Allows loading data from dictionary of constants.
687
 
688
  The loader can be used, for example, when debugging or working with small datasets.
689
 
@@ -733,29 +733,29 @@ class LoadFromDictionary(Loader):
733
 
734
 
735
  class LoadFromHFSpace(LoadHF):
736
- """Used to load data from Huggingface spaces.
737
 
738
  Loaders firstly tries to download all files specified in the 'data_files' parameter
739
- from the given space and then reads them as a Huggingface dataset.
740
 
741
  Args:
742
- space_name (str): Name of the Huggingface space to be accessed to.
743
  data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
744
  paths to files within a given repository. If given as a mapping, paths should
745
  be values, while keys should represent the type of respective files
746
  (training, testing etc.).
747
- path (str, optional): Absolute path to a directory where data should be downloaded to.
748
  revision (str, optional): ID of a Git branch or commit to be used. By default, it is
749
  set to None, thus data is downloaded from the main branch of the accessed
750
  repository.
751
- use_token (bool, optional): Whether token used for authentication when accessing
752
- the Huggingface space - if necessary - should be read from the Huggingface
753
  config folder.
754
  token_env (str, optional): Key of an env variable which value will be used for
755
- authentication when accessing the Huggingface space - if necessary.
756
 
757
  Example:
758
- Loading from Huggingface Space
759
 
760
  .. code-block:: python
761
 
 
7
  post-processing the model's output, preparing it for any given evaluator.
8
 
9
  Through that journey, the data advances in the form of Unitxt Multistream, undergoing a sequential application
10
+ of various off-the-shelf operators (i.e., picked from Unitxt catalog), or operators easily implemented by inheriting.
11
+ The journey starts by a Unitxt Loader bearing a Multistream from the given datasource.
12
  A loader, therefore, is the first item on any Unitxt Recipe.
13
 
14
  Unitxt catalog contains several loaders for the most popular datasource formats.
15
+ All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource is
16
+ straightforward.
17
 
18
  Available Loaders Overview:
19
+ - :ref:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from HuggingFace Datasets.
20
  - :ref:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
21
  - :ref:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
22
  - :ref:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
23
  - :ref:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
24
  - :ref:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
25
  - :ref:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
26
+ - :ref:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
27
 
28
 
29
 
 
64
 
65
  A loader is the first component in the Unitxt Recipe,
66
  responsible for loading data from various sources and preparing it as a MultiStream for processing.
67
+ The loader_limit is an optional parameter used to control the maximum number of instances to load from the data source. It is applied for each split separately.
68
  It is usually provided to the loader via the recipe (see standard.py)
69
  The loader can use this value to limit the amount of data downloaded from the source
70
  to reduce loading time. However, this may not always be possible, so the
 
140
 
141
 
142
  class LoadHF(Loader):
143
+ """Loads datasets from the HuggingFace Hub.
144
 
145
  It supports loading with or without streaming,
146
+ and it can filter datasets upon loading.
147
 
148
  Args:
149
+ path: The path or identifier of the dataset on the HuggingFace Hub.
150
  name: An optional dataset name.
151
  data_dir: Optional directory to store downloaded data.
152
  split: Optional specification of which split to load.
 
652
  sources: A list of loaders that will be combined to form a unified dataset.
653
 
654
  Examples:
655
+ 1) Loading the train split from a HuggingFace Hub and the test set from a local file:
656
 
657
  .. code-block:: python
658
 
 
678
 
679
  def load_data(self):
680
  return FixedFusion(
681
+ subsets=self.sources, max_instances_per_subset=self.get_limit()
682
  ).process()
683
 
684
 
685
  class LoadFromDictionary(Loader):
686
+ """Allows loading data from a dictionary of constants.
687
 
688
  The loader can be used, for example, when debugging or working with small datasets.
689
 
 
733
 
734
 
735
  class LoadFromHFSpace(LoadHF):
736
+ """Used to load data from HuggingFace Spaces.
737
 
738
  Loaders firstly tries to download all files specified in the 'data_files' parameter
739
+ from the given space and then reads them as a HuggingFace Dataset.
740
 
741
  Args:
742
+ space_name (str): Name of the HuggingFace Space to be accessed.
743
  data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
744
  paths to files within a given repository. If given as a mapping, paths should
745
  be values, while keys should represent the type of respective files
746
  (training, testing etc.).
747
+ path (str, optional): Absolute path to a directory where data should be downloaded.
748
  revision (str, optional): ID of a Git branch or commit to be used. By default, it is
749
  set to None, thus data is downloaded from the main branch of the accessed
750
  repository.
751
+ use_token (bool, optional): Whether a token is used for authentication when accessing
752
+ the HuggingFace Space. If necessary, the token is read from the HuggingFace
753
  config folder.
754
  token_env (str, optional): Key of an env variable which value will be used for
755
+ authentication when accessing the HuggingFace Space - if necessary.
756
 
757
  Example:
758
+ Loading from a HuggingFace Space
759
 
760
  .. code-block:: python
761
 
metric.py CHANGED
@@ -4,6 +4,7 @@ import evaluate
4
 
5
  from .api import __file__ as _
6
  from .artifact import __file__ as _
 
7
  from .blocks import __file__ as _
8
  from .card import __file__ as _
9
  from .catalog import __file__ as _
@@ -22,6 +23,7 @@ from .fusion import __file__ as _
22
  from .generator_utils import __file__ as _
23
  from .hf_utils import __file__ as _
24
  from .hf_utils import verify_versions_compatibility
 
25
  from .inference import __file__ as _
26
  from .instructions import __file__ as _
27
  from .llm_as_judge import __file__ as _
 
4
 
5
  from .api import __file__ as _
6
  from .artifact import __file__ as _
7
+ from .benchmark import __file__ as _
8
  from .blocks import __file__ as _
9
  from .card import __file__ as _
10
  from .catalog import __file__ as _
 
23
  from .generator_utils import __file__ as _
24
  from .hf_utils import __file__ as _
25
  from .hf_utils import verify_versions_compatibility
26
+ from .image_operators import __file__ as _
27
  from .inference import __file__ as _
28
  from .instructions import __file__ as _
29
  from .llm_as_judge import __file__ as _
metric_utils.py CHANGED
@@ -1,11 +1,12 @@
1
  import json
2
- from typing import Any, Dict, Generator, Iterable, List, Optional
 
 
 
3
 
4
  from datasets import Features, Value
5
- from numpy import nanmean
6
 
7
  from .dataclass import Dataclass
8
- from .dict_utils import dict_set
9
  from .operator import (
10
  MultiStreamOperator,
11
  SequentialOperator,
@@ -17,111 +18,20 @@ from .operators import (
17
  ApplyOperatorsField,
18
  Copy,
19
  FlattenInstances,
20
- MergeStreams,
21
- RenameFields,
22
- SplitByNestedGroup,
23
  )
24
  from .register import _reset_env_local_catalogs, register_all_artifacts
25
  from .schema import UNITXT_DATASET_SCHEMA
26
- from .settings_utils import get_settings
27
  from .stream import DynamicStream, MultiStream
28
  from .struct_data_operators import LoadJson
29
  from .utils import deepcopy
30
 
 
31
 
32
- class MultiStreamScoreMean(MultiStreamOperator):
33
- """Given a multi-stream where each stream is already scored globally, generate a nested global score for the whole multi-stream.
34
-
35
- The whole-ms-global-score is a nested structure, specifying (also) the individual global scores of the
36
- individual streams participating in the input multi_stream.
37
- The instances of all these individual streams are assumed to have the "group" field indicate the stream
38
- they belong to.
39
- Potentially, these individual streams were produced from a SplitByNestedGroup
40
- operator that did not use the full length of the value in field "group" of the instances, but only the
41
- first g components thereof, indicated by argument 'number_of_fusion_generations' of operator SplitByNestedGroup.
42
- At any rate, a distinguishing prefix of the "group" value is recorded, by operator SplitByNestedGroup, in the stream_name.
43
- The nested structure of the whole-ms-global-score is induced by these distinguishing prefixes,
44
- by virtue of the global score of each individual stream sitting in the nested whole-ms-global-score,
45
- deep in that dictionary, at the leaf lead to by a path being the distinguishing prefix indicated in the stream_name.
46
- Thus, the global score of the stream becomes a leaf (though a dict by itself) of the whole-ms-global-score.
47
-
48
- The ancestor nodes of the above leaves, in the whole-ms-global-score, contain each (in addition to dicts
49
- leading down to leaves) a field named "score" whose value is set to be the mean of the values
50
- sitting in field "score" of its immediate children nodes, and a field named "score_name" whose
51
- value is set to be "group_mean".
52
-
53
- When the input multistream consists of one single stream, it is returned as is, mainly for backward compatibility.
54
- """
55
-
56
- def update_intermediate_level_scores(self, level: dict) -> float:
57
- if "score" in level:
58
- return level["score"]
59
- # the global score of the stream participating in this MultiStream
60
- sub_scores = []
61
- for key in level:
62
- if isinstance(level[key], dict):
63
- sub_scores.append(self.update_intermediate_level_scores(level[key]))
64
- level.update({"score": nanmean(sub_scores), "score_name": "groups_mean"})
65
- return level["score"]
66
-
67
- def process(self, multi_stream: MultiStream) -> MultiStream:
68
- # each stream went through Metric which is a single-stream-operator , and ended up with all
69
- # its instance["score"]["global"] linking to the same single dict object.
70
- # Here we first generate a new, nested version, for the whole-ms-global_score, and then update
71
- # each stream's global score with the new version
72
- # but if only one stream in the multistream - we return it as is
73
- if len(multi_stream) == 1:
74
- return multi_stream
75
- global_score = {}
76
- first_instances = {}
77
- iterators = {}
78
 
79
- for stream_name, stream in multi_stream.items():
80
- iterators[stream_name] = iter(stream)
81
- try:
82
- first_instances[stream_name] = next(iterators[stream_name])
83
- except StopIteration:
84
- continue # an empty stream, goto next stream
85
- instance = first_instances[stream_name]
86
- dict_set(
87
- dic=global_score,
88
- query=stream_name.split("~")[-1],
89
- value=deepcopy(instance["score"]["global"]),
90
- not_exist_ok=True,
91
- )
92
-
93
- self.update_intermediate_level_scores(global_score)
94
- # update the global_score object for each stream. Recall that all instances
95
- # in each stream link all to same python dict object
96
- for stream_name in multi_stream.keys():
97
- instance = first_instances[stream_name]
98
- instance["score"]["global"].clear()
99
- instance["score"]["global"].update(global_score)
100
-
101
- def never_peek_twice_generator(
102
- stream_name: str, first_instances: dict, iterators: dict
103
- ) -> Generator:
104
- while True:
105
- if stream_name in first_instances:
106
- yield first_instances.pop(stream_name)
107
- try:
108
- yield next(iterators[stream_name])
109
- except StopIteration:
110
- return
111
-
112
- return MultiStream(
113
- {
114
- stream_name: DynamicStream(
115
- never_peek_twice_generator,
116
- gen_kwargs={
117
- "stream_name": stream_name,
118
- "first_instances": first_instances,
119
- "iterators": iterators,
120
- },
121
- )
122
- for stream_name in multi_stream.keys()
123
- }
124
- )
125
 
126
 
127
  class FromPredictionsAndOriginalData(StreamInitializerOperator):
@@ -142,11 +52,6 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
142
  )
143
 
144
 
145
- # The task_data field in the schema is defined as
146
- # Sequence({"key": Value(dtype="string"), "value": Value("string")})
147
- # When receiving instances from this scheme, the keys and values are returned as two separate
148
- # lists, and are converted to a dictionary.
149
-
150
  _post_process_steps = SequentialOperator(
151
  steps=[
152
  Copy(
@@ -156,6 +61,7 @@ _post_process_steps = SequentialOperator(
156
  Copy(
157
  field="references",
158
  to_field="raw_references",
 
159
  ),
160
  Copy(
161
  field="source",
@@ -171,11 +77,177 @@ _post_process_steps = SequentialOperator(
171
  Copy(
172
  field="references",
173
  to_field="processed_references",
 
174
  ),
175
  ]
176
  )
177
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  class PostProcessRecipe(SequentialOperatorInitializer):
180
  def prepare(self):
181
  register_all_artifacts()
@@ -185,10 +257,10 @@ class PostProcessRecipe(SequentialOperatorInitializer):
185
  ]
186
 
187
 
188
- def _post_process(
189
  predictions: List[str],
190
  references: Iterable,
191
- split_name: str = "all",
192
  ):
193
  _reset_env_local_catalogs()
194
  register_all_artifacts()
@@ -203,7 +275,7 @@ def _post_process(
203
 
204
  class MetricRecipe(SequentialOperatorInitializer):
205
  calc_confidence_intervals: bool = True
206
- number_of_fusion_generations: int = 2
207
 
208
  def prepare(self):
209
  register_all_artifacts()
@@ -211,21 +283,19 @@ class MetricRecipe(SequentialOperatorInitializer):
211
  FromPredictionsAndOriginalData(),
212
  LoadJson(field="task_data"),
213
  _post_process_steps,
214
- SplitByNestedGroup(
215
- field_name_of_group="group",
216
- number_of_fusion_generations=self.number_of_fusion_generations,
217
  ),
218
  ApplyMetric(
219
  "metrics",
220
  calc_confidence_intervals=self.calc_confidence_intervals,
221
  ),
222
- MultiStreamScoreMean(),
223
- MergeStreams(),
224
- RenameFields(
225
  field="raw_prediction",
226
  to_field="prediction",
227
  ),
228
- RenameFields(
229
  field="raw_references",
230
  to_field="references",
231
  ),
 
1
  import json
2
+ from collections import defaultdict
3
+ from functools import lru_cache
4
+ from statistics import mean
5
+ from typing import Any, Dict, Iterable, List, Optional
6
 
7
  from datasets import Features, Value
 
8
 
9
  from .dataclass import Dataclass
 
10
  from .operator import (
11
  MultiStreamOperator,
12
  SequentialOperator,
 
18
  ApplyOperatorsField,
19
  Copy,
20
  FlattenInstances,
21
+ Rename,
 
 
22
  )
23
  from .register import _reset_env_local_catalogs, register_all_artifacts
24
  from .schema import UNITXT_DATASET_SCHEMA
25
+ from .settings_utils import get_constants, get_settings
26
  from .stream import DynamicStream, MultiStream
27
  from .struct_data_operators import LoadJson
28
  from .utils import deepcopy
29
 
30
+ constants = get_constants()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def nan_mean(scores):
34
+ return mean(score for score in scores if score == score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  class FromPredictionsAndOriginalData(StreamInitializerOperator):
 
52
  )
53
 
54
 
 
 
 
 
 
55
  _post_process_steps = SequentialOperator(
56
  steps=[
57
  Copy(
 
61
  Copy(
62
  field="references",
63
  to_field="raw_references",
64
+ dont_apply_to_streams=[constants.inference_stream],
65
  ),
66
  Copy(
67
  field="source",
 
77
  Copy(
78
  field="references",
79
  to_field="processed_references",
80
+ dont_apply_to_streams=[constants.inference_stream],
81
  ),
82
  ]
83
  )
84
 
85
 
86
+ @lru_cache(maxsize=None)
87
+ def group_str(json_str):
88
+ data = json.loads(json_str)
89
+ return ",".join(f"{k}:{v}" for k, v in data.items())
90
+
91
+
92
+ class SplitSubsetsAndGroups(MultiStreamOperator):
93
+ """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
94
+
95
+ Args:
96
+ number_of_fusion_generations: int
97
+
98
+ the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
99
+ when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
100
+ (See BaseFusion and its extensions)
101
+ subsets_depth specifies the depth of the prefix by which to split the stream.
102
+ """
103
+
104
+ subsets_field: str = "subset"
105
+ groups_field: str = "groups"
106
+ subset_depth: Optional[int] = None
107
+
108
+ def process(self, multi_stream: MultiStream) -> MultiStream:
109
+ result = defaultdict(list)
110
+
111
+ for stream_name, stream in multi_stream.items():
112
+ for i, instance in enumerate(stream):
113
+ instance["__idx__"] = i
114
+
115
+ for field in [self.subsets_field, self.groups_field]:
116
+ if field not in instance:
117
+ raise ValueError(
118
+ f"Field {field} is missing from instance {instance}"
119
+ )
120
+
121
+ subset_stream_name = (
122
+ stream_name
123
+ + "://"
124
+ + "/".join(instance[self.subsets_field][: self.subset_depth])
125
+ )
126
+
127
+ result[subset_stream_name].append(instance)
128
+
129
+ for group in instance[self.groups_field]:
130
+ result[subset_stream_name + "?" + group_str(group)].append(instance)
131
+
132
+ return MultiStream.from_iterables(result, copying=True)
133
+
134
+
135
+ @lru_cache(maxsize=None)
136
+ def group_str_to_key_value(group_str):
137
+ keys = []
138
+ values = []
139
+ for k_v in group_str.split(","):
140
+ k, v = k_v.split(":")
141
+ if v.isdigit():
142
+ v = int(v)
143
+ keys.append(k)
144
+ values.append(v)
145
+
146
+ if len(keys) == 1:
147
+ key = keys[0]
148
+ else:
149
+ key = tuple(keys)
150
+
151
+ if len(values) == 1:
152
+ value = values[0]
153
+ else:
154
+ value = tuple(values)
155
+
156
+ return key, value
157
+
158
+
159
+ @lru_cache(maxsize=None)
160
+ def stream_name_to_origin_subset_group(stream_name):
161
+ origin, subset_group = stream_name.split("://")
162
+ if "?" in subset_group:
163
+ subset, group = subset_group.split("?")
164
+ else:
165
+ subset, group = subset_group, None
166
+ return origin, subset, group
167
+
168
+
169
+ class JoinSubsetsAndGroups(MultiStreamOperator):
170
+ def process(self, multi_stream: MultiStream) -> MultiStream:
171
+ instances = defaultdict(dict)
172
+ global_scores = defaultdict(dict)
173
+
174
+ for stream_name, stream in multi_stream.items():
175
+ origin, subset, group = stream_name_to_origin_subset_group(stream_name)
176
+
177
+ for i, instance in enumerate(stream):
178
+ global_score = instance["score"].pop("global")
179
+
180
+ idx = instance.pop("__idx__")
181
+ if idx not in instances[origin]:
182
+ instances[origin][idx] = instance
183
+
184
+ # from here below setting the global scores from that stream
185
+ # can be done with first instance only
186
+ if i > 0:
187
+ continue
188
+
189
+ if not group and not subset:
190
+ global_scores[origin]["global"] = global_score
191
+ else:
192
+ path = []
193
+
194
+ if subset:
195
+ path += ["subsets", *subset.split("/")]
196
+
197
+ if group:
198
+ key, value = group_str_to_key_value(group)
199
+ path += ["groups", key, value]
200
+
201
+ target = global_scores[origin]
202
+ for part in path[:-1]:
203
+ if part not in target:
204
+ target[part] = {}
205
+ target = target[part]
206
+ target[path[-1]] = global_score
207
+
208
+ # the leafs always have score_name and score
209
+ def recursive_mean(dic):
210
+ if isinstance(dic, dict):
211
+ if "score" in dic and "score_name" in dic:
212
+ return dic
213
+
214
+ result = {}
215
+ all_scores = []
216
+ for k, v in dic.items():
217
+ score = recursive_mean(v)
218
+ if score is not None:
219
+ all_scores.append(score["score"])
220
+ result[k] = score
221
+
222
+ result["score"] = nan_mean(all_scores)
223
+ result["score_name"] = "subsets_mean"
224
+
225
+ if result:
226
+ return result
227
+
228
+ return None
229
+
230
+ result = {}
231
+ for stream_name, stream_instances in instances.items():
232
+ score = global_scores[stream_name]
233
+
234
+ if "subsets" in score:
235
+ score["subsets"] = recursive_mean(score["subsets"])
236
+ score["global"] = {
237
+ "score": score["subsets"]["score"],
238
+ "score_name": score["subsets"]["score_name"],
239
+ }
240
+
241
+ sorted_instances = []
242
+ for key in sorted(stream_instances.keys()):
243
+ instance = stream_instances[key]
244
+ instance["score"].update(deepcopy(score))
245
+ sorted_instances.append(instance)
246
+ result[stream_name] = sorted_instances
247
+
248
+ return MultiStream.from_iterables(result, copying=True)
249
+
250
+
251
  class PostProcessRecipe(SequentialOperatorInitializer):
252
  def prepare(self):
253
  register_all_artifacts()
 
257
  ]
258
 
259
 
260
+ def _inference_post_process(
261
  predictions: List[str],
262
  references: Iterable,
263
+ split_name: str = constants.inference_stream,
264
  ):
265
  _reset_env_local_catalogs()
266
  register_all_artifacts()
 
275
 
276
  class MetricRecipe(SequentialOperatorInitializer):
277
  calc_confidence_intervals: bool = True
278
+ subset_depth: int = 2
279
 
280
  def prepare(self):
281
  register_all_artifacts()
 
283
  FromPredictionsAndOriginalData(),
284
  LoadJson(field="task_data"),
285
  _post_process_steps,
286
+ SplitSubsetsAndGroups(
287
+ subset_depth=self.subset_depth,
 
288
  ),
289
  ApplyMetric(
290
  "metrics",
291
  calc_confidence_intervals=self.calc_confidence_intervals,
292
  ),
293
+ JoinSubsetsAndGroups(),
294
+ Rename(
 
295
  field="raw_prediction",
296
  to_field="prediction",
297
  ),
298
+ Rename(
299
  field="raw_references",
300
  to_field="references",
301
  ),
metrics.py CHANGED
@@ -21,7 +21,6 @@ from scipy.stats._warnings_errors import DegenerateDataWarning
21
  from .artifact import Artifact, fetch_artifact
22
  from .dataclass import (
23
  AbstractField,
24
- DeprecatedField,
25
  InternalField,
26
  NonPositionalField,
27
  OptionalField,
@@ -1425,11 +1424,7 @@ class MetricPipeline(MultiStreamOperator, Metric):
1425
  main_score: str = None
1426
  preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
1427
  postprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
1428
- postpreprocess_steps: Optional[List[StreamingOperator]] = DeprecatedField(
1429
- metadata={
1430
- "deprecation_msg": "Field 'postpreprocess_steps' is deprecated. Please use 'postprocess_steps' for the same purpose."
1431
- }
1432
- )
1433
  metric: Metric = None
1434
 
1435
  def disable_confidence_interval_calculation(self):
@@ -1446,6 +1441,9 @@ class MetricPipeline(MultiStreamOperator, Metric):
1446
  assert isinstance(
1447
  self.metric, Metric
1448
  ), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
 
 
 
1449
 
1450
  def prepare(self):
1451
  super().prepare()
@@ -4729,3 +4727,170 @@ class MetricsEnsemble(InstanceMetric):
4729
 
4730
  def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
4731
  return {self.main_score: prediction}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from .artifact import Artifact, fetch_artifact
22
  from .dataclass import (
23
  AbstractField,
 
24
  InternalField,
25
  NonPositionalField,
26
  OptionalField,
 
1424
  main_score: str = None
1425
  preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
1426
  postprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
1427
+ postpreprocess_steps: Optional[List[StreamingOperator]] = None
 
 
 
 
1428
  metric: Metric = None
1429
 
1430
  def disable_confidence_interval_calculation(self):
 
1441
  assert isinstance(
1442
  self.metric, Metric
1443
  ), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
1444
+ if self.postpreprocess_steps is not None:
1445
+ depr_message = "Field 'postpreprocess_steps' is deprecated. Please use 'postprocess_steps' for the same purpose."
1446
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1447
 
1448
  def prepare(self):
1449
  super().prepare()
 
4727
 
4728
  def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
4729
  return {self.main_score: prediction}
4730
+
4731
+
4732
+ class F1Strings(InstanceMetric):
4733
+ main_score = "f1_strings"
4734
+ reduction_map = {"mean": ["f1_strings"]}
4735
+ prediction_type = str
4736
+ single_reference_per_prediction = True
4737
+ _requirements_list = {
4738
+ "spacy": "Please pip install spacy",
4739
+ }
4740
+
4741
+ def prepare(self):
4742
+ super().prepare()
4743
+ import spacy
4744
+
4745
+ try:
4746
+ self.nlp = spacy.load("en_core_web_sm")
4747
+ except OSError:
4748
+ from spacy.cli import download
4749
+
4750
+ download("en_core_web_sm")
4751
+ self.nlp = spacy.load("en_core_web_sm")
4752
+
4753
+ def compute(
4754
+ self,
4755
+ references: List[str],
4756
+ prediction: str,
4757
+ task_data: List[Dict],
4758
+ ) -> dict:
4759
+ doc_ref = self.nlp(references[0])
4760
+ set_ref = Counter([token.text.lower() for token in doc_ref])
4761
+ doc_pred = self.nlp(prediction)
4762
+ set_pred = Counter([token.text.lower() for token in doc_pred])
4763
+
4764
+ true_positives = sum((set_ref & set_pred).values())
4765
+ false_positives = sum((set_ref - set_pred).values())
4766
+ false_negatives = sum((set_pred - set_ref).values())
4767
+
4768
+ if true_positives == 0:
4769
+ f1 = 0.0
4770
+ else:
4771
+ precision = true_positives / (true_positives + false_positives)
4772
+ recall = true_positives / (true_positives + false_negatives)
4773
+ if precision + recall == 0:
4774
+ f1 = 0.0
4775
+ else:
4776
+ f1 = 2 * (precision * recall) / (precision + recall)
4777
+
4778
+ return {self.main_score: [f1], "score_name": self.main_score}
4779
+
4780
+
4781
+ class RandomForestMetricsEnsemble(MetricsEnsemble):
4782
+ """This class extends the `MetricsEnsemble` base class and leverages a pre-trained scikit-learn Random Forest classification model to combine and aggregate scores from multiple judges.
4783
+
4784
+ `load_weights` method:
4785
+ Loads model weights from dictionary representation of a random forest classifier.
4786
+ `ensemble` method:
4787
+ Decodes the RandomForestClassifier object and predict a score based on the given instance.
4788
+ """
4789
+
4790
+ _requirements_list: List[str] = ["sklearn"]
4791
+
4792
+ def decode_tree(self, tree_dict, n_features, n_classes, n_outputs):
4793
+ from sklearn.tree._tree import Tree
4794
+
4795
+ tree_dict["nodes"] = [tuple(lst) for lst in tree_dict["nodes"]]
4796
+
4797
+ tree_dict["values"] = np.array(tree_dict["values"])
4798
+ names = [
4799
+ "left_child",
4800
+ "right_child",
4801
+ "feature",
4802
+ "threshold",
4803
+ "impurity",
4804
+ "n_node_samples",
4805
+ "weighted_n_node_samples",
4806
+ "missing_go_to_left",
4807
+ ]
4808
+ tree_dict["nodes"] = np.array(
4809
+ tree_dict["nodes"],
4810
+ dtype=np.dtype({"names": names, "formats": tree_dict["nodes_dtype"]}),
4811
+ )
4812
+
4813
+ tree = Tree(n_features, np.array([n_classes], dtype=np.intp), n_outputs)
4814
+ tree.__setstate__(tree_dict)
4815
+
4816
+ return tree
4817
+
4818
+ def decode_decision_tree(self, model_dict):
4819
+ from sklearn.tree import DecisionTreeClassifier
4820
+
4821
+ decoded_model = DecisionTreeClassifier(**model_dict["params"])
4822
+
4823
+ decoded_model.n_features_in_ = model_dict["n_features_in_"]
4824
+ decoded_model.n_outputs_ = model_dict["n_outputs_"]
4825
+ decoded_model.max_features_ = model_dict["max_features_"]
4826
+ decoded_model.n_classes_ = model_dict["n_classes_"]
4827
+ decoded_model.classes_ = np.array(model_dict["classes_"])
4828
+
4829
+ tree = self.decode_tree(
4830
+ model_dict["tree_"],
4831
+ model_dict["n_features_in_"],
4832
+ model_dict["n_classes_"],
4833
+ model_dict["n_outputs_"],
4834
+ )
4835
+ decoded_model.tree_ = tree
4836
+
4837
+ return decoded_model
4838
+
4839
+ def decode_forest(self, model_dict):
4840
+ from sklearn.ensemble import RandomForestClassifier
4841
+
4842
+ model = RandomForestClassifier(**model_dict["params"])
4843
+ estimators = [
4844
+ self.decode_decision_tree(decision_tree)
4845
+ for decision_tree in model_dict["estimators_"]
4846
+ ]
4847
+ model.estimators_ = np.array(estimators)
4848
+
4849
+ model.n_features_in_ = model_dict["n_features_in_"]
4850
+ model.feature_names_in_ = np.array(model_dict["feature_names_in_"])
4851
+
4852
+ model.min_samples_split = model_dict["min_samples_split"]
4853
+ model.max_depth = model_dict["max_depth"]
4854
+ model.min_samples_leaf = model_dict["min_samples_leaf"]
4855
+ model.min_weight_fraction_leaf = model_dict["min_weight_fraction_leaf"]
4856
+ model.max_features = model_dict["max_features"]
4857
+ model.classes_ = np.array(model_dict["classes_"])
4858
+ model.max_leaf_nodes = model_dict["max_leaf_nodes"]
4859
+ model.min_impurity_decrease = model_dict["min_impurity_decrease"]
4860
+ model.n_outputs_ = model_dict["n_outputs_"]
4861
+
4862
+ if isinstance(model_dict["n_classes_"], list):
4863
+ model.n_classes_ = np.array(model_dict["n_classes_"])
4864
+ else:
4865
+ model.n_classes_ = model_dict["n_classes_"]
4866
+
4867
+ if "oob_score_" in model_dict:
4868
+ model.oob_score_ = model_dict["oob_score_"]
4869
+ if "oob_decision_function_" in model_dict:
4870
+ model.oob_decision_function_ = model_dict["oob_decision_function_"]
4871
+
4872
+ return model
4873
+
4874
+ def prepare(self):
4875
+ super().prepare()
4876
+
4877
+ @staticmethod
4878
+ def load_weights(json_file):
4879
+ with open(json_file) as file:
4880
+ return json.load(file)
4881
+
4882
+ def ensemble(self, instance):
4883
+ assert (
4884
+ self.weights is not None
4885
+ ), "RandomForestMetricsEnsemble must set self.weights before it can be used"
4886
+ ensemble_model = self.decode_forest(self.weights)
4887
+
4888
+ prediction_lst = []
4889
+ for i, metric in enumerate(self.metrics):
4890
+ prediction_lst.append(
4891
+ instance["score"]["instance"][
4892
+ self.get_prefix_name(i) + metric.main_score
4893
+ ]
4894
+ )
4895
+ score = ensemble_model.predict([prediction_lst])
4896
+ return score.tolist()[0]
operator.py CHANGED
@@ -4,9 +4,12 @@ from typing import Any, Dict, Generator, List, Optional, Union
4
 
5
  from .artifact import Artifact
6
  from .dataclass import InternalField, NonPositionalField
 
7
  from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
8
  from .utils import is_module_available
9
 
 
 
10
 
11
  class Operator(Artifact):
12
  pass
@@ -135,8 +138,13 @@ class MultiStreamOperator(StreamingOperator):
135
 
136
  caching: bool = NonPositionalField(default=None)
137
 
138
- def __call__(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
 
 
139
  self.before_process_multi_stream()
 
 
 
140
  result = self._process_multi_stream(multi_stream)
141
  if self.caching is not None:
142
  result.set_caching(self.caching)
@@ -158,11 +166,11 @@ class MultiStreamOperator(StreamingOperator):
158
  def process(self, multi_stream: MultiStream) -> MultiStream:
159
  pass
160
 
161
- def process_instance(self, instance, stream_name="tmp"):
162
  instance = self.verify_instance(instance)
163
  multi_stream = MultiStream({stream_name: stream_single(instance)})
164
  processed_multi_stream = self(multi_stream)
165
- return next(iter(processed_multi_stream[stream_name]))
166
 
167
 
168
  class SourceOperator(MultiStreamOperator):
@@ -214,6 +222,15 @@ class StreamInitializerOperator(SourceOperator):
214
  pass
215
 
216
 
 
 
 
 
 
 
 
 
 
217
  class StreamOperator(MultiStreamOperator):
218
  """A class representing a single-stream operator in the streaming system.
219
 
@@ -277,12 +294,12 @@ class StreamOperator(MultiStreamOperator):
277
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
278
  pass
279
 
280
- def process_instance(self, instance, stream_name="tmp"):
281
  instance = self.verify_instance(instance)
282
  processed_stream = self._process_single_stream(
283
  stream_single(instance), stream_name
284
  )
285
- return next(iter(processed_stream))
286
 
287
 
288
  class SingleStreamOperator(StreamOperator):
@@ -323,10 +340,10 @@ class PagedStreamOperator(StreamOperator):
323
  def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
324
  pass
325
 
326
- def process_instance(self, instance, stream_name="tmp"):
327
  instance = self.verify_instance(instance)
328
  processed_stream = self._process_page([instance], stream_name)
329
- return next(iter(processed_stream))
330
 
331
 
332
  class SingleStreamReducer(StreamingOperator):
@@ -381,7 +398,7 @@ class InstanceOperator(StreamOperator):
381
  ) -> Dict[str, Any]:
382
  pass
383
 
384
- def process_instance(self, instance, stream_name="tmp"):
385
  return self._process_instance(instance, stream_name)
386
 
387
 
@@ -404,30 +421,13 @@ class InstanceOperatorValidator(InstanceOperator):
404
  except StopIteration as e:
405
  raise EmptyStreamError(f"Stream '{stream_name}' is empty") from e
406
  result = self._process_instance(first_instance, stream_name)
407
- self.validate(result)
408
  yield result
409
  yield from (
410
  self._process_instance(instance, stream_name) for instance in iterator
411
  )
412
 
413
 
414
- class BaseFieldOperator(Artifact):
415
- """A class representing a field operator in the streaming system.
416
-
417
- A field operator is a type of `Artifact` that operates on a single field within an instance. It takes an instance and a field name as input, processes the field, and updates the field in the instance with the processed value.
418
- """
419
-
420
- def __call__(self, data: Dict[str, Any], field: str) -> dict:
421
- data = self.verify_instance(data)
422
- value = self.process(data[field])
423
- data[field] = value
424
- return data
425
-
426
- @abstractmethod
427
- def process(self, value: Any) -> Any:
428
- pass
429
-
430
-
431
  class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
432
  """A class representing an instance operator with global access in the streaming system.
433
 
@@ -436,7 +436,12 @@ class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
436
  In order to make this efficient and to avoid qudratic complexity, it caches the accessible streams by default.
437
  """
438
 
439
- def __call__(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
 
 
 
 
 
440
  result = {}
441
 
442
  for stream_name, stream in multi_stream.items():
 
4
 
5
  from .artifact import Artifact
6
  from .dataclass import InternalField, NonPositionalField
7
+ from .settings_utils import get_constants
8
  from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
9
  from .utils import is_module_available
10
 
11
+ constants = get_constants()
12
+
13
 
14
  class Operator(Artifact):
15
  pass
 
138
 
139
  caching: bool = NonPositionalField(default=None)
140
 
141
+ def __call__(
142
+ self, multi_stream: Optional[MultiStream] = None, **instance: Dict[str, Any]
143
+ ) -> Union[MultiStream, Dict[str, Any]]:
144
  self.before_process_multi_stream()
145
+ if instance:
146
+ if multi_stream is not None:
147
+ return self.process_instance(instance)
148
  result = self._process_multi_stream(multi_stream)
149
  if self.caching is not None:
150
  result.set_caching(self.caching)
 
166
  def process(self, multi_stream: MultiStream) -> MultiStream:
167
  pass
168
 
169
+ def process_instance(self, instance, stream_name=constants.instance_stream):
170
  instance = self.verify_instance(instance)
171
  multi_stream = MultiStream({stream_name: stream_single(instance)})
172
  processed_multi_stream = self(multi_stream)
173
+ return instance_result(processed_multi_stream[stream_name])
174
 
175
 
176
  class SourceOperator(MultiStreamOperator):
 
222
  pass
223
 
224
 
225
+ def instance_result(result_stream):
226
+ result = list(result_stream)
227
+ if len(result) == 0:
228
+ return None
229
+ if len(result) == 1:
230
+ return result[0]
231
+ return result
232
+
233
+
234
  class StreamOperator(MultiStreamOperator):
235
  """A class representing a single-stream operator in the streaming system.
236
 
 
294
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
295
  pass
296
 
297
+ def process_instance(self, instance, stream_name=constants.instance_stream):
298
  instance = self.verify_instance(instance)
299
  processed_stream = self._process_single_stream(
300
  stream_single(instance), stream_name
301
  )
302
+ return instance_result(processed_stream)
303
 
304
 
305
  class SingleStreamOperator(StreamOperator):
 
340
  def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
341
  pass
342
 
343
+ def process_instance(self, instance, stream_name=constants.instance_stream):
344
  instance = self.verify_instance(instance)
345
  processed_stream = self._process_page([instance], stream_name)
346
+ return instance_result(processed_stream)
347
 
348
 
349
  class SingleStreamReducer(StreamingOperator):
 
398
  ) -> Dict[str, Any]:
399
  pass
400
 
401
+ def process_instance(self, instance, stream_name=constants.instance_stream):
402
  return self._process_instance(instance, stream_name)
403
 
404
 
 
421
  except StopIteration as e:
422
  raise EmptyStreamError(f"Stream '{stream_name}' is empty") from e
423
  result = self._process_instance(first_instance, stream_name)
424
+ self.validate(result, stream_name)
425
  yield result
426
  yield from (
427
  self._process_instance(instance, stream_name) for instance in iterator
428
  )
429
 
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
432
  """A class representing an instance operator with global access in the streaming system.
433
 
 
436
  In order to make this efficient and to avoid qudratic complexity, it caches the accessible streams by default.
437
  """
438
 
439
+ def __call__(
440
+ self, multi_stream: Optional[MultiStream] = None, **instance: Dict[str, Any]
441
+ ) -> MultiStream:
442
+ if instance:
443
+ raise NotImplementedError("Instance mode is not supported")
444
+
445
  result = {}
446
 
447
  for stream_name, stream in multi_stream.items():
operators.py CHANGED
@@ -14,9 +14,9 @@ To enhance the functionality of Unitxt, users are encouraged to develop custom o
14
  This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
  The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16
 
17
- General or Specelized Operators
18
  --------------------------------
19
- Some operators are specielized in specific data or specific operations such as:
20
 
21
  - :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
  - :class:`splitters<unitxt.splitters>` for fixing data splits.
@@ -28,12 +28,12 @@ Some operators are specielized in specific data or specific operations such as:
28
  - :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
29
  - :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30
 
31
- Other specelized operators are used by unitxt internally:
32
 
33
  - :class:`templates<unitxt.templates>` for verbalizing data examples.
34
  - :class:`formats<unitxt.formats>` for preparing data for models.
35
 
36
- The rest of this section is dedicated for general operators.
37
 
38
  General Operators List:
39
  ------------------------
@@ -42,6 +42,7 @@ General Operators List:
42
  import copy
43
  import operator
44
  import uuid
 
45
  import zipfile
46
  from abc import abstractmethod
47
  from collections import Counter, defaultdict
@@ -63,7 +64,7 @@ from typing import (
63
  import requests
64
 
65
  from .artifact import Artifact, fetch_artifact
66
- from .dataclass import DeprecatedField, NonPositionalField, OptionalField
67
  from .deprecation_utils import deprecation
68
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
69
  from .operator import (
@@ -81,13 +82,14 @@ from .operator import (
81
  StreamOperator,
82
  )
83
  from .random_utils import new_random_generator
84
- from .settings_utils import get_settings
85
  from .stream import DynamicStream, Stream
86
  from .text_utils import nested_tuple_to_string
87
  from .type_utils import isoftype
88
  from .utils import deepcopy, flatten_dict
89
 
90
  settings = get_settings()
 
91
 
92
 
93
  class FromIterables(StreamInitializerOperator):
@@ -253,14 +255,15 @@ class Set(InstanceOperator):
253
  """
254
 
255
  fields: Dict[str, object]
256
- use_query: bool = DeprecatedField(
257
- metadata={
258
- "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
259
- "Please remove this field from your code."
260
- }
261
- )
262
  use_deepcopy: bool = False
263
 
 
 
 
 
 
 
264
  def process(
265
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
266
  ) -> Dict[str, Any]:
@@ -341,19 +344,37 @@ class InstanceFieldOperator(InstanceOperator):
341
  field: Optional[str] = None
342
  to_field: Optional[str] = None
343
  field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
344
- use_query: bool = DeprecatedField(
345
- metadata={
346
- "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
347
- "Please remove this field from your code."
348
- }
349
- )
350
  process_every_value: bool = False
351
  get_default: Any = None
352
  not_exist_ok: bool = False
353
 
354
  def verify(self):
355
  super().verify()
 
 
 
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  assert (
358
  self.field is not None or self.field_to_field is not None
359
  ), "Must supply a field to work on"
@@ -363,7 +384,9 @@ class InstanceFieldOperator(InstanceOperator):
363
  assert (
364
  self.field is None or self.field_to_field is None
365
  ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
366
- assert self._field_to_field, f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
 
 
367
  assert (
368
  len(self._field_to_field) > 0
369
  ), f"'input argument 'field_to_field' should convey at least one field to process. Got {self.field_to_field}"
@@ -402,31 +425,10 @@ class InstanceFieldOperator(InstanceOperator):
402
  def process_instance_value(self, value: Any, instance: Dict[str, Any]):
403
  pass
404
 
405
- def prepare(self):
406
- super().prepare()
407
-
408
- # prepare is invoked before verify, hence must make some checks here, before the changes done here
409
- assert (
410
- (self.field is None) != (self.field_to_field is None)
411
- ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
412
- assert (
413
- self.to_field is None or self.field_to_field is None
414
- ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
415
-
416
- if self.field_to_field is None:
417
- self._field_to_field = [
418
- (self.field, self.to_field if self.to_field is not None else self.field)
419
- ]
420
- else:
421
- self._field_to_field = (
422
- list(self.field_to_field.items())
423
- if isinstance(self.field_to_field, dict)
424
- else self.field_to_field
425
- )
426
-
427
  def process(
428
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
429
  ) -> Dict[str, Any]:
 
430
  # Need to deep copy instance, because when assigning two dictionary fields,
431
  # dict_set() the target field dictionary fields.
432
  # This means that if this target field was assigned to another field before,
@@ -474,23 +476,23 @@ class FieldOperator(InstanceFieldOperator):
474
  pass
475
 
476
 
477
- class RenameFields(FieldOperator):
478
  """Renames fields.
479
 
480
  Move value from one field to another, potentially, if field name contains a /, from one branch into another.
481
  Remove the from field, potentially part of it in case of / in from_field.
482
 
483
  Examples:
484
- RenameFields(field_to_field={"b": "c"})
485
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
486
 
487
- RenameFields(field_to_field={"b": "c/d"})
488
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
489
 
490
- RenameFields(field_to_field={"b": "b/d"})
491
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
492
 
493
- RenameFields(field_to_field={"b/c/e": "b/d"})
494
  will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
495
 
496
  """
@@ -511,6 +513,11 @@ class RenameFields(FieldOperator):
511
  return res
512
 
513
 
 
 
 
 
 
514
  class AddConstant(FieldOperator):
515
  """Adds a constant, being argument 'add', to the processed value.
516
 
@@ -777,7 +784,7 @@ class Apply(InstanceOperator):
777
  Args:
778
  function (str): name of function.
779
  to_field (str): the field to store the result
780
- additional arguments are field names passed to the function
781
 
782
  Examples:
783
  Store in field "b" the uppercase string of the value in field "a"
@@ -846,12 +853,13 @@ class ListFieldValues(InstanceOperator):
846
 
847
  fields: List[str]
848
  to_field: str
849
- use_query: bool = DeprecatedField(
850
- metadata={
851
- "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
852
- "Please remove this field from your code."
853
- }
854
- )
 
855
 
856
  def process(
857
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
@@ -878,12 +886,13 @@ class ZipFieldValues(InstanceOperator):
878
  fields: List[str]
879
  to_field: str
880
  longest: bool = False
881
- use_query: bool = DeprecatedField(
882
- metadata={
883
- "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
884
- "Please remove this field from your code."
885
- }
886
- )
 
887
 
888
  def process(
889
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
@@ -950,12 +959,13 @@ class IndexOf(InstanceOperator):
950
  search_in: str
951
  index_of: str
952
  to_field: str
953
- use_query: bool = DeprecatedField(
954
- metadata={
955
- "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
956
- "Please remove this field from your code."
957
- }
958
- )
 
959
 
960
  def process(
961
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
@@ -972,12 +982,13 @@ class TakeByField(InstanceOperator):
972
  field: str
973
  index: str
974
  to_field: str = None
975
- use_query: bool = DeprecatedField(
976
- metadata={
977
- "deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
978
- "Please remove this field from your code."
979
- }
980
- )
 
981
 
982
  def prepare(self):
983
  if self.to_field is None:
@@ -1060,8 +1071,12 @@ class Copy(FieldOperator):
1060
 
1061
  """
1062
 
 
 
1063
  def process_value(self, value: Any) -> Any:
1064
- return copy.deepcopy(value)
 
 
1065
 
1066
 
1067
  @deprecation(version="2.0.0", alternative=Copy)
@@ -1090,6 +1105,31 @@ class AddID(InstanceOperator):
1090
  return instance
1091
 
1092
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1093
  class CastFields(InstanceOperator):
1094
  """Casts specified fields to specified types.
1095
 
@@ -1247,7 +1287,7 @@ class ApplyOperatorsField(InstanceOperator):
1247
 
1248
  # we now have a list of nanes of operators, each is equipped with process_instance method.
1249
  operator = SequentialOperator(steps=operator_names)
1250
- return operator.process_instance(instance)
1251
 
1252
 
1253
  class FilterByCondition(StreamOperator):
@@ -1767,7 +1807,7 @@ class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
1767
  operator, StreamingOperator
1768
  ), f"Operator {operator_name} must be a StreamOperator"
1769
 
1770
- stream = operator(MultiStream({"tmp": stream}))["tmp"]
1771
 
1772
  yield from stream
1773
 
@@ -1808,7 +1848,7 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1808
  # Here we keep all the fields besides the score, and restore them after the metric finishes.
1809
  first_instance = stream.peek()
1810
  keys_to_restore = set(first_instance.keys()).difference({"score"})
1811
- multi_stream = MultiStream({"tmp": stream})
1812
  multi_stream = CopyFields(
1813
  field_to_field={k: f"{k}_orig" for k in keys_to_restore}
1814
  )(multi_stream)
@@ -1830,7 +1870,7 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1830
  multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
1831
  multi_stream
1832
  )
1833
- stream = multi_stream["tmp"]
1834
  yield from stream
1835
 
1836
 
 
14
  This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
  The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16
 
17
+ General or Specialized Operators
18
  --------------------------------
19
+ Some operators are specialized in specific data or specific operations such as:
20
 
21
  - :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
  - :class:`splitters<unitxt.splitters>` for fixing data splits.
 
28
  - :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
29
  - :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30
 
31
+ Other specialized operators are used by unitxt internally:
32
 
33
  - :class:`templates<unitxt.templates>` for verbalizing data examples.
34
  - :class:`formats<unitxt.formats>` for preparing data for models.
35
 
36
+ The rest of this section is dedicated to general operators.
37
 
38
  General Operators List:
39
  ------------------------
 
42
  import copy
43
  import operator
44
  import uuid
45
+ import warnings
46
  import zipfile
47
  from abc import abstractmethod
48
  from collections import Counter, defaultdict
 
64
  import requests
65
 
66
  from .artifact import Artifact, fetch_artifact
67
+ from .dataclass import NonPositionalField, OptionalField
68
  from .deprecation_utils import deprecation
69
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
70
  from .operator import (
 
82
  StreamOperator,
83
  )
84
  from .random_utils import new_random_generator
85
+ from .settings_utils import get_constants, get_settings
86
  from .stream import DynamicStream, Stream
87
  from .text_utils import nested_tuple_to_string
88
  from .type_utils import isoftype
89
  from .utils import deepcopy, flatten_dict
90
 
91
  settings = get_settings()
92
+ constants = get_constants()
93
 
94
 
95
  class FromIterables(StreamInitializerOperator):
 
255
  """
256
 
257
  fields: Dict[str, object]
258
+ use_query: Optional[bool] = None
 
 
 
 
 
259
  use_deepcopy: bool = False
260
 
261
+ def verify(self):
262
+ super().verify()
263
+ if self.use_query is not None:
264
+ depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
265
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
266
+
267
  def process(
268
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
269
  ) -> Dict[str, Any]:
 
344
  field: Optional[str] = None
345
  to_field: Optional[str] = None
346
  field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
347
+ use_query: Optional[bool] = None
 
 
 
 
 
348
  process_every_value: bool = False
349
  get_default: Any = None
350
  not_exist_ok: bool = False
351
 
352
  def verify(self):
353
  super().verify()
354
+ if self.use_query is not None:
355
+ depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
356
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
357
 
358
+ def verify_field_definition(self):
359
+ if hasattr(self, "_field_to_field") and self._field_to_field is not None:
360
+ return
361
+ assert (
362
+ (self.field is None) != (self.field_to_field is None)
363
+ ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
364
+ assert (
365
+ self.to_field is None or self.field_to_field is None
366
+ ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
367
+
368
+ if self.field_to_field is None:
369
+ self._field_to_field = [
370
+ (self.field, self.to_field if self.to_field is not None else self.field)
371
+ ]
372
+ else:
373
+ self._field_to_field = (
374
+ list(self.field_to_field.items())
375
+ if isinstance(self.field_to_field, dict)
376
+ else self.field_to_field
377
+ )
378
  assert (
379
  self.field is not None or self.field_to_field is not None
380
  ), "Must supply a field to work on"
 
384
  assert (
385
  self.field is None or self.field_to_field is None
386
  ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
387
+ assert (
388
+ self._field_to_field is not None
389
+ ), f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
390
  assert (
391
  len(self._field_to_field) > 0
392
  ), f"'input argument 'field_to_field' should convey at least one field to process. Got {self.field_to_field}"
 
425
  def process_instance_value(self, value: Any, instance: Dict[str, Any]):
426
  pass
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  def process(
429
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
430
  ) -> Dict[str, Any]:
431
+ self.verify_field_definition()
432
  # Need to deep copy instance, because when assigning two dictionary fields,
433
  # dict_set() the target field dictionary fields.
434
  # This means that if this target field was assigned to another field before,
 
476
  pass
477
 
478
 
479
+ class Rename(FieldOperator):
480
  """Renames fields.
481
 
482
  Move value from one field to another, potentially, if field name contains a /, from one branch into another.
483
  Remove the from field, potentially part of it in case of / in from_field.
484
 
485
  Examples:
486
+ Rename(field_to_field={"b": "c"})
487
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
488
 
489
+ Rename(field_to_field={"b": "c/d"})
490
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
491
 
492
+ Rename(field_to_field={"b": "b/d"})
493
  will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
494
 
495
+ Rename(field_to_field={"b/c/e": "b/d"})
496
  will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
497
 
498
  """
 
513
  return res
514
 
515
 
516
+ @deprecation(version="2.0.0", alternative=Rename)
517
+ class RenameFields(Rename):
518
+ pass
519
+
520
+
521
  class AddConstant(FieldOperator):
522
  """Adds a constant, being argument 'add', to the processed value.
523
 
 
784
  Args:
785
  function (str): name of function.
786
  to_field (str): the field to store the result
787
+ any additional arguments are field names whose values will be passed directly to the function specified
788
 
789
  Examples:
790
  Store in field "b" the uppercase string of the value in field "a"
 
853
 
854
  fields: List[str]
855
  to_field: str
856
+ use_query: Optional[bool] = None
857
+
858
+ def verify(self):
859
+ super().verify()
860
+ if self.use_query is not None:
861
+ depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
862
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
863
 
864
  def process(
865
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
 
886
  fields: List[str]
887
  to_field: str
888
  longest: bool = False
889
+ use_query: Optional[bool] = None
890
+
891
+ def verify(self):
892
+ super().verify()
893
+ if self.use_query is not None:
894
+ depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
895
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
896
 
897
  def process(
898
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
 
959
  search_in: str
960
  index_of: str
961
  to_field: str
962
+ use_query: Optional[bool] = None
963
+
964
+ def verify(self):
965
+ super().verify()
966
+ if self.use_query is not None:
967
+ depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
968
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
969
 
970
  def process(
971
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
 
982
  field: str
983
  index: str
984
  to_field: str = None
985
+ use_query: Optional[bool] = None
986
+
987
+ def verify(self):
988
+ super().verify()
989
+ if self.use_query is not None:
990
+ depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
991
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
992
 
993
  def prepare(self):
994
  if self.to_field is None:
 
1071
 
1072
  """
1073
 
1074
+ use_deep_copy: bool = True
1075
+
1076
  def process_value(self, value: Any) -> Any:
1077
+ if self.use_deep_copy:
1078
+ return copy.deepcopy(value)
1079
+ return value
1080
 
1081
 
1082
  @deprecation(version="2.0.0", alternative=Copy)
 
1105
  return instance
1106
 
1107
 
1108
+ class Cast(FieldOperator):
1109
+ """Casts specified fields to specified types.
1110
+
1111
+ Args:
1112
+ default (object): A dictionary mapping field names to default values for cases of casting failure.
1113
+ process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
1114
+ """
1115
+
1116
+ to: str
1117
+ failure_default: Optional[Any] = "__UNDEFINED__"
1118
+
1119
+ def prepare(self):
1120
+ self.types = {"int": int, "float": float, "str": str, "bool": bool}
1121
+
1122
+ def process_value(self, value):
1123
+ try:
1124
+ return self.types[self.to](value)
1125
+ except ValueError as e:
1126
+ if self.failure_default == "__UNDEFINED__":
1127
+ raise ValueError(
1128
+ f'Failed to cast value {value} to type "{self.to}", and no default value is provided.'
1129
+ ) from e
1130
+ return self.failure_default
1131
+
1132
+
1133
  class CastFields(InstanceOperator):
1134
  """Casts specified fields to specified types.
1135
 
 
1287
 
1288
  # we now have a list of nanes of operators, each is equipped with process_instance method.
1289
  operator = SequentialOperator(steps=operator_names)
1290
+ return operator.process_instance(instance, stream_name=stream_name)
1291
 
1292
 
1293
  class FilterByCondition(StreamOperator):
 
1807
  operator, StreamingOperator
1808
  ), f"Operator {operator_name} must be a StreamOperator"
1809
 
1810
+ stream = operator(MultiStream({stream_name: stream}))[stream_name]
1811
 
1812
  yield from stream
1813
 
 
1848
  # Here we keep all the fields besides the score, and restore them after the metric finishes.
1849
  first_instance = stream.peek()
1850
  keys_to_restore = set(first_instance.keys()).difference({"score"})
1851
+ multi_stream = MultiStream({stream_name: stream})
1852
  multi_stream = CopyFields(
1853
  field_to_field={k: f"{k}_orig" for k in keys_to_restore}
1854
  )(multi_stream)
 
1870
  multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
1871
  multi_stream
1872
  )
1873
+ stream = multi_stream[stream_name]
1874
  yield from stream
1875
 
1876
 
processors.py CHANGED
@@ -1,10 +1,38 @@
1
  import ast
 
2
  import json
3
  import re
4
  from difflib import get_close_matches
5
  from typing import Any, Dict
6
 
 
 
7
  from .operators import FieldOperator, InstanceFieldOperator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  class ToString(FieldOperator):
@@ -114,11 +142,16 @@ class LowerCaseTillPunc(FieldOperator):
114
  return non_empty_line
115
 
116
 
117
- class LowerCase(FieldOperator):
118
  def process_value(self, text: Any) -> Any:
119
  return text.lower()
120
 
121
 
 
 
 
 
 
122
  class Capitalize(FieldOperator):
123
  def process_value(self, text: Any) -> Any:
124
  return text.capitalize()
@@ -212,7 +245,7 @@ class StanceToProCon(FieldOperator):
212
  return "none"
213
 
214
 
215
- class StringOrNotString(FieldOperator):
216
  string: str
217
 
218
  def process_value(self, text: Any) -> Any:
@@ -223,6 +256,11 @@ class StringOrNotString(FieldOperator):
223
  return text
224
 
225
 
 
 
 
 
 
226
  class ExtractMtBenchRatingJudgment(FieldOperator):
227
  def process_value(self, text: Any) -> Any:
228
  match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
 
1
  import ast
2
+ import copy
3
  import json
4
  import re
5
  from difflib import get_close_matches
6
  from typing import Any, Dict
7
 
8
+ from .deprecation_utils import deprecation
9
+ from .operator import MultiStreamOperator
10
  from .operators import FieldOperator, InstanceFieldOperator
11
+ from .settings_utils import get_constants
12
+
13
+ constants = get_constants()
14
+
15
+
16
+ class PostProcess(MultiStreamOperator):
17
+ operator: InstanceFieldOperator
18
+ process_prediction: bool = True
19
+ process_references: bool = True
20
+
21
+ def prepare(self):
22
+ super().prepare()
23
+ self.prediction_operator = copy.deepcopy(self.operator)
24
+ self.prediction_operator.field = "prediction"
25
+ self.references_operator = copy.deepcopy(self.operator)
26
+ self.references_operator.field = "references"
27
+ self.references_operator.process_every_value = True
28
+ self.references_operator.dont_apply_to_streams = [constants.inference_stream]
29
+
30
+ def process(self, multi_stream):
31
+ if self.process_prediction:
32
+ multi_stream = self.prediction_operator(multi_stream)
33
+ if self.process_references:
34
+ multi_stream = self.references_operator(multi_stream)
35
+ return multi_stream
36
 
37
 
38
  class ToString(FieldOperator):
 
142
  return non_empty_line
143
 
144
 
145
+ class Lower(FieldOperator):
146
  def process_value(self, text: Any) -> Any:
147
  return text.lower()
148
 
149
 
150
+ @deprecation("2.0.0", alternative=Lower)
151
+ class LowerCase(Lower):
152
+ pass
153
+
154
+
155
  class Capitalize(FieldOperator):
156
  def process_value(self, text: Any) -> Any:
157
  return text.capitalize()
 
245
  return "none"
246
 
247
 
248
+ class StringEquals(FieldOperator):
249
  string: str
250
 
251
  def process_value(self, text: Any) -> Any:
 
256
  return text
257
 
258
 
259
+ @deprecation("2.0.0", alternative=StringEquals)
260
+ class StringOrNotString(StringEquals):
261
+ pass
262
+
263
+
264
  class ExtractMtBenchRatingJudgment(FieldOperator):
265
  def process_value(self, text: Any) -> Any:
266
  match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
schema.py CHANGED
@@ -1,10 +1,14 @@
1
  import json
2
- from typing import Any, Dict, Optional
3
 
4
- from datasets import Features, Sequence, Value
5
 
6
  from .artifact import Artifact
 
7
  from .operator import InstanceOperatorValidator
 
 
 
8
 
9
  UNITXT_DATASET_SCHEMA = Features(
10
  {
@@ -12,7 +16,24 @@ UNITXT_DATASET_SCHEMA = Features(
12
  "target": Value("string"),
13
  "references": Sequence(Value("string")),
14
  "metrics": Sequence(Value("string")),
15
- "group": Value("string"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  "postprocessors": Sequence(Value("string")),
17
  "task_data": Value(dtype="string"),
18
  "data_classification_policy": Sequence(Value("string")),
@@ -20,7 +41,14 @@ UNITXT_DATASET_SCHEMA = Features(
20
  )
21
 
22
 
 
 
 
 
 
 
23
  class Finalize(InstanceOperatorValidator):
 
24
  remove_unnecessary_fields: bool = True
25
 
26
  @staticmethod
@@ -29,33 +57,63 @@ class Finalize(InstanceOperatorValidator):
29
  return artifact.to_dict()
30
  return artifact.__id__
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def process(
33
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
34
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
35
  task_data = {
36
  **instance["input_fields"],
37
- **instance["reference_fields"],
38
- "metadata": {
39
- "data_classification_policy": instance["data_classification_policy"],
40
- "template": self.artifact_to_jsonable(
41
- instance["recipe_metadata"]["template"]
42
- ),
43
- "num_demos": instance["recipe_metadata"]["num_demos"],
44
- },
45
  }
 
 
 
 
46
  instance["task_data"] = json.dumps(task_data)
47
 
48
  if self.remove_unnecessary_fields:
49
  keys_to_delete = []
50
 
51
  for key in instance.keys():
52
- if key not in UNITXT_DATASET_SCHEMA:
53
  keys_to_delete.append(key)
54
 
55
  for key in keys_to_delete:
56
  del instance[key]
57
- if "group" not in instance:
58
- instance["group"] = "unitxt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  instance["metrics"] = [
60
  metric.to_json() if isinstance(metric, Artifact) else metric
61
  for metric in instance["metrics"]
@@ -64,6 +122,7 @@ class Finalize(InstanceOperatorValidator):
64
  processor.to_json() if isinstance(processor, Artifact) else processor
65
  for processor in instance["postprocessors"]
66
  ]
 
67
  return instance
68
 
69
  def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
@@ -72,7 +131,8 @@ class Finalize(InstanceOperatorValidator):
72
  assert isinstance(
73
  instance, dict
74
  ), f"Instance should be a dict, got {type(instance)}"
 
75
  assert all(
76
- key in instance for key in UNITXT_DATASET_SCHEMA
77
- ), f"Instance should have the following keys: {UNITXT_DATASET_SCHEMA}. Instance is: {instance}"
78
- UNITXT_DATASET_SCHEMA.encode_example(instance)
 
1
  import json
2
+ from typing import Any, Dict, List, Optional
3
 
4
+ from datasets import Audio, Features, Image, Sequence, Value
5
 
6
  from .artifact import Artifact
7
+ from .dict_utils import dict_get
8
  from .operator import InstanceOperatorValidator
9
+ from .settings_utils import get_constants
10
+
11
+ constants = get_constants()
12
 
13
  UNITXT_DATASET_SCHEMA = Features(
14
  {
 
16
  "target": Value("string"),
17
  "references": Sequence(Value("string")),
18
  "metrics": Sequence(Value("string")),
19
+ "groups": Sequence(Value("string")),
20
+ "subset": Sequence(Value("string")),
21
+ "media": {
22
+ "images": Sequence(Image()),
23
+ "audios": Sequence(Audio()),
24
+ },
25
+ "postprocessors": Sequence(Value("string")),
26
+ "task_data": Value(dtype="string"),
27
+ "data_classification_policy": Sequence(Value("string")),
28
+ }
29
+ )
30
+
31
+ UNITXT_INFERENCE_SCHEMA = Features(
32
+ {
33
+ "source": Value("string"),
34
+ "metrics": Sequence(Value("string")),
35
+ "groups": Sequence(Value("string")),
36
+ "subset": Sequence(Value("string")),
37
  "postprocessors": Sequence(Value("string")),
38
  "task_data": Value(dtype="string"),
39
  "data_classification_policy": Sequence(Value("string")),
 
41
  )
42
 
43
 
44
+ def get_schema(stream_name):
45
+ if stream_name == constants.inference_stream:
46
+ return UNITXT_INFERENCE_SCHEMA
47
+ return UNITXT_DATASET_SCHEMA
48
+
49
+
50
  class Finalize(InstanceOperatorValidator):
51
+ group_by: List[List[str]]
52
  remove_unnecessary_fields: bool = True
53
 
54
  @staticmethod
 
57
  return artifact.to_dict()
58
  return artifact.__id__
59
 
60
+ def _prepare_media(self, instance):
61
+ if "media" not in instance:
62
+ instance["media"] = {}
63
+
64
+ if "images" not in instance["media"]:
65
+ instance["media"]["images"] = []
66
+
67
+ if "audios" not in instance["media"]:
68
+ instance["media"]["audios"] = []
69
+
70
+ return instance
71
+
72
  def process(
73
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
74
  ) -> Dict[str, Any]:
75
+ metadata = {
76
+ "data_classification_policy": instance["data_classification_policy"],
77
+ "template": self.artifact_to_jsonable(
78
+ instance["recipe_metadata"]["template"]
79
+ ),
80
+ "num_demos": instance["recipe_metadata"]["num_demos"],
81
+ }
82
  task_data = {
83
  **instance["input_fields"],
84
+ "metadata": metadata,
 
 
 
 
 
 
 
85
  }
86
+
87
+ if stream_name != constants.inference_stream:
88
+ task_data = {**task_data, **instance["reference_fields"]}
89
+
90
  instance["task_data"] = json.dumps(task_data)
91
 
92
  if self.remove_unnecessary_fields:
93
  keys_to_delete = []
94
 
95
  for key in instance.keys():
96
+ if key not in get_schema(stream_name):
97
  keys_to_delete.append(key)
98
 
99
  for key in keys_to_delete:
100
  del instance[key]
101
+
102
+ data = {**task_data, **metadata}
103
+ groups = []
104
+ for group_attributes in self.group_by:
105
+ group = {}
106
+ if isinstance(group_attributes, str):
107
+ group_attributes = [group_attributes]
108
+ for attribute in group_attributes:
109
+ group[attribute] = dict_get(data, attribute)
110
+ groups.append(json.dumps(group))
111
+
112
+ instance["groups"] = groups
113
+ instance["subset"] = []
114
+
115
+ instance = self._prepare_media(instance)
116
+
117
  instance["metrics"] = [
118
  metric.to_json() if isinstance(metric, Artifact) else metric
119
  for metric in instance["metrics"]
 
122
  processor.to_json() if isinstance(processor, Artifact) else processor
123
  for processor in instance["postprocessors"]
124
  ]
125
+
126
  return instance
127
 
128
  def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
 
131
  assert isinstance(
132
  instance, dict
133
  ), f"Instance should be a dict, got {type(instance)}"
134
+ schema = get_schema(stream_name)
135
  assert all(
136
+ key in instance for key in schema
137
+ ), f"Instance should have the following keys: {schema}. Instance is: {instance}"
138
+ schema.encode_example(instance)
settings_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import importlib.util
2
  import os
 
3
 
4
  from .version import version
5
 
@@ -87,6 +88,17 @@ class Settings:
87
  self.environment_variable_key_name(key) for key in self._settings.keys()
88
  ]
89
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  class Constants:
92
  _instance = None
@@ -163,6 +175,8 @@ if Constants.is_uninitilized():
163
  ]
164
  constants.codebase_url = "https://github.com/IBM/unitxt"
165
  constants.website_url = "https://www.unitxt.org"
 
 
166
 
167
 
168
  def get_settings():
 
1
  import importlib.util
2
  import os
3
+ from contextlib import contextmanager
4
 
5
  from .version import version
6
 
 
88
  self.environment_variable_key_name(key) for key in self._settings.keys()
89
  ]
90
 
91
+ @contextmanager
92
+ def context(self, **kwargs):
93
+ old_values = {key: self._settings.get(key, None) for key in kwargs}
94
+ try:
95
+ for key, value in kwargs.items():
96
+ self.__setattr__(key, value)
97
+ yield
98
+ finally:
99
+ for key, value in old_values.items():
100
+ self.__setattr__(key, value)
101
+
102
 
103
  class Constants:
104
  _instance = None
 
175
  ]
176
  constants.codebase_url = "https://github.com/IBM/unitxt"
177
  constants.website_url = "https://www.unitxt.org"
178
+ constants.inference_stream = "__INFERENCE_STREAM__"
179
+ constants.instance_stream = "__INSTANCE_STREAM__"
180
 
181
 
182
  def get_settings():
standard.py CHANGED
@@ -9,11 +9,14 @@ from .operator import SequentialOperator, SourceSequentialOperator, StreamingOpe
9
  from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
10
  from .recipe import Recipe
11
  from .schema import Finalize
 
12
  from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
13
  from .stream import MultiStream
14
  from .system_prompts import EmptySystemPrompt, SystemPrompt
 
15
  from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template
16
 
 
17
  logger = get_logger()
18
 
19
 
@@ -24,7 +27,8 @@ class CreateDemosPool(SeparateSplit):
24
 
25
  class BaseRecipe(Recipe, SourceSequentialOperator):
26
  # Base parameters
27
- card: TaskCard
 
28
  template: Union[Template, List[Template]] = None
29
  system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
30
  format: Format = Field(default_factory=SystemFormat)
@@ -34,6 +38,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
34
  metrics: List[str] = NonPositionalField(default=None)
35
  postprocessors: List[str] = NonPositionalField(default=None)
36
 
 
 
37
  loader_limit: int = None
38
 
39
  max_train_instances: int = None
@@ -68,6 +74,17 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
68
 
69
  def verify(self):
70
  super().verify()
 
 
 
 
 
 
 
 
 
 
 
71
  if self.use_demos:
72
  if self.demos_pool_size is None or self.demos_pool_size < 1:
73
  raise ValueError(
@@ -143,19 +160,18 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
143
  )
144
 
145
  def set_pipelines(self):
146
- self.loading = SequentialOperator()
147
- self.loading.__description__ = "Loading the data from the data source."
148
- self.metadata = SequentialOperator()
149
- self.metadata.__description__ = (
150
- "Adding metadata (e.g. format, system prompt, template) "
151
  )
152
- self.standardization = SequentialOperator()
153
- self.standardization.__description__ = (
154
- "Standardizing the raw dataset fields to task field definition."
155
  )
156
- self.processing = SequentialOperator()
157
- self.processing.__description__ = (
158
- "Setting task fields (and selecting demos per sample if needed)."
 
 
 
159
  )
160
  self.verbalization = SequentialOperator()
161
  self.verbalization.__description__ = "Verbalizing the input to the model and gold references to the 'source', 'target' and 'references' fields."
@@ -197,8 +213,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
197
  self._demos_pool_cache = None
198
 
199
  def production_preprocess(self, task_instances):
200
- ms = MultiStream.from_iterables({"__inference__": task_instances})
201
- return list(self.inference_instance(ms)["__inference__"])
202
 
203
  def production_demos_pool(self):
204
  if self.use_demos:
@@ -222,30 +238,34 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
222
  self.before_process_multi_stream()
223
  multi_stream = MultiStream.from_iterables(
224
  {
225
- "__inference__": self.production_preprocess(task_instances),
226
  self.demos_pool_name: self.production_demos_pool(),
227
  }
228
  )
229
  multi_stream = self.inference(multi_stream)
230
- return list(multi_stream["__inference__"])
231
 
232
- def prepare(self):
233
- # To avoid the Python's mutable default list trap, we set the default value to None
234
- # and then set it to an empty list if it is None.
235
- if self.card.preprocess_steps is None:
236
  self.card.preprocess_steps = []
237
 
238
- self.set_pipelines()
 
239
 
240
- loader = self.card.loader
241
- if self.loader_limit:
242
- loader.loader_limit = self.loader_limit
243
- logger.info(f"Loader line limit was set to {self.loader_limit}")
244
- self.loading.steps.append(loader)
245
 
246
- # This is required in case loader_limit is not enforced by the loader
247
- if self.loader_limit:
248
- self.loading.steps.append(StreamRefiner(max_instances=self.loader_limit))
 
 
 
 
 
 
 
 
 
249
 
250
  self.metadata.steps.append(
251
  Set(
@@ -256,9 +276,10 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
256
  )
257
  )
258
 
259
- self.standardization.steps.extend(self.card.preprocess_steps)
 
260
 
261
- self.processing.steps.append(self.card.task)
262
 
263
  if self.augmentor.augment_task_input:
264
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
@@ -352,7 +373,10 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
352
  if self.metrics is not None:
353
  self.finalize.steps.append(Set(fields={"metrics": self.metrics}))
354
 
355
- self.finalize.steps.append(Finalize())
 
 
 
356
 
357
 
358
  class StandardRecipeWithIndexes(BaseRecipe):
@@ -395,6 +419,7 @@ class StandardRecipe(StandardRecipeWithIndexes):
395
  format (SystemFormat, optional): SystemFormat object to be used for the recipe.
396
  metrics (List[str]): list of catalog metrics to use with this recipe.
397
  postprocessors (List[str]): list of catalog processors to apply at post processing. (Not recommended to use from here)
 
398
  train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
399
  max_train_instances (int, optional): Maximum training instances for the refiner.
400
  validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
 
9
  from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
10
  from .recipe import Recipe
11
  from .schema import Finalize
12
+ from .settings_utils import get_constants
13
  from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
14
  from .stream import MultiStream
15
  from .system_prompts import EmptySystemPrompt, SystemPrompt
16
+ from .task import Task
17
  from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template
18
 
19
+ constants = get_constants()
20
  logger = get_logger()
21
 
22
 
 
27
 
28
  class BaseRecipe(Recipe, SourceSequentialOperator):
29
  # Base parameters
30
+ card: TaskCard = None
31
+ task: Task = None
32
  template: Union[Template, List[Template]] = None
33
  system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
34
  format: Format = Field(default_factory=SystemFormat)
 
38
  metrics: List[str] = NonPositionalField(default=None)
39
  postprocessors: List[str] = NonPositionalField(default=None)
40
 
41
+ group_by: List[Union[str, List[str]]] = []
42
+
43
  loader_limit: int = None
44
 
45
  max_train_instances: int = None
 
74
 
75
  def verify(self):
76
  super().verify()
77
+
78
+ if self.task is None and self.card is None:
79
+ raise ValueError("Set card or task in the recipe")
80
+
81
+ if self.card is None and (
82
+ self.num_demos > 0 or self.demos_pool_size is not None
83
+ ):
84
+ raise ValueError(
85
+ "To use num_demos and demos_pool_size in recipe set a card."
86
+ )
87
+
88
  if self.use_demos:
89
  if self.demos_pool_size is None or self.demos_pool_size < 1:
90
  raise ValueError(
 
160
  )
161
 
162
  def set_pipelines(self):
163
+ self.loading = SequentialOperator(
164
+ __description__="Loading the data from the data source."
 
 
 
165
  )
166
+ self.metadata = SequentialOperator(
167
+ __description__="Adding metadata (e.g. format, system prompt, template) "
 
168
  )
169
+ self.standardization = SequentialOperator(
170
+ __description__="Standardizing the raw dataset fields to task field definition."
171
+ )
172
+
173
+ self.processing = SequentialOperator(
174
+ __description__="Setting task fields (and selecting demos per sample if needed)."
175
  )
176
  self.verbalization = SequentialOperator()
177
  self.verbalization.__description__ = "Verbalizing the input to the model and gold references to the 'source', 'target' and 'references' fields."
 
213
  self._demos_pool_cache = None
214
 
215
  def production_preprocess(self, task_instances):
216
+ ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
217
+ return list(self.inference_instance(ms)[constants.inference_stream])
218
 
219
  def production_demos_pool(self):
220
  if self.use_demos:
 
238
  self.before_process_multi_stream()
239
  multi_stream = MultiStream.from_iterables(
240
  {
241
+ constants.inference_stream: self.production_preprocess(task_instances),
242
  self.demos_pool_name: self.production_demos_pool(),
243
  }
244
  )
245
  multi_stream = self.inference(multi_stream)
246
+ return list(multi_stream[constants.inference_stream])
247
 
248
+ def reset_pipeline(self):
249
+ if self.card and self.card.preprocess_steps is None:
 
 
250
  self.card.preprocess_steps = []
251
 
252
+ if self.task is None:
253
+ self.task = self.card.task
254
 
255
+ self.set_pipelines()
 
 
 
 
256
 
257
+ if self.card is not None:
258
+ loader = self.card.loader
259
+ if self.loader_limit:
260
+ loader.loader_limit = self.loader_limit
261
+ logger.info(f"Loader line limit was set to {self.loader_limit}")
262
+ self.loading.steps.append(loader)
263
+
264
+ # This is required in case loader_limit is not enforced by the loader
265
+ if self.loader_limit:
266
+ self.loading.steps.append(
267
+ StreamRefiner(max_instances=self.loader_limit)
268
+ )
269
 
270
  self.metadata.steps.append(
271
  Set(
 
276
  )
277
  )
278
 
279
+ if self.card:
280
+ self.standardization.steps.extend(self.card.preprocess_steps)
281
 
282
+ self.processing.steps.append(self.task)
283
 
284
  if self.augmentor.augment_task_input:
285
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
 
373
  if self.metrics is not None:
374
  self.finalize.steps.append(Set(fields={"metrics": self.metrics}))
375
 
376
+ self.finalize.steps.append(Finalize(group_by=self.group_by))
377
+
378
+ def prepare(self):
379
+ self.reset_pipeline()
380
 
381
 
382
  class StandardRecipeWithIndexes(BaseRecipe):
 
419
  format (SystemFormat, optional): SystemFormat object to be used for the recipe.
420
  metrics (List[str]): list of catalog metrics to use with this recipe.
421
  postprocessors (List[str]): list of catalog processors to apply at post processing. (Not recommended to use from here)
422
+ group_by (List[Union[str, List[str]]]): list of task_data or metadata keys to group global scores by.
423
  train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
424
  max_train_instances (int, optional): Maximum training instances for the refiner.
425
  validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
stream.py CHANGED
@@ -222,7 +222,9 @@ class MultiStream(dict):
222
  for stream in self.values():
223
  stream.set_copying(copying)
224
 
225
- def to_dataset(self, disable_cache=True, cache_dir=None) -> DatasetDict:
 
 
226
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
227
  cache_dir = dir_to_be_deleted if disable_cache else cache_dir
228
  return DatasetDict(
@@ -232,6 +234,7 @@ class MultiStream(dict):
232
  keep_in_memory=disable_cache,
233
  cache_dir=cache_dir,
234
  gen_kwargs={"key": key},
 
235
  )
236
  for key in self.keys()
237
  }
@@ -281,7 +284,10 @@ class MultiStream(dict):
281
 
282
  @classmethod
283
  def from_iterables(
284
- cls, iterables: Dict[str, Iterable], caching=False, copying=False
 
 
 
285
  ):
286
  """Creates a MultiStream from a dictionary of iterables.
287
 
 
222
  for stream in self.values():
223
  stream.set_copying(copying)
224
 
225
+ def to_dataset(
226
+ self, disable_cache=True, cache_dir=None, features=None
227
+ ) -> DatasetDict:
228
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
229
  cache_dir = dir_to_be_deleted if disable_cache else cache_dir
230
  return DatasetDict(
 
234
  keep_in_memory=disable_cache,
235
  cache_dir=cache_dir,
236
  gen_kwargs={"key": key},
237
+ features=features,
238
  )
239
  for key in self.keys()
240
  }
 
284
 
285
  @classmethod
286
  def from_iterables(
287
+ cls,
288
+ iterables: Dict[str, Iterable[Dict[str, Any]]],
289
+ caching=False,
290
+ copying=False,
291
  ):
292
  """Creates a MultiStream from a dictionary of iterables.
293
 
struct_data_operators.py CHANGED
@@ -623,3 +623,40 @@ class MapTableListsToStdTableJSON(FieldOperator):
623
 
624
  def map_tablelists_to_stdtablejson_util(self, table_content: str) -> Dict:
625
  return {"header": table_content[0], "rows": table_content[1:]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
  def map_tablelists_to_stdtablejson_util(self, table_content: str) -> Dict:
625
  return {"header": table_content[0], "rows": table_content[1:]}
626
+
627
+
628
+ class ConstructTableFromRowsCols(InstanceOperator):
629
+ """Maps column and row field into single table field encompassing both header and rows.
630
+
631
+ field[0] = header string as List
632
+ field[1] = rows string as List[List]
633
+ field[2] = table caption string(optional)
634
+ """
635
+
636
+ fields: List[str]
637
+ to_field: str
638
+
639
+ def process(
640
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
641
+ ) -> Dict[str, Any]:
642
+ header = dict_get(instance, self.fields[0])
643
+ rows = dict_get(instance, self.fields[1])
644
+
645
+ if len(self.fields) >= 3:
646
+ caption = instance[self.fields[2]]
647
+ else:
648
+ caption = None
649
+
650
+ import ast
651
+
652
+ header_processed = ast.literal_eval(header)
653
+ rows_processed = ast.literal_eval(rows)
654
+
655
+ output_dict = {"header": header_processed, "rows": rows_processed}
656
+
657
+ if caption is not None:
658
+ output_dict["caption"] = caption
659
+
660
+ instance[self.to_field] = output_dict
661
+
662
+ return instance
task.py CHANGED
@@ -1,11 +1,13 @@
 
1
  from functools import lru_cache
2
  from typing import Any, Dict, List, Optional, Union
3
 
4
  from .artifact import fetch_artifact
5
- from .dataclass import DeprecatedField
6
  from .deprecation_utils import deprecation
7
  from .error_utils import Documentation, UnitxtError, UnitxtWarning
 
8
  from .operator import InstanceOperator
 
9
  from .type_utils import (
10
  Type,
11
  get_args,
@@ -19,6 +21,9 @@ from .type_utils import (
19
  verify_required_schema,
20
  )
21
 
 
 
 
22
 
23
  @deprecation(
24
  version="2.0.0",
@@ -57,18 +62,8 @@ class Task(InstanceOperator):
57
 
58
  input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
59
  reference_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
60
- inputs: Union[Dict[str, Type], Dict[str, str], List[str]] = DeprecatedField(
61
- default=None,
62
- metadata={
63
- "deprecation_msg": "The 'inputs' field is deprecated. Please use 'input_fields' instead."
64
- },
65
- )
66
- outputs: Union[Dict[str, Type], Dict[str, str], List[str]] = DeprecatedField(
67
- default=None,
68
- metadata={
69
- "deprecation_msg": "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
70
- },
71
- )
72
  metrics: List[str]
73
  prediction_type: Optional[Union[Type, str]] = None
74
  augmentable_inputs: List[str] = []
@@ -108,6 +103,16 @@ class Task(InstanceOperator):
108
  )
109
 
110
  def verify(self):
 
 
 
 
 
 
 
 
 
 
111
  if self.input_fields is None:
112
  raise UnitxtError(
113
  "Missing attribute in task: 'input_fields' not set.",
@@ -249,19 +254,26 @@ class Task(InstanceOperator):
249
  instance = self.set_default_values(instance)
250
 
251
  verify_required_schema(self.input_fields, instance)
252
- verify_required_schema(self.reference_fields, instance)
253
-
254
  input_fields = {key: instance[key] for key in self.input_fields.keys()}
255
- reference_fields = {key: instance[key] for key in self.reference_fields.keys()}
256
  data_classification_policy = instance.get("data_classification_policy", [])
257
 
258
- return {
259
  "input_fields": input_fields,
260
- "reference_fields": reference_fields,
261
  "metrics": self.metrics,
262
  "data_classification_policy": data_classification_policy,
 
263
  }
264
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  @deprecation(version="2.0.0", alternative=Task)
267
  class FormTask(Task):
 
1
+ import warnings
2
  from functools import lru_cache
3
  from typing import Any, Dict, List, Optional, Union
4
 
5
  from .artifact import fetch_artifact
 
6
  from .deprecation_utils import deprecation
7
  from .error_utils import Documentation, UnitxtError, UnitxtWarning
8
+ from .logging_utils import get_logger
9
  from .operator import InstanceOperator
10
+ from .settings_utils import get_constants
11
  from .type_utils import (
12
  Type,
13
  get_args,
 
21
  verify_required_schema,
22
  )
23
 
24
+ constants = get_constants()
25
+ logger = get_logger()
26
+
27
 
28
  @deprecation(
29
  version="2.0.0",
 
62
 
63
  input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
64
  reference_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
65
+ inputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
66
+ outputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
 
 
 
 
 
 
 
 
 
 
67
  metrics: List[str]
68
  prediction_type: Optional[Union[Type, str]] = None
69
  augmentable_inputs: List[str] = []
 
103
  )
104
 
105
  def verify(self):
106
+ if hasattr(self, "inputs") and self.inputs is not None:
107
+ depr_message = (
108
+ "The 'inputs' field is deprecated. Please use 'input_fields' instead."
109
+ )
110
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
111
+
112
+ if hasattr(self, "outputs") and self.outputs is not None:
113
+ depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
114
+ warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
115
+
116
  if self.input_fields is None:
117
  raise UnitxtError(
118
  "Missing attribute in task: 'input_fields' not set.",
 
254
  instance = self.set_default_values(instance)
255
 
256
  verify_required_schema(self.input_fields, instance)
 
 
257
  input_fields = {key: instance[key] for key in self.input_fields.keys()}
 
258
  data_classification_policy = instance.get("data_classification_policy", [])
259
 
260
+ result = {
261
  "input_fields": input_fields,
 
262
  "metrics": self.metrics,
263
  "data_classification_policy": data_classification_policy,
264
+ "media": instance.get("media", {}),
265
  }
266
 
267
+ if stream_name == constants.inference_stream:
268
+ return result
269
+
270
+ verify_required_schema(self.reference_fields, instance)
271
+ result["reference_fields"] = {
272
+ key: instance[key] for key in self.reference_fields.keys()
273
+ }
274
+
275
+ return result
276
+
277
 
278
  @deprecation(version="2.0.0", alternative=Task)
279
  class FormTask(Task):
templates.py CHANGED
@@ -10,8 +10,11 @@ from .dict_utils import dict_set
10
  from .error_utils import Documentation, UnitxtError
11
  from .operator import InstanceOperator
12
  from .random_utils import new_random_generator
 
13
  from .type_utils import isoftype
14
 
 
 
15
 
16
  class TemplateFormatKeyError(UnitxtError):
17
  def __init__(self, template, data, data_type, format_str, format_name):
@@ -84,20 +87,30 @@ class Template(InstanceOperator):
84
  instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
85
  input_fields
86
  )
87
- target, references = self.reference_fields_to_target_and_references(
88
- reference_fields
89
- )
90
 
91
- return {
92
  **instance,
93
  "source": source,
94
- "target": target,
95
- "references": references,
96
  "instruction": instruction,
97
  "target_prefix": target_prefix,
98
  "postprocessors": self.postprocessors,
99
  }
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  @abstractmethod
102
  def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
103
  pass
@@ -143,8 +156,13 @@ class ApplyTemplate(InstanceOperator):
143
  def get_template(self, instance: Dict[str, Any]) -> Template:
144
  pass
145
 
146
- def apply(self, template: Template, instance: Dict[str, Any]):
147
- return template.process_instance(instance)
 
 
 
 
 
148
 
149
  def process(
150
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
@@ -155,11 +173,11 @@ class ApplyTemplate(InstanceOperator):
155
  if self.demos_field not in instance:
156
  raise ValueError("Demos field is missing.")
157
  instance[self.demos_field] = [
158
- self.apply(template, demo_instance)
159
  for demo_instance in instance[self.demos_field]
160
  ]
161
  dict_set(instance, "recipe_metadata/template", template)
162
- return self.apply(template, instance)
163
 
164
 
165
  class ApplySingleTemplate(ApplyTemplate):
@@ -268,6 +286,9 @@ class PairwiseChoiceTemplate(InputOutputTemplate):
268
  choice_tie_label: str
269
  shuffle: bool
270
 
 
 
 
271
  def verbalize_answer_field(self, reference_fields: Dict[str, object]):
272
  answer = reference_fields[self.answer_field]
273
  assert answer in ["choice_a", "choice_b", "tie"]
@@ -552,34 +573,45 @@ class MultipleChoiceTemplate(Template):
552
 
553
  return target, [target]
554
 
555
- def _shuffle_choices(self, instance):
556
- target_index = self.outputs_to_target_index(instance["reference_fields"])
557
- original_label_choice = instance["reference_fields"][self.choices_field][
558
- target_index
559
- ]
 
560
  choices = instance["input_fields"][self.choices_field]
561
- random_generator = new_random_generator(
562
- {**instance["input_fields"], **instance["reference_fields"]}
563
- )
 
564
  random_generator.shuffle(choices)
565
  instance["input_fields"][self.choices_field] = choices
 
 
 
 
566
  instance["reference_fields"][self.choices_field] = choices
567
  instance["reference_fields"][self.target_field] = choices.index(
568
  original_label_choice
569
  )
 
570
  return instance
571
 
572
  def process(
573
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
574
  ) -> Dict[str, Any]:
575
  if self.shuffle_choices:
576
- instance = self._shuffle_choices(instance)
577
  result = super().process(instance, stream_name)
578
-
579
- if "options" not in result["reference_fields"]:
580
- result["reference_fields"]["options"] = self.inputs_to_choices(
581
- instance["reference_fields"], self.target_choice_format
582
  )
 
 
 
 
 
583
  return result
584
 
585
 
 
10
  from .error_utils import Documentation, UnitxtError
11
  from .operator import InstanceOperator
12
  from .random_utils import new_random_generator
13
+ from .settings_utils import get_constants
14
  from .type_utils import isoftype
15
 
16
+ constants = get_constants()
17
+
18
 
19
  class TemplateFormatKeyError(UnitxtError):
20
  def __init__(self, template, data, data_type, format_str, format_name):
 
87
  instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
88
  input_fields
89
  )
 
 
 
90
 
91
+ result = {
92
  **instance,
93
  "source": source,
 
 
94
  "instruction": instruction,
95
  "target_prefix": target_prefix,
96
  "postprocessors": self.postprocessors,
97
  }
98
 
99
+ if stream_name == constants.inference_stream:
100
+ return result
101
+
102
+ if reference_fields is None:
103
+ raise ValueError("Should have reference_fields")
104
+
105
+ target, references = self.reference_fields_to_target_and_references(
106
+ reference_fields
107
+ )
108
+
109
+ result["target"] = target
110
+ result["references"] = references
111
+
112
+ return result
113
+
114
  @abstractmethod
115
  def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
116
  pass
 
156
  def get_template(self, instance: Dict[str, Any]) -> Template:
157
  pass
158
 
159
+ def apply(
160
+ self,
161
+ template: Template,
162
+ instance: Dict[str, Any],
163
+ stream_name: Optional[str] = None,
164
+ ):
165
+ return template.process_instance(instance, stream_name)
166
 
167
  def process(
168
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
 
173
  if self.demos_field not in instance:
174
  raise ValueError("Demos field is missing.")
175
  instance[self.demos_field] = [
176
+ self.apply(template, demo_instance, stream_name)
177
  for demo_instance in instance[self.demos_field]
178
  ]
179
  dict_set(instance, "recipe_metadata/template", template)
180
+ return self.apply(template, instance, stream_name)
181
 
182
 
183
  class ApplySingleTemplate(ApplyTemplate):
 
286
  choice_tie_label: str
287
  shuffle: bool
288
 
289
+ def verify(self):
290
+ super().verify()
291
+
292
  def verbalize_answer_field(self, reference_fields: Dict[str, object]):
293
  answer = reference_fields[self.answer_field]
294
  assert answer in ["choice_a", "choice_b", "tie"]
 
573
 
574
  return target, [target]
575
 
576
+ def _shuffle_choices(self, instance, stream_name):
577
+ if stream_name != constants.inference_stream:
578
+ target_index = self.outputs_to_target_index(instance["reference_fields"])
579
+ original_label_choice = instance["reference_fields"][self.choices_field][
580
+ target_index
581
+ ]
582
  choices = instance["input_fields"][self.choices_field]
583
+
584
+ random_seed = {**instance["input_fields"]}
585
+
586
+ random_generator = new_random_generator(random_seed)
587
  random_generator.shuffle(choices)
588
  instance["input_fields"][self.choices_field] = choices
589
+
590
+ if stream_name == constants.inference_stream:
591
+ return instance
592
+
593
  instance["reference_fields"][self.choices_field] = choices
594
  instance["reference_fields"][self.target_field] = choices.index(
595
  original_label_choice
596
  )
597
+
598
  return instance
599
 
600
  def process(
601
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
602
  ) -> Dict[str, Any]:
603
  if self.shuffle_choices:
604
+ instance = self._shuffle_choices(instance, stream_name)
605
  result = super().process(instance, stream_name)
606
+ if stream_name == constants.inference_stream:
607
+ result["input_fields"]["options"] = self.inputs_to_choices(
608
+ instance["input_fields"], self.target_choice_format
 
609
  )
610
+ else:
611
+ if "options" not in result["reference_fields"]:
612
+ result["reference_fields"]["options"] = self.inputs_to_choices(
613
+ instance["reference_fields"], self.target_choice_format
614
+ )
615
  return result
616
 
617
 
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.12.3"
 
1
+ version = "1.12.4"