Upload folder using huggingface_hub
Browse files- README.md +28 -9
- api.py +74 -3
- artifact.py +52 -18
- blocks.py +2 -2
- catalog.py +1 -7
- dataset.py +7 -3
- dataset_utils.py +14 -6
- deprecation_utils.py +9 -5
- dialog_operators.py +1 -0
- fusion.py +6 -2
- inference.py +212 -4
- llm_as_judge.py +16 -12
- loaders.py +44 -21
- metric.py +6 -1
- metric_utils.py +68 -7
- metrics.py +62 -37
- operator.py +60 -61
- operators.py +23 -9
- processors.py +5 -0
- settings_utils.py +6 -7
- standard.py +8 -2
- string_operators.py +21 -0
- struct_data_operators.py +1 -0
- task.py +44 -0
- text_utils.py +55 -16
- utils.py +20 -3
- version.py +1 -1
README.md
CHANGED
@@ -11,11 +11,11 @@ pinned: false
|
|
11 |
<img src="https://raw.githubusercontent.com/IBM/unitxt/main/assets/banner.png" alt="Image Description" width="100%" />
|
12 |
</div>
|
13 |
|
14 |
-
[![Button](https://img.shields.io/badge/Video-pink?style=for-the-badge)](https://unitxt.readthedocs.io/)
|
|
|
15 |
[![Button](https://img.shields.io/badge/Demo-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/demo.html)
|
16 |
[![Button](https://img.shields.io/badge/Tutorial-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/adding_dataset.html)
|
17 |
[![Button](https://img.shields.io/badge/Paper-pink?style=for-the-badge)](https://arxiv.org/abs/2401.14019)
|
18 |
-
[![Button](https://img.shields.io/badge/Documentation-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/modules.html)
|
19 |
[![Button](https://img.shields.io/badge/Catalog-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/catalog/catalog.__dir__.html)
|
20 |
[![Button](https://img.shields.io/badge/Contributors-pink?style=for-the-badge)](https://github.com/IBM/unitxt/blob/main/CONTRIBUTING.md)
|
21 |
[![Button](https://img.shields.io/badge/PyPi-pink?style=for-the-badge)](https://pypi.org/project/unitxt/)
|
@@ -72,13 +72,32 @@ pre-commit install
|
|
72 |
If you use Unitxt in your research, please cite our paper:
|
73 |
|
74 |
```
|
75 |
-
@
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
}
|
83 |
```
|
84 |
|
|
|
|
11 |
<img src="https://raw.githubusercontent.com/IBM/unitxt/main/assets/banner.png" alt="Image Description" width="100%" />
|
12 |
</div>
|
13 |
|
14 |
+
[![Button](https://img.shields.io/badge/Video-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/_static/video.mov)
|
15 |
+
[![Button](https://img.shields.io/badge/Documentation-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/introduction.html)
|
16 |
[![Button](https://img.shields.io/badge/Demo-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/demo.html)
|
17 |
[![Button](https://img.shields.io/badge/Tutorial-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/docs/adding_dataset.html)
|
18 |
[![Button](https://img.shields.io/badge/Paper-pink?style=for-the-badge)](https://arxiv.org/abs/2401.14019)
|
|
|
19 |
[![Button](https://img.shields.io/badge/Catalog-pink?style=for-the-badge)](https://unitxt.readthedocs.io/en/latest/catalog/catalog.__dir__.html)
|
20 |
[![Button](https://img.shields.io/badge/Contributors-pink?style=for-the-badge)](https://github.com/IBM/unitxt/blob/main/CONTRIBUTING.md)
|
21 |
[![Button](https://img.shields.io/badge/PyPi-pink?style=for-the-badge)](https://pypi.org/project/unitxt/)
|
|
|
72 |
If you use Unitxt in your research, please cite our paper:
|
73 |
|
74 |
```
|
75 |
+
@inproceedings{bandel-etal-2024-unitxt,
|
76 |
+
title = "Unitxt: Flexible, Shareable and Reusable Data Preparation and Evaluation for Generative {AI}",
|
77 |
+
author = "Bandel, Elron and
|
78 |
+
Perlitz, Yotam and
|
79 |
+
Venezian, Elad and
|
80 |
+
Friedman, Roni and
|
81 |
+
Arviv, Ofir and
|
82 |
+
Orbach, Matan and
|
83 |
+
Don-Yehiya, Shachar and
|
84 |
+
Sheinwald, Dafna and
|
85 |
+
Gera, Ariel and
|
86 |
+
Choshen, Leshem and
|
87 |
+
Shmueli-Scheuer, Michal and
|
88 |
+
Katz, Yoav",
|
89 |
+
editor = "Chang, Kai-Wei and
|
90 |
+
Lee, Annie and
|
91 |
+
Rajani, Nazneen",
|
92 |
+
booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 3: System Demonstrations)",
|
93 |
+
month = jun,
|
94 |
+
year = "2024",
|
95 |
+
address = "Mexico City, Mexico",
|
96 |
+
publisher = "Association for Computational Linguistics",
|
97 |
+
url = "https://aclanthology.org/2024.naacl-demo.21",
|
98 |
+
pages = "207--215",
|
99 |
+
abstract = "In the dynamic landscape of generative NLP, traditional text processing pipelines limit research flexibility and reproducibility, as they are tailored to specific dataset, task, and model combinations. The escalating complexity, involving system prompts, model-specific formats, instructions, and more, calls for a shift to a structured, modular, and customizable solution.Addressing this need, we present Unitxt, an innovative library for customizable textual data preparation and evaluation tailored to generative language models. Unitxt natively integrates with common libraries like HuggingFace and LM-eval-harness and deconstructs processing flows into modular components, enabling easy customization and sharing between practitioners. These components encompass model-specific formats, task prompts, and many other comprehensive dataset processing definitions. The Unitxt Catalog centralizes these components, fostering collaboration and exploration in modern textual data workflows. Beyond being a tool, Unitxt is a community-driven platform, empowering users to build, share, and advance their pipelines collaboratively. Join the Unitxt community at https://github.com/IBM/unitxt",
|
100 |
}
|
101 |
```
|
102 |
|
103 |
+
Unitxt emoji designed by [OpenMoji](https://openmoji.org/#) - the open-source emoji and icon project. License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/#)
|
api.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
from functools import lru_cache
|
2 |
-
from typing import Any, Dict, List, Union
|
3 |
|
4 |
from datasets import DatasetDict
|
5 |
|
6 |
from .artifact import fetch_artifact
|
7 |
from .dataset_utils import get_dataset_artifact
|
8 |
from .logging_utils import get_logger
|
9 |
-
from .metric_utils import _compute
|
10 |
from .operator import SourceOperator
|
|
|
11 |
|
12 |
logger = get_logger()
|
13 |
|
@@ -21,16 +22,79 @@ def load(source: Union[SourceOperator, str]) -> DatasetDict:
|
|
21 |
return source().to_dataset()
|
22 |
|
23 |
|
24 |
-
def
|
25 |
dataset_query = dataset_query.replace("sys_prompt", "instruction")
|
26 |
dataset_stream = get_dataset_artifact(dataset_query)
|
27 |
return dataset_stream().to_dataset()
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def evaluate(predictions, data) -> List[Dict[str, Any]]:
|
31 |
return _compute(predictions=predictions, references=data)
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
34 |
@lru_cache
|
35 |
def _get_produce_with_cache(recipe_query):
|
36 |
return get_dataset_artifact(recipe_query).produce
|
@@ -44,3 +108,10 @@ def produce(instance_or_instances, recipe_query):
|
|
44 |
if not is_list:
|
45 |
result = result[0]
|
46 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from functools import lru_cache
|
2 |
+
from typing import Any, Dict, List, Optional, Union
|
3 |
|
4 |
from datasets import DatasetDict
|
5 |
|
6 |
from .artifact import fetch_artifact
|
7 |
from .dataset_utils import get_dataset_artifact
|
8 |
from .logging_utils import get_logger
|
9 |
+
from .metric_utils import _compute, _post_process
|
10 |
from .operator import SourceOperator
|
11 |
+
from .standard import StandardRecipe
|
12 |
|
13 |
logger = get_logger()
|
14 |
|
|
|
22 |
return source().to_dataset()
|
23 |
|
24 |
|
25 |
+
def _load_dataset_from_query(dataset_query: str) -> DatasetDict:
|
26 |
dataset_query = dataset_query.replace("sys_prompt", "instruction")
|
27 |
dataset_stream = get_dataset_artifact(dataset_query)
|
28 |
return dataset_stream().to_dataset()
|
29 |
|
30 |
|
31 |
+
def _load_dataset_from_dict(dataset_params: Dict[str, Any]) -> DatasetDict:
|
32 |
+
recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
|
33 |
+
for param in dataset_params.keys():
|
34 |
+
assert param in recipe_attributes, (
|
35 |
+
f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
|
36 |
+
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
|
37 |
+
)
|
38 |
+
recipe = StandardRecipe(**dataset_params)
|
39 |
+
return recipe().to_dataset()
|
40 |
+
|
41 |
+
|
42 |
+
def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
43 |
+
"""Loads dataset.
|
44 |
+
|
45 |
+
If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
|
46 |
+
catalog based on parameters specified in the query.
|
47 |
+
Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
dataset_query (str, optional): A string query which specifies dataset to load from local catalog.
|
51 |
+
For example:
|
52 |
+
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
|
53 |
+
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
DatasetDict
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
dataset = load_dataset(
|
60 |
+
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
|
61 |
+
) # card must be present in local catalog
|
62 |
+
|
63 |
+
card = TaskCard(...)
|
64 |
+
template = Template(...)
|
65 |
+
loader_limit = 10
|
66 |
+
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
|
67 |
+
"""
|
68 |
+
if dataset_query and kwargs:
|
69 |
+
raise ValueError(
|
70 |
+
"Cannot provide 'dataset_query' and key-worded arguments at the same time. "
|
71 |
+
"If you want to load dataset from a card in local catalog, use query only. "
|
72 |
+
"Otherwise, use key-worded arguments only to specify properties of dataset."
|
73 |
+
)
|
74 |
+
|
75 |
+
if dataset_query:
|
76 |
+
if not isinstance(dataset_query, str):
|
77 |
+
raise ValueError(
|
78 |
+
f"If specified, 'dataset_query' must be a string, however, "
|
79 |
+
f"'{dataset_query}' was provided instead, which is of type "
|
80 |
+
f"'{type(dataset_query)}'."
|
81 |
+
)
|
82 |
+
return _load_dataset_from_query(dataset_query)
|
83 |
+
|
84 |
+
if kwargs:
|
85 |
+
return _load_dataset_from_dict(kwargs)
|
86 |
+
|
87 |
+
raise ValueError("Either 'dataset_query' or key-worded arguments must be provided.")
|
88 |
+
|
89 |
+
|
90 |
def evaluate(predictions, data) -> List[Dict[str, Any]]:
|
91 |
return _compute(predictions=predictions, references=data)
|
92 |
|
93 |
|
94 |
+
def post_process(predictions, data) -> List[Dict[str, Any]]:
|
95 |
+
return _post_process(predictions=predictions, references=data)
|
96 |
+
|
97 |
+
|
98 |
@lru_cache
|
99 |
def _get_produce_with_cache(recipe_query):
|
100 |
return get_dataset_artifact(recipe_query).produce
|
|
|
108 |
if not is_list:
|
109 |
result = result[0]
|
110 |
return result
|
111 |
+
|
112 |
+
|
113 |
+
def infer(instance_or_instances, recipe, engine):
|
114 |
+
dataset = produce(instance_or_instances, recipe)
|
115 |
+
engine, _ = fetch_artifact(engine)
|
116 |
+
predictions = engine.infer(dataset)
|
117 |
+
return post_process(predictions, dataset)
|
artifact.py
CHANGED
@@ -3,9 +3,10 @@ import inspect
|
|
3 |
import json
|
4 |
import os
|
5 |
import pkgutil
|
|
|
6 |
from abc import abstractmethod
|
7 |
from copy import deepcopy
|
8 |
-
from typing import Any, Dict, List, Optional, Union, final
|
9 |
|
10 |
from .dataclass import (
|
11 |
AbstractField,
|
@@ -19,13 +20,24 @@ from .logging_utils import get_logger
|
|
19 |
from .parsing_utils import (
|
20 |
separate_inside_and_outside_square_brackets,
|
21 |
)
|
22 |
-
from .settings_utils import get_settings
|
23 |
from .text_utils import camel_to_snake_case, is_camel_case
|
24 |
from .type_utils import issubtype
|
25 |
-
from .utils import artifacts_json_cache,
|
26 |
|
27 |
logger = get_logger()
|
28 |
settings = get_settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
class Artifactories:
|
@@ -120,7 +132,7 @@ class MissingArtifactTypeError(ValueError):
|
|
120 |
class Artifact(Dataclass):
|
121 |
_class_register = {}
|
122 |
|
123 |
-
|
124 |
__description__: str = NonPositionalField(
|
125 |
default=None, required=False, also_positional=False
|
126 |
)
|
@@ -135,7 +147,7 @@ class Artifact(Dataclass):
|
|
135 |
|
136 |
@classmethod
|
137 |
def is_artifact_dict(cls, d):
|
138 |
-
return isinstance(d, dict) and "
|
139 |
|
140 |
@classmethod
|
141 |
def verify_artifact_dict(cls, d):
|
@@ -143,10 +155,10 @@ class Artifact(Dataclass):
|
|
143 |
raise ValueError(
|
144 |
f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'."
|
145 |
)
|
146 |
-
if "
|
147 |
raise MissingArtifactTypeError(d)
|
148 |
-
if not cls.is_registered_type(d["
|
149 |
-
raise UnrecognizedArtifactTypeError(d["
|
150 |
|
151 |
@classmethod
|
152 |
def get_artifact_type(cls):
|
@@ -212,7 +224,7 @@ class Artifact(Dataclass):
|
|
212 |
pass
|
213 |
if cls.is_artifact_dict(obj):
|
214 |
cls.verify_artifact_dict(obj)
|
215 |
-
return cls._class_register[obj.pop("
|
216 |
|
217 |
return obj
|
218 |
|
@@ -261,7 +273,7 @@ class Artifact(Dataclass):
|
|
261 |
|
262 |
@final
|
263 |
def __post_init__(self):
|
264 |
-
self.
|
265 |
|
266 |
for field in fields(self):
|
267 |
if issubtype(
|
@@ -277,11 +289,24 @@ class Artifact(Dataclass):
|
|
277 |
self.verify()
|
278 |
|
279 |
def _to_raw_dict(self):
|
280 |
-
return {"
|
281 |
|
282 |
-
def
|
283 |
data = self.to_dict()
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
def verify_instance(
|
287 |
self, instance: Dict[str, Any], name: Optional[str] = None
|
@@ -404,13 +429,22 @@ class UnitxtArtifactNotFoundError(Exception):
|
|
404 |
return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}"
|
405 |
|
406 |
|
407 |
-
def fetch_artifact(
|
408 |
-
if Artifact
|
409 |
-
return
|
|
|
|
|
410 |
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
-
return
|
414 |
|
415 |
|
416 |
def get_artifactory_name_and_args(
|
|
|
3 |
import json
|
4 |
import os
|
5 |
import pkgutil
|
6 |
+
import re
|
7 |
from abc import abstractmethod
|
8 |
from copy import deepcopy
|
9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union, final
|
10 |
|
11 |
from .dataclass import (
|
12 |
AbstractField,
|
|
|
20 |
from .parsing_utils import (
|
21 |
separate_inside_and_outside_square_brackets,
|
22 |
)
|
23 |
+
from .settings_utils import get_constants, get_settings
|
24 |
from .text_utils import camel_to_snake_case, is_camel_case
|
25 |
from .type_utils import issubtype
|
26 |
+
from .utils import artifacts_json_cache, json_dump, save_to_file
|
27 |
|
28 |
logger = get_logger()
|
29 |
settings = get_settings()
|
30 |
+
constants = get_constants()
|
31 |
+
|
32 |
+
|
33 |
+
def is_name_legal_for_catalog(name):
|
34 |
+
return re.match(r"^[\w" + constants.catalog_hierarchy_sep + "]+$", name)
|
35 |
+
|
36 |
+
|
37 |
+
def verify_legal_catalog_name(name):
|
38 |
+
assert is_name_legal_for_catalog(
|
39 |
+
name
|
40 |
+
), f'Artifict name ("{name}") should be alphanumeric. Use "." for nesting (e.g. myfolder.my_artifact)'
|
41 |
|
42 |
|
43 |
class Artifactories:
|
|
|
132 |
class Artifact(Dataclass):
|
133 |
_class_register = {}
|
134 |
|
135 |
+
__type__: str = Field(default=None, final=True, init=False)
|
136 |
__description__: str = NonPositionalField(
|
137 |
default=None, required=False, also_positional=False
|
138 |
)
|
|
|
147 |
|
148 |
@classmethod
|
149 |
def is_artifact_dict(cls, d):
|
150 |
+
return isinstance(d, dict) and "__type__" in d
|
151 |
|
152 |
@classmethod
|
153 |
def verify_artifact_dict(cls, d):
|
|
|
155 |
raise ValueError(
|
156 |
f"Artifact dict <{d}> must be of type 'dict', got '{type(d)}'."
|
157 |
)
|
158 |
+
if "__type__" not in d:
|
159 |
raise MissingArtifactTypeError(d)
|
160 |
+
if not cls.is_registered_type(d["__type__"]):
|
161 |
+
raise UnrecognizedArtifactTypeError(d["__type__"])
|
162 |
|
163 |
@classmethod
|
164 |
def get_artifact_type(cls):
|
|
|
224 |
pass
|
225 |
if cls.is_artifact_dict(obj):
|
226 |
cls.verify_artifact_dict(obj)
|
227 |
+
return cls._class_register[obj.pop("__type__")](**obj)
|
228 |
|
229 |
return obj
|
230 |
|
|
|
273 |
|
274 |
@final
|
275 |
def __post_init__(self):
|
276 |
+
self.__type__ = self.register_class(self.__class__)
|
277 |
|
278 |
for field in fields(self):
|
279 |
if issubtype(
|
|
|
289 |
self.verify()
|
290 |
|
291 |
def _to_raw_dict(self):
|
292 |
+
return {"__type__": self.__type__, **self._init_dict}
|
293 |
|
294 |
+
def to_json(self):
|
295 |
data = self.to_dict()
|
296 |
+
return json_dump(data)
|
297 |
+
|
298 |
+
def serialize(self):
|
299 |
+
if self.__id__ is not None:
|
300 |
+
return self.__id__
|
301 |
+
return self.to_json()
|
302 |
+
|
303 |
+
def save(self, path):
|
304 |
+
save_to_file(path, self.to_json())
|
305 |
+
|
306 |
+
@classmethod
|
307 |
+
def deserialize(cls, artifact_rep):
|
308 |
+
data = json.loads(artifact_rep)
|
309 |
+
return Artifact.from_dict(data)
|
310 |
|
311 |
def verify_instance(
|
312 |
self, instance: Dict[str, Any], name: Optional[str] = None
|
|
|
429 |
return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}"
|
430 |
|
431 |
|
432 |
+
def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[Artifactory, None]]:
|
433 |
+
if isinstance(artifact_rep, Artifact):
|
434 |
+
return artifact_rep, None
|
435 |
+
if Artifact.is_artifact_file(artifact_rep):
|
436 |
+
return Artifact.load(artifact_rep), None
|
437 |
|
438 |
+
name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
|
439 |
+
if is_name_legal_for_catalog(name):
|
440 |
+
artifactory, artifact_rep, args = get_artifactory_name_and_args(
|
441 |
+
name=artifact_rep
|
442 |
+
)
|
443 |
+
return artifactory.get_with_overwrite(
|
444 |
+
artifact_rep, overwrite_args=args
|
445 |
+
), artifactory
|
446 |
|
447 |
+
return Artifact.deserialize(artifact_rep), None
|
448 |
|
449 |
|
450 |
def get_artifactory_name_and_args(
|
blocks.py
CHANGED
@@ -8,13 +8,13 @@ from .loaders import LoadFromIBMCloud, LoadFromKaggle, LoadHF
|
|
8 |
from .metrics import Accuracy
|
9 |
from .normalizers import NormalizeListFields
|
10 |
from .operators import (
|
11 |
-
AddFields,
|
12 |
AddID,
|
13 |
CastFields,
|
14 |
-
|
15 |
DivideAllFieldsBy,
|
16 |
MapInstanceValues,
|
17 |
RenameFields,
|
|
|
18 |
)
|
19 |
from .processors import ToString, ToStringStripped
|
20 |
from .recipe import SequentialRecipe
|
|
|
8 |
from .metrics import Accuracy
|
9 |
from .normalizers import NormalizeListFields
|
10 |
from .operators import (
|
|
|
11 |
AddID,
|
12 |
CastFields,
|
13 |
+
Copy,
|
14 |
DivideAllFieldsBy,
|
15 |
MapInstanceValues,
|
16 |
RenameFields,
|
17 |
+
Set,
|
18 |
)
|
19 |
from .processors import ToString, ToStringStripped
|
20 |
from .recipe import SequentialRecipe
|
catalog.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import os
|
2 |
-
import re
|
3 |
from collections import Counter
|
4 |
from functools import lru_cache
|
5 |
from pathlib import Path
|
@@ -13,6 +12,7 @@ from .artifact import (
|
|
13 |
Artifactory,
|
14 |
get_artifactory_name_and_args,
|
15 |
reset_artifacts_json_cache,
|
|
|
16 |
)
|
17 |
from .logging_utils import get_logger
|
18 |
from .settings_utils import get_constants
|
@@ -114,12 +114,6 @@ class GithubCatalog(LocalCatalog):
|
|
114 |
return response.status_code == 200
|
115 |
|
116 |
|
117 |
-
def verify_legal_catalog_name(name):
|
118 |
-
assert re.match(
|
119 |
-
r"^[\w" + constants.catalog_hierarchy_sep + "]+$", name
|
120 |
-
), f'Artifict name ("{name}") should be alphanumeric. Use "." for nesting (e.g. myfolder.my_artifact)'
|
121 |
-
|
122 |
-
|
123 |
def add_to_catalog(
|
124 |
artifact: Artifact,
|
125 |
name: str,
|
|
|
1 |
import os
|
|
|
2 |
from collections import Counter
|
3 |
from functools import lru_cache
|
4 |
from pathlib import Path
|
|
|
12 |
Artifactory,
|
13 |
get_artifactory_name_and_args,
|
14 |
reset_artifacts_json_cache,
|
15 |
+
verify_legal_catalog_name,
|
16 |
)
|
17 |
from .logging_utils import get_logger
|
18 |
from .settings_utils import get_constants
|
|
|
114 |
return response.status_code == 200
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
def add_to_catalog(
|
118 |
artifact: Artifact,
|
119 |
name: str,
|
dataset.py
CHANGED
@@ -10,6 +10,7 @@ from .catalog import __file__ as _
|
|
10 |
from .collections import __file__ as _
|
11 |
from .collections_operators import __file__ as _
|
12 |
from .dataclass import __file__ as _
|
|
|
13 |
from .dataset_utils import get_dataset_artifact
|
14 |
from .deprecation_utils import __file__ as _
|
15 |
from .dialog_operators import __file__ as _
|
@@ -19,11 +20,13 @@ from .file_utils import __file__ as _
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
|
|
22 |
from .hf_utils import verify_versions_compatibility
|
23 |
from .inference import __file__ as _
|
24 |
from .instructions import __file__ as _
|
25 |
from .llm_as_judge import __file__ as _
|
26 |
from .loaders import __file__ as _
|
|
|
27 |
from .logging_utils import get_logger
|
28 |
from .metric import __file__ as _
|
29 |
from .metric_utils import __file__ as _
|
@@ -37,6 +40,7 @@ from .random_utils import __file__ as _
|
|
37 |
from .recipe import __file__ as _
|
38 |
from .register import __file__ as _
|
39 |
from .schema import __file__ as _
|
|
|
40 |
from .settings_utils import get_constants
|
41 |
from .span_lableing_operators import __file__ as _
|
42 |
from .split_utils import __file__ as _
|
@@ -51,6 +55,7 @@ from .task import __file__ as _
|
|
51 |
from .templates import __file__ as _
|
52 |
from .text_utils import __file__ as _
|
53 |
from .type_utils import __file__ as _
|
|
|
54 |
from .utils import is_package_installed
|
55 |
from .validate import __file__ as _
|
56 |
from .version import __file__ as _
|
@@ -71,9 +76,8 @@ class Dataset(datasets.GeneratorBasedBuilder):
|
|
71 |
if is_package_installed("unitxt"):
|
72 |
verify_versions_compatibility("dataset", self.VERSION)
|
73 |
|
74 |
-
from unitxt.dataset_utils import
|
75 |
-
get_dataset_artifact as get_dataset_artifact_installed
|
76 |
-
)
|
77 |
|
78 |
logger.info("Loading with installed unitxt library...")
|
79 |
dataset = get_dataset_artifact_installed(self.config.name)
|
|
|
10 |
from .collections import __file__ as _
|
11 |
from .collections_operators import __file__ as _
|
12 |
from .dataclass import __file__ as _
|
13 |
+
from .dataset_utils import __file__ as _
|
14 |
from .dataset_utils import get_dataset_artifact
|
15 |
from .deprecation_utils import __file__ as _
|
16 |
from .dialog_operators import __file__ as _
|
|
|
20 |
from .formats import __file__ as _
|
21 |
from .fusion import __file__ as _
|
22 |
from .generator_utils import __file__ as _
|
23 |
+
from .hf_utils import __file__ as _
|
24 |
from .hf_utils import verify_versions_compatibility
|
25 |
from .inference import __file__ as _
|
26 |
from .instructions import __file__ as _
|
27 |
from .llm_as_judge import __file__ as _
|
28 |
from .loaders import __file__ as _
|
29 |
+
from .logging_utils import __file__ as _
|
30 |
from .logging_utils import get_logger
|
31 |
from .metric import __file__ as _
|
32 |
from .metric_utils import __file__ as _
|
|
|
40 |
from .recipe import __file__ as _
|
41 |
from .register import __file__ as _
|
42 |
from .schema import __file__ as _
|
43 |
+
from .settings_utils import __file__ as _
|
44 |
from .settings_utils import get_constants
|
45 |
from .span_lableing_operators import __file__ as _
|
46 |
from .split_utils import __file__ as _
|
|
|
55 |
from .templates import __file__ as _
|
56 |
from .text_utils import __file__ as _
|
57 |
from .type_utils import __file__ as _
|
58 |
+
from .utils import __file__ as _
|
59 |
from .utils import is_package_installed
|
60 |
from .validate import __file__ as _
|
61 |
from .version import __file__ as _
|
|
|
76 |
if is_package_installed("unitxt"):
|
77 |
verify_versions_compatibility("dataset", self.VERSION)
|
78 |
|
79 |
+
from unitxt.dataset_utils import \
|
80 |
+
get_dataset_artifact as get_dataset_artifact_installed
|
|
|
81 |
|
82 |
logger.info("Loading with installed unitxt library...")
|
83 |
dataset = get_dataset_artifact_installed(self.config.name)
|
dataset_utils.py
CHANGED
@@ -1,8 +1,11 @@
|
|
|
|
|
|
1 |
from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
|
2 |
from .logging_utils import get_logger
|
3 |
from .parsing_utils import parse_key_equals_value_string_to_dict
|
4 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
5 |
from .settings_utils import get_settings
|
|
|
6 |
|
7 |
logger = get_logger()
|
8 |
settings = get_settings()
|
@@ -12,7 +15,7 @@ def fetch(artifact_name):
|
|
12 |
try:
|
13 |
artifact, _ = fetch_artifact(artifact_name)
|
14 |
return artifact
|
15 |
-
except UnitxtArtifactNotFoundError:
|
16 |
return None
|
17 |
|
18 |
|
@@ -20,13 +23,18 @@ def parse(query: str):
|
|
20 |
return parse_key_equals_value_string_to_dict(query)
|
21 |
|
22 |
|
23 |
-
def get_dataset_artifact(
|
|
|
|
|
|
|
|
|
|
|
24 |
_reset_env_local_catalogs()
|
25 |
register_all_artifacts()
|
26 |
-
recipe = fetch(
|
27 |
if recipe is None:
|
28 |
-
args = parse(
|
29 |
-
if "
|
30 |
-
args["
|
31 |
recipe = Artifact.from_dict(args)
|
32 |
return recipe
|
|
|
1 |
+
from json.decoder import JSONDecodeError
|
2 |
+
|
3 |
from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
|
4 |
from .logging_utils import get_logger
|
5 |
from .parsing_utils import parse_key_equals_value_string_to_dict
|
6 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
7 |
from .settings_utils import get_settings
|
8 |
+
from .standard import BaseRecipe
|
9 |
|
10 |
logger = get_logger()
|
11 |
settings = get_settings()
|
|
|
15 |
try:
|
16 |
artifact, _ = fetch_artifact(artifact_name)
|
17 |
return artifact
|
18 |
+
except (UnitxtArtifactNotFoundError, JSONDecodeError):
|
19 |
return None
|
20 |
|
21 |
|
|
|
23 |
return parse_key_equals_value_string_to_dict(query)
|
24 |
|
25 |
|
26 |
+
def get_dataset_artifact(dataset):
|
27 |
+
if isinstance(dataset, BaseRecipe):
|
28 |
+
return dataset
|
29 |
+
assert isinstance(
|
30 |
+
dataset, str
|
31 |
+
), "dataset should be string description of recipe, or recipe object."
|
32 |
_reset_env_local_catalogs()
|
33 |
register_all_artifacts()
|
34 |
+
recipe = fetch(dataset)
|
35 |
if recipe is None:
|
36 |
+
args = parse(dataset)
|
37 |
+
if "__type__" not in args:
|
38 |
+
args["__type__"] = settings.default_recipe
|
39 |
recipe = Artifact.from_dict(args)
|
40 |
return recipe
|
deprecation_utils.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import functools
|
2 |
import warnings
|
3 |
|
4 |
-
from .settings_utils import get_constants
|
5 |
|
6 |
constants = get_constants()
|
|
|
7 |
|
8 |
|
9 |
class DeprecationError(Exception):
|
@@ -60,9 +61,12 @@ def depraction_wrapper(obj, version, alt_text):
|
|
60 |
@functools.wraps(obj)
|
61 |
def wrapper(*args, **kwargs):
|
62 |
if constants.version < version:
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
elif constants.version >= version:
|
67 |
raise DeprecationError(f"{obj.__name__} is no longer supported.{alt_text}")
|
68 |
return obj(*args, **kwargs)
|
@@ -82,7 +86,7 @@ def deprecation(version, alternative=None):
|
|
82 |
"""
|
83 |
|
84 |
def decorator(obj):
|
85 |
-
alt_text = f" Use {alternative} instead." if alternative else ""
|
86 |
if callable(obj):
|
87 |
func = obj
|
88 |
elif hasattr(obj, "__init__"):
|
|
|
1 |
import functools
|
2 |
import warnings
|
3 |
|
4 |
+
from .settings_utils import get_constants, get_settings
|
5 |
|
6 |
constants = get_constants()
|
7 |
+
settings = get_settings()
|
8 |
|
9 |
|
10 |
class DeprecationError(Exception):
|
|
|
61 |
@functools.wraps(obj)
|
62 |
def wrapper(*args, **kwargs):
|
63 |
if constants.version < version:
|
64 |
+
if settings.default_verbosity in ["debug", "info", "warning"]:
|
65 |
+
warnings.warn(
|
66 |
+
f"{obj.__name__} is deprecated.{alt_text}",
|
67 |
+
DeprecationWarning,
|
68 |
+
stacklevel=2,
|
69 |
+
)
|
70 |
elif constants.version >= version:
|
71 |
raise DeprecationError(f"{obj.__name__} is no longer supported.{alt_text}")
|
72 |
return obj(*args, **kwargs)
|
|
|
86 |
"""
|
87 |
|
88 |
def decorator(obj):
|
89 |
+
alt_text = f" Use {alternative} instead." if alternative is not None else ""
|
90 |
if callable(obj):
|
91 |
func = obj
|
92 |
elif hasattr(obj, "__init__"):
|
dialog_operators.py
CHANGED
@@ -11,6 +11,7 @@ dialog = [
|
|
11 |
{"user": "kkk", "system": ""},
|
12 |
]
|
13 |
"""
|
|
|
14 |
from typing import Any, Dict, List, Optional
|
15 |
|
16 |
from .formats import SystemFormat
|
|
|
11 |
{"user": "kkk", "system": ""},
|
12 |
]
|
13 |
"""
|
14 |
+
|
15 |
from typing import Any, Dict, List, Optional
|
16 |
|
17 |
from .formats import SystemFormat
|
fusion.py
CHANGED
@@ -103,9 +103,10 @@ class WeightedFusion(BaseFusion):
|
|
103 |
If None, all instances are returned
|
104 |
"""
|
105 |
|
106 |
-
origins: Union[Dict[str,
|
107 |
weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
|
108 |
max_total_examples: int = None
|
|
|
109 |
|
110 |
def verify(self):
|
111 |
super().verify()
|
@@ -149,7 +150,10 @@ class WeightedFusion(BaseFusion):
|
|
149 |
try:
|
150 |
instance = next(iterator)
|
151 |
if isinstance(origin_name, str):
|
152 |
-
if
|
|
|
|
|
|
|
153 |
instance["group"] = origin_name + "/" + instance["group"]
|
154 |
else:
|
155 |
instance["group"] = origin_name
|
|
|
103 |
If None, all instances are returned
|
104 |
"""
|
105 |
|
106 |
+
origins: Union[Dict[str, SourceOperator], List[SourceOperator]] = None
|
107 |
weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
|
108 |
max_total_examples: int = None
|
109 |
+
ignore_origin_groups: List[str] = ["unitxt"]
|
110 |
|
111 |
def verify(self):
|
112 |
super().verify()
|
|
|
150 |
try:
|
151 |
instance = next(iterator)
|
152 |
if isinstance(origin_name, str):
|
153 |
+
if (
|
154 |
+
"group" in instance
|
155 |
+
and instance["group"] not in self.ignore_origin_groups
|
156 |
+
):
|
157 |
instance["group"] = origin_name + "/" + instance["group"]
|
158 |
else:
|
159 |
instance["group"] = origin_name
|
inference.py
CHANGED
@@ -3,6 +3,8 @@ import os
|
|
3 |
from dataclasses import field
|
4 |
from typing import Any, Dict, List, Literal, Optional, Union
|
5 |
|
|
|
|
|
6 |
from .artifact import Artifact
|
7 |
from .operator import PackageRequirementsMixin
|
8 |
|
@@ -15,12 +17,31 @@ class InferenceEngine(abc.ABC, Artifact):
|
|
15 |
"""Perform inference on the input dataset."""
|
16 |
pass
|
17 |
|
18 |
-
def infer(self, dataset):
|
19 |
"""Verifies instances of a dataset and performs inference."""
|
20 |
[self.verify_instance(instance) for instance in dataset]
|
21 |
return self._infer(dataset)
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
25 |
model_name: str
|
26 |
max_new_tokens: int
|
@@ -158,9 +179,12 @@ class OpenAiInferenceEngineParams(Artifact):
|
|
158 |
stop: Union[Optional[str], List[str]] = None
|
159 |
temperature: Optional[float] = None
|
160 |
top_p: Optional[float] = None
|
|
|
161 |
|
162 |
|
163 |
-
class OpenAiInferenceEngine(
|
|
|
|
|
164 |
label: str = "openai"
|
165 |
model_name: str
|
166 |
parameters: OpenAiInferenceEngineParams = field(
|
@@ -169,6 +193,7 @@ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
|
169 |
_requirement = {
|
170 |
"openai": "Install openai package using 'pip install --upgrade openai"
|
171 |
}
|
|
|
172 |
|
173 |
def prepare(self):
|
174 |
from openai import OpenAI
|
@@ -183,8 +208,38 @@ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
|
183 |
self.client = OpenAI(api_key=api_key)
|
184 |
|
185 |
def _infer(self, dataset):
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
messages=[
|
189 |
# {
|
190 |
# "role": "system",
|
@@ -203,6 +258,159 @@ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
|
203 |
stop=self.parameters.stop,
|
204 |
temperature=self.parameters.temperature,
|
205 |
top_p=self.parameters.top_p,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
)
|
207 |
for instance in dataset
|
208 |
]
|
|
|
3 |
from dataclasses import field
|
4 |
from typing import Any, Dict, List, Literal, Optional, Union
|
5 |
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
from .artifact import Artifact
|
9 |
from .operator import PackageRequirementsMixin
|
10 |
|
|
|
17 |
"""Perform inference on the input dataset."""
|
18 |
pass
|
19 |
|
20 |
+
def infer(self, dataset) -> str:
|
21 |
"""Verifies instances of a dataset and performs inference."""
|
22 |
[self.verify_instance(instance) for instance in dataset]
|
23 |
return self._infer(dataset)
|
24 |
|
25 |
|
26 |
+
class LogProbInferenceEngine(abc.ABC, Artifact):
|
27 |
+
"""Abstract base class for inference with log probs."""
|
28 |
+
|
29 |
+
@abc.abstractmethod
|
30 |
+
def _infer_log_probs(self, dataset):
|
31 |
+
"""Perform inference on the input dataset that returns log probs."""
|
32 |
+
pass
|
33 |
+
|
34 |
+
def infer_log_probs(self, dataset) -> List[Dict]:
|
35 |
+
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
36 |
+
|
37 |
+
For each instance , returns a list of top tokens per position.
|
38 |
+
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
|
39 |
+
|
40 |
+
"""
|
41 |
+
[self.verify_instance(instance) for instance in dataset]
|
42 |
+
return self._infer_log_probs(dataset)
|
43 |
+
|
44 |
+
|
45 |
class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
46 |
model_name: str
|
47 |
max_new_tokens: int
|
|
|
179 |
stop: Union[Optional[str], List[str]] = None
|
180 |
temperature: Optional[float] = None
|
181 |
top_p: Optional[float] = None
|
182 |
+
top_logprobs: Optional[int] = 20
|
183 |
|
184 |
|
185 |
+
class OpenAiInferenceEngine(
|
186 |
+
InferenceEngine, LogProbInferenceEngine, PackageRequirementsMixin
|
187 |
+
):
|
188 |
label: str = "openai"
|
189 |
model_name: str
|
190 |
parameters: OpenAiInferenceEngineParams = field(
|
|
|
193 |
_requirement = {
|
194 |
"openai": "Install openai package using 'pip install --upgrade openai"
|
195 |
}
|
196 |
+
data_classification_policy = ["public"]
|
197 |
|
198 |
def prepare(self):
|
199 |
from openai import OpenAI
|
|
|
208 |
self.client = OpenAI(api_key=api_key)
|
209 |
|
210 |
def _infer(self, dataset):
|
211 |
+
outputs = []
|
212 |
+
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
213 |
+
response = self.client.chat.completions.create(
|
214 |
+
messages=[
|
215 |
+
# {
|
216 |
+
# "role": "system",
|
217 |
+
# "content": self.system_prompt,
|
218 |
+
# },
|
219 |
+
{
|
220 |
+
"role": "user",
|
221 |
+
"content": instance["source"],
|
222 |
+
}
|
223 |
+
],
|
224 |
+
model=self.model_name,
|
225 |
+
frequency_penalty=self.parameters.frequency_penalty,
|
226 |
+
presence_penalty=self.parameters.presence_penalty,
|
227 |
+
max_tokens=self.parameters.max_tokens,
|
228 |
+
seed=self.parameters.seed,
|
229 |
+
stop=self.parameters.stop,
|
230 |
+
temperature=self.parameters.temperature,
|
231 |
+
top_p=self.parameters.top_p,
|
232 |
+
)
|
233 |
+
output = response.choices[0].message.content
|
234 |
+
|
235 |
+
outputs.append(output)
|
236 |
+
|
237 |
+
return outputs
|
238 |
+
|
239 |
+
def _infer_log_probs(self, dataset):
|
240 |
+
outputs = []
|
241 |
+
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
242 |
+
response = self.client.chat.completions.create(
|
243 |
messages=[
|
244 |
# {
|
245 |
# "role": "system",
|
|
|
258 |
stop=self.parameters.stop,
|
259 |
temperature=self.parameters.temperature,
|
260 |
top_p=self.parameters.top_p,
|
261 |
+
logprobs=True,
|
262 |
+
top_logprobs=self.parameters.top_logprobs,
|
263 |
+
)
|
264 |
+
top_logprobs_response = response.choices[0].logprobs.content
|
265 |
+
output = [
|
266 |
+
{
|
267 |
+
"top_tokens": [
|
268 |
+
{"text": obj.token, "logprob": obj.logprob}
|
269 |
+
for obj in generated_token.top_logprobs
|
270 |
+
]
|
271 |
+
}
|
272 |
+
for generated_token in top_logprobs_response
|
273 |
+
]
|
274 |
+
outputs.append(output)
|
275 |
+
return outputs
|
276 |
+
|
277 |
+
|
278 |
+
class WMLInferenceEngineParams(Artifact):
|
279 |
+
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
280 |
+
length_penalty: Optional[Dict[str, Union[int, float]]] = None
|
281 |
+
temperature: Optional[float] = None
|
282 |
+
top_p: Optional[float] = None
|
283 |
+
top_k: Optional[int] = None
|
284 |
+
random_seed: Optional[int] = None
|
285 |
+
repetition_penalty: Optional[float] = None
|
286 |
+
min_new_tokens: Optional[int] = None
|
287 |
+
max_new_tokens: Optional[int] = None
|
288 |
+
stop_sequences: Optional[List[str]] = None
|
289 |
+
time_limit: Optional[int] = None
|
290 |
+
truncate_input_tokens: Optional[int] = None
|
291 |
+
prompt_variables: Optional[Dict[str, Any]] = None
|
292 |
+
return_options: Optional[Dict[str, bool]] = None
|
293 |
+
|
294 |
+
def initialize_wml_parameters(self) -> Dict[str, Any]:
|
295 |
+
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
|
296 |
+
|
297 |
+
return {
|
298 |
+
param_name.upper(): param_value
|
299 |
+
for param_name, param_value in self.to_dict().items()
|
300 |
+
if param_value and param_name.upper() in GenTextParamsMetaNames().get()
|
301 |
+
}
|
302 |
+
|
303 |
+
|
304 |
+
class WMLInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
305 |
+
"""Runs inference using ibm-watsonx-ai.
|
306 |
+
|
307 |
+
Attributes:
|
308 |
+
client: By default, it is created by a class instance but can be directly
|
309 |
+
provided instead as an instance of 'ibm_watsonx_ai.client.APIClient'.
|
310 |
+
credentials: By default, it is created by a class instance which tries to retrieve
|
311 |
+
proper environment variables ("WML_URL", "WML_PROJECT_ID", "WML_APIKEY").
|
312 |
+
However, either a dictionary with the following keys: "url", "apikey",
|
313 |
+
"project_id", or an instance of 'ibm_watsonx_ai.credentials.Credentials'
|
314 |
+
can be directly provided instead.
|
315 |
+
model_name (str, optional): ID of a model to be used for inference. Mutually
|
316 |
+
exclusive with 'deployment_id'.
|
317 |
+
deployment_id (str, optional): Deployment ID of a tuned model to be used for
|
318 |
+
inference. Mutually exclusive with 'model_name'.
|
319 |
+
parameters (WMLInferenceEngineParams): An instance of 'WMLInferenceEngineParams'
|
320 |
+
which defines parameters used for inference. All the parameters are optional.
|
321 |
+
|
322 |
+
Examples:
|
323 |
+
from .api import load_dataset
|
324 |
+
|
325 |
+
wml_parameters = WMLInferenceEngineParams(top_p=0.5, random_seed=123)
|
326 |
+
wml_credentials = {
|
327 |
+
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
328 |
+
}
|
329 |
+
model_name = "google/flan-t5-xxl"
|
330 |
+
wml_inference = WMLInferenceEngine(
|
331 |
+
credentials=wml_credentials,
|
332 |
+
parameters=wml_parameters,
|
333 |
+
model_name=model_name,
|
334 |
+
)
|
335 |
+
|
336 |
+
dataset = load_dataset(
|
337 |
+
dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5"
|
338 |
+
)
|
339 |
+
results = wml_inference.infer(dataset["test"])
|
340 |
+
"""
|
341 |
+
|
342 |
+
client = None
|
343 |
+
credentials = None
|
344 |
+
model_name: Optional[str] = None
|
345 |
+
deployment_id: Optional[str] = None
|
346 |
+
parameters: WMLInferenceEngineParams = field(
|
347 |
+
default_factory=WMLInferenceEngineParams
|
348 |
+
)
|
349 |
+
|
350 |
+
_parameters: Dict[str, Any] = field(default_factory=dict)
|
351 |
+
|
352 |
+
label: str = "wml"
|
353 |
+
_requirement = {
|
354 |
+
"ibm-watsonx-ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
|
355 |
+
"It is advised to have Python version >=3.10 installed, as at lower version this package "
|
356 |
+
"may cause conflicts with other installed packages."
|
357 |
+
}
|
358 |
+
|
359 |
+
data_classification_policy = ["proprietary"]
|
360 |
+
|
361 |
+
@staticmethod
|
362 |
+
def _read_wml_credentials_from_env() -> Dict[str, str]:
|
363 |
+
credentials = {}
|
364 |
+
for env_var_name in ["WML_URL", "WML_PROJECT_ID", "WML_APIKEY"]:
|
365 |
+
env_var = os.environ.get(env_var_name)
|
366 |
+
assert env_var, (
|
367 |
+
f"Error while trying to run 'WMLInferenceEngine'. "
|
368 |
+
f"Please set the env variable: '{env_var_name}', or "
|
369 |
+
f"directly provide an instance of ibm-watsonx-ai 'Credentials' "
|
370 |
+
f"to the engine."
|
371 |
+
)
|
372 |
+
|
373 |
+
name = env_var_name.lower().replace("wml_", "")
|
374 |
+
credentials[name] = env_var
|
375 |
+
|
376 |
+
return credentials
|
377 |
+
|
378 |
+
def _initialize_wml_client(self):
|
379 |
+
from ibm_watsonx_ai.client import APIClient
|
380 |
+
|
381 |
+
if self.credentials is None:
|
382 |
+
self.credentials = self._read_wml_credentials_from_env()
|
383 |
+
|
384 |
+
client = APIClient(credentials=self.credentials)
|
385 |
+
client.set.default_project(self.credentials["project_id"])
|
386 |
+
return client
|
387 |
+
|
388 |
+
def prepare(self):
|
389 |
+
if self.client is None:
|
390 |
+
self.client = self._initialize_wml_client()
|
391 |
+
self._parameters = self.parameters.initialize_wml_parameters()
|
392 |
+
|
393 |
+
def verify(self):
|
394 |
+
assert (
|
395 |
+
self.model_name
|
396 |
+
or self.deployment_id
|
397 |
+
and not (self.model_name and self.deployment_id)
|
398 |
+
), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time."
|
399 |
+
super().verify()
|
400 |
+
|
401 |
+
def _infer(self, dataset):
|
402 |
+
from ibm_watsonx_ai.foundation_models import ModelInference
|
403 |
+
|
404 |
+
model = ModelInference(
|
405 |
+
model_id=self.model_name,
|
406 |
+
deployment_id=self.deployment_id,
|
407 |
+
api_client=self.client,
|
408 |
+
)
|
409 |
+
|
410 |
+
return [
|
411 |
+
model.generate_text(
|
412 |
+
prompt=instance["source"],
|
413 |
+
params=self._parameters,
|
414 |
)
|
415 |
for instance in dataset
|
416 |
]
|
llm_as_judge.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from typing import Any, Dict, List, Literal, Optional
|
2 |
|
3 |
from .api import evaluate, produce
|
|
|
4 |
from .inference import InferenceEngine, OpenAiInferenceEngine
|
5 |
from .metrics import BulkInstanceMetric
|
6 |
from .operator import SequentialOperator
|
@@ -121,22 +122,25 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
121 |
)
|
122 |
|
123 |
card = f"cards.dynamic_cards_for_llm_judges.{self.task}"
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
"demos_pool_size
|
128 |
-
"num_demos
|
129 |
-
|
|
|
130 |
if self.system_prompt:
|
131 |
-
|
132 |
if self.format:
|
133 |
-
|
134 |
-
|
135 |
dataset = produce(instances, recipe)
|
136 |
verdicts = self.inference_model.infer(dataset)
|
137 |
meta_scores = evaluate(predictions=verdicts, data=dataset)
|
138 |
return [
|
139 |
-
{
|
140 |
-
|
141 |
-
|
|
|
|
|
142 |
]
|
|
|
1 |
from typing import Any, Dict, List, Literal, Optional
|
2 |
|
3 |
from .api import evaluate, produce
|
4 |
+
from .artifact import Artifact, settings
|
5 |
from .inference import InferenceEngine, OpenAiInferenceEngine
|
6 |
from .metrics import BulkInstanceMetric
|
7 |
from .operator import SequentialOperator
|
|
|
122 |
)
|
123 |
|
124 |
card = f"cards.dynamic_cards_for_llm_judges.{self.task}"
|
125 |
+
recipe_args = {
|
126 |
+
"card": card,
|
127 |
+
"template": self.template,
|
128 |
+
"demos_pool_size": 0,
|
129 |
+
"num_demos": 0,
|
130 |
+
"__type__": settings.default_recipe,
|
131 |
+
}
|
132 |
if self.system_prompt:
|
133 |
+
recipe_args["system_prompt"] = self.system_prompt
|
134 |
if self.format:
|
135 |
+
recipe_args["format"] = self.format
|
136 |
+
recipe = Artifact.from_dict(recipe_args)
|
137 |
dataset = produce(instances, recipe)
|
138 |
verdicts = self.inference_model.infer(dataset)
|
139 |
meta_scores = evaluate(predictions=verdicts, data=dataset)
|
140 |
return [
|
141 |
+
{
|
142 |
+
self.main_score: instance["processed_prediction"],
|
143 |
+
"judge_raw_output": verdict,
|
144 |
+
}
|
145 |
+
for instance, verdict in zip(meta_scores, verdicts)
|
146 |
]
|
loaders.py
CHANGED
@@ -30,6 +30,7 @@ Available Loaders Overview:
|
|
30 |
|
31 |
------------------------
|
32 |
"""
|
|
|
33 |
import fnmatch
|
34 |
import itertools
|
35 |
import os
|
@@ -49,9 +50,10 @@ from .dataclass import InternalField, OptionalField
|
|
49 |
from .fusion import FixedFusion
|
50 |
from .logging_utils import get_logger
|
51 |
from .operator import SourceOperator
|
52 |
-
from .operators import
|
53 |
from .settings_utils import get_settings
|
54 |
from .stream import DynamicStream, MultiStream
|
|
|
55 |
|
56 |
logger = get_logger()
|
57 |
settings = get_settings()
|
@@ -110,7 +112,7 @@ class Loader(SourceOperator):
|
|
110 |
f"data_classification_policy =['confidential','pii'])\n"
|
111 |
)
|
112 |
|
113 |
-
operator =
|
114 |
fields={"data_classification_policy": self.data_classification_policy}
|
115 |
)
|
116 |
return operator(multi_stream)
|
@@ -176,14 +178,13 @@ class LoadHF(Loader):
|
|
176 |
self._requirements_list.append(requirement)
|
177 |
super().verify()
|
178 |
|
179 |
-
def
|
|
|
|
|
|
|
|
|
180 |
logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
|
181 |
-
return
|
182 |
-
{
|
183 |
-
name: dataset[name].filter(eval(self.filtering_lambda))
|
184 |
-
for name in dataset
|
185 |
-
}
|
186 |
-
)
|
187 |
|
188 |
def stream_dataset(self):
|
189 |
if self._cache is None:
|
@@ -206,16 +207,17 @@ class LoadHF(Loader):
|
|
206 |
) from e
|
207 |
raise e
|
208 |
|
209 |
-
if self.filtering_lambda is not None:
|
210 |
-
dataset = self.filtered_load(dataset)
|
211 |
-
|
212 |
if self.split is not None:
|
213 |
dataset = {self.split: dataset}
|
214 |
|
215 |
self._cache = dataset
|
|
|
216 |
else:
|
217 |
dataset = self._cache
|
218 |
|
|
|
|
|
|
|
219 |
return dataset
|
220 |
|
221 |
def load_dataset(self):
|
@@ -239,9 +241,6 @@ class LoadHF(Loader):
|
|
239 |
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
|
240 |
) from e
|
241 |
|
242 |
-
if self.filtering_lambda is not None:
|
243 |
-
dataset = self.filtered_load(dataset)
|
244 |
-
|
245 |
if self.split is None:
|
246 |
for split in dataset.keys():
|
247 |
dataset[split] = dataset[split].to_iterable_dataset()
|
@@ -249,20 +248,25 @@ class LoadHF(Loader):
|
|
249 |
dataset = {self.split: dataset}
|
250 |
|
251 |
self._cache = dataset
|
|
|
252 |
else:
|
253 |
dataset = self._cache
|
254 |
|
|
|
|
|
|
|
255 |
return dataset
|
256 |
|
257 |
-
def split_limited_load(self, split_name):
|
258 |
-
yield from itertools.islice(
|
259 |
|
260 |
-
def limited_load(self):
|
261 |
self.log_limited_loading()
|
262 |
return MultiStream(
|
263 |
{
|
264 |
name: DynamicStream(
|
265 |
-
generator=self.split_limited_load,
|
|
|
266 |
)
|
267 |
for name in self._cache.keys()
|
268 |
}
|
@@ -285,7 +289,7 @@ class LoadHF(Loader):
|
|
285 |
dataset = self.load_dataset()
|
286 |
|
287 |
if self.get_limit() is not None:
|
288 |
-
return self.limited_load()
|
289 |
|
290 |
return MultiStream.from_iterables(dataset)
|
291 |
|
@@ -616,7 +620,7 @@ class LoadFromIBMCloud(Loader):
|
|
616 |
object_key,
|
617 |
local_dir + "/" + os.path.basename(temp_file.name),
|
618 |
)
|
619 |
-
os.
|
620 |
local_dir + "/" + os.path.basename(temp_file.name),
|
621 |
local_dir + "/" + data_file,
|
622 |
)
|
@@ -692,6 +696,25 @@ class LoadFromDictionary(Loader):
|
|
692 |
|
693 |
data: Dict[str, List[Dict[str, Any]]]
|
694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
def load_data(self) -> MultiStream:
|
696 |
self.sef_default_data_classification(
|
697 |
["proprietary"], "when loading from python dictionary"
|
|
|
30 |
|
31 |
------------------------
|
32 |
"""
|
33 |
+
|
34 |
import fnmatch
|
35 |
import itertools
|
36 |
import os
|
|
|
50 |
from .fusion import FixedFusion
|
51 |
from .logging_utils import get_logger
|
52 |
from .operator import SourceOperator
|
53 |
+
from .operators import Set
|
54 |
from .settings_utils import get_settings
|
55 |
from .stream import DynamicStream, MultiStream
|
56 |
+
from .type_utils import isoftype
|
57 |
|
58 |
logger = get_logger()
|
59 |
settings = get_settings()
|
|
|
112 |
f"data_classification_policy =['confidential','pii'])\n"
|
113 |
)
|
114 |
|
115 |
+
operator = Set(
|
116 |
fields={"data_classification_policy": self.data_classification_policy}
|
117 |
)
|
118 |
return operator(multi_stream)
|
|
|
178 |
self._requirements_list.append(requirement)
|
179 |
super().verify()
|
180 |
|
181 |
+
def filter_load(self, dataset):
|
182 |
+
if not settings.allow_unverified_code:
|
183 |
+
raise ValueError(
|
184 |
+
f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
|
185 |
+
)
|
186 |
logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
|
187 |
+
return dataset.filter(eval(self.filtering_lambda))
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
def stream_dataset(self):
|
190 |
if self._cache is None:
|
|
|
207 |
) from e
|
208 |
raise e
|
209 |
|
|
|
|
|
|
|
210 |
if self.split is not None:
|
211 |
dataset = {self.split: dataset}
|
212 |
|
213 |
self._cache = dataset
|
214 |
+
|
215 |
else:
|
216 |
dataset = self._cache
|
217 |
|
218 |
+
if self.filtering_lambda is not None:
|
219 |
+
dataset = self.filter_load(dataset)
|
220 |
+
|
221 |
return dataset
|
222 |
|
223 |
def load_dataset(self):
|
|
|
241 |
f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
|
242 |
) from e
|
243 |
|
|
|
|
|
|
|
244 |
if self.split is None:
|
245 |
for split in dataset.keys():
|
246 |
dataset[split] = dataset[split].to_iterable_dataset()
|
|
|
248 |
dataset = {self.split: dataset}
|
249 |
|
250 |
self._cache = dataset
|
251 |
+
|
252 |
else:
|
253 |
dataset = self._cache
|
254 |
|
255 |
+
if self.filtering_lambda is not None:
|
256 |
+
dataset = self.filter_load(dataset)
|
257 |
+
|
258 |
return dataset
|
259 |
|
260 |
+
def split_limited_load(self, dataset, split_name):
|
261 |
+
yield from itertools.islice(dataset[split_name], self.get_limit())
|
262 |
|
263 |
+
def limited_load(self, dataset):
|
264 |
self.log_limited_loading()
|
265 |
return MultiStream(
|
266 |
{
|
267 |
name: DynamicStream(
|
268 |
+
generator=self.split_limited_load,
|
269 |
+
gen_kwargs={"dataset": dataset, "split_name": name},
|
270 |
)
|
271 |
for name in self._cache.keys()
|
272 |
}
|
|
|
289 |
dataset = self.load_dataset()
|
290 |
|
291 |
if self.get_limit() is not None:
|
292 |
+
return self.limited_load(dataset=dataset)
|
293 |
|
294 |
return MultiStream.from_iterables(dataset)
|
295 |
|
|
|
620 |
object_key,
|
621 |
local_dir + "/" + os.path.basename(temp_file.name),
|
622 |
)
|
623 |
+
os.renames(
|
624 |
local_dir + "/" + os.path.basename(temp_file.name),
|
625 |
local_dir + "/" + data_file,
|
626 |
)
|
|
|
696 |
|
697 |
data: Dict[str, List[Dict[str, Any]]]
|
698 |
|
699 |
+
def verify(self):
|
700 |
+
super().verify()
|
701 |
+
if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
|
702 |
+
raise ValueError(
|
703 |
+
f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
|
704 |
+
f"Expected data should map between split name and list of instances.\n"
|
705 |
+
f"Received value: {self.data}\n"
|
706 |
+
)
|
707 |
+
for split in self.data.keys():
|
708 |
+
if len(self.data[split]) == 0:
|
709 |
+
raise ValueError(f"Split {split} has no instances.")
|
710 |
+
first_instance = self.data[split][0]
|
711 |
+
for instance in self.data[split]:
|
712 |
+
if instance.keys() != first_instance.keys():
|
713 |
+
raise ValueError(
|
714 |
+
f"Not all instances in split '{split}' have the same fields.\n"
|
715 |
+
f"instance {instance} has different fields different from {first_instance}"
|
716 |
+
)
|
717 |
+
|
718 |
def load_data(self) -> MultiStream:
|
719 |
self.sef_default_data_classification(
|
720 |
["proprietary"], "when loading from python dictionary"
|
metric.py
CHANGED
@@ -19,13 +19,16 @@ from .file_utils import __file__ as _
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
|
|
22 |
from .hf_utils import verify_versions_compatibility
|
23 |
from .inference import __file__ as _
|
24 |
from .instructions import __file__ as _
|
25 |
from .llm_as_judge import __file__ as _
|
26 |
from .loaders import __file__ as _
|
27 |
from .logging_utils import __file__ as _
|
28 |
-
from .metric_utils import UNITXT_METRIC_SCHEMA
|
|
|
|
|
29 |
from .metrics import __file__ as _
|
30 |
from .normalizers import __file__ as _
|
31 |
from .operator import __file__ as _
|
@@ -36,6 +39,7 @@ from .random_utils import __file__ as _
|
|
36 |
from .recipe import __file__ as _
|
37 |
from .register import __file__ as _
|
38 |
from .schema import __file__ as _
|
|
|
39 |
from .settings_utils import get_constants
|
40 |
from .span_lableing_operators import __file__ as _
|
41 |
from .split_utils import __file__ as _
|
@@ -50,6 +54,7 @@ from .task import __file__ as _
|
|
50 |
from .templates import __file__ as _
|
51 |
from .text_utils import __file__ as _
|
52 |
from .type_utils import __file__ as _
|
|
|
53 |
from .utils import is_package_installed
|
54 |
from .validate import __file__ as _
|
55 |
from .version import __file__ as _
|
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
22 |
+
from .hf_utils import __file__ as _
|
23 |
from .hf_utils import verify_versions_compatibility
|
24 |
from .inference import __file__ as _
|
25 |
from .instructions import __file__ as _
|
26 |
from .llm_as_judge import __file__ as _
|
27 |
from .loaders import __file__ as _
|
28 |
from .logging_utils import __file__ as _
|
29 |
+
from .metric_utils import UNITXT_METRIC_SCHEMA
|
30 |
+
from .metric_utils import __file__ as _
|
31 |
+
from .metric_utils import _compute
|
32 |
from .metrics import __file__ as _
|
33 |
from .normalizers import __file__ as _
|
34 |
from .operator import __file__ as _
|
|
|
39 |
from .recipe import __file__ as _
|
40 |
from .register import __file__ as _
|
41 |
from .schema import __file__ as _
|
42 |
+
from .settings_utils import __file__ as _
|
43 |
from .settings_utils import get_constants
|
44 |
from .span_lableing_operators import __file__ as _
|
45 |
from .split_utils import __file__ as _
|
|
|
54 |
from .templates import __file__ as _
|
55 |
from .text_utils import __file__ as _
|
56 |
from .type_utils import __file__ as _
|
57 |
+
from .utils import __file__ as _
|
58 |
from .utils import is_package_installed
|
59 |
from .validate import __file__ as _
|
60 |
from .version import __file__ as _
|
metric_utils.py
CHANGED
@@ -9,6 +9,7 @@ from .dataclass import Dataclass
|
|
9 |
from .dict_utils import dict_set
|
10 |
from .operator import (
|
11 |
MultiStreamOperator,
|
|
|
12 |
SequentialOperatorInitializer,
|
13 |
StreamInitializerOperator,
|
14 |
)
|
@@ -18,6 +19,7 @@ from .operators import (
|
|
18 |
Copy,
|
19 |
FlattenInstances,
|
20 |
MergeStreams,
|
|
|
21 |
SplitByNestedGroup,
|
22 |
)
|
23 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
@@ -145,6 +147,59 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
145 |
# When receiving instances from this scheme, the keys and values are returned as two separate
|
146 |
# lists, and are converted to a dictionary.
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
class MetricRecipe(SequentialOperatorInitializer):
|
150 |
calc_confidence_intervals: bool = True
|
@@ -155,13 +210,7 @@ class MetricRecipe(SequentialOperatorInitializer):
|
|
155 |
self.steps = [
|
156 |
FromPredictionsAndOriginalData(),
|
157 |
LoadJson(field="task_data"),
|
158 |
-
|
159 |
-
field="source",
|
160 |
-
to_field="task_data/source",
|
161 |
-
),
|
162 |
-
ApplyOperatorsField(
|
163 |
-
operators_field="postprocessors",
|
164 |
-
),
|
165 |
SplitByNestedGroup(
|
166 |
field_name_of_group="group",
|
167 |
number_of_fusion_generations=self.number_of_fusion_generations,
|
@@ -172,6 +221,18 @@ class MetricRecipe(SequentialOperatorInitializer):
|
|
172 |
),
|
173 |
MultiStreamScoreMean(),
|
174 |
MergeStreams(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
]
|
176 |
|
177 |
|
|
|
9 |
from .dict_utils import dict_set
|
10 |
from .operator import (
|
11 |
MultiStreamOperator,
|
12 |
+
SequentialOperator,
|
13 |
SequentialOperatorInitializer,
|
14 |
StreamInitializerOperator,
|
15 |
)
|
|
|
19 |
Copy,
|
20 |
FlattenInstances,
|
21 |
MergeStreams,
|
22 |
+
RenameFields,
|
23 |
SplitByNestedGroup,
|
24 |
)
|
25 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
|
|
147 |
# When receiving instances from this scheme, the keys and values are returned as two separate
|
148 |
# lists, and are converted to a dictionary.
|
149 |
|
150 |
+
_post_process_steps = SequentialOperator(
|
151 |
+
steps=[
|
152 |
+
Copy(
|
153 |
+
field="prediction",
|
154 |
+
to_field="raw_prediction",
|
155 |
+
),
|
156 |
+
Copy(
|
157 |
+
field="references",
|
158 |
+
to_field="raw_references",
|
159 |
+
),
|
160 |
+
Copy(
|
161 |
+
field="source",
|
162 |
+
to_field="task_data/source",
|
163 |
+
),
|
164 |
+
ApplyOperatorsField(
|
165 |
+
operators_field="postprocessors",
|
166 |
+
),
|
167 |
+
Copy(
|
168 |
+
field="prediction",
|
169 |
+
to_field="processed_prediction",
|
170 |
+
),
|
171 |
+
Copy(
|
172 |
+
field="references",
|
173 |
+
to_field="processed_references",
|
174 |
+
),
|
175 |
+
]
|
176 |
+
)
|
177 |
+
|
178 |
+
|
179 |
+
class PostProcessRecipe(SequentialOperatorInitializer):
|
180 |
+
def prepare(self):
|
181 |
+
register_all_artifacts()
|
182 |
+
self.steps = [
|
183 |
+
FromPredictionsAndOriginalData(),
|
184 |
+
_post_process_steps,
|
185 |
+
]
|
186 |
+
|
187 |
+
|
188 |
+
def _post_process(
|
189 |
+
predictions: List[str],
|
190 |
+
references: Iterable,
|
191 |
+
split_name: str = "all",
|
192 |
+
):
|
193 |
+
_reset_env_local_catalogs()
|
194 |
+
register_all_artifacts()
|
195 |
+
recipe = PostProcessRecipe()
|
196 |
+
|
197 |
+
multi_stream = recipe(
|
198 |
+
predictions=predictions, references=references, split_name=split_name
|
199 |
+
)
|
200 |
+
|
201 |
+
return [instance["processed_prediction"] for instance in multi_stream[split_name]]
|
202 |
+
|
203 |
|
204 |
class MetricRecipe(SequentialOperatorInitializer):
|
205 |
calc_confidence_intervals: bool = True
|
|
|
210 |
self.steps = [
|
211 |
FromPredictionsAndOriginalData(),
|
212 |
LoadJson(field="task_data"),
|
213 |
+
_post_process_steps,
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
SplitByNestedGroup(
|
215 |
field_name_of_group="group",
|
216 |
number_of_fusion_generations=self.number_of_fusion_generations,
|
|
|
221 |
),
|
222 |
MultiStreamScoreMean(),
|
223 |
MergeStreams(),
|
224 |
+
RenameFields(
|
225 |
+
field="raw_prediction",
|
226 |
+
to_field="prediction",
|
227 |
+
),
|
228 |
+
RenameFields(
|
229 |
+
field="raw_references",
|
230 |
+
to_field="references",
|
231 |
+
),
|
232 |
+
Copy(
|
233 |
+
field="source",
|
234 |
+
to_field="task_data/source",
|
235 |
+
),
|
236 |
]
|
237 |
|
238 |
|
metrics.py
CHANGED
@@ -27,7 +27,7 @@ from .operator import (
|
|
27 |
StreamingOperator,
|
28 |
StreamOperator,
|
29 |
)
|
30 |
-
from .operators import
|
31 |
from .random_utils import get_seed
|
32 |
from .settings_utils import get_settings
|
33 |
from .stream import MultiStream, Stream
|
@@ -1123,7 +1123,7 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
1123 |
|
1124 |
def prepare(self):
|
1125 |
super().prepare()
|
1126 |
-
self.prepare_score =
|
1127 |
field_to_field=[
|
1128 |
[
|
1129 |
f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
|
@@ -2134,9 +2134,7 @@ class Detector(BulkInstanceMetric):
|
|
2134 |
return self.pipe(predictions, batch_size=self.batch_size)
|
2135 |
|
2136 |
|
2137 |
-
class
|
2138 |
-
"""LlamaIndex based metric class for evaluating correctness."""
|
2139 |
-
|
2140 |
model_name: str = ""
|
2141 |
main_score: str = ""
|
2142 |
prediction_type: str = "str"
|
@@ -2151,6 +2149,34 @@ class LlamaIndexCorrectness(InstanceMetric):
|
|
2151 |
|
2152 |
_requirements_list: List[str] = ["llama_index"]
|
2153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2154 |
@staticmethod
|
2155 |
def _custom_parser(eval_response: str):
|
2156 |
"""Default parser function for evaluation response.
|
@@ -2174,37 +2200,14 @@ class LlamaIndexCorrectness(InstanceMetric):
|
|
2174 |
reasoning = reasoning_str.lstrip("\n")
|
2175 |
return score, reasoning
|
2176 |
|
2177 |
-
def _model_using_extrnal_api(self):
|
2178 |
-
return self.model_name in self.external_api_models
|
2179 |
-
|
2180 |
def prepare(self):
|
2181 |
"""Initialization method for the metric. Initializes the CorrectnessEvaluator with the OpenAI model."""
|
2182 |
super().prepare()
|
2183 |
|
2184 |
-
self.model_name_normalized = self.model_name.replace(".", "_").replace("-", "_")
|
2185 |
-
self.main_score: str = (
|
2186 |
-
f"correctness_llama_index_by_{self.model_name_normalized}_judge"
|
2187 |
-
)
|
2188 |
-
|
2189 |
-
self.reduction_map: Dict[str, List[str]] = {"mean": [self.main_score]}
|
2190 |
-
|
2191 |
from llama_index.core.evaluation import CorrectnessEvaluator
|
2192 |
|
2193 |
-
if self.model_name in self.openai_models:
|
2194 |
-
from llama_index.llms.openai import OpenAI
|
2195 |
-
|
2196 |
-
llm = OpenAI("gpt-3.5-turbo")
|
2197 |
-
elif self.model_name in self.mock_models:
|
2198 |
-
from llama_index.core.llms.mock import MockLLM
|
2199 |
-
|
2200 |
-
llm = MockLLM(system_prompt="5") # perfect score
|
2201 |
-
else:
|
2202 |
-
raise NotImplementedError(
|
2203 |
-
f"LlamaIndexCorrectnessMetric does not support {self.model_name}, currently only gpt-3.5-turbo is supported"
|
2204 |
-
)
|
2205 |
-
|
2206 |
self.evaluator = CorrectnessEvaluator(
|
2207 |
-
llm=llm, parser_function=self._custom_parser
|
2208 |
)
|
2209 |
|
2210 |
def compute(
|
@@ -2226,9 +2229,6 @@ class LlamaIndexCorrectness(InstanceMetric):
|
|
2226 |
Raises:
|
2227 |
AssertionError: If the input does not meet the expected format.
|
2228 |
"""
|
2229 |
-
# treat the references as the questions and the predictions as answers
|
2230 |
-
# assume a single reference
|
2231 |
-
|
2232 |
query = task_data["question"]
|
2233 |
|
2234 |
contexts = None
|
@@ -2247,11 +2247,36 @@ class LlamaIndexCorrectness(InstanceMetric):
|
|
2247 |
)
|
2248 |
result = max([results.score for results in per_reference_results])
|
2249 |
|
2250 |
-
return {
|
2251 |
-
|
2252 |
-
|
2253 |
-
|
2254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2255 |
|
2256 |
|
2257 |
class Perplexity(BulkInstanceMetric):
|
|
|
27 |
StreamingOperator,
|
28 |
StreamOperator,
|
29 |
)
|
30 |
+
from .operators import Copy
|
31 |
from .random_utils import get_seed
|
32 |
from .settings_utils import get_settings
|
33 |
from .stream import MultiStream, Stream
|
|
|
1123 |
|
1124 |
def prepare(self):
|
1125 |
super().prepare()
|
1126 |
+
self.prepare_score = Copy(
|
1127 |
field_to_field=[
|
1128 |
[
|
1129 |
f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
|
|
|
2134 |
return self.pipe(predictions, batch_size=self.batch_size)
|
2135 |
|
2136 |
|
2137 |
+
class LlamaIndexLLMMetric(InstanceMetric):
|
|
|
|
|
2138 |
model_name: str = ""
|
2139 |
main_score: str = ""
|
2140 |
prediction_type: str = "str"
|
|
|
2149 |
|
2150 |
_requirements_list: List[str] = ["llama_index"]
|
2151 |
|
2152 |
+
def prepare(self):
|
2153 |
+
self.model_name_normalized = self.model_name.replace(".", "_").replace("-", "_")
|
2154 |
+
self.main_score: str = f"llama_index_by_{self.model_name_normalized}_judge"
|
2155 |
+
|
2156 |
+
self.reduction_map: Dict[str, List[str]] = {"mean": [self.main_score]}
|
2157 |
+
|
2158 |
+
if self.model_name in self.openai_models:
|
2159 |
+
from llama_index.llms.openai import OpenAI
|
2160 |
+
|
2161 |
+
self.llm = OpenAI("gpt-3.5-turbo")
|
2162 |
+
elif self.model_name in self.mock_models:
|
2163 |
+
from llama_index.core.llms.mock import MockLLM
|
2164 |
+
|
2165 |
+
self.llm = MockLLM(system_prompt="5") # perfect score
|
2166 |
+
else:
|
2167 |
+
raise NotImplementedError(
|
2168 |
+
f"LlamaIndexLLM metric does not support {self.model_name}, currently only gpt-3.5-turbo is supported"
|
2169 |
+
)
|
2170 |
+
|
2171 |
+
def _model_using_extrnal_api(self):
|
2172 |
+
return self.model_name in self.external_api_models
|
2173 |
+
|
2174 |
+
|
2175 |
+
class LlamaIndexCorrectness(LlamaIndexLLMMetric):
|
2176 |
+
"""LlamaIndex based metric class for evaluating correctness."""
|
2177 |
+
|
2178 |
+
score_prefix = "correctness_"
|
2179 |
+
|
2180 |
@staticmethod
|
2181 |
def _custom_parser(eval_response: str):
|
2182 |
"""Default parser function for evaluation response.
|
|
|
2200 |
reasoning = reasoning_str.lstrip("\n")
|
2201 |
return score, reasoning
|
2202 |
|
|
|
|
|
|
|
2203 |
def prepare(self):
|
2204 |
"""Initialization method for the metric. Initializes the CorrectnessEvaluator with the OpenAI model."""
|
2205 |
super().prepare()
|
2206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2207 |
from llama_index.core.evaluation import CorrectnessEvaluator
|
2208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2209 |
self.evaluator = CorrectnessEvaluator(
|
2210 |
+
llm=self.llm, parser_function=self._custom_parser
|
2211 |
)
|
2212 |
|
2213 |
def compute(
|
|
|
2229 |
Raises:
|
2230 |
AssertionError: If the input does not meet the expected format.
|
2231 |
"""
|
|
|
|
|
|
|
2232 |
query = task_data["question"]
|
2233 |
|
2234 |
contexts = None
|
|
|
2247 |
)
|
2248 |
result = max([results.score for results in per_reference_results])
|
2249 |
|
2250 |
+
return {self.main_score: result / 5}
|
2251 |
+
|
2252 |
+
|
2253 |
+
class LlamaIndexFaithfulness(LlamaIndexLLMMetric):
|
2254 |
+
"""LlamaIndex based metric class for evaluating faithfulness."""
|
2255 |
+
|
2256 |
+
score_prefix = "faithfulness_"
|
2257 |
+
|
2258 |
+
def prepare(self):
|
2259 |
+
"""Initialization method for the metric. Initializes the FaithfulnessEvaluator with the OpenAI model."""
|
2260 |
+
super().prepare()
|
2261 |
+
|
2262 |
+
from llama_index.core.evaluation import FaithfulnessEvaluator
|
2263 |
+
|
2264 |
+
self.evaluator = FaithfulnessEvaluator(llm=self.llm)
|
2265 |
+
|
2266 |
+
def compute(
|
2267 |
+
self,
|
2268 |
+
references: List[str],
|
2269 |
+
prediction: str,
|
2270 |
+
task_data: Dict,
|
2271 |
+
) -> Dict[str, Any]:
|
2272 |
+
result = self.evaluator.evaluate(
|
2273 |
+
query=task_data["question"],
|
2274 |
+
response=prediction,
|
2275 |
+
contexts=task_data["contexts"],
|
2276 |
+
)
|
2277 |
+
score = result.score
|
2278 |
+
|
2279 |
+
return {self.main_score: score}
|
2280 |
|
2281 |
|
2282 |
class Perplexity(BulkInstanceMetric):
|
operator.py
CHANGED
@@ -117,54 +117,6 @@ class SideEffectOperator(StreamingOperator):
|
|
117 |
pass
|
118 |
|
119 |
|
120 |
-
class SourceOperator(StreamingOperator):
|
121 |
-
"""A class representing a source operator in the streaming system.
|
122 |
-
|
123 |
-
A source operator is responsible for generating the data stream from some source, such as a database or a file.
|
124 |
-
This is the starting point of a stream processing pipeline.
|
125 |
-
The `SourceOperator` class is a type of `SourceOperator`, which is a special type of `StreamingOperator`
|
126 |
-
that generates an output stream but does not take any input streams.
|
127 |
-
|
128 |
-
When called, a `SourceOperator` invokes its `process` method, which should be implemented by all subclasses
|
129 |
-
to generate the required `MultiStream`.
|
130 |
-
|
131 |
-
"""
|
132 |
-
|
133 |
-
caching: bool = NonPositionalField(default=None)
|
134 |
-
|
135 |
-
def __call__(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
|
136 |
-
multi_stream = self.process()
|
137 |
-
if self.caching is not None:
|
138 |
-
multi_stream.set_caching(self.caching)
|
139 |
-
return multi_stream
|
140 |
-
|
141 |
-
@abstractmethod
|
142 |
-
def process(self) -> MultiStream:
|
143 |
-
pass
|
144 |
-
|
145 |
-
|
146 |
-
class StreamInitializerOperator(SourceOperator):
|
147 |
-
"""A class representing a stream initializer operator in the streaming system.
|
148 |
-
|
149 |
-
A stream initializer operator is a special type of `SourceOperator` that is capable of taking parameters during the stream generation process. This can be useful in situations where the stream generation process needs to be customized or configured based on certain parameters.
|
150 |
-
|
151 |
-
When called, a `StreamInitializerOperator` invokes its `process` method, passing any supplied arguments and keyword arguments. The `process` method should be implemented by all subclasses to generate the required `MultiStream` based on the given arguments and keyword arguments.
|
152 |
-
|
153 |
-
"""
|
154 |
-
|
155 |
-
caching: bool = NonPositionalField(default=None)
|
156 |
-
|
157 |
-
def __call__(self, *args, **kwargs) -> MultiStream:
|
158 |
-
multi_stream = self.process(*args, **kwargs)
|
159 |
-
if self.caching is not None:
|
160 |
-
multi_stream.set_caching(self.caching)
|
161 |
-
return self.process(*args, **kwargs)
|
162 |
-
|
163 |
-
@abstractmethod
|
164 |
-
def process(self, *args, **kwargs) -> MultiStream:
|
165 |
-
pass
|
166 |
-
|
167 |
-
|
168 |
def instance_generator(instance):
|
169 |
yield instance
|
170 |
|
@@ -213,6 +165,55 @@ class MultiStreamOperator(StreamingOperator):
|
|
213 |
return next(iter(processed_multi_stream[stream_name]))
|
214 |
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
class StreamOperator(MultiStreamOperator):
|
217 |
"""A class representing a single-stream operator in the streaming system.
|
218 |
|
@@ -458,15 +459,8 @@ class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
|
|
458 |
pass
|
459 |
|
460 |
|
461 |
-
class
|
462 |
-
|
463 |
-
|
464 |
-
A sequential operator is a type of `MultiStreamOperator` that applies a sequence of other operators to a
|
465 |
-
`MultiStream`. It maintains a list of `StreamingOperator`s and applies them in order to the `MultiStream`.
|
466 |
-
"""
|
467 |
-
|
468 |
-
max_steps = None
|
469 |
-
|
470 |
steps: List[StreamingOperator] = field(default_factory=list)
|
471 |
|
472 |
def num_steps(self) -> int:
|
@@ -488,13 +482,21 @@ class SequentialOperator(MultiStreamOperator):
|
|
488 |
def _get_max_steps(self):
|
489 |
return self.max_steps if self.max_steps is not None else len(self.steps)
|
490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
def process(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
|
492 |
for operator in self.steps[0 : self._get_max_steps()]:
|
493 |
multi_stream = operator(multi_stream)
|
494 |
return multi_stream
|
495 |
|
496 |
|
497 |
-
class SourceSequentialOperator(
|
498 |
"""A class representing a source sequential operator in the streaming system.
|
499 |
|
500 |
A source sequential operator is a type of `SequentialOperator` that starts with a source operator.
|
@@ -502,9 +504,6 @@ class SourceSequentialOperator(SequentialOperator):
|
|
502 |
that the other operators then process.
|
503 |
"""
|
504 |
|
505 |
-
def __call__(self) -> MultiStream:
|
506 |
-
return super().__call__()
|
507 |
-
|
508 |
def process(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
|
509 |
assert (
|
510 |
self.num_steps() > 0
|
|
|
117 |
pass
|
118 |
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
def instance_generator(instance):
|
121 |
yield instance
|
122 |
|
|
|
165 |
return next(iter(processed_multi_stream[stream_name]))
|
166 |
|
167 |
|
168 |
+
class SourceOperator(MultiStreamOperator):
|
169 |
+
"""A class representing a source operator in the streaming system.
|
170 |
+
|
171 |
+
A source operator is responsible for generating the data stream from some source, such as a database or a file.
|
172 |
+
This is the starting point of a stream processing pipeline.
|
173 |
+
The `SourceOperator` class is a type of `SourceOperator`, which is a special type of `StreamingOperator`
|
174 |
+
that generates an output stream but does not take any input streams.
|
175 |
+
|
176 |
+
When called, a `SourceOperator` invokes its `process` method, which should be implemented by all subclasses
|
177 |
+
to generate the required `MultiStream`.
|
178 |
+
|
179 |
+
"""
|
180 |
+
|
181 |
+
def _process_multi_stream(
|
182 |
+
self, multi_stream: Optional[MultiStream] = None
|
183 |
+
) -> MultiStream:
|
184 |
+
result = self.process()
|
185 |
+
assert isinstance(
|
186 |
+
result, MultiStream
|
187 |
+
), "MultiStreamOperator must return a MultiStream"
|
188 |
+
return result
|
189 |
+
|
190 |
+
@abstractmethod
|
191 |
+
def process(self) -> MultiStream:
|
192 |
+
pass
|
193 |
+
|
194 |
+
|
195 |
+
class StreamInitializerOperator(SourceOperator):
|
196 |
+
"""A class representing a stream initializer operator in the streaming system.
|
197 |
+
|
198 |
+
A stream initializer operator is a special type of `SourceOperator` that is capable of taking parameters during the stream generation process. This can be useful in situations where the stream generation process needs to be customized or configured based on certain parameters.
|
199 |
+
|
200 |
+
When called, a `StreamInitializerOperator` invokes its `process` method, passing any supplied arguments and keyword arguments. The `process` method should be implemented by all subclasses to generate the required `MultiStream` based on the given arguments and keyword arguments.
|
201 |
+
|
202 |
+
"""
|
203 |
+
|
204 |
+
caching: bool = NonPositionalField(default=None)
|
205 |
+
|
206 |
+
def __call__(self, *args, **kwargs) -> MultiStream:
|
207 |
+
multi_stream = self.process(*args, **kwargs)
|
208 |
+
if self.caching is not None:
|
209 |
+
multi_stream.set_caching(self.caching)
|
210 |
+
return self.process(*args, **kwargs)
|
211 |
+
|
212 |
+
@abstractmethod
|
213 |
+
def process(self, *args, **kwargs) -> MultiStream:
|
214 |
+
pass
|
215 |
+
|
216 |
+
|
217 |
class StreamOperator(MultiStreamOperator):
|
218 |
"""A class representing a single-stream operator in the streaming system.
|
219 |
|
|
|
459 |
pass
|
460 |
|
461 |
|
462 |
+
class SequentialMixin(Artifact):
|
463 |
+
max_steps: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
steps: List[StreamingOperator] = field(default_factory=list)
|
465 |
|
466 |
def num_steps(self) -> int:
|
|
|
482 |
def _get_max_steps(self):
|
483 |
return self.max_steps if self.max_steps is not None else len(self.steps)
|
484 |
|
485 |
+
|
486 |
+
class SequentialOperator(MultiStreamOperator, SequentialMixin):
|
487 |
+
"""A class representing a sequential operator in the streaming system.
|
488 |
+
|
489 |
+
A sequential operator is a type of `MultiStreamOperator` that applies a sequence of other operators to a
|
490 |
+
`MultiStream`. It maintains a list of `StreamingOperator`s and applies them in order to the `MultiStream`.
|
491 |
+
"""
|
492 |
+
|
493 |
def process(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
|
494 |
for operator in self.steps[0 : self._get_max_steps()]:
|
495 |
multi_stream = operator(multi_stream)
|
496 |
return multi_stream
|
497 |
|
498 |
|
499 |
+
class SourceSequentialOperator(SourceOperator, SequentialMixin):
|
500 |
"""A class representing a source sequential operator in the streaming system.
|
501 |
|
502 |
A source sequential operator is a type of `SequentialOperator` that starts with a source operator.
|
|
|
504 |
that the other operators then process.
|
505 |
"""
|
506 |
|
|
|
|
|
|
|
507 |
def process(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
|
508 |
assert (
|
509 |
self.num_steps() > 0
|
operators.py
CHANGED
@@ -16,11 +16,17 @@ The primary task in any operator development is to implement the `process` funct
|
|
16 |
|
17 |
General or Specelized Operators
|
18 |
--------------------------------
|
19 |
-
Some operators are specielized in specific
|
20 |
|
21 |
-
- :class:`loaders<unitxt.loaders>` for
|
22 |
- :class:`splitters<unitxt.splitters>` for fixing data splits.
|
|
|
23 |
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
Other specelized operators are used by unitxt internally:
|
26 |
|
@@ -59,6 +65,7 @@ import requests
|
|
59 |
|
60 |
from .artifact import Artifact, fetch_artifact
|
61 |
from .dataclass import DeprecatedField, NonPositionalField, OptionalField
|
|
|
62 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
63 |
from .operator import (
|
64 |
InstanceOperator,
|
@@ -222,7 +229,7 @@ class FlattenInstances(InstanceOperator):
|
|
222 |
return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
|
223 |
|
224 |
|
225 |
-
class
|
226 |
"""Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
|
227 |
|
228 |
Args:
|
@@ -232,17 +239,17 @@ class AddFields(InstanceOperator):
|
|
232 |
|
233 |
Examples:
|
234 |
# Add a 'classes' field with a value of a list "positive" and "negative" to all streams
|
235 |
-
|
236 |
|
237 |
# Add a 'start' field under the 'span' field with a value of 0 to all streams
|
238 |
-
|
239 |
|
240 |
# Add a 'classes' field with a value of a list "positive" and "negative" to 'train' stream
|
241 |
-
|
242 |
|
243 |
# Add a 'classes' field on a given list, prevent modification of original list
|
244 |
# from changing the instance.
|
245 |
-
|
246 |
# if now alist is modified, still the instances remain intact.
|
247 |
"""
|
248 |
|
@@ -265,6 +272,11 @@ class AddFields(InstanceOperator):
|
|
265 |
return instance
|
266 |
|
267 |
|
|
|
|
|
|
|
|
|
|
|
268 |
class RemoveFields(InstanceOperator):
|
269 |
"""Remove specified fields from each instance in a stream.
|
270 |
|
@@ -1049,6 +1061,7 @@ class Copy(FieldOperator):
|
|
1049 |
return copy.deepcopy(value)
|
1050 |
|
1051 |
|
|
|
1052 |
class CopyFields(Copy):
|
1053 |
pass
|
1054 |
|
@@ -1392,7 +1405,8 @@ class ComputeExpressionMixin(Artifact):
|
|
1392 |
return eval(self.expression, self.globals, instance)
|
1393 |
|
1394 |
raise ValueError(
|
1395 |
-
f"Cannot
|
|
|
1396 |
)
|
1397 |
|
1398 |
|
@@ -1556,7 +1570,7 @@ class ExtractMostCommonFieldValues(MultiStreamOperator):
|
|
1556 |
for ele in values_and_counts
|
1557 |
]
|
1558 |
|
1559 |
-
addmostcommons =
|
1560 |
return addmostcommons(multi_stream)
|
1561 |
|
1562 |
|
|
|
16 |
|
17 |
General or Specelized Operators
|
18 |
--------------------------------
|
19 |
+
Some operators are specielized in specific data or specific operations such as:
|
20 |
|
21 |
+
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
|
22 |
- :class:`splitters<unitxt.splitters>` for fixing data splits.
|
23 |
+
- :class:`stream_operators<unitxt.stream_operators>` for changing joining and mixing streams.
|
24 |
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
|
25 |
+
- :class:`collections_operators<unitxt.collections_operators>` for handling collections such as lists and dictionaries.
|
26 |
+
- :class:`dialog_operators<unitxt.dialog_operators>` for handling dialogs.
|
27 |
+
- :class:`string_operators<unitxt.string_operators>` for handling strings.
|
28 |
+
- :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
|
29 |
+
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
|
30 |
|
31 |
Other specelized operators are used by unitxt internally:
|
32 |
|
|
|
65 |
|
66 |
from .artifact import Artifact, fetch_artifact
|
67 |
from .dataclass import DeprecatedField, NonPositionalField, OptionalField
|
68 |
+
from .deprecation_utils import deprecation
|
69 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
70 |
from .operator import (
|
71 |
InstanceOperator,
|
|
|
229 |
return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
|
230 |
|
231 |
|
232 |
+
class Set(InstanceOperator):
|
233 |
"""Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
|
234 |
|
235 |
Args:
|
|
|
239 |
|
240 |
Examples:
|
241 |
# Add a 'classes' field with a value of a list "positive" and "negative" to all streams
|
242 |
+
Set(fields={"classes": ["positive","negatives"]})
|
243 |
|
244 |
# Add a 'start' field under the 'span' field with a value of 0 to all streams
|
245 |
+
Set(fields={"span/start": 0}
|
246 |
|
247 |
# Add a 'classes' field with a value of a list "positive" and "negative" to 'train' stream
|
248 |
+
Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})
|
249 |
|
250 |
# Add a 'classes' field on a given list, prevent modification of original list
|
251 |
# from changing the instance.
|
252 |
+
Set(fields={"classes": alist}), use_deepcopy=True)
|
253 |
# if now alist is modified, still the instances remain intact.
|
254 |
"""
|
255 |
|
|
|
272 |
return instance
|
273 |
|
274 |
|
275 |
+
@deprecation(version="1.11.0", alternative=Set)
|
276 |
+
class AddFields(Set):
|
277 |
+
pass
|
278 |
+
|
279 |
+
|
280 |
class RemoveFields(InstanceOperator):
|
281 |
"""Remove specified fields from each instance in a stream.
|
282 |
|
|
|
1061 |
return copy.deepcopy(value)
|
1062 |
|
1063 |
|
1064 |
+
@deprecation(version="1.11.0", alternative=Copy)
|
1065 |
class CopyFields(Copy):
|
1066 |
pass
|
1067 |
|
|
|
1405 |
return eval(self.expression, self.globals, instance)
|
1406 |
|
1407 |
raise ValueError(
|
1408 |
+
f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
|
1409 |
+
"\nNote: If using test_card() with the default setting, increase loader_limit to avoid missing conditions due to limited data sampling."
|
1410 |
)
|
1411 |
|
1412 |
|
|
|
1570 |
for ele in values_and_counts
|
1571 |
]
|
1572 |
|
1573 |
+
addmostcommons = Set(fields={self.to_field: values_to_keep})
|
1574 |
return addmostcommons(multi_stream)
|
1575 |
|
1576 |
|
processors.py
CHANGED
@@ -33,6 +33,11 @@ class ToListByComma(SplitStrip):
|
|
33 |
strip_every_element = True
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
class RegexParser(FieldOperator):
|
37 |
"""A processor that uses regex in order to parse a string."""
|
38 |
|
|
|
33 |
strip_every_element = True
|
34 |
|
35 |
|
36 |
+
class ToListByCommaSpace(SplitStrip):
|
37 |
+
delimiter = ", "
|
38 |
+
strip_every_element = True
|
39 |
+
|
40 |
+
|
41 |
class RegexParser(FieldOperator):
|
42 |
"""A processor that uses regex in order to parse a string."""
|
43 |
|
settings_utils.py
CHANGED
@@ -1,7 +1,6 @@
|
|
|
|
1 |
import os
|
2 |
|
3 |
-
import pkg_resources
|
4 |
-
|
5 |
from .version import version
|
6 |
|
7 |
|
@@ -141,11 +140,11 @@ if Constants.is_uninitilized():
|
|
141 |
constants.dataset_file = os.path.join(os.path.dirname(__file__), "dataset.py")
|
142 |
constants.metric_file = os.path.join(os.path.dirname(__file__), "metric.py")
|
143 |
constants.local_catalog_path = os.path.join(os.path.dirname(__file__), "catalog")
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
)
|
148 |
-
|
149 |
constants.default_catalog_path = constants.local_catalog_path
|
150 |
constants.catalog_dir = constants.local_catalog_path
|
151 |
constants.dataset_url = "unitxt/data"
|
|
|
1 |
+
import importlib.util
|
2 |
import os
|
3 |
|
|
|
|
|
4 |
from .version import version
|
5 |
|
6 |
|
|
|
140 |
constants.dataset_file = os.path.join(os.path.dirname(__file__), "dataset.py")
|
141 |
constants.metric_file = os.path.join(os.path.dirname(__file__), "metric.py")
|
142 |
constants.local_catalog_path = os.path.join(os.path.dirname(__file__), "catalog")
|
143 |
+
unitxt_pkg = importlib.util.find_spec("unitxt")
|
144 |
+
if unitxt_pkg and unitxt_pkg.origin:
|
145 |
+
unitxt_dir = os.path.dirname(unitxt_pkg.origin)
|
146 |
+
constants.default_catalog_path = os.path.join(unitxt_dir, "catalog")
|
147 |
+
else:
|
148 |
constants.default_catalog_path = constants.local_catalog_path
|
149 |
constants.catalog_dir = constants.local_catalog_path
|
150 |
constants.dataset_url = "unitxt/data"
|
standard.py
CHANGED
@@ -5,7 +5,7 @@ from .dataclass import Field, InternalField, NonPositionalField, OptionalField
|
|
5 |
from .formats import Format, SystemFormat
|
6 |
from .logging_utils import get_logger
|
7 |
from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
|
8 |
-
from .operators import
|
9 |
from .recipe import Recipe
|
10 |
from .schema import ToUnitxtGroup
|
11 |
from .splitters import Sampler, SeparateSplit, SpreadSplit
|
@@ -120,6 +120,12 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
120 |
metrics = self.card.task.metrics
|
121 |
else:
|
122 |
metrics = self.metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
return metrics, postprocessors
|
124 |
|
125 |
def set_pipelines(self):
|
@@ -220,7 +226,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
220 |
self.loading.steps.append(StreamRefiner(max_instances=self.loader_limit))
|
221 |
|
222 |
self.metadata.steps.append(
|
223 |
-
|
224 |
fields={
|
225 |
"recipe_metadata": {
|
226 |
"template": self.template,
|
|
|
5 |
from .formats import Format, SystemFormat
|
6 |
from .logging_utils import get_logger
|
7 |
from .operator import SequentialOperator, SourceSequentialOperator, StreamingOperator
|
8 |
+
from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
|
9 |
from .recipe import Recipe
|
10 |
from .schema import ToUnitxtGroup
|
11 |
from .splitters import Sampler, SeparateSplit, SpreadSplit
|
|
|
120 |
metrics = self.card.task.metrics
|
121 |
else:
|
122 |
metrics = self.metrics
|
123 |
+
|
124 |
+
metrics = [
|
125 |
+
metric if isinstance(metric, str) else metric.to_json()
|
126 |
+
for metric in metrics
|
127 |
+
]
|
128 |
+
|
129 |
return metrics, postprocessors
|
130 |
|
131 |
def set_pipelines(self):
|
|
|
226 |
self.loading.steps.append(StreamRefiner(max_instances=self.loader_limit))
|
227 |
|
228 |
self.metadata.steps.append(
|
229 |
+
Set(
|
230 |
fields={
|
231 |
"recipe_metadata": {
|
232 |
"template": self.template,
|
string_operators.py
CHANGED
@@ -37,6 +37,27 @@ class TokensSplit(FieldOperator):
|
|
37 |
return self.tokenizer.tokenize(value)
|
38 |
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
class Join(FieldOperator):
|
41 |
by: str
|
42 |
|
|
|
37 |
return self.tokenizer.tokenize(value)
|
38 |
|
39 |
|
40 |
+
class TokensSlice(FieldOperator):
|
41 |
+
model: str
|
42 |
+
start: Optional[int] = None
|
43 |
+
stop: Optional[int] = None
|
44 |
+
step: Optional[int] = None
|
45 |
+
|
46 |
+
_requirements_list = ["transformers"]
|
47 |
+
|
48 |
+
def prepare(self):
|
49 |
+
super().prepare()
|
50 |
+
from transformers import AutoTokenizer
|
51 |
+
|
52 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
|
53 |
+
|
54 |
+
def process_value(self, value: str) -> str:
|
55 |
+
encoded = self.tokenizer.encode(value)
|
56 |
+
slicer = slice(self.start, self.stop, self.step)
|
57 |
+
sliced = encoded[slicer]
|
58 |
+
return self.tokenizer.decode(sliced)
|
59 |
+
|
60 |
+
|
61 |
class Join(FieldOperator):
|
62 |
by: str
|
63 |
|
struct_data_operators.py
CHANGED
@@ -14,6 +14,7 @@ For key-value pairs, expected input format is:
|
|
14 |
{"key1": "value1", "key2": value2, "key3": "value3"}
|
15 |
------------------------
|
16 |
"""
|
|
|
17 |
import json
|
18 |
import random
|
19 |
from abc import ABC, abstractmethod
|
|
|
14 |
{"key1": "value1", "key2": value2, "key3": "value3"}
|
15 |
------------------------
|
16 |
"""
|
17 |
+
|
18 |
import json
|
19 |
import random
|
20 |
from abc import ABC, abstractmethod
|
task.py
CHANGED
@@ -27,6 +27,10 @@ class Task(InstanceOperator):
|
|
27 |
prediction_type (Optional[str]):
|
28 |
Need to be consistent with all used metrics. Defaults to None, which means that it will
|
29 |
be set to Any.
|
|
|
|
|
|
|
|
|
30 |
|
31 |
The output instance contains three fields:
|
32 |
"inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
|
@@ -39,6 +43,7 @@ class Task(InstanceOperator):
|
|
39 |
metrics: List[str]
|
40 |
prediction_type: Optional[str] = None
|
41 |
augmentable_inputs: List[str] = []
|
|
|
42 |
|
43 |
def verify(self):
|
44 |
for io_type in ["inputs", "outputs"]:
|
@@ -72,6 +77,8 @@ class Task(InstanceOperator):
|
|
72 |
augmentable_input in self.inputs
|
73 |
), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
|
74 |
|
|
|
|
|
75 |
@staticmethod
|
76 |
@lru_cache(maxsize=None)
|
77 |
def get_metric_prediction_type(metric_id: str):
|
@@ -99,9 +106,46 @@ class Task(InstanceOperator):
|
|
99 |
f"metric's prediction type ({metric_prediction_type}) are different."
|
100 |
)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def process(
|
103 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
104 |
) -> Dict[str, Any]:
|
|
|
|
|
105 |
verify_required_schema(self.inputs, instance)
|
106 |
verify_required_schema(self.outputs, instance)
|
107 |
|
|
|
27 |
prediction_type (Optional[str]):
|
28 |
Need to be consistent with all used metrics. Defaults to None, which means that it will
|
29 |
be set to Any.
|
30 |
+
defaults (Optional[Dict[str, Any]]):
|
31 |
+
An optional dictionary with default values for chosen input/output keys. Needs to be
|
32 |
+
consistent with names and types provided in 'inputs' and/or 'outputs' arguments.
|
33 |
+
Will not overwrite values if already provided in a given instance.
|
34 |
|
35 |
The output instance contains three fields:
|
36 |
"inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
|
|
|
43 |
metrics: List[str]
|
44 |
prediction_type: Optional[str] = None
|
45 |
augmentable_inputs: List[str] = []
|
46 |
+
defaults: Optional[Dict[str, Any]] = None
|
47 |
|
48 |
def verify(self):
|
49 |
for io_type in ["inputs", "outputs"]:
|
|
|
77 |
augmentable_input in self.inputs
|
78 |
), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
|
79 |
|
80 |
+
self.verify_defaults()
|
81 |
+
|
82 |
@staticmethod
|
83 |
@lru_cache(maxsize=None)
|
84 |
def get_metric_prediction_type(metric_id: str):
|
|
|
106 |
f"metric's prediction type ({metric_prediction_type}) are different."
|
107 |
)
|
108 |
|
109 |
+
def verify_defaults(self):
|
110 |
+
if self.defaults:
|
111 |
+
if not isinstance(self.defaults, dict):
|
112 |
+
raise ValueError(
|
113 |
+
f"If specified, the 'defaults' must be a dictionary, "
|
114 |
+
f"however, '{self.defaults}' was provided instead, "
|
115 |
+
f"which is of type '{type(self.defaults)}'."
|
116 |
+
)
|
117 |
+
|
118 |
+
for default_name, default_value in self.defaults.items():
|
119 |
+
assert isinstance(default_name, str), (
|
120 |
+
f"If specified, all keys of the 'defaults' must be strings, "
|
121 |
+
f"however, the key '{default_name}' is of type '{type(default_name)}'."
|
122 |
+
)
|
123 |
+
|
124 |
+
val_type = self.inputs.get(default_name) or self.outputs.get(
|
125 |
+
default_name
|
126 |
+
)
|
127 |
+
|
128 |
+
assert val_type, (
|
129 |
+
f"If specified, all keys of the 'defaults' must refer to a chosen "
|
130 |
+
f"key in either 'inputs' or 'outputs'. However, the name '{default_name}' "
|
131 |
+
f"was provided which does not match any of the keys."
|
132 |
+
)
|
133 |
+
|
134 |
+
assert isoftype(default_value, parse_type_string(val_type)), (
|
135 |
+
f"The value of '{default_name}' from the 'defaults' must be of "
|
136 |
+
f"type '{val_type}', however, it is of type '{type(default_value)}'."
|
137 |
+
)
|
138 |
+
|
139 |
+
def set_default_values(self, instance: Dict[str, Any]) -> Dict[str, Any]:
|
140 |
+
if self.defaults:
|
141 |
+
instance = {**self.defaults, **instance}
|
142 |
+
return instance
|
143 |
+
|
144 |
def process(
|
145 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
146 |
) -> Dict[str, Any]:
|
147 |
+
instance = self.set_default_values(instance)
|
148 |
+
|
149 |
verify_required_schema(self.inputs, instance)
|
150 |
verify_required_schema(self.outputs, instance)
|
151 |
|
text_utils.py
CHANGED
@@ -69,7 +69,7 @@ def camel_to_snake_case(s):
|
|
69 |
return s.lower()
|
70 |
|
71 |
|
72 |
-
def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None):
|
73 |
"""Constructs a formatted string of a dictionary.
|
74 |
|
75 |
Args:
|
@@ -77,13 +77,21 @@ def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None):
|
|
77 |
indent (int, optional): The current level of indentation. Defaults to 0.
|
78 |
indent_delta (int, optional): The amount of spaces to add for each level of indentation. Defaults to 4.
|
79 |
max_chars (int, optional): The maximum number of characters for each line. Defaults to terminal width - 10.
|
|
|
80 |
"""
|
81 |
max_chars = max_chars or shutil.get_terminal_size()[0] - 10
|
82 |
indent_str = " " * indent
|
83 |
indent_delta_str = " " * indent_delta
|
84 |
res = ""
|
85 |
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
if isinstance(value, dict):
|
88 |
res += f"{indent_str}{key}:\n"
|
89 |
res += construct_dict_str(value, indent + indent_delta, max_chars=max_chars)
|
@@ -106,10 +114,12 @@ def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None):
|
|
106 |
return res
|
107 |
|
108 |
|
109 |
-
def print_dict(
|
110 |
-
|
|
|
|
|
111 |
dict_str = "\n" + dict_str
|
112 |
-
logger
|
113 |
|
114 |
|
115 |
def nested_tuple_to_string(nested_tuple: tuple) -> str:
|
@@ -150,6 +160,7 @@ def is_made_of_sub_strings(string, sub_strings):
|
|
150 |
# It also prepares for the case that __description__ tag does not contain balanced
|
151 |
# parentheses, since it is often cut in the middle, (with "... see more at")
|
152 |
# flake8: noqa: B007
|
|
|
153 |
def lines_defining_obj_in_card(
|
154 |
all_lines: List[str], obj_name: str, start_search_at_line: int = 0
|
155 |
) -> Tuple[int, int]:
|
@@ -165,25 +176,53 @@ def lines_defining_obj_in_card(
|
|
165 |
ending_line = starting_line - 1
|
166 |
while ending_line < len(all_lines):
|
167 |
ending_line += 1
|
168 |
-
|
169 |
-
num_of_closes += len(re.findall(r"[)}\]]", all_lines[ending_line]))
|
170 |
-
if num_of_closes == num_of_opens:
|
171 |
-
break
|
172 |
if "__description__" in all_lines[ending_line]:
|
173 |
-
# can not trust parentheses inside description
|
174 |
-
#
|
|
|
175 |
# a line consisting of only __description__=(
|
176 |
# followed by one or more lines of text, can not trust opens and closes
|
177 |
# in them, followed by a line consisting of only: ),
|
178 |
# where the ) is indented with the beginning of __description__
|
|
|
|
|
|
|
|
|
179 |
tag_indentation = all_lines[ending_line].index("__description__")
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
ending_line += 1
|
183 |
if "__description__" in obj_name:
|
184 |
-
return (
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
if num_of_closes != num_of_opens:
|
189 |
raise ValueError(
|
|
|
69 |
return s.lower()
|
70 |
|
71 |
|
72 |
+
def construct_dict_str(d, indent=0, indent_delta=4, max_chars=None, keys=None):
|
73 |
"""Constructs a formatted string of a dictionary.
|
74 |
|
75 |
Args:
|
|
|
77 |
indent (int, optional): The current level of indentation. Defaults to 0.
|
78 |
indent_delta (int, optional): The amount of spaces to add for each level of indentation. Defaults to 4.
|
79 |
max_chars (int, optional): The maximum number of characters for each line. Defaults to terminal width - 10.
|
80 |
+
keys (List[Str], optional): the list of fields to print
|
81 |
"""
|
82 |
max_chars = max_chars or shutil.get_terminal_size()[0] - 10
|
83 |
indent_str = " " * indent
|
84 |
indent_delta_str = " " * indent_delta
|
85 |
res = ""
|
86 |
|
87 |
+
if keys is None:
|
88 |
+
keys = d.keys()
|
89 |
+
for key in keys:
|
90 |
+
if key not in d.keys():
|
91 |
+
raise ValueError(
|
92 |
+
f"Dictionary does not contain field {key} specified in 'keys' argument. The available keys are {d.keys()}"
|
93 |
+
)
|
94 |
+
value = d[key]
|
95 |
if isinstance(value, dict):
|
96 |
res += f"{indent_str}{key}:\n"
|
97 |
res += construct_dict_str(value, indent + indent_delta, max_chars=max_chars)
|
|
|
114 |
return res
|
115 |
|
116 |
|
117 |
+
def print_dict(
|
118 |
+
d, indent=0, indent_delta=4, max_chars=None, keys_to_print=None, log_level="info"
|
119 |
+
):
|
120 |
+
dict_str = construct_dict_str(d, indent, indent_delta, max_chars, keys_to_print)
|
121 |
dict_str = "\n" + dict_str
|
122 |
+
getattr(logger, log_level)(dict_str)
|
123 |
|
124 |
|
125 |
def nested_tuple_to_string(nested_tuple: tuple) -> str:
|
|
|
160 |
# It also prepares for the case that __description__ tag does not contain balanced
|
161 |
# parentheses, since it is often cut in the middle, (with "... see more at")
|
162 |
# flake8: noqa: B007
|
163 |
+
# flake8: noqa: C901
|
164 |
def lines_defining_obj_in_card(
|
165 |
all_lines: List[str], obj_name: str, start_search_at_line: int = 0
|
166 |
) -> Tuple[int, int]:
|
|
|
176 |
ending_line = starting_line - 1
|
177 |
while ending_line < len(all_lines):
|
178 |
ending_line += 1
|
179 |
+
|
|
|
|
|
|
|
180 |
if "__description__" in all_lines[ending_line]:
|
181 |
+
# can not trust parentheses inside description, because this is mainly truncated
|
182 |
+
# free text.
|
183 |
+
# We do trust the indentation enforced by ruff, and the way we build __description__:
|
184 |
# a line consisting of only __description__=(
|
185 |
# followed by one or more lines of text, can not trust opens and closes
|
186 |
# in them, followed by a line consisting of only: ),
|
187 |
# where the ) is indented with the beginning of __description__
|
188 |
+
# We also prepare for the case that, when not entered by us, __description__=
|
189 |
+
# is not followed by a ( and the whole description does not end with a single ) in its line.
|
190 |
+
# We build on ruff making the line following the description start with same indentation
|
191 |
+
# or 4 less (i.e., the following line is the closing of the card).
|
192 |
tag_indentation = all_lines[ending_line].index("__description__")
|
193 |
+
starts_with_parent = "__description__=(" in all_lines[ending_line]
|
194 |
+
if starts_with_parent:
|
195 |
+
last_line_to_start_with = (" " * tag_indentation) + r"\)"
|
196 |
+
else:
|
197 |
+
# actually, the line that follows the description
|
198 |
+
last_line_to_start_with1 = (" " * tag_indentation) + "[^ ]"
|
199 |
+
last_line_to_start_with2 = (" " * (tag_indentation - 4)) + "[^ ]"
|
200 |
+
last_line_to_start_with = (
|
201 |
+
"("
|
202 |
+
+ last_line_to_start_with1
|
203 |
+
+ "|"
|
204 |
+
+ last_line_to_start_with2
|
205 |
+
+ ")"
|
206 |
+
)
|
207 |
+
ending_line += 1
|
208 |
+
while not re.search("^" + last_line_to_start_with, all_lines[ending_line]):
|
209 |
ending_line += 1
|
210 |
if "__description__" in obj_name:
|
211 |
+
return (
|
212 |
+
starting_line,
|
213 |
+
ending_line if starts_with_parent else ending_line - 1,
|
214 |
+
)
|
215 |
+
|
216 |
+
if starts_with_parent:
|
217 |
+
ending_line += 1
|
218 |
+
|
219 |
+
# we conrinue in card, having passed the description, ending line points
|
220 |
+
# to the line that follows description
|
221 |
+
|
222 |
+
num_of_opens += len(re.findall(r"[({[]", all_lines[ending_line]))
|
223 |
+
num_of_closes += len(re.findall(r"[)}\]]", all_lines[ending_line]))
|
224 |
+
if num_of_closes == num_of_opens:
|
225 |
+
break
|
226 |
|
227 |
if num_of_closes != num_of_opens:
|
228 |
raise ValueError(
|
utils.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import json
|
|
|
2 |
from functools import lru_cache
|
3 |
from typing import Any, Dict
|
4 |
|
@@ -47,13 +49,16 @@ def load_json(path):
|
|
47 |
) from e
|
48 |
|
49 |
|
50 |
-
def
|
51 |
with open(path, "w") as f:
|
52 |
-
|
53 |
-
f.write(dumped)
|
54 |
f.write("\n")
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
57 |
def is_package_installed(package_name):
|
58 |
"""Check if a package is installed.
|
59 |
|
@@ -113,3 +118,15 @@ def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
|
|
113 |
raise ValueError(
|
114 |
f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
|
115 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
import json
|
3 |
+
import os
|
4 |
from functools import lru_cache
|
5 |
from typing import Any, Dict
|
6 |
|
|
|
49 |
) from e
|
50 |
|
51 |
|
52 |
+
def save_to_file(path, data):
|
53 |
with open(path, "w") as f:
|
54 |
+
f.write(data)
|
|
|
55 |
f.write("\n")
|
56 |
|
57 |
|
58 |
+
def json_dump(data):
|
59 |
+
return json.dumps(data, indent=4, ensure_ascii=False)
|
60 |
+
|
61 |
+
|
62 |
def is_package_installed(package_name):
|
63 |
"""Check if a package is installed.
|
64 |
|
|
|
118 |
raise ValueError(
|
119 |
f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
|
120 |
)
|
121 |
+
|
122 |
+
|
123 |
+
def import_module_from_file(file_path):
|
124 |
+
# Get the module name (file name without extension)
|
125 |
+
module_name = os.path.splitext(os.path.basename(file_path))[0]
|
126 |
+
# Create a module specification
|
127 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
128 |
+
# Create a new module based on the specification
|
129 |
+
module = importlib.util.module_from_spec(spec)
|
130 |
+
# Load the module
|
131 |
+
spec.loader.exec_module(module)
|
132 |
+
return module
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.10.
|
|
|
1 |
+
version = "1.10.1"
|