Elron commited on
Commit
058c80a
·
verified ·
1 Parent(s): 7f5e8be

Upload folder using huggingface_hub

Browse files
Files changed (27) hide show
  1. README.md +28 -9
  2. api.py +74 -3
  3. artifact.py +52 -18
  4. blocks.py +2 -2
  5. catalog.py +1 -7
  6. dataset.py +7 -3
  7. dataset_utils.py +14 -6
  8. deprecation_utils.py +9 -5
  9. dialog_operators.py +1 -0
  10. fusion.py +6 -2
  11. inference.py +212 -4
  12. llm_as_judge.py +16 -12
  13. loaders.py +44 -21
  14. metric.py +6 -1
  15. metric_utils.py +68 -7
  16. metrics.py +62 -37
  17. operator.py +60 -61
  18. operators.py +23 -9
  19. processors.py +5 -0
  20. settings_utils.py +6 -7
  21. standard.py +8 -2
  22. string_operators.py +21 -0
  23. struct_data_operators.py +1 -0
  24. task.py +44 -0
  25. text_utils.py +55 -16
  26. utils.py +20 -3
  27. version.py +1 -1
README.md CHANGED
@@ -11,11 +11,11 @@ pinned: false
11
  <img src="https://raw.githubusercontent.com/IBM/unitxt/main/assets/banner.png" alt="Image Description" width="100%" />
12
  </div>
13
 
14
- [![Button](https://img.shields.io/badge/Video-pink?style=for-the-badge)](https://unitxt.readthedocs.io/)
 
15
  [![Button](https://img.shields.io/badge/Demo-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/demo.html)
16
  [![Button](https://img.shields.io/badge/Tutorial-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/adding_dataset.html)
17
  [![Button](https://img.shields.io/badge/Paper-pink?style=for-the-badge)](https://arxiv.org/abs/2401.14019)
18
- [![Button](https://img.shields.io/badge/Documentation-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/modules.html)
19
  [![Button](https://img.shields.io/badge/Catalog-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/catalog/catalog.__dir__.html)
20
  [![Button](https://img.shields.io/badge/Contributors-pink?style=for-the-badge)](https://github.com/IBM/unitxt/blob/main/CONTRIBUTING.md)
21
  [![Button](https://img.shields.io/badge/PyPi-pink?style=for-the-badge)](https://pypi.org/project/unitxt/)
@@ -72,13 +72,32 @@ pre-commit install
72
  If you use Unitxt in your research, please cite our paper:
73
 
74
  ```
75
- @misc{unitxt,
76
- title={Unitxt: Flexible, Shareable and Reusable Data Preparation and Evaluation for Generative AI},
77
- author={Elron Bandel and Yotam Perlitz and Elad Venezian and Roni Friedman-Melamed and Ofir Arviv and Matan Orbach and Shachar Don-Yehyia and Dafna Sheinwald and Ariel Gera and Leshem Choshen and Michal Shmueli-Scheuer and Yoav Katz},
78
- year={2024},
79
- eprint={2401.14019},
80
- archivePrefix={arXiv},
81
- primaryClass={cs.CL}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  }
83
  ```
84
 
 
 
11
  <img src="https://raw.githubusercontent.com/IBM/unitxt/main/assets/banner.png" alt="Image Description" width="100%" />
12
  </div>
13
 
14
+ [![Button](https://img.shields.io/badge/Video-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/_static/video.mov)
15
+ [![Button](https://img.shields.io/badge/Documentation-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/introduction.html)
16
  [![Button](https://img.shields.io/badge/Demo-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/demo.html)
17
  [![Button](https://img.shields.io/badge/Tutorial-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/adding_dataset.html)
18
  [![Button](https://img.shields.io/badge/Paper-pink?style=for-the-badge)](https://arxiv.org/abs/2401.14019)
 
19
  [![Button](https://img.shields.io/badge/Catalog-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/catalog/catalog.__dir__.html)
20
  [![Button](https://img.shields.io/badge/Contributors-pink?style=for-the-badge)](https://github.com/IBM/unitxt/blob/main/CONTRIBUTING.md)
21
  [![Button](https://img.shields.io/badge/PyPi-pink?style=for-the-badge)](https://pypi.org/project/unitxt/)
 
72
  If you use Unitxt in your research, please cite our paper:
73
 
74
  ```
75
+ @inproceedings{bandel-etal-2024-unitxt,
76
+ title = "Unitxt: Flexible, Shareable and Reusable Data Preparation and Evaluation for Generative {AI}",
77
+ author = "Bandel, Elron and
78
+ Perlitz, Yotam and
79
+ Venezian, Elad and
80
+ Friedman, Roni and
81
+ Arviv, Ofir and
82
+ Orbach, Matan and
83
+ Don-Yehiya, Shachar and
84
+ Sheinwald, Dafna and
85
+ Gera, Ariel and
86
+ Choshen, Leshem and
87
+ Shmueli-Scheuer, Michal and
88
+ Katz, Yoav",
89
+ editor = "Chang, Kai-Wei and
90
+ Lee, Annie and
91
+ Rajani, Nazneen",
92
+ booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 3: System Demonstrations)",
93
+ month = jun,
94
+ year = "2024",
95
+ address = "Mexico City, Mexico",
96
+ publisher = "Association for Computational Linguistics",
97
+ url = "https://aclanthology.org/2024.naacl-demo.21",
98
+ pages = "207--215",
99
+ abstract = "In the dynamic landscape of generative NLP, traditional text processing pipelines limit research flexibility and reproducibility, as they are tailored to specific dataset, task, and model combinations. The escalating complexity, involving system prompts, model-specific formats, instructions, and more, calls for a shift to a structured, modular, and customizable solution.Addressing this need, we present Unitxt, an innovative library for customizable textual data preparation and evaluation tailored to generative language models. Unitxt natively integrates with common libraries like HuggingFace and LM-eval-harness and deconstructs processing flows into modular components, enabling easy customization and sharing between practitioners. These components encompass model-specific formats, task prompts, and many other comprehensive dataset processing definitions. The Unitxt Catalog centralizes these components, fostering collaboration and exploration in modern textual data workflows. Beyond being a tool, Unitxt is a community-driven platform, empowering users to build, share, and advance their pipelines collaboratively. Join the Unitxt community at https://github.com/IBM/unitxt",
100
  }
101
  ```
102
 
103
+ Unitxt emoji designed by [OpenMoji](https://openmoji.org/#) - the open-source emoji and icon project. License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/#)
api.py CHANGED
@@ -1,13 +1,14 @@
1
  from functools import lru_cache
2
- from typing import Any, Dict, List, Union
3
 
4
  from datasets import DatasetDict
5
 
6
  from .artifact import fetch_artifact
7
  from .dataset_utils import get_dataset_artifact
8
  from .logging_utils import get_logger
9
- from .metric_utils import _compute
10
  from .operator import SourceOperator
 
11
 
12
  logger = get_logger()
13
 
@@ -21,16 +22,79 @@ def load(source: Union[SourceOperator, str]) -> DatasetDict:
21
  return source().to_dataset()
22
 
23
 
24
- def load_dataset(dataset_query: str) -> DatasetDict:
25
  dataset_query = dataset_query.replace("sys_prompt", "instruction")
26
  dataset_stream = get_dataset_artifact(dataset_query)
27
  return dataset_stream().to_dataset()
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def evaluate(predictions, data) -> List[Dict[str, Any]]:
31
  return _compute(predictions=predictions, references=data)
32
 
33
 
 
 
 
 
34
  @lru_cache
35
  def _get_produce_with_cache(recipe_query):
36
  return get_dataset_artifact(recipe_query).produce
@@ -44,3 +108,10 @@ def produce(instance_or_instances, recipe_query):
44
  if not is_list:
45
  result = result[0]
46
  return result
 
 
 
 
 
 
 
 
1
  from functools import lru_cache
2
+ from typing import Any, Dict, List, Optional, Union
3
 
4
  from datasets import DatasetDict
5
 
6
  from .artifact import fetch_artifact
7
  from .dataset_utils import get_dataset_artifact
8
  from .logging_utils import get_logger
9
+ from .metric_utils import _compute, _post_process
10
  from .operator import SourceOperator
11
+ from .standard import StandardRecipe
12
 
13
  logger = get_logger()
14
 
 
22
  return source().to_dataset()
23
 
24
 
25
+ def _load_dataset_from_query(dataset_query: str) -> DatasetDict:
26
  dataset_query = dataset_query.replace("sys_prompt", "instruction")
27
  dataset_stream = get_dataset_artifact(dataset_query)
28
  return dataset_stream().to_dataset()
29
 
30
 
31
+ def _load_dataset_from_dict(dataset_params: Dict[str, Any]) -> DatasetDict:
32
+ recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
33
+ for param in dataset_params.keys():
34
+ assert param in recipe_attributes, (
35
+ f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
36
+ f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
37
+ )
38
+ recipe = StandardRecipe(**dataset_params)
39
+ return recipe().to_dataset()
40
+
41
+
42
+ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
43
+ """Loads dataset.
44
+
45
+ If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
46
+ catalog based on parameters specified in the query.
47
+ Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
48
+
49
+ Args:
50
+ dataset_query (str, optional): A string query which specifies dataset to load from local catalog.
51
+ For example:
52
+ "card=cards.wnli,template=templates.classification.multi_class.relation.default".
53
+ **kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
54
+
55
+ Returns:
56
+ DatasetDict
57
+
58
+ Examples:
59
+ dataset = load_dataset(
60
+ dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
61
+ ) # card must be present in local catalog
62
+
63
+ card = TaskCard(...)
64
+ template = Template(...)
65
+ loader_limit = 10
66
+ dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
67
+ """
68
+ if dataset_query and kwargs:
69
+ raise ValueError(
70
+ "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
71
+ "If you want to load dataset from a card in local catalog, use query only. "
72
+ "Otherwise, use key-worded arguments only to specify properties of dataset."
73
+ )
74
+
75
+ if dataset_query:
76
+ if not isinstance(dataset_query, str):
77
+ raise ValueError(
78
+ f"If specified, 'dataset_query' must be a string, however, "
79
+ f"'{dataset_query}' was provided instead, which is of type "
80
+ f"'{type(dataset_query)}'."
81
+ )
82
+ return _load_dataset_from_query(dataset_query)
83
+
84
+ if kwargs:
85
+ return _load_dataset_from_dict(kwargs)
86
+
87
+ raise ValueError("Either 'dataset_query' or key-worded arguments must be provided.")
88
+
89
+
90
  def evaluate(predictions, data) -> List[Dict[str, Any]]:
91
  return _compute(predictions=predictions, references=data)
92
 
93
 
94
+ def post_process(predictions, data) -> List[Dict[str, Any]]:
95
+ return _post_process(predictions=predictions, references=data)
96
+
97
+
98
  @lru_cache
99
  def _get_produce_with_cache(recipe_query):
100
  return get_dataset_artifact(recipe_query).produce
 
108
  if not is_list:
109
  result = result[0]
110
  return result
111
+
112
+
113
+ def infer(instance_or_instances, recipe, engine):
114
+ dataset = produce(instance_or_instances, recipe)
115
+ engine, _ = fetch_artifact(engine)
116
+ predictions = engine.infer(dataset)
117
+ return post_process(predictions, dataset)
artifact.py CHANGED
@@ -3,9 +3,10 @@ import inspect
3
  import json
4
  import os
5
  import pkgutil
 
6
  from abc import abstractmethod
7
  from copy import deepcopy
8
- from typing import Any, Dict, List, Optional, Union, final
9
 
10
  from .dataclass import (
11
  AbstractField,
@@ -19,13 +20,24 @@ from .logging_utils import get_logger
19
  from .parsing_utils import (
20
  separate_inside_and_outside_square_brackets,
21
  )
22
- from .settings_utils import get_settings
23
  from .text_utils import camel_to_snake_case, is_camel_case
24
  from .type_utils import issubtype
25
- from .utils import artifacts_json_cache, save_json
26
 
27
  logger = get_logger()
28
  settings = get_settings()
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  class Artifactories:
@@ -120,7 +132,7 @@ class MissingArtifactTypeError(ValueError):
120
  class Artifact(Dataclass):
121
  _class_register = {}
122
 
123
- type: str = Field(default=None, final=True, init=False)
124
  __description__: str = NonPositionalField(
125
  default=None, required=False, also_positional=False
126
  )
@@ -135,7 +147,7 @@ class Artifact(Dataclass):
135
 
136
  @classmethod
137
  def is_artifact_dict(cls, d):
138
- return isinstance(d, dict) and "type" in d
139
 
140
  @classmethod
141
  def verify_artifact_dict(cls, d):
@@ -143,10 +155,10 @@ class Artifact(Dataclass):
143
  raise ValueError(
144
  f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'."
145
  )
146
- if "type" not in d:
147
  raise MissingArtifactTypeError(d)
148
- if not cls.is_registered_type(d["type"]):
149
- raise UnrecognizedArtifactTypeError(d["type"])
150
 
151
  @classmethod
152
  def get_artifact_type(cls):
@@ -212,7 +224,7 @@ class Artifact(Dataclass):
212
  pass
213
  if cls.is_artifact_dict(obj):
214
  cls.verify_artifact_dict(obj)
215
- return cls._class_register[obj.pop("type")](**obj)
216
 
217
  return obj
218
 
@@ -261,7 +273,7 @@ class Artifact(Dataclass):
261
 
262
  @final
263
  def __post_init__(self):
264
- self.type = self.register_class(self.__class__)
265
 
266
  for field in fields(self):
267
  if issubtype(
@@ -277,11 +289,24 @@ class Artifact(Dataclass):
277
  self.verify()
278
 
279
  def _to_raw_dict(self):
280
- return {"type": self.type, **self._init_dict}
281
 
282
- def save(self, path):
283
  data = self.to_dict()
284
- save_json(path, data)
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  def verify_instance(
287
  self, instance: Dict[str, Any], name: Optional[str] = None
@@ -404,13 +429,22 @@ class UnitxtArtifactNotFoundError(Exception):
404
  return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}"
405
 
406
 
407
- def fetch_artifact(name):
408
- if Artifact.is_artifact_file(name):
409
- return Artifact.load(name), None
 
 
410
 
411
- artifactory, name, args = get_artifactory_name_and_args(name=name)
 
 
 
 
 
 
 
412
 
413
- return artifactory.get_with_overwrite(name, overwrite_args=args), artifactory
414
 
415
 
416
  def get_artifactory_name_and_args(
 
3
  import json
4
  import os
5
  import pkgutil
6
+ import re
7
  from abc import abstractmethod
8
  from copy import deepcopy
9
+ from typing import Any, Dict, List, Optional, Tuple, Union, final
10
 
11
  from .dataclass import (
12
  AbstractField,
 
20
  from .parsing_utils import (
21
  separate_inside_and_outside_square_brackets,
22
  )
23
+ from .settings_utils import get_constants, get_settings
24
  from .text_utils import camel_to_snake_case, is_camel_case
25
  from .type_utils import issubtype
26
+ from .utils import artifacts_json_cache, json_dump, save_to_file
27
 
28
  logger = get_logger()
29
  settings = get_settings()
30
+ constants = get_constants()
31
+
32
+
33
+ def is_name_legal_for_catalog(name):
34
+ return re.match(r"^[\w" + constants.catalog_hierarchy_sep + "]+$", name)
35
+
36
+
37
+ def verify_legal_catalog_name(name):
38
+ assert is_name_legal_for_catalog(
39
+ name
40
+ ), f'Artifict name ("{name}") should be alphanumeric. Use "." for nesting (e.g. myfolder.my_artifact)'
41
 
42
 
43
  class Artifactories:
 
132
  class Artifact(Dataclass):
133
  _class_register = {}
134
 
135
+ __type__: str = Field(default=None, final=True, init=False)
136
  __description__: str = NonPositionalField(
137
  default=None, required=False, also_positional=False
138
  )
 
147
 
148
  @classmethod
149
  def is_artifact_dict(cls, d):
150
+ return isinstance(d, dict) and "__type__" in d
151
 
152
  @classmethod
153
  def verify_artifact_dict(cls, d):
 
155
  raise ValueError(
156
  f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'."
157
  )
158
+ if "__type__" not in d:
159
  raise MissingArtifactTypeError(d)
160
+ if not cls.is_registered_type(d["__type__"]):
161
+ raise UnrecognizedArtifactTypeError(d["__type__"])
162
 
163
  @classmethod
164
  def get_artifact_type(cls):
 
224
  pass
225
  if cls.is_artifact_dict(obj):
226
  cls.verify_artifact_dict(obj)
227
+ return cls._class_register[obj.pop("__type__")](**obj)
228
 
229
  return obj
230
 
 
273
 
274
  @final
275
  def __post_init__(self):
276
+ self.__type__ = self.register_class(self.__class__)
277
 
278
  for field in fields(self):
279
  if issubtype(
 
289
  self.verify()
290
 
291
  def _to_raw_dict(self):
292
+ return {"__type__": self.__type__, **self._init_dict}
293
 
294
+ def to_json(self):
295
  data = self.to_dict()
296
+ return json_dump(data)
297
+
298
+ def serialize(self):
299
+ if self.__id__ is not None:
300
+ return self.__id__
301
+ return self.to_json()
302
+
303
+ def save(self, path):
304
+ save_to_file(path, self.to_json())
305
+
306
+ @classmethod
307
+ def deserialize(cls, artifact_rep):
308
+ data = json.loads(artifact_rep)
309
+ return Artifact.from_dict(data)
310
 
311
  def verify_instance(
312
  self, instance: Dict[str, Any], name: Optional[str] = None
 
429
  return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}"
430
 
431
 
432
+ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[Artifactory, None]]:
433
+ if isinstance(artifact_rep, Artifact):
434
+ return artifact_rep, None
435
+ if Artifact.is_artifact_file(artifact_rep):
436
+ return Artifact.load(artifact_rep), None
437
 
438
+ name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
439
+ if is_name_legal_for_catalog(name):
440
+ artifactory, artifact_rep, args = get_artifactory_name_and_args(
441
+ name=artifact_rep
442
+ )
443
+ return artifactory.get_with_overwrite(
444
+ artifact_rep, overwrite_args=args
445
+ ), artifactory
446
 
447
+ return Artifact.deserialize(artifact_rep), None
448
 
449
 
450
  def get_artifactory_name_and_args(
blocks.py CHANGED
@@ -8,13 +8,13 @@ from .loaders import LoadFromIBMCloud, LoadFromKaggle, LoadHF
8
  from .metrics import Accuracy
9
  from .normalizers import NormalizeListFields
10
  from .operators import (
11
- AddFields,
12
  AddID,
13
  CastFields,
14
- CopyFields,
15
  DivideAllFieldsBy,
16
  MapInstanceValues,
17
  RenameFields,
 
18
  )
19
  from .processors import ToString, ToStringStripped
20
  from .recipe import SequentialRecipe
 
8
  from .metrics import Accuracy
9
  from .normalizers import NormalizeListFields
10
  from .operators import (
 
11
  AddID,
12
  CastFields,
13
+ Copy,
14
  DivideAllFieldsBy,
15
  MapInstanceValues,
16
  RenameFields,
17
+ Set,
18
  )
19
  from .processors import ToString, ToStringStripped
20
  from .recipe import SequentialRecipe
catalog.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import re
3
  from collections import Counter
4
  from functools import lru_cache
5
  from pathlib import Path
@@ -13,6 +12,7 @@ from .artifact import (
13
  Artifactory,
14
  get_artifactory_name_and_args,
15
  reset_artifacts_json_cache,
 
16
  )
17
  from .logging_utils import get_logger
18
  from .settings_utils import get_constants
@@ -114,12 +114,6 @@ class GithubCatalog(LocalCatalog):
114
  return response.status_code == 200
115
 
116
 
117
- def verify_legal_catalog_name(name):
118
- assert re.match(
119
- r"^[\w" + constants.catalog_hierarchy_sep + "]+$", name
120
- ), f'Artifict name ("{name}") should be alphanumeric. Use "." for nesting (e.g. myfolder.my_artifact)'
121
-
122
-
123
  def add_to_catalog(
124
  artifact: Artifact,
125
  name: str,
 
1
  import os
 
2
  from collections import Counter
3
  from functools import lru_cache
4
  from pathlib import Path
 
12
  Artifactory,
13
  get_artifactory_name_and_args,
14
  reset_artifacts_json_cache,
15
+ verify_legal_catalog_name,
16
  )
17
  from .logging_utils import get_logger
18
  from .settings_utils import get_constants
 
114
  return response.status_code == 200
115
 
116
 
 
 
 
 
 
 
117
  def add_to_catalog(
118
  artifact: Artifact,
119
  name: str,
dataset.py CHANGED
@@ -10,6 +10,7 @@ from .catalog import __file__ as _
10
  from .collections import __file__ as _
11
  from .collections_operators import __file__ as _
12
  from .dataclass import __file__ as _
 
13
  from .dataset_utils import get_dataset_artifact
14
  from .deprecation_utils import __file__ as _
15
  from .dialog_operators import __file__ as _
@@ -19,11 +20,13 @@ from .file_utils import __file__ as _
19
  from .formats import __file__ as _
20
  from .fusion import __file__ as _
21
  from .generator_utils import __file__ as _
 
22
  from .hf_utils import verify_versions_compatibility
23
  from .inference import __file__ as _
24
  from .instructions import __file__ as _
25
  from .llm_as_judge import __file__ as _
26
  from .loaders import __file__ as _
 
27
  from .logging_utils import get_logger
28
  from .metric import __file__ as _
29
  from .metric_utils import __file__ as _
@@ -37,6 +40,7 @@ from .random_utils import __file__ as _
37
  from .recipe import __file__ as _
38
  from .register import __file__ as _
39
  from .schema import __file__ as _
 
40
  from .settings_utils import get_constants
41
  from .span_lableing_operators import __file__ as _
42
  from .split_utils import __file__ as _
@@ -51,6 +55,7 @@ from .task import __file__ as _
51
  from .templates import __file__ as _
52
  from .text_utils import __file__ as _
53
  from .type_utils import __file__ as _
 
54
  from .utils import is_package_installed
55
  from .validate import __file__ as _
56
  from .version import __file__ as _
@@ -71,9 +76,8 @@ class Dataset(datasets.GeneratorBasedBuilder):
71
  if is_package_installed("unitxt"):
72
  verify_versions_compatibility("dataset", self.VERSION)
73
 
74
- from unitxt.dataset_utils import (
75
- get_dataset_artifact as get_dataset_artifact_installed,
76
- )
77
 
78
  logger.info("Loading with installed unitxt library...")
79
  dataset = get_dataset_artifact_installed(self.config.name)
 
10
  from .collections import __file__ as _
11
  from .collections_operators import __file__ as _
12
  from .dataclass import __file__ as _
13
+ from .dataset_utils import __file__ as _
14
  from .dataset_utils import get_dataset_artifact
15
  from .deprecation_utils import __file__ as _
16
  from .dialog_operators import __file__ as _
 
20
  from .formats import __file__ as _
21
  from .fusion import __file__ as _
22
  from .generator_utils import __file__ as _
23
+ from .hf_utils import __file__ as _
24
  from .hf_utils import verify_versions_compatibility
25
  from .inference import __file__ as _
26
  from .instructions import __file__ as _
27
  from .llm_as_judge import __file__ as _
28
  from .loaders import __file__ as _
29
+ from .logging_utils import __file__ as _
30
  from .logging_utils import get_logger
31
  from .metric import __file__ as _
32
  from .metric_utils import __file__ as _
 
40
  from .recipe import __file__ as _
41
  from .register import __file__ as _
42
  from .schema import __file__ as _
43
+ from .settings_utils import __file__ as _
44
  from .settings_utils import get_constants
45
  from .span_lableing_operators import __file__ as _
46
  from .split_utils import __file__ as _
 
55
  from .templates import __file__ as _
56
  from .text_utils import __file__ as _
57
  from .type_utils import __file__ as _
58
+ from .utils import __file__ as _
59
  from .utils import is_package_installed
60
  from .validate import __file__ as _
61
  from .version import __file__ as _
 
76
  if is_package_installed("unitxt"):
77
  verify_versions_compatibility("dataset", self.VERSION)
78
 
79
+ from unitxt.dataset_utils import \
80
+ get_dataset_artifact as get_dataset_artifact_installed
 
81
 
82
  logger.info("Loading with installed unitxt library...")
83
  dataset = get_dataset_artifact_installed(self.config.name)
dataset_utils.py CHANGED
@@ -1,8 +1,11 @@
 
 
1
  from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
2
  from .logging_utils import get_logger
3
  from .parsing_utils import parse_key_equals_value_string_to_dict
4
  from .register import _reset_env_local_catalogs, register_all_artifacts
5
  from .settings_utils import get_settings
 
6
 
7
  logger = get_logger()
8
  settings = get_settings()
@@ -12,7 +15,7 @@ def fetch(artifact_name):
12
  try:
13
  artifact, _ = fetch_artifact(artifact_name)
14
  return artifact
15
- except UnitxtArtifactNotFoundError:
16
  return None
17
 
18
 
@@ -20,13 +23,18 @@ def parse(query: str):
20
  return parse_key_equals_value_string_to_dict(query)
21
 
22
 
23
- def get_dataset_artifact(dataset_str):
 
 
 
 
 
24
  _reset_env_local_catalogs()
25
  register_all_artifacts()
26
- recipe = fetch(dataset_str)
27
  if recipe is None:
28
- args = parse(dataset_str)
29
- if "type" not in args:
30
- args["type"] = settings.default_recipe
31
  recipe = Artifact.from_dict(args)
32
  return recipe
 
1
+ from json.decoder import JSONDecodeError
2
+
3
  from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
4
  from .logging_utils import get_logger
5
  from .parsing_utils import parse_key_equals_value_string_to_dict
6
  from .register import _reset_env_local_catalogs, register_all_artifacts
7
  from .settings_utils import get_settings
8
+ from .standard import BaseRecipe
9
 
10
  logger = get_logger()
11
  settings = get_settings()
 
15
  try:
16
  artifact, _ = fetch_artifact(artifact_name)
17
  return artifact
18
+ except (UnitxtArtifactNotFoundError, JSONDecodeError):
19
  return None
20
 
21
 
 
23
  return parse_key_equals_value_string_to_dict(query)
24
 
25
 
26
+ def get_dataset_artifact(dataset):
27
+ if isinstance(dataset, BaseRecipe):
28
+ return dataset
29
+ assert isinstance(
30
+ dataset, str
31
+ ), "dataset should be string description of recipe, or recipe object."
32
  _reset_env_local_catalogs()
33
  register_all_artifacts()
34
+ recipe = fetch(dataset)
35
  if recipe is None:
36
+ args = parse(dataset)
37
+ if "__type__" not in args:
38
+ args["__type__"] = settings.default_recipe
39
  recipe = Artifact.from_dict(args)
40
  return recipe
deprecation_utils.py CHANGED
@@ -1,9 +1,10 @@
1
  import functools
2
  import warnings
3
 
4
- from .settings_utils import get_constants
5
 
6
  constants = get_constants()
 
7
 
8
 
9
  class DeprecationError(Exception):
@@ -60,9 +61,12 @@ def depraction_wrapper(obj, version, alt_text):
60
  @functools.wraps(obj)
61
  def wrapper(*args, **kwargs):
62
  if constants.version < version:
63
- warnings.warn(
64
- f"{obj.__name__} is deprecated.", DeprecationWarning, stacklevel=2
65
- )
 
 
 
66
  elif constants.version >= version:
67
  raise DeprecationError(f"{obj.__name__} is no longer supported.{alt_text}")
68
  return obj(*args, **kwargs)
@@ -82,7 +86,7 @@ def deprecation(version, alternative=None):
82
  """
83
 
84
  def decorator(obj):
85
- alt_text = f" Use {alternative} instead." if alternative else ""
86
  if callable(obj):
87
  func = obj
88
  elif hasattr(obj, "__init__"):
 
1
  import functools
2
  import warnings
3
 
4
+ from .settings_utils import get_constants, get_settings
5
 
6
  constants = get_constants()
7
+ settings = get_settings()
8
 
9
 
10
  class DeprecationError(Exception):
 
61
  @functools.wraps(obj)
62
  def wrapper(*args, **kwargs):
63
  if constants.version < version:
64
+ if settings.default_verbosity in ["debug", "info", "warning"]:
65
+ warnings.warn(
66
+ f"{obj.__name__} is deprecated.{alt_text}",
67
+ DeprecationWarning,
68
+ stacklevel=2,
69
+ )
70
  elif constants.version >= version:
71
  raise DeprecationError(f"{obj.__name__} is no longer supported.{alt_text}")
72
  return obj(*args, **kwargs)
 
86
  """
87
 
88
  def decorator(obj):
89
+ alt_text = f" Use {alternative} instead." if alternative is not None else ""
90
  if callable(obj):
91
  func = obj
92
  elif hasattr(obj, "__init__"):
dialog_operators.py CHANGED
@@ -11,6 +11,7 @@ dialog = [
11
  {"user": "kkk", "system": ""},
12
  ]
13
  """
 
14
  from typing import Any, Dict, List, Optional
15
 
16
  from .formats import SystemFormat
 
11
  {"user": "kkk", "system": ""},
12
  ]
13
  """
14
+
15
  from typing import Any, Dict, List, Optional
16
 
17
  from .formats import SystemFormat
fusion.py CHANGED
@@ -103,9 +103,10 @@ class WeightedFusion(BaseFusion):
103
  If None, all instances are returned
104
  """
105
 
106
- origins: Union[Dict[str, MultiStream], List[MultiStream]] = None
107
  weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
108
  max_total_examples: int = None
 
109
 
110
  def verify(self):
111
  super().verify()
@@ -149,7 +150,10 @@ class WeightedFusion(BaseFusion):
149
  try:
150
  instance = next(iterator)
151
  if isinstance(origin_name, str):
152
- if "group" in instance:
 
 
 
153
  instance["group"] = origin_name + "/" + instance["group"]
154
  else:
155
  instance["group"] = origin_name
 
103
  If None, all instances are returned
104
  """
105
 
106
+ origins: Union[Dict[str, SourceOperator], List[SourceOperator]] = None
107
  weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
108
  max_total_examples: int = None
109
+ ignore_origin_groups: List[str] = ["unitxt"]
110
 
111
  def verify(self):
112
  super().verify()
 
150
  try:
151
  instance = next(iterator)
152
  if isinstance(origin_name, str):
153
+ if (
154
+ "group" in instance
155
+ and instance["group"] not in self.ignore_origin_groups
156
+ ):
157
  instance["group"] = origin_name + "/" + instance["group"]
158
  else:
159
  instance["group"] = origin_name
inference.py CHANGED
@@ -3,6 +3,8 @@ import os
3
  from dataclasses import field
4
  from typing import Any, Dict, List, Literal, Optional, Union
5
 
 
 
6
  from .artifact import Artifact
7
  from .operator import PackageRequirementsMixin
8
 
@@ -15,12 +17,31 @@ class InferenceEngine(abc.ABC, Artifact):
15
  """Perform inference on the input dataset."""
16
  pass
17
 
18
- def infer(self, dataset):
19
  """Verifies instances of a dataset and performs inference."""
20
  [self.verify_instance(instance) for instance in dataset]
21
  return self._infer(dataset)
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
25
  model_name: str
26
  max_new_tokens: int
@@ -158,9 +179,12 @@ class OpenAiInferenceEngineParams(Artifact):
158
  stop: Union[Optional[str], List[str]] = None
159
  temperature: Optional[float] = None
160
  top_p: Optional[float] = None
 
161
 
162
 
163
- class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
 
 
164
  label: str = "openai"
165
  model_name: str
166
  parameters: OpenAiInferenceEngineParams = field(
@@ -169,6 +193,7 @@ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
169
  _requirement = {
170
  "openai": "Install openai package using 'pip install --upgrade openai"
171
  }
 
172
 
173
  def prepare(self):
174
  from openai import OpenAI
@@ -183,8 +208,38 @@ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
183
  self.client = OpenAI(api_key=api_key)
184
 
185
  def _infer(self, dataset):
186
- return [
187
- self.client.chat.completions.create(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  messages=[
189
  # {
190
  # "role": "system",
@@ -203,6 +258,159 @@ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
203
  stop=self.parameters.stop,
204
  temperature=self.parameters.temperature,
205
  top_p=self.parameters.top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  )
207
  for instance in dataset
208
  ]
 
3
  from dataclasses import field
4
  from typing import Any, Dict, List, Literal, Optional, Union
5
 
6
+ from tqdm import tqdm
7
+
8
  from .artifact import Artifact
9
  from .operator import PackageRequirementsMixin
10
 
 
17
  """Perform inference on the input dataset."""
18
  pass
19
 
20
+ def infer(self, dataset) -> str:
21
  """Verifies instances of a dataset and performs inference."""
22
  [self.verify_instance(instance) for instance in dataset]
23
  return self._infer(dataset)
24
 
25
 
26
+ class LogProbInferenceEngine(abc.ABC, Artifact):
27
+ """Abstract base class for inference with log probs."""
28
+
29
+ @abc.abstractmethod
30
+ def _infer_log_probs(self, dataset):
31
+ """Perform inference on the input dataset that returns log probs."""
32
+ pass
33
+
34
+ def infer_log_probs(self, dataset) -> List[Dict]:
35
+ """Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
36
+
37
+ For each instance , returns a list of top tokens per position.
38
+ [ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
39
+
40
+ """
41
+ [self.verify_instance(instance) for instance in dataset]
42
+ return self._infer_log_probs(dataset)
43
+
44
+
45
  class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
46
  model_name: str
47
  max_new_tokens: int
 
179
  stop: Union[Optional[str], List[str]] = None
180
  temperature: Optional[float] = None
181
  top_p: Optional[float] = None
182
+ top_logprobs: Optional[int] = 20
183
 
184
 
185
+ class OpenAiInferenceEngine(
186
+ InferenceEngine, LogProbInferenceEngine, PackageRequirementsMixin
187
+ ):
188
  label: str = "openai"
189
  model_name: str
190
  parameters: OpenAiInferenceEngineParams = field(
 
193
  _requirement = {
194
  "openai": "Install openai package using 'pip install --upgrade openai"
195
  }
196
+ data_classification_policy = ["public"]
197
 
198
  def prepare(self):
199
  from openai import OpenAI
 
208
  self.client = OpenAI(api_key=api_key)
209
 
210
  def _infer(self, dataset):
211
+ outputs = []
212
+ for instance in tqdm(dataset, desc="Inferring with openAI API"):
213
+ response = self.client.chat.completions.create(
214
+ messages=[
215
+ # {
216
+ # "role": "system",
217
+ # "content": self.system_prompt,
218
+ # },
219
+ {
220
+ "role": "user",
221
+ "content": instance["source"],
222
+ }
223
+ ],
224
+ model=self.model_name,
225
+ frequency_penalty=self.parameters.frequency_penalty,
226
+ presence_penalty=self.parameters.presence_penalty,
227
+ max_tokens=self.parameters.max_tokens,
228
+ seed=self.parameters.seed,
229
+ stop=self.parameters.stop,
230
+ temperature=self.parameters.temperature,
231
+ top_p=self.parameters.top_p,
232
+ )
233
+ output = response.choices[0].message.content
234
+
235
+ outputs.append(output)
236
+
237
+ return outputs
238
+
239
+ def _infer_log_probs(self, dataset):
240
+ outputs = []
241
+ for instance in tqdm(dataset, desc="Inferring with openAI API"):
242
+ response = self.client.chat.completions.create(
243
  messages=[
244
  # {
245
  # "role": "system",
 
258
  stop=self.parameters.stop,
259
  temperature=self.parameters.temperature,
260
  top_p=self.parameters.top_p,
261
+ logprobs=True,
262
+ top_logprobs=self.parameters.top_logprobs,
263
+ )
264
+ top_logprobs_response = response.choices[0].logprobs.content
265
+ output = [
266
+ {
267
+ "top_tokens": [
268
+ {"text": obj.token, "logprob": obj.logprob}
269
+ for obj in generated_token.top_logprobs
270
+ ]
271
+ }
272
+ for generated_token in top_logprobs_response
273
+ ]
274
+ outputs.append(output)
275
+ return outputs
276
+
277
+
278
+ class WMLInferenceEngineParams(Artifact):
279
+ decoding_method: Optional[Literal["greedy", "sample"]] = None
280
+ length_penalty: Optional[Dict[str, Union[int, float]]] = None
281
+ temperature: Optional[float] = None
282
+ top_p: Optional[float] = None
283
+ top_k: Optional[int] = None
284
+ random_seed: Optional[int] = None
285
+ repetition_penalty: Optional[float] = None
286
+ min_new_tokens: Optional[int] = None
287
+ max_new_tokens: Optional[int] = None
288
+ stop_sequences: Optional[List[str]] = None
289
+ time_limit: Optional[int] = None
290
+ truncate_input_tokens: Optional[int] = None
291
+ prompt_variables: Optional[Dict[str, Any]] = None
292
+ return_options: Optional[Dict[str, bool]] = None
293
+
294
+ def initialize_wml_parameters(self) -> Dict[str, Any]:
295
+ from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
296
+
297
+ return {
298
+ param_name.upper(): param_value
299
+ for param_name, param_value in self.to_dict().items()
300
+ if param_value and param_name.upper() in GenTextParamsMetaNames().get()
301
+ }
302
+
303
+
304
+ class WMLInferenceEngine(InferenceEngine, PackageRequirementsMixin):
305
+ """Runs inference using ibm-watsonx-ai.
306
+
307
+ Attributes:
308
+ client: By default, it is created by a class instance but can be directly
309
+ provided instead as an instance of 'ibm_watsonx_ai.client.APIClient'.
310
+ credentials: By default, it is created by a class instance which tries to retrieve
311
+ proper environment variables ("WML_URL", "WML_PROJECT_ID", "WML_APIKEY").
312
+ However, either a dictionary with the following keys: "url", "apikey",
313
+ "project_id", or an instance of 'ibm_watsonx_ai.credentials.Credentials'
314
+ can be directly provided instead.
315
+ model_name (str, optional): ID of a model to be used for inference. Mutually
316
+ exclusive with 'deployment_id'.
317
+ deployment_id (str, optional): Deployment ID of a tuned model to be used for
318
+ inference. Mutually exclusive with 'model_name'.
319
+ parameters (WMLInferenceEngineParams): An instance of 'WMLInferenceEngineParams'
320
+ which defines parameters used for inference. All the parameters are optional.
321
+
322
+ Examples:
323
+ from .api import load_dataset
324
+
325
+ wml_parameters = WMLInferenceEngineParams(top_p=0.5, random_seed=123)
326
+ wml_credentials = {
327
+ "url": "some_url", "project_id": "some_id", "api_key": "some_key"
328
+ }
329
+ model_name = "google/flan-t5-xxl"
330
+ wml_inference = WMLInferenceEngine(
331
+ credentials=wml_credentials,
332
+ parameters=wml_parameters,
333
+ model_name=model_name,
334
+ )
335
+
336
+ dataset = load_dataset(
337
+ dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5"
338
+ )
339
+ results = wml_inference.infer(dataset["test"])
340
+ """
341
+
342
+ client = None
343
+ credentials = None
344
+ model_name: Optional[str] = None
345
+ deployment_id: Optional[str] = None
346
+ parameters: WMLInferenceEngineParams = field(
347
+ default_factory=WMLInferenceEngineParams
348
+ )
349
+
350
+ _parameters: Dict[str, Any] = field(default_factory=dict)
351
+
352
+ label: str = "wml"
353
+ _requirement = {
354
+ "ibm-watsonx-ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
355
+ "It is advised to have Python version >=3.10 installed, as at lower version this package "
356
+ "may cause conflicts with other installed packages."
357
+ }
358
+
359
+ data_classification_policy = ["proprietary"]
360
+
361
+ @staticmethod
362
+ def _read_wml_credentials_from_env() -> Dict[str, str]:
363
+ credentials = {}
364
+ for env_var_name in ["WML_URL", "WML_PROJECT_ID", "WML_APIKEY"]:
365
+ env_var = os.environ.get(env_var_name)
366
+ assert env_var, (
367
+ f"Error while trying to run 'WMLInferenceEngine'. "
368
+ f"Please set the env variable: '{env_var_name}', or "
369
+ f"directly provide an instance of ibm-watsonx-ai 'Credentials' "
370
+ f"to the engine."
371
+ )
372
+
373
+ name = env_var_name.lower().replace("wml_", "")
374
+ credentials[name] = env_var
375
+
376
+ return credentials
377
+
378
+ def _initialize_wml_client(self):
379
+ from ibm_watsonx_ai.client import APIClient
380
+
381
+ if self.credentials is None:
382
+ self.credentials = self._read_wml_credentials_from_env()
383
+
384
+ client = APIClient(credentials=self.credentials)
385
+ client.set.default_project(self.credentials["project_id"])
386
+ return client
387
+
388
+ def prepare(self):
389
+ if self.client is None:
390
+ self.client = self._initialize_wml_client()
391
+ self._parameters = self.parameters.initialize_wml_parameters()
392
+
393
+ def verify(self):
394
+ assert (
395
+ self.model_name
396
+ or self.deployment_id
397
+ and not (self.model_name and self.deployment_id)
398
+ ), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time."
399
+ super().verify()
400
+
401
+ def _infer(self, dataset):
402
+ from ibm_watsonx_ai.foundation_models import ModelInference
403
+
404
+ model = ModelInference(
405
+ model_id=self.model_name,
406
+ deployment_id=self.deployment_id,
407
+ api_client=self.client,
408
+ )
409
+
410
+ return [
411
+ model.generate_text(
412
+ prompt=instance["source"],
413
+ params=self._parameters,
414
  )
415
  for instance in dataset
416
  ]
llm_as_judge.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Any, Dict, List, Literal, Optional
2
 
3
  from .api import evaluate, produce
 
4
  from .inference import InferenceEngine, OpenAiInferenceEngine
5
  from .metrics import BulkInstanceMetric
6
  from .operator import SequentialOperator
@@ -121,22 +122,25 @@ class LLMAsJudge(BulkInstanceMetric):
121
  )
122
 
123
  card = f"cards.dynamic_cards_for_llm_judges.{self.task}"
124
- recipe = (
125
- f"card={card},"
126
- f"template={self.template},"
127
- "demos_pool_size=0,"
128
- "num_demos=0"
129
- )
 
130
  if self.system_prompt:
131
- recipe = f"{recipe},system_prompt={self.system_prompt}"
132
  if self.format:
133
- recipe = f"{recipe},format={self.format}"
134
-
135
  dataset = produce(instances, recipe)
136
  verdicts = self.inference_model.infer(dataset)
137
  meta_scores = evaluate(predictions=verdicts, data=dataset)
138
  return [
139
- {self.main_score: instance["prediction"], "judge_raw_output": verdict}
140
- for instance in meta_scores
141
- for verdict in verdicts
 
 
142
  ]
 
1
  from typing import Any, Dict, List, Literal, Optional
2
 
3
  from .api import evaluate, produce
4
+ from .artifact import Artifact, settings
5
  from .inference import InferenceEngine, OpenAiInferenceEngine
6
  from .metrics import BulkInstanceMetric
7
  from .operator import SequentialOperator
 
122
  )
123
 
124
  card = f"cards.dynamic_cards_for_llm_judges.{self.task}"
125
+ recipe_args = {
126
+ "card": card,
127
+ "template": self.template,
128
+ "demos_pool_size": 0,
129
+ "num_demos": 0,
130
+ "__type__": settings.default_recipe,
131
+ }
132
  if self.system_prompt:
133
+ recipe_args["system_prompt"] = self.system_prompt
134
  if self.format:
135
+ recipe_args["format"] = self.format
136
+ recipe = Artifact.from_dict(recipe_args)
137
  dataset = produce(instances, recipe)
138
  verdicts = self.inference_model.infer(dataset)
139
  meta_scores = evaluate(predictions=verdicts, data=dataset)
140
  return [
141
+ {
142
+ self.main_score: instance["processed_prediction"],
143
+ "judge_raw_output": verdict,
144
+ }
145
+ for instance, verdict in zip(meta_scores, verdicts)
146
  ]
loaders.py CHANGED
@@ -30,6 +30,7 @@ Available Loaders Overview:
30
 
31
  ------------------------
32
  """
 
33
  import fnmatch
34
  import itertools
35
  import os
@@ -49,9 +50,10 @@ from .dataclass import InternalField, OptionalField
49
  from .fusion import FixedFusion
50
  from .logging_utils import get_logger
51
  from .operator import SourceOperator
52
- from .operators import AddFields
53
  from .settings_utils import get_settings
54
  from .stream import DynamicStream, MultiStream
 
55
 
56
  logger = get_logger()
57
  settings = get_settings()
@@ -110,7 +112,7 @@ class Loader(SourceOperator):
110
  f"data_classification_policy =['confidential','pii'])\n"
111
  )
112
 
113
- operator = AddFields(
114
  fields={"data_classification_policy": self.data_classification_policy}
115
  )
116
  return operator(multi_stream)
@@ -176,14 +178,13 @@ class LoadHF(Loader):
176
  self._requirements_list.append(requirement)
177
  super().verify()
178
 
179
- def filtered_load(self, dataset):
 
 
 
 
180
  logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
181
- return MultiStream(
182
- {
183
- name: dataset[name].filter(eval(self.filtering_lambda))
184
- for name in dataset
185
- }
186
- )
187
 
188
  def stream_dataset(self):
189
  if self._cache is None:
@@ -206,16 +207,17 @@ class LoadHF(Loader):
206
  ) from e
207
  raise e
208
 
209
- if self.filtering_lambda is not None:
210
- dataset = self.filtered_load(dataset)
211
-
212
  if self.split is not None:
213
  dataset = {self.split: dataset}
214
 
215
  self._cache = dataset
 
216
  else:
217
  dataset = self._cache
218
 
 
 
 
219
  return dataset
220
 
221
  def load_dataset(self):
@@ -239,9 +241,6 @@ class LoadHF(Loader):
239
  f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
240
  ) from e
241
 
242
- if self.filtering_lambda is not None:
243
- dataset = self.filtered_load(dataset)
244
-
245
  if self.split is None:
246
  for split in dataset.keys():
247
  dataset[split] = dataset[split].to_iterable_dataset()
@@ -249,20 +248,25 @@ class LoadHF(Loader):
249
  dataset = {self.split: dataset}
250
 
251
  self._cache = dataset
 
252
  else:
253
  dataset = self._cache
254
 
 
 
 
255
  return dataset
256
 
257
- def split_limited_load(self, split_name):
258
- yield from itertools.islice(self._cache[split_name], self.get_limit())
259
 
260
- def limited_load(self):
261
  self.log_limited_loading()
262
  return MultiStream(
263
  {
264
  name: DynamicStream(
265
- generator=self.split_limited_load, gen_kwargs={"split_name": name}
 
266
  )
267
  for name in self._cache.keys()
268
  }
@@ -285,7 +289,7 @@ class LoadHF(Loader):
285
  dataset = self.load_dataset()
286
 
287
  if self.get_limit() is not None:
288
- return self.limited_load()
289
 
290
  return MultiStream.from_iterables(dataset)
291
 
@@ -616,7 +620,7 @@ class LoadFromIBMCloud(Loader):
616
  object_key,
617
  local_dir + "/" + os.path.basename(temp_file.name),
618
  )
619
- os.rename(
620
  local_dir + "/" + os.path.basename(temp_file.name),
621
  local_dir + "/" + data_file,
622
  )
@@ -692,6 +696,25 @@ class LoadFromDictionary(Loader):
692
 
693
  data: Dict[str, List[Dict[str, Any]]]
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  def load_data(self) -> MultiStream:
696
  self.sef_default_data_classification(
697
  ["proprietary"], "when loading from python dictionary"
 
30
 
31
  ------------------------
32
  """
33
+
34
  import fnmatch
35
  import itertools
36
  import os
 
50
  from .fusion import FixedFusion
51
  from .logging_utils import get_logger
52
  from .operator import SourceOperator
53
+ from .operators import Set
54
  from .settings_utils import get_settings
55
  from .stream import DynamicStream, MultiStream
56
+ from .type_utils import isoftype
57
 
58
  logger = get_logger()
59
  settings = get_settings()
 
112
  f"data_classification_policy =['confidential','pii'])\n"
113
  )
114
 
115
+ operator = Set(
116
  fields={"data_classification_policy": self.data_classification_policy}
117
  )
118
  return operator(multi_stream)
 
178
  self._requirements_list.append(requirement)
179
  super().verify()
180
 
181
+ def filter_load(self, dataset):
182
+ if not settings.allow_unverified_code:
183
+ raise ValueError(
184
+ f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
185
+ )
186
  logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
187
+ return dataset.filter(eval(self.filtering_lambda))
 
 
 
 
 
188
 
189
  def stream_dataset(self):
190
  if self._cache is None:
 
207
  ) from e
208
  raise e
209
 
 
 
 
210
  if self.split is not None:
211
  dataset = {self.split: dataset}
212
 
213
  self._cache = dataset
214
+
215
  else:
216
  dataset = self._cache
217
 
218
+ if self.filtering_lambda is not None:
219
+ dataset = self.filter_load(dataset)
220
+
221
  return dataset
222
 
223
  def load_dataset(self):
 
241
  f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
242
  ) from e
243
 
 
 
 
244
  if self.split is None:
245
  for split in dataset.keys():
246
  dataset[split] = dataset[split].to_iterable_dataset()
 
248
  dataset = {self.split: dataset}
249
 
250
  self._cache = dataset
251
+
252
  else:
253
  dataset = self._cache
254
 
255
+ if self.filtering_lambda is not None:
256
+ dataset = self.filter_load(dataset)
257
+
258
  return dataset
259
 
260
+ def split_limited_load(self, dataset, split_name):
261
+ yield from itertools.islice(dataset[split_name], self.get_limit())
262
 
263
+ def limited_load(self, dataset):
264
  self.log_limited_loading()
265
  return MultiStream(
266
  {
267
  name: DynamicStream(
268
+ generator=self.split_limited_load,
269
+ gen_kwargs={"dataset": dataset, "split_name": name},
270
  )
271
  for name in self._cache.keys()
272
  }
 
289
  dataset = self.load_dataset()
290
 
291
  if self.get_limit() is not None:
292
+ return self.limited_load(dataset=dataset)
293
 
294
  return MultiStream.from_iterables(dataset)
295
 
 
620
  object_key,
621
  local_dir + "/" + os.path.basename(temp_file.name),
622
  )
623
+ os.renames(
624
  local_dir + "/" + os.path.basename(temp_file.name),
625
  local_dir + "/" + data_file,
626
  )
 
696
 
697
  data: Dict[str, List[Dict[str, Any]]]
698
 
699
+ def verify(self):
700
+ super().verify()
701
+ if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
702
+ raise ValueError(
703
+ f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
704
+ f"Expected data should map between split name and list of instances.\n"
705
+ f"Received value: {self.data}\n"
706
+ )
707
+ for split in self.data.keys():
708
+ if len(self.data[split]) == 0:
709
+ raise ValueError(f"Split {split} has no instances.")
710
+ first_instance = self.data[split][0]
711
+ for instance in self.data[split]:
712
+ if instance.keys() != first_instance.keys():
713
+ raise ValueError(
714
+ f"Not all instances in split '{split}' have the same fields.\n"
715
+ f"instance {instance} has different fields different from {first_instance}"
716
+ )
717
+
718
  def load_data(self) -> MultiStream:
719
  self.sef_default_data_classification(
720
  ["proprietary"], "when loading from python dictionary"
metric.py CHANGED
@@ -19,13 +19,16 @@ from .file_utils import __file__ as _
19
  from .formats import __file__ as _
20
  from .fusion import __file__ as _
21
  from .generator_utils import __file__ as _
 
22
  from .hf_utils import verify_versions_compatibility
23
  from .inference import __file__ as _
24
  from .instructions import __file__ as _
25
  from .llm_as_judge import __file__ as _
26
  from .loaders import __file__ as _
27
  from .logging_utils import __file__ as _
28
- from .metric_utils import UNITXT_METRIC_SCHEMA, _compute
 
 
29
  from .metrics import __file__ as _
30
  from .normalizers import __file__ as _
31
  from .operator import __file__ as _
@@ -36,6 +39,7 @@ from .random_utils import __file__ as _
36
  from .recipe import __file__ as _
37
  from .register import __file__ as _
38
  from .schema import __file__ as _
 
39
  from .settings_utils import get_constants
40
  from .span_lableing_operators import __file__ as _
41
  from .split_utils import __file__ as _
@@ -50,6 +54,7 @@ from .task import __file__ as _
50
  from .templates import __file__ as _
51
  from .text_utils import __file__ as _
52
  from .type_utils import __file__ as _
 
53
  from .utils import is_package_installed
54
  from .validate import __file__ as _
55
  from .version import __file__ as _
 
19
  from .formats import __file__ as _
20
  from .fusion import __file__ as _
21
  from .generator_utils import __file__ as _
22
+ from .hf_utils import __file__ as _
23
  from .hf_utils import verify_versions_compatibility
24
  from .inference import __file__ as _
25
  from .instructions import __file__ as _
26
  from .llm_as_judge import __file__ as _
27
  from .loaders import __file__ as _
28
  from .logging_utils import __file__ as _
29
+ from .metric_utils import UNITXT_METRIC_SCHEMA
30
+ from .metric_utils import __file__ as _
31
+ from .metric_utils import _compute
32
  from .metrics import __file__ as _
33
  from .normalizers import __file__ as _
34
  from .operator import __file__ as _
 
39
  from .recipe import __file__ as _
40
  from .register import __file__ as _
41
  from .schema import __file__ as _
42
+ from .settings_utils import __file__ as _
43
  from .settings_utils import get_constants
44
  from .span_lableing_operators import __file__ as _
45
  from .split_utils import __file__ as _
 
54
  from .templates import __file__ as _
55
  from .text_utils import __file__ as _
56
  from .type_utils import __file__ as _
57
+ from .utils import __file__ as _
58
  from .utils import is_package_installed
59
  from .validate import __file__ as _
60
  from .version import __file__ as _
metric_utils.py CHANGED
@@ -9,6 +9,7 @@ from .dataclass import Dataclass
9
  from .dict_utils import dict_set
10
  from .operator import (
11
  MultiStreamOperator,
 
12
  SequentialOperatorInitializer,
13
  StreamInitializerOperator,
14
  )
@@ -18,6 +19,7 @@ from .operators import (
18
  Copy,
19
  FlattenInstances,
20
  MergeStreams,
 
21
  SplitByNestedGroup,
22
  )
23
  from .register import _reset_env_local_catalogs, register_all_artifacts
@@ -145,6 +147,59 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
145
  # When receiving instances from this scheme, the keys and values are returned as two separate
146
  # lists, and are converted to a dictionary.
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  class MetricRecipe(SequentialOperatorInitializer):
150
  calc_confidence_intervals: bool = True
@@ -155,13 +210,7 @@ class MetricRecipe(SequentialOperatorInitializer):
155
  self.steps = [
156
  FromPredictionsAndOriginalData(),
157
  LoadJson(field="task_data"),
158
- Copy(
159
- field="source",
160
- to_field="task_data/source",
161
- ),
162
- ApplyOperatorsField(
163
- operators_field="postprocessors",
164
- ),
165
  SplitByNestedGroup(
166
  field_name_of_group="group",
167
  number_of_fusion_generations=self.number_of_fusion_generations,
@@ -172,6 +221,18 @@ class MetricRecipe(SequentialOperatorInitializer):
172
  ),
173
  MultiStreamScoreMean(),
174
  MergeStreams(),
 
 
 
 
 
 
 
 
 
 
 
 
175
  ]
176
 
177
 
 
9
  from .dict_utils import dict_set
10
  from .operator import (
11
  MultiStreamOperator,
12
+ SequentialOperator,
13
  SequentialOperatorInitializer,
14
  StreamInitializerOperator,
15
  )
 
19
  Copy,
20
  FlattenInstances,
21
  MergeStreams,
22
+ RenameFields,
23
  SplitByNestedGroup,
24
  )
25
  from .register import _reset_env_local_catalogs, register_all_artifacts
 
147
  # When receiving instances from this scheme, the keys and values are returned as two separate
148
  # lists, and are converted to a dictionary.
149
 
150
+ _post_process_steps = SequentialOperator(
151
+ steps=[
152
+ Copy(
153
+ field="prediction",
154
+ to_field="raw_prediction",
155
+ ),
156
+ Copy(
157
+ field="references",
158
+ to_field="raw_references",
159
+ ),
160
+ Copy(
161
+ field="source",
162
+ to_field="task_data/source",
163
+ ),
164
+ ApplyOperatorsField(
165
+ operators_field="postprocessors",
166
+ ),
167
+ Copy(
168
+ field="prediction",
169
+ to_field="processed_prediction",
170
+ ),
171
+ Copy(
172
+ field="references",
173
+ to_field="processed_references",
174
+ ),
175
+ ]
176
+ )
177
+
178
+
179
+ class PostProcessRecipe(SequentialOperatorInitializer):
180
+ def prepare(self):
181
+ register_all_artifacts()
182
+ self.steps = [
183
+ FromPredictionsAndOriginalData(),
184
+ _post_process_steps,
185
+ ]
186
+
187
+
188
+ def _post_process(
189
+ predictions: List[str],
190
+ references: Iterable,
191
+ split_name: str = "all",
192
+ ):
193
+ _reset_env_local_catalogs()
194
+ register_all_artifacts()
195
+ recipe = PostProcessRecipe()
196
+
197
+ multi_stream = recipe(
198
+ predictions=predictions, references=references, split_name=split_name
199
+ )
200
+
201
+ return [instance["processed_prediction"] for instance in multi_stream[split_name]]
202
+
203
 
204
  class MetricRecipe(SequentialOperatorInitializer):
205
  calc_confidence_intervals: bool = True
 
210
  self.steps = [
211
  FromPredictionsAndOriginalData(),
212
  LoadJson(field="task_data"),
213
+ _post_process_steps,
 
 
 
 
 
 
214
  SplitByNestedGroup(
215
  field_name_of_group="group",
216
  number_of_fusion_generations=self.number_of_fusion_generations,
 
221
  ),
222
  MultiStreamScoreMean(),
223
  MergeStreams(),
224
+ RenameFields(
225
+ field="raw_prediction",
226
+ to_field="prediction",
227
+ ),
228
+ RenameFields(
229
+ field="raw_references",
230
+ to_field="references",
231
+ ),
232
+ Copy(
233
+ field="source",
234
+ to_field="task_data/source",
235
+ ),
236
  ]
237
 
238
 
metrics.py CHANGED
@@ -27,7 +27,7 @@ from .operator import (
27
  StreamingOperator,
28
  StreamOperator,
29
  )
30
- from .operators import CopyFields
31
  from .random_utils import get_seed
32
  from .settings_utils import get_settings
33
  from .stream import MultiStream, Stream
@@ -1123,7 +1123,7 @@ class MetricPipeline(MultiStreamOperator, Metric):
1123
 
1124
  def prepare(self):
1125
  super().prepare()
1126
- self.prepare_score = CopyFields(
1127
  field_to_field=[
1128
  [
1129
  f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
@@ -2134,9 +2134,7 @@ class Detector(BulkInstanceMetric):
2134
  return self.pipe(predictions, batch_size=self.batch_size)
2135
 
2136
 
2137
- class LlamaIndexCorrectness(InstanceMetric):
2138
- """LlamaIndex based metric class for evaluating correctness."""
2139
-
2140
  model_name: str = ""
2141
  main_score: str = ""
2142
  prediction_type: str = "str"
@@ -2151,6 +2149,34 @@ class LlamaIndexCorrectness(InstanceMetric):
2151
 
2152
  _requirements_list: List[str] = ["llama_index"]
2153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2154
  @staticmethod
2155
  def _custom_parser(eval_response: str):
2156
  """Default parser function for evaluation response.
@@ -2174,37 +2200,14 @@ class LlamaIndexCorrectness(InstanceMetric):
2174
  reasoning = reasoning_str.lstrip("\n")
2175
  return score, reasoning
2176
 
2177
- def _model_using_extrnal_api(self):
2178
- return self.model_name in self.external_api_models
2179
-
2180
  def prepare(self):
2181
  """Initialization method for the metric. Initializes the CorrectnessEvaluator with the OpenAI model."""
2182
  super().prepare()
2183
 
2184
- self.model_name_normalized = self.model_name.replace(".", "_").replace("-", "_")
2185
- self.main_score: str = (
2186
- f"correctness_llama_index_by_{self.model_name_normalized}_judge"
2187
- )
2188
-
2189
- self.reduction_map: Dict[str, List[str]] = {"mean": [self.main_score]}
2190
-
2191
  from llama_index.core.evaluation import CorrectnessEvaluator
2192
 
2193
- if self.model_name in self.openai_models:
2194
- from llama_index.llms.openai import OpenAI
2195
-
2196
- llm = OpenAI("gpt-3.5-turbo")
2197
- elif self.model_name in self.mock_models:
2198
- from llama_index.core.llms.mock import MockLLM
2199
-
2200
- llm = MockLLM(system_prompt="5") # perfect score
2201
- else:
2202
- raise NotImplementedError(
2203
- f"LlamaIndexCorrectnessMetric does not support {self.model_name}, currently only gpt-3.5-turbo is supported"
2204
- )
2205
-
2206
  self.evaluator = CorrectnessEvaluator(
2207
- llm=llm, parser_function=self._custom_parser
2208
  )
2209
 
2210
  def compute(
@@ -2226,9 +2229,6 @@ class LlamaIndexCorrectness(InstanceMetric):
2226
  Raises:
2227
  AssertionError: If the input does not meet the expected format.
2228
  """
2229
- # treat the references as the questions and the predictions as answers
2230
- # assume a single reference
2231
-
2232
  query = task_data["question"]
2233
 
2234
  contexts = None
@@ -2247,11 +2247,36 @@ class LlamaIndexCorrectness(InstanceMetric):
2247
  )
2248
  result = max([results.score for results in per_reference_results])
2249
 
2250
- return {
2251
- self.main_score: result / 5,
2252
- # "score_name": self.main_score,
2253
- # "feedback": result.feedback, # removed since this cannot be tested
2254
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2255
 
2256
 
2257
  class Perplexity(BulkInstanceMetric):
 
27
  StreamingOperator,
28
  StreamOperator,
29
  )
30
+ from .operators import Copy
31
  from .random_utils import get_seed
32
  from .settings_utils import get_settings
33
  from .stream import MultiStream, Stream
 
1123
 
1124
  def prepare(self):
1125
  super().prepare()
1126
+ self.prepare_score = Copy(
1127
  field_to_field=[
1128
  [
1129
  f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
 
2134
  return self.pipe(predictions, batch_size=self.batch_size)
2135
 
2136
 
2137
+ class LlamaIndexLLMMetric(InstanceMetric):
 
 
2138
  model_name: str = ""
2139
  main_score: str = ""
2140
  prediction_type: str = "str"
 
2149
 
2150
  _requirements_list: List[str] = ["llama_index"]
2151
 
2152
+ def prepare(self):
2153
+ self.model_name_normalized = self.model_name.replace(".", "_").replace("-", "_")
2154
+ self.main_score: str = f"llama_index_by_{self.model_name_normalized}_judge"
2155
+
2156
+ self.reduction_map: Dict[str, List[str]] = {"mean": [self.main_score]}
2157
+
2158
+ if self.model_name in self.openai_models:
2159
+ from llama_index.llms.openai import OpenAI
2160
+
2161
+ self.llm = OpenAI("gpt-3.5-turbo")
2162
+ elif self.model_name in self.mock_models:
2163
+ from llama_index.core.llms.mock import MockLLM
2164
+
2165
+ self.llm = MockLLM(system_prompt="5") # perfect score
2166
+ else:
2167
+ raise NotImplementedError(
2168
+ f"LlamaIndexLLM metric does not support {self.model_name}, currently only gpt-3.5-turbo is supported"
2169
+ )
2170
+
2171
+ def _model_using_extrnal_api(self):
2172
+ return self.model_name in self.external_api_models
2173
+
2174
+
2175
+ class LlamaIndexCorrectness(LlamaIndexLLMMetric):
2176
+ """LlamaIndex based metric class for evaluating correctness."""
2177
+
2178
+ score_prefix = "correctness_"
2179
+
2180
  @staticmethod
2181
  def _custom_parser(eval_response: str):
2182
  """Default parser function for evaluation response.
 
2200
  reasoning = reasoning_str.lstrip("\n")
2201
  return score, reasoning
2202
 
 
 
 
2203
  def prepare(self):
2204
  """Initialization method for the metric. Initializes the CorrectnessEvaluator with the OpenAI model."""
2205
  super().prepare()
2206
 
 
 
 
 
 
 
 
2207
  from llama_index.core.evaluation import CorrectnessEvaluator
2208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2209
  self.evaluator = CorrectnessEvaluator(
2210
+ llm=self.llm, parser_function=self._custom_parser
2211
  )
2212
 
2213
  def compute(
 
2229
  Raises:
2230
  AssertionError: If the input does not meet the expected format.
2231
  """
 
 
 
2232
  query = task_data["question"]
2233
 
2234
  contexts = None
 
2247
  )
2248
  result = max([results.score for results in per_reference_results])
2249
 
2250
+ return {self.main_score: result / 5}
2251
+
2252
+
2253
+ class LlamaIndexFaithfulness(LlamaIndexLLMMetric):
2254
+ """LlamaIndex based metric class for evaluating faithfulness."""
2255
+
2256
+ score_prefix = "faithfulness_"
2257
+
2258
+ def prepare(self):
2259
+ """Initialization method for the metric. Initializes the FaithfulnessEvaluator with the OpenAI model."""
2260
+ super().prepare()
2261
+
2262
+ from llama_index.core.evaluation import FaithfulnessEvaluator
2263
+
2264
+ self.evaluator = FaithfulnessEvaluator(llm=self.llm)
2265
+
2266
+ def compute(
2267
+ self,
2268
+ references: List[str],
2269
+ prediction: str,
2270
+ task_data: Dict,
2271
+ ) -> Dict[str, Any]:
2272
+ result = self.evaluator.evaluate(
2273
+ query=task_data["question"],
2274
+ response=prediction,
2275
+ contexts=task_data["contexts"],
2276
+ )
2277
+ score = result.score
2278
+
2279
+ return {self.main_score: score}
2280
 
2281
 
2282
  class Perplexity(BulkInstanceMetric):
operator.py CHANGED
@@ -117,54 +117,6 @@ class SideEffectOperator(StreamingOperator):
117
  pass
118
 
119
 
120
- class SourceOperator(StreamingOperator):
121
- """A class representing a source operator in the streaming system.
122
-
123
- A source operator is responsible for generating the data stream from some source, such as a database or a file.
124
- This is the starting point of a stream processing pipeline.
125
- The `SourceOperator` class is a type of `SourceOperator`, which is a special type of `StreamingOperator`
126
- that generates an output stream but does not take any input streams.
127
-
128
- When called, a `SourceOperator` invokes its `process` method, which should be implemented by all subclasses
129
- to generate the required `MultiStream`.
130
-
131
- """
132
-
133
- caching: bool = NonPositionalField(default=None)
134
-
135
- def __call__(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
136
- multi_stream = self.process()
137
- if self.caching is not None:
138
- multi_stream.set_caching(self.caching)
139
- return multi_stream
140
-
141
- @abstractmethod
142
- def process(self) -> MultiStream:
143
- pass
144
-
145
-
146
- class StreamInitializerOperator(SourceOperator):
147
- """A class representing a stream initializer operator in the streaming system.
148
-
149
- A stream initializer operator is a special type of `SourceOperator` that is capable of taking parameters during the stream generation process. This can be useful in situations where the stream generation process needs to be customized or configured based on certain parameters.
150
-
151
- When called, a `StreamInitializerOperator` invokes its `process` method, passing any supplied arguments and keyword arguments. The `process` method should be implemented by all subclasses to generate the required `MultiStream` based on the given arguments and keyword arguments.
152
-
153
- """
154
-
155
- caching: bool = NonPositionalField(default=None)
156
-
157
- def __call__(self, *args, **kwargs) -> MultiStream:
158
- multi_stream = self.process(*args, **kwargs)
159
- if self.caching is not None:
160
- multi_stream.set_caching(self.caching)
161
- return self.process(*args, **kwargs)
162
-
163
- @abstractmethod
164
- def process(self, *args, **kwargs) -> MultiStream:
165
- pass
166
-
167
-
168
  def instance_generator(instance):
169
  yield instance
170
 
@@ -213,6 +165,55 @@ class MultiStreamOperator(StreamingOperator):
213
  return next(iter(processed_multi_stream[stream_name]))
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  class StreamOperator(MultiStreamOperator):
217
  """A class representing a single-stream operator in the streaming system.
218
 
@@ -458,15 +459,8 @@ class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
458
  pass
459
 
460
 
461
- class SequentialOperator(MultiStreamOperator):
462
- """A class representing a sequential operator in the streaming system.
463
-
464
- A sequential operator is a type of `MultiStreamOperator` that applies a sequence of other operators to a
465
- `MultiStream`. It maintains a list of `StreamingOperator`s and applies them in order to the `MultiStream`.
466
- """
467
-
468
- max_steps = None
469
-
470
  steps: List[StreamingOperator] = field(default_factory=list)
471
 
472
  def num_steps(self) -> int:
@@ -488,13 +482,21 @@ class SequentialOperator(MultiStreamOperator):
488
  def _get_max_steps(self):
489
  return self.max_steps if self.max_steps is not None else len(self.steps)
490
 
 
 
 
 
 
 
 
 
491
  def process(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
492
  for operator in self.steps[0 : self._get_max_steps()]:
493
  multi_stream = operator(multi_stream)
494
  return multi_stream
495
 
496
 
497
- class SourceSequentialOperator(SequentialOperator):
498
  """A class representing a source sequential operator in the streaming system.
499
 
500
  A source sequential operator is a type of `SequentialOperator` that starts with a source operator.
@@ -502,9 +504,6 @@ class SourceSequentialOperator(SequentialOperator):
502
  that the other operators then process.
503
  """
504
 
505
- def __call__(self) -> MultiStream:
506
- return super().__call__()
507
-
508
  def process(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
509
  assert (
510
  self.num_steps() > 0
 
117
  pass
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def instance_generator(instance):
121
  yield instance
122
 
 
165
  return next(iter(processed_multi_stream[stream_name]))
166
 
167
 
168
+ class SourceOperator(MultiStreamOperator):
169
+ """A class representing a source operator in the streaming system.
170
+
171
+ A source operator is responsible for generating the data stream from some source, such as a database or a file.
172
+ This is the starting point of a stream processing pipeline.
173
+ The `SourceOperator` class is a type of `SourceOperator`, which is a special type of `StreamingOperator`
174
+ that generates an output stream but does not take any input streams.
175
+
176
+ When called, a `SourceOperator` invokes its `process` method, which should be implemented by all subclasses
177
+ to generate the required `MultiStream`.
178
+
179
+ """
180
+
181
+ def _process_multi_stream(
182
+ self, multi_stream: Optional[MultiStream] = None
183
+ ) -> MultiStream:
184
+ result = self.process()
185
+ assert isinstance(
186
+ result, MultiStream
187
+ ), "MultiStreamOperator must return a MultiStream"
188
+ return result
189
+
190
+ @abstractmethod
191
+ def process(self) -> MultiStream:
192
+ pass
193
+
194
+
195
+ class StreamInitializerOperator(SourceOperator):
196
+ """A class representing a stream initializer operator in the streaming system.
197
+
198
+ A stream initializer operator is a special type of `SourceOperator` that is capable of taking parameters during the stream generation process. This can be useful in situations where the stream generation process needs to be customized or configured based on certain parameters.
199
+
200
+ When called, a `StreamInitializerOperator` invokes its `process` method, passing any supplied arguments and keyword arguments. The `process` method should be implemented by all subclasses to generate the required `MultiStream` based on the given arguments and keyword arguments.
201
+
202
+ """
203
+
204
+ caching: bool = NonPositionalField(default=None)
205
+
206
+ def __call__(self, *args, **kwargs) -> MultiStream:
207
+ multi_stream = self.process(*args, **kwargs)
208
+ if self.caching is not None:
209
+ multi_stream.set_caching(self.caching)
210
+ return self.process(*args, **kwargs)
211
+
212
+ @abstractmethod
213
+ def process(self, *args, **kwargs) -> MultiStream:
214
+ pass
215
+
216
+
217
  class StreamOperator(MultiStreamOperator):
218
  """A class representing a single-stream operator in the streaming system.
219
 
 
459
  pass
460
 
461
 
462
+ class SequentialMixin(Artifact):
463
+ max_steps: Optional[int] = None
 
 
 
 
 
 
 
464
  steps: List[StreamingOperator] = field(default_factory=list)
465
 
466
  def num_steps(self) -> int:
 
482
  def _get_max_steps(self):
483
  return self.max_steps if self.max_steps is not None else len(self.steps)
484
 
485
+
486
+ class SequentialOperator(MultiStreamOperator, SequentialMixin):
487
+ """A class representing a sequential operator in the streaming system.
488
+
489
+ A sequential operator is a type of `MultiStreamOperator` that applies a sequence of other operators to a
490
+ `MultiStream`. It maintains a list of `StreamingOperator`s and applies them in order to the `MultiStream`.
491
+ """
492
+
493
  def process(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
494
  for operator in self.steps[0 : self._get_max_steps()]:
495
  multi_stream = operator(multi_stream)
496
  return multi_stream
497
 
498
 
499
+ class SourceSequentialOperator(SourceOperator, SequentialMixin):
500
  """A class representing a source sequential operator in the streaming system.
501
 
502
  A source sequential operator is a type of `SequentialOperator` that starts with a source operator.
 
504
  that the other operators then process.
505
  """
506
 
 
 
 
507
  def process(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
508
  assert (
509
  self.num_steps() > 0
operators.py CHANGED
@@ -16,11 +16,17 @@ The primary task in any operator development is to implement the `process` funct
16
 
17
  General or Specelized Operators
18
  --------------------------------
19
- Some operators are specielized in specific task such as:
20
 
21
- - :class:`loaders<unitxt.loaders>` for loading data.
22
  - :class:`splitters<unitxt.splitters>` for fixing data splits.
 
23
  - :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
 
 
 
 
 
24
 
25
  Other specelized operators are used by unitxt internally:
26
 
@@ -59,6 +65,7 @@ import requests
59
 
60
  from .artifact import Artifact, fetch_artifact
61
  from .dataclass import DeprecatedField, NonPositionalField, OptionalField
 
62
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
63
  from .operator import (
64
  InstanceOperator,
@@ -222,7 +229,7 @@ class FlattenInstances(InstanceOperator):
222
  return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
223
 
224
 
225
- class AddFields(InstanceOperator):
226
  """Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
227
 
228
  Args:
@@ -232,17 +239,17 @@ class AddFields(InstanceOperator):
232
 
233
  Examples:
234
  # Add a 'classes' field with a value of a list "positive" and "negative" to all streams
235
- AddFields(fields={"classes": ["positive","negatives"]})
236
 
237
  # Add a 'start' field under the 'span' field with a value of 0 to all streams
238
- AddFields(fields={"span/start": 0}
239
 
240
  # Add a 'classes' field with a value of a list "positive" and "negative" to 'train' stream
241
- AddFields(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})
242
 
243
  # Add a 'classes' field on a given list, prevent modification of original list
244
  # from changing the instance.
245
- AddFields(fields={"classes": alist}), use_deepcopy=True)
246
  # if now alist is modified, still the instances remain intact.
247
  """
248
 
@@ -265,6 +272,11 @@ class AddFields(InstanceOperator):
265
  return instance
266
 
267
 
 
 
 
 
 
268
  class RemoveFields(InstanceOperator):
269
  """Remove specified fields from each instance in a stream.
270
 
@@ -1049,6 +1061,7 @@ class Copy(FieldOperator):
1049
  return copy.deepcopy(value)
1050
 
1051
 
 
1052
  class CopyFields(Copy):
1053
  pass
1054
 
@@ -1392,7 +1405,8 @@ class ComputeExpressionMixin(Artifact):
1392
  return eval(self.expression, self.globals, instance)
1393
 
1394
  raise ValueError(
1395
- f"Cannot run expression by {self} when unitxt.settings.allow_unverified_code=False either set it to True or set {settings.allow_unverified_code_key} environment variable."
 
1396
  )
1397
 
1398
 
@@ -1556,7 +1570,7 @@ class ExtractMostCommonFieldValues(MultiStreamOperator):
1556
  for ele in values_and_counts
1557
  ]
1558
 
1559
- addmostcommons = AddFields(fields={self.to_field: values_to_keep})
1560
  return addmostcommons(multi_stream)
1561
 
1562
 
 
16
 
17
  General or Specelized Operators
18
  --------------------------------
19
+ Some operators are specielized in specific data or specific operations such as:
20
 
21
+ - :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
  - :class:`splitters<unitxt.splitters>` for fixing data splits.
23
+ - :class:`stream_operators<unitxt.stream_operators>` for changing joining and mixing streams.
24
  - :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
25
+ - :class:`collections_operators<unitxt.collections_operators>` for handling collections such as lists and dictionaries.
26
+ - :class:`dialog_operators<unitxt.dialog_operators>` for handling dialogs.
27
+ - :class:`string_operators<unitxt.string_operators>` for handling strings.
28
+ - :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
29
+ - :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30
 
31
  Other specelized operators are used by unitxt internally:
32
 
 
65
 
66
  from .artifact import Artifact, fetch_artifact
67
  from .dataclass import DeprecatedField, NonPositionalField, OptionalField
68
+ from .deprecation_utils import deprecation
69
  from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
70
  from .operator import (
71
  InstanceOperator,
 
229
  return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
230
 
231
 
232
+ class Set(InstanceOperator):
233
  """Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
234
 
235
  Args:
 
239
 
240
  Examples:
241
  # Add a 'classes' field with a value of a list "positive" and "negative" to all streams
242
+ Set(fields={"classes": ["positive","negatives"]})
243
 
244
  # Add a 'start' field under the 'span' field with a value of 0 to all streams
245
+ Set(fields={"span/start": 0}
246
 
247
  # Add a 'classes' field with a value of a list "positive" and "negative" to 'train' stream
248
+ Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})
249
 
250
  # Add a 'classes' field on a given list, prevent modification of original list
251
  # from changing the instance.
252
+ Set(fields={"classes": alist}), use_deepcopy=True)
253
  # if now alist is modified, still the instances remain intact.
254
  """
255
 
 
272
  return instance
273
 
274
 
275
+ @deprecation(version="1.11.0", alternative=Set)
276
+ class AddFields(Set):
277
+ pass
278
+
279
+
280
  class RemoveFields(InstanceOperator):
281
  """Remove specified fields from each instance in a stream.
282
 
 
1061
  return copy.deepcopy(value)
1062
 
1063
 
1064
+ @deprecation(version="1.11.0", alternative=Copy)
1065
  class CopyFields(Copy):
1066
  pass
1067
 
 
1405
  return eval(self.expression, self.globals, instance)
1406
 
1407
  raise ValueError(
1408
+ f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
1409
+ "\nNote: If using test_card() with the default setting, increase loader_limit to avoid missing conditions due to limited data sampling."
1410
  )
1411
 
1412
 
 
1570
  for ele in values_and_counts
1571
  ]
1572
 
1573
+ addmostcommons = Set(fields={self.to_field: values_to_keep})
1574
  return addmostcommons(multi_stream)
1575
 
1576
 
processors.py CHANGED
@@ -33,6 +33,11 @@ class ToListByComma(SplitStrip):
33
  strip_every_element = True
34
 
35
 
 
 
 
 
 
36
  class RegexParser(FieldOperator):
37
  """A processor that uses regex in order to parse a string."""
38
 
 
33
  strip_every_element = True
34
 
35
 
36
+ class ToListByCommaSpace(SplitStrip):
37
+ delimiter = ", "
38
+ strip_every_element = True
39
+
40
+
41
  class RegexParser(FieldOperator):
42
  """A processor that uses regex in order to parse a string."""
43
 
settings_utils.py CHANGED
@@ -1,7 +1,6 @@
 
1
  import os
2
 
3
- import pkg_resources
4
-
5
  from .version import version
6
 
7
 
@@ -141,11 +140,11 @@ if Constants.is_uninitilized():
141
  constants.dataset_file = os.path.join(os.path.dirname(__file__), "dataset.py")
142
  constants.metric_file = os.path.join(os.path.dirname(__file__), "metric.py")
143
  constants.local_catalog_path = os.path.join(os.path.dirname(__file__), "catalog")
144
- try:
145
- constants.default_catalog_path = pkg_resources.resource_filename(
146
- "unitxt", "catalog"
147
- )
148
- except ModuleNotFoundError:
149
  constants.default_catalog_path = constants.local_catalog_path
150
  constants.catalog_dir = constants.local_catalog_path
151
  constants.dataset_url = "unitxt/data"
 
1
+ import importlib.util
2
  import os
3
 
 
 
4
  from .version import version
5
 
6
 
 
140
  constants.dataset_file = os.path.join(os.path.dirname(__file__), "dataset.py")
141
  constants.metric_file = os.path.join(os.path.dirname(__file__), "metric.py")
142
  constants.local_catalog_path = os.path.join(os.path.dirname(__file__), "catalog")
143
+ unitxt_pkg = importlib.util.find_spec("unitxt")
144
+ if unitxt_pkg and unitxt_pkg.origin:
145
+ unitxt_dir = os.path.dirname(unitxt_pkg.origin)
146
+ constants.default_catalog_path = os.path.join(unitxt_dir, "catalog")
147
+ else:
148
  constants.default_catalog_path = constants.local_catalog_path
149
  constants.catalog_dir = constants.local_catalog_path
150
  constants.dataset_url = "unitxt/data"
standard.py CHANGED
@@ -5,7 +5,7 @@ from .dataclass import Field, InternalField, NonPositionalField, OptionalField
5
  from .formats import Format, SystemFormat
6
  from .logging_utils import get_logger
7
  from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
8
- from .operators import AddFields, Augmentor, NullAugmentor, StreamRefiner
9
  from .recipe import Recipe
10
  from .schema import ToUnitxtGroup
11
  from .splitters import Sampler, SeparateSplit, SpreadSplit
@@ -120,6 +120,12 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
120
  metrics = self.card.task.metrics
121
  else:
122
  metrics = self.metrics
 
 
 
 
 
 
123
  return metrics, postprocessors
124
 
125
  def set_pipelines(self):
@@ -220,7 +226,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
220
  self.loading.steps.append(StreamRefiner(max_instances=self.loader_limit))
221
 
222
  self.metadata.steps.append(
223
- AddFields(
224
  fields={
225
  "recipe_metadata": {
226
  "template": self.template,
 
5
  from .formats import Format, SystemFormat
6
  from .logging_utils import get_logger
7
  from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
8
+ from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
9
  from .recipe import Recipe
10
  from .schema import ToUnitxtGroup
11
  from .splitters import Sampler, SeparateSplit, SpreadSplit
 
120
  metrics = self.card.task.metrics
121
  else:
122
  metrics = self.metrics
123
+
124
+ metrics = [
125
+ metric if isinstance(metric, str) else metric.to_json()
126
+ for metric in metrics
127
+ ]
128
+
129
  return metrics, postprocessors
130
 
131
  def set_pipelines(self):
 
226
  self.loading.steps.append(StreamRefiner(max_instances=self.loader_limit))
227
 
228
  self.metadata.steps.append(
229
+ Set(
230
  fields={
231
  "recipe_metadata": {
232
  "template": self.template,
string_operators.py CHANGED
@@ -37,6 +37,27 @@ class TokensSplit(FieldOperator):
37
  return self.tokenizer.tokenize(value)
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  class Join(FieldOperator):
41
  by: str
42
 
 
37
  return self.tokenizer.tokenize(value)
38
 
39
 
40
+ class TokensSlice(FieldOperator):
41
+ model: str
42
+ start: Optional[int] = None
43
+ stop: Optional[int] = None
44
+ step: Optional[int] = None
45
+
46
+ _requirements_list = ["transformers"]
47
+
48
+ def prepare(self):
49
+ super().prepare()
50
+ from transformers import AutoTokenizer
51
+
52
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model)
53
+
54
+ def process_value(self, value: str) -> str:
55
+ encoded = self.tokenizer.encode(value)
56
+ slicer = slice(self.start, self.stop, self.step)
57
+ sliced = encoded[slicer]
58
+ return self.tokenizer.decode(sliced)
59
+
60
+
61
  class Join(FieldOperator):
62
  by: str
63
 
struct_data_operators.py CHANGED
@@ -14,6 +14,7 @@ For key-value pairs, expected input format is:
14
  {"key1": "value1", "key2": value2, "key3": "value3"}
15
  ------------------------
16
  """
 
17
  import json
18
  import random
19
  from abc import ABC, abstractmethod
 
14
  {"key1": "value1", "key2": value2, "key3": "value3"}
15
  ------------------------
16
  """
17
+
18
  import json
19
  import random
20
  from abc import ABC, abstractmethod
task.py CHANGED
@@ -27,6 +27,10 @@ class Task(InstanceOperator):
27
  prediction_type (Optional[str]):
28
  Need to be consistent with all used metrics. Defaults to None, which means that it will
29
  be set to Any.
 
 
 
 
30
 
31
  The output instance contains three fields:
32
  "inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
@@ -39,6 +43,7 @@ class Task(InstanceOperator):
39
  metrics: List[str]
40
  prediction_type: Optional[str] = None
41
  augmentable_inputs: List[str] = []
 
42
 
43
  def verify(self):
44
  for io_type in ["inputs", "outputs"]:
@@ -72,6 +77,8 @@ class Task(InstanceOperator):
72
  augmentable_input in self.inputs
73
  ), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
74
 
 
 
75
  @staticmethod
76
  @lru_cache(maxsize=None)
77
  def get_metric_prediction_type(metric_id: str):
@@ -99,9 +106,46 @@ class Task(InstanceOperator):
99
  f"metric's prediction type ({metric_prediction_type}) are different."
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def process(
103
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
104
  ) -> Dict[str, Any]:
 
 
105
  verify_required_schema(self.inputs, instance)
106
  verify_required_schema(self.outputs, instance)
107
 
 
27
  prediction_type (Optional[str]):
28
  Need to be consistent with all used metrics. Defaults to None, which means that it will
29
  be set to Any.
30
+ defaults (Optional[Dict[str, Any]]):
31
+ An optional dictionary with default values for chosen input/output keys. Needs to be
32
+ consistent with names and types provided in 'inputs' and/or 'outputs' arguments.
33
+ Will not overwrite values if already provided in a given instance.
34
 
35
  The output instance contains three fields:
36
  "inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
 
43
  metrics: List[str]
44
  prediction_type: Optional[str] = None
45
  augmentable_inputs: List[str] = []
46
+ defaults: Optional[Dict[str, Any]] = None
47
 
48
  def verify(self):
49
  for io_type in ["inputs", "outputs"]:
 
77
  augmentable_input in self.inputs
78
  ), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
79
 
80
+ self.verify_defaults()
81
+
82
  @staticmethod
83
  @lru_cache(maxsize=None)
84
  def get_metric_prediction_type(metric_id: str):
 
106
  f"metric's prediction type ({metric_prediction_type}) are different."
107
  )
108
 
109
+ def verify_defaults(self):
110
+ if self.defaults:
111
+ if not isinstance(self.defaults, dict):
112
+ raise ValueError(
113
+ f"If specified, the 'defaults' must be a dictionary, "
114
+ f"however, '{self.defaults}' was provided instead, "
115
+ f"which is of type '{type(self.defaults)}'."
116
+ )
117
+
118
+ for default_name, default_value in self.defaults.items():
119
+ assert isinstance(default_name, str), (
120
+ f"If specified, all keys of the 'defaults' must be strings, "
121
+ f"however, the key '{default_name}' is of type '{type(default_name)}'."
122
+ )
123
+
124
+ val_type = self.inputs.get(default_name) or self.outputs.get(
125
+ default_name
126
+ )
127
+
128
+ assert val_type, (
129
+ f"If specified, all keys of the 'defaults' must refer to a chosen "
130
+ f"key in either 'inputs' or 'outputs'. However, the name '{default_name}' "
131
+ f"was provided which does not match any of the keys."
132
+ )
133
+
134
+ assert isoftype(default_value, parse_type_string(val_type)), (
135
+ f"The value of '{default_name}' from the 'defaults' must be of "
136
+ f"type '{val_type}', however, it is of type '{type(default_value)}'."
137
+ )
138
+
139
+ def set_default_values(self, instance: Dict[str, Any]) -> Dict[str, Any]:
140
+ if self.defaults:
141
+ instance = {**self.defaults, **instance}
142
+ return instance
143
+
144
  def process(
145
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
146
  ) -> Dict[str, Any]:
147
+ instance = self.set_default_values(instance)
148
+
149
  verify_required_schema(self.inputs, instance)
150
  verify_required_schema(self.outputs, instance)
151
 
text_utils.py CHANGED
@@ -69,7 +69,7 @@ def camel_to_snake_case(s):
69
  return s.lower()
70
 
71
 
72
- def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None):
73
  """Constructs a formatted string of a dictionary.
74
 
75
  Args:
@@ -77,13 +77,21 @@ def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None):
77
  indent (int, optional): The current level of indentation. Defaults to 0.
78
  indent_delta (int, optional): The amount of spaces to add for each level of indentation. Defaults to 4.
79
  max_chars (int, optional): The maximum number of characters for each line. Defaults to terminal width - 10.
 
80
  """
81
  max_chars = max_chars or shutil.get_terminal_size()[0] - 10
82
  indent_str = " " * indent
83
  indent_delta_str = " " * indent_delta
84
  res = ""
85
 
86
- for key, value in d.items():
 
 
 
 
 
 
 
87
  if isinstance(value, dict):
88
  res += f"{indent_str}{key}:\n"
89
  res += construct_dict_str(value, indent + indent_delta, max_chars=max_chars)
@@ -106,10 +114,12 @@ def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None):
106
  return res
107
 
108
 
109
- def print_dict(d, indent=0, indent_delta=4, max_chars=None):
110
- dict_str = construct_dict_str(d, indent, indent_delta, max_chars)
 
 
111
  dict_str = "\n" + dict_str
112
- logger.info(dict_str)
113
 
114
 
115
  def nested_tuple_to_string(nested_tuple: tuple) -> str:
@@ -150,6 +160,7 @@ def is_made_of_sub_strings(string, sub_strings):
150
  # It also prepares for the case that __description__ tag does not contain balanced
151
  # parentheses, since it is often cut in the middle, (with "... see more at")
152
  # flake8: noqa: B007
 
153
  def lines_defining_obj_in_card(
154
  all_lines: List[str], obj_name: str, start_search_at_line: int = 0
155
  ) -> Tuple[int, int]:
@@ -165,25 +176,53 @@ def lines_defining_obj_in_card(
165
  ending_line = starting_line - 1
166
  while ending_line < len(all_lines):
167
  ending_line += 1
168
- num_of_opens += len(re.findall(r"[({[]", all_lines[ending_line]))
169
- num_of_closes += len(re.findall(r"[)}\]]", all_lines[ending_line]))
170
- if num_of_closes == num_of_opens:
171
- break
172
  if "__description__" in all_lines[ending_line]:
173
- # can not trust parentheses inside description.
174
- # trust the indentation enforced by ruff, and the way we build __description__:
 
175
  # a line consisting of only __description__=(
176
  # followed by one or more lines of text, can not trust opens and closes
177
  # in them, followed by a line consisting of only: ),
178
  # where the ) is indented with the beginning of __description__
 
 
 
 
179
  tag_indentation = all_lines[ending_line].index("__description__")
180
- last_line_to_start_with = (" " * tag_indentation) + ")"
181
- while not all_lines[ending_line].startswith(last_line_to_start_with):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  ending_line += 1
183
  if "__description__" in obj_name:
184
- return (starting_line, ending_line)
185
- num_of_closes += 1 # for this last line of desc
186
- # continue to the line following the end of description
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  if num_of_closes != num_of_opens:
189
  raise ValueError(
 
69
  return s.lower()
70
 
71
 
72
+ def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None, keys=None):
73
  """Constructs a formatted string of a dictionary.
74
 
75
  Args:
 
77
  indent (int, optional): The current level of indentation. Defaults to 0.
78
  indent_delta (int, optional): The amount of spaces to add for each level of indentation. Defaults to 4.
79
  max_chars (int, optional): The maximum number of characters for each line. Defaults to terminal width - 10.
80
+ keys (List[Str], optional): the list of fields to print
81
  """
82
  max_chars = max_chars or shutil.get_terminal_size()[0] - 10
83
  indent_str = " " * indent
84
  indent_delta_str = " " * indent_delta
85
  res = ""
86
 
87
+ if keys is None:
88
+ keys = d.keys()
89
+ for key in keys:
90
+ if key not in d.keys():
91
+ raise ValueError(
92
+ f"Dictionary does not contain field {key} specified in 'keys' argument. The available keys are {d.keys()}"
93
+ )
94
+ value = d[key]
95
  if isinstance(value, dict):
96
  res += f"{indent_str}{key}:\n"
97
  res += construct_dict_str(value, indent + indent_delta, max_chars=max_chars)
 
114
  return res
115
 
116
 
117
+ def print_dict(
118
+ d, indent=0, indent_delta=4, max_chars=None, keys_to_print=None, log_level="info"
119
+ ):
120
+ dict_str = construct_dict_str(d, indent, indent_delta, max_chars, keys_to_print)
121
  dict_str = "\n" + dict_str
122
+ getattr(logger, log_level)(dict_str)
123
 
124
 
125
  def nested_tuple_to_string(nested_tuple: tuple) -> str:
 
160
  # It also prepares for the case that __description__ tag does not contain balanced
161
  # parentheses, since it is often cut in the middle, (with "... see more at")
162
  # flake8: noqa: B007
163
+ # flake8: noqa: C901
164
  def lines_defining_obj_in_card(
165
  all_lines: List[str], obj_name: str, start_search_at_line: int = 0
166
  ) -> Tuple[int, int]:
 
176
  ending_line = starting_line - 1
177
  while ending_line < len(all_lines):
178
  ending_line += 1
179
+
 
 
 
180
  if "__description__" in all_lines[ending_line]:
181
+ # can not trust parentheses inside description, because this is mainly truncated
182
+ # free text.
183
+ # We do trust the indentation enforced by ruff, and the way we build __description__:
184
  # a line consisting of only __description__=(
185
  # followed by one or more lines of text, can not trust opens and closes
186
  # in them, followed by a line consisting of only: ),
187
  # where the ) is indented with the beginning of __description__
188
+ # We also prepare for the case that, when not entered by us, __description__=
189
+ # is not followed by a ( and the whole description does not end with a single ) in its line.
190
+ # We build on ruff making the line following the description start with same indentation
191
+ # or 4 less (i.e., the following line is the closing of the card).
192
  tag_indentation = all_lines[ending_line].index("__description__")
193
+ starts_with_parent = "__description__=(" in all_lines[ending_line]
194
+ if starts_with_parent:
195
+ last_line_to_start_with = (" " * tag_indentation) + r"\)"
196
+ else:
197
+ # actually, the line that follows the description
198
+ last_line_to_start_with1 = (" " * tag_indentation) + "[^ ]"
199
+ last_line_to_start_with2 = (" " * (tag_indentation - 4)) + "[^ ]"
200
+ last_line_to_start_with = (
201
+ "("
202
+ + last_line_to_start_with1
203
+ + "|"
204
+ + last_line_to_start_with2
205
+ + ")"
206
+ )
207
+ ending_line += 1
208
+ while not re.search("^" + last_line_to_start_with, all_lines[ending_line]):
209
  ending_line += 1
210
  if "__description__" in obj_name:
211
+ return (
212
+ starting_line,
213
+ ending_line if starts_with_parent else ending_line - 1,
214
+ )
215
+
216
+ if starts_with_parent:
217
+ ending_line += 1
218
+
219
+ # we conrinue in card, having passed the description, ending line points
220
+ # to the line that follows description
221
+
222
+ num_of_opens += len(re.findall(r"[({[]", all_lines[ending_line]))
223
+ num_of_closes += len(re.findall(r"[)}\]]", all_lines[ending_line]))
224
+ if num_of_closes == num_of_opens:
225
+ break
226
 
227
  if num_of_closes != num_of_opens:
228
  raise ValueError(
utils.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import json
 
2
  from functools import lru_cache
3
  from typing import Any, Dict
4
 
@@ -47,13 +49,16 @@ def load_json(path):
47
  ) from e
48
 
49
 
50
- def save_json(path, data):
51
  with open(path, "w") as f:
52
- dumped = json.dumps(data, indent=4, ensure_ascii=False)
53
- f.write(dumped)
54
  f.write("\n")
55
 
56
 
 
 
 
 
57
  def is_package_installed(package_name):
58
  """Check if a package is installed.
59
 
@@ -113,3 +118,15 @@ def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
113
  raise ValueError(
114
  f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
115
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
  import json
3
+ import os
4
  from functools import lru_cache
5
  from typing import Any, Dict
6
 
 
49
  ) from e
50
 
51
 
52
+ def save_to_file(path, data):
53
  with open(path, "w") as f:
54
+ f.write(data)
 
55
  f.write("\n")
56
 
57
 
58
+ def json_dump(data):
59
+ return json.dumps(data, indent=4, ensure_ascii=False)
60
+
61
+
62
  def is_package_installed(package_name):
63
  """Check if a package is installed.
64
 
 
118
  raise ValueError(
119
  f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
120
  )
121
+
122
+
123
+ def import_module_from_file(file_path):
124
+ # Get the module name (file name without extension)
125
+ module_name = os.path.splitext(os.path.basename(file_path))[0]
126
+ # Create a module specification
127
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
128
+ # Create a new module based on the specification
129
+ module = importlib.util.module_from_spec(spec)
130
+ # Load the module
131
+ spec.loader.exec_module(module)
132
+ return module
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.10.0"
 
1
+ version = "1.10.1"