Upload folder using huggingface_hub
Browse files- api.py +41 -9
- artifact.py +7 -2
- collections_operators.py +22 -4
- dialog_operators.py +2 -2
- formats.py +1 -0
- generator_utils.py +2 -32
- inference.py +376 -55
- llm_as_judge.py +261 -62
- loaders.py +14 -6
- metric_utils.py +18 -9
- metrics.py +206 -67
- operators.py +79 -47
- processors.py +77 -2
- settings_utils.py +1 -0
- split_utils.py +6 -1
- splitters.py +4 -2
- standard.py +6 -6
- stream.py +4 -3
- stream_operators.py +5 -3
- string_operators.py +9 -0
- struct_data_operators.py +194 -5
- templates.py +1 -1
- type_utils.py +3 -0
- utils.py +84 -1
- version.py +1 -1
api.py
CHANGED
@@ -1,10 +1,10 @@
|
|
|
|
1 |
from functools import lru_cache
|
2 |
from typing import Any, Dict, List, Optional, Union
|
3 |
|
4 |
-
from datasets import DatasetDict
|
5 |
-
|
6 |
from .artifact import fetch_artifact
|
7 |
from .dataset_utils import get_dataset_artifact
|
|
|
8 |
from .logging_utils import get_logger
|
9 |
from .metric_utils import _compute, _inference_post_process
|
10 |
from .operator import SourceOperator
|
@@ -14,7 +14,7 @@ from .standard import StandardRecipe
|
|
14 |
logger = get_logger()
|
15 |
|
16 |
|
17 |
-
def load(source: Union[SourceOperator, str])
|
18 |
assert isinstance(
|
19 |
source, (SourceOperator, str)
|
20 |
), "source must be a SourceOperator or a string"
|
@@ -79,7 +79,9 @@ def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> StandardRecipe
|
|
79 |
return recipe
|
80 |
|
81 |
|
82 |
-
def load_dataset(
|
|
|
|
|
83 |
"""Loads dataset.
|
84 |
|
85 |
If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
|
@@ -90,6 +92,7 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
|
90 |
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
91 |
For example:
|
92 |
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
|
|
|
93 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
94 |
|
95 |
Returns:
|
@@ -107,6 +110,9 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
|
107 |
"""
|
108 |
recipe = load_recipe(dataset_query, **kwargs)
|
109 |
|
|
|
|
|
|
|
110 |
return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
|
111 |
|
112 |
|
@@ -135,19 +141,45 @@ def produce(instance_or_instances, dataset_query: Optional[str] = None, **kwargs
|
|
135 |
|
136 |
def infer(
|
137 |
instance_or_instances,
|
138 |
-
engine,
|
139 |
dataset_query: Optional[str] = None,
|
140 |
-
return_data=False,
|
|
|
|
|
141 |
**kwargs,
|
142 |
):
|
143 |
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
144 |
engine, _ = fetch_artifact(engine)
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
predictions = post_process(raw_predictions, dataset)
|
147 |
if return_data:
|
148 |
-
for prediction, raw_prediction, instance in zip(
|
149 |
-
predictions, raw_predictions, dataset
|
150 |
):
|
|
|
|
|
|
|
151 |
instance["prediction"] = prediction
|
152 |
instance["raw_prediction"] = raw_prediction
|
153 |
return dataset
|
|
|
1 |
+
import json
|
2 |
from functools import lru_cache
|
3 |
from typing import Any, Dict, List, Optional, Union
|
4 |
|
|
|
|
|
5 |
from .artifact import fetch_artifact
|
6 |
from .dataset_utils import get_dataset_artifact
|
7 |
+
from .inference import InferenceEngine, LogProbInferenceEngine
|
8 |
from .logging_utils import get_logger
|
9 |
from .metric_utils import _compute, _inference_post_process
|
10 |
from .operator import SourceOperator
|
|
|
14 |
logger = get_logger()
|
15 |
|
16 |
|
17 |
+
def load(source: Union[SourceOperator, str]):
|
18 |
assert isinstance(
|
19 |
source, (SourceOperator, str)
|
20 |
), "source must be a SourceOperator or a string"
|
|
|
79 |
return recipe
|
80 |
|
81 |
|
82 |
+
def load_dataset(
|
83 |
+
dataset_query: Optional[str] = None, streaming: bool = False, **kwargs
|
84 |
+
):
|
85 |
"""Loads dataset.
|
86 |
|
87 |
If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
|
|
|
92 |
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
93 |
For example:
|
94 |
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
|
95 |
+
streaming (bool, False): When True yields the data as Unitxt streams dictionary
|
96 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
97 |
|
98 |
Returns:
|
|
|
110 |
"""
|
111 |
recipe = load_recipe(dataset_query, **kwargs)
|
112 |
|
113 |
+
if streaming:
|
114 |
+
return recipe()
|
115 |
+
|
116 |
return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
|
117 |
|
118 |
|
|
|
141 |
|
142 |
def infer(
|
143 |
instance_or_instances,
|
144 |
+
engine: InferenceEngine,
|
145 |
dataset_query: Optional[str] = None,
|
146 |
+
return_data: bool = False,
|
147 |
+
return_log_probs: bool = False,
|
148 |
+
return_meta_data: bool = False,
|
149 |
**kwargs,
|
150 |
):
|
151 |
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
152 |
engine, _ = fetch_artifact(engine)
|
153 |
+
if return_log_probs:
|
154 |
+
if not isinstance(engine, LogProbInferenceEngine):
|
155 |
+
raise NotImplementedError(
|
156 |
+
f"Error in infer: return_log_probs set to True but supplied engine "
|
157 |
+
f"{engine.__class__.__name__} does not support logprobs."
|
158 |
+
)
|
159 |
+
infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
|
160 |
+
raw_predictions = (
|
161 |
+
[output.prediction for output in infer_outputs]
|
162 |
+
if return_meta_data
|
163 |
+
else infer_outputs
|
164 |
+
)
|
165 |
+
raw_predictions = [
|
166 |
+
json.dumps(raw_prediction) for raw_prediction in raw_predictions
|
167 |
+
]
|
168 |
+
else:
|
169 |
+
infer_outputs = engine.infer(dataset, return_meta_data)
|
170 |
+
raw_predictions = (
|
171 |
+
[output.prediction for output in infer_outputs]
|
172 |
+
if return_meta_data
|
173 |
+
else infer_outputs
|
174 |
+
)
|
175 |
predictions = post_process(raw_predictions, dataset)
|
176 |
if return_data:
|
177 |
+
for prediction, raw_prediction, instance, infer_output in zip(
|
178 |
+
predictions, raw_predictions, dataset, infer_outputs
|
179 |
):
|
180 |
+
if return_meta_data:
|
181 |
+
instance["infer_meta_data"] = infer_output.__dict__
|
182 |
+
del instance["infer_meta_data"]["prediction"]
|
183 |
instance["prediction"] = prediction
|
184 |
instance["raw_prediction"] = raw_prediction
|
185 |
return dataset
|
artifact.py
CHANGED
@@ -22,7 +22,12 @@ from .parsing_utils import (
|
|
22 |
from .settings_utils import get_constants, get_settings
|
23 |
from .text_utils import camel_to_snake_case, is_camel_case
|
24 |
from .type_utils import issubtype
|
25 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
logger = get_logger()
|
28 |
settings = get_settings()
|
@@ -405,7 +410,7 @@ def get_raw(obj):
|
|
405 |
if isinstance(obj, dict):
|
406 |
return type(obj)({get_raw(k): get_raw(v) for k, v in obj.items()})
|
407 |
|
408 |
-
return
|
409 |
|
410 |
|
411 |
class ArtifactList(list, Artifact):
|
|
|
22 |
from .settings_utils import get_constants, get_settings
|
23 |
from .text_utils import camel_to_snake_case, is_camel_case
|
24 |
from .type_utils import issubtype
|
25 |
+
from .utils import (
|
26 |
+
artifacts_json_cache,
|
27 |
+
json_dump,
|
28 |
+
save_to_file,
|
29 |
+
shallow_copy,
|
30 |
+
)
|
31 |
|
32 |
logger = get_logger()
|
33 |
settings = get_settings()
|
|
|
410 |
if isinstance(obj, dict):
|
411 |
return type(obj)({get_raw(k): get_raw(v) for k, v in obj.items()})
|
412 |
|
413 |
+
return shallow_copy(obj)
|
414 |
|
415 |
|
416 |
class ArtifactList(list, Artifact):
|
collections_operators.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Generator, List, Optional
|
|
3 |
from .dict_utils import dict_get, dict_set
|
4 |
from .operators import FieldOperator, StreamOperator
|
5 |
from .stream import Stream
|
6 |
-
from .utils import
|
7 |
|
8 |
|
9 |
class Dictify(FieldOperator):
|
@@ -70,10 +70,10 @@ class DuplicateByList(StreamOperator):
|
|
70 |
elements = dict_get(instance, self.field)
|
71 |
for element in elements:
|
72 |
if self.use_deep_copy:
|
73 |
-
instance_copy =
|
74 |
|
75 |
else:
|
76 |
-
instance_copy =
|
77 |
dict_set(instance_copy, to_field, element)
|
78 |
yield instance_copy
|
79 |
|
@@ -93,7 +93,7 @@ class DuplicateBySubLists(StreamOperator):
|
|
93 |
elements = instance[self.field]
|
94 |
for i in range(1, len(elements) + 1):
|
95 |
if self.use_deep_copy:
|
96 |
-
instance_copy =
|
97 |
instance_copy[to_field] = elements[:i]
|
98 |
else:
|
99 |
instance_copy = {
|
@@ -107,3 +107,21 @@ class DuplicateBySubLists(StreamOperator):
|
|
107 |
class GetLength(FieldOperator):
|
108 |
def process_value(self, collection: Any) -> Any:
|
109 |
return len(collection)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from .dict_utils import dict_get, dict_set
|
4 |
from .operators import FieldOperator, StreamOperator
|
5 |
from .stream import Stream
|
6 |
+
from .utils import recursive_shallow_copy
|
7 |
|
8 |
|
9 |
class Dictify(FieldOperator):
|
|
|
70 |
elements = dict_get(instance, self.field)
|
71 |
for element in elements:
|
72 |
if self.use_deep_copy:
|
73 |
+
instance_copy = recursive_shallow_copy(instance)
|
74 |
|
75 |
else:
|
76 |
+
instance_copy = instance.copy()
|
77 |
dict_set(instance_copy, to_field, element)
|
78 |
yield instance_copy
|
79 |
|
|
|
93 |
elements = instance[self.field]
|
94 |
for i in range(1, len(elements) + 1):
|
95 |
if self.use_deep_copy:
|
96 |
+
instance_copy = recursive_shallow_copy(instance)
|
97 |
instance_copy[to_field] = elements[:i]
|
98 |
else:
|
99 |
instance_copy = {
|
|
|
107 |
class GetLength(FieldOperator):
|
108 |
def process_value(self, collection: Any) -> Any:
|
109 |
return len(collection)
|
110 |
+
|
111 |
+
|
112 |
+
class Filter(FieldOperator):
|
113 |
+
values: List[Any]
|
114 |
+
|
115 |
+
def process_value(self, collection: Any) -> Any:
|
116 |
+
# If collection is a list, tuple, or set
|
117 |
+
if isinstance(collection, (list, set, tuple)):
|
118 |
+
return type(collection)(
|
119 |
+
item for item in collection if item not in self.values
|
120 |
+
)
|
121 |
+
|
122 |
+
# If collection is a dictionary, filter by keys
|
123 |
+
if isinstance(collection, dict):
|
124 |
+
return {k: v for k, v in collection.items() if k not in self.values}
|
125 |
+
|
126 |
+
# If collection is of an unsupported type
|
127 |
+
raise TypeError(f"Unsupported collection type: {type(collection)}")
|
dialog_operators.py
CHANGED
@@ -157,13 +157,13 @@ class SerializeOpenAiFormatDialog(SerializeDialog):
|
|
157 |
f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
|
158 |
)
|
159 |
|
160 |
-
if entry["role"] not in {"user", "assistant"}:
|
161 |
raise ValueError(
|
162 |
f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
|
163 |
)
|
164 |
|
165 |
first_entry = dialog[0]
|
166 |
-
if first_entry["role"] != "user":
|
167 |
raise ValueError(
|
168 |
f"First entry role is expected to be 'user' It is {first_entry['role']}."
|
169 |
)
|
|
|
157 |
f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
|
158 |
)
|
159 |
|
160 |
+
if entry["role"].lower() not in {"user", "assistant"}:
|
161 |
raise ValueError(
|
162 |
f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
|
163 |
)
|
164 |
|
165 |
first_entry = dialog[0]
|
166 |
+
if first_entry["role"].lower() != "user":
|
167 |
raise ValueError(
|
168 |
f"First entry role is expected to be 'user' It is {first_entry['role']}."
|
169 |
)
|
formats.py
CHANGED
@@ -182,6 +182,7 @@ class SystemFormat(BaseFormat):
|
|
182 |
target_prefix=demo_target_prefix,
|
183 |
source=demo_source,
|
184 |
target=demo_target,
|
|
|
185 |
**self.format_args,
|
186 |
)
|
187 |
demos_string += demo_str
|
|
|
182 |
target_prefix=demo_target_prefix,
|
183 |
source=demo_source,
|
184 |
target=demo_target,
|
185 |
+
instruction=instruction,
|
186 |
**self.format_args,
|
187 |
)
|
188 |
demos_string += demo_str
|
generator_utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from typing import Any, Dict, List
|
2 |
|
3 |
from .dataclass import Dataclass, OptionalField
|
4 |
-
from .utils import
|
5 |
|
6 |
|
7 |
class ReusableGenerator(Dataclass):
|
@@ -22,34 +22,4 @@ class ReusableGenerator(Dataclass):
|
|
22 |
class CopyingReusableGenerator(ReusableGenerator):
|
23 |
def __iter__(self):
|
24 |
for instance in self.activate():
|
25 |
-
yield
|
26 |
-
|
27 |
-
|
28 |
-
# if __name__ == "__main__":
|
29 |
-
# from itertools import chain, islice
|
30 |
-
|
31 |
-
# # Creating objects of MyIterable
|
32 |
-
# iterable1 = ReusableGenerator(range, gen_argv=[1, 4])
|
33 |
-
# iterable2 = ReusableGenerator(range, gen_argv=[4, 7])
|
34 |
-
|
35 |
-
# # Using itertools.chain
|
36 |
-
# chained = list(chain(iterable1, iterable2))
|
37 |
-
# logger.info(chained) # Prints: [1, 2, 3, 4, 5, 6]
|
38 |
-
|
39 |
-
# # Using itertools.islice
|
40 |
-
# sliced = list(islice(ReusableGenerator(range, gen_argv=[1, 7]), 1, 4))
|
41 |
-
# logger.info(sliced) # Prints: [2, 3, 4]
|
42 |
-
|
43 |
-
# # now same test with generators
|
44 |
-
# def generator(start, end):
|
45 |
-
# for i in range(start, end):
|
46 |
-
# yield i
|
47 |
-
|
48 |
-
# iterable1 = ReusableGenerator(generator, gen_argv=[1, 4])
|
49 |
-
# iterable2 = ReusableGenerator(generator, gen_argv=[4, 7])
|
50 |
-
|
51 |
-
# chained = list(chain(iterable1, iterable2))
|
52 |
-
# logger.info(chained) # Prints: [1, 2, 3, 4, 5, 6]
|
53 |
-
|
54 |
-
# sliced = list(islice(ReusableGenerator(generator, gen_argv=[1, 7]), 1, 4))
|
55 |
-
# logger.info(sliced) # Prints: [2, 3, 4]
|
|
|
1 |
from typing import Any, Dict, List
|
2 |
|
3 |
from .dataclass import Dataclass, OptionalField
|
4 |
+
from .utils import recursive_shallow_copy
|
5 |
|
6 |
|
7 |
class ReusableGenerator(Dataclass):
|
|
|
22 |
class CopyingReusableGenerator(ReusableGenerator):
|
23 |
def __iter__(self):
|
24 |
for instance in self.activate():
|
25 |
+
yield recursive_shallow_copy(instance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
import abc
|
|
|
2 |
import os
|
3 |
import re
|
4 |
from typing import Any, Dict, List, Literal, Optional, Union
|
5 |
|
|
|
6 |
from tqdm import tqdm
|
7 |
|
8 |
from .artifact import Artifact, fetch_artifact
|
@@ -16,12 +18,52 @@ from .settings_utils import get_settings
|
|
16 |
settings = get_settings()
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
class InferenceEngine(abc.ABC, Artifact):
|
20 |
"""Abstract base class for inference."""
|
21 |
|
22 |
@abc.abstractmethod
|
23 |
-
def _infer(
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
pass
|
26 |
|
27 |
@abc.abstractmethod
|
@@ -33,12 +75,29 @@ class InferenceEngine(abc.ABC, Artifact):
|
|
33 |
if not settings.mock_inference_mode:
|
34 |
self.prepare_engine()
|
35 |
|
36 |
-
def infer(
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
[self.verify_instance(instance) for instance in dataset]
|
39 |
if settings.mock_inference_mode:
|
40 |
return [instance["source"] for instance in dataset]
|
41 |
-
return self._infer(dataset)
|
|
|
|
|
|
|
42 |
|
43 |
@deprecation(version="2.0.0")
|
44 |
def _set_inference_parameters(self):
|
@@ -62,19 +121,39 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
|
|
62 |
"""Abstract base class for inference with log probs."""
|
63 |
|
64 |
@abc.abstractmethod
|
65 |
-
def _infer_log_probs(
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
pass
|
68 |
|
69 |
-
def infer_log_probs(
|
|
|
|
|
|
|
|
|
70 |
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
71 |
|
72 |
-
For each instance ,
|
73 |
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
|
74 |
-
|
|
|
75 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
[self.verify_instance(instance) for instance in dataset]
|
77 |
-
return self._infer_log_probs(dataset)
|
78 |
|
79 |
|
80 |
class LazyLoadMixin(Artifact):
|
@@ -96,6 +175,9 @@ class HFPipelineBasedInferenceEngine(
|
|
96 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
97 |
}
|
98 |
|
|
|
|
|
|
|
99 |
def _prepare_pipeline(self):
|
100 |
import torch
|
101 |
from transformers import AutoConfig, pipeline
|
@@ -143,7 +225,11 @@ class HFPipelineBasedInferenceEngine(
|
|
143 |
def _is_loaded(self):
|
144 |
return hasattr(self, "model") and self.model is not None
|
145 |
|
146 |
-
def _infer(
|
|
|
|
|
|
|
|
|
147 |
if not self._is_loaded():
|
148 |
self._prepare_pipeline()
|
149 |
|
@@ -157,12 +243,20 @@ class HFPipelineBasedInferenceEngine(
|
|
157 |
|
158 |
class MockInferenceEngine(InferenceEngine):
|
159 |
model_name: str
|
|
|
|
|
|
|
|
|
160 |
|
161 |
def prepare_engine(self):
|
162 |
return
|
163 |
|
164 |
-
def _infer(
|
165 |
-
|
|
|
|
|
|
|
|
|
166 |
|
167 |
|
168 |
class MockModeMixin(Artifact):
|
@@ -226,7 +320,14 @@ class GenericInferenceEngine(InferenceEngine):
|
|
226 |
engine_reference = self.default
|
227 |
self.engine, _ = fetch_artifact(engine_reference)
|
228 |
|
229 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
return self.engine._infer(dataset)
|
231 |
|
232 |
|
@@ -238,10 +339,17 @@ class OllamaInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
|
238 |
}
|
239 |
data_classification_policy = ["public", "proprietary"]
|
240 |
|
|
|
|
|
|
|
241 |
def prepare_engine(self):
|
242 |
pass
|
243 |
|
244 |
-
def _infer(
|
|
|
|
|
|
|
|
|
245 |
import ollama
|
246 |
|
247 |
result = [
|
@@ -260,7 +368,10 @@ class OllamaInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
|
260 |
|
261 |
|
262 |
class IbmGenAiInferenceEngine(
|
263 |
-
InferenceEngine,
|
|
|
|
|
|
|
264 |
):
|
265 |
label: str = "ibm_genai"
|
266 |
model_name: str
|
@@ -270,6 +381,9 @@ class IbmGenAiInferenceEngine(
|
|
270 |
data_classification_policy = ["public", "proprietary"]
|
271 |
parameters: Optional[IbmGenAiInferenceEngineParams] = None
|
272 |
|
|
|
|
|
|
|
273 |
def prepare_engine(self):
|
274 |
from genai import Client, Credentials
|
275 |
|
@@ -285,21 +399,88 @@ class IbmGenAiInferenceEngine(
|
|
285 |
|
286 |
self._set_inference_parameters()
|
287 |
|
288 |
-
def _infer(
|
|
|
|
|
|
|
|
|
289 |
from genai.schema import TextGenerationParameters
|
290 |
|
291 |
genai_params = TextGenerationParameters(
|
292 |
**self.to_dict([IbmGenAiInferenceEngineParamsMixin])
|
293 |
)
|
294 |
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
301 |
)
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
|
305 |
class OpenAiInferenceEngineParamsMixin(Artifact):
|
@@ -349,18 +530,29 @@ class OpenAiInferenceEngine(
|
|
349 |
data_classification_policy = ["public"]
|
350 |
parameters: Optional[OpenAiInferenceEngineParams] = None
|
351 |
|
352 |
-
def
|
353 |
-
|
354 |
|
355 |
-
|
356 |
-
|
|
|
357 |
assert api_key is not None, (
|
358 |
-
f"Error while trying to run
|
359 |
-
f" Please set the environment param '{
|
360 |
)
|
|
|
361 |
|
362 |
-
|
|
|
363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
self._set_inference_parameters()
|
365 |
|
366 |
def _get_completion_kwargs(self):
|
@@ -370,7 +562,11 @@ class OpenAiInferenceEngine(
|
|
370 |
if v is not None
|
371 |
}
|
372 |
|
373 |
-
def _infer(
|
|
|
|
|
|
|
|
|
374 |
outputs = []
|
375 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
376 |
response = self.client.chat.completions.create(
|
@@ -387,13 +583,18 @@ class OpenAiInferenceEngine(
|
|
387 |
model=self.model_name,
|
388 |
**self._get_completion_kwargs(),
|
389 |
)
|
390 |
-
|
|
|
391 |
|
392 |
outputs.append(output)
|
393 |
|
394 |
return outputs
|
395 |
|
396 |
-
def _infer_log_probs(
|
|
|
|
|
|
|
|
|
397 |
outputs = []
|
398 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
399 |
response = self.client.chat.completions.create(
|
@@ -411,7 +612,7 @@ class OpenAiInferenceEngine(
|
|
411 |
**self._get_completion_kwargs(),
|
412 |
)
|
413 |
top_logprobs_response = response.choices[0].logprobs.content
|
414 |
-
|
415 |
{
|
416 |
"top_tokens": [
|
417 |
{"text": obj.token, "logprob": obj.logprob}
|
@@ -420,9 +621,21 @@ class OpenAiInferenceEngine(
|
|
420 |
}
|
421 |
for generated_token in top_logprobs_response
|
422 |
]
|
|
|
423 |
outputs.append(output)
|
424 |
return outputs
|
425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
class TogetherAiInferenceEngineParamsMixin(Artifact):
|
428 |
max_tokens: Optional[int] = None
|
@@ -450,6 +663,9 @@ class TogetherAiInferenceEngine(
|
|
450 |
data_classification_policy = ["public"]
|
451 |
parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None
|
452 |
|
|
|
|
|
|
|
453 |
def prepare_engine(self):
|
454 |
from together import Together
|
455 |
from together.types.models import ModelType
|
@@ -501,7 +717,11 @@ class TogetherAiInferenceEngine(
|
|
501 |
)
|
502 |
return response.choices[0].text
|
503 |
|
504 |
-
def _infer(
|
|
|
|
|
|
|
|
|
505 |
from together.types.models import ModelType
|
506 |
|
507 |
outputs = []
|
@@ -514,6 +734,23 @@ class TogetherAiInferenceEngine(
|
|
514 |
return outputs
|
515 |
|
516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
class WMLInferenceEngineParamsMixin(Artifact):
|
518 |
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
519 |
length_penalty: Optional[Dict[str, Union[int, float]]] = None
|
@@ -550,7 +787,10 @@ class WMLInferenceEngineParams(Artifact):
|
|
550 |
|
551 |
|
552 |
class WMLInferenceEngine(
|
553 |
-
InferenceEngine,
|
|
|
|
|
|
|
554 |
):
|
555 |
"""Runs inference using ibm-watsonx-ai.
|
556 |
|
@@ -604,14 +844,17 @@ class WMLInferenceEngine(
|
|
604 |
concurrency_limit: int = 10
|
605 |
_client: Any = InternalField(default=None, name="WML client")
|
606 |
|
|
|
|
|
|
|
607 |
def verify(self):
|
608 |
super().verify()
|
609 |
|
610 |
if self.credentials is not None:
|
611 |
for key in self.credentials:
|
612 |
-
if key not in ["url", "apikey", "project_id"]:
|
613 |
raise ValueError(
|
614 |
-
f'Illegal credential key: {key}, use only ["url", "apikey", "project_id"]'
|
615 |
)
|
616 |
|
617 |
assert (
|
@@ -631,10 +874,14 @@ class WMLInferenceEngine(
|
|
631 |
|
632 |
@staticmethod
|
633 |
def _read_wml_credentials_from_env() -> (
|
634 |
-
Dict[Literal["url", "apikey", "project_id"], str]
|
635 |
):
|
636 |
credentials = {}
|
637 |
-
|
|
|
|
|
|
|
|
|
638 |
env_var = os.environ.get(env_var_name)
|
639 |
assert env_var, (
|
640 |
f"Error while trying to run 'WMLInferenceEngine'. "
|
@@ -655,7 +902,10 @@ class WMLInferenceEngine(
|
|
655 |
self.credentials = self._read_wml_credentials_from_env()
|
656 |
|
657 |
client = APIClient(credentials=self.credentials)
|
658 |
-
|
|
|
|
|
|
|
659 |
return client
|
660 |
|
661 |
def prepare_engine(self):
|
@@ -663,7 +913,7 @@ class WMLInferenceEngine(
|
|
663 |
|
664 |
self._set_inference_parameters()
|
665 |
|
666 |
-
def
|
667 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
668 |
|
669 |
model = ModelInference(
|
@@ -671,20 +921,81 @@ class WMLInferenceEngine(
|
|
671 |
deployment_id=self.deployment_id,
|
672 |
api_client=self._client,
|
673 |
)
|
|
|
674 |
|
675 |
-
|
676 |
-
dataset = dataset if isinstance(dataset, list) else [dataset]
|
677 |
|
678 |
-
|
679 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
680 |
prompt=instance["source"],
|
681 |
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
|
682 |
)
|
683 |
-
|
684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
|
686 |
-
|
687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
|
689 |
|
690 |
class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
@@ -698,6 +1009,9 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
698 |
"accelerate": "pip install accelerate",
|
699 |
}
|
700 |
|
|
|
|
|
|
|
701 |
def _prepare_engine(self):
|
702 |
import torch
|
703 |
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
@@ -725,14 +1039,18 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
725 |
def _is_loaded(self):
|
726 |
return hasattr(self, "model") and self.model is not None
|
727 |
|
728 |
-
def _infer(
|
|
|
|
|
|
|
|
|
729 |
if not self._is_loaded():
|
730 |
self._prepare_engine()
|
731 |
|
732 |
import torch
|
733 |
|
734 |
results = []
|
735 |
-
for instance in dataset:
|
736 |
text = instance["source"]
|
737 |
images = extract_images(text, instance)
|
738 |
# Regular expression to match all <img src="..."> tags
|
@@ -745,7 +1063,10 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
745 |
).to(self.device, torch.float16)
|
746 |
input_len = len(inputs["input_ids"][0])
|
747 |
output = self.model.generate(
|
748 |
-
**inputs,
|
|
|
|
|
|
|
749 |
)
|
750 |
result = self.processor.decode(
|
751 |
output[0][input_len:], skip_special_tokens=True
|
|
|
1 |
import abc
|
2 |
+
import dataclasses
|
3 |
import os
|
4 |
import re
|
5 |
from typing import Any, Dict, List, Literal, Optional, Union
|
6 |
|
7 |
+
from datasets import DatasetDict
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
from .artifact import Artifact, fetch_artifact
|
|
|
18 |
settings = get_settings()
|
19 |
|
20 |
|
21 |
+
def get_model_and_label_id(model_name, label):
|
22 |
+
model_id = model_name.split("/")[-1].replace("-", "_").replace(".", ",").lower()
|
23 |
+
return f"{model_id}_{label}"
|
24 |
+
|
25 |
+
|
26 |
+
@dataclasses.dataclass
|
27 |
+
class TextGenerationInferenceOutput:
|
28 |
+
"""Contains the prediction results and metadata for the inference.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
prediction (Union[str, List[Dict[str, Any]]]): If this is the result of an _infer call, the string predicted by the model.
|
32 |
+
If this is the results of an _infer_log_probs call, a list of dictionaries. The i'th dictionary represents
|
33 |
+
the i'th token in the response. The entry "top_tokens" in the dictionary holds a sorted list of the top tokens
|
34 |
+
for this position and their probabilities.
|
35 |
+
For example: [ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
|
36 |
+
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]}
|
37 |
+
]
|
38 |
+
|
39 |
+
input_tokens (int) : number of input tokens to the model.
|
40 |
+
output_tokens (int) : number of output tokens to the model.
|
41 |
+
model_name (str): the model_name as kept in the InferenceEngine.
|
42 |
+
inference_type (str): The label stating the type of the InferenceEngine.
|
43 |
+
"""
|
44 |
+
|
45 |
+
prediction: Union[str, List[Dict[str, Any]]]
|
46 |
+
input_tokens: Optional[int] = None
|
47 |
+
output_tokens: Optional[int] = None
|
48 |
+
model_name: Optional[str] = None
|
49 |
+
inference_type: Optional[str] = None
|
50 |
+
|
51 |
+
|
52 |
class InferenceEngine(abc.ABC, Artifact):
|
53 |
"""Abstract base class for inference."""
|
54 |
|
55 |
@abc.abstractmethod
|
56 |
+
def _infer(
|
57 |
+
self,
|
58 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
59 |
+
return_meta_data: bool = False,
|
60 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
61 |
+
"""Perform inference on the input dataset.
|
62 |
+
|
63 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string.
|
64 |
+
return_meta_data is only supported for some InferenceEngines.
|
65 |
+
predictions.
|
66 |
+
"""
|
67 |
pass
|
68 |
|
69 |
@abc.abstractmethod
|
|
|
75 |
if not settings.mock_inference_mode:
|
76 |
self.prepare_engine()
|
77 |
|
78 |
+
def infer(
|
79 |
+
self,
|
80 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
81 |
+
return_meta_data: bool = False,
|
82 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
83 |
+
"""Verifies instances of a dataset and perform inference on the input dataset.
|
84 |
+
|
85 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
|
86 |
+
predictions.
|
87 |
+
"""
|
88 |
+
if return_meta_data and not hasattr(self, "get_return_object"):
|
89 |
+
raise NotImplementedError(
|
90 |
+
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
|
91 |
+
f"does not contain a 'get_return_object' method. Please set return_meta_data=False."
|
92 |
+
)
|
93 |
+
|
94 |
[self.verify_instance(instance) for instance in dataset]
|
95 |
if settings.mock_inference_mode:
|
96 |
return [instance["source"] for instance in dataset]
|
97 |
+
return self._infer(dataset, return_meta_data)
|
98 |
+
|
99 |
+
def get_engine_id(self):
|
100 |
+
raise NotImplementedError()
|
101 |
|
102 |
@deprecation(version="2.0.0")
|
103 |
def _set_inference_parameters(self):
|
|
|
121 |
"""Abstract base class for inference with log probs."""
|
122 |
|
123 |
@abc.abstractmethod
|
124 |
+
def _infer_log_probs(
|
125 |
+
self,
|
126 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
127 |
+
return_meta_data: bool = False,
|
128 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
129 |
+
"""Perform inference on the input dataset that returns log probs.
|
130 |
+
|
131 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the logprob dicts.
|
132 |
+
return_meta_data is only supported for some InferenceEngines.
|
133 |
+
predictions.
|
134 |
+
"""
|
135 |
pass
|
136 |
|
137 |
+
def infer_log_probs(
|
138 |
+
self,
|
139 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
140 |
+
return_meta_data: bool = False,
|
141 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
142 |
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
143 |
|
144 |
+
For each instance , generates a list of top tokens per position.
|
145 |
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
|
146 |
+
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns the list of the logprob dicts.
|
147 |
+
return_meta_data is only supported for some InferenceEngines.
|
148 |
"""
|
149 |
+
if return_meta_data and not hasattr(self, "get_return_object"):
|
150 |
+
raise NotImplementedError(
|
151 |
+
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
|
152 |
+
f"does not contain a 'get_return_object' method. Please set return_meta_data=False."
|
153 |
+
)
|
154 |
+
|
155 |
[self.verify_instance(instance) for instance in dataset]
|
156 |
+
return self._infer_log_probs(dataset, return_meta_data)
|
157 |
|
158 |
|
159 |
class LazyLoadMixin(Artifact):
|
|
|
175 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
176 |
}
|
177 |
|
178 |
+
def get_engine_id(self):
|
179 |
+
return get_model_and_label_id(self.model_name, "hf_pipeline")
|
180 |
+
|
181 |
def _prepare_pipeline(self):
|
182 |
import torch
|
183 |
from transformers import AutoConfig, pipeline
|
|
|
225 |
def _is_loaded(self):
|
226 |
return hasattr(self, "model") and self.model is not None
|
227 |
|
228 |
+
def _infer(
|
229 |
+
self,
|
230 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
231 |
+
return_meta_data: bool = False,
|
232 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
233 |
if not self._is_loaded():
|
234 |
self._prepare_pipeline()
|
235 |
|
|
|
243 |
|
244 |
class MockInferenceEngine(InferenceEngine):
|
245 |
model_name: str
|
246 |
+
default_inference_value: str = "[[10]]"
|
247 |
+
|
248 |
+
def get_engine_id(self):
|
249 |
+
return get_model_and_label_id(self.model_name, "mock")
|
250 |
|
251 |
def prepare_engine(self):
|
252 |
return
|
253 |
|
254 |
+
def _infer(
|
255 |
+
self,
|
256 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
257 |
+
return_meta_data: bool = False,
|
258 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
259 |
+
return [self.default_inference_value for instance in dataset]
|
260 |
|
261 |
|
262 |
class MockModeMixin(Artifact):
|
|
|
320 |
engine_reference = self.default
|
321 |
self.engine, _ = fetch_artifact(engine_reference)
|
322 |
|
323 |
+
def get_engine_id(self):
|
324 |
+
return "generic_inference_engine"
|
325 |
+
|
326 |
+
def _infer(
|
327 |
+
self,
|
328 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
329 |
+
return_meta_data: bool = False,
|
330 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
331 |
return self.engine._infer(dataset)
|
332 |
|
333 |
|
|
|
339 |
}
|
340 |
data_classification_policy = ["public", "proprietary"]
|
341 |
|
342 |
+
def get_engine_id(self):
|
343 |
+
return get_model_and_label_id(self.model_name, self.label)
|
344 |
+
|
345 |
def prepare_engine(self):
|
346 |
pass
|
347 |
|
348 |
+
def _infer(
|
349 |
+
self,
|
350 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
351 |
+
return_meta_data: bool = False,
|
352 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
353 |
import ollama
|
354 |
|
355 |
result = [
|
|
|
368 |
|
369 |
|
370 |
class IbmGenAiInferenceEngine(
|
371 |
+
InferenceEngine,
|
372 |
+
IbmGenAiInferenceEngineParamsMixin,
|
373 |
+
PackageRequirementsMixin,
|
374 |
+
LogProbInferenceEngine,
|
375 |
):
|
376 |
label: str = "ibm_genai"
|
377 |
model_name: str
|
|
|
381 |
data_classification_policy = ["public", "proprietary"]
|
382 |
parameters: Optional[IbmGenAiInferenceEngineParams] = None
|
383 |
|
384 |
+
def get_engine_id(self):
|
385 |
+
return get_model_and_label_id(self.model_name, self.label)
|
386 |
+
|
387 |
def prepare_engine(self):
|
388 |
from genai import Client, Credentials
|
389 |
|
|
|
399 |
|
400 |
self._set_inference_parameters()
|
401 |
|
402 |
+
def _infer(
|
403 |
+
self,
|
404 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
405 |
+
return_meta_data: bool = False,
|
406 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
407 |
from genai.schema import TextGenerationParameters
|
408 |
|
409 |
genai_params = TextGenerationParameters(
|
410 |
**self.to_dict([IbmGenAiInferenceEngineParamsMixin])
|
411 |
)
|
412 |
|
413 |
+
results = []
|
414 |
+
responses = self.client.text.generation.create(
|
415 |
+
model_id=self.model_name,
|
416 |
+
inputs=[instance["source"] for instance in dataset],
|
417 |
+
parameters=genai_params,
|
418 |
+
)
|
419 |
+
for response in responses:
|
420 |
+
generated_text = response.results[0].generated_text
|
421 |
+
result = self.get_return_object(
|
422 |
+
generated_text, response.results[0], return_meta_data
|
423 |
)
|
424 |
+
results.append(result)
|
425 |
+
return results
|
426 |
+
|
427 |
+
def _infer_log_probs(
|
428 |
+
self,
|
429 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
430 |
+
return_meta_data: bool = False,
|
431 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
432 |
+
from genai.schema import TextGenerationParameters
|
433 |
+
|
434 |
+
logprobs_return_options = {
|
435 |
+
"generated_tokens": True,
|
436 |
+
"input_text": False,
|
437 |
+
"input_tokens": False,
|
438 |
+
"token_logprobs": True,
|
439 |
+
"token_ranks": True,
|
440 |
+
"top_n_tokens": 5,
|
441 |
+
}
|
442 |
+
genai_params = self.to_dict(
|
443 |
+
[IbmGenAiInferenceEngineParamsMixin], keep_empty=False
|
444 |
+
)
|
445 |
+
genai_params = {**genai_params, "return_options": logprobs_return_options}
|
446 |
+
genai_params = TextGenerationParameters(**genai_params)
|
447 |
+
predictions = self.client.text.generation.create(
|
448 |
+
model_id=self.model_name,
|
449 |
+
inputs=[instance["source"] for instance in dataset],
|
450 |
+
parameters=genai_params,
|
451 |
+
)
|
452 |
+
|
453 |
+
predict_results = []
|
454 |
+
for prediction in predictions:
|
455 |
+
result = prediction.results[0]
|
456 |
+
assert isinstance(
|
457 |
+
result.generated_tokens, list
|
458 |
+
), "result.generated_tokens should be a list"
|
459 |
+
|
460 |
+
predict_result = []
|
461 |
+
for base_token in result.generated_tokens:
|
462 |
+
res = {**base_token.__dict__, **base_token.model_extra}
|
463 |
+
res["top_tokens"] = [
|
464 |
+
{"logprob": top_token.logprob, "text": top_token.text}
|
465 |
+
for top_token in res["top_tokens"]
|
466 |
+
]
|
467 |
+
predict_result.append(res)
|
468 |
+
final_results = self.get_return_object(
|
469 |
+
predict_result, result, return_meta_data
|
470 |
+
)
|
471 |
+
predict_results.append(final_results)
|
472 |
+
return predict_results
|
473 |
+
|
474 |
+
def get_return_object(self, predict_result, result, return_meta_data):
|
475 |
+
if return_meta_data:
|
476 |
+
return TextGenerationInferenceOutput(
|
477 |
+
prediction=predict_result,
|
478 |
+
input_tokens=result.input_token_count,
|
479 |
+
output_tokens=result.generated_token_count,
|
480 |
+
model_name=self.model_name,
|
481 |
+
inference_type=self.label,
|
482 |
+
)
|
483 |
+
return predict_result
|
484 |
|
485 |
|
486 |
class OpenAiInferenceEngineParamsMixin(Artifact):
|
|
|
530 |
data_classification_policy = ["public"]
|
531 |
parameters: Optional[OpenAiInferenceEngineParams] = None
|
532 |
|
533 |
+
def get_engine_id(self):
|
534 |
+
return get_model_and_label_id(self.model_name, self.label)
|
535 |
|
536 |
+
@classmethod
|
537 |
+
def get_api_param(cls, inference_engine: str, api_param_env_var_name: str):
|
538 |
+
api_key = os.environ.get(api_param_env_var_name)
|
539 |
assert api_key is not None, (
|
540 |
+
f"Error while trying to run {inference_engine}."
|
541 |
+
f" Please set the environment param '{api_param_env_var_name}'."
|
542 |
)
|
543 |
+
return api_key
|
544 |
|
545 |
+
def create_client(self):
|
546 |
+
from openai import OpenAI
|
547 |
|
548 |
+
api_key = self.get_api_param(
|
549 |
+
inference_engine="OpenAiInferenceEngine",
|
550 |
+
api_param_env_var_name="OPENAI_API_KEY",
|
551 |
+
)
|
552 |
+
return OpenAI(api_key=api_key)
|
553 |
+
|
554 |
+
def prepare_engine(self):
|
555 |
+
self.client = self.create_client()
|
556 |
self._set_inference_parameters()
|
557 |
|
558 |
def _get_completion_kwargs(self):
|
|
|
562 |
if v is not None
|
563 |
}
|
564 |
|
565 |
+
def _infer(
|
566 |
+
self,
|
567 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
568 |
+
return_meta_data: bool = False,
|
569 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
570 |
outputs = []
|
571 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
572 |
response = self.client.chat.completions.create(
|
|
|
583 |
model=self.model_name,
|
584 |
**self._get_completion_kwargs(),
|
585 |
)
|
586 |
+
prediction = response.choices[0].message.content
|
587 |
+
output = self.get_return_object(prediction, response, return_meta_data)
|
588 |
|
589 |
outputs.append(output)
|
590 |
|
591 |
return outputs
|
592 |
|
593 |
+
def _infer_log_probs(
|
594 |
+
self,
|
595 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
596 |
+
return_meta_data: bool = False,
|
597 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
598 |
outputs = []
|
599 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
600 |
response = self.client.chat.completions.create(
|
|
|
612 |
**self._get_completion_kwargs(),
|
613 |
)
|
614 |
top_logprobs_response = response.choices[0].logprobs.content
|
615 |
+
pred_output = [
|
616 |
{
|
617 |
"top_tokens": [
|
618 |
{"text": obj.token, "logprob": obj.logprob}
|
|
|
621 |
}
|
622 |
for generated_token in top_logprobs_response
|
623 |
]
|
624 |
+
output = self.get_return_object(pred_output, response, return_meta_data)
|
625 |
outputs.append(output)
|
626 |
return outputs
|
627 |
|
628 |
+
def get_return_object(self, predict_result, response, return_meta_data):
|
629 |
+
if return_meta_data:
|
630 |
+
return TextGenerationInferenceOutput(
|
631 |
+
prediction=predict_result,
|
632 |
+
input_tokens=response.usage.prompt_tokens,
|
633 |
+
output_tokens=response.usage.completion_tokens,
|
634 |
+
model_name=self.model_name,
|
635 |
+
inference_type=self.label,
|
636 |
+
)
|
637 |
+
return predict_result
|
638 |
+
|
639 |
|
640 |
class TogetherAiInferenceEngineParamsMixin(Artifact):
|
641 |
max_tokens: Optional[int] = None
|
|
|
663 |
data_classification_policy = ["public"]
|
664 |
parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None
|
665 |
|
666 |
+
def get_engine_id(self):
|
667 |
+
return get_model_and_label_id(self.model_name, self.label)
|
668 |
+
|
669 |
def prepare_engine(self):
|
670 |
from together import Together
|
671 |
from together.types.models import ModelType
|
|
|
717 |
)
|
718 |
return response.choices[0].text
|
719 |
|
720 |
+
def _infer(
|
721 |
+
self,
|
722 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
723 |
+
return_meta_data: bool = False,
|
724 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
725 |
from together.types.models import ModelType
|
726 |
|
727 |
outputs = []
|
|
|
734 |
return outputs
|
735 |
|
736 |
|
737 |
+
class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
|
738 |
+
label: str = "vllm"
|
739 |
+
|
740 |
+
def create_client(self):
|
741 |
+
from openai import OpenAI
|
742 |
+
|
743 |
+
api_key = self.get_api_param(
|
744 |
+
inference_engine="VLLMRemoteInferenceEngine",
|
745 |
+
api_param_env_var_name="VLLM_API_KEY",
|
746 |
+
)
|
747 |
+
api_url = self.get_api_param(
|
748 |
+
inference_engine="VLLMRemoteInferenceEngine",
|
749 |
+
api_param_env_var_name="VLLM_API_URL",
|
750 |
+
)
|
751 |
+
return OpenAI(api_key=api_key, base_url=api_url)
|
752 |
+
|
753 |
+
|
754 |
class WMLInferenceEngineParamsMixin(Artifact):
|
755 |
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
756 |
length_penalty: Optional[Dict[str, Union[int, float]]] = None
|
|
|
787 |
|
788 |
|
789 |
class WMLInferenceEngine(
|
790 |
+
InferenceEngine,
|
791 |
+
WMLInferenceEngineParamsMixin,
|
792 |
+
PackageRequirementsMixin,
|
793 |
+
LogProbInferenceEngine,
|
794 |
):
|
795 |
"""Runs inference using ibm-watsonx-ai.
|
796 |
|
|
|
844 |
concurrency_limit: int = 10
|
845 |
_client: Any = InternalField(default=None, name="WML client")
|
846 |
|
847 |
+
def get_engine_id(self):
|
848 |
+
return get_model_and_label_id(self.model_name, self.label)
|
849 |
+
|
850 |
def verify(self):
|
851 |
super().verify()
|
852 |
|
853 |
if self.credentials is not None:
|
854 |
for key in self.credentials:
|
855 |
+
if key not in ["url", "apikey", "project_id", "space_id"]:
|
856 |
raise ValueError(
|
857 |
+
f'Illegal credential key: {key}, use only ["url", "apikey", "project_id", "space_id"]'
|
858 |
)
|
859 |
|
860 |
assert (
|
|
|
874 |
|
875 |
@staticmethod
|
876 |
def _read_wml_credentials_from_env() -> (
|
877 |
+
Dict[Literal["url", "apikey", "project_id", "space_id"], str]
|
878 |
):
|
879 |
credentials = {}
|
880 |
+
project_or_deployment_var_name = (
|
881 |
+
"WML_SPACE_ID" if "WML_SPACE_ID" in os.environ else "WML_PROJECT_ID"
|
882 |
+
)
|
883 |
+
|
884 |
+
for env_var_name in ["WML_URL", project_or_deployment_var_name, "WML_APIKEY"]:
|
885 |
env_var = os.environ.get(env_var_name)
|
886 |
assert env_var, (
|
887 |
f"Error while trying to run 'WMLInferenceEngine'. "
|
|
|
902 |
self.credentials = self._read_wml_credentials_from_env()
|
903 |
|
904 |
client = APIClient(credentials=self.credentials)
|
905 |
+
if "space_id" in self.credentials:
|
906 |
+
client.set.default_space(self.credentials["space_id"])
|
907 |
+
else:
|
908 |
+
client.set.default_project(self.credentials["project_id"])
|
909 |
return client
|
910 |
|
911 |
def prepare_engine(self):
|
|
|
913 |
|
914 |
self._set_inference_parameters()
|
915 |
|
916 |
+
def _load_model_and_params(self):
|
917 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
918 |
|
919 |
model = ModelInference(
|
|
|
921 |
deployment_id=self.deployment_id,
|
922 |
api_client=self._client,
|
923 |
)
|
924 |
+
params = self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False)
|
925 |
|
926 |
+
return model, params
|
|
|
927 |
|
928 |
+
def _infer(
|
929 |
+
self,
|
930 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
931 |
+
return_meta_data: bool = False,
|
932 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
933 |
+
model, params = self._load_model_and_params()
|
934 |
+
|
935 |
+
result = []
|
936 |
+
for instance in dataset:
|
937 |
+
instance_result = model.generate(
|
938 |
prompt=instance["source"],
|
939 |
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
|
940 |
)
|
941 |
+
prediction = instance_result["results"][0]["generated_text"]
|
942 |
+
instance_final_results = self.get_return_object(
|
943 |
+
prediction, instance_result, return_meta_data
|
944 |
+
)
|
945 |
+
result.append(instance_final_results)
|
946 |
+
|
947 |
+
return result
|
948 |
+
|
949 |
+
def _infer_log_probs(
|
950 |
+
self,
|
951 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
952 |
+
return_meta_data: bool = False,
|
953 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
954 |
+
model, params = self._load_model_and_params()
|
955 |
+
|
956 |
+
user_return_options = params.pop("return_options", {})
|
957 |
+
# currently this is the only configuration that returns generated logprobs and behaves as expected
|
958 |
+
logprobs_return_options = {
|
959 |
+
"input_tokens": True,
|
960 |
+
"generated_tokens": True,
|
961 |
+
"token_logprobs": True,
|
962 |
+
"top_n_tokens": user_return_options.get("top_n_tokens", 5),
|
963 |
+
}
|
964 |
+
for key, value in logprobs_return_options.items():
|
965 |
+
if key in user_return_options and user_return_options[key] != value:
|
966 |
+
raise ValueError(
|
967 |
+
f"'{key}={user_return_options[key]}' is not supported for the 'infer_log_probs' "
|
968 |
+
f"method of {self.__class__.__name__}. For obtaining the logprobs of generated tokens "
|
969 |
+
f"please use '{key}={value}'."
|
970 |
+
)
|
971 |
+
|
972 |
+
params = {
|
973 |
+
**params,
|
974 |
+
"return_options": logprobs_return_options,
|
975 |
+
}
|
976 |
|
977 |
+
results = model.generate(
|
978 |
+
prompt=[instance["source"] for instance in dataset],
|
979 |
+
params=params,
|
980 |
+
)
|
981 |
+
final_results = []
|
982 |
+
for result in results:
|
983 |
+
generated_tokens = result["results"][0]["generated_tokens"]
|
984 |
+
final_results.append(
|
985 |
+
self.get_return_object(generated_tokens, result, return_meta_data)
|
986 |
+
)
|
987 |
+
return final_results
|
988 |
+
|
989 |
+
def get_return_object(self, predict_result, result, return_meta_data):
|
990 |
+
if return_meta_data:
|
991 |
+
return TextGenerationInferenceOutput(
|
992 |
+
prediction=predict_result,
|
993 |
+
input_tokens=result["results"][0]["input_token_count"],
|
994 |
+
output_tokens=result["results"][0]["generated_token_count"],
|
995 |
+
model_name=self.model_name,
|
996 |
+
inference_type=self.label,
|
997 |
+
)
|
998 |
+
return predict_result
|
999 |
|
1000 |
|
1001 |
class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
|
|
1009 |
"accelerate": "pip install accelerate",
|
1010 |
}
|
1011 |
|
1012 |
+
def get_engine_id(self):
|
1013 |
+
return get_model_and_label_id(self.model_name, "hf_lava")
|
1014 |
+
|
1015 |
def _prepare_engine(self):
|
1016 |
import torch
|
1017 |
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
|
1039 |
def _is_loaded(self):
|
1040 |
return hasattr(self, "model") and self.model is not None
|
1041 |
|
1042 |
+
def _infer(
|
1043 |
+
self,
|
1044 |
+
dataset: Union[List[Dict[str, Any]], DatasetDict],
|
1045 |
+
return_meta_data: bool = False,
|
1046 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1047 |
if not self._is_loaded():
|
1048 |
self._prepare_engine()
|
1049 |
|
1050 |
import torch
|
1051 |
|
1052 |
results = []
|
1053 |
+
for instance in tqdm(dataset):
|
1054 |
text = instance["source"]
|
1055 |
images = extract_images(text, instance)
|
1056 |
# Regular expression to match all <img src="..."> tags
|
|
|
1063 |
).to(self.device, torch.float16)
|
1064 |
input_len = len(inputs["input_ids"][0])
|
1065 |
output = self.model.generate(
|
1066 |
+
**inputs,
|
1067 |
+
max_new_tokens=self.max_new_tokens,
|
1068 |
+
do_sample=False,
|
1069 |
+
pad_token_id=self.processor.tokenizer.eos_token_id,
|
1070 |
)
|
1071 |
result = self.processor.decode(
|
1072 |
output[0][input_len:], skip_special_tokens=True
|
llm_as_judge.py
CHANGED
@@ -1,10 +1,11 @@
|
|
|
|
1 |
from typing import Any, Dict, List, Literal, Optional
|
2 |
|
3 |
from .api import infer
|
4 |
from .artifact import fetch_artifact
|
5 |
from .dataclass import Field
|
6 |
from .formats import Format, SystemFormat
|
7 |
-
from .inference import InferenceEngine, OpenAiInferenceEngine
|
8 |
from .metrics import BulkInstanceMetric
|
9 |
from .operator import SequentialOperator
|
10 |
from .settings_utils import get_settings
|
@@ -14,38 +15,142 @@ from .templates import Template
|
|
14 |
settings = get_settings()
|
15 |
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
Attributes:
|
21 |
main_score (str): The main score label used for evaluation.
|
22 |
-
task (
|
23 |
format of the judge model.
|
24 |
template (Template): The template used when generating inputs for the judge llm.
|
25 |
format (Format): The format used when generating inputs for judge llm.
|
26 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
27 |
-
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
28 |
-
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
29 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
30 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
31 |
batch_size (int): The size of the bulk.
|
32 |
"""
|
33 |
|
34 |
main_score: str = "llm_as_judge"
|
35 |
-
task:
|
36 |
-
"rating.single_turn",
|
37 |
-
"rating.single_turn_with_reference",
|
38 |
-
"pairwise_comparative_rating.single_turn",
|
39 |
-
]
|
40 |
template: Template
|
41 |
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
42 |
format: Format = Field(default_factory=SystemFormat)
|
43 |
-
strip_system_prompt_and_format_from_inputs: bool = True
|
44 |
inference_model: InferenceEngine
|
45 |
reduction_map: Optional[Dict[str, List[str]]] = None
|
46 |
batch_size: int = 32
|
47 |
prediction_type = Any # Because handled with multiple tasks
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def _get_input_instances(self, task_data: List[Dict]) -> List:
|
50 |
if self.strip_system_prompt_and_format_from_inputs:
|
51 |
instances = []
|
@@ -119,6 +224,7 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
119 |
self.reduction_map = {"mean": [self.main_score]}
|
120 |
|
121 |
def verify(self):
|
|
|
122 |
supported_tasks = [
|
123 |
"rating.single_turn",
|
124 |
"rating.single_turn_with_reference",
|
@@ -129,68 +235,25 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
129 |
f"The supported tasks types are: {', '.join(supported_tasks)}."
|
130 |
)
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
f"Provided template argument to 'LLMAsJudge' metric is not of type Template, but {type(self.template)}"
|
135 |
-
)
|
136 |
-
if self.format and not isinstance(self.format, Format):
|
137 |
-
raise ValueError(
|
138 |
-
f"Provided format argument to 'LLMAsJudge' metric is not of type Format, but {type(self.format)}"
|
139 |
-
)
|
140 |
-
|
141 |
-
if self.system_prompt and not isinstance(self.system_prompt, SystemPrompt):
|
142 |
-
raise ValueError(
|
143 |
-
f"Provided system_prompt argument to 'LLMAsJudge' metric is not of type SystemPrompt, but {type(self.system_prompt)}"
|
144 |
-
)
|
145 |
-
|
146 |
-
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
147 |
-
if self.format and type(self.format) is not SystemFormat:
|
148 |
-
raise ValueError(
|
149 |
-
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
150 |
-
"not support formatting. Please remove the format definition from the recipe"
|
151 |
-
" (OpenAi Chat API take care of the formatting automatically)."
|
152 |
-
)
|
153 |
-
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
154 |
-
raise ValueError(
|
155 |
-
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
156 |
-
"not support system prompt. Please remove the system_prompt definition from the recipe"
|
157 |
-
" (Current implementation of Unitxt does not support this."
|
158 |
-
" Support will be added in future updates)."
|
159 |
-
)
|
160 |
|
161 |
-
def
|
162 |
-
|
163 |
-
references: List[List[Any]],
|
164 |
-
predictions: List[Any],
|
165 |
-
task_data: List[Dict],
|
166 |
-
) -> List[Dict[str, Any]]:
|
167 |
-
input_instances = self._get_input_instances(task_data)
|
168 |
-
instances = self._get_instance_for_judge_model(
|
169 |
-
input_instances, predictions, references
|
170 |
-
)
|
171 |
-
outputs = infer(
|
172 |
instances,
|
173 |
engine=self.inference_model,
|
174 |
-
task=
|
175 |
template=self.template,
|
176 |
system_prompt=self.system_prompt,
|
177 |
format=self.format,
|
178 |
return_data=True,
|
179 |
)
|
180 |
|
|
|
181 |
results = []
|
182 |
for instance in outputs:
|
183 |
if self.task == "pairwise_comparative_rating.single_turn":
|
184 |
-
|
185 |
-
|
186 |
-
# seems like the task data sometimes comes as a string, not a dict
|
187 |
-
# this fixes it
|
188 |
-
task_data = (
|
189 |
-
json.loads(instance["task_data"])
|
190 |
-
if isinstance(instance["task_data"], str)
|
191 |
-
else instance["task_data"]
|
192 |
-
)
|
193 |
-
|
194 |
is_model_b_the_baseline = task_data["model_b"] == "baseline_model"
|
195 |
if is_model_b_the_baseline:
|
196 |
model_a_preference_score = instance["prediction"]
|
@@ -209,5 +272,141 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
209 |
"judge_raw_input": instance["source"],
|
210 |
}
|
211 |
results.append(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
from typing import Any, Dict, List, Literal, Optional
|
3 |
|
4 |
from .api import infer
|
5 |
from .artifact import fetch_artifact
|
6 |
from .dataclass import Field
|
7 |
from .formats import Format, SystemFormat
|
8 |
+
from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
|
9 |
from .metrics import BulkInstanceMetric
|
10 |
from .operator import SequentialOperator
|
11 |
from .settings_utils import get_settings
|
|
|
15 |
settings = get_settings()
|
16 |
|
17 |
|
18 |
+
def get_task_data_dict(task_data):
|
19 |
+
import json
|
20 |
+
|
21 |
+
# seems like the task data sometimes comes as a string, not a dict
|
22 |
+
# this fixes it
|
23 |
+
return json.loads(task_data) if isinstance(task_data, str) else task_data
|
24 |
+
|
25 |
+
|
26 |
+
class LLMAsJudgeBase(BulkInstanceMetric):
|
27 |
+
"""LLM-as-judge-base metric class for evaluating correctness of generated predictions.
|
28 |
|
29 |
Attributes:
|
30 |
main_score (str): The main score label used for evaluation.
|
31 |
+
task (str): The type of task the llm as judge runs. This defines the output and input
|
32 |
format of the judge model.
|
33 |
template (Template): The template used when generating inputs for the judge llm.
|
34 |
format (Format): The format used when generating inputs for judge llm.
|
35 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
|
|
|
|
36 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
37 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
38 |
batch_size (int): The size of the bulk.
|
39 |
"""
|
40 |
|
41 |
main_score: str = "llm_as_judge"
|
42 |
+
task: str
|
|
|
|
|
|
|
|
|
43 |
template: Template
|
44 |
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
45 |
format: Format = Field(default_factory=SystemFormat)
|
|
|
46 |
inference_model: InferenceEngine
|
47 |
reduction_map: Optional[Dict[str, List[str]]] = None
|
48 |
batch_size: int = 32
|
49 |
prediction_type = Any # Because handled with multiple tasks
|
50 |
|
51 |
+
def verify(self):
|
52 |
+
if not isinstance(self.template, Template):
|
53 |
+
raise ValueError(
|
54 |
+
f"Provided template argument to 'LLMAsJudge' metric is not of type Template, but {type(self.template)}"
|
55 |
+
)
|
56 |
+
if self.format and not isinstance(self.format, Format):
|
57 |
+
raise ValueError(
|
58 |
+
f"Provided format argument to 'LLMAsJudge' metric is not of type Format, but {type(self.format)}"
|
59 |
+
)
|
60 |
+
|
61 |
+
if self.system_prompt and not isinstance(self.system_prompt, SystemPrompt):
|
62 |
+
raise ValueError(
|
63 |
+
f"Provided system_prompt argument to 'LLMAsJudge' metric is not of type SystemPrompt, but {type(self.system_prompt)}"
|
64 |
+
)
|
65 |
+
|
66 |
+
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
67 |
+
if self.format and type(self.format) is not SystemFormat:
|
68 |
+
raise ValueError(
|
69 |
+
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
70 |
+
"not support formatting. Please remove the format definition from the recipe"
|
71 |
+
" (OpenAi Chat API take care of the formatting automatically)."
|
72 |
+
)
|
73 |
+
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
74 |
+
raise ValueError(
|
75 |
+
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
76 |
+
"not support system prompt. Please remove the system_prompt definition from the recipe"
|
77 |
+
" (Current implementation of Unitxt does not support this."
|
78 |
+
" Support will be added in future updates)."
|
79 |
+
)
|
80 |
+
|
81 |
+
@abstractmethod
|
82 |
+
def get_full_task_name(self):
|
83 |
+
pass
|
84 |
+
|
85 |
+
def compute(
|
86 |
+
self,
|
87 |
+
references: List[List[Any]],
|
88 |
+
predictions: List[Any],
|
89 |
+
task_data: List[Dict],
|
90 |
+
) -> List[Dict[str, Any]]:
|
91 |
+
instances = self.prepare_instances(references, predictions, task_data)
|
92 |
+
outputs = self.infer_instances(instances)
|
93 |
+
return self.get_metric_results_from_prediction_outputs(outputs)
|
94 |
+
|
95 |
+
@abstractmethod
|
96 |
+
def prepare_instances(
|
97 |
+
self, references, predictions, task_data
|
98 |
+
) -> List[Dict[str, Any]]:
|
99 |
+
"""Generate a list of instances for inference.
|
100 |
+
|
101 |
+
Each generated instance should include all the fields required by the metrics' task and template, to
|
102 |
+
create the source prompt for the judge.
|
103 |
+
"""
|
104 |
+
pass
|
105 |
+
|
106 |
+
@abstractmethod
|
107 |
+
def infer_instances(self, instances: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
108 |
+
"""Generate the dataset and call the inference engine to generate the judges' predictions.
|
109 |
+
|
110 |
+
Return the list of the produced instances with their generated judge predictions.
|
111 |
+
"""
|
112 |
+
pass
|
113 |
+
|
114 |
+
@abstractmethod
|
115 |
+
def get_metric_results_from_prediction_outputs(
|
116 |
+
self, outputs: List[Dict[str, Any]]
|
117 |
+
) -> List[Dict[str, Any]]:
|
118 |
+
"""Generate a scores' dictionary for each instance.
|
119 |
+
|
120 |
+
Return the list of scores dictionaries for the input instances.
|
121 |
+
"""
|
122 |
+
pass
|
123 |
+
|
124 |
+
|
125 |
+
class LLMAsJudge(LLMAsJudgeBase):
|
126 |
+
"""LLM-as-judge-based metric class for evaluating correctness of generated predictions.
|
127 |
+
|
128 |
+
This class uses the source prompt given to the generator and the generator's predictions to evaluate
|
129 |
+
correctness using one of three supported tasks (rating.single_turn, rating.single_turn_with_reference,
|
130 |
+
pairwise_comparative_rating.single_turn).
|
131 |
+
|
132 |
+
Attributes:
|
133 |
+
main_score (str): The main score label used for evaluation.
|
134 |
+
task (Literal["rating.single_turn","rating.single_turn_with_reference",
|
135 |
+
"pairwise_comparative_rating.single_turn"]): The type of task the llm as judge runs.
|
136 |
+
This defines the output and input format of the judge model.
|
137 |
+
template (Template): The template used when generating inputs for the judge llm.
|
138 |
+
format (Format): The format used when generating inputs for judge llm.
|
139 |
+
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
140 |
+
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
141 |
+
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
142 |
+
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
143 |
+
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
144 |
+
batch_size (int): The size of the bulk.
|
145 |
+
"""
|
146 |
+
|
147 |
+
task: Literal[
|
148 |
+
"rating.single_turn",
|
149 |
+
"rating.single_turn_with_reference",
|
150 |
+
"pairwise_comparative_rating.single_turn",
|
151 |
+
]
|
152 |
+
strip_system_prompt_and_format_from_inputs: bool = True
|
153 |
+
|
154 |
def _get_input_instances(self, task_data: List[Dict]) -> List:
|
155 |
if self.strip_system_prompt_and_format_from_inputs:
|
156 |
instances = []
|
|
|
224 |
self.reduction_map = {"mean": [self.main_score]}
|
225 |
|
226 |
def verify(self):
|
227 |
+
super().verify()
|
228 |
supported_tasks = [
|
229 |
"rating.single_turn",
|
230 |
"rating.single_turn_with_reference",
|
|
|
235 |
f"The supported tasks types are: {', '.join(supported_tasks)}."
|
236 |
)
|
237 |
|
238 |
+
def get_full_task_name(self):
|
239 |
+
return f"tasks.response_assessment.{self.task}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
+
def infer_instances(self, instances):
|
242 |
+
return infer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
instances,
|
244 |
engine=self.inference_model,
|
245 |
+
task=self.get_full_task_name(),
|
246 |
template=self.template,
|
247 |
system_prompt=self.system_prompt,
|
248 |
format=self.format,
|
249 |
return_data=True,
|
250 |
)
|
251 |
|
252 |
+
def get_metric_results_from_prediction_outputs(self, outputs):
|
253 |
results = []
|
254 |
for instance in outputs:
|
255 |
if self.task == "pairwise_comparative_rating.single_turn":
|
256 |
+
task_data = get_task_data_dict(instance["task_data"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
is_model_b_the_baseline = task_data["model_b"] == "baseline_model"
|
258 |
if is_model_b_the_baseline:
|
259 |
model_a_preference_score = instance["prediction"]
|
|
|
272 |
"judge_raw_input": instance["source"],
|
273 |
}
|
274 |
results.append(result)
|
275 |
+
return results
|
276 |
+
|
277 |
+
def prepare_instances(self, references, predictions, task_data):
|
278 |
+
input_instances = self._get_input_instances(task_data)
|
279 |
+
return self._get_instance_for_judge_model(
|
280 |
+
input_instances, predictions, references
|
281 |
+
)
|
282 |
|
283 |
+
|
284 |
+
class TaskBasedLLMasJudge(LLMAsJudgeBase):
|
285 |
+
"""LLM-as-judge-based metric class for evaluating correctness of generated predictions.
|
286 |
+
|
287 |
+
This class can use any task and matching template to evaluate the predictions. All
|
288 |
+
task/templates field are taken from the instance's task_data.
|
289 |
+
The instances sent to the judge can either be: 1.a unitxt dataset, in which case the predictions are
|
290 |
+
copied to a specified field of the task. 2. dictionaries with the fields required by the task and template.
|
291 |
+
|
292 |
+
Attributes:
|
293 |
+
main_score (str): The main score label used for evaluation.
|
294 |
+
task (str): The type of task the llm as judge runs.
|
295 |
+
This defines the output and input format of the judge model.
|
296 |
+
template (Template): The template used when generating inputs for the judge llm.
|
297 |
+
format (Format): The format used when generating inputs for judge llm.
|
298 |
+
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
299 |
+
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
300 |
+
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
301 |
+
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
302 |
+
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
303 |
+
batch_size (int): The size of the bulk.
|
304 |
+
infer_log_probs(bool): whether to perform the inference using logprobs. If true, the template's
|
305 |
+
post-processing must support the logprobs output.
|
306 |
+
judge_to_generator_fields_mapping (Dict[str, str]): optional mapping between the names of the fields in the generator task and the
|
307 |
+
judge task. For example, if the generator task uses "reference_answers" and the judge task expect "ground_truth",
|
308 |
+
include {"ground_truth": "reference_answers"} in this dictionary.
|
309 |
+
prediction_field: if indicated, and prediction exist, copy prediction to this field name in task_data.
|
310 |
+
include_meta_data (bool): whether to include the inference per-instance metadata in the returned results.
|
311 |
+
|
312 |
+
"""
|
313 |
+
|
314 |
+
infer_log_probs: bool = False
|
315 |
+
judge_to_generator_fields_mapping: Dict[str, str] = {}
|
316 |
+
prediction_field: Optional[str] = None
|
317 |
+
include_meta_data: bool = True
|
318 |
+
|
319 |
+
# Allow for input which is a dictionary of all input fields. In this case, all input fields are
|
320 |
+
# treated as the task data, and the predictions and references are taken directly from there
|
321 |
+
# by the judge's template
|
322 |
+
def preprocess_instance(self, instance):
|
323 |
+
if "task_data" not in instance:
|
324 |
+
instance["task_data"] = instance.copy()
|
325 |
+
if "prediction" not in instance:
|
326 |
+
instance["prediction"] = None
|
327 |
+
if "references" not in instance:
|
328 |
+
instance["references"] = [""]
|
329 |
+
return instance
|
330 |
+
|
331 |
+
def verify(self):
|
332 |
+
super().verify()
|
333 |
+
if self.infer_log_probs and not isinstance(
|
334 |
+
self.inference_model, LogProbInferenceEngine
|
335 |
+
):
|
336 |
+
raise NotImplementedError(
|
337 |
+
f"Error in TaskBasedLLMasJudge: return_log_probs set to True but supplied engine "
|
338 |
+
f"{self.inference_model.__class__.__name__} does not support logprobs."
|
339 |
+
)
|
340 |
+
if self.include_meta_data and not hasattr(
|
341 |
+
self.inference_model, "get_return_object"
|
342 |
+
):
|
343 |
+
Warning(
|
344 |
+
f"Supplied inference engine {self.inference_model.__class__.__name__} does not support "
|
345 |
+
"return_meta_data. Setting return_meta_data to False. Metadata scores will not appear "
|
346 |
+
"in returned instances scores."
|
347 |
+
)
|
348 |
+
self.include_meta_data = False
|
349 |
+
|
350 |
+
def prepare(self):
|
351 |
+
super().prepare()
|
352 |
+
self.reduction_map = {"mean": [self.main_score]}
|
353 |
+
self.score_prefix = f"{self.inference_model.get_engine_id()}_"
|
354 |
+
|
355 |
+
def get_full_task_name(self):
|
356 |
+
return self.task
|
357 |
+
|
358 |
+
def get_metric_results_from_prediction_outputs(self, outputs):
|
359 |
+
results = []
|
360 |
+
for instance in outputs:
|
361 |
+
result = {
|
362 |
+
self.main_score: instance["prediction"],
|
363 |
+
f"{self.main_score}_judge_raw_output": instance["raw_prediction"],
|
364 |
+
f"{self.main_score}_judge_raw_input": instance["source"],
|
365 |
+
}
|
366 |
+
if self.include_meta_data:
|
367 |
+
meta_data = {
|
368 |
+
f"{self.main_score}_{k}": v
|
369 |
+
for k, v in instance["infer_meta_data"].items()
|
370 |
+
}
|
371 |
+
result.update(meta_data)
|
372 |
+
results.append(result)
|
373 |
return results
|
374 |
+
|
375 |
+
def prepare_instances(self, references, predictions, task_data):
|
376 |
+
from . import get_from_catalog
|
377 |
+
|
378 |
+
instances = []
|
379 |
+
judge_task = get_from_catalog(self.get_full_task_name())
|
380 |
+
judge_task_input_fields = judge_task.input_fields
|
381 |
+
|
382 |
+
for input_instance, prediction, _ in zip(task_data, predictions, references):
|
383 |
+
input_instance = get_task_data_dict(input_instance)
|
384 |
+
|
385 |
+
instance_task_data = {}
|
386 |
+
for judge_task_input_field in judge_task_input_fields:
|
387 |
+
orig_task_field_name = self.judge_to_generator_fields_mapping.get(
|
388 |
+
judge_task_input_field, judge_task_input_field
|
389 |
+
)
|
390 |
+
new_val = input_instance.get(orig_task_field_name)
|
391 |
+
if new_val:
|
392 |
+
instance_task_data[judge_task_input_field] = new_val
|
393 |
+
|
394 |
+
if self.prediction_field and prediction:
|
395 |
+
instance_task_data[self.prediction_field] = str(prediction)
|
396 |
+
instance_task_data = judge_task.process(instance_task_data)["input_fields"]
|
397 |
+
instances.append(instance_task_data)
|
398 |
+
|
399 |
+
return instances
|
400 |
+
|
401 |
+
def infer_instances(self, instances):
|
402 |
+
return infer(
|
403 |
+
instances,
|
404 |
+
engine=self.inference_model,
|
405 |
+
task=self.get_full_task_name(),
|
406 |
+
template=self.template,
|
407 |
+
system_prompt=self.system_prompt,
|
408 |
+
format=self.format,
|
409 |
+
return_data=True,
|
410 |
+
return_log_probs=self.infer_log_probs,
|
411 |
+
return_meta_data=self.include_meta_data,
|
412 |
+
)
|
loaders.py
CHANGED
@@ -53,7 +53,7 @@ from .operators import Set
|
|
53 |
from .settings_utils import get_settings
|
54 |
from .stream import DynamicStream, MultiStream
|
55 |
from .type_utils import isoftype
|
56 |
-
from .utils import
|
57 |
|
58 |
logger = get_logger()
|
59 |
settings = get_settings()
|
@@ -195,6 +195,10 @@ class LoadHF(Loader):
|
|
195 |
def stream_dataset(self):
|
196 |
if self._cache is None:
|
197 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
|
|
|
|
|
|
|
|
198 |
try:
|
199 |
dataset = hf_load_dataset(
|
200 |
self.path,
|
@@ -203,7 +207,7 @@ class LoadHF(Loader):
|
|
203 |
data_files=self.data_files,
|
204 |
revision=self.revision,
|
205 |
streaming=self.streaming,
|
206 |
-
cache_dir=
|
207 |
split=self.split,
|
208 |
trust_remote_code=settings.allow_unverified_code,
|
209 |
num_proc=self.num_proc,
|
@@ -231,6 +235,10 @@ class LoadHF(Loader):
|
|
231 |
def load_dataset(self):
|
232 |
if self._cache is None:
|
233 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
|
|
|
|
|
|
|
|
234 |
try:
|
235 |
dataset = hf_load_dataset(
|
236 |
self.path,
|
@@ -239,7 +247,7 @@ class LoadHF(Loader):
|
|
239 |
data_files=self.data_files,
|
240 |
streaming=False,
|
241 |
keep_in_memory=True,
|
242 |
-
cache_dir=
|
243 |
split=self.split,
|
244 |
trust_remote_code=settings.allow_unverified_code,
|
245 |
num_proc=self.num_proc,
|
@@ -664,7 +672,7 @@ class MultipleSourceLoader(Loader):
|
|
664 |
|
665 |
.. code-block:: python
|
666 |
|
667 |
-
MultipleSourceLoader(
|
668 |
|
669 |
|
670 |
|
@@ -672,7 +680,7 @@ class MultipleSourceLoader(Loader):
|
|
672 |
|
673 |
.. code-block:: python
|
674 |
|
675 |
-
MultipleSourceLoader(
|
676 |
"""
|
677 |
|
678 |
sources: List[Loader]
|
@@ -737,7 +745,7 @@ class LoadFromDictionary(Loader):
|
|
737 |
self.sef_default_data_classification(
|
738 |
["proprietary"], "when loading from python dictionary"
|
739 |
)
|
740 |
-
return MultiStream.from_iterables(
|
741 |
|
742 |
|
743 |
class LoadFromHFSpace(LoadHF):
|
|
|
53 |
from .settings_utils import get_settings
|
54 |
from .stream import DynamicStream, MultiStream
|
55 |
from .type_utils import isoftype
|
56 |
+
from .utils import recursive_copy
|
57 |
|
58 |
logger = get_logger()
|
59 |
settings = get_settings()
|
|
|
195 |
def stream_dataset(self):
|
196 |
if self._cache is None:
|
197 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
198 |
+
if settings.disable_hf_datasets_cache and not self.streaming:
|
199 |
+
cache_dir = dir_to_be_deleted
|
200 |
+
else:
|
201 |
+
cache_dir = None
|
202 |
try:
|
203 |
dataset = hf_load_dataset(
|
204 |
self.path,
|
|
|
207 |
data_files=self.data_files,
|
208 |
revision=self.revision,
|
209 |
streaming=self.streaming,
|
210 |
+
cache_dir=cache_dir,
|
211 |
split=self.split,
|
212 |
trust_remote_code=settings.allow_unverified_code,
|
213 |
num_proc=self.num_proc,
|
|
|
235 |
def load_dataset(self):
|
236 |
if self._cache is None:
|
237 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
238 |
+
if settings.disable_hf_datasets_cache:
|
239 |
+
cache_dir = dir_to_be_deleted
|
240 |
+
else:
|
241 |
+
cache_dir = None
|
242 |
try:
|
243 |
dataset = hf_load_dataset(
|
244 |
self.path,
|
|
|
247 |
data_files=self.data_files,
|
248 |
streaming=False,
|
249 |
keep_in_memory=True,
|
250 |
+
cache_dir=cache_dir,
|
251 |
split=self.split,
|
252 |
trust_remote_code=settings.allow_unverified_code,
|
253 |
num_proc=self.num_proc,
|
|
|
672 |
|
673 |
.. code-block:: python
|
674 |
|
675 |
+
MultipleSourceLoader(sources = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
|
676 |
|
677 |
|
678 |
|
|
|
680 |
|
681 |
.. code-block:: python
|
682 |
|
683 |
+
MultipleSourceLoader(sources = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
|
684 |
"""
|
685 |
|
686 |
sources: List[Loader]
|
|
|
745 |
self.sef_default_data_classification(
|
746 |
["proprietary"], "when loading from python dictionary"
|
747 |
)
|
748 |
+
return MultiStream.from_iterables(recursive_copy(self.data))
|
749 |
|
750 |
|
751 |
class LoadFromHFSpace(LoadHF):
|
metric_utils.py
CHANGED
@@ -16,8 +16,8 @@ from .operator import (
|
|
16 |
from .operators import (
|
17 |
ApplyMetric,
|
18 |
ApplyOperatorsField,
|
19 |
-
Copy,
|
20 |
FlattenInstances,
|
|
|
21 |
Rename,
|
22 |
)
|
23 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
@@ -25,7 +25,7 @@ from .schema import UNITXT_DATASET_SCHEMA
|
|
25 |
from .settings_utils import get_constants, get_settings
|
26 |
from .stream import DynamicStream, MultiStream
|
27 |
from .struct_data_operators import LoadJson
|
28 |
-
from .utils import
|
29 |
|
30 |
constants = get_constants()
|
31 |
|
@@ -54,27 +54,27 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
54 |
|
55 |
_post_process_steps = SequentialOperator(
|
56 |
steps=[
|
57 |
-
|
58 |
field="prediction",
|
59 |
to_field="raw_prediction",
|
60 |
),
|
61 |
-
|
62 |
field="references",
|
63 |
to_field="raw_references",
|
64 |
dont_apply_to_streams=[constants.inference_stream],
|
65 |
),
|
66 |
-
|
67 |
field="source",
|
68 |
to_field="task_data/source",
|
69 |
),
|
70 |
ApplyOperatorsField(
|
71 |
operators_field="postprocessors",
|
72 |
),
|
73 |
-
|
74 |
field="prediction",
|
75 |
to_field="processed_prediction",
|
76 |
),
|
77 |
-
|
78 |
field="references",
|
79 |
to_field="processed_references",
|
80 |
dont_apply_to_streams=[constants.inference_stream],
|
@@ -213,14 +213,19 @@ class JoinSubsetsAndGroups(MultiStreamOperator):
|
|
213 |
|
214 |
result = {}
|
215 |
all_scores = []
|
|
|
216 |
for k, v in dic.items():
|
217 |
score = recursive_mean(v)
|
218 |
if score is not None:
|
219 |
all_scores.append(score["score"])
|
|
|
|
|
220 |
result[k] = score
|
221 |
|
222 |
result["score"] = nan_mean(all_scores)
|
223 |
result["score_name"] = "subsets_mean"
|
|
|
|
|
224 |
|
225 |
if result:
|
226 |
return result
|
@@ -237,11 +242,15 @@ class JoinSubsetsAndGroups(MultiStreamOperator):
|
|
237 |
"score": score["subsets"]["score"],
|
238 |
"score_name": score["subsets"]["score_name"],
|
239 |
}
|
|
|
|
|
|
|
|
|
240 |
|
241 |
sorted_instances = []
|
242 |
for key in sorted(stream_instances.keys()):
|
243 |
instance = stream_instances[key]
|
244 |
-
instance["score"].update(
|
245 |
sorted_instances.append(instance)
|
246 |
result[stream_name] = sorted_instances
|
247 |
|
@@ -299,7 +308,7 @@ class MetricRecipe(SequentialOperatorInitializer):
|
|
299 |
field="raw_references",
|
300 |
to_field="references",
|
301 |
),
|
302 |
-
|
303 |
field="source",
|
304 |
to_field="task_data/source",
|
305 |
),
|
|
|
16 |
from .operators import (
|
17 |
ApplyMetric,
|
18 |
ApplyOperatorsField,
|
|
|
19 |
FlattenInstances,
|
20 |
+
RecursiveCopy,
|
21 |
Rename,
|
22 |
)
|
23 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
|
|
25 |
from .settings_utils import get_constants, get_settings
|
26 |
from .stream import DynamicStream, MultiStream
|
27 |
from .struct_data_operators import LoadJson
|
28 |
+
from .utils import recursive_shallow_copy
|
29 |
|
30 |
constants = get_constants()
|
31 |
|
|
|
54 |
|
55 |
_post_process_steps = SequentialOperator(
|
56 |
steps=[
|
57 |
+
RecursiveCopy(
|
58 |
field="prediction",
|
59 |
to_field="raw_prediction",
|
60 |
),
|
61 |
+
RecursiveCopy(
|
62 |
field="references",
|
63 |
to_field="raw_references",
|
64 |
dont_apply_to_streams=[constants.inference_stream],
|
65 |
),
|
66 |
+
RecursiveCopy(
|
67 |
field="source",
|
68 |
to_field="task_data/source",
|
69 |
),
|
70 |
ApplyOperatorsField(
|
71 |
operators_field="postprocessors",
|
72 |
),
|
73 |
+
RecursiveCopy(
|
74 |
field="prediction",
|
75 |
to_field="processed_prediction",
|
76 |
),
|
77 |
+
RecursiveCopy(
|
78 |
field="references",
|
79 |
to_field="processed_references",
|
80 |
dont_apply_to_streams=[constants.inference_stream],
|
|
|
213 |
|
214 |
result = {}
|
215 |
all_scores = []
|
216 |
+
all_num_of_instances = []
|
217 |
for k, v in dic.items():
|
218 |
score = recursive_mean(v)
|
219 |
if score is not None:
|
220 |
all_scores.append(score["score"])
|
221 |
+
if "num_of_instances" in score:
|
222 |
+
all_num_of_instances.append(score["num_of_instances"])
|
223 |
result[k] = score
|
224 |
|
225 |
result["score"] = nan_mean(all_scores)
|
226 |
result["score_name"] = "subsets_mean"
|
227 |
+
if all_num_of_instances:
|
228 |
+
result["num_of_instances"] = sum(all_num_of_instances)
|
229 |
|
230 |
if result:
|
231 |
return result
|
|
|
242 |
"score": score["subsets"]["score"],
|
243 |
"score_name": score["subsets"]["score_name"],
|
244 |
}
|
245 |
+
if "num_of_instances" in score["subsets"]:
|
246 |
+
score["global"]["num_of_instances"] = score["subsets"][
|
247 |
+
"num_of_instances"
|
248 |
+
]
|
249 |
|
250 |
sorted_instances = []
|
251 |
for key in sorted(stream_instances.keys()):
|
252 |
instance = stream_instances[key]
|
253 |
+
instance["score"].update(recursive_shallow_copy(score))
|
254 |
sorted_instances.append(instance)
|
255 |
result[stream_name] = sorted_instances
|
256 |
|
|
|
308 |
field="raw_references",
|
309 |
to_field="references",
|
310 |
),
|
311 |
+
RecursiveCopy(
|
312 |
field="source",
|
313 |
to_field="task_data/source",
|
314 |
),
|
metrics.py
CHANGED
@@ -8,10 +8,9 @@ import warnings
|
|
8 |
from abc import ABC, abstractmethod
|
9 |
from collections import Counter, defaultdict
|
10 |
from dataclasses import field
|
11 |
-
from
|
12 |
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
13 |
|
14 |
-
import evaluate
|
15 |
import numpy
|
16 |
import numpy as np
|
17 |
import pandas as pd
|
@@ -37,20 +36,18 @@ from .operator import (
|
|
37 |
StreamingOperator,
|
38 |
StreamOperator,
|
39 |
)
|
40 |
-
from .operators import Copy
|
41 |
from .random_utils import get_seed
|
42 |
from .settings_utils import get_settings
|
43 |
from .stream import MultiStream, Stream
|
44 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
45 |
-
from .utils import
|
46 |
|
47 |
logger = get_logger()
|
48 |
settings = get_settings()
|
49 |
|
50 |
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
51 |
|
52 |
-
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
53 |
-
|
54 |
|
55 |
def abstract_factory():
|
56 |
return {}
|
@@ -139,6 +136,7 @@ class Metric(Artifact):
|
|
139 |
return (
|
140 |
self.score_prefix + score_name
|
141 |
if score_name not in ["score", "score_name"]
|
|
|
142 |
else score_name
|
143 |
)
|
144 |
|
@@ -147,18 +145,24 @@ class Metric(Artifact):
|
|
147 |
) -> Dict[str, Any]:
|
148 |
new_scores = {}
|
149 |
for score_name, score in scores.items():
|
|
|
|
|
|
|
150 |
score_with_prefix = self._add_score_prefix(score_name)
|
151 |
new_scores[score_with_prefix] = (
|
152 |
score if score_name not in ["score_name"] else self.score_prefix + score
|
153 |
)
|
154 |
for new_score_name in new_scores:
|
155 |
-
if new_score_name in ["score", "score_name"]
|
|
|
|
|
156 |
continue
|
157 |
if new_score_name in existing_scores:
|
158 |
UnitxtWarning(
|
159 |
message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
|
160 |
f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
|
161 |
-
f"To avoid overwriting the existing value, add a score_prefix to the metric (e.g. score_prefix='my_second_'
|
|
|
162 |
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
163 |
)
|
164 |
return new_scores
|
@@ -279,7 +283,12 @@ class Metric(Artifact):
|
|
279 |
self, instance: Dict[str, Any], global_score: dict
|
280 |
):
|
281 |
for score_name in global_score:
|
282 |
-
if score_name in [
|
|
|
|
|
|
|
|
|
|
|
283 |
continue
|
284 |
if score_name in instance["score"]["global"]:
|
285 |
UnitxtWarning(
|
@@ -469,11 +478,17 @@ class MetricWithConfidenceInterval(Metric):
|
|
469 |
# iterate over the rows and compute the metric on each resampling
|
470 |
def metric(sample_refs, sample_preds, sample_task_data):
|
471 |
try:
|
472 |
-
|
473 |
references=sample_refs,
|
474 |
predictions=sample_preds,
|
475 |
task_data=sample_task_data,
|
476 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
except Exception as e:
|
478 |
# this happens in edge cases, for example, when the sampling creates a
|
479 |
# sample where all strings are empty and this fails bleu.
|
@@ -538,7 +553,6 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
538 |
references = []
|
539 |
predictions = []
|
540 |
task_data = []
|
541 |
-
global_score = {}
|
542 |
|
543 |
instances = []
|
544 |
|
@@ -589,6 +603,7 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
589 |
)
|
590 |
)
|
591 |
self._validate_references_and_prediction(references, predictions)
|
|
|
592 |
|
593 |
result = self._compute(references, predictions, task_data)
|
594 |
global_score.update(
|
@@ -596,11 +611,18 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
596 |
result, global_score
|
597 |
)
|
598 |
)
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
|
605 |
for instance in instances:
|
606 |
self.update_and_adjust_global_score(instance, global_score)
|
@@ -649,28 +671,24 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
649 |
default_factory=lambda: ["mean", "weighted_win_rate"]
|
650 |
)
|
651 |
|
|
|
|
|
|
|
652 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
653 |
-
global_score = {}
|
654 |
instances = []
|
|
|
|
|
|
|
|
|
655 |
|
656 |
-
|
657 |
-
references
|
658 |
-
list,
|
659 |
-
zip(
|
660 |
-
*[
|
661 |
-
itemgetter("references", "prediction")(
|
662 |
-
self.verify_instance(instance)
|
663 |
-
)
|
664 |
-
for instance in stream
|
665 |
-
]
|
666 |
-
),
|
667 |
-
)
|
668 |
-
|
669 |
task_data = [
|
670 |
instance["task_data"] if "task_data" in instance else {}
|
671 |
-
for instance in
|
672 |
]
|
673 |
self._validate_references_and_prediction(references, predictions)
|
|
|
674 |
# compute the metric over all refs and preds
|
675 |
instance_scores = self.compute(
|
676 |
references=references,
|
@@ -683,7 +701,7 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
683 |
instance_score["score"] = instance_score[self.main_score]
|
684 |
instance_score["score_name"] = self.main_score
|
685 |
|
686 |
-
for instance, score in zip(
|
687 |
if "score" not in instance:
|
688 |
instance["score"] = {"global": {}, "instance": {}}
|
689 |
|
@@ -692,7 +710,6 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
692 |
score, instance["score"]["instance"]
|
693 |
)
|
694 |
)
|
695 |
-
instances.append(instance)
|
696 |
|
697 |
for reduction, fields in self.reduction_map.items():
|
698 |
assert (
|
@@ -1059,7 +1076,7 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1059 |
|
1060 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1061 |
instances = self.compute_instance_scores(stream)
|
1062 |
-
global_score = {}
|
1063 |
for reduction_type, reduction_params in self.reduction_map.items():
|
1064 |
assert (
|
1065 |
reduction_type in self.implemented_reductions
|
@@ -1096,7 +1113,10 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1096 |
scores_to_resample,
|
1097 |
aggregation_function,
|
1098 |
) = self._set_up_group_mean_aggregation(
|
1099 |
-
instances,
|
|
|
|
|
|
|
1100 |
)
|
1101 |
else:
|
1102 |
raise ValueError(
|
@@ -1171,13 +1191,16 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1171 |
instance_score["score_name"] = self.main_score
|
1172 |
if "score" not in instance:
|
1173 |
instance["score"] = {"global": {}, "instance": {}}
|
|
|
|
|
|
|
|
|
1174 |
|
1175 |
instance["score"]["instance"].update(
|
1176 |
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
1177 |
instance_score, instance["score"]["instance"]
|
1178 |
)
|
1179 |
)
|
1180 |
-
|
1181 |
instances.append(instance)
|
1182 |
|
1183 |
return instances
|
@@ -1187,7 +1210,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1187 |
instances: List[dict],
|
1188 |
score_names: List[str],
|
1189 |
group_aggregation_func,
|
1190 |
-
prepend_score_prefix: bool
|
|
|
|
|
1191 |
):
|
1192 |
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
1193 |
|
@@ -1199,6 +1224,8 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1199 |
callable function returns a single score for the group
|
1200 |
prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
|
1201 |
if down the stream such a prepending is expected.
|
|
|
|
|
1202 |
|
1203 |
Returns:
|
1204 |
List of dicts, each corresponding to a group of instances (defined by 'group_id'),
|
@@ -1233,8 +1260,27 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1233 |
]
|
1234 |
)
|
1235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1236 |
# if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
|
1237 |
-
|
1238 |
{
|
1239 |
"score": {
|
1240 |
"instance": {
|
@@ -1255,9 +1301,25 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1255 |
) # sorted for consistency
|
1256 |
]
|
1257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1258 |
def _set_up_group_mean_aggregation(
|
1259 |
-
self,
|
|
|
|
|
|
|
|
|
1260 |
):
|
|
|
1261 |
group_aggregation_func = reduction_params["agg_func"][1]
|
1262 |
# if treat groups as units
|
1263 |
do_resample_as_group = reduction_params["agg_func"][2]
|
@@ -1265,7 +1327,12 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1265 |
# pass the group aggregate---not instance---scores to resample as usual
|
1266 |
aggregation_function = self.average_item_scores
|
1267 |
scores_to_resample = self.get_group_scores(
|
1268 |
-
instances,
|
|
|
|
|
|
|
|
|
|
|
1269 |
)
|
1270 |
else:
|
1271 |
# pass the instance scores to resample, and calculate the group aggregation on the resamplings
|
@@ -1277,7 +1344,12 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
1277 |
group_aggregation_func=group_aggregation_func,
|
1278 |
):
|
1279 |
group_scores = self.get_group_scores(
|
1280 |
-
instances,
|
|
|
|
|
|
|
|
|
|
|
1281 |
)
|
1282 |
return nan_mean(
|
1283 |
[group["score"]["instance"][field_name] for group in group_scores]
|
@@ -1315,6 +1387,19 @@ class ANLS(InstanceMetric):
|
|
1315 |
reduction_map = {"mean": ["anls"]}
|
1316 |
prediction_type = Any # string representation is compared
|
1317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1318 |
def compute(
|
1319 |
self,
|
1320 |
references: List[Any],
|
@@ -1324,20 +1409,14 @@ class ANLS(InstanceMetric):
|
|
1324 |
) -> dict:
|
1325 |
"""ANLS image-text accuracy metric."""
|
1326 |
values = []
|
1327 |
-
for
|
1328 |
-
|
1329 |
-
gt_answer = " ".join(answer.strip().lower().split())
|
1330 |
-
det_answer = " ".join(prediction.strip().lower().split())
|
1331 |
-
|
1332 |
-
# dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
|
1333 |
-
dist = self.levenshtein_distance(gt_answer, det_answer)
|
1334 |
-
length = max(len(answer.upper()), len(prediction.upper()))
|
1335 |
-
values.append(0.0 if length == 0 else float(dist) / float(length))
|
1336 |
|
1337 |
question_result = 1.0 - min(values)
|
1338 |
|
1339 |
if question_result < threshold:
|
1340 |
question_result = 0.0
|
|
|
1341 |
result = {}
|
1342 |
result["score"] = question_result
|
1343 |
result[self.main_score] = question_result
|
@@ -1345,6 +1424,7 @@ class ANLS(InstanceMetric):
|
|
1345 |
return result
|
1346 |
|
1347 |
@staticmethod
|
|
|
1348 |
def levenshtein_distance(s1, s2):
|
1349 |
if len(s1) > len(s2):
|
1350 |
s1, s2 = s2, s1
|
@@ -1526,16 +1606,40 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
1526 |
), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
|
1527 |
if has_postpreprocess:
|
1528 |
self.postprocess_steps = self.postpreprocess_steps
|
1529 |
-
self.prepare_score =
|
1530 |
-
|
1531 |
-
|
1532 |
-
f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
|
1533 |
-
"score/instance/score",
|
1534 |
-
|
1535 |
-
|
1536 |
-
f"score/global/{self.metric._add_score_prefix(self.main_score)}",
|
1537 |
-
"score/global/score",
|
1538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1539 |
],
|
1540 |
)
|
1541 |
|
@@ -1589,6 +1693,8 @@ class HuggingfaceMetric(GlobalMetric):
|
|
1589 |
|
1590 |
def prepare(self):
|
1591 |
super().prepare()
|
|
|
|
|
1592 |
self.metric = evaluate.load(
|
1593 |
self.hf_metric_name, experiment_id=self.experiment_id
|
1594 |
)
|
@@ -1663,6 +1769,8 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
|
|
1663 |
|
1664 |
def prepare(self):
|
1665 |
super().prepare()
|
|
|
|
|
1666 |
self.metric = evaluate.load(
|
1667 |
self.hf_metric_name, experiment_id=str(uuid.uuid4())
|
1668 |
)
|
@@ -1709,6 +1817,8 @@ class HuggingfaceInstanceMetric(InstanceMetric):
|
|
1709 |
|
1710 |
def prepare(self):
|
1711 |
super().prepare()
|
|
|
|
|
1712 |
self.metric = evaluate.load(
|
1713 |
self.hf_metric_name, experiment_id=str(uuid.uuid4())
|
1714 |
)
|
@@ -1788,6 +1898,8 @@ class F1(GlobalMetric):
|
|
1788 |
|
1789 |
def prepare(self):
|
1790 |
super().prepare()
|
|
|
|
|
1791 |
self._metric = evaluate.load(self.metric, experiment_id=str(uuid.uuid4()))
|
1792 |
|
1793 |
def get_str_id(self, str):
|
@@ -1847,6 +1959,7 @@ class F1Binary(GlobalMetric):
|
|
1847 |
_metric = None
|
1848 |
metric = "f1"
|
1849 |
single_reference_per_prediction = True
|
|
|
1850 |
_requirements_list: List[str] = ["sklearn"]
|
1851 |
|
1852 |
def prepare(self):
|
@@ -2064,6 +2177,8 @@ class F1MultiLabel(GlobalMetric):
|
|
2064 |
|
2065 |
def prepare(self):
|
2066 |
super().prepare()
|
|
|
|
|
2067 |
self._metric = evaluate.load(
|
2068 |
self.metric, "multilabel", experiment_id=str(uuid.uuid4())
|
2069 |
)
|
@@ -3033,7 +3148,7 @@ class SafetyMetric(GlobalMetric):
|
|
3033 |
class LlamaIndexLLMMetric(InstanceMetric):
|
3034 |
model_name: str = ""
|
3035 |
main_score: str = ""
|
3036 |
-
prediction_type
|
3037 |
reduction_map: Dict[str, List[str]] = None
|
3038 |
openai_models: List[str] = ["gpt-3.5-turbo"]
|
3039 |
anthropic_models: List[
|
@@ -3679,6 +3794,7 @@ class RetrievalAtK(RetrievalMetric):
|
|
3679 |
(recall_at_k, "recall"),
|
3680 |
(match_at_k, "match"),
|
3681 |
]:
|
|
|
3682 |
max_k = max(measure_array.keys())
|
3683 |
for k in self.k_list:
|
3684 |
result[self.score_name(measure_name, k)] = measure_array[min(k, max_k)]
|
@@ -3725,7 +3841,7 @@ class RemoteMetric(StreamOperator, Metric):
|
|
3725 |
remotely (pre and post processing steps in the MetricPipeline will be computed locally).
|
3726 |
"""
|
3727 |
local_inner_metric = metric_pipeline.metric
|
3728 |
-
metric_pipeline =
|
3729 |
metric_pipeline
|
3730 |
) # To avoid unintentional changes to the catalog contents
|
3731 |
metric_pipeline.metric = RemoteMetric(
|
@@ -4376,6 +4492,7 @@ class BinaryMaxF1(F1Binary):
|
|
4376 |
main_score = "max_f1_binary"
|
4377 |
single_reference_per_prediction = True
|
4378 |
average = None
|
|
|
4379 |
|
4380 |
def compute(
|
4381 |
self,
|
@@ -4799,17 +4916,22 @@ class F1Strings(InstanceMetric):
|
|
4799 |
"spacy": "Please pip install spacy",
|
4800 |
}
|
4801 |
|
4802 |
-
def
|
4803 |
-
super().prepare()
|
4804 |
import spacy
|
4805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
4806 |
try:
|
4807 |
-
self.
|
4808 |
except OSError:
|
4809 |
from spacy.cli import download
|
4810 |
|
4811 |
download("en_core_web_sm")
|
4812 |
-
self.
|
4813 |
|
4814 |
def compute(
|
4815 |
self,
|
@@ -4955,3 +5077,20 @@ class RandomForestMetricsEnsemble(MetricsEnsemble):
|
|
4955 |
)
|
4956 |
score = ensemble_model.predict([prediction_lst])
|
4957 |
return score.tolist()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from abc import ABC, abstractmethod
|
9 |
from collections import Counter, defaultdict
|
10 |
from dataclasses import field
|
11 |
+
from functools import lru_cache
|
12 |
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
13 |
|
|
|
14 |
import numpy
|
15 |
import numpy as np
|
16 |
import pandas as pd
|
|
|
36 |
StreamingOperator,
|
37 |
StreamOperator,
|
38 |
)
|
39 |
+
from .operators import Copy, Set
|
40 |
from .random_utils import get_seed
|
41 |
from .settings_utils import get_settings
|
42 |
from .stream import MultiStream, Stream
|
43 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
44 |
+
from .utils import deep_copy
|
45 |
|
46 |
logger = get_logger()
|
47 |
settings = get_settings()
|
48 |
|
49 |
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
50 |
|
|
|
|
|
51 |
|
52 |
def abstract_factory():
|
53 |
return {}
|
|
|
136 |
return (
|
137 |
self.score_prefix + score_name
|
138 |
if score_name not in ["score", "score_name"]
|
139 |
+
and not score_name.startswith("num_of_instances")
|
140 |
else score_name
|
141 |
)
|
142 |
|
|
|
145 |
) -> Dict[str, Any]:
|
146 |
new_scores = {}
|
147 |
for score_name, score in scores.items():
|
148 |
+
if isinstance(score, dict):
|
149 |
+
new_scores[score_name] = score
|
150 |
+
continue # do not prefix group names
|
151 |
score_with_prefix = self._add_score_prefix(score_name)
|
152 |
new_scores[score_with_prefix] = (
|
153 |
score if score_name not in ["score_name"] else self.score_prefix + score
|
154 |
)
|
155 |
for new_score_name in new_scores:
|
156 |
+
if new_score_name in ["score", "score_name"] or new_score_name.startswith(
|
157 |
+
"num_of_instances"
|
158 |
+
):
|
159 |
continue
|
160 |
if new_score_name in existing_scores:
|
161 |
UnitxtWarning(
|
162 |
message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
|
163 |
f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
|
164 |
+
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
|
165 |
+
f"which will yield, in this case, a score named: 'my_second_{new_score_name}')",
|
166 |
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
167 |
)
|
168 |
return new_scores
|
|
|
283 |
self, instance: Dict[str, Any], global_score: dict
|
284 |
):
|
285 |
for score_name in global_score:
|
286 |
+
if score_name in [
|
287 |
+
"score",
|
288 |
+
"score_name",
|
289 |
+
"score_ci_low",
|
290 |
+
"score_ci_high",
|
291 |
+
] or score_name.startswith("num_of_instances"):
|
292 |
continue
|
293 |
if score_name in instance["score"]["global"]:
|
294 |
UnitxtWarning(
|
|
|
478 |
# iterate over the rows and compute the metric on each resampling
|
479 |
def metric(sample_refs, sample_preds, sample_task_data):
|
480 |
try:
|
481 |
+
results = self._compute(
|
482 |
references=sample_refs,
|
483 |
predictions=sample_preds,
|
484 |
task_data=sample_task_data,
|
485 |
+
)
|
486 |
+
results.update(
|
487 |
+
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
488 |
+
results, {}
|
489 |
+
)
|
490 |
+
)
|
491 |
+
return results[score_name]
|
492 |
except Exception as e:
|
493 |
# this happens in edge cases, for example, when the sampling creates a
|
494 |
# sample where all strings are empty and this fails bleu.
|
|
|
553 |
references = []
|
554 |
predictions = []
|
555 |
task_data = []
|
|
|
556 |
|
557 |
instances = []
|
558 |
|
|
|
603 |
)
|
604 |
)
|
605 |
self._validate_references_and_prediction(references, predictions)
|
606 |
+
global_score = {"num_of_instances": len(instances)}
|
607 |
|
608 |
result = self._compute(references, predictions, task_data)
|
609 |
global_score.update(
|
|
|
611 |
result, global_score
|
612 |
)
|
613 |
)
|
614 |
+
if self.ci_scores:
|
615 |
+
score_names = [
|
616 |
+
self._add_score_prefix(score_name) for score_name in self.ci_scores
|
617 |
+
]
|
618 |
+
else:
|
619 |
+
score_names = [global_score["score_name"]]
|
620 |
+
|
621 |
+
for score_name in score_names:
|
622 |
+
confidence_interval = self.compute_global_confidence_intervals(
|
623 |
+
references, predictions, task_data, score_name
|
624 |
+
)
|
625 |
+
global_score.update(confidence_interval)
|
626 |
|
627 |
for instance in instances:
|
628 |
self.update_and_adjust_global_score(instance, global_score)
|
|
|
671 |
default_factory=lambda: ["mean", "weighted_win_rate"]
|
672 |
)
|
673 |
|
674 |
+
def preprocess_instance(self, instance):
|
675 |
+
return instance
|
676 |
+
|
677 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
|
|
678 |
instances = []
|
679 |
+
for instance in stream:
|
680 |
+
self.verify_instance(instance)
|
681 |
+
instance = self.preprocess_instance(instance)
|
682 |
+
instances.append(instance)
|
683 |
|
684 |
+
predictions = [instance["prediction"] for instance in instances]
|
685 |
+
references = [instance["references"] for instance in instances]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
task_data = [
|
687 |
instance["task_data"] if "task_data" in instance else {}
|
688 |
+
for instance in instances
|
689 |
]
|
690 |
self._validate_references_and_prediction(references, predictions)
|
691 |
+
global_score = {"num_of_instances": len(instances)}
|
692 |
# compute the metric over all refs and preds
|
693 |
instance_scores = self.compute(
|
694 |
references=references,
|
|
|
701 |
instance_score["score"] = instance_score[self.main_score]
|
702 |
instance_score["score_name"] = self.main_score
|
703 |
|
704 |
+
for instance, score in zip(instances, instance_scores):
|
705 |
if "score" not in instance:
|
706 |
instance["score"] = {"global": {}, "instance": {}}
|
707 |
|
|
|
710 |
score, instance["score"]["instance"]
|
711 |
)
|
712 |
)
|
|
|
713 |
|
714 |
for reduction, fields in self.reduction_map.items():
|
715 |
assert (
|
|
|
1076 |
|
1077 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1078 |
instances = self.compute_instance_scores(stream)
|
1079 |
+
global_score = {"num_of_instances": len(instances)}
|
1080 |
for reduction_type, reduction_params in self.reduction_map.items():
|
1081 |
assert (
|
1082 |
reduction_type in self.implemented_reductions
|
|
|
1113 |
scores_to_resample,
|
1114 |
aggregation_function,
|
1115 |
) = self._set_up_group_mean_aggregation(
|
1116 |
+
instances,
|
1117 |
+
reduction_params,
|
1118 |
+
reduction_fields,
|
1119 |
+
global_score,
|
1120 |
)
|
1121 |
else:
|
1122 |
raise ValueError(
|
|
|
1191 |
instance_score["score_name"] = self.main_score
|
1192 |
if "score" not in instance:
|
1193 |
instance["score"] = {"global": {}, "instance": {}}
|
1194 |
+
if "global" not in instance["score"]:
|
1195 |
+
instance["score"]["global"] = {}
|
1196 |
+
if "instance" not in instance["score"]:
|
1197 |
+
instance["score"]["instance"] = {}
|
1198 |
|
1199 |
instance["score"]["instance"].update(
|
1200 |
self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
1201 |
instance_score, instance["score"]["instance"]
|
1202 |
)
|
1203 |
)
|
|
|
1204 |
instances.append(instance)
|
1205 |
|
1206 |
return instances
|
|
|
1210 |
instances: List[dict],
|
1211 |
score_names: List[str],
|
1212 |
group_aggregation_func,
|
1213 |
+
prepend_score_prefix: bool,
|
1214 |
+
global_score: dict,
|
1215 |
+
aggregation_function_name: str,
|
1216 |
):
|
1217 |
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
1218 |
|
|
|
1224 |
callable function returns a single score for the group
|
1225 |
prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
|
1226 |
if down the stream such a prepending is expected.
|
1227 |
+
global_score: the being built up global score. It will be filled here with number of instances per each group, and group scores.
|
1228 |
+
aggregation_function_name: used to annotate the groups' global scores.
|
1229 |
|
1230 |
Returns:
|
1231 |
List of dicts, each corresponding to a group of instances (defined by 'group_id'),
|
|
|
1260 |
]
|
1261 |
)
|
1262 |
|
1263 |
+
# count the instances in each group and subgroup.
|
1264 |
+
# Each instance goes into group_to_instances per each score_name.
|
1265 |
+
# So we count over the first score_name only
|
1266 |
+
for group_key in group_to_instance_scores:
|
1267 |
+
if group_key not in global_score:
|
1268 |
+
global_score[group_key] = {}
|
1269 |
+
global_score[group_key]["num_of_instances"] = sum(
|
1270 |
+
[
|
1271 |
+
len(
|
1272 |
+
group_to_instance_scores[group_key][score_names[0]][
|
1273 |
+
subgroup_type
|
1274 |
+
]
|
1275 |
+
)
|
1276 |
+
for subgroup_type in group_to_instance_scores[group_key][
|
1277 |
+
score_names[0]
|
1278 |
+
]
|
1279 |
+
]
|
1280 |
+
)
|
1281 |
+
|
1282 |
# if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
|
1283 |
+
to_return = [
|
1284 |
{
|
1285 |
"score": {
|
1286 |
"instance": {
|
|
|
1301 |
) # sorted for consistency
|
1302 |
]
|
1303 |
|
1304 |
+
# update each group section in global_score
|
1305 |
+
for i, group_name in enumerate(sorted(group_to_instance_scores.keys())):
|
1306 |
+
global_score[group_name].update(
|
1307 |
+
{
|
1308 |
+
aggregation_function_name + "_" + k: v
|
1309 |
+
for k, v in to_return[i]["score"]["instance"].items()
|
1310 |
+
}
|
1311 |
+
)
|
1312 |
+
|
1313 |
+
return to_return
|
1314 |
+
|
1315 |
def _set_up_group_mean_aggregation(
|
1316 |
+
self,
|
1317 |
+
instances,
|
1318 |
+
reduction_params,
|
1319 |
+
reduction_fields,
|
1320 |
+
global_score,
|
1321 |
):
|
1322 |
+
aggregation_function_name = str(reduction_params["agg_func"][0])
|
1323 |
group_aggregation_func = reduction_params["agg_func"][1]
|
1324 |
# if treat groups as units
|
1325 |
do_resample_as_group = reduction_params["agg_func"][2]
|
|
|
1327 |
# pass the group aggregate---not instance---scores to resample as usual
|
1328 |
aggregation_function = self.average_item_scores
|
1329 |
scores_to_resample = self.get_group_scores(
|
1330 |
+
instances=instances,
|
1331 |
+
score_names=reduction_fields,
|
1332 |
+
group_aggregation_func=group_aggregation_func,
|
1333 |
+
prepend_score_prefix=True,
|
1334 |
+
global_score=global_score,
|
1335 |
+
aggregation_function_name=aggregation_function_name,
|
1336 |
)
|
1337 |
else:
|
1338 |
# pass the instance scores to resample, and calculate the group aggregation on the resamplings
|
|
|
1344 |
group_aggregation_func=group_aggregation_func,
|
1345 |
):
|
1346 |
group_scores = self.get_group_scores(
|
1347 |
+
instances=instances,
|
1348 |
+
score_names=[field_name],
|
1349 |
+
group_aggregation_func=group_aggregation_func,
|
1350 |
+
prepend_score_prefix=False,
|
1351 |
+
global_score=global_score,
|
1352 |
+
aggregation_function_name=aggregation_function_name,
|
1353 |
)
|
1354 |
return nan_mean(
|
1355 |
[group["score"]["instance"][field_name] for group in group_scores]
|
|
|
1387 |
reduction_map = {"mean": ["anls"]}
|
1388 |
prediction_type = Any # string representation is compared
|
1389 |
|
1390 |
+
@staticmethod
|
1391 |
+
@lru_cache(maxsize=10000)
|
1392 |
+
def preprocess_text(text):
|
1393 |
+
return " ".join(text.strip().lower().split()), len(text.upper())
|
1394 |
+
|
1395 |
+
def distance(self, prediction, reference):
|
1396 |
+
processed_reference, len_reference = self.preprocess_text(reference)
|
1397 |
+
processed_prediction, len_prediction = self.preprocess_text(prediction)
|
1398 |
+
|
1399 |
+
dist = self.levenshtein_distance(processed_reference, processed_prediction)
|
1400 |
+
length = max(len_reference, len_prediction)
|
1401 |
+
return 0.0 if length == 0 else float(dist) / float(length)
|
1402 |
+
|
1403 |
def compute(
|
1404 |
self,
|
1405 |
references: List[Any],
|
|
|
1409 |
) -> dict:
|
1410 |
"""ANLS image-text accuracy metric."""
|
1411 |
values = []
|
1412 |
+
for reference in references:
|
1413 |
+
values.append(self.distance(prediction, reference))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1414 |
|
1415 |
question_result = 1.0 - min(values)
|
1416 |
|
1417 |
if question_result < threshold:
|
1418 |
question_result = 0.0
|
1419 |
+
|
1420 |
result = {}
|
1421 |
result["score"] = question_result
|
1422 |
result[self.main_score] = question_result
|
|
|
1424 |
return result
|
1425 |
|
1426 |
@staticmethod
|
1427 |
+
@lru_cache(maxsize=10000)
|
1428 |
def levenshtein_distance(s1, s2):
|
1429 |
if len(s1) > len(s2):
|
1430 |
s1, s2 = s2, s1
|
|
|
1606 |
), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
|
1607 |
if has_postpreprocess:
|
1608 |
self.postprocess_steps = self.postpreprocess_steps
|
1609 |
+
self.prepare_score = SequentialOperator(
|
1610 |
+
steps=[
|
1611 |
+
Copy(
|
1612 |
+
field=f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
|
1613 |
+
to_field="score/instance/score",
|
1614 |
+
),
|
1615 |
+
Copy(
|
1616 |
+
field=f"score/global/{self.metric._add_score_prefix(self.main_score)}",
|
1617 |
+
to_field="score/global/score",
|
1618 |
+
),
|
1619 |
+
Copy(
|
1620 |
+
field=f"score/global/{self.metric._add_score_prefix(self.main_score)}_ci_low",
|
1621 |
+
to_field="score/global/score_ci_low",
|
1622 |
+
not_exist_do_nothing=True,
|
1623 |
+
),
|
1624 |
+
Copy(
|
1625 |
+
field=f"score/global/{self.metric._add_score_prefix(self.main_score)}_ci_high",
|
1626 |
+
to_field="score/global/score_ci_high",
|
1627 |
+
not_exist_do_nothing=True,
|
1628 |
+
),
|
1629 |
+
Set(
|
1630 |
+
fields={
|
1631 |
+
"score/instance/score_name": self.metric._add_score_prefix(
|
1632 |
+
self.main_score
|
1633 |
+
)
|
1634 |
+
}
|
1635 |
+
),
|
1636 |
+
Set(
|
1637 |
+
fields={
|
1638 |
+
"score/global/score_name": self.metric._add_score_prefix(
|
1639 |
+
self.main_score
|
1640 |
+
)
|
1641 |
+
}
|
1642 |
+
),
|
1643 |
],
|
1644 |
)
|
1645 |
|
|
|
1693 |
|
1694 |
def prepare(self):
|
1695 |
super().prepare()
|
1696 |
+
import evaluate
|
1697 |
+
|
1698 |
self.metric = evaluate.load(
|
1699 |
self.hf_metric_name, experiment_id=self.experiment_id
|
1700 |
)
|
|
|
1769 |
|
1770 |
def prepare(self):
|
1771 |
super().prepare()
|
1772 |
+
import evaluate
|
1773 |
+
|
1774 |
self.metric = evaluate.load(
|
1775 |
self.hf_metric_name, experiment_id=str(uuid.uuid4())
|
1776 |
)
|
|
|
1817 |
|
1818 |
def prepare(self):
|
1819 |
super().prepare()
|
1820 |
+
import evaluate
|
1821 |
+
|
1822 |
self.metric = evaluate.load(
|
1823 |
self.hf_metric_name, experiment_id=str(uuid.uuid4())
|
1824 |
)
|
|
|
1898 |
|
1899 |
def prepare(self):
|
1900 |
super().prepare()
|
1901 |
+
import evaluate
|
1902 |
+
|
1903 |
self._metric = evaluate.load(self.metric, experiment_id=str(uuid.uuid4()))
|
1904 |
|
1905 |
def get_str_id(self, str):
|
|
|
1959 |
_metric = None
|
1960 |
metric = "f1"
|
1961 |
single_reference_per_prediction = True
|
1962 |
+
ci_scores = [main_score, "f1_binary_neg"]
|
1963 |
_requirements_list: List[str] = ["sklearn"]
|
1964 |
|
1965 |
def prepare(self):
|
|
|
2177 |
|
2178 |
def prepare(self):
|
2179 |
super().prepare()
|
2180 |
+
import evaluate
|
2181 |
+
|
2182 |
self._metric = evaluate.load(
|
2183 |
self.metric, "multilabel", experiment_id=str(uuid.uuid4())
|
2184 |
)
|
|
|
3148 |
class LlamaIndexLLMMetric(InstanceMetric):
|
3149 |
model_name: str = ""
|
3150 |
main_score: str = ""
|
3151 |
+
prediction_type = str
|
3152 |
reduction_map: Dict[str, List[str]] = None
|
3153 |
openai_models: List[str] = ["gpt-3.5-turbo"]
|
3154 |
anthropic_models: List[
|
|
|
3794 |
(recall_at_k, "recall"),
|
3795 |
(match_at_k, "match"),
|
3796 |
]:
|
3797 |
+
measure_array[0] = 0.0 # to support cases where the prediction is empty.
|
3798 |
max_k = max(measure_array.keys())
|
3799 |
for k in self.k_list:
|
3800 |
result[self.score_name(measure_name, k)] = measure_array[min(k, max_k)]
|
|
|
3841 |
remotely (pre and post processing steps in the MetricPipeline will be computed locally).
|
3842 |
"""
|
3843 |
local_inner_metric = metric_pipeline.metric
|
3844 |
+
metric_pipeline = deep_copy(
|
3845 |
metric_pipeline
|
3846 |
) # To avoid unintentional changes to the catalog contents
|
3847 |
metric_pipeline.metric = RemoteMetric(
|
|
|
4492 |
main_score = "max_f1_binary"
|
4493 |
single_reference_per_prediction = True
|
4494 |
average = None
|
4495 |
+
ci_scores = [main_score, "max_f1_binary_neg"]
|
4496 |
|
4497 |
def compute(
|
4498 |
self,
|
|
|
4916 |
"spacy": "Please pip install spacy",
|
4917 |
}
|
4918 |
|
4919 |
+
def load_spacy(self):
|
|
|
4920 |
import spacy
|
4921 |
|
4922 |
+
self.nlp = spacy.load(
|
4923 |
+
"en_core_web_sm", disable=["tagger", "parser", "ner", "lemmatizer"]
|
4924 |
+
)
|
4925 |
+
|
4926 |
+
def prepare(self):
|
4927 |
+
super().prepare()
|
4928 |
try:
|
4929 |
+
self.load_spacy()
|
4930 |
except OSError:
|
4931 |
from spacy.cli import download
|
4932 |
|
4933 |
download("en_core_web_sm")
|
4934 |
+
self.load_spacy()
|
4935 |
|
4936 |
def compute(
|
4937 |
self,
|
|
|
5077 |
)
|
5078 |
score = ensemble_model.predict([prediction_lst])
|
5079 |
return score.tolist()[0]
|
5080 |
+
|
5081 |
+
|
5082 |
+
class PredictionLength(InstanceMetric):
|
5083 |
+
"""Returns the length of the prediction."""
|
5084 |
+
|
5085 |
+
main_score = "prediction_length"
|
5086 |
+
reduction_map = {"mean": ["prediction_length"]}
|
5087 |
+
prediction_type = str
|
5088 |
+
single_reference_per_prediction = True
|
5089 |
+
|
5090 |
+
def compute(
|
5091 |
+
self,
|
5092 |
+
references: List[str],
|
5093 |
+
prediction: str,
|
5094 |
+
task_data: List[Dict],
|
5095 |
+
) -> dict:
|
5096 |
+
return {self.main_score: [len(prediction)], "score_name": self.main_score}
|
operators.py
CHANGED
@@ -39,7 +39,6 @@ General Operators List:
|
|
39 |
------------------------
|
40 |
"""
|
41 |
|
42 |
-
import copy
|
43 |
import operator
|
44 |
import uuid
|
45 |
import warnings
|
@@ -82,14 +81,19 @@ from .operator import (
|
|
82 |
StreamOperator,
|
83 |
)
|
84 |
from .random_utils import new_random_generator
|
85 |
-
from .settings_utils import
|
86 |
-
from .stream import DynamicStream, Stream
|
87 |
from .text_utils import nested_tuple_to_string
|
88 |
from .type_utils import isoftype
|
89 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
settings = get_settings()
|
92 |
-
constants = get_constants()
|
93 |
|
94 |
|
95 |
class FromIterables(StreamInitializerOperator):
|
@@ -132,8 +136,8 @@ class MapInstanceValues(InstanceOperator):
|
|
132 |
it maps values of instances in a stream using predefined mappers.
|
133 |
|
134 |
Attributes:
|
135 |
-
mappers (Dict[str, Dict[str,
|
136 |
-
Keys are the names of the fields to
|
137 |
that define the mapping from old values to new values.
|
138 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
139 |
does not exist in the mapper, it will raise a KeyError. If False, values
|
@@ -203,13 +207,12 @@ class MapInstanceValues(InstanceOperator):
|
|
203 |
|
204 |
def get_mapped_value(self, instance, key, mapper, val):
|
205 |
val_as_str = str(val) # make sure the value is a string
|
206 |
-
if
|
|
|
|
|
207 |
raise KeyError(
|
208 |
f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
209 |
)
|
210 |
-
# By default deep copy the value in mapper to avoid shared modifications
|
211 |
-
if val_as_str in mapper:
|
212 |
-
return deepcopy(mapper[val_as_str])
|
213 |
return val
|
214 |
|
215 |
|
@@ -269,7 +272,7 @@ class Set(InstanceOperator):
|
|
269 |
) -> Dict[str, Any]:
|
270 |
for key, value in self.fields.items():
|
271 |
if self.use_deepcopy:
|
272 |
-
value =
|
273 |
dict_set(instance, key, value)
|
274 |
return instance
|
275 |
|
@@ -318,6 +321,13 @@ class SelectFields(InstanceOperator):
|
|
318 |
return new_instance
|
319 |
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
class InstanceFieldOperator(InstanceOperator):
|
322 |
"""A general stream instance operator that processes the values of a field (or multiple ones).
|
323 |
|
@@ -348,6 +358,7 @@ class InstanceFieldOperator(InstanceOperator):
|
|
348 |
process_every_value: bool = False
|
349 |
get_default: Any = None
|
350 |
not_exist_ok: bool = False
|
|
|
351 |
|
352 |
def verify(self):
|
353 |
super().verify()
|
@@ -429,19 +440,18 @@ class InstanceFieldOperator(InstanceOperator):
|
|
429 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
430 |
) -> Dict[str, Any]:
|
431 |
self.verify_field_definition()
|
432 |
-
# Need to deep copy instance, because when assigning two dictionary fields,
|
433 |
-
# dict_set() the target field dictionary fields.
|
434 |
-
# This means that if this target field was assigned to another field before,
|
435 |
-
# the field is updated as well.
|
436 |
-
instance = deepcopy(instance)
|
437 |
for from_field, to_field in self._field_to_field:
|
438 |
try:
|
439 |
old_value = dict_get(
|
440 |
instance,
|
441 |
from_field,
|
442 |
-
default=
|
443 |
-
not_exist_ok=self.not_exist_ok,
|
444 |
)
|
|
|
|
|
|
|
|
|
445 |
except Exception as e:
|
446 |
raise ValueError(
|
447 |
f"Failed to get '{from_field}' from {instance} due to : {e}"
|
@@ -476,6 +486,13 @@ class FieldOperator(InstanceFieldOperator):
|
|
476 |
pass
|
477 |
|
478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
class Rename(FieldOperator):
|
480 |
"""Renames fields.
|
481 |
|
@@ -643,7 +660,9 @@ class ListFieldValues(InstanceOperator):
|
|
643 |
values = []
|
644 |
for field_name in self.fields:
|
645 |
values.append(dict_get(instance, field_name))
|
646 |
-
|
|
|
|
|
647 |
return instance
|
648 |
|
649 |
|
@@ -680,7 +699,7 @@ class ZipFieldValues(InstanceOperator):
|
|
680 |
zipped = zip_longest(*values)
|
681 |
else:
|
682 |
zipped = zip(*values)
|
683 |
-
instance
|
684 |
return instance
|
685 |
|
686 |
|
@@ -847,14 +866,15 @@ class Copy(FieldOperator):
|
|
847 |
|
848 |
"""
|
849 |
|
850 |
-
use_deep_copy: bool = True
|
851 |
-
|
852 |
def process_value(self, value: Any) -> Any:
|
853 |
-
if self.use_deep_copy:
|
854 |
-
return copy.deepcopy(value)
|
855 |
return value
|
856 |
|
857 |
|
|
|
|
|
|
|
|
|
|
|
858 |
@deprecation(version="2.0.0", alternative=Copy)
|
859 |
class CopyFields(Copy):
|
860 |
pass
|
@@ -1022,7 +1042,7 @@ class ArtifactFetcherMixin:
|
|
1022 |
if artifact_identifier not in cls.cache:
|
1023 |
artifact, artifactory = fetch_artifact(artifact_identifier)
|
1024 |
cls.cache[artifact_identifier] = artifact
|
1025 |
-
return
|
1026 |
|
1027 |
|
1028 |
class ApplyOperatorsField(InstanceOperator):
|
@@ -1602,7 +1622,23 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1602 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1603 |
from .metrics import Metric
|
1604 |
|
1605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1606 |
|
1607 |
metric_names = first_instance.get(self.metric_field, [])
|
1608 |
if not metric_names:
|
@@ -1619,16 +1655,6 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1619 |
# by the first listed metric (as desired).
|
1620 |
metric_names = list(reversed(metric_names))
|
1621 |
|
1622 |
-
# Workaround: The metric/MetricPipeline modifies the stream itself, sometimes making it incompatible
|
1623 |
-
# for further metrics' processing, instead of just modifying the score field.
|
1624 |
-
# Here we keep all the fields besides the score, and restore them after the metric finishes.
|
1625 |
-
first_instance = stream.peek()
|
1626 |
-
keys_to_restore = set(first_instance.keys()).difference({"score"})
|
1627 |
-
multi_stream = MultiStream({stream_name: stream})
|
1628 |
-
multi_stream = CopyFields(
|
1629 |
-
field_to_field={k: f"{k}_orig" for k in keys_to_restore}
|
1630 |
-
)(multi_stream)
|
1631 |
-
|
1632 |
for metric_name in metric_names:
|
1633 |
metric = self.get_artifact(metric_name)
|
1634 |
assert isinstance(
|
@@ -1637,17 +1663,23 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1637 |
|
1638 |
if not self.calc_confidence_intervals:
|
1639 |
metric.disable_confidence_interval_calculation()
|
1640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1641 |
multi_stream = metric(multi_stream)
|
1642 |
-
|
1643 |
-
|
1644 |
-
)
|
|
|
|
|
|
|
1645 |
|
1646 |
-
|
1647 |
-
multi_stream
|
1648 |
-
)
|
1649 |
-
stream = multi_stream[stream_name]
|
1650 |
-
yield from stream
|
1651 |
|
1652 |
|
1653 |
class MergeStreams(MultiStreamOperator):
|
@@ -2066,7 +2098,7 @@ class DuplicateInstances(StreamOperator):
|
|
2066 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
2067 |
for instance in stream:
|
2068 |
for idx in range(self.num_duplications):
|
2069 |
-
duplicate =
|
2070 |
if self.duplication_index_field:
|
2071 |
duplicate.update({self.duplication_index_field: idx})
|
2072 |
yield duplicate
|
|
|
39 |
------------------------
|
40 |
"""
|
41 |
|
|
|
42 |
import operator
|
43 |
import uuid
|
44 |
import warnings
|
|
|
81 |
StreamOperator,
|
82 |
)
|
83 |
from .random_utils import new_random_generator
|
84 |
+
from .settings_utils import get_settings
|
85 |
+
from .stream import DynamicStream, ListStream, Stream
|
86 |
from .text_utils import nested_tuple_to_string
|
87 |
from .type_utils import isoftype
|
88 |
+
from .utils import (
|
89 |
+
deep_copy,
|
90 |
+
flatten_dict,
|
91 |
+
recursive_copy,
|
92 |
+
recursive_shallow_copy,
|
93 |
+
shallow_copy,
|
94 |
+
)
|
95 |
|
96 |
settings = get_settings()
|
|
|
97 |
|
98 |
|
99 |
class FromIterables(StreamInitializerOperator):
|
|
|
136 |
it maps values of instances in a stream using predefined mappers.
|
137 |
|
138 |
Attributes:
|
139 |
+
mappers (Dict[str, Dict[str, Any]]): The mappers to use for mapping instance values.
|
140 |
+
Keys are the names of the fields to undergo mapping, and values are dictionaries
|
141 |
that define the mapping from old values to new values.
|
142 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
143 |
does not exist in the mapper, it will raise a KeyError. If False, values
|
|
|
207 |
|
208 |
def get_mapped_value(self, instance, key, mapper, val):
|
209 |
val_as_str = str(val) # make sure the value is a string
|
210 |
+
if val_as_str in mapper:
|
211 |
+
return recursive_copy(mapper[val_as_str])
|
212 |
+
if self.strict:
|
213 |
raise KeyError(
|
214 |
f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
215 |
)
|
|
|
|
|
|
|
216 |
return val
|
217 |
|
218 |
|
|
|
272 |
) -> Dict[str, Any]:
|
273 |
for key, value in self.fields.items():
|
274 |
if self.use_deepcopy:
|
275 |
+
value = deep_copy(value)
|
276 |
dict_set(instance, key, value)
|
277 |
return instance
|
278 |
|
|
|
321 |
return new_instance
|
322 |
|
323 |
|
324 |
+
class DefaultPlaceHolder:
|
325 |
+
pass
|
326 |
+
|
327 |
+
|
328 |
+
default_place_holder = DefaultPlaceHolder()
|
329 |
+
|
330 |
+
|
331 |
class InstanceFieldOperator(InstanceOperator):
|
332 |
"""A general stream instance operator that processes the values of a field (or multiple ones).
|
333 |
|
|
|
358 |
process_every_value: bool = False
|
359 |
get_default: Any = None
|
360 |
not_exist_ok: bool = False
|
361 |
+
not_exist_do_nothing: bool = False
|
362 |
|
363 |
def verify(self):
|
364 |
super().verify()
|
|
|
440 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
441 |
) -> Dict[str, Any]:
|
442 |
self.verify_field_definition()
|
|
|
|
|
|
|
|
|
|
|
443 |
for from_field, to_field in self._field_to_field:
|
444 |
try:
|
445 |
old_value = dict_get(
|
446 |
instance,
|
447 |
from_field,
|
448 |
+
default=default_place_holder,
|
449 |
+
not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
|
450 |
)
|
451 |
+
if old_value is default_place_holder:
|
452 |
+
if self.not_exist_do_nothing:
|
453 |
+
return instance
|
454 |
+
old_value = self.get_default
|
455 |
except Exception as e:
|
456 |
raise ValueError(
|
457 |
f"Failed to get '{from_field}' from {instance} due to : {e}"
|
|
|
486 |
pass
|
487 |
|
488 |
|
489 |
+
class MapValues(FieldOperator):
|
490 |
+
mapping: Dict[str, str]
|
491 |
+
|
492 |
+
def process_value(self, value: Any) -> Any:
|
493 |
+
return self.mapping[str(value)]
|
494 |
+
|
495 |
+
|
496 |
class Rename(FieldOperator):
|
497 |
"""Renames fields.
|
498 |
|
|
|
660 |
values = []
|
661 |
for field_name in self.fields:
|
662 |
values.append(dict_get(instance, field_name))
|
663 |
+
|
664 |
+
dict_set(instance, self.to_field, values)
|
665 |
+
|
666 |
return instance
|
667 |
|
668 |
|
|
|
699 |
zipped = zip_longest(*values)
|
700 |
else:
|
701 |
zipped = zip(*values)
|
702 |
+
dict_set(instance, self.to_field, list(zipped))
|
703 |
return instance
|
704 |
|
705 |
|
|
|
866 |
|
867 |
"""
|
868 |
|
|
|
|
|
869 |
def process_value(self, value: Any) -> Any:
|
|
|
|
|
870 |
return value
|
871 |
|
872 |
|
873 |
+
class RecursiveCopy(FieldOperator):
|
874 |
+
def process_value(self, value: Any) -> Any:
|
875 |
+
return recursive_copy(value)
|
876 |
+
|
877 |
+
|
878 |
@deprecation(version="2.0.0", alternative=Copy)
|
879 |
class CopyFields(Copy):
|
880 |
pass
|
|
|
1042 |
if artifact_identifier not in cls.cache:
|
1043 |
artifact, artifactory = fetch_artifact(artifact_identifier)
|
1044 |
cls.cache[artifact_identifier] = artifact
|
1045 |
+
return shallow_copy(cls.cache[artifact_identifier])
|
1046 |
|
1047 |
|
1048 |
class ApplyOperatorsField(InstanceOperator):
|
|
|
1622 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1623 |
from .metrics import Metric
|
1624 |
|
1625 |
+
# Number of instances in input stream is assumed to be small. This is why
|
1626 |
+
# each metric consumes all of them and lays them in its main memory, and even generates
|
1627 |
+
# some 1000 copies thereof for the sake of CI.
|
1628 |
+
# So we start with deep copying here, to make a 'frozen' status of the stream, having
|
1629 |
+
# passed the preprocess_steps of the task, and inference, and now getting to be evaluated,
|
1630 |
+
# a frozen status to be fed into each of the metrics listed in metric_field,
|
1631 |
+
# so that the evaluation of one does not affect the evaluation of another
|
1632 |
+
# (typically, affecting via change of instance as part of
|
1633 |
+
# preprocess_steps of MetricPipeline, as illustrated in docs/adding_metrics/Using Metric Pipelines).
|
1634 |
+
|
1635 |
+
instances_upon_entrance_to_metrics_evaluations = []
|
1636 |
+
for instance in stream:
|
1637 |
+
instances_upon_entrance_to_metrics_evaluations.append(
|
1638 |
+
recursive_copy(instance)
|
1639 |
+
)
|
1640 |
+
|
1641 |
+
first_instance = instances_upon_entrance_to_metrics_evaluations[0]
|
1642 |
|
1643 |
metric_names = first_instance.get(self.metric_field, [])
|
1644 |
if not metric_names:
|
|
|
1655 |
# by the first listed metric (as desired).
|
1656 |
metric_names = list(reversed(metric_names))
|
1657 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1658 |
for metric_name in metric_names:
|
1659 |
metric = self.get_artifact(metric_name)
|
1660 |
assert isinstance(
|
|
|
1663 |
|
1664 |
if not self.calc_confidence_intervals:
|
1665 |
metric.disable_confidence_interval_calculation()
|
1666 |
+
multi_stream = MultiStream(
|
1667 |
+
{
|
1668 |
+
"tmp": ListStream(
|
1669 |
+
instances_list=instances_upon_entrance_to_metrics_evaluations,
|
1670 |
+
copying=True, # ensures deep copy when iterating over instances
|
1671 |
+
)
|
1672 |
+
}
|
1673 |
+
)
|
1674 |
multi_stream = metric(multi_stream)
|
1675 |
+
for evaluated_instance, freezed_instance in zip(
|
1676 |
+
multi_stream["tmp"], instances_upon_entrance_to_metrics_evaluations
|
1677 |
+
):
|
1678 |
+
freezed_instance["score"] = recursive_shallow_copy(
|
1679 |
+
evaluated_instance["score"]
|
1680 |
+
)
|
1681 |
|
1682 |
+
yield from instances_upon_entrance_to_metrics_evaluations
|
|
|
|
|
|
|
|
|
1683 |
|
1684 |
|
1685 |
class MergeStreams(MultiStreamOperator):
|
|
|
2098 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
2099 |
for instance in stream:
|
2100 |
for idx in range(self.num_duplications):
|
2101 |
+
duplicate = recursive_shallow_copy(instance)
|
2102 |
if self.duplication_index_field:
|
2103 |
duplicate.update({self.duplication_index_field: idx})
|
2104 |
yield duplicate
|
processors.py
CHANGED
@@ -2,9 +2,12 @@ import ast
|
|
2 |
import copy
|
3 |
import json
|
4 |
import re
|
|
|
5 |
from difflib import get_close_matches
|
6 |
from typing import Any, Dict
|
7 |
|
|
|
|
|
8 |
from .deprecation_utils import deprecation
|
9 |
from .operator import MultiStreamOperator
|
10 |
from .operators import FieldOperator, InstanceFieldOperator
|
@@ -20,9 +23,9 @@ class PostProcess(MultiStreamOperator):
|
|
20 |
|
21 |
def prepare(self):
|
22 |
super().prepare()
|
23 |
-
self.prediction_operator = copy.
|
24 |
self.prediction_operator.field = "prediction"
|
25 |
-
self.references_operator = copy.
|
26 |
self.references_operator.field = "references"
|
27 |
self.references_operator.process_every_value = True
|
28 |
self.references_operator.dont_apply_to_streams = [constants.inference_stream]
|
@@ -315,3 +318,75 @@ class ExtractArenaHardNumericalJudgment(FieldOperator):
|
|
315 |
|
316 |
except:
|
317 |
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import copy
|
3 |
import json
|
4 |
import re
|
5 |
+
import string
|
6 |
from difflib import get_close_matches
|
7 |
from typing import Any, Dict
|
8 |
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
from .deprecation_utils import deprecation
|
12 |
from .operator import MultiStreamOperator
|
13 |
from .operators import FieldOperator, InstanceFieldOperator
|
|
|
23 |
|
24 |
def prepare(self):
|
25 |
super().prepare()
|
26 |
+
self.prediction_operator = copy.copy(self.operator)
|
27 |
self.prediction_operator.field = "prediction"
|
28 |
+
self.references_operator = copy.copy(self.operator)
|
29 |
self.references_operator.field = "references"
|
30 |
self.references_operator.process_every_value = True
|
31 |
self.references_operator.dont_apply_to_streams = [constants.inference_stream]
|
|
|
318 |
|
319 |
except:
|
320 |
return 0
|
321 |
+
|
322 |
+
|
323 |
+
class InferDictsToBinaryLogprobs(FieldOperator):
|
324 |
+
neg_class_name: str
|
325 |
+
pos_class_name: str
|
326 |
+
|
327 |
+
take_logprobs_from_end: bool = False
|
328 |
+
num_logprobs_to_take: int = 3
|
329 |
+
min_probability_mass = 0.0001
|
330 |
+
|
331 |
+
def verify(self):
|
332 |
+
super().verify()
|
333 |
+
if (
|
334 |
+
self.neg_class_name.lower() in self.pos_class_name.lower()
|
335 |
+
or self.pos_class_name.lower() in self.neg_class_name.lower()
|
336 |
+
):
|
337 |
+
raise ValueError(
|
338 |
+
f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}"""
|
339 |
+
)
|
340 |
+
|
341 |
+
def process_value(self, obj: Any) -> Any:
|
342 |
+
for i in self.get_token_range(obj):
|
343 |
+
try:
|
344 |
+
pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i])
|
345 |
+
if pos_probs or neg_probs:
|
346 |
+
sum_probs = sum(pos_probs) + sum(neg_probs)
|
347 |
+
if sum_probs > self.min_probability_mass:
|
348 |
+
return sum(pos_probs) / sum_probs
|
349 |
+
except:
|
350 |
+
pass
|
351 |
+
return 0
|
352 |
+
|
353 |
+
def get_pos_neg_probs(self, pred_dict):
|
354 |
+
token_logprobs = pred_dict["top_tokens"]
|
355 |
+
|
356 |
+
pos_and_neg_probs = []
|
357 |
+
for class_name in [self.pos_class_name, self.neg_class_name]:
|
358 |
+
# We need to capture different variants of model behavior and tokenizers, for example with opening space,
|
359 |
+
# punctuation etc. but avoid longer words that contain the class name.
|
360 |
+
# For example, for class "yes" we would capture "YES," and " Yes" but not "yesterday".
|
361 |
+
name_regex = re.compile(
|
362 |
+
rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE
|
363 |
+
)
|
364 |
+
class_probs = [
|
365 |
+
np.exp(d["logprob"])
|
366 |
+
for d in token_logprobs
|
367 |
+
if name_regex.fullmatch(d["text"])
|
368 |
+
]
|
369 |
+
pos_and_neg_probs.append(class_probs)
|
370 |
+
return pos_and_neg_probs
|
371 |
+
|
372 |
+
def get_token_range(self, obj: Any) -> range:
|
373 |
+
n_tokens = min([self.num_logprobs_to_take, len(obj)])
|
374 |
+
if self.take_logprobs_from_end:
|
375 |
+
return range(-1, -(n_tokens + 1), -1)
|
376 |
+
return range(n_tokens)
|
377 |
+
|
378 |
+
|
379 |
+
class RemoveArticles(FieldOperator):
|
380 |
+
def process_value(self, text: Any) -> Any:
|
381 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
382 |
+
|
383 |
+
|
384 |
+
class RemovePunctuations(FieldOperator):
|
385 |
+
def process_value(self, text: Any) -> Any:
|
386 |
+
puncs_to_exclude = set(string.punctuation)
|
387 |
+
return "".join(c for c in text if c not in puncs_to_exclude)
|
388 |
+
|
389 |
+
|
390 |
+
class FixWhiteSpace(FieldOperator):
|
391 |
+
def process_value(self, text: Any) -> Any:
|
392 |
+
return " ".join(text.split())
|
settings_utils.py
CHANGED
@@ -147,6 +147,7 @@ if Settings.is_uninitilized():
|
|
147 |
settings.skip_artifacts_prepare_and_verify = (bool, False)
|
148 |
settings.data_classification_policy = None
|
149 |
settings.mock_inference_mode = (bool, False)
|
|
|
150 |
|
151 |
if Constants.is_uninitilized():
|
152 |
constants = Constants()
|
|
|
147 |
settings.skip_artifacts_prepare_and_verify = (bool, False)
|
148 |
settings.data_classification_policy = None
|
149 |
settings.mock_inference_mode = (bool, False)
|
150 |
+
settings.disable_hf_datasets_cache = (bool, True)
|
151 |
|
152 |
if Constants.is_uninitilized():
|
153 |
constants = Constants()
|
split_utils.py
CHANGED
@@ -226,7 +226,12 @@ def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]):
|
|
226 |
dict: A dictionary containing the generated new streams, where each key is the name
|
227 |
of the new stream and the value is a generator representing the stream.
|
228 |
"""
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
|
232 |
def random_mix_generator(
|
|
|
226 |
dict: A dictionary containing the generated new streams, where each key is the name
|
227 |
of the new stream and the value is a generator representing the stream.
|
228 |
"""
|
229 |
+
new_streams = {}
|
230 |
+
for key, val in mapping.items():
|
231 |
+
if key not in input_streams:
|
232 |
+
raise ValueError("Wrong stream name")
|
233 |
+
new_streams[val] = input_streams.pop(key)
|
234 |
+
return {**input_streams, **new_streams}
|
235 |
|
236 |
|
237 |
def random_mix_generator(
|
splitters.py
CHANGED
@@ -16,7 +16,7 @@ from .split_utils import (
|
|
16 |
)
|
17 |
from .stream import EmptyStreamError, FaultyStreamError, MultiStream
|
18 |
from .type_utils import isoftype
|
19 |
-
from .utils import
|
20 |
|
21 |
|
22 |
class Splitter(MultiStreamOperator):
|
@@ -353,7 +353,9 @@ class Sample(InstanceOperatorWithMultiStreamAccess):
|
|
353 |
sample_size = self.get_sample_size(instance)
|
354 |
try:
|
355 |
if self.local_cache is None:
|
356 |
-
self.local_cache =
|
|
|
|
|
357 |
|
358 |
source_stream = self.local_cache
|
359 |
source_stream = self.sampler.filter_source_by_instance(
|
|
|
16 |
)
|
17 |
from .stream import EmptyStreamError, FaultyStreamError, MultiStream
|
18 |
from .type_utils import isoftype
|
19 |
+
from .utils import recursive_shallow_copy
|
20 |
|
21 |
|
22 |
class Splitter(MultiStreamOperator):
|
|
|
353 |
sample_size = self.get_sample_size(instance)
|
354 |
try:
|
355 |
if self.local_cache is None:
|
356 |
+
self.local_cache = recursive_shallow_copy(
|
357 |
+
list(multi_stream[self.from_stream])
|
358 |
+
)
|
359 |
|
360 |
source_stream = self.local_cache
|
361 |
source_stream = self.sampler.filter_source_by_instance(
|
standard.py
CHANGED
@@ -249,12 +249,12 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
249 |
def produce(self, task_instances):
|
250 |
"""Use the recipe in production to produce model ready query from standard task instance."""
|
251 |
self.before_process_multi_stream()
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
)
|
258 |
multi_stream = self.inference(multi_stream)
|
259 |
return list(multi_stream[constants.inference_stream])
|
260 |
|
|
|
249 |
def produce(self, task_instances):
|
250 |
"""Use the recipe in production to produce model ready query from standard task instance."""
|
251 |
self.before_process_multi_stream()
|
252 |
+
streams = {
|
253 |
+
constants.inference_stream: self.production_preprocess(task_instances),
|
254 |
+
}
|
255 |
+
if self.use_demos:
|
256 |
+
streams[self.demos_pool_name] = self.production_demos_pool()
|
257 |
+
multi_stream = MultiStream.from_iterables(streams)
|
258 |
multi_stream = self.inference(multi_stream)
|
259 |
return list(multi_stream[constants.inference_stream])
|
260 |
|
stream.py
CHANGED
@@ -10,7 +10,7 @@ from .dataclass import Dataclass, OptionalField
|
|
10 |
from .generator_utils import CopyingReusableGenerator, ReusableGenerator
|
11 |
from .logging_utils import get_logger
|
12 |
from .settings_utils import get_settings
|
13 |
-
from .utils import
|
14 |
|
15 |
settings = get_settings()
|
16 |
logger = get_logger()
|
@@ -40,7 +40,7 @@ class ListStream(Stream):
|
|
40 |
|
41 |
def __iter__(self):
|
42 |
if self.copying:
|
43 |
-
return iter(
|
44 |
return iter(self.instances_list)
|
45 |
|
46 |
def peek(self):
|
@@ -244,7 +244,8 @@ class MultiStream(dict):
|
|
244 |
return IterableDatasetDict(
|
245 |
{
|
246 |
key: IterableDataset.from_generator(
|
247 |
-
self.get_generator,
|
|
|
248 |
)
|
249 |
for key in self.keys()
|
250 |
}
|
|
|
10 |
from .generator_utils import CopyingReusableGenerator, ReusableGenerator
|
11 |
from .logging_utils import get_logger
|
12 |
from .settings_utils import get_settings
|
13 |
+
from .utils import recursive_copy
|
14 |
|
15 |
settings = get_settings()
|
16 |
logger = get_logger()
|
|
|
40 |
|
41 |
def __iter__(self):
|
42 |
if self.copying:
|
43 |
+
return iter(recursive_copy(self.instances_list))
|
44 |
return iter(self.instances_list)
|
45 |
|
46 |
def peek(self):
|
|
|
244 |
return IterableDatasetDict(
|
245 |
{
|
246 |
key: IterableDataset.from_generator(
|
247 |
+
self.get_generator,
|
248 |
+
gen_kwargs={"key": key},
|
249 |
)
|
250 |
for key in self.keys()
|
251 |
}
|
stream_operators.py
CHANGED
@@ -31,6 +31,7 @@ The rest of this section is dedicated for operators that operates on streams.
|
|
31 |
|
32 |
"""
|
33 |
|
|
|
34 |
from typing import (
|
35 |
List,
|
36 |
Literal,
|
@@ -154,6 +155,7 @@ class DuplicateSplit(MultiStreamOperator):
|
|
154 |
|
155 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
156 |
assert self.split in multi_stream
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
31 |
|
32 |
"""
|
33 |
|
34 |
+
import copy
|
35 |
from typing import (
|
36 |
List,
|
37 |
Literal,
|
|
|
155 |
|
156 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
157 |
assert self.split in multi_stream
|
158 |
+
new_stream = copy.deepcopy(multi_stream[self.split])
|
159 |
+
new_stream.set_copying(copying=True)
|
160 |
+
multi_stream[self.to_split] = new_stream
|
161 |
+
return multi_stream
|
string_operators.py
CHANGED
@@ -87,3 +87,12 @@ class Replace(FieldOperator):
|
|
87 |
|
88 |
def process_value(self, value: str) -> str:
|
89 |
return value.replace(self.old, self.new)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
def process_value(self, value: str) -> str:
|
89 |
return value.replace(self.old, self.new)
|
90 |
+
|
91 |
+
|
92 |
+
class MapReplace(FieldOperator):
|
93 |
+
mapping: Dict[str, str]
|
94 |
+
|
95 |
+
def process_value(self, value: Any) -> Any:
|
96 |
+
for key, val in self.mapping.items():
|
97 |
+
value = value.replace(key, val)
|
98 |
+
return value
|
struct_data_operators.py
CHANGED
@@ -32,7 +32,7 @@ from .operators import FieldOperator, InstanceOperator
|
|
32 |
from .random_utils import new_random_generator
|
33 |
from .serializers import TableSerializer
|
34 |
from .types import Table
|
35 |
-
from .utils import
|
36 |
|
37 |
|
38 |
def shuffle_columns(table: Table, seed=0) -> Table:
|
@@ -76,7 +76,7 @@ class SerializeTable(ABC, TableSerializer):
|
|
76 |
shuffle_columns: bool = False
|
77 |
|
78 |
def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
|
79 |
-
value =
|
80 |
if self.shuffle_columns:
|
81 |
value = shuffle_columns(table=value, seed=self.seed)
|
82 |
|
@@ -207,6 +207,12 @@ class SerializeTableAsDFLoader(SerializeTable):
|
|
207 |
|
208 |
assert header and rows, "Incorrect input table format"
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
# Create a pandas DataFrame
|
211 |
df = pd.DataFrame(rows, columns=header)
|
212 |
|
@@ -252,6 +258,59 @@ class SerializeTableAsJson(SerializeTable):
|
|
252 |
return json.dumps(output_dict)
|
253 |
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
# truncate cell value to maximum allowed length
|
256 |
def truncate_cell(cell_value, max_len):
|
257 |
if cell_value is None:
|
@@ -490,7 +549,7 @@ class ConvertTableColNamesToSequential(FieldOperator):
|
|
490 |
"""
|
491 |
|
492 |
def process_value(self, table: Any) -> Any:
|
493 |
-
table_input =
|
494 |
return self.replace_header(table_content=table_input)
|
495 |
|
496 |
# replaces header with sequential column names
|
@@ -523,7 +582,7 @@ class ShuffleTableRows(FieldOperator):
|
|
523 |
"""
|
524 |
|
525 |
def process_value(self, table: Any) -> Any:
|
526 |
-
table_input =
|
527 |
return shuffle_rows(table_input)
|
528 |
|
529 |
|
@@ -544,7 +603,7 @@ class ShuffleTableColumns(FieldOperator):
|
|
544 |
"""
|
545 |
|
546 |
def process_value(self, table: Any) -> Any:
|
547 |
-
table_input =
|
548 |
return shuffle_columns(table_input)
|
549 |
|
550 |
|
@@ -658,3 +717,133 @@ class ConstructTableFromRowsCols(InstanceOperator):
|
|
658 |
instance[self.to_field] = output_dict
|
659 |
|
660 |
return instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
from .random_utils import new_random_generator
|
33 |
from .serializers import TableSerializer
|
34 |
from .types import Table
|
35 |
+
from .utils import recursive_copy
|
36 |
|
37 |
|
38 |
def shuffle_columns(table: Table, seed=0) -> Table:
|
|
|
76 |
shuffle_columns: bool = False
|
77 |
|
78 |
def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
|
79 |
+
value = recursive_copy(value)
|
80 |
if self.shuffle_columns:
|
81 |
value = shuffle_columns(table=value, seed=self.seed)
|
82 |
|
|
|
207 |
|
208 |
assert header and rows, "Incorrect input table format"
|
209 |
|
210 |
+
# Fix duplicate columns, ensuring the first occurrence has no suffix
|
211 |
+
header = [
|
212 |
+
f"{col}_{header[:i].count(col)}" if header[:i].count(col) > 0 else col
|
213 |
+
for i, col in enumerate(header)
|
214 |
+
]
|
215 |
+
|
216 |
# Create a pandas DataFrame
|
217 |
df = pd.DataFrame(rows, columns=header)
|
218 |
|
|
|
258 |
return json.dumps(output_dict)
|
259 |
|
260 |
|
261 |
+
class SerializeTableAsHTML(SerializeTable):
|
262 |
+
"""HTML Table Serializer.
|
263 |
+
|
264 |
+
HTML table format used for rendering tables in web pages.
|
265 |
+
Format(Sample):
|
266 |
+
<table>
|
267 |
+
<thead>
|
268 |
+
<tr><th>name</th><th>age</th><th>sex</th></tr>
|
269 |
+
</thead>
|
270 |
+
<tbody>
|
271 |
+
<tr><td>Alice</td><td>26</td><td>F</td></tr>
|
272 |
+
<tr><td>Raj</td><td>34</td><td>M</td></tr>
|
273 |
+
</tbody>
|
274 |
+
</table>
|
275 |
+
"""
|
276 |
+
|
277 |
+
# main method that serializes a table.
|
278 |
+
# table_content must be in the prescribed input format.
|
279 |
+
def serialize_table(self, table_content: Dict) -> str:
|
280 |
+
# Extract headers and rows from the dictionary
|
281 |
+
header = table_content.get("header", [])
|
282 |
+
rows = table_content.get("rows", [])
|
283 |
+
|
284 |
+
assert header and rows, "Incorrect input table format"
|
285 |
+
|
286 |
+
# Build the HTML table structure
|
287 |
+
serialized_tbl_str = "<table>\n"
|
288 |
+
serialized_tbl_str += self.process_header(header) + "\n"
|
289 |
+
serialized_tbl_str += self.process_rows(rows) + "\n"
|
290 |
+
serialized_tbl_str += "</table>"
|
291 |
+
|
292 |
+
return serialized_tbl_str.strip()
|
293 |
+
|
294 |
+
# serialize the header into an HTML <thead> section
|
295 |
+
def process_header(self, header: List) -> str:
|
296 |
+
header_html = " <thead>\n <tr>"
|
297 |
+
for col in header:
|
298 |
+
header_html += f"<th>{col}</th>"
|
299 |
+
header_html += "</tr>\n </thead>"
|
300 |
+
return header_html
|
301 |
+
|
302 |
+
# serialize the rows into an HTML <tbody> section
|
303 |
+
def process_rows(self, rows: List[List]) -> str:
|
304 |
+
rows_html = " <tbody>"
|
305 |
+
for row in rows:
|
306 |
+
rows_html += "\n <tr>"
|
307 |
+
for cell in row:
|
308 |
+
rows_html += f"<td>{cell}</td>"
|
309 |
+
rows_html += "</tr>"
|
310 |
+
rows_html += "\n </tbody>"
|
311 |
+
return rows_html
|
312 |
+
|
313 |
+
|
314 |
# truncate cell value to maximum allowed length
|
315 |
def truncate_cell(cell_value, max_len):
|
316 |
if cell_value is None:
|
|
|
549 |
"""
|
550 |
|
551 |
def process_value(self, table: Any) -> Any:
|
552 |
+
table_input = recursive_copy(table)
|
553 |
return self.replace_header(table_content=table_input)
|
554 |
|
555 |
# replaces header with sequential column names
|
|
|
582 |
"""
|
583 |
|
584 |
def process_value(self, table: Any) -> Any:
|
585 |
+
table_input = recursive_copy(table)
|
586 |
return shuffle_rows(table_input)
|
587 |
|
588 |
|
|
|
603 |
"""
|
604 |
|
605 |
def process_value(self, table: Any) -> Any:
|
606 |
+
table_input = recursive_copy(table)
|
607 |
return shuffle_columns(table_input)
|
608 |
|
609 |
|
|
|
717 |
instance[self.to_field] = output_dict
|
718 |
|
719 |
return instance
|
720 |
+
|
721 |
+
|
722 |
+
class TransposeTable(FieldOperator):
|
723 |
+
"""Transpose a table.
|
724 |
+
|
725 |
+
Sample Input:
|
726 |
+
{
|
727 |
+
"header": ["name", "age", "sex"],
|
728 |
+
"rows": [["Alice", 26, "F"], ["Raj", 34, "M"], ["Donald", 39, "M"]],
|
729 |
+
}
|
730 |
+
|
731 |
+
Sample Output:
|
732 |
+
{
|
733 |
+
"header": [" ", "0", "1", "2"],
|
734 |
+
"rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
|
735 |
+
}
|
736 |
+
"""
|
737 |
+
|
738 |
+
def process_value(self, table: Any) -> Any:
|
739 |
+
return self.transpose_table(table)
|
740 |
+
|
741 |
+
def transpose_table(self, table: Dict) -> Dict:
|
742 |
+
# Extract the header and rows from the table object
|
743 |
+
header = table["header"]
|
744 |
+
rows = table["rows"]
|
745 |
+
|
746 |
+
# Transpose the table by converting rows as columns and vice versa
|
747 |
+
transposed_header = [" "] + [str(i) for i in range(len(rows))]
|
748 |
+
transposed_rows = [
|
749 |
+
[header[i]] + [row[i] for row in rows] for i in range(len(header))
|
750 |
+
]
|
751 |
+
|
752 |
+
return {"header": transposed_header, "rows": transposed_rows}
|
753 |
+
|
754 |
+
|
755 |
+
class DuplicateTableRows(FieldOperator):
|
756 |
+
"""Duplicates specific rows of a table for the given number of times.
|
757 |
+
|
758 |
+
Args:
|
759 |
+
row_indices (List[int]) - rows to be duplicated
|
760 |
+
times(int) - how many times to duplicate
|
761 |
+
"""
|
762 |
+
|
763 |
+
row_indices: List[int] = []
|
764 |
+
times: int = 1
|
765 |
+
|
766 |
+
def process_value(self, table: Any) -> Any:
|
767 |
+
# Extract the header and rows from the table
|
768 |
+
header = table["header"]
|
769 |
+
rows = table["rows"]
|
770 |
+
|
771 |
+
# Duplicate only the specified rows
|
772 |
+
duplicated_rows = []
|
773 |
+
for i, row in enumerate(rows):
|
774 |
+
if i in self.row_indices:
|
775 |
+
duplicated_rows.extend(
|
776 |
+
[row] * self.times
|
777 |
+
) # Duplicate the selected rows
|
778 |
+
else:
|
779 |
+
duplicated_rows.append(row) # Leave other rows unchanged
|
780 |
+
|
781 |
+
# Return the new table with selectively duplicated rows
|
782 |
+
return {"header": header, "rows": duplicated_rows}
|
783 |
+
|
784 |
+
|
785 |
+
class DuplicateTableColumns(FieldOperator):
|
786 |
+
"""Duplicates specific columns of a table for the given number of times.
|
787 |
+
|
788 |
+
Args:
|
789 |
+
column_indices (List[int]) - columns to be duplicated
|
790 |
+
times(int) - how many times to duplicate
|
791 |
+
"""
|
792 |
+
|
793 |
+
column_indices: List[int] = []
|
794 |
+
times: int = 1
|
795 |
+
|
796 |
+
def process_value(self, table: Any) -> Any:
|
797 |
+
# Extract the header and rows from the table
|
798 |
+
header = table["header"]
|
799 |
+
rows = table["rows"]
|
800 |
+
|
801 |
+
# Duplicate the specified columns in the header
|
802 |
+
duplicated_header = []
|
803 |
+
for i, col in enumerate(header):
|
804 |
+
if i in self.column_indices:
|
805 |
+
duplicated_header.extend([col] * self.times)
|
806 |
+
else:
|
807 |
+
duplicated_header.append(col)
|
808 |
+
|
809 |
+
# Duplicate the specified columns in each row
|
810 |
+
duplicated_rows = []
|
811 |
+
for row in rows:
|
812 |
+
new_row = []
|
813 |
+
for i, value in enumerate(row):
|
814 |
+
if i in self.column_indices:
|
815 |
+
new_row.extend([value] * self.times)
|
816 |
+
else:
|
817 |
+
new_row.append(value)
|
818 |
+
duplicated_rows.append(new_row)
|
819 |
+
|
820 |
+
# Return the new table with selectively duplicated columns
|
821 |
+
return {"header": duplicated_header, "rows": duplicated_rows}
|
822 |
+
|
823 |
+
|
824 |
+
class InsertEmptyTableRows(FieldOperator):
|
825 |
+
"""Inserts empty rows in a table randomly for the given number of times.
|
826 |
+
|
827 |
+
Args:
|
828 |
+
times(int) - how many times to insert
|
829 |
+
"""
|
830 |
+
|
831 |
+
times: int = 0
|
832 |
+
|
833 |
+
def process_value(self, table: Any) -> Any:
|
834 |
+
# Extract the header and rows from the table
|
835 |
+
header = table["header"]
|
836 |
+
rows = table["rows"]
|
837 |
+
|
838 |
+
# Insert empty rows at random positions
|
839 |
+
for _ in range(self.times):
|
840 |
+
empty_row = [""] * len(
|
841 |
+
header
|
842 |
+
) # Create an empty row with the same number of columns
|
843 |
+
insert_pos = random.randint(
|
844 |
+
0, len(rows)
|
845 |
+
) # Get a random position to insert the empty row created
|
846 |
+
rows.insert(insert_pos, empty_row)
|
847 |
+
|
848 |
+
# Return the modified table
|
849 |
+
return {"header": header, "rows": rows}
|
templates.py
CHANGED
@@ -210,7 +210,7 @@ class ApplyTemplate(InstanceOperator):
|
|
210 |
if self.demos_field not in instance:
|
211 |
raise ValueError("Demos field is missing.")
|
212 |
instance[self.demos_field] = [
|
213 |
-
self.apply(template, demo_instance
|
214 |
for demo_instance in instance[self.demos_field]
|
215 |
]
|
216 |
dict_set(instance, "recipe_metadata/template", template)
|
|
|
210 |
if self.demos_field not in instance:
|
211 |
raise ValueError("Demos field is missing.")
|
212 |
instance[self.demos_field] = [
|
213 |
+
self.apply(template, demo_instance)
|
214 |
for demo_instance in instance[self.demos_field]
|
215 |
]
|
216 |
dict_set(instance, "recipe_metadata/template", template)
|
type_utils.py
CHANGED
@@ -4,6 +4,7 @@ import io
|
|
4 |
import itertools
|
5 |
import re
|
6 |
import typing
|
|
|
7 |
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
8 |
|
9 |
from .utils import safe_eval
|
@@ -810,6 +811,7 @@ class NormalizedType(typing.NamedTuple):
|
|
810 |
return f"{self.origin}[{self.args}])"
|
811 |
|
812 |
|
|
|
813 |
def _normalize_args(tps: TypeArgs):
|
814 |
if isinstance(tps, str):
|
815 |
return tps
|
@@ -918,6 +920,7 @@ def _is_origin_subtype_args(
|
|
918 |
return _is_normal_subtype(left, right, forward_refs)
|
919 |
|
920 |
|
|
|
921 |
def _is_normal_subtype(
|
922 |
left: NormalizedType,
|
923 |
right: NormalizedType,
|
|
|
4 |
import itertools
|
5 |
import re
|
6 |
import typing
|
7 |
+
from functools import lru_cache
|
8 |
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
9 |
|
10 |
from .utils import safe_eval
|
|
|
811 |
return f"{self.origin}[{self.args}])"
|
812 |
|
813 |
|
814 |
+
@lru_cache(maxsize=None)
|
815 |
def _normalize_args(tps: TypeArgs):
|
816 |
if isinstance(tps, str):
|
817 |
return tps
|
|
|
920 |
return _is_normal_subtype(left, right, forward_refs)
|
921 |
|
922 |
|
923 |
+
@lru_cache(maxsize=None)
|
924 |
def _is_normal_subtype(
|
925 |
left: NormalizedType,
|
926 |
right: NormalizedType,
|
utils.py
CHANGED
@@ -148,5 +148,88 @@ def import_module_from_file(file_path):
|
|
148 |
return module
|
149 |
|
150 |
|
151 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
return copy.deepcopy(obj)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
return module
|
149 |
|
150 |
|
151 |
+
def deep_copy(obj):
|
152 |
+
"""Creates a deep copy of the given object.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
obj: The object to be deep copied.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
A deep copy of the original object.
|
159 |
+
"""
|
160 |
return copy.deepcopy(obj)
|
161 |
+
|
162 |
+
|
163 |
+
def shallow_copy(obj):
|
164 |
+
"""Creates a shallow copy of the given object.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
obj: The object to be shallow copied.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
A shallow copy of the original object.
|
171 |
+
"""
|
172 |
+
return copy.copy(obj)
|
173 |
+
|
174 |
+
|
175 |
+
def recursive_copy(obj, internal_copy=None):
|
176 |
+
"""Recursively copies an object with a selective copy method.
|
177 |
+
|
178 |
+
For `list`, `dict`, and `tuple` types, it recursively copies their contents.
|
179 |
+
For other types, it uses the provided `internal_copy` function if available.
|
180 |
+
Objects without a `copy` method are returned as is.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
obj: The object to be copied.
|
184 |
+
internal_copy (callable, optional): The copy function to use for non-container objects.
|
185 |
+
If `None`, objects without a `copy` method are returned as is.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
The recursively copied object.
|
189 |
+
"""
|
190 |
+
# Handle dictionaries
|
191 |
+
if isinstance(obj, dict):
|
192 |
+
return type(obj)(
|
193 |
+
{key: recursive_copy(value, internal_copy) for key, value in obj.items()}
|
194 |
+
)
|
195 |
+
|
196 |
+
# Handle named tuples
|
197 |
+
if isinstance(obj, tuple) and hasattr(obj, "_fields"):
|
198 |
+
return type(obj)(*(recursive_copy(item, internal_copy) for item in obj))
|
199 |
+
|
200 |
+
# Handle tuples and lists
|
201 |
+
if isinstance(obj, (tuple, list)):
|
202 |
+
return type(obj)(recursive_copy(item, internal_copy) for item in obj)
|
203 |
+
|
204 |
+
if internal_copy is None:
|
205 |
+
return obj
|
206 |
+
|
207 |
+
return internal_copy(obj)
|
208 |
+
|
209 |
+
|
210 |
+
def recursive_deep_copy(obj):
|
211 |
+
"""Performs a recursive deep copy of the given object.
|
212 |
+
|
213 |
+
This function uses `deep_copy` as the internal copy method for non-container objects.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
obj: The object to be deep copied.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
A recursively deep-copied version of the original object.
|
220 |
+
"""
|
221 |
+
return recursive_copy(obj, deep_copy)
|
222 |
+
|
223 |
+
|
224 |
+
def recursive_shallow_copy(obj):
|
225 |
+
"""Performs a recursive shallow copy of the given object.
|
226 |
+
|
227 |
+
This function uses `shallow_copy` as the internal copy method for non-container objects.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
obj: The object to be shallow copied.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
A recursively shallow-copied version of the original object.
|
234 |
+
"""
|
235 |
+
return recursive_copy(obj, shallow_copy)
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.
|
|
|
1 |
+
version = "1.14.0"
|