Upload folder using huggingface_hub
Browse files- api.py +21 -13
- artifact.py +29 -21
- augmentors.py +2 -1
- dataclass.py +8 -1
- dialog_operators.py +7 -5
- dict_utils.py +20 -15
- formats.py +20 -32
- image_operators.py +10 -7
- inference.py +107 -102
- llm_as_judge.py +40 -12
- loaders.py +26 -11
- metrics.py +172 -35
- operators.py +123 -105
- span_lableing_operators.py +22 -16
- struct_data_operators.py +125 -86
- task.py +16 -12
- templates.py +17 -8
- type_utils.py +13 -12
- utils.py +2 -2
- version.py +1 -1
api.py
CHANGED
@@ -93,31 +93,39 @@ def load_dataset(
|
|
93 |
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
|
94 |
"""Loads dataset.
|
95 |
|
96 |
-
If the 'dataset_query' argument is provided, then dataset is loaded from a card
|
97 |
-
catalog based on parameters specified in the query.
|
98 |
-
|
|
|
|
|
99 |
|
100 |
Args:
|
101 |
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
102 |
-
|
103 |
-
|
104 |
streaming (bool, False): When True yields the data as Unitxt streams dictionary
|
|
|
105 |
split (str, optional): The split of the data to load
|
|
|
106 |
disable_cache (str, optional): Disable caching process of the data
|
|
|
107 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
108 |
|
109 |
Returns:
|
110 |
DatasetDict
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
card = TaskCard(...)
|
118 |
-
template = Template(...)
|
119 |
-
loader_limit = 10
|
120 |
-
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
|
121 |
"""
|
122 |
recipe = load_recipe(dataset_query, **kwargs)
|
123 |
|
|
|
93 |
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
|
94 |
"""Loads dataset.
|
95 |
|
96 |
+
If the 'dataset_query' argument is provided, then dataset is loaded from a card
|
97 |
+
in local catalog based on parameters specified in the query.
|
98 |
+
|
99 |
+
Alternatively, dataset is loaded from a provided card based on explicitly
|
100 |
+
given parameters.
|
101 |
|
102 |
Args:
|
103 |
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
104 |
+
For example: ``"card=cards.wnli,template=templates.classification.multi_class.relation.default".``
|
105 |
+
|
106 |
streaming (bool, False): When True yields the data as Unitxt streams dictionary
|
107 |
+
|
108 |
split (str, optional): The split of the data to load
|
109 |
+
|
110 |
disable_cache (str, optional): Disable caching process of the data
|
111 |
+
|
112 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
113 |
|
114 |
Returns:
|
115 |
DatasetDict
|
116 |
|
117 |
+
Example:
|
118 |
+
.. code-block:: python
|
119 |
+
|
120 |
+
dataset = load_dataset(
|
121 |
+
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
|
122 |
+
) # card must be present in local catalog
|
123 |
+
|
124 |
+
card = TaskCard(...)
|
125 |
+
template = Template(...)
|
126 |
+
loader_limit = 10
|
127 |
+
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
|
128 |
|
|
|
|
|
|
|
|
|
129 |
"""
|
130 |
recipe = load_recipe(dataset_query, **kwargs)
|
131 |
|
artifact.py
CHANGED
@@ -89,16 +89,18 @@ class Catalogs:
|
|
89 |
self.catalogs = []
|
90 |
|
91 |
|
92 |
-
def
|
93 |
-
if
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
102 |
|
103 |
|
104 |
def get_closest_artifact_type(type):
|
@@ -150,8 +152,12 @@ class Artifact(Dataclass):
|
|
150 |
)
|
151 |
|
152 |
@classmethod
|
153 |
-
def is_artifact_dict(cls,
|
154 |
-
return isinstance(
|
|
|
|
|
|
|
|
|
155 |
|
156 |
@classmethod
|
157 |
def verify_artifact_dict(cls, d):
|
@@ -292,7 +298,7 @@ class Artifact(Dataclass):
|
|
292 |
field.type, Union[Artifact, List[Artifact], Dict[str, Artifact]]
|
293 |
):
|
294 |
value = getattr(self, field.name)
|
295 |
-
value =
|
296 |
setattr(self, field.name, value)
|
297 |
|
298 |
self.verify_data_classification_policy()
|
@@ -343,15 +349,18 @@ class Artifact(Dataclass):
|
|
343 |
|
344 |
Args:
|
345 |
instance (Dict[str, Any]): data which should contain its allowed data
|
346 |
-
|
|
|
347 |
name (Optional[str]): name of artifact which should be used to retrieve
|
348 |
-
|
349 |
-
|
350 |
|
351 |
Returns:
|
352 |
Dict[str, Any]: unchanged instance.
|
353 |
|
354 |
Examples:
|
|
|
|
|
355 |
instance = {"x": "some_text", "data_classification_policy": ["pii"]}
|
356 |
|
357 |
# Will raise an error as "pii" is not included policy
|
@@ -574,11 +583,10 @@ def reset_artifacts_json_cache():
|
|
574 |
artifacts_json_cache.cache_clear()
|
575 |
|
576 |
|
577 |
-
def maybe_recover_artifact(
|
578 |
-
if
|
579 |
-
return verbosed_fetch_artifact(
|
580 |
-
|
581 |
-
return artifact
|
582 |
|
583 |
|
584 |
def register_all_artifacts(path):
|
|
|
89 |
self.catalogs = []
|
90 |
|
91 |
|
92 |
+
def maybe_recover_artifacts_structure(obj):
|
93 |
+
if Artifact.is_possible_identifier(obj):
|
94 |
+
return verbosed_fetch_artifact(obj)
|
95 |
+
if isinstance(obj, dict):
|
96 |
+
for key, value in obj.items():
|
97 |
+
obj[key] = maybe_recover_artifact(value)
|
98 |
+
return obj
|
99 |
+
if isinstance(obj, list):
|
100 |
+
for i in range(len(obj)):
|
101 |
+
obj[i] = maybe_recover_artifact(obj[i])
|
102 |
+
return obj
|
103 |
+
return obj
|
104 |
|
105 |
|
106 |
def get_closest_artifact_type(type):
|
|
|
152 |
)
|
153 |
|
154 |
@classmethod
|
155 |
+
def is_artifact_dict(cls, obj):
|
156 |
+
return isinstance(obj, dict) and "__type__" in obj
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def is_possible_identifier(cls, obj):
|
160 |
+
return isinstance(obj, str) or cls.is_artifact_dict(obj)
|
161 |
|
162 |
@classmethod
|
163 |
def verify_artifact_dict(cls, d):
|
|
|
298 |
field.type, Union[Artifact, List[Artifact], Dict[str, Artifact]]
|
299 |
):
|
300 |
value = getattr(self, field.name)
|
301 |
+
value = maybe_recover_artifacts_structure(value)
|
302 |
setattr(self, field.name, value)
|
303 |
|
304 |
self.verify_data_classification_policy()
|
|
|
349 |
|
350 |
Args:
|
351 |
instance (Dict[str, Any]): data which should contain its allowed data
|
352 |
+
classification policies under key 'data_classification_policy'.
|
353 |
+
|
354 |
name (Optional[str]): name of artifact which should be used to retrieve
|
355 |
+
data classification from env. If not specified, then either ``__id__`` or
|
356 |
+
``__class__.__name__``, are used instead, respectively.
|
357 |
|
358 |
Returns:
|
359 |
Dict[str, Any]: unchanged instance.
|
360 |
|
361 |
Examples:
|
362 |
+
.. code-block:: python
|
363 |
+
|
364 |
instance = {"x": "some_text", "data_classification_policy": ["pii"]}
|
365 |
|
366 |
# Will raise an error as "pii" is not included policy
|
|
|
583 |
artifacts_json_cache.cache_clear()
|
584 |
|
585 |
|
586 |
+
def maybe_recover_artifact(obj):
|
587 |
+
if Artifact.is_possible_identifier(obj):
|
588 |
+
return verbosed_fetch_artifact(obj)
|
589 |
+
return obj
|
|
|
590 |
|
591 |
|
592 |
def register_all_artifacts(path):
|
augmentors.py
CHANGED
@@ -97,7 +97,8 @@ class AugmentPrefixSuffix(TextAugmentor):
|
|
97 |
To prepend the input with a prefix made of 4 ``\n``-s or ``\t``-s, employ
|
98 |
``AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)``.
|
99 |
|
100 |
-
To append the input with a suffix made of 3 ``\n``-s or ``\t``-s, with ``\n`` being preferred over ``\t``,
|
|
|
101 |
``AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)``
|
102 |
which will append ``\n``-s twice as often as ``\t``-s.
|
103 |
|
|
|
97 |
To prepend the input with a prefix made of 4 ``\n``-s or ``\t``-s, employ
|
98 |
``AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)``.
|
99 |
|
100 |
+
To append the input with a suffix made of 3 ``\n``-s or ``\t``-s, with ``\n`` being preferred over ``\t``,
|
101 |
+
at 2:1 ratio, employ
|
102 |
``AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)``
|
103 |
which will append ``\n``-s twice as often as ``\t``-s.
|
104 |
|
dataclass.py
CHANGED
@@ -533,6 +533,13 @@ class Dataclass(metaclass=DataclassMeta):
|
|
533 |
if keep_empty or value is not None
|
534 |
}
|
535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
def __repr__(self) -> str:
|
537 |
"""String representation."""
|
538 |
-
return f"{self.__class__.__name__}({', '.join([f'{
|
|
|
533 |
if keep_empty or value is not None
|
534 |
}
|
535 |
|
536 |
+
def get_repr_dict(self):
|
537 |
+
result = {}
|
538 |
+
for field in fields(self):
|
539 |
+
if not field.internal:
|
540 |
+
result[field.name] = getattr(self, field.name)
|
541 |
+
return result
|
542 |
+
|
543 |
def __repr__(self) -> str:
|
544 |
"""String representation."""
|
545 |
+
return f"{self.__class__.__name__}({', '.join([f'{key}={val!r}' for key, val in self.get_repr_dict().items()])})"
|
dialog_operators.py
CHANGED
@@ -5,11 +5,13 @@ text that can be fed to the model.
|
|
5 |
|
6 |
The format of the dialog is:
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
13 |
"""
|
14 |
from typing import Any, Dict, List, Optional
|
15 |
|
|
|
5 |
|
6 |
The format of the dialog is:
|
7 |
|
8 |
+
.. code-block:: text
|
9 |
+
|
10 |
+
dialog = [
|
11 |
+
{"user": "hello", "system": "hi"},
|
12 |
+
{"user": "kkk", "system": ""},
|
13 |
+
{"user": "kkk", "system": ""},
|
14 |
+
]
|
15 |
"""
|
16 |
from typing import Any, Dict, List, Optional
|
17 |
|
dict_utils.py
CHANGED
@@ -24,29 +24,32 @@ def is_wildcard(string):
|
|
24 |
# formal definition of qpath syntax by which a query is specified:
|
25 |
# qpath -> A (/A)*
|
26 |
# A -> name | * | non-neg-int
|
27 |
-
# name ->
|
28 |
-
# *
|
29 |
#
|
30 |
-
#
|
31 |
-
#
|
32 |
-
#
|
33 |
-
# (
|
34 |
-
#
|
35 |
-
#
|
|
|
|
|
|
|
36 |
# and hence no path in dic matches query qpath. (E.g., when el is a list, A must match indx, and its
|
37 |
# int value should be smaller than len(el) in order for the path in dic leading to element el[A] to match pref/A)
|
38 |
-
# (3) Denoting as in (2), now with A == *
|
39 |
# {el[0], el[1], .. , el[len(el)-1]} is said to be lead to by a path matching pref/*
|
40 |
# and when el is a dict, each and every element in the set {el[k] for k being a key in el} is said to be lead
|
41 |
# to by a path matching pref/*
|
42 |
#
|
43 |
# An element el lead to by path p that matches qpath as a whole is thus either a list member (when indx.match the last
|
44 |
-
# component of p
|
45 |
-
# of el
|
46 |
#
|
47 |
# Thus, for a query with no *, dic contains at most one element the path to which matches the query.
|
48 |
# If there is such one in dic - the function (either dict_get, dict_set, or dict_delete) operates on
|
49 |
-
# that element according to its arguments, other than not_exist_ok
|
50 |
# If there is not any such element in dic - the function throws or does not throw an exception, depending
|
51 |
# on flag not_exist_ok.
|
52 |
# For a query with *, there could be up to as many as there are values to match the *
|
@@ -54,9 +57,9 @@ def is_wildcard(string):
|
|
54 |
# for more than one * in the query -- this effect multiplies)
|
55 |
# Each of the three functions below (dict_get, dict_set, dict_delete) applies the requested
|
56 |
# operation (read, set, or delete) to each and every element el in dic, the path to which matches the query in whole,
|
57 |
-
# and reads a value from, or sets a new value to, or pops
|
58 |
#
|
59 |
-
# If no path in dic matches the query, then
|
60 |
# but if not_exist_ok=True, the function returns a default value (dict_get) or does nothing (dict_delete)
|
61 |
# or generates all the needed missing suffixes (dict_set, see details below).
|
62 |
#
|
@@ -444,7 +447,9 @@ def dict_get(
|
|
444 |
)
|
445 |
if len(components) > 1:
|
446 |
try:
|
447 |
-
success, values = get_values(
|
|
|
|
|
448 |
if success:
|
449 |
return values
|
450 |
except Exception as e:
|
|
|
24 |
# formal definition of qpath syntax by which a query is specified:
|
25 |
# qpath -> A (/A)*
|
26 |
# A -> name | * | non-neg-int
|
27 |
+
# name -> a string satisfying is_name above.
|
28 |
+
# * -> ALL members (each and every) of a list or a dictionary element in the input dictionary,
|
29 |
#
|
30 |
+
# A path p in dictionary dic, leading to element (aka subfield) el, is said to match query qpath
|
31 |
+
# (alternatively said: query qpath matches path p in dic),
|
32 |
+
# if the following recursively defined condition is satisfied:
|
33 |
+
# (1) the prefix of length 0 of qpath (i.e., pref = "") matches the empty path in dic, the path leading to the whole of dic.
|
34 |
+
# (2) Denoting by el the element in dic lead to by the path in dic that matches the prefix pref of qpath
|
35 |
+
# (el must be a list or dictionary, since led to by a path matching a prefix of qpath, and not the whole of qpath),
|
36 |
+
# and by A (as the definition above) the component, DIFFERENT from *, in qpath, that follows pref, then the element
|
37 |
+
# lead to by the path in dic matching query pref/A is el[A]. If el[A] is missing from dic, then no path in dic matches
|
38 |
+
# pref/A, that is either a longer prefix of qpath, or the whole of qpath,
|
39 |
# and hence no path in dic matches query qpath. (E.g., when el is a list, A must match indx, and its
|
40 |
# int value should be smaller than len(el) in order for the path in dic leading to element el[A] to match pref/A)
|
41 |
+
# (3) Denoting as in (2), now with A == * , then when el is a list, each and every element in the set:
|
42 |
# {el[0], el[1], .. , el[len(el)-1]} is said to be lead to by a path matching pref/*
|
43 |
# and when el is a dict, each and every element in the set {el[k] for k being a key in el} is said to be lead
|
44 |
# to by a path matching pref/*
|
45 |
#
|
46 |
# An element el lead to by path p that matches qpath as a whole is thus either a list member (when indx.match the last
|
47 |
+
# component of p) or a dictionary item (the key of which equals the last component of p). The value
|
48 |
+
# of el is returned (dic_get) or el is popped (dic_delete) or el's value is replaced by a new value (dic_set).
|
49 |
#
|
50 |
# Thus, for a query with no *, dic contains at most one element the path to which matches the query.
|
51 |
# If there is such one in dic - the function (either dict_get, dict_set, or dict_delete) operates on
|
52 |
+
# that element according to its arguments, other than not_exist_ok.
|
53 |
# If there is not any such element in dic - the function throws or does not throw an exception, depending
|
54 |
# on flag not_exist_ok.
|
55 |
# For a query with *, there could be up to as many as there are values to match the *
|
|
|
57 |
# for more than one * in the query -- this effect multiplies)
|
58 |
# Each of the three functions below (dict_get, dict_set, dict_delete) applies the requested
|
59 |
# operation (read, set, or delete) to each and every element el in dic, the path to which matches the query in whole,
|
60 |
+
# and reads a value from, or sets a new value to, or pops el out from dic.
|
61 |
#
|
62 |
+
# If no path in dic matches the query, then if not_exist_ok=False, the function throws an exception;
|
63 |
# but if not_exist_ok=True, the function returns a default value (dict_get) or does nothing (dict_delete)
|
64 |
# or generates all the needed missing suffixes (dict_set, see details below).
|
65 |
#
|
|
|
447 |
)
|
448 |
if len(components) > 1:
|
449 |
try:
|
450 |
+
success, values = get_values(
|
451 |
+
dic, components, -1 * len(components), allow_int_index=allow_int_index
|
452 |
+
)
|
453 |
if success:
|
454 |
return values
|
455 |
except Exception as e:
|
formats.py
CHANGED
@@ -28,26 +28,26 @@ class Format(InstanceOperator):
|
|
28 |
def apply_capital_new_line_notation(text: str) -> str:
|
29 |
r"""Transforms a given string by applying the Capital New Line Notation.
|
30 |
|
31 |
-
The Capital New Line Notation (\N) is designed to manage newline behavior in a string efficiently.
|
32 |
-
This custom notation aims to consolidate multiple newline characters (\n) into a single newline under
|
33 |
specific conditions, with tailored handling based on whether there's preceding text. The function
|
34 |
distinguishes between two primary scenarios:
|
35 |
|
36 |
-
1. If there's text (referred to as a prefix) followed by any number of
|
37 |
-
more
|
38 |
newlines and notation characters into a single newline when there's preceding text.
|
39 |
-
|
40 |
-
|
|
|
41 |
applicable when the notation should not introduce any newlines due to the absence of preceding text.
|
42 |
|
43 |
Args:
|
44 |
-
text (str): The input string to be transformed, potentially containing the Capital New Line Notation
|
45 |
-
(\N) mixed with actual newline characters (\n).
|
46 |
|
47 |
Returns:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
Examples:
|
53 |
>>> apply_capital_new_line_notation("Hello World\\n\\n\N")
|
@@ -131,27 +131,26 @@ class BaseFormat(Format):
|
|
131 |
class SystemFormat(BaseFormat):
|
132 |
r"""Generates the whole input to the model, from constant strings that are given as args, and from values found in specified fields of the instance.
|
133 |
|
134 |
-
Important: formats can use '\N' notations that means new-line if no new-line before and no empty string before.
|
135 |
|
136 |
SystemFormat expects the input instance to contain:
|
137 |
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
|
138 |
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
139 |
from the source dataset), in the context of the underlying task.
|
140 |
3. A field named "instruction" that contains a (non-None) string.
|
141 |
-
4. A field named with the value in arg 'demos_field'
|
142 |
and "target", representing a single demo.
|
143 |
5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt
|
144 |
|
145 |
-
SystemFormat formats the above fields into a single string to be
|
146 |
-
field "source" of the instance. Formatting is driven by two args: 'demo_format' and 'model_input_format'
|
147 |
SystemFormat also pops fields "system_prompt", "instruction", "target_prefix", and the field containing the demos out from the input instance.
|
148 |
|
149 |
Args:
|
150 |
demos_field (str): the name of the field that contains the demos, being a list of dicts, each with "source" and "target" keys
|
151 |
demo_format (str): formatting string for a single demo, combining fields "source" and "target"
|
152 |
-
model_input_format (str) overall product format, combining instruction and source (as read from fields "instruction"
|
153 |
-
|
154 |
-
format_args: Dict[str,str]: additional format args to be used when formatting the different format strings
|
155 |
|
156 |
Example:
|
157 |
when input instance:
|
@@ -423,24 +422,13 @@ class ChatAPIFormat(BaseFormat):
|
|
423 |
class HFSystemFormat(ChatAPIFormat):
|
424 |
r"""Formats the complete input for the model using the HuggingFace chat template of a given model.
|
425 |
|
426 |
-
HFSystemFormat
|
427 |
-
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
|
428 |
-
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
429 |
-
from the source dataset), in the context of the underlying task.
|
430 |
-
3. A field named "instruction" that contains a (non-None) string.
|
431 |
-
4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
|
432 |
-
and "target", representing a single demo.
|
433 |
-
5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt.
|
434 |
-
|
435 |
-
SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
|
436 |
field "source" of the instance.
|
437 |
|
438 |
Example:
|
439 |
-
HFSystemFormat(model_name="HuggingFaceH4/zephyr-7b-beta")
|
440 |
-
|
441 |
-
Uses the template defined the in tokenizer_config.json of the model:
|
442 |
|
443 |
-
"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
444 |
|
445 |
See more details in https://huggingface.co/docs/transformers/main/en/chat_templating
|
446 |
|
|
|
28 |
def apply_capital_new_line_notation(text: str) -> str:
|
29 |
r"""Transforms a given string by applying the Capital New Line Notation.
|
30 |
|
31 |
+
The Capital New Line Notation ``(\N)`` is designed to manage newline behavior in a string efficiently.
|
32 |
+
This custom notation aims to consolidate multiple newline characters ``(\n)`` into a single newline under
|
33 |
specific conditions, with tailored handling based on whether there's preceding text. The function
|
34 |
distinguishes between two primary scenarios:
|
35 |
|
36 |
+
1. If there's text (referred to as a prefix) followed by any number of ``\n`` characters and then one or
|
37 |
+
more ``\N``, the entire sequence is replaced with a single ``\n``. This effectively simplifies multiple
|
38 |
newlines and notation characters into a single newline when there's preceding text.
|
39 |
+
|
40 |
+
2. If the string starts with ``\n`` characters followed by ``\N`` without any text before this sequence, or if
|
41 |
+
``\N`` is at the very beginning of the string, the sequence is completely removed. This case is
|
42 |
applicable when the notation should not introduce any newlines due to the absence of preceding text.
|
43 |
|
44 |
Args:
|
45 |
+
text (str): The input string to be transformed, potentially containing the Capital New Line Notation ``(\N)`` mixed with actual newline characters ``(\n)``.
|
|
|
46 |
|
47 |
Returns:
|
48 |
+
The string after applying the Capital New Line Notation rules, which either consolidates multiple
|
49 |
+
newlines and notation characters into a single newline when text precedes them, or removes the
|
50 |
+
notation and any preceding newlines entirely if no text is present before the notation.
|
51 |
|
52 |
Examples:
|
53 |
>>> apply_capital_new_line_notation("Hello World\\n\\n\N")
|
|
|
131 |
class SystemFormat(BaseFormat):
|
132 |
r"""Generates the whole input to the model, from constant strings that are given as args, and from values found in specified fields of the instance.
|
133 |
|
134 |
+
Important: formats can use ``'\N'`` notations that means new-line if no new-line before and no empty string before.
|
135 |
|
136 |
SystemFormat expects the input instance to contain:
|
137 |
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
|
138 |
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
139 |
from the source dataset), in the context of the underlying task.
|
140 |
3. A field named "instruction" that contains a (non-None) string.
|
141 |
+
4. A field named with the value in arg ``'demos_field'``, containing a list of dicts, each dict with fields "source"
|
142 |
and "target", representing a single demo.
|
143 |
5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt
|
144 |
|
145 |
+
SystemFormat formats the above fields into a single string to be input to the model. This string overwrites
|
146 |
+
field "source" of the instance. Formatting is driven by two args: ``'demo_format'`` and ``'model_input_format'``.
|
147 |
SystemFormat also pops fields "system_prompt", "instruction", "target_prefix", and the field containing the demos out from the input instance.
|
148 |
|
149 |
Args:
|
150 |
demos_field (str): the name of the field that contains the demos, being a list of dicts, each with "source" and "target" keys
|
151 |
demo_format (str): formatting string for a single demo, combining fields "source" and "target"
|
152 |
+
model_input_format (str): overall product format, combining instruction and source (as read from fields "instruction" and "source" of the input instance), together with demos (as formatted into one string)
|
153 |
+
format_args (Dict[str,str]): additional format args to be used when formatting the different format strings
|
|
|
154 |
|
155 |
Example:
|
156 |
when input instance:
|
|
|
422 |
class HFSystemFormat(ChatAPIFormat):
|
423 |
r"""Formats the complete input for the model using the HuggingFace chat template of a given model.
|
424 |
|
425 |
+
HFSystemFormat formats instance fields into a single string to be inputted to the model. This string overwrites
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
field "source" of the instance.
|
427 |
|
428 |
Example:
|
429 |
+
``HFSystemFormat(model_name="HuggingFaceH4/zephyr-7b-beta")`` Uses the template defined the in tokenizer_config.json of the model:
|
|
|
|
|
430 |
|
431 |
+
``"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"``
|
432 |
|
433 |
See more details in https://huggingface.co/docs/transformers/main/en/chat_templating
|
434 |
|
image_operators.py
CHANGED
@@ -167,12 +167,14 @@ class GridLines(ImageAugmentor):
|
|
167 |
"""A class that overlays a fixed number of evenly spaced horizontal and vertical lines on an image.
|
168 |
|
169 |
Attributes:
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
|
174 |
Methods:
|
175 |
-
|
176 |
"""
|
177 |
|
178 |
num_lines: int = 128
|
@@ -207,11 +209,12 @@ class PixelNoise(ImageAugmentor):
|
|
207 |
"""A class that overlays a mask of randomly colored nxn squares across an image based on a specified noise rate.
|
208 |
|
209 |
Attributes:
|
210 |
-
|
211 |
-
|
|
|
212 |
|
213 |
Methods:
|
214 |
-
|
215 |
"""
|
216 |
|
217 |
square_size: int = 1
|
|
|
167 |
"""A class that overlays a fixed number of evenly spaced horizontal and vertical lines on an image.
|
168 |
|
169 |
Attributes:
|
170 |
+
num_lines (int): The number of horizontal and vertical lines to add.
|
171 |
+
|
172 |
+
line_thickness (int): Thickness of each line in pixels.
|
173 |
+
|
174 |
+
line_color (Tuple[int, int, int]): RGB color of the grid lines.
|
175 |
|
176 |
Methods:
|
177 |
+
process_image(image): Adds grid lines to the provided image and returns the modified image.
|
178 |
"""
|
179 |
|
180 |
num_lines: int = 128
|
|
|
209 |
"""A class that overlays a mask of randomly colored nxn squares across an image based on a specified noise rate.
|
210 |
|
211 |
Attributes:
|
212 |
+
square_size (int): Size of each square in pixels.
|
213 |
+
|
214 |
+
noise_rate (float): Proportion of the image that should be affected by noise (0 to 1).
|
215 |
|
216 |
Methods:
|
217 |
+
process_image(image): Adds the random square mask to the provided image and returns the modified image.
|
218 |
"""
|
219 |
|
220 |
square_size: int = 1
|
inference.py
CHANGED
@@ -23,7 +23,7 @@ from typing import (
|
|
23 |
Union,
|
24 |
)
|
25 |
|
26 |
-
from datasets import DatasetDict
|
27 |
from tqdm import tqdm, trange
|
28 |
from tqdm.asyncio import tqdm_asyncio
|
29 |
|
@@ -70,21 +70,26 @@ class TextGenerationInferenceOutput:
|
|
70 |
"""Contains the prediction results and metadata for the inference.
|
71 |
|
72 |
Args:
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
-
output_tokens (int) : number of output tokens to the model.
|
83 |
-
stop_reason (str): stop reason for text generation, for example "eos" (end of string).
|
84 |
-
seed (int): seed used by the model during generation.
|
85 |
-
input_text (str): input to the model.
|
86 |
-
model_name (str): the model_name as kept in the InferenceEngine.
|
87 |
-
inference_type (str): The label stating the type of the InferenceEngine.
|
88 |
"""
|
89 |
|
90 |
prediction: Union[str, List[Dict[str, Any]]]
|
@@ -103,7 +108,7 @@ class InferenceEngine(Artifact):
|
|
103 |
@abc.abstractmethod
|
104 |
def _infer(
|
105 |
self,
|
106 |
-
dataset: Union[List[Dict[str, Any]],
|
107 |
return_meta_data: bool = False,
|
108 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
109 |
"""Perform inference on the input dataset.
|
@@ -126,7 +131,7 @@ class InferenceEngine(Artifact):
|
|
126 |
|
127 |
def infer(
|
128 |
self,
|
129 |
-
dataset: Union[List[Dict[str, Any]],
|
130 |
return_meta_data: bool = False,
|
131 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
132 |
"""Verifies instances of a dataset and perform inference on the input dataset.
|
@@ -134,6 +139,10 @@ class InferenceEngine(Artifact):
|
|
134 |
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
|
135 |
predictions.
|
136 |
"""
|
|
|
|
|
|
|
|
|
137 |
if return_meta_data and not hasattr(self, "get_return_object"):
|
138 |
raise NotImplementedError(
|
139 |
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
|
@@ -147,7 +156,7 @@ class InferenceEngine(Artifact):
|
|
147 |
|
148 |
def _mock_infer(
|
149 |
self,
|
150 |
-
dataset: Union[List[Dict[str, Any]],
|
151 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
152 |
return [str(instance["source"]) for instance in dataset]
|
153 |
|
@@ -198,7 +207,7 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
|
|
198 |
@abc.abstractmethod
|
199 |
def _infer_log_probs(
|
200 |
self,
|
201 |
-
dataset: Union[List[Dict[str, Any]],
|
202 |
return_meta_data: bool = False,
|
203 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
204 |
"""Perform inference on the input dataset that returns log probs.
|
@@ -211,7 +220,7 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
|
|
211 |
|
212 |
def infer_log_probs(
|
213 |
self,
|
214 |
-
dataset: Union[List[Dict[str, Any]],
|
215 |
return_meta_data: bool = False,
|
216 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
217 |
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
@@ -446,7 +455,7 @@ class HFInferenceEngineBase(
|
|
446 |
|
447 |
def infer(
|
448 |
self,
|
449 |
-
dataset: Union[List[Dict[str, Any]],
|
450 |
return_meta_data: bool = False,
|
451 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
452 |
if not self._is_loaded():
|
@@ -456,14 +465,14 @@ class HFInferenceEngineBase(
|
|
456 |
@abc.abstractmethod
|
457 |
def _infer(
|
458 |
self,
|
459 |
-
dataset: Union[List[Dict[str, Any]],
|
460 |
return_meta_data: bool = False,
|
461 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
462 |
raise NotImplementedError
|
463 |
|
464 |
def infer_log_probs(
|
465 |
self,
|
466 |
-
dataset: Union[List[Dict[str, Any]],
|
467 |
return_meta_data: bool = False,
|
468 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
469 |
if not self._is_loaded():
|
@@ -473,7 +482,7 @@ class HFInferenceEngineBase(
|
|
473 |
@abc.abstractmethod
|
474 |
def _infer_log_probs(
|
475 |
self,
|
476 |
-
dataset: Union[List[Dict[str, Any]],
|
477 |
return_meta_data: bool = False,
|
478 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
479 |
raise NotImplementedError
|
@@ -524,7 +533,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
524 |
|
525 |
def _infer_fn(
|
526 |
self,
|
527 |
-
dataset: Union[List[Dict[str, Any]],
|
528 |
return_meta_data: bool,
|
529 |
return_logprobs: bool,
|
530 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
@@ -565,7 +574,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
565 |
|
566 |
def _infer(
|
567 |
self,
|
568 |
-
dataset: Union[List[Dict[str, Any]],
|
569 |
return_meta_data: bool = False,
|
570 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
571 |
self.verify_not_chat_api(dataset)
|
@@ -573,7 +582,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
573 |
|
574 |
def _infer_log_probs(
|
575 |
self,
|
576 |
-
dataset: Union[List[Dict[str, Any]],
|
577 |
return_meta_data: bool = False,
|
578 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
579 |
self.verify_not_chat_api(dataset)
|
@@ -647,7 +656,7 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
|
|
647 |
|
648 |
def _infer_fn(
|
649 |
self,
|
650 |
-
dataset: Union[List[Dict[str, Any]],
|
651 |
return_meta_data: bool,
|
652 |
return_logprobs: bool,
|
653 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
@@ -681,14 +690,14 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
|
|
681 |
|
682 |
def _infer(
|
683 |
self,
|
684 |
-
dataset: Union[List[Dict[str, Any]],
|
685 |
return_meta_data: bool = False,
|
686 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
687 |
return self._infer_fn(dataset, return_meta_data, False)
|
688 |
|
689 |
def _infer_log_probs(
|
690 |
self,
|
691 |
-
dataset: Union[List[Dict[str, Any]],
|
692 |
return_meta_data: bool = False,
|
693 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
694 |
return self._infer_fn(dataset, return_meta_data, True)
|
@@ -879,7 +888,7 @@ class HFPipelineBasedInferenceEngine(
|
|
879 |
|
880 |
def _infer(
|
881 |
self,
|
882 |
-
dataset: Union[List[Dict[str, Any]],
|
883 |
return_meta_data: bool = False,
|
884 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
885 |
if not self._is_loaded():
|
@@ -933,13 +942,13 @@ class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine):
|
|
933 |
|
934 |
def _mock_infer(
|
935 |
self,
|
936 |
-
dataset: Union[List[Dict[str, Any]],
|
937 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
938 |
return [self.default_inference_value for _ in dataset]
|
939 |
|
940 |
def _infer(
|
941 |
self,
|
942 |
-
dataset: Union[List[Dict[str, Any]],
|
943 |
return_meta_data: bool = False,
|
944 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
945 |
return [
|
@@ -951,7 +960,7 @@ class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine):
|
|
951 |
|
952 |
def _infer_log_probs(
|
953 |
self,
|
954 |
-
dataset: Union[List[Dict[str, Any]],
|
955 |
return_meta_data: bool = False,
|
956 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
957 |
return [
|
@@ -1047,14 +1056,14 @@ class GenericInferenceEngine(
|
|
1047 |
|
1048 |
def _infer(
|
1049 |
self,
|
1050 |
-
dataset: Union[List[Dict[str, Any]],
|
1051 |
return_meta_data: bool = False,
|
1052 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1053 |
return self.engine._infer(dataset)
|
1054 |
|
1055 |
def _infer_log_probs(
|
1056 |
self,
|
1057 |
-
dataset: Union[List[Dict[str, Any]],
|
1058 |
return_meta_data: bool = False,
|
1059 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1060 |
if not isinstance(self.engine, LogProbInferenceEngine):
|
@@ -1082,7 +1091,7 @@ class OllamaInferenceEngine(
|
|
1082 |
|
1083 |
def _infer(
|
1084 |
self,
|
1085 |
-
dataset: Union[List[Dict[str, Any]],
|
1086 |
return_meta_data: bool = False,
|
1087 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1088 |
import ollama
|
@@ -1250,7 +1259,7 @@ class IbmGenAiInferenceEngine(
|
|
1250 |
|
1251 |
def _infer(
|
1252 |
self,
|
1253 |
-
dataset: Union[List[Dict[str, Any]],
|
1254 |
return_meta_data: bool = False,
|
1255 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1256 |
from genai.schema import TextGenerationParameters, TextGenerationResult
|
@@ -1279,7 +1288,7 @@ class IbmGenAiInferenceEngine(
|
|
1279 |
|
1280 |
def _infer_log_probs(
|
1281 |
self,
|
1282 |
-
dataset: Union[List[Dict[str, Any]],
|
1283 |
return_meta_data: bool = False,
|
1284 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1285 |
from genai.schema import TextGenerationParameters, TextGenerationResult
|
@@ -1507,7 +1516,7 @@ class OpenAiInferenceEngine(
|
|
1507 |
|
1508 |
def _infer(
|
1509 |
self,
|
1510 |
-
dataset: Union[List[Dict[str, Any]],
|
1511 |
return_meta_data: bool = False,
|
1512 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1513 |
outputs = []
|
@@ -1527,22 +1536,14 @@ class OpenAiInferenceEngine(
|
|
1527 |
|
1528 |
def _infer_log_probs(
|
1529 |
self,
|
1530 |
-
dataset: Union[List[Dict[str, Any]],
|
1531 |
return_meta_data: bool = False,
|
1532 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1533 |
outputs = []
|
1534 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
|
|
1535 |
response = self.client.chat.completions.create(
|
1536 |
-
messages=
|
1537 |
-
# {
|
1538 |
-
# "role": "system",
|
1539 |
-
# "content": self.system_prompt,
|
1540 |
-
# },
|
1541 |
-
{
|
1542 |
-
"role": "user",
|
1543 |
-
"content": instance["source"],
|
1544 |
-
}
|
1545 |
-
],
|
1546 |
model=self.model_name,
|
1547 |
**self._get_completion_kwargs(),
|
1548 |
)
|
@@ -1681,7 +1682,7 @@ class TogetherAiInferenceEngine(
|
|
1681 |
|
1682 |
def _infer(
|
1683 |
self,
|
1684 |
-
dataset: Union[List[Dict[str, Any]],
|
1685 |
return_meta_data: bool = False,
|
1686 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1687 |
from together.types.models import ModelType
|
@@ -1943,7 +1944,7 @@ class WMLInferenceEngineBase(
|
|
1943 |
@abc.abstractmethod
|
1944 |
def _send_requests(
|
1945 |
self,
|
1946 |
-
dataset: Union[List[Dict[str, Any]],
|
1947 |
return_logprobs: bool,
|
1948 |
return_meta_data: bool,
|
1949 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
@@ -1955,7 +1956,7 @@ class WMLInferenceEngineBase(
|
|
1955 |
|
1956 |
def _infer(
|
1957 |
self,
|
1958 |
-
dataset: Union[List[Dict[str, Any]],
|
1959 |
return_meta_data: bool = False,
|
1960 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1961 |
if self._model is None:
|
@@ -1969,7 +1970,7 @@ class WMLInferenceEngineBase(
|
|
1969 |
|
1970 |
def _infer_log_probs(
|
1971 |
self,
|
1972 |
-
dataset: Union[List[Dict[str, Any]],
|
1973 |
return_meta_data: bool = False,
|
1974 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1975 |
if self._model is None:
|
@@ -2050,27 +2051,29 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
2050 |
|
2051 |
Attributes:
|
2052 |
concurrency_limit (int): Number of concurrent requests sent to a model. Default is 10,
|
2053 |
-
|
2054 |
|
2055 |
Examples:
|
2056 |
-
|
2057 |
|
2058 |
-
|
2059 |
-
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
2060 |
-
}
|
2061 |
-
model_name = "google/flan-t5-xxl"
|
2062 |
-
wml_inference = WMLInferenceEngineGeneration(
|
2063 |
-
credentials=wml_credentials,
|
2064 |
-
model_name=model_name,
|
2065 |
-
data_classification_policy=["public"],
|
2066 |
-
top_p=0.5,
|
2067 |
-
random_seed=123,
|
2068 |
-
)
|
2069 |
|
2070 |
-
|
2071 |
-
|
2072 |
-
|
2073 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2074 |
"""
|
2075 |
|
2076 |
concurrency_limit: int = 10
|
@@ -2112,7 +2115,7 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
2112 |
|
2113 |
def _send_requests(
|
2114 |
self,
|
2115 |
-
dataset: Union[List[Dict[str, Any]],
|
2116 |
return_logprobs: bool,
|
2117 |
return_meta_data: bool,
|
2118 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
@@ -2178,31 +2181,33 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2178 |
|
2179 |
Attributes:
|
2180 |
image_encoder (EncodeImageToString, optional): operator which encodes images in
|
2181 |
-
|
2182 |
-
|
2183 |
|
2184 |
Example:
|
2185 |
-
|
2186 |
-
from .image_operators
|
2187 |
|
2188 |
-
|
|
|
2189 |
|
2190 |
-
|
2191 |
-
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
2192 |
-
}
|
2193 |
-
model_name = "meta-llama/llama-3-2-11b-vision-instruct"
|
2194 |
-
wml_inference = WMLInferenceEngineChat(
|
2195 |
-
credentials=wml_credentials,
|
2196 |
-
model_name=model_name,
|
2197 |
-
image_encoder=image_encoder,
|
2198 |
-
data_classification_policy=["public"],
|
2199 |
-
max_tokens=1024,
|
2200 |
-
)
|
2201 |
|
2202 |
-
|
2203 |
-
|
2204 |
-
|
2205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2206 |
"""
|
2207 |
|
2208 |
image_encoder: Optional[EncodeImageToString] = None
|
@@ -2303,7 +2308,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2303 |
|
2304 |
def _send_requests(
|
2305 |
self,
|
2306 |
-
dataset: Union[List[Dict[str, Any]],
|
2307 |
return_logprobs: bool,
|
2308 |
return_meta_data: bool,
|
2309 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
@@ -2428,7 +2433,7 @@ class LMMSEvalInferenceEngine(LMMSEvalBaseInferenceEngine):
|
|
2428 |
|
2429 |
def _infer(
|
2430 |
self,
|
2431 |
-
dataset: Union[List[Dict[str, Any]],
|
2432 |
return_meta_data: bool = False,
|
2433 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2434 |
if not self._is_loaded():
|
@@ -2500,7 +2505,7 @@ class LMMSEvalLoglikelihoodInferenceEngine(LMMSEvalBaseInferenceEngine):
|
|
2500 |
|
2501 |
def _infer(
|
2502 |
self,
|
2503 |
-
dataset: Union[List[Dict[str, Any]],
|
2504 |
return_meta_data: bool = False,
|
2505 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2506 |
if not self._is_loaded():
|
@@ -2555,7 +2560,7 @@ class VLLMInferenceEngine(
|
|
2555 |
|
2556 |
def _infer(
|
2557 |
self,
|
2558 |
-
dataset: Union[List[Dict[str, Any]],
|
2559 |
return_meta_data: bool = False,
|
2560 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2561 |
inputs = []
|
@@ -2681,7 +2686,7 @@ class LiteLLMInferenceEngine(
|
|
2681 |
|
2682 |
def _infer(
|
2683 |
self,
|
2684 |
-
dataset: Union[List[Dict[str, Any]],
|
2685 |
return_meta_data: bool = False,
|
2686 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2687 |
"""Main inference entry point."""
|
@@ -2735,8 +2740,8 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
2735 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
2736 |
},
|
2737 |
"together-ai": {
|
2738 |
-
"llama-3-8b-instruct": "together_ai/
|
2739 |
-
"llama-3-70b-instruct": "together_ai/
|
2740 |
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
|
2741 |
},
|
2742 |
"aws": {
|
@@ -2812,7 +2817,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
2812 |
|
2813 |
def _infer(
|
2814 |
self,
|
2815 |
-
dataset: Union[List[Dict[str, Any]],
|
2816 |
return_meta_data: bool = False,
|
2817 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2818 |
return self.engine._infer(dataset, return_meta_data)
|
@@ -2898,7 +2903,7 @@ class HFOptionSelectingInferenceEngine(InferenceEngine):
|
|
2898 |
|
2899 |
def _infer(
|
2900 |
self,
|
2901 |
-
dataset: Union[List[Dict[str, Any]],
|
2902 |
return_meta_data: bool = False,
|
2903 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2904 |
inputs = []
|
|
|
23 |
Union,
|
24 |
)
|
25 |
|
26 |
+
from datasets import Dataset, DatasetDict
|
27 |
from tqdm import tqdm, trange
|
28 |
from tqdm.asyncio import tqdm_asyncio
|
29 |
|
|
|
70 |
"""Contains the prediction results and metadata for the inference.
|
71 |
|
72 |
Args:
|
73 |
+
prediction (Union[str, List[Dict[str, Any]]]): If this is the result of an _infer call, the string predicted by the model.
|
74 |
+
| If this is the results of an _infer_log_probs call, a list of dictionaries. The i'th dictionary represents
|
75 |
+
the i'th token in the response. The entry "top_tokens" in the dictionary holds a sorted list of the top tokens
|
76 |
+
for this position and their probabilities.
|
77 |
+
| For example: ``[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
|
78 |
+
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]``
|
79 |
+
|
80 |
+
input_tokens (int) : number of input tokens to the model.
|
81 |
+
|
82 |
+
output_tokens (int) : number of output tokens to the model.
|
83 |
+
|
84 |
+
stop_reason (str): stop reason for text generation, for example "eos" (end of string).
|
85 |
+
|
86 |
+
seed (int): seed used by the model during generation.
|
87 |
+
|
88 |
+
input_text (str): input to the model.
|
89 |
+
|
90 |
+
model_name (str): the model_name as kept in the InferenceEngine.
|
91 |
|
92 |
+
inference_type (str): The label stating the type of the InferenceEngine.
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
"""
|
94 |
|
95 |
prediction: Union[str, List[Dict[str, Any]]]
|
|
|
108 |
@abc.abstractmethod
|
109 |
def _infer(
|
110 |
self,
|
111 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
112 |
return_meta_data: bool = False,
|
113 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
114 |
"""Perform inference on the input dataset.
|
|
|
131 |
|
132 |
def infer(
|
133 |
self,
|
134 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
135 |
return_meta_data: bool = False,
|
136 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
137 |
"""Verifies instances of a dataset and perform inference on the input dataset.
|
|
|
139 |
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
|
140 |
predictions.
|
141 |
"""
|
142 |
+
if not isoftype(dataset, Union[List[Dict[str, Any]], Dataset]):
|
143 |
+
raise Exception(
|
144 |
+
"Dataset passed to infer() is not list of dictionaries or Huggingface Dataset"
|
145 |
+
)
|
146 |
if return_meta_data and not hasattr(self, "get_return_object"):
|
147 |
raise NotImplementedError(
|
148 |
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
|
|
|
156 |
|
157 |
def _mock_infer(
|
158 |
self,
|
159 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
160 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
161 |
return [str(instance["source"]) for instance in dataset]
|
162 |
|
|
|
207 |
@abc.abstractmethod
|
208 |
def _infer_log_probs(
|
209 |
self,
|
210 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
211 |
return_meta_data: bool = False,
|
212 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
213 |
"""Perform inference on the input dataset that returns log probs.
|
|
|
220 |
|
221 |
def infer_log_probs(
|
222 |
self,
|
223 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
224 |
return_meta_data: bool = False,
|
225 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
226 |
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
|
|
455 |
|
456 |
def infer(
|
457 |
self,
|
458 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
459 |
return_meta_data: bool = False,
|
460 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
461 |
if not self._is_loaded():
|
|
|
465 |
@abc.abstractmethod
|
466 |
def _infer(
|
467 |
self,
|
468 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
469 |
return_meta_data: bool = False,
|
470 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
471 |
raise NotImplementedError
|
472 |
|
473 |
def infer_log_probs(
|
474 |
self,
|
475 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
476 |
return_meta_data: bool = False,
|
477 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
478 |
if not self._is_loaded():
|
|
|
482 |
@abc.abstractmethod
|
483 |
def _infer_log_probs(
|
484 |
self,
|
485 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
486 |
return_meta_data: bool = False,
|
487 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
488 |
raise NotImplementedError
|
|
|
533 |
|
534 |
def _infer_fn(
|
535 |
self,
|
536 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
537 |
return_meta_data: bool,
|
538 |
return_logprobs: bool,
|
539 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
574 |
|
575 |
def _infer(
|
576 |
self,
|
577 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
578 |
return_meta_data: bool = False,
|
579 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
580 |
self.verify_not_chat_api(dataset)
|
|
|
582 |
|
583 |
def _infer_log_probs(
|
584 |
self,
|
585 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
586 |
return_meta_data: bool = False,
|
587 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
588 |
self.verify_not_chat_api(dataset)
|
|
|
656 |
|
657 |
def _infer_fn(
|
658 |
self,
|
659 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
660 |
return_meta_data: bool,
|
661 |
return_logprobs: bool,
|
662 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
690 |
|
691 |
def _infer(
|
692 |
self,
|
693 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
694 |
return_meta_data: bool = False,
|
695 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
696 |
return self._infer_fn(dataset, return_meta_data, False)
|
697 |
|
698 |
def _infer_log_probs(
|
699 |
self,
|
700 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
701 |
return_meta_data: bool = False,
|
702 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
703 |
return self._infer_fn(dataset, return_meta_data, True)
|
|
|
888 |
|
889 |
def _infer(
|
890 |
self,
|
891 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
892 |
return_meta_data: bool = False,
|
893 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
894 |
if not self._is_loaded():
|
|
|
942 |
|
943 |
def _mock_infer(
|
944 |
self,
|
945 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
946 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
947 |
return [self.default_inference_value for _ in dataset]
|
948 |
|
949 |
def _infer(
|
950 |
self,
|
951 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
952 |
return_meta_data: bool = False,
|
953 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
954 |
return [
|
|
|
960 |
|
961 |
def _infer_log_probs(
|
962 |
self,
|
963 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
964 |
return_meta_data: bool = False,
|
965 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
966 |
return [
|
|
|
1056 |
|
1057 |
def _infer(
|
1058 |
self,
|
1059 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1060 |
return_meta_data: bool = False,
|
1061 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1062 |
return self.engine._infer(dataset)
|
1063 |
|
1064 |
def _infer_log_probs(
|
1065 |
self,
|
1066 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1067 |
return_meta_data: bool = False,
|
1068 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1069 |
if not isinstance(self.engine, LogProbInferenceEngine):
|
|
|
1091 |
|
1092 |
def _infer(
|
1093 |
self,
|
1094 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1095 |
return_meta_data: bool = False,
|
1096 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1097 |
import ollama
|
|
|
1259 |
|
1260 |
def _infer(
|
1261 |
self,
|
1262 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1263 |
return_meta_data: bool = False,
|
1264 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1265 |
from genai.schema import TextGenerationParameters, TextGenerationResult
|
|
|
1288 |
|
1289 |
def _infer_log_probs(
|
1290 |
self,
|
1291 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1292 |
return_meta_data: bool = False,
|
1293 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1294 |
from genai.schema import TextGenerationParameters, TextGenerationResult
|
|
|
1516 |
|
1517 |
def _infer(
|
1518 |
self,
|
1519 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1520 |
return_meta_data: bool = False,
|
1521 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1522 |
outputs = []
|
|
|
1536 |
|
1537 |
def _infer_log_probs(
|
1538 |
self,
|
1539 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1540 |
return_meta_data: bool = False,
|
1541 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1542 |
outputs = []
|
1543 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
1544 |
+
messages = self.to_messages(instance)
|
1545 |
response = self.client.chat.completions.create(
|
1546 |
+
messages=messages,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1547 |
model=self.model_name,
|
1548 |
**self._get_completion_kwargs(),
|
1549 |
)
|
|
|
1682 |
|
1683 |
def _infer(
|
1684 |
self,
|
1685 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1686 |
return_meta_data: bool = False,
|
1687 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1688 |
from together.types.models import ModelType
|
|
|
1944 |
@abc.abstractmethod
|
1945 |
def _send_requests(
|
1946 |
self,
|
1947 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1948 |
return_logprobs: bool,
|
1949 |
return_meta_data: bool,
|
1950 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
1956 |
|
1957 |
def _infer(
|
1958 |
self,
|
1959 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1960 |
return_meta_data: bool = False,
|
1961 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1962 |
if self._model is None:
|
|
|
1970 |
|
1971 |
def _infer_log_probs(
|
1972 |
self,
|
1973 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
1974 |
return_meta_data: bool = False,
|
1975 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1976 |
if self._model is None:
|
|
|
2051 |
|
2052 |
Attributes:
|
2053 |
concurrency_limit (int): Number of concurrent requests sent to a model. Default is 10,
|
2054 |
+
which is also the maximum value.
|
2055 |
|
2056 |
Examples:
|
2057 |
+
.. code-block:: python
|
2058 |
|
2059 |
+
from .api import load_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2060 |
|
2061 |
+
wml_credentials = {
|
2062 |
+
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
2063 |
+
}
|
2064 |
+
model_name = "google/flan-t5-xxl"
|
2065 |
+
wml_inference = WMLInferenceEngineGeneration(
|
2066 |
+
credentials=wml_credentials,
|
2067 |
+
model_name=model_name,
|
2068 |
+
data_classification_policy=["public"],
|
2069 |
+
top_p=0.5,
|
2070 |
+
random_seed=123,
|
2071 |
+
)
|
2072 |
+
|
2073 |
+
dataset = load_dataset(
|
2074 |
+
dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5"
|
2075 |
+
)
|
2076 |
+
results = wml_inference.infer(dataset["test"])
|
2077 |
"""
|
2078 |
|
2079 |
concurrency_limit: int = 10
|
|
|
2115 |
|
2116 |
def _send_requests(
|
2117 |
self,
|
2118 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
2119 |
return_logprobs: bool,
|
2120 |
return_meta_data: bool,
|
2121 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
2181 |
|
2182 |
Attributes:
|
2183 |
image_encoder (EncodeImageToString, optional): operator which encodes images in
|
2184 |
+
given format to base64 strings required by service. You should specify it when
|
2185 |
+
you are using images in your inputs.
|
2186 |
|
2187 |
Example:
|
2188 |
+
.. code-block:: python
|
|
|
2189 |
|
2190 |
+
from .api import load_dataset
|
2191 |
+
from .image_operators
|
2192 |
|
2193 |
+
image_encoder = EncodeImageToString(image_format="JPEG")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2194 |
|
2195 |
+
wml_credentials = {
|
2196 |
+
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
2197 |
+
}
|
2198 |
+
model_name = "meta-llama/llama-3-2-11b-vision-instruct"
|
2199 |
+
wml_inference = WMLInferenceEngineChat(
|
2200 |
+
credentials=wml_credentials,
|
2201 |
+
model_name=model_name,
|
2202 |
+
image_encoder=image_encoder,
|
2203 |
+
data_classification_policy=["public"],
|
2204 |
+
max_tokens=1024,
|
2205 |
+
)
|
2206 |
+
|
2207 |
+
dataset = load_dataset(
|
2208 |
+
dataset_query="card=cards.doc_vqa.en,template=templates.qa.with_context.with_type,loader_limit=30"
|
2209 |
+
)
|
2210 |
+
results = wml_inference.infer(dataset["test"])
|
2211 |
"""
|
2212 |
|
2213 |
image_encoder: Optional[EncodeImageToString] = None
|
|
|
2308 |
|
2309 |
def _send_requests(
|
2310 |
self,
|
2311 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
2312 |
return_logprobs: bool,
|
2313 |
return_meta_data: bool,
|
2314 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
2433 |
|
2434 |
def _infer(
|
2435 |
self,
|
2436 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
2437 |
return_meta_data: bool = False,
|
2438 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2439 |
if not self._is_loaded():
|
|
|
2505 |
|
2506 |
def _infer(
|
2507 |
self,
|
2508 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
2509 |
return_meta_data: bool = False,
|
2510 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2511 |
if not self._is_loaded():
|
|
|
2560 |
|
2561 |
def _infer(
|
2562 |
self,
|
2563 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
2564 |
return_meta_data: bool = False,
|
2565 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2566 |
inputs = []
|
|
|
2686 |
|
2687 |
def _infer(
|
2688 |
self,
|
2689 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
2690 |
return_meta_data: bool = False,
|
2691 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2692 |
"""Main inference entry point."""
|
|
|
2740 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
2741 |
},
|
2742 |
"together-ai": {
|
2743 |
+
"llama-3-8b-instruct": "together_ai/meta-llama/Llama-3-8b-chat-hf",
|
2744 |
+
"llama-3-70b-instruct": "together_ai/meta-llama/Llama-3-70b-chat-hf",
|
2745 |
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
|
2746 |
},
|
2747 |
"aws": {
|
|
|
2817 |
|
2818 |
def _infer(
|
2819 |
self,
|
2820 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
2821 |
return_meta_data: bool = False,
|
2822 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2823 |
return self.engine._infer(dataset, return_meta_data)
|
|
|
2903 |
|
2904 |
def _infer(
|
2905 |
self,
|
2906 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
2907 |
return_meta_data: bool = False,
|
2908 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
2909 |
inputs = []
|
llm_as_judge.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Literal, Optional
|
|
4 |
|
5 |
from .api import infer
|
6 |
from .dataclass import Field
|
7 |
-
from .formats import Format, SystemFormat
|
8 |
from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
|
9 |
from .metrics import BulkInstanceMetric
|
10 |
from .operator import SequentialOperator
|
@@ -65,12 +65,17 @@ class LLMAsJudgeBase(BulkInstanceMetric, ArtifactFetcherMixin):
|
|
65 |
)
|
66 |
|
67 |
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
68 |
-
if self.format and type(self.format) is not
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
75 |
raise ValueError(
|
76 |
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
@@ -132,16 +137,24 @@ class LLMAsJudge(LLMAsJudgeBase):
|
|
132 |
|
133 |
Attributes:
|
134 |
main_score (str): The main score label used for evaluation.
|
|
|
135 |
task (Literal["rating.single_turn","rating.single_turn_with_reference",
|
136 |
"pairwise_comparative_rating.single_turn"]): The type of task the llm as judge runs.
|
137 |
-
|
|
|
138 |
template (Template): The template used when generating inputs for the judge llm.
|
|
|
139 |
format (Format): The format used when generating inputs for judge llm.
|
|
|
140 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
|
|
141 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
142 |
-
|
|
|
143 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
|
|
144 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
|
|
145 |
batch_size (int): The size of the bulk.
|
146 |
"""
|
147 |
|
@@ -318,22 +331,34 @@ class TaskBasedLLMasJudge(LLMAsJudgeBase):
|
|
318 |
|
319 |
Attributes:
|
320 |
main_score (str): The main score label used for evaluation.
|
|
|
321 |
task (str): The type of task the llm as judge runs.
|
322 |
This defines the output and input format of the judge model.
|
|
|
323 |
template (Template): The template used when generating inputs for the judge llm.
|
|
|
324 |
format (Format): The format used when generating inputs for judge llm.
|
|
|
325 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
|
|
326 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
327 |
-
|
|
|
328 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
|
|
329 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
|
|
330 |
batch_size (int): The size of the bulk.
|
|
|
331 |
infer_log_probs(bool): whether to perform the inference using logprobs. If true, the template's
|
332 |
post-processing must support the logprobs output.
|
|
|
333 |
judge_to_generator_fields_mapping (Dict[str, str]): optional mapping between the names of the fields in the generator task and the
|
334 |
judge task. For example, if the generator task uses "reference_answers" and the judge task expect "ground_truth",
|
335 |
include {"ground_truth": "reference_answers"} in this dictionary.
|
336 |
-
|
|
|
|
|
337 |
include_meta_data (bool): whether to include the inference per-instance metadata in the returned results.
|
338 |
|
339 |
"""
|
@@ -384,7 +409,10 @@ class TaskBasedLLMasJudge(LLMAsJudgeBase):
|
|
384 |
# if format is not directly set in constructor, choose according to the inference model
|
385 |
def set_format_for_inference_engine(self):
|
386 |
model_name = self.inference_model.get_engine_id()
|
387 |
-
|
|
|
|
|
|
|
388 |
format_name = "formats.llama3_instruct"
|
389 |
else:
|
390 |
format_name = "formats.empty"
|
|
|
4 |
|
5 |
from .api import infer
|
6 |
from .dataclass import Field
|
7 |
+
from .formats import ChatAPIFormat, Format, SystemFormat
|
8 |
from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
|
9 |
from .metrics import BulkInstanceMetric
|
10 |
from .operator import SequentialOperator
|
|
|
65 |
)
|
66 |
|
67 |
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
68 |
+
if self.format and type(self.format) is not ChatAPIFormat:
|
69 |
+
if not (
|
70 |
+
type(self.format) is SystemFormat
|
71 |
+
and self.format.__id__ == "formats.empty"
|
72 |
+
):
|
73 |
+
raise ValueError(
|
74 |
+
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
75 |
+
"not support formatting. Please remove the format definition from the recipe,"
|
76 |
+
"or set the format to either 'formats.empty' or 'formats.chat_api'"
|
77 |
+
" (OpenAi Chat API take care of the formatting automatically)."
|
78 |
+
)
|
79 |
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
80 |
raise ValueError(
|
81 |
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
|
|
137 |
|
138 |
Attributes:
|
139 |
main_score (str): The main score label used for evaluation.
|
140 |
+
|
141 |
task (Literal["rating.single_turn","rating.single_turn_with_reference",
|
142 |
"pairwise_comparative_rating.single_turn"]): The type of task the llm as judge runs.
|
143 |
+
This defines the output and input format of the judge model.
|
144 |
+
|
145 |
template (Template): The template used when generating inputs for the judge llm.
|
146 |
+
|
147 |
format (Format): The format used when generating inputs for judge llm.
|
148 |
+
|
149 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
150 |
+
|
151 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
152 |
+
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
153 |
+
|
154 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
155 |
+
|
156 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
157 |
+
|
158 |
batch_size (int): The size of the bulk.
|
159 |
"""
|
160 |
|
|
|
331 |
|
332 |
Attributes:
|
333 |
main_score (str): The main score label used for evaluation.
|
334 |
+
|
335 |
task (str): The type of task the llm as judge runs.
|
336 |
This defines the output and input format of the judge model.
|
337 |
+
|
338 |
template (Template): The template used when generating inputs for the judge llm.
|
339 |
+
|
340 |
format (Format): The format used when generating inputs for judge llm.
|
341 |
+
|
342 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
343 |
+
|
344 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
345 |
+
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
346 |
+
|
347 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
348 |
+
|
349 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
350 |
+
|
351 |
batch_size (int): The size of the bulk.
|
352 |
+
|
353 |
infer_log_probs(bool): whether to perform the inference using logprobs. If true, the template's
|
354 |
post-processing must support the logprobs output.
|
355 |
+
|
356 |
judge_to_generator_fields_mapping (Dict[str, str]): optional mapping between the names of the fields in the generator task and the
|
357 |
judge task. For example, if the generator task uses "reference_answers" and the judge task expect "ground_truth",
|
358 |
include {"ground_truth": "reference_answers"} in this dictionary.
|
359 |
+
|
360 |
+
prediction_field (str): if indicated, and prediction exist, copy prediction to this field name in task_data.
|
361 |
+
|
362 |
include_meta_data (bool): whether to include the inference per-instance metadata in the returned results.
|
363 |
|
364 |
"""
|
|
|
409 |
# if format is not directly set in constructor, choose according to the inference model
|
410 |
def set_format_for_inference_engine(self):
|
411 |
model_name = self.inference_model.get_engine_id()
|
412 |
+
# TODO : better format resolution to support more chat_api options
|
413 |
+
if "rits" in model_name:
|
414 |
+
format_name = "formats.chat_api"
|
415 |
+
elif re.search("llama.?3.*instruct", model_name):
|
416 |
format_name = "formats.llama3_instruct"
|
417 |
else:
|
418 |
format_name = "formats.empty"
|
loaders.py
CHANGED
@@ -162,14 +162,22 @@ class LoadHF(Loader):
|
|
162 |
|
163 |
Args:
|
164 |
path: The path or identifier of the dataset on the HuggingFace Hub.
|
|
|
165 |
name: An optional dataset name.
|
|
|
166 |
data_dir: Optional directory to store downloaded data.
|
|
|
167 |
split: Optional specification of which split to load.
|
|
|
168 |
data_files: Optional specification of particular data files to load.
|
|
|
169 |
revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
|
170 |
-
|
|
|
|
|
171 |
filtering_lambda: A lambda function for filtering the data after loading.
|
172 |
-
|
|
|
173 |
|
174 |
Example:
|
175 |
Loading glue's mrpc dataset
|
@@ -355,7 +363,9 @@ class LoadCSV(Loader):
|
|
355 |
file_path, nrows=self.get_limit(), sep=self.sep
|
356 |
).to_dict("records")
|
357 |
else:
|
358 |
-
iterables[split_name] = pd.read_csv(file_path).to_dict(
|
|
|
|
|
359 |
return iterables
|
360 |
|
361 |
|
@@ -733,19 +743,24 @@ class LoadFromHFSpace(LoadHF):
|
|
733 |
|
734 |
Args:
|
735 |
space_name (str): Name of the HuggingFace Space to be accessed.
|
|
|
736 |
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
|
737 |
-
|
738 |
-
|
739 |
-
|
|
|
740 |
path (str, optional): Absolute path to a directory where data should be downloaded.
|
|
|
741 |
revision (str, optional): ID of a Git branch or commit to be used. By default, it is
|
742 |
-
|
743 |
-
|
|
|
744 |
use_token (bool, optional): Whether a token is used for authentication when accessing
|
745 |
-
|
746 |
-
|
|
|
747 |
token_env (str, optional): Key of an env variable which value will be used for
|
748 |
-
|
749 |
|
750 |
Example:
|
751 |
Loading from a HuggingFace Space
|
|
|
162 |
|
163 |
Args:
|
164 |
path: The path or identifier of the dataset on the HuggingFace Hub.
|
165 |
+
|
166 |
name: An optional dataset name.
|
167 |
+
|
168 |
data_dir: Optional directory to store downloaded data.
|
169 |
+
|
170 |
split: Optional specification of which split to load.
|
171 |
+
|
172 |
data_files: Optional specification of particular data files to load.
|
173 |
+
|
174 |
revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
|
175 |
+
|
176 |
+
streaming (bool): indicating if streaming should be used.
|
177 |
+
|
178 |
filtering_lambda: A lambda function for filtering the data after loading.
|
179 |
+
|
180 |
+
num_proc (int): Optional integer to specify the number of processes to use for parallel dataset loading.
|
181 |
|
182 |
Example:
|
183 |
Loading glue's mrpc dataset
|
|
|
363 |
file_path, nrows=self.get_limit(), sep=self.sep
|
364 |
).to_dict("records")
|
365 |
else:
|
366 |
+
iterables[split_name] = pd.read_csv(file_path, sep=self.sep).to_dict(
|
367 |
+
"records"
|
368 |
+
)
|
369 |
return iterables
|
370 |
|
371 |
|
|
|
743 |
|
744 |
Args:
|
745 |
space_name (str): Name of the HuggingFace Space to be accessed.
|
746 |
+
|
747 |
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
|
748 |
+
paths to files within a given repository. If given as a mapping, paths should
|
749 |
+
be values, while keys should represent the type of respective files
|
750 |
+
(training, testing etc.).
|
751 |
+
|
752 |
path (str, optional): Absolute path to a directory where data should be downloaded.
|
753 |
+
|
754 |
revision (str, optional): ID of a Git branch or commit to be used. By default, it is
|
755 |
+
set to None, thus data is downloaded from the main branch of the accessed
|
756 |
+
repository.
|
757 |
+
|
758 |
use_token (bool, optional): Whether a token is used for authentication when accessing
|
759 |
+
the HuggingFace Space. If necessary, the token is read from the HuggingFace
|
760 |
+
config folder.
|
761 |
+
|
762 |
token_env (str, optional): Key of an env variable which value will be used for
|
763 |
+
authentication when accessing the HuggingFace Space - if necessary.
|
764 |
|
765 |
Example:
|
766 |
Loading from a HuggingFace Space
|
metrics.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import ast
|
2 |
import json
|
|
|
3 |
import os
|
4 |
import re
|
5 |
import string
|
@@ -27,7 +28,11 @@ from .dataclass import (
|
|
27 |
)
|
28 |
from .deprecation_utils import deprecation
|
29 |
from .error_utils import Documentation, UnitxtWarning
|
30 |
-
from .inference import
|
|
|
|
|
|
|
|
|
31 |
from .logging_utils import get_logger
|
32 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
33 |
from .operator import (
|
@@ -960,11 +965,13 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
960 |
"""Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
|
961 |
|
962 |
InstanceMetric currently allows two reductions:
|
|
|
963 |
1. 'mean', which calculates the mean of instance scores,
|
964 |
2. 'group_mean', which first applies an aggregation function specified in the reduction_map
|
965 |
-
|
966 |
-
|
967 |
-
|
|
|
968 |
"""
|
969 |
|
970 |
n_resamples: int = OptionalField(
|
@@ -1489,13 +1496,17 @@ class StringContainmentRatio(InstanceMetric):
|
|
1489 |
|
1490 |
Attributes:
|
1491 |
field: The field from the task_data that contains the values to be checked for containment.
|
1492 |
-
|
1493 |
-
|
1494 |
-
|
1495 |
-
|
1496 |
-
|
1497 |
-
|
1498 |
-
|
|
|
|
|
|
|
|
|
1499 |
"""
|
1500 |
|
1501 |
reduction_map = {"mean": ["string_containment"]}
|
@@ -2776,8 +2787,8 @@ class BertScore(HuggingfaceBulkMetric):
|
|
2776 |
|
2777 |
|
2778 |
class SentenceBert(BulkInstanceMetric):
|
2779 |
-
|
2780 |
-
|
2781 |
batch_size: int = 32
|
2782 |
|
2783 |
model_name: str
|
@@ -2823,12 +2834,12 @@ class SentenceBert(BulkInstanceMetric):
|
|
2823 |
refs_group_emb = refs_emb[ref_group_bounds[0] : ref_group_bounds[1]]
|
2824 |
scores.append(self.util.cos_sim(pred_emb, refs_group_emb).max().item())
|
2825 |
|
2826 |
-
return [{
|
2827 |
|
2828 |
|
2829 |
class Reward(BulkInstanceMetric):
|
2830 |
-
|
2831 |
-
|
2832 |
batch_size: int = 32
|
2833 |
|
2834 |
model_name: str
|
@@ -2864,12 +2875,15 @@ class Reward(BulkInstanceMetric):
|
|
2864 |
|
2865 |
# compute the metric
|
2866 |
# add function_to_apply="none" to disable sigmoid
|
2867 |
-
|
|
|
|
|
|
|
2868 |
|
2869 |
|
2870 |
class Detector(BulkInstanceMetric):
|
2871 |
-
|
2872 |
-
|
2873 |
batch_size: int = 32
|
2874 |
|
2875 |
prediction_type = str
|
@@ -2896,7 +2910,10 @@ class Detector(BulkInstanceMetric):
|
|
2896 |
) -> List[Dict[str, Any]]:
|
2897 |
# compute the metric
|
2898 |
# add function_to_apply="none" to disable sigmoid
|
2899 |
-
|
|
|
|
|
|
|
2900 |
|
2901 |
|
2902 |
class RegardMetric(GlobalMetric):
|
@@ -3537,13 +3554,13 @@ class Perplexity(BulkInstanceMetric):
|
|
3537 |
|
3538 |
|
3539 |
class FaithfulnessHHEM(BulkInstanceMetric):
|
3540 |
-
|
3541 |
-
main_score = "score"
|
3542 |
batch_size: int = 2
|
3543 |
model_name: str = "vectara/hallucination_evaluation_model"
|
3544 |
prediction_type = str
|
3545 |
single_reference_per_prediction = True
|
3546 |
max_context_words = 4096
|
|
|
3547 |
|
3548 |
_requirements_list: List[str] = ["transformers", "torch"]
|
3549 |
|
@@ -3587,7 +3604,7 @@ class FaithfulnessHHEM(BulkInstanceMetric):
|
|
3587 |
for input_batch in tqdm(input_batches, "input batch"):
|
3588 |
batch_scores = self.model.predict(input_batch).cpu().tolist()
|
3589 |
scores.extend(batch_scores)
|
3590 |
-
return [{
|
3591 |
|
3592 |
|
3593 |
class Squad(HuggingfaceMetric):
|
@@ -4019,18 +4036,21 @@ def performance_drop_rate(
|
|
4019 |
def interpret_effect_size(x: float):
|
4020 |
"""Return a string rule-of-thumb interpretation of an effect size value, as defined by Cohen/Sawilowsky.
|
4021 |
|
4022 |
-
See https://en.wikipedia.org/wiki/Effect_size
|
4023 |
-
Cohen, Jacob (1988). Statistical Power Analysis for the Behavioral Sciences; and
|
4024 |
-
Sawilowsky, S (2009). "New effect size rules of thumb". Journal of Modern Applied Statistical Methods. 8 (2): 467-474.
|
4025 |
|
4026 |
Value has interpretation of
|
4027 |
-
|
4028 |
-
-
|
4029 |
-
|
4030 |
-
|
4031 |
-
|
4032 |
-
|
4033 |
-
|
|
|
|
|
|
|
4034 |
|
4035 |
Args:
|
4036 |
x: float effect size value
|
@@ -4066,7 +4086,7 @@ def normalized_cohens_h(
|
|
4066 |
"""Cohen's h effect size between two proportions, normalized to interval [-1,1].
|
4067 |
|
4068 |
Allows for change-type metric when the baseline is 0 (percentage change, and thus PDR, is undefined)
|
4069 |
-
https://en.wikipedia.org/wiki/Cohen%27s_h
|
4070 |
|
4071 |
Cohen's h effect size metric between two proportions p2 and p1 is 2 * (arcsin(sqrt(p2)) - arcsin(sqrt(p1))).
|
4072 |
h in -pi, pi, with +/-pi representing the largest increase/decrease (p1=0, p2=1), or (p1=1, p2=0).
|
@@ -4077,6 +4097,9 @@ def normalized_cohens_h(
|
|
4077 |
Interpretation: the original unscaled Cohen's h can be interpreted according to function interpret_effect_size
|
4078 |
|
4079 |
Thus, the rule of interpreting the effect of the normalized value is to use the same thresholds divided by pi
|
|
|
|
|
|
|
4080 |
- essentially 0 if |norm h| < 0.0031831
|
4081 |
- very small if 0.0031831 <= |norm h| < 0.06366198
|
4082 |
- small difference if 0.06366198 <= |norm h| < 0.15915494
|
@@ -4084,12 +4107,17 @@ def normalized_cohens_h(
|
|
4084 |
- a large difference if 0.25464791 <= |norm h| < 0.38197186
|
4085 |
- a very large difference if 0.38197186 <= |norm h| < 0.63661977
|
4086 |
- a huge difference if 0.63661977 <= |norm h|
|
|
|
4087 |
Args:
|
4088 |
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
|
|
4089 |
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
|
|
4090 |
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
4091 |
-
|
|
|
4092 |
interpret: boolean, whether to interpret the significance of the score or not
|
|
|
4093 |
Returns:
|
4094 |
float score between -1 and 1, and a string interpretation if interpret=True
|
4095 |
"""
|
@@ -5118,3 +5146,112 @@ class PredictionLength(InstanceMetric):
|
|
5118 |
task_data: List[Dict],
|
5119 |
) -> dict:
|
5120 |
return {self.main_score: [len(prediction)], "score_name": self.main_score}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import ast
|
2 |
import json
|
3 |
+
import math
|
4 |
import os
|
5 |
import re
|
6 |
import string
|
|
|
28 |
)
|
29 |
from .deprecation_utils import deprecation
|
30 |
from .error_utils import Documentation, UnitxtWarning
|
31 |
+
from .inference import (
|
32 |
+
HFPipelineBasedInferenceEngine,
|
33 |
+
InferenceEngine,
|
34 |
+
WMLInferenceEngineGeneration,
|
35 |
+
)
|
36 |
from .logging_utils import get_logger
|
37 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
38 |
from .operator import (
|
|
|
965 |
"""Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
|
966 |
|
967 |
InstanceMetric currently allows two reductions:
|
968 |
+
|
969 |
1. 'mean', which calculates the mean of instance scores,
|
970 |
2. 'group_mean', which first applies an aggregation function specified in the reduction_map
|
971 |
+
to instance scores grouped by the field grouping_field (which must not be None), and returns the mean
|
972 |
+
of the group scores; if grouping_field is None, grouping is disabled.
|
973 |
+
See _validate_group_mean_reduction for formatting instructions.
|
974 |
+
|
975 |
"""
|
976 |
|
977 |
n_resamples: int = OptionalField(
|
|
|
1496 |
|
1497 |
Attributes:
|
1498 |
field: The field from the task_data that contains the values to be checked for containment.
|
1499 |
+
|
1500 |
+
Example task that contains this metric:
|
1501 |
+
|
1502 |
+
.. code-block:: python
|
1503 |
+
|
1504 |
+
Task(
|
1505 |
+
input_fields={"question": str},
|
1506 |
+
reference_fields={"entities": str},
|
1507 |
+
prediction_type=str,
|
1508 |
+
metrics=["string_containment_ratio[field=entities]"],
|
1509 |
+
)
|
1510 |
"""
|
1511 |
|
1512 |
reduction_map = {"mean": ["string_containment"]}
|
|
|
2787 |
|
2788 |
|
2789 |
class SentenceBert(BulkInstanceMetric):
|
2790 |
+
main_score = "sbert_score"
|
2791 |
+
reduction_map = {"mean": [main_score]}
|
2792 |
batch_size: int = 32
|
2793 |
|
2794 |
model_name: str
|
|
|
2834 |
refs_group_emb = refs_emb[ref_group_bounds[0] : ref_group_bounds[1]]
|
2835 |
scores.append(self.util.cos_sim(pred_emb, refs_group_emb).max().item())
|
2836 |
|
2837 |
+
return [{self.main_score: score} for score in scores]
|
2838 |
|
2839 |
|
2840 |
class Reward(BulkInstanceMetric):
|
2841 |
+
main_score = "reward_score"
|
2842 |
+
reduction_map = {"mean": [main_score]}
|
2843 |
batch_size: int = 32
|
2844 |
|
2845 |
model_name: str
|
|
|
2875 |
|
2876 |
# compute the metric
|
2877 |
# add function_to_apply="none" to disable sigmoid
|
2878 |
+
results = self.pipe(inputs, batch_size=self.batch_size)
|
2879 |
+
for result in results:
|
2880 |
+
result[self.main_score] = result["score"]
|
2881 |
+
return results
|
2882 |
|
2883 |
|
2884 |
class Detector(BulkInstanceMetric):
|
2885 |
+
main_score = "detector_score"
|
2886 |
+
reduction_map = {"mean": [main_score]}
|
2887 |
batch_size: int = 32
|
2888 |
|
2889 |
prediction_type = str
|
|
|
2910 |
) -> List[Dict[str, Any]]:
|
2911 |
# compute the metric
|
2912 |
# add function_to_apply="none" to disable sigmoid
|
2913 |
+
results = self.pipe(predictions, batch_size=self.batch_size)
|
2914 |
+
for result in results:
|
2915 |
+
result[self.main_score] = result["score"]
|
2916 |
+
return results
|
2917 |
|
2918 |
|
2919 |
class RegardMetric(GlobalMetric):
|
|
|
3554 |
|
3555 |
|
3556 |
class FaithfulnessHHEM(BulkInstanceMetric):
|
3557 |
+
main_score = "hhem_score"
|
|
|
3558 |
batch_size: int = 2
|
3559 |
model_name: str = "vectara/hallucination_evaluation_model"
|
3560 |
prediction_type = str
|
3561 |
single_reference_per_prediction = True
|
3562 |
max_context_words = 4096
|
3563 |
+
reduction_map = {"mean": [main_score]}
|
3564 |
|
3565 |
_requirements_list: List[str] = ["transformers", "torch"]
|
3566 |
|
|
|
3604 |
for input_batch in tqdm(input_batches, "input batch"):
|
3605 |
batch_scores = self.model.predict(input_batch).cpu().tolist()
|
3606 |
scores.extend(batch_scores)
|
3607 |
+
return [{self.main_score: score} for score in scores]
|
3608 |
|
3609 |
|
3610 |
class Squad(HuggingfaceMetric):
|
|
|
4036 |
def interpret_effect_size(x: float):
|
4037 |
"""Return a string rule-of-thumb interpretation of an effect size value, as defined by Cohen/Sawilowsky.
|
4038 |
|
4039 |
+
| See `Effect size <https://en.wikipedia.org/wiki/Effect_size>`_
|
4040 |
+
| Cohen, Jacob (1988). Statistical Power Analysis for the Behavioral Sciences; and
|
4041 |
+
| Sawilowsky, S (2009). "New effect size rules of thumb". Journal of Modern Applied Statistical Methods. 8 (2): 467-474.
|
4042 |
|
4043 |
Value has interpretation of
|
4044 |
+
|
4045 |
+
.. code-block:: text
|
4046 |
+
|
4047 |
+
- essentially 0 if |x| < 0.01
|
4048 |
+
- very small if 0.01 <= |x| < 0.2
|
4049 |
+
- small difference if 0.2 <= |x| < 0.5
|
4050 |
+
- a medium difference if 0.5 <= |x| < 0.8
|
4051 |
+
- a large difference if 0.8 <= |x| < 1.2
|
4052 |
+
- a very large difference if 1.2 <= |x| < 2.0
|
4053 |
+
- a huge difference if 2.0 <= |x|
|
4054 |
|
4055 |
Args:
|
4056 |
x: float effect size value
|
|
|
4086 |
"""Cohen's h effect size between two proportions, normalized to interval [-1,1].
|
4087 |
|
4088 |
Allows for change-type metric when the baseline is 0 (percentage change, and thus PDR, is undefined)
|
4089 |
+
`Conhen's h <https://en.wikipedia.org/wiki/Cohen%27s_h>`_
|
4090 |
|
4091 |
Cohen's h effect size metric between two proportions p2 and p1 is 2 * (arcsin(sqrt(p2)) - arcsin(sqrt(p1))).
|
4092 |
h in -pi, pi, with +/-pi representing the largest increase/decrease (p1=0, p2=1), or (p1=1, p2=0).
|
|
|
4097 |
Interpretation: the original unscaled Cohen's h can be interpreted according to function interpret_effect_size
|
4098 |
|
4099 |
Thus, the rule of interpreting the effect of the normalized value is to use the same thresholds divided by pi
|
4100 |
+
|
4101 |
+
.. code-block:: text
|
4102 |
+
|
4103 |
- essentially 0 if |norm h| < 0.0031831
|
4104 |
- very small if 0.0031831 <= |norm h| < 0.06366198
|
4105 |
- small difference if 0.06366198 <= |norm h| < 0.15915494
|
|
|
4107 |
- a large difference if 0.25464791 <= |norm h| < 0.38197186
|
4108 |
- a very large difference if 0.38197186 <= |norm h| < 0.63661977
|
4109 |
- a huge difference if 0.63661977 <= |norm h|
|
4110 |
+
|
4111 |
Args:
|
4112 |
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
4113 |
+
|
4114 |
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
4115 |
+
|
4116 |
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
4117 |
+
to be compared to the control group.
|
4118 |
+
|
4119 |
interpret: boolean, whether to interpret the significance of the score or not
|
4120 |
+
|
4121 |
Returns:
|
4122 |
float score between -1 and 1, and a string interpretation if interpret=True
|
4123 |
"""
|
|
|
5146 |
task_data: List[Dict],
|
5147 |
) -> dict:
|
5148 |
return {self.main_score: [len(prediction)], "score_name": self.main_score}
|
5149 |
+
|
5150 |
+
|
5151 |
+
class GraniteGuardianWMLMetric(InstanceMetric):
|
5152 |
+
"""Return metric for different kinds of "risk" from the Granite-3.0 Guardian model."""
|
5153 |
+
|
5154 |
+
main_score = "granite_guardian"
|
5155 |
+
reduction_map: Dict[str, List[str]] = None
|
5156 |
+
prediction_type = float
|
5157 |
+
|
5158 |
+
model_name: str = "ibm/granite-guardian-3-8b"
|
5159 |
+
hf_model_name: str = "ibm-granite/granite-guardian-3.0-8b"
|
5160 |
+
safe_token = "No"
|
5161 |
+
unsafe_token = "Yes"
|
5162 |
+
|
5163 |
+
inference_engine: WMLInferenceEngineGeneration = None
|
5164 |
+
generation_params: Dict = None
|
5165 |
+
risk_name: str = None
|
5166 |
+
|
5167 |
+
_requirements_list: List[str] = ["ibm_watsonx_ai", "torch", "transformers"]
|
5168 |
+
|
5169 |
+
def prepare(self):
|
5170 |
+
self.reduction_map = {"mean": [self.main_score]}
|
5171 |
+
|
5172 |
+
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
5173 |
+
from transformers import AutoTokenizer
|
5174 |
+
|
5175 |
+
if not hasattr(self, "_tokenizer") or self._tokenizer is None:
|
5176 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.hf_model_name)
|
5177 |
+
self.inference_engine = WMLInferenceEngineGeneration(
|
5178 |
+
model_name=self.model_name,
|
5179 |
+
)
|
5180 |
+
self.inference_engine._load_model()
|
5181 |
+
self.model = self.inference_engine._model
|
5182 |
+
self.generation_params = self.inference_engine._set_logprobs_params({})
|
5183 |
+
|
5184 |
+
messages = self.process_input_fields(task_data)
|
5185 |
+
guardian_config = {"risk_name": self.risk_name}
|
5186 |
+
processed_input = self._tokenizer.apply_chat_template(
|
5187 |
+
messages,
|
5188 |
+
guardian_config=guardian_config,
|
5189 |
+
tokenize=False,
|
5190 |
+
add_generation_prompt=True,
|
5191 |
+
)
|
5192 |
+
|
5193 |
+
result = self.model.generate(
|
5194 |
+
prompt=[processed_input],
|
5195 |
+
params=self.generation_params,
|
5196 |
+
)
|
5197 |
+
generated_tokens_list = result[0]["results"][0]["generated_tokens"]
|
5198 |
+
label, prob_of_risk = self.parse_output(generated_tokens_list)
|
5199 |
+
score = 1 - prob_of_risk if label is not None else np.nan
|
5200 |
+
return {self.main_score: score}
|
5201 |
+
|
5202 |
+
def process_input_fields(self, task_data):
|
5203 |
+
if self.risk_name == "groundedness":
|
5204 |
+
messages = [
|
5205 |
+
{"role": "context", "content": "\n".join(task_data["contexts"])},
|
5206 |
+
{"role": "assistant", "content": task_data["answer"]},
|
5207 |
+
]
|
5208 |
+
elif self.risk_name == "answer_relevance":
|
5209 |
+
messages = [
|
5210 |
+
{"role": "user", "content": task_data["question"]},
|
5211 |
+
{"role": "assistant", "content": task_data["answer"]},
|
5212 |
+
]
|
5213 |
+
elif self.risk_name == "context_relevance":
|
5214 |
+
messages = [
|
5215 |
+
{"role": "user", "content": task_data["question"]},
|
5216 |
+
{"role": "context", "content": "\n".join(task_data["contexts"])},
|
5217 |
+
]
|
5218 |
+
else:
|
5219 |
+
raise NotImplementedError()
|
5220 |
+
|
5221 |
+
return messages
|
5222 |
+
|
5223 |
+
def parse_output(self, generated_tokens_list):
|
5224 |
+
top_tokens_list = [
|
5225 |
+
generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list
|
5226 |
+
]
|
5227 |
+
prob = self.get_probabilities(top_tokens_list)
|
5228 |
+
prob_of_risk = prob[1]
|
5229 |
+
|
5230 |
+
res = next(iter(generated_tokens_list))["text"].strip()
|
5231 |
+
|
5232 |
+
if self.unsafe_token.lower() == res.lower():
|
5233 |
+
label = self.unsafe_token
|
5234 |
+
elif self.safe_token.lower() == res.lower():
|
5235 |
+
label = self.safe_token
|
5236 |
+
else:
|
5237 |
+
label = None
|
5238 |
+
|
5239 |
+
return label, prob_of_risk
|
5240 |
+
|
5241 |
+
def get_probabilities(self, top_tokens_list):
|
5242 |
+
import torch
|
5243 |
+
|
5244 |
+
safe_token_prob = 1e-50
|
5245 |
+
unsafe_token_prob = 1e-50
|
5246 |
+
|
5247 |
+
for top_tokens in top_tokens_list:
|
5248 |
+
for token in top_tokens:
|
5249 |
+
if token["text"].strip().lower() == self.safe_token.lower():
|
5250 |
+
safe_token_prob += math.exp(token["logprob"])
|
5251 |
+
if token["text"].strip().lower() == self.unsafe_token.lower():
|
5252 |
+
unsafe_token_prob += math.exp(token["logprob"])
|
5253 |
+
|
5254 |
+
return torch.softmax(
|
5255 |
+
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]),
|
5256 |
+
dim=0,
|
5257 |
+
).numpy()
|
operators.py
CHANGED
@@ -137,34 +137,39 @@ class MapInstanceValues(InstanceOperator):
|
|
137 |
|
138 |
Attributes:
|
139 |
mappers (Dict[str, Dict[str, Any]]): The mappers to use for mapping instance values.
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
142 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
143 |
-
|
144 |
-
|
|
|
145 |
process_every_value (bool): If True, all fields to be mapped should be lists, and the mapping
|
146 |
-
|
147 |
-
|
148 |
|
149 |
Examples:
|
150 |
-
MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})
|
151 |
-
replaces
|
152 |
-
instance {"a":
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
168 |
"""
|
169 |
|
170 |
mappers: Dict[str, Dict[str, str]]
|
@@ -234,27 +239,25 @@ class FlattenInstances(InstanceOperator):
|
|
234 |
|
235 |
|
236 |
class Set(InstanceOperator):
|
237 |
-
"""
|
238 |
|
239 |
Args:
|
240 |
-
fields (Dict[str, object]): The fields to add to each instance.
|
241 |
-
|
242 |
use_deepcopy (bool) : Deep copy the input value to avoid later modifications
|
243 |
|
244 |
Examples:
|
245 |
-
#
|
246 |
-
Set(fields={"classes": ["positive","negatives"]})
|
247 |
|
248 |
-
#
|
249 |
-
Set(fields={"span/start": 0}
|
250 |
|
251 |
-
#
|
252 |
-
Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})
|
253 |
|
254 |
-
#
|
255 |
-
|
256 |
-
Set(fields={"classes": alist}), use_deepcopy=True)
|
257 |
-
# if now alist is modified, still the instances remain intact.
|
258 |
"""
|
259 |
|
260 |
fields: Dict[str, object]
|
@@ -333,22 +336,26 @@ class InstanceFieldOperator(InstanceOperator):
|
|
333 |
|
334 |
Args:
|
335 |
field (Optional[str]): The field to process, if only a single one is passed. Defaults to None
|
|
|
336 |
to_field (Optional[str]): Field name to save result into, if only one field is processed, if None is passed the
|
337 |
-
|
|
|
338 |
field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]): Mapping from names of fields to process,
|
339 |
-
|
340 |
-
|
341 |
is mapped to the field.
|
342 |
-
|
343 |
-
in the (outer) List. But when the type of argument
|
344 |
order. The end result might depend on that order if either (1) two different fields are mapped to the same
|
345 |
to_field, or (2) a field shows both as a key and as a value in different mappings.
|
346 |
-
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
349 |
|
350 |
-
Note: if 'field' and 'to_field' (or both members of a pair in 'field_to_field') are equal (or share a common
|
351 |
-
prefix if 'field' and 'to_field' contain a /), then the result of the operation is saved within 'field'
|
352 |
"""
|
353 |
|
354 |
field: Optional[str] = None
|
@@ -577,17 +584,18 @@ class Apply(InstanceOperator):
|
|
577 |
Args:
|
578 |
function (str): name of function.
|
579 |
to_field (str): the field to store the result
|
580 |
-
|
|
|
581 |
|
582 |
Examples:
|
583 |
-
Store in field "b" the uppercase string of the value in field "a"
|
584 |
-
Apply("a", function=str.upper, to_field="b")
|
585 |
|
586 |
-
Dump the json representation of field "t" and store back in the same field
|
587 |
-
Apply("t", function=json.dumps, to_field="t")
|
588 |
|
589 |
-
Set the time in a field 'b'
|
590 |
-
Apply(function=time.time, to_field="b")
|
591 |
|
592 |
"""
|
593 |
|
@@ -667,14 +675,13 @@ class ListFieldValues(InstanceOperator):
|
|
667 |
|
668 |
|
669 |
class ZipFieldValues(InstanceOperator):
|
670 |
-
"""Zips values of multiple fields in a given instance, similar to list(zip(*fields))
|
671 |
|
672 |
The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
|
673 |
are zipped, and stored into 'to_field'.
|
674 |
|
675 |
-
If 'longest'=False, the length of the zipped result is determined by the shortest input value.
|
676 |
-
If 'longest'=
|
677 |
-
inputs with None -s.
|
678 |
|
679 |
"""
|
680 |
|
@@ -706,11 +713,11 @@ class ZipFieldValues(InstanceOperator):
|
|
706 |
class InterleaveListsToDialogOperator(InstanceOperator):
|
707 |
"""Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
|
708 |
|
709 |
-
|
710 |
-
|
711 |
|
712 |
The user turns and assistant turns field are specified in the arguments.
|
713 |
-
|
714 |
|
715 |
"""
|
716 |
|
@@ -854,13 +861,13 @@ class Copy(FieldOperator):
|
|
854 |
|
855 |
Examples:
|
856 |
An input instance {"a": 2, "b": 3}, when processed by
|
857 |
-
Copy(field_to_field={"a": "b"}
|
858 |
would yield {"a": 2, "b": 2}, and when processed by
|
859 |
-
Copy(field_to_field={"a": "c"} would yield
|
860 |
{"a": 2, "b": 3, "c": 2}
|
861 |
|
862 |
with field names containing / , we can also copy inside the field:
|
863 |
-
Copy(field="a/0",to_field="a")
|
864 |
would process instance {"a": [1, 3]} into {"a": 1}
|
865 |
|
866 |
|
@@ -930,32 +937,41 @@ class CastFields(InstanceOperator):
|
|
930 |
"""Casts specified fields to specified types.
|
931 |
|
932 |
Args:
|
933 |
-
use_nested_query (bool): Whether to cast nested fields, expressed in dpath. Defaults to False.
|
934 |
fields (Dict[str, str]): A dictionary mapping field names to the names of the types to cast the fields to.
|
935 |
-
|
|
|
936 |
defaults (Dict[str, object]): A dictionary mapping field names to default values for cases of casting failure.
|
|
|
937 |
process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
|
938 |
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
|
|
|
|
948 |
|
949 |
"""
|
950 |
|
951 |
fields: Dict[str, str] = field(default_factory=dict)
|
952 |
failure_defaults: Dict[str, object] = field(default_factory=dict)
|
953 |
-
use_nested_query: bool =
|
954 |
process_every_value: bool = False
|
955 |
|
956 |
def prepare(self):
|
957 |
self.types = {"int": int, "float": float, "str": str, "bool": bool}
|
958 |
|
|
|
|
|
|
|
|
|
|
|
|
|
959 |
def _cast_single(self, value, type, field):
|
960 |
try:
|
961 |
return self.types[type](value)
|
@@ -1093,18 +1109,18 @@ class FilterByCondition(StreamOperator):
|
|
1093 |
|
1094 |
Args:
|
1095 |
values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
|
|
|
1096 |
condition: the name of the desired condition operator between the specified (sub) field's value and the provided constant value. Supported conditions are ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
|
|
|
1097 |
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
|
1098 |
|
1099 |
Examples:
|
1100 |
-
FilterByCondition(values = {"a":4}, condition = "gt") will yield only instances where field "a" contains a value
|
1101 |
-
FilterByCondition(values = {"a":4}, condition = "le") will yield only instances where "a"<=4
|
1102 |
-
FilterByCondition(values = {"a":[4,8]}, condition = "in") will yield only instances where "a" is 4 or 8
|
1103 |
-
FilterByCondition(values = {"a":[4,8]}, condition = "not in") will yield only instances where "a" different from 4 or 8
|
1104 |
-
FilterByCondition(values = {"a/b":[4,8]}, condition = "not in") will yield only instances where "a" is
|
1105 |
-
|
1106 |
-
FilterByCondition(values = {"a[2]":4}, condition = "le") will yield only instances where "a" is a list whose 3-rd
|
1107 |
-
element is <= 4
|
1108 |
|
1109 |
|
1110 |
"""
|
@@ -1805,14 +1821,14 @@ class EncodeLabels(InstanceOperator):
|
|
1805 |
Args:
|
1806 |
fields (List[str]): The fields to encode together.
|
1807 |
|
1808 |
-
Example:
|
1809 |
-
EncodeLabels(fields = ["a", "b/*"])
|
1810 |
-
on input stream = [{"a": "red", "b": ["red", "blue"], "c":"bread"},
|
1811 |
-
{"a": "blue", "b": ["green"], "c":"water"}] will yield the
|
1812 |
-
output stream = [{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]
|
1813 |
|
1814 |
-
Note:
|
1815 |
-
input 'fields' with the appendix "/*" as in the above example.
|
1816 |
|
1817 |
"""
|
1818 |
|
@@ -2132,21 +2148,23 @@ class CollateInstances(StreamOperator):
|
|
2132 |
batch_size (int)
|
2133 |
|
2134 |
Example:
|
2135 |
-
|
2136 |
-
|
2137 |
-
Given inputs = [
|
2138 |
-
{"a": 1, "b": 2},
|
2139 |
-
{"a": 2, "b": 2},
|
2140 |
-
{"a": 3, "b": 2},
|
2141 |
-
{"a": 4, "b": 2},
|
2142 |
-
{"a": 5, "b": 2}
|
2143 |
-
]
|
2144 |
|
2145 |
-
|
2146 |
-
|
2147 |
-
|
2148 |
-
|
2149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2150 |
|
2151 |
|
2152 |
"""
|
|
|
137 |
|
138 |
Attributes:
|
139 |
mappers (Dict[str, Dict[str, Any]]): The mappers to use for mapping instance values.
|
140 |
+
Keys are the names of the fields to undergo mapping, and values are dictionaries
|
141 |
+
that define the mapping from old values to new values.
|
142 |
+
Note that mapped values are defined by their string representation, so mapped values
|
143 |
+
are converted to strings before being looked up in the mappers.
|
144 |
+
|
145 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
146 |
+
does not exist in the mapper, it will raise a KeyError. If False, values
|
147 |
+
that are not present in the mapper are kept as they are.
|
148 |
+
|
149 |
process_every_value (bool): If True, all fields to be mapped should be lists, and the mapping
|
150 |
+
is to be applied to their individual elements. If False, mapping is only applied to a field
|
151 |
+
containing a single value.
|
152 |
|
153 |
Examples:
|
154 |
+
``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
|
155 |
+
replaces ``"1"`` with ``"hi"`` and ``"2"`` with ``"bye"`` in field ``"a"`` in all instances of all streams:
|
156 |
+
instance ``{"a": 1, "b": 2}`` becomes ``{"a": "hi", "b": 2}``. Note that the value of ``"b"`` remained intact,
|
157 |
+
since field-name ``"b"`` does not participate in the mappers, and that ``1`` was casted to ``"1"`` before looked
|
158 |
+
up in the mapper of ``"a"``.
|
159 |
+
|
160 |
+
``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)``:
|
161 |
+
Assuming field ``"a"`` is a list of values, potentially including ``"1"``-s and ``"2"``-s, this replaces
|
162 |
+
each such ``"1"`` with ``"hi"`` and ``"2"`` -- with ``"bye"`` in all instances of all streams:
|
163 |
+
instance ``{"a": ["1", "2"], "b": 2}`` becomes ``{"a": ["hi", "bye"], "b": 2}``.
|
164 |
+
|
165 |
+
``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)``:
|
166 |
+
To ensure that all values of field ``"a"`` are mapped in every instance, use ``strict=True``.
|
167 |
+
Input instance ``{"a":"3", "b": 2}`` will raise an exception per the above call,
|
168 |
+
because ``"3"`` is not a key in the mapper of ``"a"``.
|
169 |
+
|
170 |
+
``MapInstanceValues(mappers={"a": {str([1,2,3,4]): "All", str([]): "None"}}, strict=True)``
|
171 |
+
replaces a list ``[1,2,3,4]`` with the string ``"All"`` and an empty list by string ``"None"``.
|
172 |
+
|
173 |
"""
|
174 |
|
175 |
mappers: Dict[str, Dict[str, str]]
|
|
|
239 |
|
240 |
|
241 |
class Set(InstanceOperator):
|
242 |
+
"""Sets specified fields in each instance, in a given stream or all streams (default), with specified values. If fields exist, updates them, if do not exist -- adds them.
|
243 |
|
244 |
Args:
|
245 |
+
fields (Dict[str, object]): The fields to add to each instance. Use '/' to access inner fields
|
246 |
+
|
247 |
use_deepcopy (bool) : Deep copy the input value to avoid later modifications
|
248 |
|
249 |
Examples:
|
250 |
+
# Set a value of a list consisting of "positive" and "negative" do field "classes" to each and every instance of all streams
|
251 |
+
``Set(fields={"classes": ["positive","negatives"]})``
|
252 |
|
253 |
+
# In each and every instance of all streams, field "span" is to become a dictionary containing a field "start", in which the value 0 is to be set
|
254 |
+
``Set(fields={"span/start": 0}``
|
255 |
|
256 |
+
# In all instances of stream "train" only, Set field "classes" to have the value of a list consisting of "positive" and "negative"
|
257 |
+
``Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})``
|
258 |
|
259 |
+
# Set field "classes" to have the value of a given list, preventing modification of original list from changing the instance.
|
260 |
+
``Set(fields={"classes": alist}), use_deepcopy=True)`` if now alist is modified, still the instances remain intact.
|
|
|
|
|
261 |
"""
|
262 |
|
263 |
fields: Dict[str, object]
|
|
|
336 |
|
337 |
Args:
|
338 |
field (Optional[str]): The field to process, if only a single one is passed. Defaults to None
|
339 |
+
|
340 |
to_field (Optional[str]): Field name to save result into, if only one field is processed, if None is passed the
|
341 |
+
operation would happen in-place and its result would replace the value of ``field``. Defaults to None
|
342 |
+
|
343 |
field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]): Mapping from names of fields to process,
|
344 |
+
to names of fields to save the results into. Inner List, if used, should be of length 2.
|
345 |
+
| A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
|
346 |
is mapped to the field.
|
347 |
+
| When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
|
348 |
+
in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
|
349 |
order. The end result might depend on that order if either (1) two different fields are mapped to the same
|
350 |
to_field, or (2) a field shows both as a key and as a value in different mappings.
|
351 |
+
| The operator throws an AssertionError in either of these cases.
|
352 |
+
| field_to_field defaults to None
|
353 |
+
|
354 |
+
process_every_value (bool): Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
|
355 |
+
|
356 |
+
Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
|
357 |
+
prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
|
358 |
|
|
|
|
|
359 |
"""
|
360 |
|
361 |
field: Optional[str] = None
|
|
|
584 |
Args:
|
585 |
function (str): name of function.
|
586 |
to_field (str): the field to store the result
|
587 |
+
|
588 |
+
any additional arguments are field names whose values will be passed directly to the function specified
|
589 |
|
590 |
Examples:
|
591 |
+
Store in field "b" the uppercase string of the value in field "a":
|
592 |
+
``Apply("a", function=str.upper, to_field="b")``
|
593 |
|
594 |
+
Dump the json representation of field "t" and store back in the same field:
|
595 |
+
``Apply("t", function=json.dumps, to_field="t")``
|
596 |
|
597 |
+
Set the time in a field 'b':
|
598 |
+
``Apply(function=time.time, to_field="b")``
|
599 |
|
600 |
"""
|
601 |
|
|
|
675 |
|
676 |
|
677 |
class ZipFieldValues(InstanceOperator):
|
678 |
+
"""Zips values of multiple fields in a given instance, similar to ``list(zip(*fields))``.
|
679 |
|
680 |
The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
|
681 |
are zipped, and stored into 'to_field'.
|
682 |
|
683 |
+
| If 'longest'=False, the length of the zipped result is determined by the shortest input value.
|
684 |
+
| If 'longest'=True, the length of the zipped result is determined by the longest input, padding shorter inputs with None-s.
|
|
|
685 |
|
686 |
"""
|
687 |
|
|
|
713 |
class InterleaveListsToDialogOperator(InstanceOperator):
|
714 |
"""Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
|
715 |
|
716 |
+
The list of tuples if of format (role, turn_content), where the role label is specified by
|
717 |
+
the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
|
718 |
|
719 |
The user turns and assistant turns field are specified in the arguments.
|
720 |
+
The value of each of the 'fields' is assumed to be a list.
|
721 |
|
722 |
"""
|
723 |
|
|
|
861 |
|
862 |
Examples:
|
863 |
An input instance {"a": 2, "b": 3}, when processed by
|
864 |
+
``Copy(field_to_field={"a": "b"})``
|
865 |
would yield {"a": 2, "b": 2}, and when processed by
|
866 |
+
``Copy(field_to_field={"a": "c"})`` would yield
|
867 |
{"a": 2, "b": 3, "c": 2}
|
868 |
|
869 |
with field names containing / , we can also copy inside the field:
|
870 |
+
``Copy(field="a/0",to_field="a")``
|
871 |
would process instance {"a": [1, 3]} into {"a": 1}
|
872 |
|
873 |
|
|
|
937 |
"""Casts specified fields to specified types.
|
938 |
|
939 |
Args:
|
|
|
940 |
fields (Dict[str, str]): A dictionary mapping field names to the names of the types to cast the fields to.
|
941 |
+
e.g: "int", "str", "float", "bool". Basic names of types
|
942 |
+
|
943 |
defaults (Dict[str, object]): A dictionary mapping field names to default values for cases of casting failure.
|
944 |
+
|
945 |
process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
|
946 |
|
947 |
+
Example:
|
948 |
+
.. code-block:: python
|
949 |
+
|
950 |
+
CastFields(
|
951 |
+
fields={"a/d": "float", "b": "int"},
|
952 |
+
failure_defaults={"a/d": 0.0, "b": 0},
|
953 |
+
process_every_value=True,
|
954 |
+
)
|
955 |
+
|
956 |
+
would process the input instance: ``{"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}``
|
957 |
+
into ``{"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}``.
|
958 |
|
959 |
"""
|
960 |
|
961 |
fields: Dict[str, str] = field(default_factory=dict)
|
962 |
failure_defaults: Dict[str, object] = field(default_factory=dict)
|
963 |
+
use_nested_query: bool = None # deprecated field
|
964 |
process_every_value: bool = False
|
965 |
|
966 |
def prepare(self):
|
967 |
self.types = {"int": int, "float": float, "str": str, "bool": bool}
|
968 |
|
969 |
+
def verify(self):
|
970 |
+
super().verify()
|
971 |
+
if self.use_nested_query is not None:
|
972 |
+
depr_message = "Field 'use_nested_query' is deprecated. From now on, default behavior is compatible to use_nested_query=True. Please remove this field from your code."
|
973 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
974 |
+
|
975 |
def _cast_single(self, value, type, field):
|
976 |
try:
|
977 |
return self.types[type](value)
|
|
|
1109 |
|
1110 |
Args:
|
1111 |
values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
|
1112 |
+
|
1113 |
condition: the name of the desired condition operator between the specified (sub) field's value and the provided constant value. Supported conditions are ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
|
1114 |
+
|
1115 |
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
|
1116 |
|
1117 |
Examples:
|
1118 |
+
| ``FilterByCondition(values = {"a":4}, condition = "gt")`` will yield only instances where field ``"a"`` contains a value ``> 4``
|
1119 |
+
| ``FilterByCondition(values = {"a":4}, condition = "le")`` will yield only instances where ``"a"<=4``
|
1120 |
+
| ``FilterByCondition(values = {"a":[4,8]}, condition = "in")`` will yield only instances where ``"a"`` is ``4`` or ``8``
|
1121 |
+
| ``FilterByCondition(values = {"a":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is different from ``4`` or ``8``
|
1122 |
+
| ``FilterByCondition(values = {"a/b":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is a dict in which key ``"b"`` is mapped to a value that is neither ``4`` nor ``8``
|
1123 |
+
| ``FilterByCondition(values = {"a[2]":4}, condition = "le")`` will yield only instances where "a" is a list whose 3-rd element is ``<= 4``
|
|
|
|
|
1124 |
|
1125 |
|
1126 |
"""
|
|
|
1821 |
Args:
|
1822 |
fields (List[str]): The fields to encode together.
|
1823 |
|
1824 |
+
Example:
|
1825 |
+
applying ``EncodeLabels(fields = ["a", "b/*"])``
|
1826 |
+
on input stream = ``[{"a": "red", "b": ["red", "blue"], "c":"bread"},
|
1827 |
+
{"a": "blue", "b": ["green"], "c":"water"}]`` will yield the
|
1828 |
+
output stream = ``[{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]``
|
1829 |
|
1830 |
+
Note: dict_utils are applied here, and hence, fields that are lists, should be included in
|
1831 |
+
input 'fields' with the appendix ``"/*"`` as in the above example.
|
1832 |
|
1833 |
"""
|
1834 |
|
|
|
2148 |
batch_size (int)
|
2149 |
|
2150 |
Example:
|
2151 |
+
.. code-block:: text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2152 |
|
2153 |
+
CollateInstances(batch_size=2)
|
2154 |
+
|
2155 |
+
Given inputs = [
|
2156 |
+
{"a": 1, "b": 2},
|
2157 |
+
{"a": 2, "b": 2},
|
2158 |
+
{"a": 3, "b": 2},
|
2159 |
+
{"a": 4, "b": 2},
|
2160 |
+
{"a": 5, "b": 2}
|
2161 |
+
]
|
2162 |
+
|
2163 |
+
Returns targets = [
|
2164 |
+
{"a": [1,2], "b": [2,2]},
|
2165 |
+
{"a": [3,4], "b": [2,2]},
|
2166 |
+
{"a": [5], "b": [2]},
|
2167 |
+
]
|
2168 |
|
2169 |
|
2170 |
"""
|
span_lableing_operators.py
CHANGED
@@ -8,29 +8,35 @@ class IobExtractor(InstanceOperator):
|
|
8 |
|
9 |
Attributes:
|
10 |
labels (List[str]): A list of entity type labels, e.g., ["Person", "Organization", "Location"].
|
|
|
11 |
begin_labels (List[str]): A list of labels indicating the beginning of an entity, e.g., ["B-PER", "B-ORG", "B-LOC"].
|
|
|
12 |
inside_labels (List[str]): A list of labels indicating the continuation of an entity, e.g., ["I-PER", "I-ORG", "I-LOC"].
|
|
|
13 |
outside_label (str): The label indicating tokens outside of any entity, typically "O".
|
14 |
|
15 |
The extraction process identifies spans of text corresponding to entities and labels them according to their entity type. Each span is annotated with a start and end character offset, the entity text, and the corresponding label.
|
16 |
|
|
|
|
|
17 |
Example of instantiation and usage:
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
For more details on the IOB tagging convention, see: https://en.wikipedia.org/wiki/Inside-outside-beginning_(tagging)
|
36 |
|
|
|
8 |
|
9 |
Attributes:
|
10 |
labels (List[str]): A list of entity type labels, e.g., ["Person", "Organization", "Location"].
|
11 |
+
|
12 |
begin_labels (List[str]): A list of labels indicating the beginning of an entity, e.g., ["B-PER", "B-ORG", "B-LOC"].
|
13 |
+
|
14 |
inside_labels (List[str]): A list of labels indicating the continuation of an entity, e.g., ["I-PER", "I-ORG", "I-LOC"].
|
15 |
+
|
16 |
outside_label (str): The label indicating tokens outside of any entity, typically "O".
|
17 |
|
18 |
The extraction process identifies spans of text corresponding to entities and labels them according to their entity type. Each span is annotated with a start and end character offset, the entity text, and the corresponding label.
|
19 |
|
20 |
+
|
21 |
+
|
22 |
Example of instantiation and usage:
|
23 |
+
|
24 |
+
.. code-block:: python
|
25 |
+
|
26 |
+
operator = IobExtractor(
|
27 |
+
labels=["Person", "Organization", "Location"],
|
28 |
+
begin_labels=["B-PER", "B-ORG", "B-LOC"],
|
29 |
+
inside_labels=["I-PER", "I-ORG", "I-LOC"],
|
30 |
+
outside_label="O",
|
31 |
+
)
|
32 |
+
|
33 |
+
instance = {
|
34 |
+
"labels": ["B-PER", "I-PER", "O", "B-ORG", "I-ORG"],
|
35 |
+
"tokens": ["John", "Doe", "works", "at", "OpenAI"]
|
36 |
+
}
|
37 |
+
processed_instance = operator.process(instance)
|
38 |
+
print(processed_instance["spans"])
|
39 |
+
# Output: [{'start': 0, 'end': 8, 'text': 'John Doe', 'label': 'Person'}, ...]
|
40 |
|
41 |
For more details on the IOB tagging convention, see: https://en.wikipedia.org/wiki/Inside-outside-beginning_(tagging)
|
42 |
|
struct_data_operators.py
CHANGED
@@ -2,17 +2,25 @@
|
|
2 |
|
3 |
These operators are specialized in handling structured data like tables.
|
4 |
For tables, expected input format is:
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
For triples, expected input format is:
|
11 |
-
|
|
|
|
|
|
|
12 |
|
13 |
For key-value pairs, expected input format is:
|
14 |
-
|
15 |
-
|
|
|
|
|
16 |
"""
|
17 |
|
18 |
import json
|
@@ -148,11 +156,15 @@ class SerializeTableAsMarkdown(SerializeTable):
|
|
148 |
|
149 |
Markdown table format is used in GitHub code primarily.
|
150 |
Format:
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
156 |
"""
|
157 |
|
158 |
# main method that serializes a table.
|
@@ -192,11 +204,14 @@ class SerializeTableAsDFLoader(SerializeTable):
|
|
192 |
|
193 |
Pandas dataframe based code snippet format serializer.
|
194 |
Format(Sample):
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
200 |
"""
|
201 |
|
202 |
# main method that serializes a table.
|
@@ -234,11 +249,14 @@ class SerializeTableAsJson(SerializeTable):
|
|
234 |
|
235 |
Json format based serializer.
|
236 |
Format(Sample):
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
242 |
"""
|
243 |
|
244 |
# main method that serializes a table.
|
@@ -264,15 +282,18 @@ class SerializeTableAsHTML(SerializeTable):
|
|
264 |
|
265 |
HTML table format used for rendering tables in web pages.
|
266 |
Format(Sample):
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
276 |
"""
|
277 |
|
278 |
# main method that serializes a table.
|
@@ -404,7 +425,7 @@ class TruncateTableRows(FieldOperator):
|
|
404 |
"""Limits table rows to specified limit by removing excess rows via random selection.
|
405 |
|
406 |
Args:
|
407 |
-
rows_to_keep (int)
|
408 |
"""
|
409 |
|
410 |
rows_to_keep: int = 10
|
@@ -563,16 +584,19 @@ class ListToKeyValPairs(InstanceOperator):
|
|
563 |
class ConvertTableColNamesToSequential(FieldOperator):
|
564 |
"""Replaces actual table column names with static sequential names like col_0, col_1,...
|
565 |
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
|
|
|
|
|
|
576 |
"""
|
577 |
|
578 |
def process_value(self, table: Any) -> Any:
|
@@ -595,17 +619,19 @@ class ConvertTableColNamesToSequential(FieldOperator):
|
|
595 |
class ShuffleTableRows(TypeDependentAugmentor):
|
596 |
"""Shuffles the input table rows randomly.
|
597 |
|
598 |
-
|
599 |
-
{
|
600 |
-
"header": ["name", "age"],
|
601 |
-
"rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
|
602 |
-
}
|
603 |
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
"""
|
610 |
|
611 |
augmented_type = Table
|
@@ -619,17 +645,19 @@ class ShuffleTableRows(TypeDependentAugmentor):
|
|
619 |
class ShuffleTableColumns(TypeDependentAugmentor):
|
620 |
"""Shuffles the table columns randomly.
|
621 |
|
622 |
-
|
623 |
-
{
|
624 |
-
"header": ["name", "age"],
|
625 |
-
"rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
|
626 |
-
}
|
627 |
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
"""
|
634 |
|
635 |
augmented_type = Table
|
@@ -662,11 +690,14 @@ class DumpJson(FieldOperator):
|
|
662 |
class MapHTMLTableToJSON(FieldOperator):
|
663 |
"""Converts HTML table format to the basic one (JSON).
|
664 |
|
665 |
-
JSON format
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
|
|
|
|
|
|
670 |
"""
|
671 |
|
672 |
_requirements_list = ["bs4"]
|
@@ -701,11 +732,14 @@ class MapHTMLTableToJSON(FieldOperator):
|
|
701 |
class MapTableListsToStdTableJSON(FieldOperator):
|
702 |
"""Converts lists table format to the basic one (JSON).
|
703 |
|
704 |
-
JSON format
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
|
|
|
|
|
|
709 |
"""
|
710 |
|
711 |
def process_value(self, table: Any) -> Any:
|
@@ -755,17 +789,20 @@ class ConstructTableFromRowsCols(InstanceOperator):
|
|
755 |
class TransposeTable(TypeDependentAugmentor):
|
756 |
"""Transpose a table.
|
757 |
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
|
764 |
-
Sample Output:
|
765 |
-
{
|
766 |
-
"header": [" ", "0", "1", "2"],
|
767 |
-
"rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
|
768 |
-
}
|
769 |
"""
|
770 |
|
771 |
augmented_type = Table
|
@@ -791,8 +828,9 @@ class DuplicateTableRows(TypeDependentAugmentor):
|
|
791 |
"""Duplicates specific rows of a table for the given number of times.
|
792 |
|
793 |
Args:
|
794 |
-
row_indices (List[int])
|
795 |
-
|
|
|
796 |
"""
|
797 |
|
798 |
augmented_type = Table
|
@@ -823,8 +861,9 @@ class DuplicateTableColumns(TypeDependentAugmentor):
|
|
823 |
"""Duplicates specific columns of a table for the given number of times.
|
824 |
|
825 |
Args:
|
826 |
-
column_indices (List[int])
|
827 |
-
|
|
|
828 |
"""
|
829 |
|
830 |
augmented_type = Table
|
|
|
2 |
|
3 |
These operators are specialized in handling structured data like tables.
|
4 |
For tables, expected input format is:
|
5 |
+
|
6 |
+
.. code-block:: text
|
7 |
+
|
8 |
+
{
|
9 |
+
"header": ["col1", "col2"],
|
10 |
+
"rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
|
11 |
+
}
|
12 |
|
13 |
For triples, expected input format is:
|
14 |
+
|
15 |
+
.. code-block:: text
|
16 |
+
|
17 |
+
[[ "subject1", "relation1", "object1" ], [ "subject1", "relation2", "object2"]]
|
18 |
|
19 |
For key-value pairs, expected input format is:
|
20 |
+
|
21 |
+
.. code-block:: text
|
22 |
+
|
23 |
+
{"key1": "value1", "key2": value2, "key3": "value3"}
|
24 |
"""
|
25 |
|
26 |
import json
|
|
|
156 |
|
157 |
Markdown table format is used in GitHub code primarily.
|
158 |
Format:
|
159 |
+
|
160 |
+
.. code-block:: text
|
161 |
+
|
162 |
+
|col1|col2|col3|
|
163 |
+
|---|---|---|
|
164 |
+
|A|4|1|
|
165 |
+
|I|2|1|
|
166 |
+
...
|
167 |
+
|
168 |
"""
|
169 |
|
170 |
# main method that serializes a table.
|
|
|
204 |
|
205 |
Pandas dataframe based code snippet format serializer.
|
206 |
Format(Sample):
|
207 |
+
|
208 |
+
.. code-block:: python
|
209 |
+
|
210 |
+
pd.DataFrame({
|
211 |
+
"name" : ["Alex", "Diana", "Donald"],
|
212 |
+
"age" : [26, 34, 39]
|
213 |
+
},
|
214 |
+
index=[0,1,2])
|
215 |
"""
|
216 |
|
217 |
# main method that serializes a table.
|
|
|
249 |
|
250 |
Json format based serializer.
|
251 |
Format(Sample):
|
252 |
+
|
253 |
+
.. code-block:: json
|
254 |
+
|
255 |
+
{
|
256 |
+
"0":{"name":"Alex","age":26},
|
257 |
+
"1":{"name":"Diana","age":34},
|
258 |
+
"2":{"name":"Donald","age":39}
|
259 |
+
}
|
260 |
"""
|
261 |
|
262 |
# main method that serializes a table.
|
|
|
282 |
|
283 |
HTML table format used for rendering tables in web pages.
|
284 |
Format(Sample):
|
285 |
+
|
286 |
+
.. code-block:: html
|
287 |
+
|
288 |
+
<table>
|
289 |
+
<thead>
|
290 |
+
<tr><th>name</th><th>age</th><th>sex</th></tr>
|
291 |
+
</thead>
|
292 |
+
<tbody>
|
293 |
+
<tr><td>Alice</td><td>26</td><td>F</td></tr>
|
294 |
+
<tr><td>Raj</td><td>34</td><td>M</td></tr>
|
295 |
+
</tbody>
|
296 |
+
</table>
|
297 |
"""
|
298 |
|
299 |
# main method that serializes a table.
|
|
|
425 |
"""Limits table rows to specified limit by removing excess rows via random selection.
|
426 |
|
427 |
Args:
|
428 |
+
rows_to_keep (int): number of rows to keep.
|
429 |
"""
|
430 |
|
431 |
rows_to_keep: int = 10
|
|
|
584 |
class ConvertTableColNamesToSequential(FieldOperator):
|
585 |
"""Replaces actual table column names with static sequential names like col_0, col_1,...
|
586 |
|
587 |
+
.. code-block:: text
|
588 |
+
|
589 |
+
Sample input:
|
590 |
+
{
|
591 |
+
"header": ["name", "age"],
|
592 |
+
"rows": [["Alex", 21], ["Donald", 34]]
|
593 |
+
}
|
594 |
+
|
595 |
+
Sample output:
|
596 |
+
{
|
597 |
+
"header": ["col_0", "col_1"],
|
598 |
+
"rows": [["Alex", 21], ["Donald", 34]]
|
599 |
+
}
|
600 |
"""
|
601 |
|
602 |
def process_value(self, table: Any) -> Any:
|
|
|
619 |
class ShuffleTableRows(TypeDependentAugmentor):
|
620 |
"""Shuffles the input table rows randomly.
|
621 |
|
622 |
+
.. code-block:: text
|
|
|
|
|
|
|
|
|
623 |
|
624 |
+
Sample Input:
|
625 |
+
{
|
626 |
+
"header": ["name", "age"],
|
627 |
+
"rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
|
628 |
+
}
|
629 |
+
|
630 |
+
Sample Output:
|
631 |
+
{
|
632 |
+
"header": ["name", "age"],
|
633 |
+
"rows": [["Donald", 39], ["Raj", 34], ["Alex", 26]],
|
634 |
+
}
|
635 |
"""
|
636 |
|
637 |
augmented_type = Table
|
|
|
645 |
class ShuffleTableColumns(TypeDependentAugmentor):
|
646 |
"""Shuffles the table columns randomly.
|
647 |
|
648 |
+
.. code-block:: text
|
|
|
|
|
|
|
|
|
649 |
|
650 |
+
Sample Input:
|
651 |
+
{
|
652 |
+
"header": ["name", "age"],
|
653 |
+
"rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
|
654 |
+
}
|
655 |
+
|
656 |
+
Sample Output:
|
657 |
+
{
|
658 |
+
"header": ["age", "name"],
|
659 |
+
"rows": [[26, "Alex"], [34, "Raj"], [39, "Donald"]],
|
660 |
+
}
|
661 |
"""
|
662 |
|
663 |
augmented_type = Table
|
|
|
690 |
class MapHTMLTableToJSON(FieldOperator):
|
691 |
"""Converts HTML table format to the basic one (JSON).
|
692 |
|
693 |
+
JSON format:
|
694 |
+
|
695 |
+
.. code-block:: json
|
696 |
+
|
697 |
+
{
|
698 |
+
"header": ["col1", "col2"],
|
699 |
+
"rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
|
700 |
+
}
|
701 |
"""
|
702 |
|
703 |
_requirements_list = ["bs4"]
|
|
|
732 |
class MapTableListsToStdTableJSON(FieldOperator):
|
733 |
"""Converts lists table format to the basic one (JSON).
|
734 |
|
735 |
+
JSON format:
|
736 |
+
|
737 |
+
.. code-block:: json
|
738 |
+
|
739 |
+
{
|
740 |
+
"header": ["col1", "col2"],
|
741 |
+
"rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
|
742 |
+
}
|
743 |
"""
|
744 |
|
745 |
def process_value(self, table: Any) -> Any:
|
|
|
789 |
class TransposeTable(TypeDependentAugmentor):
|
790 |
"""Transpose a table.
|
791 |
|
792 |
+
.. code-block:: text
|
793 |
+
|
794 |
+
Sample Input:
|
795 |
+
{
|
796 |
+
"header": ["name", "age", "sex"],
|
797 |
+
"rows": [["Alice", 26, "F"], ["Raj", 34, "M"], ["Donald", 39, "M"]],
|
798 |
+
}
|
799 |
+
|
800 |
+
Sample Output:
|
801 |
+
{
|
802 |
+
"header": [" ", "0", "1", "2"],
|
803 |
+
"rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
|
804 |
+
}
|
805 |
|
|
|
|
|
|
|
|
|
|
|
806 |
"""
|
807 |
|
808 |
augmented_type = Table
|
|
|
828 |
"""Duplicates specific rows of a table for the given number of times.
|
829 |
|
830 |
Args:
|
831 |
+
row_indices (List[int]): rows to be duplicated
|
832 |
+
|
833 |
+
times(int): each row to be duplicated is to show that many times
|
834 |
"""
|
835 |
|
836 |
augmented_type = Table
|
|
|
861 |
"""Duplicates specific columns of a table for the given number of times.
|
862 |
|
863 |
Args:
|
864 |
+
column_indices (List[int]): columns to be duplicated
|
865 |
+
|
866 |
+
times(int): each column to be duplicated is to show that many times
|
867 |
"""
|
868 |
|
869 |
augmented_type = Table
|
task.py
CHANGED
@@ -41,24 +41,28 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
41 |
|
42 |
Attributes:
|
43 |
input_fields (Union[Dict[str, str], List[str]]):
|
44 |
-
|
45 |
-
|
|
|
46 |
reference_fields (Union[Dict[str, str], List[str]]):
|
47 |
-
|
48 |
-
|
|
|
49 |
metrics (List[str]): List of names of metrics to be used in the task.
|
|
|
50 |
prediction_type (Optional[str]):
|
51 |
-
|
52 |
-
|
|
|
53 |
defaults (Optional[Dict[str, Any]]):
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
The output instance contains three fields:
|
59 |
-
"input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
|
60 |
-
"reference_fields" -- for the fields listed in Arg "reference_fields".
|
61 |
-
"metrics" -- to contain the value of Arg 'metrics'
|
62 |
"""
|
63 |
|
64 |
input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
|
|
41 |
|
42 |
Attributes:
|
43 |
input_fields (Union[Dict[str, str], List[str]]):
|
44 |
+
Dictionary with string names of instance input fields and types of respective values.
|
45 |
+
In case a list is passed, each type will be assumed to be Any.
|
46 |
+
|
47 |
reference_fields (Union[Dict[str, str], List[str]]):
|
48 |
+
Dictionary with string names of instance output fields and types of respective values.
|
49 |
+
In case a list is passed, each type will be assumed to be Any.
|
50 |
+
|
51 |
metrics (List[str]): List of names of metrics to be used in the task.
|
52 |
+
|
53 |
prediction_type (Optional[str]):
|
54 |
+
Need to be consistent with all used metrics. Defaults to None, which means that it will
|
55 |
+
be set to Any.
|
56 |
+
|
57 |
defaults (Optional[Dict[str, Any]]):
|
58 |
+
An optional dictionary with default values for chosen input/output keys. Needs to be
|
59 |
+
consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
|
60 |
+
Will not overwrite values if already provided in a given instance.
|
61 |
|
62 |
The output instance contains three fields:
|
63 |
+
1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
|
64 |
+
2. "reference_fields" -- for the fields listed in Arg "reference_fields".
|
65 |
+
3. "metrics" -- to contain the value of Arg 'metrics'
|
66 |
"""
|
67 |
|
68 |
input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
templates.py
CHANGED
@@ -308,19 +308,25 @@ class PairwiseChoiceTemplate(InputOutputTemplate):
|
|
308 |
|
309 |
Args:
|
310 |
choice_a_field (str): The field which contains choice_a value
|
|
|
311 |
choice_b_field (str): The field which contains choice_b value
|
|
|
312 |
answer_field (str): The field which contains the answer value.
|
313 |
-
|
|
|
314 |
choice_a_label (str): The label of choice A answer as it is verbalized in the template.
|
|
|
315 |
choice_b_label (str): The label of choice B answer as it is verbalized in the template.
|
|
|
316 |
choice_tie_label (str): The label of a tie answer as it should be verbalized in the template.
|
|
|
317 |
shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
|
318 |
|
319 |
shuffle: 50% of the time:
|
320 |
-
1
|
321 |
-
2
|
322 |
-
|
323 |
-
|
324 |
|
325 |
"""
|
326 |
|
@@ -433,14 +439,17 @@ class PairwiseComparativeRatingTemplate(InputOutputTemplate):
|
|
433 |
|
434 |
Args:
|
435 |
choice_a_field (str): The field which contains choice_a value
|
|
|
436 |
choice_b_field (str): The field which contains choice_b value
|
|
|
437 |
answer_field (str): The field which contains the answer value. The value should be an int.
|
438 |
-
|
|
|
439 |
shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
|
440 |
|
441 |
shuffle: 50% of the time:
|
442 |
-
|
443 |
-
|
444 |
|
445 |
"""
|
446 |
|
|
|
308 |
|
309 |
Args:
|
310 |
choice_a_field (str): The field which contains choice_a value
|
311 |
+
|
312 |
choice_b_field (str): The field which contains choice_b value
|
313 |
+
|
314 |
answer_field (str): The field which contains the answer value.
|
315 |
+
Should be of type Literal["choice_1", "choice_2", "tie"]
|
316 |
+
|
317 |
choice_a_label (str): The label of choice A answer as it is verbalized in the template.
|
318 |
+
|
319 |
choice_b_label (str): The label of choice B answer as it is verbalized in the template.
|
320 |
+
|
321 |
choice_tie_label (str): The label of a tie answer as it should be verbalized in the template.
|
322 |
+
|
323 |
shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
|
324 |
|
325 |
shuffle: 50% of the time:
|
326 |
+
1. The values of choice_a_field and choice_b_field will be swapped.
|
327 |
+
2. If the values of answer_field is choice_a_label, set it to choice_b_label.
|
328 |
+
| Else if the values of answer_field is choice_b_label, set it to choice_a_label.
|
329 |
+
| Else if the value of answer_field is choice_tie_label, do nothing.
|
330 |
|
331 |
"""
|
332 |
|
|
|
439 |
|
440 |
Args:
|
441 |
choice_a_field (str): The field which contains choice_a value
|
442 |
+
|
443 |
choice_b_field (str): The field which contains choice_b value
|
444 |
+
|
445 |
answer_field (str): The field which contains the answer value. The value should be an int.
|
446 |
+
Positive for preferring choice_a, and negative for preferring choice_b
|
447 |
+
|
448 |
shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
|
449 |
|
450 |
shuffle: 50% of the time:
|
451 |
+
| 1) The values of choice_a_field and choice_b_field will be swapped.
|
452 |
+
| 2) Replace the values of answer_field with its mapped value according to the reverse_preference_map Dict.
|
453 |
|
454 |
"""
|
455 |
|
type_utils.py
CHANGED
@@ -307,21 +307,22 @@ def infer_type_string(obj: typing.Any) -> str:
|
|
307 |
obj:Any
|
308 |
|
309 |
Returns:
|
310 |
-
|
|
|
|
|
|
|
|
|
311 |
|
312 |
-
formal definition of the returned string:
|
313 |
-
Type -> basic | List[Type] | Dict[Type, Type] | Union[Type (, Type)* | Tuple[Type (,Type)*]
|
314 |
-
basic -> bool,str,int,float,Any
|
315 |
-
no spaces at all.
|
316 |
|
317 |
Examples:
|
318 |
-
infer_type_string({"how_much": 7}) returns "Dict[str,int]"
|
319 |
-
infer_type_string([1, 2]) returns "List[int]"
|
320 |
-
infer_type_string([]) returns "List[Any]") no contents to list to indicate any type
|
321 |
-
infer_type_string([[], [7]]) returns "List[List[int]]" type of parent list indicated
|
322 |
-
|
323 |
-
|
324 |
-
infer_type_string([[], 7, True]) returns "List[Union[List[Any],int]]"
|
|
|
325 |
|
326 |
"""
|
327 |
|
|
|
307 |
obj:Any
|
308 |
|
309 |
Returns:
|
310 |
+
a string representation of the type of the object. e.g. ``"str"``, ``"List[int]"``, ``"Dict[str, Any]"``
|
311 |
+
|
312 |
+
| formal definition of the returned string:
|
313 |
+
| Type -> basic | List[Type] | Dict[Type, Type] | Union[Type(, Type)*] | Tuple[Type(, Type)*]
|
314 |
+
| basic -> ``bool`` | ``str`` | ``int`` | ``float`` | ``Any``
|
315 |
|
|
|
|
|
|
|
|
|
316 |
|
317 |
Examples:
|
318 |
+
| ``infer_type_string({"how_much": 7})`` returns ``"Dict[str,int]"``
|
319 |
+
| ``infer_type_string([1, 2])`` returns ``"List[int]"``
|
320 |
+
| ``infer_type_string([])`` returns ``"List[Any]")`` no contents to list to indicate any type
|
321 |
+
| ``infer_type_string([[], [7]])`` returns ``"List[List[int]]"`` type of parent list indicated
|
322 |
+
by the type of the non-empty child list. The empty child list is indeed, by default, also of
|
323 |
+
that type of the non-empty child.
|
324 |
+
| ``infer_type_string([[], 7, True])`` returns ``"List[Union[List[Any],int]]"``
|
325 |
+
because ``bool`` is also an ``int``
|
326 |
|
327 |
"""
|
328 |
|
utils.py
CHANGED
@@ -32,8 +32,8 @@ class LRUCache:
|
|
32 |
|
33 |
Attributes:
|
34 |
max_size (int): The maximum number of items to store in the cache.
|
35 |
-
|
36 |
-
|
37 |
"""
|
38 |
|
39 |
def __init__(self, max_size=10):
|
|
|
32 |
|
33 |
Attributes:
|
34 |
max_size (int): The maximum number of items to store in the cache.
|
35 |
+
Items exceeding this limit are automatically removed based on least
|
36 |
+
recent usage.
|
37 |
"""
|
38 |
|
39 |
def __init__(self, max_size=10):
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.15.
|
|
|
1 |
+
version = "1.15.10"
|