Elron commited on
Commit
9d5b4c0
1 Parent(s): 24278cc

Upload folder using huggingface_hub

Browse files
Files changed (21) hide show
  1. artifact.py +1 -2
  2. blocks.py +1 -1
  3. collections_operators.py +6 -1
  4. dataset.py +1 -0
  5. error_utils.py +50 -0
  6. generator_utils.py +2 -2
  7. inference.py +44 -29
  8. loaders.py +1 -1
  9. metric.py +1 -0
  10. metric_utils.py +1 -1
  11. metrics.py +152 -44
  12. operators.py +1 -2
  13. schema.py +14 -11
  14. splitters.py +56 -47
  15. standard.py +114 -67
  16. stream.py +1 -1
  17. struct_data_operators.py +1 -1
  18. task.py +27 -15
  19. templates.py +76 -21
  20. utils.py +5 -0
  21. version.py +1 -1
artifact.py CHANGED
@@ -5,7 +5,6 @@ import os
5
  import pkgutil
6
  import re
7
  from abc import abstractmethod
8
- from copy import deepcopy
9
  from typing import Any, Dict, List, Optional, Tuple, Union, final
10
 
11
  from .dataclass import (
@@ -23,7 +22,7 @@ from .parsing_utils import (
23
  from .settings_utils import get_constants, get_settings
24
  from .text_utils import camel_to_snake_case, is_camel_case
25
  from .type_utils import issubtype
26
- from .utils import artifacts_json_cache, json_dump, save_to_file
27
 
28
  logger = get_logger()
29
  settings = get_settings()
 
5
  import pkgutil
6
  import re
7
  from abc import abstractmethod
 
8
  from typing import Any, Dict, List, Optional, Tuple, Union, final
9
 
10
  from .dataclass import (
 
22
  from .settings_utils import get_constants, get_settings
23
  from .text_utils import camel_to_snake_case, is_camel_case
24
  from .type_utils import issubtype
25
+ from .utils import artifacts_json_cache, deepcopy, json_dump, save_to_file
26
 
27
  logger = get_logger()
28
  settings = get_settings()
blocks.py CHANGED
@@ -18,7 +18,7 @@ from .operators import (
18
  )
19
  from .processors import ToString, ToStringStripped
20
  from .recipe import SequentialRecipe
21
- from .splitters import RandomSampler, SliceSplit, SplitRandomMix, SpreadSplit
22
  from .stream import MultiStream
23
  from .struct_data_operators import (
24
  ListToKeyValPairs,
 
18
  )
19
  from .processors import ToString, ToStringStripped
20
  from .recipe import SequentialRecipe
21
+ from .splitters import RandomSampler, Sample, SliceSplit, SplitRandomMix
22
  from .stream import MultiStream
23
  from .struct_data_operators import (
24
  ListToKeyValPairs,
collections_operators.py CHANGED
@@ -1,8 +1,8 @@
1
- from copy import deepcopy
2
  from typing import Any, Generator, List, Optional
3
 
4
  from .operators import FieldOperator, StreamOperator
5
  from .stream import Stream
 
6
 
7
 
8
  class Dictify(FieldOperator):
@@ -100,3 +100,8 @@ class DuplicateBySubLists(StreamOperator):
100
  to_field: elements[:i],
101
  }
102
  yield instance_copy
 
 
 
 
 
 
 
1
  from typing import Any, Generator, List, Optional
2
 
3
  from .operators import FieldOperator, StreamOperator
4
  from .stream import Stream
5
+ from .utils import deepcopy
6
 
7
 
8
  class Dictify(FieldOperator):
 
100
  to_field: elements[:i],
101
  }
102
  yield instance_copy
103
+
104
+
105
+ class GetLength(FieldOperator):
106
+ def process_value(self, collection: Any) -> Any:
107
+ return len(collection)
dataset.py CHANGED
@@ -15,6 +15,7 @@ from .dataset_utils import get_dataset_artifact
15
  from .deprecation_utils import __file__ as _
16
  from .dialog_operators import __file__ as _
17
  from .dict_utils import __file__ as _
 
18
  from .eval_utils import __file__ as _
19
  from .file_utils import __file__ as _
20
  from .formats import __file__ as _
 
15
  from .deprecation_utils import __file__ as _
16
  from .dialog_operators import __file__ as _
17
  from .dict_utils import __file__ as _
18
+ from .error_utils import __file__ as _
19
  from .eval_utils import __file__ as _
20
  from .file_utils import __file__ as _
21
  from .formats import __file__ as _
error_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from .logging_utils import get_logger
4
+
5
+ logger = get_logger()
6
+
7
+
8
+ class Documentation:
9
+ URL = "https://www.unitxt.ai/en/latest/"
10
+ HUGGINGFACE_METRICS = "docs/adding_metric.html#adding-a-hugginface-metric"
11
+ ADDING_TASK = "docs/adding_task.html"
12
+ ADDING_TEMPLATE = "docs/adding_template.html"
13
+ MULTIPLE_METRICS_OUTPUTS = (
14
+ "docs/adding_metric.html#metric-outputs-with-multiple-metrics"
15
+ )
16
+
17
+
18
+ def additional_info(path: str) -> str:
19
+ return f"\nFor more information: see {Documentation.URL}/{path} \n"
20
+
21
+
22
+ class UnitxtError(Exception):
23
+ """Exception raised for Unitxt errors.
24
+
25
+ Attributes:
26
+ message : str -- explanation of the error
27
+ additional_info_id : Optional[str] -- relative path to additional documentation on web
28
+ If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
29
+
30
+ """
31
+
32
+ def __init__(self, message: str, additional_info_id: Optional[str] = None):
33
+ if additional_info_id is not None:
34
+ message += additional_info(additional_info_id)
35
+ super().__init__(message)
36
+
37
+
38
+ class UnitxtWarning:
39
+ """Object to format warning message to log.
40
+
41
+ Attributes:
42
+ message -- explanation of the warning
43
+ additional_info_id : Optional[str] -- relative path to additional documentation on web
44
+ If set, should be one of the DOCUMENATION_* constants in the error_utils.py file.
45
+ """
46
+
47
+ def __init__(self, message: str, additional_info_id: Optional[str] = None):
48
+ if additional_info_id is not None:
49
+ message += additional_info(additional_info_id)
50
+ logger.warning(message)
generator_utils.py CHANGED
@@ -1,7 +1,7 @@
1
- import copy
2
  from typing import Any, Dict, List
3
 
4
  from .dataclass import Dataclass, OptionalField
 
5
 
6
 
7
  class ReusableGenerator(Dataclass):
@@ -22,7 +22,7 @@ class ReusableGenerator(Dataclass):
22
  class CopyingReusableGenerator(ReusableGenerator):
23
  def __iter__(self):
24
  for instance in self.activate():
25
- yield copy.deepcopy(instance)
26
 
27
 
28
  # if __name__ == "__main__":
 
 
1
  from typing import Any, Dict, List
2
 
3
  from .dataclass import Dataclass, OptionalField
4
+ from .utils import deepcopy
5
 
6
 
7
  class ReusableGenerator(Dataclass):
 
22
  class CopyingReusableGenerator(ReusableGenerator):
23
  def __iter__(self):
24
  for instance in self.activate():
25
+ yield deepcopy(instance)
26
 
27
 
28
  # if __name__ == "__main__":
inference.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
5
  from tqdm import tqdm
6
 
7
  from .artifact import Artifact
 
8
  from .deprecation_utils import deprecation
9
  from .logging_utils import get_logger
10
  from .operator import PackageRequirementsMixin
@@ -376,13 +377,11 @@ class WMLInferenceEngine(
376
  """Runs inference using ibm-watsonx-ai.
377
 
378
  Attributes:
379
- client: By default, it is created by a class instance but can be directly
380
- provided instead as an instance of 'ibm_watsonx_ai.client.APIClient'.
381
- credentials: By default, it is created by a class instance which tries to retrieve
382
- proper environment variables ("WML_URL", "WML_PROJECT_ID", "WML_APIKEY").
383
- However, either a dictionary with the following keys: "url", "apikey",
384
- "project_id", or an instance of 'ibm_watsonx_ai.credentials.Credentials'
385
- can be directly provided instead.
386
  model_name (str, optional): ID of a model to be used for inference. Mutually
387
  exclusive with 'deployment_id'.
388
  deployment_id (str, optional): Deployment ID of a tuned model to be used for
@@ -412,8 +411,7 @@ class WMLInferenceEngine(
412
  results = wml_inference.infer(dataset["test"])
413
  """
414
 
415
- client: Any = None
416
- credentials: Any = None
417
  model_name: Optional[str] = None
418
  deployment_id: Optional[str] = None
419
  label: str = "wml"
@@ -422,11 +420,40 @@ class WMLInferenceEngine(
422
  "It is advised to have Python version >=3.10 installed, as at lower version this package "
423
  "may cause conflicts with other installed packages."
424
  }
425
- data_classification_policy = ["proprietary"]
426
  parameters: Optional[WMLInferenceEngineParams] = None
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  @staticmethod
429
- def _read_wml_credentials_from_env() -> Dict[str, str]:
 
 
430
  credentials = {}
431
  for env_var_name in ["WML_URL", "WML_PROJECT_ID", "WML_APIKEY"]:
432
  env_var = os.environ.get(env_var_name)
@@ -453,32 +480,20 @@ class WMLInferenceEngine(
453
  return client
454
 
455
  def prepare(self):
456
- if self.client is None:
457
- self.client = self._initialize_wml_client()
458
 
459
  self._set_inference_parameters()
460
 
461
- def verify(self):
462
- assert (
463
- self.model_name
464
- or self.deployment_id
465
- and not (self.model_name and self.deployment_id)
466
- ), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time."
467
- super().verify()
468
-
469
  def _infer(self, dataset):
470
  from ibm_watsonx_ai.foundation_models import ModelInference
471
 
472
  model = ModelInference(
473
  model_id=self.model_name,
474
  deployment_id=self.deployment_id,
475
- api_client=self.client,
476
  )
477
 
478
- return [
479
- model.generate_text(
480
- prompt=instance["source"],
481
- params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
482
- )
483
- for instance in dataset
484
- ]
 
5
  from tqdm import tqdm
6
 
7
  from .artifact import Artifact
8
+ from .dataclass import InternalField
9
  from .deprecation_utils import deprecation
10
  from .logging_utils import get_logger
11
  from .operator import PackageRequirementsMixin
 
377
  """Runs inference using ibm-watsonx-ai.
378
 
379
  Attributes:
380
+ credentials (Dict[str, str], optional): By default, it is created by a class
381
+ instance which tries to retrieve proper environment variables
382
+ ("WML_URL", "WML_PROJECT_ID", "WML_APIKEY"). However, a dictionary with
383
+ the following keys: "url", "apikey", "project_id" can be directly provided
384
+ instead.
 
 
385
  model_name (str, optional): ID of a model to be used for inference. Mutually
386
  exclusive with 'deployment_id'.
387
  deployment_id (str, optional): Deployment ID of a tuned model to be used for
 
411
  results = wml_inference.infer(dataset["test"])
412
  """
413
 
414
+ credentials: Optional[Dict[Literal["url", "apikey", "project_id"], str]] = None
 
415
  model_name: Optional[str] = None
416
  deployment_id: Optional[str] = None
417
  label: str = "wml"
 
420
  "It is advised to have Python version >=3.10 installed, as at lower version this package "
421
  "may cause conflicts with other installed packages."
422
  }
423
+ data_classification_policy = ["public", "proprietary"]
424
  parameters: Optional[WMLInferenceEngineParams] = None
425
 
426
+ _client: Any = InternalField(default=None, name="WML client")
427
+
428
+ def verify(self):
429
+ super().verify()
430
+
431
+ if self.credentials is not None:
432
+ for key in self.credentials:
433
+ if key not in ["url", "apikey", "project_id"]:
434
+ raise ValueError(
435
+ f'Illegal credential key: {key}, use only ["url", "apikey", "project_id"]'
436
+ )
437
+
438
+ assert (
439
+ self.model_name
440
+ or self.deployment_id
441
+ and not (self.model_name and self.deployment_id)
442
+ ), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time."
443
+
444
+ def process_data_before_dump(self, data):
445
+ if "credentials" in data:
446
+ for key, value in data["credentials"].items():
447
+ if key != "url":
448
+ data["credentials"][key] = "<hidden>"
449
+ else:
450
+ data["credentials"][key] = value
451
+ return data
452
+
453
  @staticmethod
454
+ def _read_wml_credentials_from_env() -> (
455
+ Dict[Literal["url", "apikey", "project_id"], str]
456
+ ):
457
  credentials = {}
458
  for env_var_name in ["WML_URL", "WML_PROJECT_ID", "WML_APIKEY"]:
459
  env_var = os.environ.get(env_var_name)
 
480
  return client
481
 
482
  def prepare(self):
483
+ self._client = self._initialize_wml_client()
 
484
 
485
  self._set_inference_parameters()
486
 
 
 
 
 
 
 
 
 
487
  def _infer(self, dataset):
488
  from ibm_watsonx_ai.foundation_models import ModelInference
489
 
490
  model = ModelInference(
491
  model_id=self.model_name,
492
  deployment_id=self.deployment_id,
493
+ api_client=self._client,
494
  )
495
 
496
+ return model.generate_text(
497
+ prompt=dataset["source"],
498
+ params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
499
+ )
 
 
 
loaders.py CHANGED
@@ -36,7 +36,6 @@ import itertools
36
  import os
37
  import tempfile
38
  from abc import abstractmethod
39
- from copy import deepcopy
40
  from pathlib import Path
41
  from tempfile import TemporaryDirectory
42
  from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
@@ -54,6 +53,7 @@ from .operators import Set
54
  from .settings_utils import get_settings
55
  from .stream import DynamicStream, MultiStream
56
  from .type_utils import isoftype
 
57
 
58
  logger = get_logger()
59
  settings = get_settings()
 
36
  import os
37
  import tempfile
38
  from abc import abstractmethod
 
39
  from pathlib import Path
40
  from tempfile import TemporaryDirectory
41
  from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
 
53
  from .settings_utils import get_settings
54
  from .stream import DynamicStream, MultiStream
55
  from .type_utils import isoftype
56
+ from .utils import deepcopy
57
 
58
  logger = get_logger()
59
  settings = get_settings()
metric.py CHANGED
@@ -14,6 +14,7 @@ from .dataset_utils import __file__ as _
14
  from .deprecation_utils import __file__ as _
15
  from .dialog_operators import __file__ as _
16
  from .dict_utils import __file__ as _
 
17
  from .eval_utils import __file__ as _
18
  from .file_utils import __file__ as _
19
  from .formats import __file__ as _
 
14
  from .deprecation_utils import __file__ as _
15
  from .dialog_operators import __file__ as _
16
  from .dict_utils import __file__ as _
17
+ from .error_utils import __file__ as _
18
  from .eval_utils import __file__ as _
19
  from .file_utils import __file__ as _
20
  from .formats import __file__ as _
metric_utils.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- from copy import deepcopy
3
  from typing import Any, Dict, Generator, Iterable, List, Optional
4
 
5
  from datasets import Features, Value
@@ -27,6 +26,7 @@ from .schema import UNITXT_DATASET_SCHEMA
27
  from .settings_utils import get_settings
28
  from .stream import DynamicStream, MultiStream
29
  from .struct_data_operators import LoadJson
 
30
 
31
 
32
  class MultiStreamScoreMean(MultiStreamOperator):
 
1
  import json
 
2
  from typing import Any, Dict, Generator, Iterable, List, Optional
3
 
4
  from datasets import Features, Value
 
26
  from .settings_utils import get_settings
27
  from .stream import DynamicStream, MultiStream
28
  from .struct_data_operators import LoadJson
29
+ from .utils import deepcopy
30
 
31
 
32
  class MultiStreamScoreMean(MultiStreamOperator):
metrics.py CHANGED
@@ -1,15 +1,14 @@
1
  import ast
2
  import json
 
3
  import re
4
  import string
5
  import uuid
6
  import warnings
7
  from abc import ABC, abstractmethod
8
  from collections import Counter, defaultdict
9
- from copy import deepcopy
10
  from dataclasses import field
11
  from operator import itemgetter
12
- from statistics import mean
13
  from typing import Any, Dict, Generator, List, Optional, Tuple, Union
14
 
15
  import evaluate
@@ -22,11 +21,13 @@ from scipy.stats._warnings_errors import DegenerateDataWarning
22
  from .artifact import Artifact, fetch_artifact
23
  from .dataclass import (
24
  AbstractField,
 
25
  InternalField,
26
  NonPositionalField,
27
  OptionalField,
28
  )
29
  from .deprecation_utils import deprecation
 
30
  from .inference import HFPipelineBasedInferenceEngine, InferenceEngine
31
  from .logging_utils import get_logger
32
  from .metric_utils import InstanceInput, MetricRequest, MetricResponse
@@ -42,6 +43,7 @@ from .random_utils import get_seed
42
  from .settings_utils import get_settings
43
  from .stream import MultiStream, Stream
44
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
 
45
 
46
  logger = get_logger()
47
  settings = get_settings()
@@ -141,13 +143,25 @@ class Metric(Artifact):
141
  else score_name
142
  )
143
 
144
- def _add_score_prefixes_to_score_dict(self, scores: Dict[str, Any]):
 
 
145
  new_scores = {}
146
  for score_name, score in scores.items():
147
  score_with_prefix = self._add_score_prefix(score_name)
148
  new_scores[score_with_prefix] = (
149
  score if score_name not in ["score_name"] else self.score_prefix + score
150
  )
 
 
 
 
 
 
 
 
 
 
151
  return new_scores
152
 
153
  def _validate_references_and_prediction(self, references, predictions):
@@ -238,12 +252,14 @@ class Metric(Artifact):
238
  def disable_confidence_interval_calculation(self):
239
  pass
240
 
241
- # update instance["score"]["global"] with the newly computed global score, global_score, for the
242
- # current metric computed. global_score contains "score" and "score_name" fields that reflect
243
- # (the main_score of) the current metric.
 
244
  # A simple python-dictionary-update adds new fields to instance["score"]["global"], and also replaces the values
245
- # of its fields "score" and "score_name", to reflect the current metric, overwriting previous metrics' settings
246
- # of these fields (if any previous metric exists).
 
247
  # When global_score does NOT contain ci score (because CI was not computed for the current metric), but
248
  # one of the previous metrics computed did have, the last of such previous metrics set the values in
249
  # fields "score_ci_low" and "score_ci_high" in instance["score"]["global"] to reflect its
@@ -254,15 +270,25 @@ class Metric(Artifact):
254
  # therefore, not consistent with "score_name".
255
  # In such a case, following the python-dictionary-update, we pop out fields "score_ci_low" and
256
  # "score_ci_high" from instance["score"]["global"], so that now all the fields "score.." in
257
- # instance["score"]["global"] are consistent with the current metric: The current metric
258
- # is named instance["score"]["global"]["score_name"], its score shows in
259
  # field instance["score"]["global"]["score"], and it does not have ci_scores,
260
  # which is also reflected in the absence of fields "score_ci_low" and "score_ci_high" from instance["score"]["global"].
261
  # If ci IS computed for the current metric, global_score contains "score_ci_low" and "score_ci_high", and these overwrite
262
- # the ones existing in instance["score"]["global"] by a simple python-dictionary-update, and no need for any further fixeup.
263
  def update_and_adjust_global_score(
264
  self, instance: Dict[str, Any], global_score: dict
265
  ):
 
 
 
 
 
 
 
 
 
 
266
  instance["score"]["global"].update(global_score)
267
  for score_ci in ["score_ci_low", "score_ci_high"]:
268
  if score_ci in global_score:
@@ -559,12 +585,18 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
559
  instance_score[self.main_score] = no_score_value
560
 
561
  instance["score"]["instance"].update(
562
- self._add_score_prefixes_to_score_dict(instance_score)
 
 
563
  )
564
  self._validate_references_and_prediction(references, predictions)
565
 
566
  result = self._compute(references, predictions, task_data)
567
- global_score.update(self._add_score_prefixes_to_score_dict(result))
 
 
 
 
568
  score_name = global_score["score_name"]
569
  confidence_interval = self.compute_global_confidence_intervals(
570
  references, predictions, task_data, score_name
@@ -657,7 +689,9 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
657
  instance["score"] = {"global": {}, "instance": {}}
658
 
659
  instance["score"]["instance"].update(
660
- self._add_score_prefixes_to_score_dict(score)
 
 
661
  )
662
  instances.append(instance)
663
 
@@ -669,7 +703,7 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
669
  if reduction == "mean":
670
  for field_name in fields:
671
  field_name_with_prefix = self._add_score_prefix(field_name)
672
- global_score[field_name_with_prefix] = mean(
673
  [
674
  instance["score"]["instance"][field_name_with_prefix]
675
  for instance in instances
@@ -1140,7 +1174,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1140
  instance["score"] = {"global": {}, "instance": {}}
1141
 
1142
  instance["score"]["instance"].update(
1143
- self._add_score_prefixes_to_score_dict(instance_score)
 
 
1144
  )
1145
 
1146
  instances.append(instance)
@@ -1326,7 +1362,6 @@ class StringContainment(InstanceMetric):
1326
  ci_scores = ["string_containment"]
1327
 
1328
  prediction_type = Any # string representation is compared
1329
- single_reference_per_prediction = False # multiple references allowed
1330
 
1331
  def compute(
1332
  self, references: List[Any], prediction: Any, task_data: List[Dict]
@@ -1341,11 +1376,59 @@ class StringContainment(InstanceMetric):
1341
  return result
1342
 
1343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1344
  class MetricPipeline(MultiStreamOperator, Metric):
1345
  main_score: str = None
1346
  preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
1347
- postpreprocess_steps: Optional[List[StreamingOperator]] = field(
1348
- default_factory=list
 
 
 
1349
  )
1350
  metric: Metric = None
1351
 
@@ -1366,6 +1449,23 @@ class MetricPipeline(MultiStreamOperator, Metric):
1366
 
1367
  def prepare(self):
1368
  super().prepare()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1369
  self.prepare_score = Copy(
1370
  field_to_field=[
1371
  [
@@ -1383,7 +1483,7 @@ class MetricPipeline(MultiStreamOperator, Metric):
1383
  for step in self.preprocess_steps:
1384
  multi_stream = step(multi_stream)
1385
  multi_stream = self.metric(multi_stream)
1386
- for step in self.postpreprocess_steps:
1387
  multi_stream = step(multi_stream)
1388
  return self.prepare_score(multi_stream)
1389
 
@@ -1409,6 +1509,13 @@ class HuggingfaceMetric(GlobalMetric):
1409
  experiment_id: str = OptionalField(default_factory=lambda: str(uuid.uuid4()))
1410
 
1411
  def verify(self):
 
 
 
 
 
 
 
1412
  assert (
1413
  self.hf_additional_input_fields is None
1414
  or isoftype(self.hf_additional_input_fields, List[str])
@@ -1654,7 +1761,7 @@ class F1(GlobalMetric):
1654
  average=self.average,
1655
  )
1656
  if isinstance(result[self.metric], numpy.ndarray):
1657
- final_result = {self.main_score: mean(result[self.metric])}
1658
  for i, label in enumerate(labels):
1659
  final_result[f"{self.metric}_" + self.id_to_str[label]] = result[
1660
  self.metric
@@ -1959,7 +2066,7 @@ class F1MultiLabel(GlobalMetric):
1959
  assert (
1960
  len(result[self.metric]) == len(labels)
1961
  ), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
1962
- final_result = {self.main_score: mean(result[self.metric])}
1963
  for i, label in enumerate(labels):
1964
  final_result[self.metric + "_" + label] = result[self.metric][i]
1965
  else:
@@ -2001,7 +2108,17 @@ class F1MacroMultiLabel(F1MultiLabel):
2001
  average = None
2002
 
2003
 
2004
- class Rouge(InstanceMetric):
 
 
 
 
 
 
 
 
 
 
2005
  main_score = "rougeL"
2006
  prediction_type = str
2007
  single_reference_per_prediction = False # multiple references allowed
@@ -2014,21 +2131,17 @@ class Rouge(InstanceMetric):
2014
 
2015
  def prepare(self):
2016
  super().prepare()
2017
- import nltk
2018
  from rouge_score import rouge_scorer
2019
 
2020
  self.rouge_scorer = rouge_scorer
2021
 
2022
- nltk.download("punkt", quiet=True)
2023
- self.sent_tokenize = nltk.sent_tokenize
2024
-
2025
  def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
2026
  # for a single instance, prediction is of type str, and references: list of str
2027
  if self.sent_split_newline:
2028
- prediction = "\n".join(self.sent_tokenize(prediction.strip()))
2029
 
2030
  references = [
2031
- "\n".join(self.sent_tokenize(reference.strip()))
2032
  for reference in references
2033
  ]
2034
 
@@ -2044,7 +2157,7 @@ class Rouge(InstanceMetric):
2044
  return score
2045
 
2046
 
2047
- class RougeHF(HuggingfaceInstanceMetric):
2048
  hf_metric_name = "rouge"
2049
  main_score = "rougeL"
2050
  scale = 1.0
@@ -2070,18 +2183,13 @@ class RougeHF(HuggingfaceInstanceMetric):
2070
  {"use_aggregator": False, "rouge_types": self.rouge_types}
2071
  )
2072
 
2073
- import nltk
2074
-
2075
- nltk.download("punkt", quiet=True)
2076
- self.sent_tokenize = nltk.sent_tokenize
2077
-
2078
  def compute(self, references, prediction, task_data: List[Dict]):
2079
  # for a single instance, prediction is of type str, and references: list of str
2080
  if self.sent_split_newline:
2081
- prediction = "\n".join(self.sent_tokenize(prediction.strip()))
2082
 
2083
  references = [
2084
- "\n".join(self.sent_tokenize(reference.strip()))
2085
  for reference in references
2086
  ]
2087
 
@@ -3360,7 +3468,7 @@ class NDCG(GlobalMetric):
3360
  for pred in q_predictions
3361
  ]
3362
  scores.append(self.eval([q_references], [q_predictions]))
3363
- return {self.main_score: mean(scores) if len(scores) > 0 else np.nan}
3364
 
3365
 
3366
  class RetrievalMetric(InstanceMetric):
@@ -3695,8 +3803,8 @@ def performance_drop_rate(
3695
  if any(len(scores) == 0 for scores in group_scores_list):
3696
  # no comparison can be made since there is not at least one score per type
3697
  return np.nan
3698
- control_mean = mean(group_scores_list[0])
3699
- comparison_mean = mean(group_scores_list[1])
3700
  if control_mean == 0:
3701
  # return 0 if comparison is also 0
3702
  if comparison_mean == 0:
@@ -3809,8 +3917,8 @@ def normalized_cohens_h(
3809
  # no comparison can be made since there is not at least one score per type
3810
  h, norm_h = np.nan, np.nan
3811
  else:
3812
- control_mean = mean(group_scores_list[0])
3813
- comparison_mean = mean(group_scores_list[1])
3814
  h = 2 * (np.arcsin(np.sqrt(comparison_mean)) - np.arcsin(np.sqrt(control_mean)))
3815
  norm_h = np.clip(a=h / np.pi, a_min=-1, a_max=1)
3816
 
@@ -3863,7 +3971,7 @@ def normalized_hedges_g(
3863
  g, norm_g = np.nan, np.nan
3864
  else:
3865
  # otherwise, calculate the variances
3866
- group_mean = [mean(scores) for scores in group_scores_list]
3867
  # sample variance with 1 degree of freedom (denominator n-1); if n=1, return 0 since otherwise throws an error
3868
  group_var = [
3869
  0.0 if nn == 1 else np.var(scores, ddof=1)
@@ -3922,7 +4030,7 @@ def mean_subgroup_score(
3922
  if len(score_list) == 0:
3923
  # no scores to use
3924
  return np.nan
3925
- return mean(score_list)
3926
 
3927
 
3928
  # metrics using mean reduction
 
1
  import ast
2
  import json
3
+ import os
4
  import re
5
  import string
6
  import uuid
7
  import warnings
8
  from abc import ABC, abstractmethod
9
  from collections import Counter, defaultdict
 
10
  from dataclasses import field
11
  from operator import itemgetter
 
12
  from typing import Any, Dict, Generator, List, Optional, Tuple, Union
13
 
14
  import evaluate
 
21
  from .artifact import Artifact, fetch_artifact
22
  from .dataclass import (
23
  AbstractField,
24
+ DeprecatedField,
25
  InternalField,
26
  NonPositionalField,
27
  OptionalField,
28
  )
29
  from .deprecation_utils import deprecation
30
+ from .error_utils import Documentation, UnitxtWarning
31
  from .inference import HFPipelineBasedInferenceEngine, InferenceEngine
32
  from .logging_utils import get_logger
33
  from .metric_utils import InstanceInput, MetricRequest, MetricResponse
 
43
  from .settings_utils import get_settings
44
  from .stream import MultiStream, Stream
45
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
46
+ from .utils import deepcopy
47
 
48
  logger = get_logger()
49
  settings = get_settings()
 
143
  else score_name
144
  )
145
 
146
+ def _add_score_prefixes_to_score_dict_and_check_against_existing_scores(
147
+ self, scores: Dict[str, Any], existing_scores: Dict[str, Any]
148
+ ) -> Dict[str, Any]:
149
  new_scores = {}
150
  for score_name, score in scores.items():
151
  score_with_prefix = self._add_score_prefix(score_name)
152
  new_scores[score_with_prefix] = (
153
  score if score_name not in ["score_name"] else self.score_prefix + score
154
  )
155
+ for new_score_name in new_scores:
156
+ if new_score_name in ["score", "score_name"]:
157
+ continue
158
+ if new_score_name in existing_scores:
159
+ UnitxtWarning(
160
+ message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
161
+ f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
162
+ f"To avoid overwriting the existing value, add a score_prefix to the metric (e.g. score_prefix='my_second_').",
163
+ additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
164
+ )
165
  return new_scores
166
 
167
  def _validate_references_and_prediction(self, references, predictions):
 
252
  def disable_confidence_interval_calculation(self):
253
  pass
254
 
255
+ # update instance["score"]["global"] with the global_score just computed for the
256
+ # current metric. global_score contains "score" and "score_name" fields that reflect
257
+ # (the main_score of) the current metric. If CI was computed for global_score, then global_score
258
+ # also contains "score_ci_low" and "score_ci_high" that reflect (the main_score of) the current metric.
259
  # A simple python-dictionary-update adds new fields to instance["score"]["global"], and also replaces the values
260
+ # of its fields "score" and "score_name" (and "score_ci_low", "score_ci_high" if applicable),
261
+ # to reflect the current metric, overwriting previous metrics' settings of these fields
262
+ # (if any previous metric exists).
263
  # When global_score does NOT contain ci score (because CI was not computed for the current metric), but
264
  # one of the previous metrics computed did have, the last of such previous metrics set the values in
265
  # fields "score_ci_low" and "score_ci_high" in instance["score"]["global"] to reflect its
 
270
  # therefore, not consistent with "score_name".
271
  # In such a case, following the python-dictionary-update, we pop out fields "score_ci_low" and
272
  # "score_ci_high" from instance["score"]["global"], so that now all the fields "score.." in
273
+ # instance["score"]["global"] are consistent with the current metric: The metric that is named
274
+ # instance["score"]["global"]["score_name"], its score shows in
275
  # field instance["score"]["global"]["score"], and it does not have ci_scores,
276
  # which is also reflected in the absence of fields "score_ci_low" and "score_ci_high" from instance["score"]["global"].
277
  # If ci IS computed for the current metric, global_score contains "score_ci_low" and "score_ci_high", and these overwrite
278
+ # the ones existing in instance["score"]["global"] by the simple python-dictionary-update, and no need for any further fixeup.
279
  def update_and_adjust_global_score(
280
  self, instance: Dict[str, Any], global_score: dict
281
  ):
282
+ for score_name in global_score:
283
+ if score_name in ["score", "score_name", "score_ci_low", "score_ci_high"]:
284
+ continue
285
+ if score_name in instance["score"]["global"]:
286
+ UnitxtWarning(
287
+ message=f"Global metric '{score_name}' that has just been evaluated to {global_score[score_name]}, is already recorded "
288
+ f"to have value {instance['score']['global'][score_name]} by a previous metric evaluation on this stream. "
289
+ f"To avoid overwriting the value, add a score_prefix to the metric (e.g. score_prefix='my_{score_name}'.",
290
+ additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
291
+ )
292
  instance["score"]["global"].update(global_score)
293
  for score_ci in ["score_ci_low", "score_ci_high"]:
294
  if score_ci in global_score:
 
585
  instance_score[self.main_score] = no_score_value
586
 
587
  instance["score"]["instance"].update(
588
+ self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
589
+ instance_score, instance["score"]["instance"]
590
+ )
591
  )
592
  self._validate_references_and_prediction(references, predictions)
593
 
594
  result = self._compute(references, predictions, task_data)
595
+ global_score.update(
596
+ self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
597
+ result, global_score
598
+ )
599
+ )
600
  score_name = global_score["score_name"]
601
  confidence_interval = self.compute_global_confidence_intervals(
602
  references, predictions, task_data, score_name
 
689
  instance["score"] = {"global": {}, "instance": {}}
690
 
691
  instance["score"]["instance"].update(
692
+ self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
693
+ score, instance["score"]["instance"]
694
+ )
695
  )
696
  instances.append(instance)
697
 
 
703
  if reduction == "mean":
704
  for field_name in fields:
705
  field_name_with_prefix = self._add_score_prefix(field_name)
706
+ global_score[field_name_with_prefix] = nan_mean(
707
  [
708
  instance["score"]["instance"][field_name_with_prefix]
709
  for instance in instances
 
1174
  instance["score"] = {"global": {}, "instance": {}}
1175
 
1176
  instance["score"]["instance"].update(
1177
+ self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1178
+ instance_score, instance["score"]["instance"]
1179
+ )
1180
  )
1181
 
1182
  instances.append(instance)
 
1362
  ci_scores = ["string_containment"]
1363
 
1364
  prediction_type = Any # string representation is compared
 
1365
 
1366
  def compute(
1367
  self, references: List[Any], prediction: Any, task_data: List[Dict]
 
1376
  return result
1377
 
1378
 
1379
+ class StringContainmentRatio(InstanceMetric):
1380
+ """Metric that returns the ratio of values from a specific field contained in the prediction.
1381
+
1382
+ Attributes:
1383
+ field: The field from the task_data that contains the values to be checked for containment.
1384
+ Example task:
1385
+ Task(
1386
+ input_fields={"question": str},
1387
+ reference_fields={"entities": str},
1388
+ prediction_type=str,
1389
+ metrics=["string_containment_ratio[field=entities]"],
1390
+ )
1391
+ """
1392
+
1393
+ reduction_map = {"mean": ["string_containment"]}
1394
+ main_score = "string_containment"
1395
+ ci_scores = ["string_containment"]
1396
+ field: str = None
1397
+
1398
+ prediction_type = Any # string representation is compared
1399
+
1400
+ def compute(
1401
+ self, references: List[Any], prediction: Any, task_data: List[Dict]
1402
+ ) -> dict:
1403
+ if self.field not in task_data:
1404
+ raise ValueError(
1405
+ f"'{self.field}' field required by {__class__.__name__} is not in passed in task_data: {task_data}"
1406
+ )
1407
+ contain_results = [
1408
+ str(value) in str(prediction) for value in task_data[self.field]
1409
+ ]
1410
+ score = sum(contain_results) / len(contain_results)
1411
+ result = {self.main_score: score}
1412
+ result["score"] = result[self.main_score]
1413
+ result["score_name"] = self.main_score
1414
+ return result
1415
+
1416
+ def verify(self):
1417
+ super().verify()
1418
+ if self.field is None:
1419
+ raise ValueError(
1420
+ "StringContainmentRatio metric requires the 'field' attribute to be set."
1421
+ )
1422
+
1423
+
1424
  class MetricPipeline(MultiStreamOperator, Metric):
1425
  main_score: str = None
1426
  preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
1427
+ postprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
1428
+ postpreprocess_steps: Optional[List[StreamingOperator]] = DeprecatedField(
1429
+ metadata={
1430
+ "deprecation_msg": "Field 'postpreprocess_steps' is deprecated. Please use 'postprocess_steps' for the same purpose."
1431
+ }
1432
  )
1433
  metric: Metric = None
1434
 
 
1449
 
1450
  def prepare(self):
1451
  super().prepare()
1452
+ has_postpreprocess = (
1453
+ hasattr(self, "postpreprocess_steps")
1454
+ and self.postpreprocess_steps is not None
1455
+ and isinstance(self.postpreprocess_steps, list)
1456
+ and len(self.postpreprocess_steps) > 0
1457
+ )
1458
+ has_postprocess = (
1459
+ hasattr(self, "postprocess_steps")
1460
+ and self.postprocess_steps is not None
1461
+ and isinstance(self.postprocess_steps, list)
1462
+ and len(self.postprocess_steps) > 0
1463
+ )
1464
+ assert not (
1465
+ has_postpreprocess and has_postprocess
1466
+ ), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
1467
+ if has_postpreprocess:
1468
+ self.postprocess_steps = self.postpreprocess_steps
1469
  self.prepare_score = Copy(
1470
  field_to_field=[
1471
  [
 
1483
  for step in self.preprocess_steps:
1484
  multi_stream = step(multi_stream)
1485
  multi_stream = self.metric(multi_stream)
1486
+ for step in self.postprocess_steps:
1487
  multi_stream = step(multi_stream)
1488
  return self.prepare_score(multi_stream)
1489
 
 
1509
  experiment_id: str = OptionalField(default_factory=lambda: str(uuid.uuid4()))
1510
 
1511
  def verify(self):
1512
+ if os.path.exists(self.hf_metric_name):
1513
+ UnitxtWarning(
1514
+ f"{self.get_metric_name()} uses a huggingface metric {self.hf_metric_name} which is defined in a local file."
1515
+ f"This may cause issues when running on different machine or different root directories.",
1516
+ Documentation.HUGGINGFACE_METRICS,
1517
+ )
1518
+
1519
  assert (
1520
  self.hf_additional_input_fields is None
1521
  or isoftype(self.hf_additional_input_fields, List[str])
 
1761
  average=self.average,
1762
  )
1763
  if isinstance(result[self.metric], numpy.ndarray):
1764
+ final_result = {self.main_score: nan_mean(result[self.metric])}
1765
  for i, label in enumerate(labels):
1766
  final_result[f"{self.metric}_" + self.id_to_str[label]] = result[
1767
  self.metric
 
2066
  assert (
2067
  len(result[self.metric]) == len(labels)
2068
  ), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
2069
+ final_result = {self.main_score: nan_mean(result[self.metric])}
2070
  for i, label in enumerate(labels):
2071
  final_result[self.metric + "_" + label] = result[self.metric][i]
2072
  else:
 
2108
  average = None
2109
 
2110
 
2111
+ class NLTKMixin(Artifact):
2112
+ def prepare(self):
2113
+ super().prepare()
2114
+ import nltk
2115
+
2116
+ nltk.download("punkt", quiet=True)
2117
+ nltk.download("punkt_tab", quiet=True)
2118
+ self.nltk = nltk
2119
+
2120
+
2121
+ class Rouge(InstanceMetric, NLTKMixin):
2122
  main_score = "rougeL"
2123
  prediction_type = str
2124
  single_reference_per_prediction = False # multiple references allowed
 
2131
 
2132
  def prepare(self):
2133
  super().prepare()
 
2134
  from rouge_score import rouge_scorer
2135
 
2136
  self.rouge_scorer = rouge_scorer
2137
 
 
 
 
2138
  def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
2139
  # for a single instance, prediction is of type str, and references: list of str
2140
  if self.sent_split_newline:
2141
+ prediction = "\n".join(self.nltk.sent_tokenize(prediction.strip()))
2142
 
2143
  references = [
2144
+ "\n".join(self.nltk.sent_tokenize(reference.strip()))
2145
  for reference in references
2146
  ]
2147
 
 
2157
  return score
2158
 
2159
 
2160
+ class RougeHF(HuggingfaceInstanceMetric, NLTKMixin):
2161
  hf_metric_name = "rouge"
2162
  main_score = "rougeL"
2163
  scale = 1.0
 
2183
  {"use_aggregator": False, "rouge_types": self.rouge_types}
2184
  )
2185
 
 
 
 
 
 
2186
  def compute(self, references, prediction, task_data: List[Dict]):
2187
  # for a single instance, prediction is of type str, and references: list of str
2188
  if self.sent_split_newline:
2189
+ prediction = "\n".join(self.nltk.sent_tokenize(prediction.strip()))
2190
 
2191
  references = [
2192
+ "\n".join(self.nltk.sent_tokenize(reference.strip()))
2193
  for reference in references
2194
  ]
2195
 
 
3468
  for pred in q_predictions
3469
  ]
3470
  scores.append(self.eval([q_references], [q_predictions]))
3471
+ return {self.main_score: nan_mean(scores) if len(scores) > 0 else np.nan}
3472
 
3473
 
3474
  class RetrievalMetric(InstanceMetric):
 
3803
  if any(len(scores) == 0 for scores in group_scores_list):
3804
  # no comparison can be made since there is not at least one score per type
3805
  return np.nan
3806
+ control_mean = nan_mean(group_scores_list[0])
3807
+ comparison_mean = nan_mean(group_scores_list[1])
3808
  if control_mean == 0:
3809
  # return 0 if comparison is also 0
3810
  if comparison_mean == 0:
 
3917
  # no comparison can be made since there is not at least one score per type
3918
  h, norm_h = np.nan, np.nan
3919
  else:
3920
+ control_mean = nan_mean(group_scores_list[0])
3921
+ comparison_mean = nan_mean(group_scores_list[1])
3922
  h = 2 * (np.arcsin(np.sqrt(comparison_mean)) - np.arcsin(np.sqrt(control_mean)))
3923
  norm_h = np.clip(a=h / np.pi, a_min=-1, a_max=1)
3924
 
 
3971
  g, norm_g = np.nan, np.nan
3972
  else:
3973
  # otherwise, calculate the variances
3974
+ group_mean = [nan_mean(scores) for scores in group_scores_list]
3975
  # sample variance with 1 degree of freedom (denominator n-1); if n=1, return 0 since otherwise throws an error
3976
  group_var = [
3977
  0.0 if nn == 1 else np.var(scores, ddof=1)
 
4030
  if len(score_list) == 0:
4031
  # no scores to use
4032
  return np.nan
4033
+ return nan_mean(score_list)
4034
 
4035
 
4036
  # metrics using mean reduction
operators.py CHANGED
@@ -45,7 +45,6 @@ import uuid
45
  import zipfile
46
  from abc import abstractmethod
47
  from collections import Counter, defaultdict
48
- from copy import deepcopy
49
  from dataclasses import field
50
  from itertools import zip_longest
51
  from random import Random
@@ -86,7 +85,7 @@ from .settings_utils import get_settings
86
  from .stream import DynamicStream, Stream
87
  from .text_utils import nested_tuple_to_string
88
  from .type_utils import isoftype
89
- from .utils import flatten_dict
90
 
91
  settings = get_settings()
92
 
 
45
  import zipfile
46
  from abc import abstractmethod
47
  from collections import Counter, defaultdict
 
48
  from dataclasses import field
49
  from itertools import zip_longest
50
  from random import Random
 
85
  from .stream import DynamicStream, Stream
86
  from .text_utils import nested_tuple_to_string
87
  from .type_utils import isoftype
88
+ from .utils import deepcopy, flatten_dict
89
 
90
  settings = get_settings()
91
 
schema.py CHANGED
@@ -1,9 +1,9 @@
1
  import json
2
- from dataclasses import field
3
- from typing import Any, Dict, List, Optional
4
 
5
  from datasets import Features, Sequence, Value
6
 
 
7
  from .operator import InstanceOperatorValidator
8
 
9
  UNITXT_DATASET_SCHEMA = Features(
@@ -20,10 +20,7 @@ UNITXT_DATASET_SCHEMA = Features(
20
  )
21
 
22
 
23
- class ToUnitxtGroup(InstanceOperatorValidator):
24
- group: str
25
- metrics: List[str] = None
26
- postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
27
  remove_unnecessary_fields: bool = True
28
 
29
  @staticmethod
@@ -43,6 +40,7 @@ class ToUnitxtGroup(InstanceOperatorValidator):
43
  "template": self.artifact_to_jsonable(
44
  instance["recipe_metadata"]["template"]
45
  ),
 
46
  },
47
  }
48
  instance["task_data"] = json.dumps(task_data)
@@ -56,11 +54,16 @@ class ToUnitxtGroup(InstanceOperatorValidator):
56
 
57
  for key in keys_to_delete:
58
  del instance[key]
59
- instance["group"] = self.group
60
- if self.metrics is not None:
61
- instance["metrics"] = self.metrics
62
- if self.postprocessors is not None:
63
- instance["postprocessors"] = self.postprocessors
 
 
 
 
 
64
  return instance
65
 
66
  def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
 
1
  import json
2
+ from typing import Any, Dict, Optional
 
3
 
4
  from datasets import Features, Sequence, Value
5
 
6
+ from .artifact import Artifact
7
  from .operator import InstanceOperatorValidator
8
 
9
  UNITXT_DATASET_SCHEMA = Features(
 
20
  )
21
 
22
 
23
+ class Finalize(InstanceOperatorValidator):
 
 
 
24
  remove_unnecessary_fields: bool = True
25
 
26
  @staticmethod
 
40
  "template": self.artifact_to_jsonable(
41
  instance["recipe_metadata"]["template"]
42
  ),
43
+ "num_demos": instance["recipe_metadata"]["num_demos"],
44
  },
45
  }
46
  instance["task_data"] = json.dumps(task_data)
 
54
 
55
  for key in keys_to_delete:
56
  del instance[key]
57
+ if "group" not in instance:
58
+ instance["group"] = "unitxt"
59
+ instance["metrics"] = [
60
+ metric.to_json() if isinstance(metric, Artifact) else metric
61
+ for metric in instance["metrics"]
62
+ ]
63
+ instance["postprocessors"] = [
64
+ processor.to_json() if isinstance(processor, Artifact) else processor
65
+ for processor in instance["postprocessors"]
66
+ ]
67
  return instance
68
 
69
  def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
splitters.py CHANGED
@@ -1,6 +1,5 @@
1
  import itertools
2
  from abc import abstractmethod
3
- from copy import deepcopy
4
  from difflib import get_close_matches
5
  from typing import Dict, List, Optional
6
 
@@ -17,6 +16,7 @@ from .split_utils import (
17
  )
18
  from .stream import EmptyStreamError, FaultyStreamError, MultiStream
19
  from .type_utils import isoftype
 
20
 
21
 
22
  class Splitter(MultiStreamOperator):
@@ -109,36 +109,25 @@ class SliceSplit(Splitter):
109
  return MultiStream.from_generators(generators)
110
 
111
 
112
- class Sampler(Artifact):
113
- sample_size: int = None
114
-
115
- def prepare(self):
116
- super().prepare()
117
- self.set_size(self.sample_size)
118
 
119
- def set_size(self, size):
120
- if isinstance(size, str):
121
- assert (
122
- size.isdigit()
123
- ), f"sample_size must be a natural number, got {self.sample_size}"
124
- size = int(size)
125
- self.sample_size = size
126
 
 
127
  @abstractmethod
128
  def sample(
129
- self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
 
 
 
130
  ) -> List[Dict[str, object]]:
131
  pass
132
 
133
- def get_random_generator_based_on_instance(self, instance):
134
- return new_random_generator(sub_seed={**instance["input_fields"]})
135
-
136
  def filter_source_by_instance(
137
  self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
138
  ) -> List[Dict[str, object]]:
139
  if "input_fields" not in instance:
140
  raise ValueError(f"'input_fields' field is missing from '{instance}'.")
141
- # l = list(filter(lambda x: x["inputs"] != instance["inputs"], instances_pool))
142
  try:
143
  return [
144
  item
@@ -154,12 +143,13 @@ class RandomSampler(Sampler):
154
 
155
  def sample(
156
  self,
 
157
  instances_pool: List[Dict[str, object]],
158
  instance: Optional[Dict[str, object]],
159
  ) -> List[Dict[str, object]]:
160
  instances_pool = list(instances_pool)
161
- random_generator = self.get_random_generator_based_on_instance(instance)
162
- return random_generator.sample(instances_pool, self.sample_size)
163
 
164
 
165
  class FixedIndicesSampler(Sampler):
@@ -175,13 +165,14 @@ class FixedIndicesSampler(Sampler):
175
 
176
  def sample(
177
  self,
 
178
  instances_pool: List[Dict[str, object]],
179
  instance: Optional[Dict[str, object]],
180
  ) -> List[Dict[str, object]]:
181
  num_instances = len(instances_pool)
182
 
183
  instances = []
184
- for index in self.indices[0 : self.sample_size]:
185
  if index >= num_instances:
186
  raise ValueError(
187
  f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})"
@@ -200,7 +191,10 @@ class CloseTextSampler(Sampler):
200
  field: str
201
 
202
  def sample(
203
- self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
 
 
 
204
  ) -> List[Dict[str, object]]:
205
  field = f"input_fields/{self.field}"
206
  value = dict_get(instance, field)
@@ -211,9 +205,7 @@ class CloseTextSampler(Sampler):
211
  options = []
212
  for instance_in_pool in instances_pool:
213
  options.append(dict_get(instance_in_pool, field))
214
- closest_matches = get_close_matches(
215
- value, options, n=self.sample_size, cutoff=0
216
- )
217
  # Randmly select 'sample_size' instances that are from the closest matches text
218
  # (There may be multiple instance with same text in the given field, and the order returned is
219
  # is also randomized )
@@ -222,8 +214,8 @@ class CloseTextSampler(Sampler):
222
  for instance_in_pool in instances_pool
223
  if dict_get(instance_in_pool, field) in closest_matches
224
  ]
225
- random_generator = self.get_random_generator_based_on_instance(instance)
226
- return random_generator.sample(instances_pool, self.sample_size)
227
 
228
 
229
  class DiverseLabelsSampler(Sampler):
@@ -306,26 +298,27 @@ class DiverseLabelsSampler(Sampler):
306
 
307
  def sample(
308
  self,
 
309
  instances_pool: List[Dict[str, object]],
310
  instance: Optional[Dict[str, object]],
311
  ) -> List[Dict[str, object]]:
312
  if self.labels_cache is None:
313
  self.labels_cache = self.divide_by_repr(instances_pool)
314
  all_labels = list(self.labels_cache.keys())
315
- random_generator = self.get_random_generator_based_on_instance(instance)
316
  random_generator.shuffle(all_labels)
317
  from collections import Counter
318
 
319
- if self.sample_size > len(instances_pool):
320
  raise ValueError(
321
- f"Request sample size {self.sample_size} is greater than number of instances {len(instances_pool)}"
322
  )
323
  total_allocated = 0
324
  allocations = Counter()
325
 
326
- while total_allocated < self.sample_size:
327
  for label in all_labels:
328
- if total_allocated < self.sample_size:
329
  if len(self.labels_cache[label]) - allocations[label] > 0:
330
  allocations[label] += 1
331
  total_allocated += 1
@@ -341,40 +334,56 @@ class DiverseLabelsSampler(Sampler):
341
  return result
342
 
343
 
344
- class SpreadSplit(InstanceOperatorWithMultiStreamAccess):
345
- source_stream: str = None
346
- target_field: str = None
347
- sampler: Sampler = None
348
 
349
  def prepare(self):
350
  self.local_cache = None
351
  self.sampler.prepare()
352
 
353
- def verify(self):
354
- assert self.source_stream is not None, "Source stream must be specified"
355
- assert self.target_field is not None, "Target field must be specified"
356
- assert self.sampler is not None, "Sampler must be specified"
357
- return super().verify()
358
 
359
  def process(
360
  self, instance: Dict[str, object], multi_stream: MultiStream
361
  ) -> Dict[str, object]:
 
362
  try:
363
  if self.local_cache is None:
364
- self.local_cache = deepcopy(list(multi_stream[self.source_stream]))
365
 
366
  source_stream = self.local_cache
367
  source_stream = self.sampler.filter_source_by_instance(
368
  source_stream, instance
369
  )
370
- if len(source_stream) < self.sampler.sample_size:
371
  raise ValueError(
372
  f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}."
373
  )
374
- sampled_instances = self.sampler.sample(source_stream, instance)
375
- instance[self.target_field] = sampled_instances
 
 
376
  return instance
377
  except FaultyStreamError as e:
378
  raise EmptyStreamError(
379
- f"Unable to fetch instances from '{self.source_stream}' to '{self.target_field}', due to {e.__class__.__name__}: {e}"
380
  ) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import itertools
2
  from abc import abstractmethod
 
3
  from difflib import get_close_matches
4
  from typing import Dict, List, Optional
5
 
 
16
  )
17
  from .stream import EmptyStreamError, FaultyStreamError, MultiStream
18
  from .type_utils import isoftype
19
+ from .utils import deepcopy
20
 
21
 
22
  class Splitter(MultiStreamOperator):
 
109
  return MultiStream.from_generators(generators)
110
 
111
 
112
+ def get_random_generator_based_on_instance(instance):
113
+ return new_random_generator(sub_seed={**instance["input_fields"]})
 
 
 
 
114
 
 
 
 
 
 
 
 
115
 
116
+ class Sampler(Artifact):
117
  @abstractmethod
118
  def sample(
119
+ self,
120
+ sample_size: int,
121
+ instances_pool: List[Dict[str, object]],
122
+ instance: Dict[str, object],
123
  ) -> List[Dict[str, object]]:
124
  pass
125
 
 
 
 
126
  def filter_source_by_instance(
127
  self, instances_pool: List[Dict[str, object]], instance: Dict[str, object]
128
  ) -> List[Dict[str, object]]:
129
  if "input_fields" not in instance:
130
  raise ValueError(f"'input_fields' field is missing from '{instance}'.")
 
131
  try:
132
  return [
133
  item
 
143
 
144
  def sample(
145
  self,
146
+ sample_size,
147
  instances_pool: List[Dict[str, object]],
148
  instance: Optional[Dict[str, object]],
149
  ) -> List[Dict[str, object]]:
150
  instances_pool = list(instances_pool)
151
+ random_generator = get_random_generator_based_on_instance(instance)
152
+ return random_generator.sample(instances_pool, sample_size)
153
 
154
 
155
  class FixedIndicesSampler(Sampler):
 
165
 
166
  def sample(
167
  self,
168
+ sample_size,
169
  instances_pool: List[Dict[str, object]],
170
  instance: Optional[Dict[str, object]],
171
  ) -> List[Dict[str, object]]:
172
  num_instances = len(instances_pool)
173
 
174
  instances = []
175
+ for index in self.indices[0:sample_size]:
176
  if index >= num_instances:
177
  raise ValueError(
178
  f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})"
 
191
  field: str
192
 
193
  def sample(
194
+ self,
195
+ sample_size: int,
196
+ instances_pool: List[Dict[str, object]],
197
+ instance: Dict[str, object],
198
  ) -> List[Dict[str, object]]:
199
  field = f"input_fields/{self.field}"
200
  value = dict_get(instance, field)
 
205
  options = []
206
  for instance_in_pool in instances_pool:
207
  options.append(dict_get(instance_in_pool, field))
208
+ closest_matches = get_close_matches(value, options, n=sample_size, cutoff=0)
 
 
209
  # Randmly select 'sample_size' instances that are from the closest matches text
210
  # (There may be multiple instance with same text in the given field, and the order returned is
211
  # is also randomized )
 
214
  for instance_in_pool in instances_pool
215
  if dict_get(instance_in_pool, field) in closest_matches
216
  ]
217
+ random_generator = get_random_generator_based_on_instance(instance)
218
+ return random_generator.sample(instances_pool, sample_size)
219
 
220
 
221
  class DiverseLabelsSampler(Sampler):
 
298
 
299
  def sample(
300
  self,
301
+ sample_size: int,
302
  instances_pool: List[Dict[str, object]],
303
  instance: Optional[Dict[str, object]],
304
  ) -> List[Dict[str, object]]:
305
  if self.labels_cache is None:
306
  self.labels_cache = self.divide_by_repr(instances_pool)
307
  all_labels = list(self.labels_cache.keys())
308
+ random_generator = get_random_generator_based_on_instance(instance)
309
  random_generator.shuffle(all_labels)
310
  from collections import Counter
311
 
312
+ if sample_size > len(instances_pool):
313
  raise ValueError(
314
+ f"Request sample size {sample_size} is greater than number of instances {len(instances_pool)}"
315
  )
316
  total_allocated = 0
317
  allocations = Counter()
318
 
319
+ while total_allocated < sample_size:
320
  for label in all_labels:
321
+ if total_allocated < sample_size:
322
  if len(self.labels_cache[label]) - allocations[label] > 0:
323
  allocations[label] += 1
324
  total_allocated += 1
 
334
  return result
335
 
336
 
337
+ class Sample(InstanceOperatorWithMultiStreamAccess):
338
+ from_stream: str
339
+ to_field: str
340
+ sampler: Sampler
341
 
342
  def prepare(self):
343
  self.local_cache = None
344
  self.sampler.prepare()
345
 
346
+ @abstractmethod
347
+ def get_sample_size(self, instance) -> int:
348
+ pass
 
 
349
 
350
  def process(
351
  self, instance: Dict[str, object], multi_stream: MultiStream
352
  ) -> Dict[str, object]:
353
+ sample_size = self.get_sample_size(instance)
354
  try:
355
  if self.local_cache is None:
356
+ self.local_cache = deepcopy(list(multi_stream[self.from_stream]))
357
 
358
  source_stream = self.local_cache
359
  source_stream = self.sampler.filter_source_by_instance(
360
  source_stream, instance
361
  )
362
+ if len(source_stream) < sample_size:
363
  raise ValueError(
364
  f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {self.sampler.sample_size}."
365
  )
366
+ sampled_instances = self.sampler.sample(
367
+ sample_size=sample_size, instances_pool=source_stream, instance=instance
368
+ )
369
+ instance[self.to_field] = sampled_instances
370
  return instance
371
  except FaultyStreamError as e:
372
  raise EmptyStreamError(
373
+ f"Unable to fetch instances from '{self.from_stream}' to '{self.to_field}', due to {e.__class__.__name__}: {e}"
374
  ) from e
375
+
376
+
377
+ class ConstantSizeSample(Sample):
378
+ sample_size: int
379
+
380
+ def get_sample_size(self, instance) -> int:
381
+ return self.sample_size
382
+
383
+
384
+ class RandomSizeSample(Sample):
385
+ sample_sizes: List[int]
386
+
387
+ def get_sample_size(self, instance) -> int:
388
+ random_generator = get_random_generator_based_on_instance(instance)
389
+ return random_generator.choice(self.sample_sizes)
standard.py CHANGED
@@ -1,17 +1,18 @@
1
- from typing import List
2
 
3
  from .card import TaskCard
 
4
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
5
  from .formats import Format, SystemFormat
6
  from .logging_utils import get_logger
7
  from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
8
  from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
9
  from .recipe import Recipe
10
- from .schema import ToUnitxtGroup
11
- from .splitters import Sampler, SeparateSplit, SpreadSplit
12
  from .stream import MultiStream
13
  from .system_prompts import EmptySystemPrompt, SystemPrompt
14
- from .templates import Template
15
 
16
  logger = get_logger()
17
 
@@ -21,15 +22,15 @@ class CreateDemosPool(SeparateSplit):
21
  pass
22
 
23
 
24
- class AddDemosField(SpreadSplit):
25
- pass
26
-
27
-
28
  class BaseRecipe(Recipe, SourceSequentialOperator):
 
29
  card: TaskCard
30
- template: Template = None
31
  system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
32
  format: Format = Field(default_factory=SystemFormat)
 
 
 
33
  metrics: List[str] = NonPositionalField(default=None)
34
  postprocessors: List[str] = NonPositionalField(default=None)
35
 
@@ -44,7 +45,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
44
  test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
45
 
46
  demos_pool_size: int = None
47
- num_demos: int = 0
48
  demos_removed_from_data: bool = True
49
 
50
  demos_pool_name: str = "demos_pool"
@@ -59,16 +60,22 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
59
  def before_process_multi_stream(self):
60
  super().before_process_multi_stream()
61
 
 
 
 
 
 
 
62
  def verify(self):
63
  super().verify()
64
- if self.num_demos > 0:
65
  if self.demos_pool_size is None or self.demos_pool_size < 1:
66
  raise ValueError(
67
  "When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
68
  )
69
- if self.demos_pool_size < self.num_demos:
70
  raise ValueError(
71
- f"num_demos (got: {self.num_demos}) should not exceed demos_pool_size (got: {self.demos_pool_size})"
72
  )
73
  if self.loader_limit and self.demos_pool_size > self.loader_limit:
74
  raise ValueError(
@@ -105,6 +112,17 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
105
  f"post processors must be a list of post processor. Got postprocessors = {self.postprocessors}"
106
  )
107
 
 
 
 
 
 
 
 
 
 
 
 
108
  def prepare_refiners(self):
109
  self.train_refiner.max_instances = self.max_train_instances
110
  self.train_refiner.apply_to_streams = ["train"]
@@ -118,31 +136,12 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
118
  self.test_refiner.apply_to_streams = ["test"]
119
  self.processing.steps.append(self.test_refiner)
120
 
121
- def prepare_metrics_and_postprocessors(self):
122
- # Check is done here to ensure get_postprocessor is called on
123
- # a Template object
124
- if self.template is not None and not isinstance(self.template, Template):
125
  raise ValueError(
126
- f"template argument must be an object of type Template. Got template = {self.template}"
127
  )
128
 
129
- if self.postprocessors is None:
130
- postprocessors = self.template.get_postprocessors()
131
- else:
132
- postprocessors = self.postprocessors
133
-
134
- if self.metrics is None:
135
- metrics = self.card.task.metrics
136
- else:
137
- metrics = self.metrics
138
-
139
- metrics = [
140
- metric if isinstance(metric, str) else metric.to_json()
141
- for metric in metrics
142
- ]
143
-
144
- return metrics, postprocessors
145
-
146
  def set_pipelines(self):
147
  self.loading = SequentialOperator()
148
  self.loading.__description__ = "Loading the data from the data source."
@@ -158,8 +157,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
158
  self.processing.__description__ = (
159
  "Setting task fields (and selecting demos per sample if needed)."
160
  )
161
- self.verblization = SequentialOperator()
162
- self.verblization.__description__ = "Verbalizing the input to the model and gold references to the 'source', 'target' and 'references' fields."
163
  self.finalize = SequentialOperator()
164
  self.finalize.__description__ = "Adding post processors. Removing intermediate fields. Creating the final output dataset."
165
 
@@ -169,7 +168,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
169
  self.standardization,
170
  self.processing,
171
  self.metadata,
172
- self.verblization,
173
  self.finalize,
174
  ]
175
 
@@ -193,7 +192,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
193
 
194
  self.inference = SequentialOperator()
195
 
196
- self.inference.steps = [self.verblization, self.finalize]
197
 
198
  self._demos_pool_cache = None
199
 
@@ -202,7 +201,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
202
  return list(self.inference_instance(ms)["__inference__"])
203
 
204
  def production_demos_pool(self):
205
- if self.num_demos > 0:
206
  if self._demos_pool_cache is None:
207
  self._demos_pool_cache = list(
208
  self.inference_demos()[self.demos_pool_name]
@@ -210,6 +209,14 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
210
  return self._demos_pool_cache
211
  return []
212
 
 
 
 
 
 
 
 
 
213
  def produce(self, task_instances):
214
  """Use the recipe in production to produce model ready query from standard task instance."""
215
  self.before_process_multi_stream()
@@ -243,11 +250,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
243
  self.metadata.steps.append(
244
  Set(
245
  fields={
246
- "recipe_metadata": {
247
- "template": self.template,
248
- "system_prompt": self.system_prompt,
249
- "format": self.format,
250
- }
251
  }
252
  )
253
  )
@@ -260,7 +264,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
260
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
261
  self.processing.steps.append(self.augmentor)
262
 
263
- if self.demos_pool_size is not None and self.demos_pool_size > 0:
264
  self.processing.steps.append(
265
  CreateDemosPool(
266
  from_split=self.demos_taken_from,
@@ -270,7 +274,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
270
  )
271
  )
272
 
273
- if self.num_demos > 0:
274
  if self.sampler is None:
275
  if self.card.sampler is None:
276
  raise ValueError(
@@ -279,33 +283,76 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
279
  )
280
  self.sampler = self.card.sampler
281
 
282
- self.sampler.set_size(self.num_demos)
283
-
284
  self.prepare_refiners()
285
 
286
- self.verblization.steps.append(self.template)
287
- if self.num_demos > 0:
288
- self.verblization.steps.append(
289
- AddDemosField(
290
- source_stream=self.demos_pool_name,
291
- target_field=self.demos_field,
292
- sampler=self.sampler,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
- self.verblization.steps.append(self.system_prompt)
296
- self.verblization.steps.append(self.format)
297
- if self.augmentor.augment_model_input:
298
- self.verblization.steps.append(self.augmentor)
 
 
 
 
299
 
300
- metrics, postprocessors = self.prepare_metrics_and_postprocessors()
 
 
 
301
 
302
- self.finalize.steps.append(
303
- ToUnitxtGroup(
304
- group="unitxt",
305
- metrics=metrics,
306
- postprocessors=postprocessors,
307
  )
308
- )
 
 
 
 
309
 
310
 
311
  class StandardRecipeWithIndexes(BaseRecipe):
 
1
+ from typing import List, Optional, Union
2
 
3
  from .card import TaskCard
4
+ from .collections_operators import GetLength
5
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
6
  from .formats import Format, SystemFormat
7
  from .logging_utils import get_logger
8
  from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
9
  from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
10
  from .recipe import Recipe
11
+ from .schema import Finalize
12
+ from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
13
  from .stream import MultiStream
14
  from .system_prompts import EmptySystemPrompt, SystemPrompt
15
+ from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template
16
 
17
  logger = get_logger()
18
 
 
22
  pass
23
 
24
 
 
 
 
 
25
  class BaseRecipe(Recipe, SourceSequentialOperator):
26
+ # Base parameters
27
  card: TaskCard
28
+ template: Union[Template, List[Template]] = None
29
  system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
30
  format: Format = Field(default_factory=SystemFormat)
31
+
32
+ # Additional parameters
33
+ template_card_index: int = NonPositionalField(default=None)
34
  metrics: List[str] = NonPositionalField(default=None)
35
  postprocessors: List[str] = NonPositionalField(default=None)
36
 
 
45
  test_refiner: StreamRefiner = OptionalField(default_factory=StreamRefiner)
46
 
47
  demos_pool_size: int = None
48
+ num_demos: Optional[Union[int, List[int]]] = 0
49
  demos_removed_from_data: bool = True
50
 
51
  demos_pool_name: str = "demos_pool"
 
60
  def before_process_multi_stream(self):
61
  super().before_process_multi_stream()
62
 
63
+ @property
64
+ def max_demos_size(self):
65
+ if isinstance(self.num_demos, list):
66
+ return max(self.num_demos)
67
+ return self.num_demos
68
+
69
  def verify(self):
70
  super().verify()
71
+ if self.use_demos:
72
  if self.demos_pool_size is None or self.demos_pool_size < 1:
73
  raise ValueError(
74
  "When using demonstrations both num_demos and demos_pool_size should be assigned with positive integers."
75
  )
76
+ if self.demos_pool_size < self.max_demos_size:
77
  raise ValueError(
78
+ f"num_demos (got: {self.max_demos_size}) should not exceed demos_pool_size (got: {self.demos_pool_size})"
79
  )
80
  if self.loader_limit and self.demos_pool_size > self.loader_limit:
81
  raise ValueError(
 
112
  f"post processors must be a list of post processor. Got postprocessors = {self.postprocessors}"
113
  )
114
 
115
+ if self.template is None:
116
+ raise ValueError(
117
+ "You must set in the recipe either `template`, `template_card_index` or `templates`."
118
+ )
119
+
120
+ if isinstance(self.template, list):
121
+ for template in self.template:
122
+ self.verify_template(template)
123
+ else:
124
+ self.verify_template(self.template)
125
+
126
  def prepare_refiners(self):
127
  self.train_refiner.max_instances = self.max_train_instances
128
  self.train_refiner.apply_to_streams = ["train"]
 
136
  self.test_refiner.apply_to_streams = ["test"]
137
  self.processing.steps.append(self.test_refiner)
138
 
139
+ def verify_template(self, template):
140
+ if not isinstance(template, Template):
 
 
141
  raise ValueError(
142
+ f"template argument must be an object of type Template. Got template = {template}"
143
  )
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def set_pipelines(self):
146
  self.loading = SequentialOperator()
147
  self.loading.__description__ = "Loading the data from the data source."
 
157
  self.processing.__description__ = (
158
  "Setting task fields (and selecting demos per sample if needed)."
159
  )
160
+ self.verbalization = SequentialOperator()
161
+ self.verbalization.__description__ = "Verbalizing the input to the model and gold references to the 'source', 'target' and 'references' fields."
162
  self.finalize = SequentialOperator()
163
  self.finalize.__description__ = "Adding post processors. Removing intermediate fields. Creating the final output dataset."
164
 
 
168
  self.standardization,
169
  self.processing,
170
  self.metadata,
171
+ self.verbalization,
172
  self.finalize,
173
  ]
174
 
 
192
 
193
  self.inference = SequentialOperator()
194
 
195
+ self.inference.steps = [self.verbalization, self.finalize]
196
 
197
  self._demos_pool_cache = None
198
 
 
201
  return list(self.inference_instance(ms)["__inference__"])
202
 
203
  def production_demos_pool(self):
204
+ if self.use_demos:
205
  if self._demos_pool_cache is None:
206
  self._demos_pool_cache = list(
207
  self.inference_demos()[self.demos_pool_name]
 
209
  return self._demos_pool_cache
210
  return []
211
 
212
+ @property
213
+ def has_custom_demos_pool(self):
214
+ return self.demos_pool_size is not None and self.demos_pool_size > 0
215
+
216
+ @property
217
+ def use_demos(self):
218
+ return self.num_demos is not None and self.max_demos_size > 0
219
+
220
  def produce(self, task_instances):
221
  """Use the recipe in production to produce model ready query from standard task instance."""
222
  self.before_process_multi_stream()
 
250
  self.metadata.steps.append(
251
  Set(
252
  fields={
253
+ "recipe_metadata/system_prompt": self.system_prompt,
254
+ "recipe_metadata/format": self.format,
 
 
 
255
  }
256
  )
257
  )
 
264
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
265
  self.processing.steps.append(self.augmentor)
266
 
267
+ if self.has_custom_demos_pool:
268
  self.processing.steps.append(
269
  CreateDemosPool(
270
  from_split=self.demos_taken_from,
 
274
  )
275
  )
276
 
277
+ if self.use_demos:
278
  if self.sampler is None:
279
  if self.card.sampler is None:
280
  raise ValueError(
 
283
  )
284
  self.sampler = self.card.sampler
285
 
 
 
286
  self.prepare_refiners()
287
 
288
+ if self.use_demos:
289
+ if isinstance(self.num_demos, int):
290
+ self.verbalization.steps.append(
291
+ ConstantSizeSample(
292
+ from_stream=self.demos_pool_name,
293
+ to_field=self.demos_field,
294
+ sampler=self.sampler,
295
+ sample_size=self.num_demos,
296
+ )
297
+ )
298
+ self.verbalization.steps.append(
299
+ Set(fields={"recipe_metadata/num_demos": self.num_demos})
300
+ )
301
+
302
+ elif isinstance(self.num_demos, list):
303
+ self.verbalization.steps.append(
304
+ RandomSizeSample(
305
+ from_stream=self.demos_pool_name,
306
+ to_field=self.demos_field,
307
+ sampler=self.sampler,
308
+ sample_sizes=self.num_demos,
309
+ )
310
  )
311
+ self.verbalization.steps.append(
312
+ GetLength(field="demos", to_field="recipe_metadata/num_demos")
313
+ )
314
+ else:
315
+ raise ValueError("num_demos must be int or List[int]")
316
+
317
+ if isinstance(self.template, list):
318
+ self.verbalization.steps.append(
319
+ ApplyRandomTemplate(
320
+ templates=self.template, demos_field=self.demos_field
321
+ )
322
+ )
323
+ else:
324
+ self.verbalization.steps.append(
325
+ ApplySingleTemplate(
326
+ template=self.template, demos_field=self.demos_field
327
+ )
328
+ )
329
+ else:
330
+ self.verbalization.steps.append(
331
+ Set(fields={"recipe_metadata/num_demos": 0})
332
  )
333
+ if isinstance(self.template, list):
334
+ self.verbalization.steps.append(
335
+ ApplyRandomTemplate(templates=self.template)
336
+ )
337
+ else:
338
+ self.verbalization.steps.append(
339
+ ApplySingleTemplate(template=self.template)
340
+ )
341
 
342
+ self.verbalization.steps.append(self.system_prompt)
343
+ self.verbalization.steps.append(self.format)
344
+ if self.augmentor.augment_model_input:
345
+ self.verbalization.steps.append(self.augmentor)
346
 
347
+ if self.postprocessors is not None:
348
+ self.finalize.steps.append(
349
+ Set(fields={"postprocessors": self.postprocessors})
 
 
350
  )
351
+
352
+ if self.metrics is not None:
353
+ self.finalize.steps.append(Set(fields={"metrics": self.metrics}))
354
+
355
+ self.finalize.steps.append(Finalize())
356
 
357
 
358
  class StandardRecipeWithIndexes(BaseRecipe):
stream.py CHANGED
@@ -2,7 +2,6 @@ import tempfile
2
  import traceback
3
  import warnings
4
  from abc import abstractmethod
5
- from copy import deepcopy
6
  from typing import Any, Callable, Dict, Generator, Iterable, List
7
 
8
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
@@ -11,6 +10,7 @@ from .dataclass import Dataclass, OptionalField
11
  from .generator_utils import CopyingReusableGenerator, ReusableGenerator
12
  from .logging_utils import get_logger
13
  from .settings_utils import get_settings
 
14
 
15
  settings = get_settings()
16
  logger = get_logger()
 
2
  import traceback
3
  import warnings
4
  from abc import abstractmethod
 
5
  from typing import Any, Callable, Dict, Generator, Iterable, List
6
 
7
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
 
10
  from .generator_utils import CopyingReusableGenerator, ReusableGenerator
11
  from .logging_utils import get_logger
12
  from .settings_utils import get_settings
13
+ from .utils import deepcopy
14
 
15
  settings = get_settings()
16
  logger = get_logger()
struct_data_operators.py CHANGED
@@ -18,7 +18,6 @@ For key-value pairs, expected input format is:
18
  import json
19
  import random
20
  from abc import ABC, abstractmethod
21
- from copy import deepcopy
22
  from typing import (
23
  Any,
24
  Dict,
@@ -30,6 +29,7 @@ import pandas as pd
30
 
31
  from .dict_utils import dict_get
32
  from .operators import FieldOperator, InstanceOperator
 
33
 
34
 
35
  class SerializeTable(ABC, FieldOperator):
 
18
  import json
19
  import random
20
  from abc import ABC, abstractmethod
 
21
  from typing import (
22
  Any,
23
  Dict,
 
29
 
30
  from .dict_utils import dict_get
31
  from .operators import FieldOperator, InstanceOperator
32
+ from .utils import deepcopy
33
 
34
 
35
  class SerializeTable(ABC, FieldOperator):
task.py CHANGED
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Union
4
  from .artifact import fetch_artifact
5
  from .dataclass import DeprecatedField
6
  from .deprecation_utils import deprecation
7
- from .logging_utils import get_logger
8
  from .operator import InstanceOperator
9
  from .type_utils import (
10
  Type,
@@ -77,12 +77,14 @@ class Task(InstanceOperator):
77
  def prepare(self):
78
  super().prepare()
79
  if self.input_fields is not None and self.inputs is not None:
80
- raise ValueError(
81
- "Conflicting attributes: 'input_fields' cannot be set simultaneously with 'inputs'. Use only 'input_fields'"
 
82
  )
83
  if self.reference_fields is not None and self.outputs is not None:
84
- raise ValueError(
85
- "Conflicting attributes: 'reference_fields' cannot be set simultaneously with 'output'. Use only 'reference_fields'"
 
86
  )
87
 
88
  self.input_fields = (
@@ -107,9 +109,15 @@ class Task(InstanceOperator):
107
 
108
  def verify(self):
109
  if self.input_fields is None:
110
- raise ValueError("Missing attribute in task: 'input_fields' not set.")
 
 
 
111
  if self.reference_fields is None:
112
- raise ValueError("Missing attribute in task: 'reference_fields' not set.")
 
 
 
113
  for io_type in ["input_fields", "reference_fields"]:
114
  data = (
115
  self.input_fields
@@ -118,11 +126,12 @@ class Task(InstanceOperator):
118
  )
119
 
120
  if isinstance(data, list) or not is_type_dict(data):
121
- get_logger().warning(
122
  f"'{io_type}' field of Task should be a dictionary of field names and their types. "
123
  f"For example, {{'text': str, 'classes': List[str]}}. Instead only '{data}' was "
124
  f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
125
- f"will raise an exception."
 
126
  )
127
  data = {key: Any for key in data}
128
  if io_type == "input_fields":
@@ -131,11 +140,12 @@ class Task(InstanceOperator):
131
  self.reference_fields = data
132
 
133
  if not self.prediction_type:
134
- get_logger().warning(
135
  "'prediction_type' was not set in Task. It is used to check the output of "
136
  "template post processors is compatible with the expected input of the metrics. "
137
  "Setting `prediction_type` to 'Any' (no checking is done). In future version "
138
- "of unitxt this will raise an exception."
 
139
  )
140
  self.prediction_type = Any
141
 
@@ -191,18 +201,20 @@ class Task(InstanceOperator):
191
  ):
192
  continue
193
 
194
- raise ValueError(
195
  f"The task's prediction type ({prediction_type}) and '{metric_id}' "
196
- f"metric's prediction type ({metric_prediction_type}) are different."
 
197
  )
198
 
199
  def verify_defaults(self):
200
  if self.defaults:
201
  if not isinstance(self.defaults, dict):
202
- raise ValueError(
203
  f"If specified, the 'defaults' must be a dictionary, "
204
  f"however, '{self.defaults}' was provided instead, "
205
- f"which is of type '{to_type_string(type(self.defaults))}'."
 
206
  )
207
 
208
  for default_name, default_value in self.defaults.items():
 
4
  from .artifact import fetch_artifact
5
  from .dataclass import DeprecatedField
6
  from .deprecation_utils import deprecation
7
+ from .error_utils import Documentation, UnitxtError, UnitxtWarning
8
  from .operator import InstanceOperator
9
  from .type_utils import (
10
  Type,
 
77
  def prepare(self):
78
  super().prepare()
79
  if self.input_fields is not None and self.inputs is not None:
80
+ raise UnitxtError(
81
+ "Conflicting attributes: 'input_fields' cannot be set simultaneously with 'inputs'. Use only 'input_fields'",
82
+ Documentation.ADDING_TASK,
83
  )
84
  if self.reference_fields is not None and self.outputs is not None:
85
+ raise UnitxtError(
86
+ "Conflicting attributes: 'reference_fields' cannot be set simultaneously with 'output'. Use only 'reference_fields'",
87
+ Documentation.ADDING_TASK,
88
  )
89
 
90
  self.input_fields = (
 
109
 
110
  def verify(self):
111
  if self.input_fields is None:
112
+ raise UnitxtError(
113
+ "Missing attribute in task: 'input_fields' not set.",
114
+ Documentation.ADDING_TASK,
115
+ )
116
  if self.reference_fields is None:
117
+ raise UnitxtError(
118
+ "Missing attribute in task: 'reference_fields' not set.",
119
+ Documentation.ADDING_TASK,
120
+ )
121
  for io_type in ["input_fields", "reference_fields"]:
122
  data = (
123
  self.input_fields
 
126
  )
127
 
128
  if isinstance(data, list) or not is_type_dict(data):
129
+ UnitxtWarning(
130
  f"'{io_type}' field of Task should be a dictionary of field names and their types. "
131
  f"For example, {{'text': str, 'classes': List[str]}}. Instead only '{data}' was "
132
  f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
133
+ f"will raise an exception.",
134
+ Documentation.ADDING_TASK,
135
  )
136
  data = {key: Any for key in data}
137
  if io_type == "input_fields":
 
140
  self.reference_fields = data
141
 
142
  if not self.prediction_type:
143
+ UnitxtWarning(
144
  "'prediction_type' was not set in Task. It is used to check the output of "
145
  "template post processors is compatible with the expected input of the metrics. "
146
  "Setting `prediction_type` to 'Any' (no checking is done). In future version "
147
+ "of unitxt this will raise an exception.",
148
+ Documentation.ADDING_TASK,
149
  )
150
  self.prediction_type = Any
151
 
 
201
  ):
202
  continue
203
 
204
+ raise UnitxtError(
205
  f"The task's prediction type ({prediction_type}) and '{metric_id}' "
206
+ f"metric's prediction type ({metric_prediction_type}) are different.",
207
+ Documentation.ADDING_TASK,
208
  )
209
 
210
  def verify_defaults(self):
211
  if self.defaults:
212
  if not isinstance(self.defaults, dict):
213
+ raise UnitxtError(
214
  f"If specified, the 'defaults' must be a dictionary, "
215
  f"however, '{self.defaults}' was provided instead, "
216
+ f"which is of type '{to_type_string(type(self.defaults))}'.",
217
+ Documentation.ADDING_TASK,
218
  )
219
 
220
  for default_name, default_value in self.defaults.items():
templates.py CHANGED
@@ -6,17 +6,20 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
  from .artifact import Artifact
7
  from .collections import ListCollection
8
  from .dataclass import NonPositionalField
 
 
9
  from .operator import InstanceOperator
10
  from .random_utils import new_random_generator
11
  from .type_utils import isoftype
12
 
13
 
14
- class TemplateFormatKeyError(KeyError):
15
  def __init__(self, template, data, data_type, format_str, format_name):
16
  keys = ", ".join(data.keys())
17
  super().__init__(
18
  f"Available {data_type}s are [{keys}] "
19
- f"but {template.__class__.__name__}.{format_name} format requires a different ones: '{format_str}'"
 
20
  )
21
 
22
 
@@ -92,6 +95,7 @@ class Template(InstanceOperator):
92
  "references": references,
93
  "instruction": instruction,
94
  "target_prefix": target_prefix,
 
95
  }
96
 
97
  @abstractmethod
@@ -108,9 +112,6 @@ class Template(InstanceOperator):
108
  ) -> Tuple[str, List[str]]:
109
  pass
110
 
111
- def get_postprocessors(self) -> List[str]:
112
- return self.postprocessors
113
-
114
  def serialize_data(self, data):
115
  return {
116
  k: ", ".join(str(t) for t in v) if isinstance(v, list) else v
@@ -123,6 +124,11 @@ class Template(InstanceOperator):
123
  if serialize:
124
  data = self.serialize_data(data)
125
  try:
 
 
 
 
 
126
  return format_str.format(**data)
127
  except KeyError as e:
128
  raise TemplateFormatKeyError(
@@ -130,6 +136,49 @@ class Template(InstanceOperator):
130
  ) from e
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  class InputOutputTemplate(Template):
134
  """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
135
 
@@ -471,8 +520,9 @@ class MultipleChoiceTemplate(Template):
471
  try:
472
  return reference_fields[self.choices_field].index(target)
473
  except ValueError as e:
474
- raise ValueError(
475
- f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {reference_fields[self.choices_field]}"
 
476
  ) from e
477
  return target
478
 
@@ -485,8 +535,9 @@ class MultipleChoiceTemplate(Template):
485
  try:
486
  target = reference_fields[self.choices_field].index(target)
487
  except ValueError as e:
488
- raise ValueError(
489
- f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {reference_fields[self.choices_field]}"
 
490
  ) from e
491
 
492
  choices = self.inputs_to_choices(reference_fields, self.target_choice_format)
@@ -494,8 +545,9 @@ class MultipleChoiceTemplate(Template):
494
  try:
495
  target = choices[target]
496
  except IndexError as e:
497
- raise IndexError(
498
- f"MultipleChoiceTemplate cannot find index number {target} in choices: {choices}"
 
499
  ) from e
500
 
501
  return target, [target]
@@ -574,21 +626,21 @@ class YesNoTemplate(Template):
574
  try:
575
  gold_class_names = reference_fields[self.label_field]
576
  except KeyError as e:
577
- raise RuntimeError(
578
  f"Available reference_fields are {list(reference_fields.keys())}, missing required label field: '{self.label_field}'."
579
  ) from e
580
  if not isinstance(gold_class_names, list):
581
- raise RuntimeError(
582
  f"Unexpected value for gold_class_names: '{gold_class_names}'. Expecting a list."
583
  )
584
  try:
585
  queried_class_name = reference_fields[self.class_field]
586
  except KeyError as e:
587
- raise RuntimeError(
588
  f"Available reference_fields are {list(reference_fields.keys())}, missing required class field: '{self.class_field}'."
589
  ) from e
590
  if not queried_class_name or not isinstance(queried_class_name, str):
591
- raise RuntimeError(
592
  f"Unexpected value for queried_class_names: '{queried_class_name}'. Expected a string."
593
  )
594
  if queried_class_name in gold_class_names:
@@ -674,8 +726,9 @@ class MultiLabelTemplate(InputOutputTemplate):
674
  ) -> str:
675
  labels = reference_fields[self.labels_field]
676
  if not isinstance(labels, list):
677
- raise ValueError(
678
- f"MultiLabelTemplate requires labels field '{self.labels_field}' to be a list. Got {self.labels_field}<{type(labels).__name__}>: {labels}"
 
679
  )
680
  if len(labels) == 0:
681
  labels = [self.empty_label]
@@ -694,12 +747,14 @@ class MultiReferenceTemplate(InputOutputTemplate):
694
  ) -> List[str]:
695
  references = reference_fields[self.references_field]
696
  if not isoftype(references, List[str]):
697
- raise ValueError(
698
- f"MultiReferenceTemplate requires references field '{self.references_field}' to be List[str]. Got {self.references_field}<{type(references).__name__}>: {references}"
 
699
  )
700
  if len(references) == 0:
701
- raise ValueError(
702
- "No references found. MultiReferenceTemplate requires at least one reference."
 
703
  )
704
 
705
  if self.random_reference:
 
6
  from .artifact import Artifact
7
  from .collections import ListCollection
8
  from .dataclass import NonPositionalField
9
+ from .dict_utils import dict_set
10
+ from .error_utils import Documentation, UnitxtError
11
  from .operator import InstanceOperator
12
  from .random_utils import new_random_generator
13
  from .type_utils import isoftype
14
 
15
 
16
+ class TemplateFormatKeyError(UnitxtError):
17
  def __init__(self, template, data, data_type, format_str, format_name):
18
  keys = ", ".join(data.keys())
19
  super().__init__(
20
  f"Available {data_type}s are [{keys}] "
21
+ f"but {template.__class__.__name__}.{format_name} format requires a different ones: '{format_str}'",
22
+ Documentation.ADDING_TEMPLATE,
23
  )
24
 
25
 
 
95
  "references": references,
96
  "instruction": instruction,
97
  "target_prefix": target_prefix,
98
+ "postprocessors": self.postprocessors,
99
  }
100
 
101
  @abstractmethod
 
112
  ) -> Tuple[str, List[str]]:
113
  pass
114
 
 
 
 
115
  def serialize_data(self, data):
116
  return {
117
  k: ", ".join(str(t) for t in v) if isinstance(v, list) else v
 
124
  if serialize:
125
  data = self.serialize_data(data)
126
  try:
127
+ if format_str is None:
128
+ raise UnitxtError(
129
+ f"Required field 'output_format' of class {self.__class__.__name__} not set in {self.__class__.__name__}",
130
+ Documentation.ADDING_TEMPLATE,
131
+ )
132
  return format_str.format(**data)
133
  except KeyError as e:
134
  raise TemplateFormatKeyError(
 
136
  ) from e
137
 
138
 
139
+ class ApplyTemplate(InstanceOperator):
140
+ demos_field: Optional[str] = None
141
+
142
+ @abstractmethod
143
+ def get_template(self, instance: Dict[str, Any]) -> Template:
144
+ pass
145
+
146
+ def apply(self, template: Template, instance: Dict[str, Any]):
147
+ return template.process_instance(instance)
148
+
149
+ def process(
150
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
151
+ ) -> Dict[str, Any]:
152
+ template = self.get_template(instance)
153
+
154
+ if self.demos_field is not None:
155
+ if self.demos_field not in instance:
156
+ raise ValueError("Demos field is missing.")
157
+ instance[self.demos_field] = [
158
+ self.apply(template, demo_instance)
159
+ for demo_instance in instance[self.demos_field]
160
+ ]
161
+ dict_set(instance, "recipe_metadata/template", template)
162
+ return self.apply(template, instance)
163
+
164
+
165
+ class ApplySingleTemplate(ApplyTemplate):
166
+ template: Template
167
+
168
+ def get_template(self, instance: Dict[str, Any]) -> Template:
169
+ return self.template
170
+
171
+
172
+ class ApplyRandomTemplate(ApplyTemplate):
173
+ templates: List[Template]
174
+
175
+ def get_template(self, instance: Dict[str, Any]) -> Template:
176
+ random_generator = new_random_generator(
177
+ {**instance["input_fields"], **instance["reference_fields"]}
178
+ )
179
+ return random_generator.choice(self.templates)
180
+
181
+
182
  class InputOutputTemplate(Template):
183
  """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
184
 
 
520
  try:
521
  return reference_fields[self.choices_field].index(target)
522
  except ValueError as e:
523
+ raise UnitxtError(
524
+ f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {reference_fields[self.choices_field]}",
525
+ Documentation.ADDING_TEMPLATE,
526
  ) from e
527
  return target
528
 
 
535
  try:
536
  target = reference_fields[self.choices_field].index(target)
537
  except ValueError as e:
538
+ raise UnitxtError(
539
+ f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {reference_fields[self.choices_field]}",
540
+ Documentation.ADDING_TEMPLATE,
541
  ) from e
542
 
543
  choices = self.inputs_to_choices(reference_fields, self.target_choice_format)
 
545
  try:
546
  target = choices[target]
547
  except IndexError as e:
548
+ raise UnitxtError(
549
+ f"MultipleChoiceTemplate cannot find index number {target} in choices: {choices}",
550
+ Documentation.ADDING_TEMPLATE,
551
  ) from e
552
 
553
  return target, [target]
 
626
  try:
627
  gold_class_names = reference_fields[self.label_field]
628
  except KeyError as e:
629
+ raise UnitxtError(
630
  f"Available reference_fields are {list(reference_fields.keys())}, missing required label field: '{self.label_field}'."
631
  ) from e
632
  if not isinstance(gold_class_names, list):
633
+ raise UnitxtError(
634
  f"Unexpected value for gold_class_names: '{gold_class_names}'. Expecting a list."
635
  )
636
  try:
637
  queried_class_name = reference_fields[self.class_field]
638
  except KeyError as e:
639
+ raise UnitxtError(
640
  f"Available reference_fields are {list(reference_fields.keys())}, missing required class field: '{self.class_field}'."
641
  ) from e
642
  if not queried_class_name or not isinstance(queried_class_name, str):
643
+ raise UnitxtError(
644
  f"Unexpected value for queried_class_names: '{queried_class_name}'. Expected a string."
645
  )
646
  if queried_class_name in gold_class_names:
 
726
  ) -> str:
727
  labels = reference_fields[self.labels_field]
728
  if not isinstance(labels, list):
729
+ raise UnitxtError(
730
+ f"MultiLabelTemplate requires labels field '{self.labels_field}' to be a list. Got {self.labels_field}<{type(labels).__name__}>: {labels}",
731
+ Documentation.ADDING_TEMPLATE,
732
  )
733
  if len(labels) == 0:
734
  labels = [self.empty_label]
 
747
  ) -> List[str]:
748
  references = reference_fields[self.references_field]
749
  if not isoftype(references, List[str]):
750
+ raise UnitxtError(
751
+ f"MultiReferenceTemplate requires references field '{self.references_field}' to be List[str]. Got {self.references_field}<{type(references).__name__}>: {references}",
752
+ Documentation.ADDING_TEMPLATE,
753
  )
754
  if len(references) == 0:
755
+ raise UnitxtError(
756
+ "No references found. MultiReferenceTemplate requires at least one reference.",
757
+ Documentation.ADDING_TEMPLATE,
758
  )
759
 
760
  if self.random_reference:
utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import importlib.util
2
  import json
3
  import os
@@ -125,3 +126,7 @@ def import_module_from_file(file_path):
125
  # Load the module
126
  spec.loader.exec_module(module)
127
  return module
 
 
 
 
 
1
+ import copy
2
  import importlib.util
3
  import json
4
  import os
 
126
  # Load the module
127
  spec.loader.exec_module(module)
128
  return module
129
+
130
+
131
+ def deepcopy(obj):
132
+ return copy.deepcopy(obj)
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.12.2"
 
1
+ version = "1.12.3"