Upload folder using huggingface_hub
Browse files- api.py +73 -36
- artifact.py +4 -4
- benchmark.py +58 -0
- blocks.py +2 -1
- card.py +2 -2
- catalog.py +41 -0
- collections_operators.py +9 -7
- dataclass.py +0 -23
- dataset.py +2 -0
- dict_utils.py +71 -19
- formats.py +5 -5
- fusion.py +42 -46
- image_operators.py +26 -0
- inference.py +84 -5
- llm_as_judge.py +35 -52
- loaders.py +21 -21
- metric.py +2 -0
- metric_utils.py +184 -114
- metrics.py +171 -6
- operator.py +32 -27
- operators.py +116 -76
- processors.py +40 -2
- schema.py +77 -17
- settings_utils.py +14 -0
- standard.py +57 -32
- stream.py +8 -2
- struct_data_operators.py +37 -0
- task.py +30 -18
- templates.py +55 -23
- version.py +1 -1
api.py
CHANGED
@@ -6,8 +6,9 @@ from datasets import DatasetDict
|
|
6 |
from .artifact import fetch_artifact
|
7 |
from .dataset_utils import get_dataset_artifact
|
8 |
from .logging_utils import get_logger
|
9 |
-
from .metric_utils import _compute,
|
10 |
from .operator import SourceOperator
|
|
|
11 |
from .standard import StandardRecipe
|
12 |
|
13 |
logger = get_logger()
|
@@ -22,21 +23,60 @@ def load(source: Union[SourceOperator, str]) -> DatasetDict:
|
|
22 |
return source().to_dataset()
|
23 |
|
24 |
|
25 |
-
def
|
26 |
dataset_query = dataset_query.replace("sys_prompt", "instruction")
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
-
def
|
32 |
recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
|
33 |
for param in dataset_params.keys():
|
34 |
assert param in recipe_attributes, (
|
35 |
f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
|
36 |
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
|
37 |
)
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
@@ -47,7 +87,7 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
|
47 |
Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
|
48 |
|
49 |
Args:
|
50 |
-
dataset_query (str, optional): A string query which specifies dataset to load from local catalog.
|
51 |
For example:
|
52 |
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
|
53 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
@@ -65,26 +105,9 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
|
65 |
loader_limit = 10
|
66 |
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
|
67 |
"""
|
68 |
-
|
69 |
-
raise ValueError(
|
70 |
-
"Cannot provide 'dataset_query' and key-worded arguments at the same time. "
|
71 |
-
"If you want to load dataset from a card in local catalog, use query only. "
|
72 |
-
"Otherwise, use key-worded arguments only to specify properties of dataset."
|
73 |
-
)
|
74 |
-
|
75 |
-
if dataset_query:
|
76 |
-
if not isinstance(dataset_query, str):
|
77 |
-
raise ValueError(
|
78 |
-
f"If specified, 'dataset_query' must be a string, however, "
|
79 |
-
f"'{dataset_query}' was provided instead, which is of type "
|
80 |
-
f"'{type(dataset_query)}'."
|
81 |
-
)
|
82 |
-
return _load_dataset_from_query(dataset_query)
|
83 |
-
|
84 |
-
if kwargs:
|
85 |
-
return _load_dataset_from_dict(kwargs)
|
86 |
|
87 |
-
|
88 |
|
89 |
|
90 |
def evaluate(predictions, data) -> List[Dict[str, Any]]:
|
@@ -92,26 +115,40 @@ def evaluate(predictions, data) -> List[Dict[str, Any]]:
|
|
92 |
|
93 |
|
94 |
def post_process(predictions, data) -> List[Dict[str, Any]]:
|
95 |
-
return
|
96 |
|
97 |
|
98 |
@lru_cache
|
99 |
-
def _get_produce_with_cache(
|
100 |
-
return
|
101 |
|
102 |
|
103 |
-
def produce(instance_or_instances,
|
104 |
is_list = isinstance(instance_or_instances, list)
|
105 |
if not is_list:
|
106 |
instance_or_instances = [instance_or_instances]
|
107 |
-
result = _get_produce_with_cache(
|
108 |
if not is_list:
|
109 |
result = result[0]
|
110 |
return result
|
111 |
|
112 |
|
113 |
-
def infer(
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
engine, _ = fetch_artifact(engine)
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from .artifact import fetch_artifact
|
7 |
from .dataset_utils import get_dataset_artifact
|
8 |
from .logging_utils import get_logger
|
9 |
+
from .metric_utils import _compute, _inference_post_process
|
10 |
from .operator import SourceOperator
|
11 |
+
from .schema import UNITXT_DATASET_SCHEMA
|
12 |
from .standard import StandardRecipe
|
13 |
|
14 |
logger = get_logger()
|
|
|
23 |
return source().to_dataset()
|
24 |
|
25 |
|
26 |
+
def _get_recipe_from_query(dataset_query: str) -> StandardRecipe:
|
27 |
dataset_query = dataset_query.replace("sys_prompt", "instruction")
|
28 |
+
try:
|
29 |
+
dataset_stream, _ = fetch_artifact(dataset_query)
|
30 |
+
except:
|
31 |
+
dataset_stream = get_dataset_artifact(dataset_query)
|
32 |
+
return dataset_stream
|
33 |
|
34 |
|
35 |
+
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> StandardRecipe:
|
36 |
recipe_attributes = list(StandardRecipe.__dict__["__fields__"].keys())
|
37 |
for param in dataset_params.keys():
|
38 |
assert param in recipe_attributes, (
|
39 |
f"The parameter '{param}' is not an attribute of the 'StandardRecipe' class. "
|
40 |
f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
|
41 |
)
|
42 |
+
return StandardRecipe(**dataset_params)
|
43 |
+
|
44 |
+
|
45 |
+
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
|
46 |
+
if dataset_query and dataset_args:
|
47 |
+
raise ValueError(
|
48 |
+
"Cannot provide 'dataset_query' and key-worded arguments at the same time. "
|
49 |
+
"If you want to load dataset from a card in local catalog, use query only. "
|
50 |
+
"Otherwise, use key-worded arguments only to specify properties of dataset."
|
51 |
+
)
|
52 |
+
|
53 |
+
if dataset_query:
|
54 |
+
if not isinstance(dataset_query, str):
|
55 |
+
raise ValueError(
|
56 |
+
f"If specified, 'dataset_query' must be a string, however, "
|
57 |
+
f"'{dataset_query}' was provided instead, which is of type "
|
58 |
+
f"'{type(dataset_query)}'."
|
59 |
+
)
|
60 |
+
|
61 |
+
if not dataset_query and not dataset_args:
|
62 |
+
raise ValueError(
|
63 |
+
"Either 'dataset_query' or key-worded arguments must be provided."
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> StandardRecipe:
|
68 |
+
if isinstance(dataset_query, StandardRecipe):
|
69 |
+
return dataset_query
|
70 |
+
|
71 |
+
_verify_dataset_args(dataset_query, kwargs)
|
72 |
+
|
73 |
+
if dataset_query:
|
74 |
+
recipe = _get_recipe_from_query(dataset_query)
|
75 |
+
|
76 |
+
if kwargs:
|
77 |
+
recipe = _get_recipe_from_dict(kwargs)
|
78 |
+
|
79 |
+
return recipe
|
80 |
|
81 |
|
82 |
def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
|
|
|
87 |
Alternatively, dataset is loaded from a provided card based on explicitly given parameters.
|
88 |
|
89 |
Args:
|
90 |
+
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
91 |
For example:
|
92 |
"card=cards.wnli,template=templates.classification.multi_class.relation.default".
|
93 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
|
|
105 |
loader_limit = 10
|
106 |
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
|
107 |
"""
|
108 |
+
recipe = load_recipe(dataset_query, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
+
return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
|
111 |
|
112 |
|
113 |
def evaluate(predictions, data) -> List[Dict[str, Any]]:
|
|
|
115 |
|
116 |
|
117 |
def post_process(predictions, data) -> List[Dict[str, Any]]:
|
118 |
+
return _inference_post_process(predictions=predictions, references=data)
|
119 |
|
120 |
|
121 |
@lru_cache
|
122 |
+
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
|
123 |
+
return load_recipe(dataset_query, **kwargs).produce
|
124 |
|
125 |
|
126 |
+
def produce(instance_or_instances, dataset_query: Optional[str] = None, **kwargs):
|
127 |
is_list = isinstance(instance_or_instances, list)
|
128 |
if not is_list:
|
129 |
instance_or_instances = [instance_or_instances]
|
130 |
+
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
|
131 |
if not is_list:
|
132 |
result = result[0]
|
133 |
return result
|
134 |
|
135 |
|
136 |
+
def infer(
|
137 |
+
instance_or_instances,
|
138 |
+
engine,
|
139 |
+
dataset_query: Optional[str] = None,
|
140 |
+
return_data=False,
|
141 |
+
**kwargs,
|
142 |
+
):
|
143 |
+
dataset = produce(instance_or_instances, dataset_query, **kwargs)
|
144 |
engine, _ = fetch_artifact(engine)
|
145 |
+
raw_predictions = engine.infer(dataset)
|
146 |
+
predictions = post_process(raw_predictions, dataset)
|
147 |
+
if return_data:
|
148 |
+
for prediction, raw_prediction, instance in zip(
|
149 |
+
predictions, raw_predictions, dataset
|
150 |
+
):
|
151 |
+
instance["prediction"] = prediction
|
152 |
+
instance["raw_prediction"] = raw_prediction
|
153 |
+
return dataset
|
154 |
+
return predictions
|
artifact.py
CHANGED
@@ -439,10 +439,10 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[Artifactory, None]]:
|
|
439 |
"""Loads an artifict from one of possible representations.
|
440 |
|
441 |
(1) If artifact representation is already an Artifact object, return it.
|
442 |
-
(2) If artifact representation is a string location of a local file, load the Artifact from local file.
|
443 |
-
(3) If artifact representation is a string name
|
444 |
-
(4) If artifact representation is a json string, create dictionary representation from the string and build an Artifact object from it.
|
445 |
-
(5) Otherwise, check the artifact representation is a dictionary and build an Artifact object from it.
|
446 |
"""
|
447 |
if isinstance(artifact_rep, Artifact):
|
448 |
return artifact_rep, None
|
|
|
439 |
"""Loads an artifict from one of possible representations.
|
440 |
|
441 |
(1) If artifact representation is already an Artifact object, return it.
|
442 |
+
(2) If artifact representation is a string location of a local file, load the Artifact from the local file.
|
443 |
+
(3) If artifact representation is a string name in the catalog, load the Artifact from the catalog.
|
444 |
+
(4) If artifact representation is a json string, create a dictionary representation from the string and build an Artifact object from it.
|
445 |
+
(5) Otherwise, check that the artifact representation is a dictionary and build an Artifact object from it.
|
446 |
"""
|
447 |
if isinstance(artifact_rep, Artifact):
|
448 |
return artifact_rep, None
|
benchmark.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Union
|
2 |
+
|
3 |
+
from .dataclass import NonPositionalField
|
4 |
+
from .formats import Format
|
5 |
+
from .fusion import FixedFusion, WeightedFusion
|
6 |
+
from .operator import SourceOperator
|
7 |
+
from .standard import StandardRecipe
|
8 |
+
from .stream import MultiStream
|
9 |
+
from .system_prompts import SystemPrompt
|
10 |
+
|
11 |
+
|
12 |
+
class BaseBenchmark(SourceOperator):
|
13 |
+
format: Format = NonPositionalField(default=None)
|
14 |
+
num_demos: int = NonPositionalField(default=None)
|
15 |
+
system_prompt: SystemPrompt = NonPositionalField(default=None)
|
16 |
+
loader_limit: int = NonPositionalField(default=None)
|
17 |
+
|
18 |
+
|
19 |
+
class Benchmark(BaseBenchmark):
|
20 |
+
subsets: Dict[str, Union[StandardRecipe, BaseBenchmark]]
|
21 |
+
|
22 |
+
max_total_samples: int = None
|
23 |
+
max_samples_per_subset: int = None
|
24 |
+
|
25 |
+
def verify(self):
|
26 |
+
if (
|
27 |
+
self.max_total_samples is not None
|
28 |
+
and self.max_samples_per_subset is not None
|
29 |
+
):
|
30 |
+
raise ValueError("Set either max_total_samples or max_samples_per_subset")
|
31 |
+
|
32 |
+
def prepare(self):
|
33 |
+
if self.format is not None or self.num_demos is not None:
|
34 |
+
for subset in self.subsets.values():
|
35 |
+
if self.num_demos is not None:
|
36 |
+
subset.num_demos = self.num_demos
|
37 |
+
if self.format is not None:
|
38 |
+
subset.format = self.format
|
39 |
+
if self.system_prompt is not None:
|
40 |
+
subset.system_prompt = self.system_prompt
|
41 |
+
if self.loader_limit is not None:
|
42 |
+
subset.loader_limit = self.loader_limit
|
43 |
+
subset.prepare()
|
44 |
+
|
45 |
+
def process(
|
46 |
+
self,
|
47 |
+
) -> MultiStream:
|
48 |
+
if self.max_total_samples is None:
|
49 |
+
operator = FixedFusion(
|
50 |
+
subsets=self.subsets,
|
51 |
+
max_instances_per_subset=self.max_samples_per_subset,
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
operator = WeightedFusion(
|
55 |
+
subsets=self.subsets, max_total_samples=self.max_total_samples
|
56 |
+
)
|
57 |
+
|
58 |
+
return operator()
|
blocks.py
CHANGED
@@ -13,7 +13,7 @@ from .operators import (
|
|
13 |
Copy,
|
14 |
DivideAllFieldsBy,
|
15 |
MapInstanceValues,
|
16 |
-
|
17 |
Set,
|
18 |
)
|
19 |
from .processors import ToString, ToStringStripped
|
@@ -21,6 +21,7 @@ from .recipe import SequentialRecipe
|
|
21 |
from .splitters import RandomSampler, Sample, SliceSplit, SplitRandomMix
|
22 |
from .stream import MultiStream
|
23 |
from .struct_data_operators import (
|
|
|
24 |
ListToKeyValPairs,
|
25 |
MapHTMLTableToJSON,
|
26 |
SerializeKeyValPairs,
|
|
|
13 |
Copy,
|
14 |
DivideAllFieldsBy,
|
15 |
MapInstanceValues,
|
16 |
+
Rename,
|
17 |
Set,
|
18 |
)
|
19 |
from .processors import ToString, ToStringStripped
|
|
|
21 |
from .splitters import RandomSampler, Sample, SliceSplit, SplitRandomMix
|
22 |
from .stream import MultiStream
|
23 |
from .struct_data_operators import (
|
24 |
+
ConstructTableFromRowsCols,
|
25 |
ListToKeyValPairs,
|
26 |
MapHTMLTableToJSON,
|
27 |
SerializeKeyValPairs,
|
card.py
CHANGED
@@ -10,12 +10,12 @@ from .task import Task
|
|
10 |
|
11 |
|
12 |
class TaskCard(Artifact):
|
13 |
-
"""TaskCard delineates the phases in transforming the source dataset into
|
14 |
|
15 |
Attributes:
|
16 |
loader: specifies the source address and the loading operator that can access that source and transform it into a unitxt multistream.
|
17 |
|
18 |
-
preprocess_steps: list of unitxt operators to process the data source into
|
19 |
|
20 |
task: specifies the fields (of the already (pre)processed instance) making the inputs, the fields making the outputs, and the metrics to be used for evaluating the model output.
|
21 |
|
|
|
10 |
|
11 |
|
12 |
class TaskCard(Artifact):
|
13 |
+
"""TaskCard delineates the phases in transforming the source dataset into model input, and specifies the metrics for evaluation of model output.
|
14 |
|
15 |
Attributes:
|
16 |
loader: specifies the source address and the loading operator that can access that source and transform it into a unitxt multistream.
|
17 |
|
18 |
+
preprocess_steps: list of unitxt operators to process the data source into model input.
|
19 |
|
20 |
task: specifies the fields (of the already (pre)processed instance) making the inputs, the fields making the outputs, and the metrics to be used for evaluating the model output.
|
21 |
|
catalog.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
from collections import Counter
|
3 |
from functools import lru_cache
|
@@ -195,6 +196,46 @@ def summary():
|
|
195 |
return result
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
def ls(to_file=None):
|
199 |
done = set()
|
200 |
result = []
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
from collections import Counter
|
4 |
from functools import lru_cache
|
|
|
196 |
return result
|
197 |
|
198 |
|
199 |
+
def _get_tags_from_file(file_path):
|
200 |
+
result = Counter()
|
201 |
+
|
202 |
+
with open(file_path) as f:
|
203 |
+
data = json.load(f)
|
204 |
+
if "__tags__" in data and isinstance(data["__tags__"], dict):
|
205 |
+
tags = data["__tags__"]
|
206 |
+
for key, value in tags.items():
|
207 |
+
if isinstance(value, list):
|
208 |
+
for item in value:
|
209 |
+
result[f"{key}:{item}"] += 1
|
210 |
+
else:
|
211 |
+
result[f"{key}:{value}"] += 1
|
212 |
+
|
213 |
+
return result
|
214 |
+
|
215 |
+
|
216 |
+
def count_tags():
|
217 |
+
result = Counter()
|
218 |
+
done = set()
|
219 |
+
|
220 |
+
for local_catalog_path in get_local_catalogs_paths():
|
221 |
+
if local_catalog_path not in done:
|
222 |
+
for root, _, files in os.walk(local_catalog_path):
|
223 |
+
for file in files:
|
224 |
+
if file.endswith(".json"):
|
225 |
+
file_path = os.path.join(root, file)
|
226 |
+
try:
|
227 |
+
result += _get_tags_from_file(file_path)
|
228 |
+
except json.JSONDecodeError:
|
229 |
+
logger.info(f"Error decoding JSON in file: {file_path}")
|
230 |
+
except OSError:
|
231 |
+
logger.info(f"Error reading file: {file_path}")
|
232 |
+
|
233 |
+
done.add(local_catalog_path)
|
234 |
+
|
235 |
+
print_dict(result)
|
236 |
+
return result
|
237 |
+
|
238 |
+
|
239 |
def ls(to_file=None):
|
240 |
done = set()
|
241 |
result = []
|
collections_operators.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from typing import Any, Generator, List, Optional
|
2 |
|
|
|
3 |
from .operators import FieldOperator, StreamOperator
|
4 |
from .stream import Stream
|
5 |
from .utils import deepcopy
|
@@ -66,20 +67,21 @@ class DuplicateByList(StreamOperator):
|
|
66 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
67 |
to_field = self.field if self.to_field is None else self.to_field
|
68 |
for instance in stream:
|
69 |
-
elements = instance
|
70 |
for element in elements:
|
71 |
if self.use_deep_copy:
|
72 |
instance_copy = deepcopy(instance)
|
73 |
-
|
74 |
else:
|
75 |
-
instance_copy = {
|
76 |
-
|
77 |
-
self.field: elements,
|
78 |
-
to_field: element,
|
79 |
-
}
|
80 |
yield instance_copy
|
81 |
|
82 |
|
|
|
|
|
|
|
|
|
83 |
class DuplicateBySubLists(StreamOperator):
|
84 |
field: str
|
85 |
to_field: Optional[str] = None
|
|
|
1 |
from typing import Any, Generator, List, Optional
|
2 |
|
3 |
+
from .dict_utils import dict_get, dict_set
|
4 |
from .operators import FieldOperator, StreamOperator
|
5 |
from .stream import Stream
|
6 |
from .utils import deepcopy
|
|
|
67 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
68 |
to_field = self.field if self.to_field is None else self.to_field
|
69 |
for instance in stream:
|
70 |
+
elements = dict_get(instance, self.field)
|
71 |
for element in elements:
|
72 |
if self.use_deep_copy:
|
73 |
instance_copy = deepcopy(instance)
|
74 |
+
|
75 |
else:
|
76 |
+
instance_copy = {**instance}
|
77 |
+
dict_set(instance_copy, to_field, element)
|
|
|
|
|
|
|
78 |
yield instance_copy
|
79 |
|
80 |
|
81 |
+
class Explode(DuplicateByList):
|
82 |
+
pass
|
83 |
+
|
84 |
+
|
85 |
class DuplicateBySubLists(StreamOperator):
|
86 |
field: str
|
87 |
to_field: Optional[str] = None
|
dataclass.py
CHANGED
@@ -2,7 +2,6 @@ import copy
|
|
2 |
import dataclasses
|
3 |
import functools
|
4 |
import inspect
|
5 |
-
import warnings
|
6 |
from abc import ABCMeta
|
7 |
from inspect import Parameter, Signature
|
8 |
from typing import Any, Dict, List, Optional, final
|
@@ -38,7 +37,6 @@ class Field:
|
|
38 |
final: bool = False
|
39 |
abstract: bool = False
|
40 |
required: bool = False
|
41 |
-
deprecated: bool = False
|
42 |
internal: bool = False
|
43 |
origin_cls: type = None
|
44 |
metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
|
@@ -55,12 +53,6 @@ class FinalField(Field):
|
|
55 |
self.final = True
|
56 |
|
57 |
|
58 |
-
@dataclasses.dataclass
|
59 |
-
class DeprecatedField(Field):
|
60 |
-
def __post_init__(self):
|
61 |
-
self.deprecated = True
|
62 |
-
|
63 |
-
|
64 |
@dataclasses.dataclass
|
65 |
class RequiredField(Field):
|
66 |
def __post_init__(self):
|
@@ -251,10 +243,6 @@ def required_fields(cls):
|
|
251 |
return [field for field in fields(cls) if field.required]
|
252 |
|
253 |
|
254 |
-
def deprecated_fields(cls):
|
255 |
-
return [field for field in fields(cls) if field.deprecated]
|
256 |
-
|
257 |
-
|
258 |
def abstract_fields(cls):
|
259 |
return [field for field in fields(cls) if field.abstract]
|
260 |
|
@@ -267,10 +255,6 @@ def is_final_field(field):
|
|
267 |
return field.final
|
268 |
|
269 |
|
270 |
-
def is_deprecated_field(field):
|
271 |
-
return field.deprecated
|
272 |
-
|
273 |
-
|
274 |
def get_field_default(field):
|
275 |
if field.default_factory is not None:
|
276 |
return field.default_factory()
|
@@ -424,7 +408,6 @@ class Dataclass(metaclass=DataclassMeta):
|
|
424 |
"""Initialize fields based on kwargs.
|
425 |
|
426 |
Checks for abstract fields when an instance is created.
|
427 |
-
Warn when a deprecated is used
|
428 |
"""
|
429 |
super().__init__()
|
430 |
_init_fields = [field for field in fields(self) if field.init]
|
@@ -433,12 +416,6 @@ class Dataclass(metaclass=DataclassMeta):
|
|
433 |
field.name for field in _init_fields if field.also_positional
|
434 |
]
|
435 |
|
436 |
-
_init_deprecated_fields = [field for field in _init_fields if field.deprecated]
|
437 |
-
for dep_field in _init_deprecated_fields:
|
438 |
-
warnings.warn(
|
439 |
-
dep_field.metadata["deprecation_msg"], DeprecationWarning, stacklevel=2
|
440 |
-
)
|
441 |
-
|
442 |
for name in _init_positional_fields_names[: len(argv)]:
|
443 |
if name in kwargs:
|
444 |
raise TypeError(
|
|
|
2 |
import dataclasses
|
3 |
import functools
|
4 |
import inspect
|
|
|
5 |
from abc import ABCMeta
|
6 |
from inspect import Parameter, Signature
|
7 |
from typing import Any, Dict, List, Optional, final
|
|
|
37 |
final: bool = False
|
38 |
abstract: bool = False
|
39 |
required: bool = False
|
|
|
40 |
internal: bool = False
|
41 |
origin_cls: type = None
|
42 |
metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
|
|
|
53 |
self.final = True
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
@dataclasses.dataclass
|
57 |
class RequiredField(Field):
|
58 |
def __post_init__(self):
|
|
|
243 |
return [field for field in fields(cls) if field.required]
|
244 |
|
245 |
|
|
|
|
|
|
|
|
|
246 |
def abstract_fields(cls):
|
247 |
return [field for field in fields(cls) if field.abstract]
|
248 |
|
|
|
255 |
return field.final
|
256 |
|
257 |
|
|
|
|
|
|
|
|
|
258 |
def get_field_default(field):
|
259 |
if field.default_factory is not None:
|
260 |
return field.default_factory()
|
|
|
408 |
"""Initialize fields based on kwargs.
|
409 |
|
410 |
Checks for abstract fields when an instance is created.
|
|
|
411 |
"""
|
412 |
super().__init__()
|
413 |
_init_fields = [field for field in fields(self) if field.init]
|
|
|
416 |
field.name for field in _init_fields if field.also_positional
|
417 |
]
|
418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
for name in _init_positional_fields_names[: len(argv)]:
|
420 |
if name in kwargs:
|
421 |
raise TypeError(
|
dataset.py
CHANGED
@@ -4,6 +4,7 @@ import datasets
|
|
4 |
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
|
|
7 |
from .blocks import __file__ as _
|
8 |
from .card import __file__ as _
|
9 |
from .catalog import __file__ as _
|
@@ -23,6 +24,7 @@ from .fusion import __file__ as _
|
|
23 |
from .generator_utils import __file__ as _
|
24 |
from .hf_utils import __file__ as _
|
25 |
from .hf_utils import verify_versions_compatibility
|
|
|
26 |
from .inference import __file__ as _
|
27 |
from .instructions import __file__ as _
|
28 |
from .llm_as_judge import __file__ as _
|
|
|
4 |
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
7 |
+
from .benchmark import __file__ as _
|
8 |
from .blocks import __file__ as _
|
9 |
from .card import __file__ as _
|
10 |
from .catalog import __file__ as _
|
|
|
24 |
from .generator_utils import __file__ as _
|
25 |
from .hf_utils import __file__ as _
|
26 |
from .hf_utils import verify_versions_compatibility
|
27 |
+
from .image_operators import __file__ as _
|
28 |
from .inference import __file__ as _
|
29 |
from .instructions import __file__ as _
|
30 |
from .llm_as_judge import __file__ as _
|
dict_utils.py
CHANGED
@@ -4,8 +4,23 @@ from typing import Any, List, Tuple
|
|
4 |
from .text_utils import construct_dict_str
|
5 |
|
6 |
indx = re.compile(r"^(\d+)$")
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
name = re.compile(r"^[\w. -]+$")
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# formal definition of qpath syntax by which a query is specified:
|
10 |
# qpath -> A (/A)*
|
11 |
# A -> name | * | non-neg-int
|
@@ -51,7 +66,9 @@ name = re.compile(r"^[\w. -]+$")
|
|
51 |
|
52 |
|
53 |
# validate and normalizes into components
|
54 |
-
def validate_query_and_break_to_components(
|
|
|
|
|
55 |
if not isinstance(query, str) or len(query) == 0:
|
56 |
raise ValueError(
|
57 |
f"invalid query: either not a string or an empty string: {query}"
|
@@ -69,9 +86,9 @@ def validate_query_and_break_to_components(query: str) -> List[str]:
|
|
69 |
components = [component.strip() for component in components]
|
70 |
for component in components:
|
71 |
if not (
|
72 |
-
|
73 |
-
or component
|
74 |
-
or
|
75 |
):
|
76 |
raise ValueError(
|
77 |
f"Component {component} in input query is none of: valid field-name, non-neg-int, or '*'"
|
@@ -79,10 +96,14 @@ def validate_query_and_break_to_components(query: str) -> List[str]:
|
|
79 |
return components
|
80 |
|
81 |
|
82 |
-
def is_subpath(subpath, fullpath):
|
83 |
# Split the paths into individual components
|
84 |
-
subpath_components = validate_query_and_break_to_components(
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Check if the full path starts with the subpath
|
88 |
return fullpath_components[: len(subpath_components)] == subpath_components
|
@@ -100,16 +121,17 @@ def delete_values(
|
|
100 |
query: List[str],
|
101 |
index_into_query: int,
|
102 |
remove_empty_ancestors=False,
|
|
|
103 |
) -> Tuple[bool, Any]:
|
104 |
component = query[index_into_query]
|
105 |
if index_into_query == -1:
|
106 |
-
if component
|
107 |
# delete all members of the list or dict
|
108 |
current_element = [] if isinstance(current_element, list) else {}
|
109 |
return (True, current_element)
|
110 |
# component is a either a dictionary key or an index into a list,
|
111 |
# pop the respective element from current_element
|
112 |
-
if
|
113 |
component = int(component)
|
114 |
try:
|
115 |
current_element.pop(component)
|
@@ -141,6 +163,7 @@ def delete_values(
|
|
141 |
query=query,
|
142 |
index_into_query=index_into_query + 1,
|
143 |
remove_empty_ancestors=remove_empty_ancestors,
|
|
|
144 |
)
|
145 |
if not success:
|
146 |
continue
|
@@ -155,7 +178,7 @@ def delete_values(
|
|
155 |
return (any_success, current_element)
|
156 |
|
157 |
# current component is index into a list or a key into a dictionary
|
158 |
-
if
|
159 |
component = int(component)
|
160 |
try:
|
161 |
success, new_val = delete_values(
|
@@ -163,6 +186,7 @@ def delete_values(
|
|
163 |
query=query,
|
164 |
index_into_query=index_into_query + 1,
|
165 |
remove_empty_ancestors=remove_empty_ancestors,
|
|
|
166 |
)
|
167 |
if not success:
|
168 |
return (False, None)
|
@@ -176,7 +200,11 @@ def delete_values(
|
|
176 |
|
177 |
|
178 |
def dict_delete(
|
179 |
-
dic: dict,
|
|
|
|
|
|
|
|
|
180 |
):
|
181 |
# We remove from dic the value from each and every element lead to by a path matching the query.
|
182 |
# If remove_empty_ancestors=True, and the removal of any such value leaves its containing element (list or dict)
|
@@ -197,7 +225,9 @@ def dict_delete(
|
|
197 |
dic.pop(query.strip())
|
198 |
return
|
199 |
|
200 |
-
qpath = validate_query_and_break_to_components(
|
|
|
|
|
201 |
|
202 |
try:
|
203 |
success, new_val = delete_values(
|
@@ -205,6 +235,7 @@ def dict_delete(
|
|
205 |
query=qpath,
|
206 |
index_into_query=(-1) * len(qpath),
|
207 |
remove_empty_ancestors=remove_empty_ancestors,
|
|
|
208 |
)
|
209 |
|
210 |
if success:
|
@@ -225,7 +256,10 @@ def dict_delete(
|
|
225 |
# if query includes * then return a list of values reached by all paths that match the query
|
226 |
# flake8: noqa: C901
|
227 |
def get_values(
|
228 |
-
current_element: Any,
|
|
|
|
|
|
|
229 |
) -> Tuple[bool, Any]:
|
230 |
# going down from current_element through query[index_into_query].
|
231 |
if index_into_query == 0:
|
@@ -244,7 +278,12 @@ def get_values(
|
|
244 |
sub_elements = current_element
|
245 |
for sub_element in sub_elements:
|
246 |
try:
|
247 |
-
success, val = get_values(
|
|
|
|
|
|
|
|
|
|
|
248 |
if success:
|
249 |
to_ret.append(val)
|
250 |
except:
|
@@ -253,11 +292,14 @@ def get_values(
|
|
253 |
return (len(to_ret) > 0 or index_into_query == -1, to_ret)
|
254 |
# when * is the last component, it refers to 'all the contents' of an empty list being current_element.
|
255 |
# next_component is indx or name, current_element must be a list or a dict
|
256 |
-
if
|
257 |
component = int(component)
|
258 |
try:
|
259 |
success, new_val = get_values(
|
260 |
-
current_element[component],
|
|
|
|
|
|
|
261 |
)
|
262 |
if success:
|
263 |
return (True, new_val)
|
@@ -274,6 +316,7 @@ def set_values(
|
|
274 |
index_into_query: int,
|
275 |
fixed_parameters: dict,
|
276 |
set_multiple: bool = False,
|
|
|
277 |
) -> Tuple[bool, Any]:
|
278 |
if index_into_query == 0:
|
279 |
return (True, value) # matched query all along!
|
@@ -321,6 +364,7 @@ def set_values(
|
|
321 |
index_into_query=index_into_query + 1,
|
322 |
set_multiple=False, # now used, not allowed again,
|
323 |
fixed_parameters=fixed_parameters,
|
|
|
324 |
)
|
325 |
if not success:
|
326 |
continue
|
@@ -335,7 +379,7 @@ def set_values(
|
|
335 |
)
|
336 |
|
337 |
# component is an index into a list or a key into a dictionary
|
338 |
-
if
|
339 |
if current_element is None or not isinstance(current_element, list):
|
340 |
if not fixed_parameters["generate_if_not_exists"]:
|
341 |
return (False, None)
|
@@ -368,6 +412,7 @@ def set_values(
|
|
368 |
index_into_query=index_into_query + 1,
|
369 |
fixed_parameters=fixed_parameters,
|
370 |
set_multiple=set_multiple,
|
|
|
371 |
)
|
372 |
if success:
|
373 |
current_element[component] = new_val
|
@@ -383,6 +428,7 @@ def dict_get(
|
|
383 |
query: str,
|
384 |
not_exist_ok: bool = False,
|
385 |
default: Any = None,
|
|
|
386 |
):
|
387 |
if len(query.strip()) == 0:
|
388 |
return dic
|
@@ -393,7 +439,9 @@ def dict_get(
|
|
393 |
if isinstance(dic, dict) and query.strip() in dic:
|
394 |
return dic[query.strip()]
|
395 |
|
396 |
-
components = validate_query_and_break_to_components(
|
|
|
|
|
397 |
if len(components) > 1:
|
398 |
try:
|
399 |
success, values = get_values(dic, components, -1 * len(components))
|
@@ -474,6 +522,7 @@ def dict_set(
|
|
474 |
value: Any,
|
475 |
not_exist_ok=True,
|
476 |
set_multiple=False,
|
|
|
477 |
): # sets dic to its new value
|
478 |
if dic is None or not isinstance(dic, (list, dict)):
|
479 |
raise ValueError(
|
@@ -510,7 +559,9 @@ def dict_set(
|
|
510 |
f"set_multiple=True, but value, {value}, can not be broken up, as either it is not a list or it is an empty list"
|
511 |
)
|
512 |
|
513 |
-
components = validate_query_and_break_to_components(
|
|
|
|
|
514 |
fixed_parameters = {
|
515 |
"query": components,
|
516 |
"generate_if_not_exists": not_exist_ok,
|
@@ -522,6 +573,7 @@ def dict_set(
|
|
522 |
index_into_query=(-1) * len(components),
|
523 |
fixed_parameters=fixed_parameters,
|
524 |
set_multiple=set_multiple,
|
|
|
525 |
)
|
526 |
if not success and not not_exist_ok:
|
527 |
raise ValueError(f"No path in dic {dic} matches query {query}.")
|
|
|
4 |
from .text_utils import construct_dict_str
|
5 |
|
6 |
indx = re.compile(r"^(\d+)$")
|
7 |
+
|
8 |
+
|
9 |
+
def is_index(string):
|
10 |
+
return bool(indx.match(string))
|
11 |
+
|
12 |
+
|
13 |
name = re.compile(r"^[\w. -]+$")
|
14 |
|
15 |
+
|
16 |
+
def is_name(string):
|
17 |
+
return bool(name.match(string))
|
18 |
+
|
19 |
+
|
20 |
+
def is_wildcard(string):
|
21 |
+
return string == "*"
|
22 |
+
|
23 |
+
|
24 |
# formal definition of qpath syntax by which a query is specified:
|
25 |
# qpath -> A (/A)*
|
26 |
# A -> name | * | non-neg-int
|
|
|
66 |
|
67 |
|
68 |
# validate and normalizes into components
|
69 |
+
def validate_query_and_break_to_components(
|
70 |
+
query: str, allow_int_index=True
|
71 |
+
) -> List[str]:
|
72 |
if not isinstance(query, str) or len(query) == 0:
|
73 |
raise ValueError(
|
74 |
f"invalid query: either not a string or an empty string: {query}"
|
|
|
86 |
components = [component.strip() for component in components]
|
87 |
for component in components:
|
88 |
if not (
|
89 |
+
is_name(component)
|
90 |
+
or is_wildcard(component)
|
91 |
+
or (is_index(component) and allow_int_index)
|
92 |
):
|
93 |
raise ValueError(
|
94 |
f"Component {component} in input query is none of: valid field-name, non-neg-int, or '*'"
|
|
|
96 |
return components
|
97 |
|
98 |
|
99 |
+
def is_subpath(subpath, fullpath, allow_int_index=True):
|
100 |
# Split the paths into individual components
|
101 |
+
subpath_components = validate_query_and_break_to_components(
|
102 |
+
subpath, allow_int_index=allow_int_index
|
103 |
+
)
|
104 |
+
fullpath_components = validate_query_and_break_to_components(
|
105 |
+
fullpath, allow_int_index=allow_int_index
|
106 |
+
)
|
107 |
|
108 |
# Check if the full path starts with the subpath
|
109 |
return fullpath_components[: len(subpath_components)] == subpath_components
|
|
|
121 |
query: List[str],
|
122 |
index_into_query: int,
|
123 |
remove_empty_ancestors=False,
|
124 |
+
allow_int_index=True,
|
125 |
) -> Tuple[bool, Any]:
|
126 |
component = query[index_into_query]
|
127 |
if index_into_query == -1:
|
128 |
+
if is_wildcard(component):
|
129 |
# delete all members of the list or dict
|
130 |
current_element = [] if isinstance(current_element, list) else {}
|
131 |
return (True, current_element)
|
132 |
# component is a either a dictionary key or an index into a list,
|
133 |
# pop the respective element from current_element
|
134 |
+
if is_index(component) and allow_int_index:
|
135 |
component = int(component)
|
136 |
try:
|
137 |
current_element.pop(component)
|
|
|
163 |
query=query,
|
164 |
index_into_query=index_into_query + 1,
|
165 |
remove_empty_ancestors=remove_empty_ancestors,
|
166 |
+
allow_int_index=allow_int_index,
|
167 |
)
|
168 |
if not success:
|
169 |
continue
|
|
|
178 |
return (any_success, current_element)
|
179 |
|
180 |
# current component is index into a list or a key into a dictionary
|
181 |
+
if is_index(component) and allow_int_index:
|
182 |
component = int(component)
|
183 |
try:
|
184 |
success, new_val = delete_values(
|
|
|
186 |
query=query,
|
187 |
index_into_query=index_into_query + 1,
|
188 |
remove_empty_ancestors=remove_empty_ancestors,
|
189 |
+
allow_int_index=allow_int_index,
|
190 |
)
|
191 |
if not success:
|
192 |
return (False, None)
|
|
|
200 |
|
201 |
|
202 |
def dict_delete(
|
203 |
+
dic: dict,
|
204 |
+
query: str,
|
205 |
+
not_exist_ok: bool = False,
|
206 |
+
remove_empty_ancestors=False,
|
207 |
+
allow_int_index=True,
|
208 |
):
|
209 |
# We remove from dic the value from each and every element lead to by a path matching the query.
|
210 |
# If remove_empty_ancestors=True, and the removal of any such value leaves its containing element (list or dict)
|
|
|
225 |
dic.pop(query.strip())
|
226 |
return
|
227 |
|
228 |
+
qpath = validate_query_and_break_to_components(
|
229 |
+
query, allow_int_index=allow_int_index
|
230 |
+
)
|
231 |
|
232 |
try:
|
233 |
success, new_val = delete_values(
|
|
|
235 |
query=qpath,
|
236 |
index_into_query=(-1) * len(qpath),
|
237 |
remove_empty_ancestors=remove_empty_ancestors,
|
238 |
+
allow_int_index=allow_int_index,
|
239 |
)
|
240 |
|
241 |
if success:
|
|
|
256 |
# if query includes * then return a list of values reached by all paths that match the query
|
257 |
# flake8: noqa: C901
|
258 |
def get_values(
|
259 |
+
current_element: Any,
|
260 |
+
query: List[str],
|
261 |
+
index_into_query: int,
|
262 |
+
allow_int_index=True,
|
263 |
) -> Tuple[bool, Any]:
|
264 |
# going down from current_element through query[index_into_query].
|
265 |
if index_into_query == 0:
|
|
|
278 |
sub_elements = current_element
|
279 |
for sub_element in sub_elements:
|
280 |
try:
|
281 |
+
success, val = get_values(
|
282 |
+
sub_element,
|
283 |
+
query,
|
284 |
+
index_into_query + 1,
|
285 |
+
allow_int_index=allow_int_index,
|
286 |
+
)
|
287 |
if success:
|
288 |
to_ret.append(val)
|
289 |
except:
|
|
|
292 |
return (len(to_ret) > 0 or index_into_query == -1, to_ret)
|
293 |
# when * is the last component, it refers to 'all the contents' of an empty list being current_element.
|
294 |
# next_component is indx or name, current_element must be a list or a dict
|
295 |
+
if is_index(component) and allow_int_index:
|
296 |
component = int(component)
|
297 |
try:
|
298 |
success, new_val = get_values(
|
299 |
+
current_element[component],
|
300 |
+
query,
|
301 |
+
index_into_query + 1,
|
302 |
+
allow_int_index=allow_int_index,
|
303 |
)
|
304 |
if success:
|
305 |
return (True, new_val)
|
|
|
316 |
index_into_query: int,
|
317 |
fixed_parameters: dict,
|
318 |
set_multiple: bool = False,
|
319 |
+
allow_int_index=True,
|
320 |
) -> Tuple[bool, Any]:
|
321 |
if index_into_query == 0:
|
322 |
return (True, value) # matched query all along!
|
|
|
364 |
index_into_query=index_into_query + 1,
|
365 |
set_multiple=False, # now used, not allowed again,
|
366 |
fixed_parameters=fixed_parameters,
|
367 |
+
allow_int_index=allow_int_index,
|
368 |
)
|
369 |
if not success:
|
370 |
continue
|
|
|
379 |
)
|
380 |
|
381 |
# component is an index into a list or a key into a dictionary
|
382 |
+
if is_index(component) and allow_int_index:
|
383 |
if current_element is None or not isinstance(current_element, list):
|
384 |
if not fixed_parameters["generate_if_not_exists"]:
|
385 |
return (False, None)
|
|
|
412 |
index_into_query=index_into_query + 1,
|
413 |
fixed_parameters=fixed_parameters,
|
414 |
set_multiple=set_multiple,
|
415 |
+
allow_int_index=allow_int_index,
|
416 |
)
|
417 |
if success:
|
418 |
current_element[component] = new_val
|
|
|
428 |
query: str,
|
429 |
not_exist_ok: bool = False,
|
430 |
default: Any = None,
|
431 |
+
allow_int_index=True,
|
432 |
):
|
433 |
if len(query.strip()) == 0:
|
434 |
return dic
|
|
|
439 |
if isinstance(dic, dict) and query.strip() in dic:
|
440 |
return dic[query.strip()]
|
441 |
|
442 |
+
components = validate_query_and_break_to_components(
|
443 |
+
query, allow_int_index=allow_int_index
|
444 |
+
)
|
445 |
if len(components) > 1:
|
446 |
try:
|
447 |
success, values = get_values(dic, components, -1 * len(components))
|
|
|
522 |
value: Any,
|
523 |
not_exist_ok=True,
|
524 |
set_multiple=False,
|
525 |
+
allow_int_index=True,
|
526 |
): # sets dic to its new value
|
527 |
if dic is None or not isinstance(dic, (list, dict)):
|
528 |
raise ValueError(
|
|
|
559 |
f"set_multiple=True, but value, {value}, can not be broken up, as either it is not a list or it is an empty list"
|
560 |
)
|
561 |
|
562 |
+
components = validate_query_and_break_to_components(
|
563 |
+
query, allow_int_index=allow_int_index
|
564 |
+
)
|
565 |
fixed_parameters = {
|
566 |
"query": components,
|
567 |
"generate_if_not_exists": not_exist_ok,
|
|
|
573 |
index_into_query=(-1) * len(components),
|
574 |
fixed_parameters=fixed_parameters,
|
575 |
set_multiple=set_multiple,
|
576 |
+
allow_int_index=allow_int_index,
|
577 |
)
|
578 |
if not success and not not_exist_ok:
|
579 |
raise ValueError(f"No path in dic {dic} matches query {query}.")
|
formats.py
CHANGED
@@ -79,13 +79,13 @@ class SystemFormat(BaseFormat):
|
|
79 |
Important: formats can use '\N' notations that means new-line if no new-line before and no empty string before.
|
80 |
|
81 |
SystemFormat expects the input instance to contain:
|
82 |
-
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task
|
83 |
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
84 |
from the source dataset), in the context of the underlying task.
|
85 |
3. A field named "instruction" that contains a (non-None) string.
|
86 |
4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
|
87 |
and "target", representing a single demo.
|
88 |
-
5. A field named "
|
89 |
|
90 |
SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
|
91 |
field "source" of the instance. Formatting is driven by two args: 'demo_format' and 'model_input_format'.
|
@@ -200,16 +200,16 @@ class SystemFormat(BaseFormat):
|
|
200 |
|
201 |
|
202 |
class HFSystemFormat(BaseFormat):
|
203 |
-
r"""Formats the complete input for the model using the
|
204 |
|
205 |
HFSystemFormat expects the input instance to contain:
|
206 |
-
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task
|
207 |
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
208 |
from the source dataset), in the context of the underlying task.
|
209 |
3. A field named "instruction" that contains a (non-None) string.
|
210 |
4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
|
211 |
and "target", representing a single demo.
|
212 |
-
5. A field named "
|
213 |
|
214 |
SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
|
215 |
field "source" of the instance.
|
|
|
79 |
Important: formats can use '\N' notations that means new-line if no new-line before and no empty string before.
|
80 |
|
81 |
SystemFormat expects the input instance to contain:
|
82 |
+
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
|
83 |
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
84 |
from the source dataset), in the context of the underlying task.
|
85 |
3. A field named "instruction" that contains a (non-None) string.
|
86 |
4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
|
87 |
and "target", representing a single demo.
|
88 |
+
5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt
|
89 |
|
90 |
SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
|
91 |
field "source" of the instance. Formatting is driven by two args: 'demo_format' and 'model_input_format'.
|
|
|
200 |
|
201 |
|
202 |
class HFSystemFormat(BaseFormat):
|
203 |
+
r"""Formats the complete input for the model using the HuggingFace chat template of a given model.
|
204 |
|
205 |
HFSystemFormat expects the input instance to contain:
|
206 |
+
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
|
207 |
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
208 |
from the source dataset), in the context of the underlying task.
|
209 |
3. A field named "instruction" that contains a (non-None) string.
|
210 |
4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
|
211 |
and "target", representing a single demo.
|
212 |
+
5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt.
|
213 |
|
214 |
SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
|
215 |
field "source" of the instance.
|
fusion.py
CHANGED
@@ -12,13 +12,13 @@ class BaseFusion(SourceOperator):
|
|
12 |
"""BaseFusion operator that combines multiple multistreams into one.
|
13 |
|
14 |
Args:
|
15 |
-
|
16 |
each is specified along with its input, so can generate a MultiStream
|
17 |
include_splits: List of splits to include from each input MultiStream.
|
18 |
If None, all splits are included.
|
19 |
"""
|
20 |
|
21 |
-
|
22 |
include_splits: Optional[List[str]] = NonPositionalField(default=None)
|
23 |
|
24 |
@abstractmethod
|
@@ -26,18 +26,18 @@ class BaseFusion(SourceOperator):
|
|
26 |
pass
|
27 |
|
28 |
def prepare(self):
|
29 |
-
assert isoftype(self.
|
30 |
-
self.
|
31 |
)
|
32 |
-
self.
|
33 |
-
{i: self.
|
34 |
-
if isinstance(self.
|
35 |
-
else {name: origin() for name, origin in self.
|
36 |
)
|
37 |
|
38 |
def splits(self) -> List[str]:
|
39 |
splits = []
|
40 |
-
for _, origin in self.
|
41 |
for s in origin.keys():
|
42 |
if s not in splits:
|
43 |
if self.include_splits is None or s in self.include_splits:
|
@@ -59,69 +59,69 @@ class FixedFusion(BaseFusion):
|
|
59 |
"""FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
|
60 |
|
61 |
Args:
|
62 |
-
|
63 |
splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
|
64 |
-
|
65 |
If None, all instances of each split (that is specified in include_splits) are included in the result.
|
66 |
|
67 |
"""
|
68 |
|
69 |
-
|
70 |
|
71 |
def prepare(self):
|
72 |
super().prepare()
|
73 |
|
74 |
# flake8: noqa: C901
|
75 |
def fusion_generator(self, split) -> Generator:
|
76 |
-
for origin_name, origin in self.
|
77 |
if split not in origin:
|
78 |
continue
|
79 |
emitted_from_this_split = 0
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
94 |
|
95 |
|
96 |
class WeightedFusion(BaseFusion):
|
97 |
"""Fusion operator that combines multiple MultiStream-s.
|
98 |
|
99 |
Args:
|
100 |
-
|
101 |
weights: Dict of named weights for each origin, or a list thereof
|
102 |
max_total_examples: Total number of instances to return per returned split.
|
103 |
If None, all instances are returned
|
104 |
"""
|
105 |
|
106 |
-
|
107 |
weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
|
108 |
-
|
109 |
-
ignore_origin_groups: List[str] = ["unitxt"]
|
110 |
|
111 |
def verify(self):
|
112 |
super().verify()
|
113 |
-
assert self.
|
114 |
assert self.weights is not None, "weights must be specified"
|
115 |
-
assert len(self.
|
116 |
self.weights
|
117 |
-
), "
|
118 |
-
assert isoftype(self.
|
119 |
-
self.
|
120 |
)
|
121 |
assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
|
122 |
self.weights, List[Union[int, float]]
|
123 |
)
|
124 |
-
assert isinstance(self.
|
125 |
|
126 |
def prepare(self):
|
127 |
super().prepare()
|
@@ -134,12 +134,12 @@ class WeightedFusion(BaseFusion):
|
|
134 |
def fusion_generator(self, split) -> Generator:
|
135 |
iterators = {
|
136 |
named_origin: iter(origin[split])
|
137 |
-
for named_origin, origin in self.
|
138 |
}
|
139 |
total_examples = 0
|
140 |
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
|
141 |
while (
|
142 |
-
self.
|
143 |
) and len(iterators) > 0:
|
144 |
population = list(iterators.keys())
|
145 |
origin_name = random_generator.choices(
|
@@ -150,13 +150,9 @@ class WeightedFusion(BaseFusion):
|
|
150 |
try:
|
151 |
instance = next(iterator)
|
152 |
if isinstance(origin_name, str):
|
153 |
-
if
|
154 |
-
"
|
155 |
-
|
156 |
-
):
|
157 |
-
instance["group"] = origin_name + "/" + instance["group"]
|
158 |
-
else:
|
159 |
-
instance["group"] = origin_name
|
160 |
total_examples += 1
|
161 |
yield instance
|
162 |
|
|
|
12 |
"""BaseFusion operator that combines multiple multistreams into one.
|
13 |
|
14 |
Args:
|
15 |
+
subsets: a dict of named SourceOperator objects (each to yield a MultiStream) or a list thereof,
|
16 |
each is specified along with its input, so can generate a MultiStream
|
17 |
include_splits: List of splits to include from each input MultiStream.
|
18 |
If None, all splits are included.
|
19 |
"""
|
20 |
|
21 |
+
subsets: Union[List[SourceOperator], Dict[str, SourceOperator]]
|
22 |
include_splits: Optional[List[str]] = NonPositionalField(default=None)
|
23 |
|
24 |
@abstractmethod
|
|
|
26 |
pass
|
27 |
|
28 |
def prepare(self):
|
29 |
+
assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
|
30 |
+
self.subsets, List[SourceOperator]
|
31 |
)
|
32 |
+
self.named_subsets = (
|
33 |
+
{i: self.subsets[i]() for i in range(len(self.subsets))}
|
34 |
+
if isinstance(self.subsets, list)
|
35 |
+
else {name: origin() for name, origin in self.subsets.items()}
|
36 |
)
|
37 |
|
38 |
def splits(self) -> List[str]:
|
39 |
splits = []
|
40 |
+
for _, origin in self.named_subsets.items():
|
41 |
for s in origin.keys():
|
42 |
if s not in splits:
|
43 |
if self.include_splits is None or s in self.include_splits:
|
|
|
59 |
"""FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
|
60 |
|
61 |
Args:
|
62 |
+
subsets: Dict of named SourceOperator objects (each to yield a MultiStream), or a list thereof
|
63 |
splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
|
64 |
+
max_instances_per_subset: Number of instances to take from each input split of each input multistream.
|
65 |
If None, all instances of each split (that is specified in include_splits) are included in the result.
|
66 |
|
67 |
"""
|
68 |
|
69 |
+
max_instances_per_subset: Optional[int] = None
|
70 |
|
71 |
def prepare(self):
|
72 |
super().prepare()
|
73 |
|
74 |
# flake8: noqa: C901
|
75 |
def fusion_generator(self, split) -> Generator:
|
76 |
+
for origin_name, origin in self.named_subsets.items():
|
77 |
if split not in origin:
|
78 |
continue
|
79 |
emitted_from_this_split = 0
|
80 |
+
try:
|
81 |
+
for instance in origin[split]:
|
82 |
+
if (
|
83 |
+
self.max_instances_per_subset is not None
|
84 |
+
and emitted_from_this_split >= self.max_instances_per_subset
|
85 |
+
):
|
86 |
+
break
|
87 |
+
if isinstance(origin_name, str):
|
88 |
+
if "subset" not in instance:
|
89 |
+
instance["subset"] = []
|
90 |
+
instance["subset"].insert(0, origin_name)
|
91 |
+
emitted_from_this_split += 1
|
92 |
+
yield instance
|
93 |
+
except Exception as e:
|
94 |
+
raise RuntimeError(f"Exception in subset: {origin_name}") from e
|
95 |
|
96 |
|
97 |
class WeightedFusion(BaseFusion):
|
98 |
"""Fusion operator that combines multiple MultiStream-s.
|
99 |
|
100 |
Args:
|
101 |
+
subsets: Dict of named MultiStream objects, or a list thereof
|
102 |
weights: Dict of named weights for each origin, or a list thereof
|
103 |
max_total_examples: Total number of instances to return per returned split.
|
104 |
If None, all instances are returned
|
105 |
"""
|
106 |
|
107 |
+
subsets: Union[Dict[str, SourceOperator], List[SourceOperator]] = None
|
108 |
weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
|
109 |
+
max_total_samples: int = None
|
|
|
110 |
|
111 |
def verify(self):
|
112 |
super().verify()
|
113 |
+
assert self.subsets is not None, "subsets must be specified"
|
114 |
assert self.weights is not None, "weights must be specified"
|
115 |
+
assert len(self.subsets) == len(
|
116 |
self.weights
|
117 |
+
), "subsets and weights must have the same length"
|
118 |
+
assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
|
119 |
+
self.subsets, List[SourceOperator]
|
120 |
)
|
121 |
assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
|
122 |
self.weights, List[Union[int, float]]
|
123 |
)
|
124 |
+
assert isinstance(self.subsets, dict) == isinstance(self.weights, dict)
|
125 |
|
126 |
def prepare(self):
|
127 |
super().prepare()
|
|
|
134 |
def fusion_generator(self, split) -> Generator:
|
135 |
iterators = {
|
136 |
named_origin: iter(origin[split])
|
137 |
+
for named_origin, origin in self.named_subsets.items()
|
138 |
}
|
139 |
total_examples = 0
|
140 |
random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
|
141 |
while (
|
142 |
+
self.max_total_samples is None or total_examples < self.max_total_samples
|
143 |
) and len(iterators) > 0:
|
144 |
population = list(iterators.keys())
|
145 |
origin_name = random_generator.choices(
|
|
|
150 |
try:
|
151 |
instance = next(iterator)
|
152 |
if isinstance(origin_name, str):
|
153 |
+
if "subset" not in instance:
|
154 |
+
instance["subset"] = []
|
155 |
+
instance["subset"].insert(0, origin_name)
|
|
|
|
|
|
|
|
|
156 |
total_examples += 1
|
157 |
yield instance
|
158 |
|
image_operators.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Any, Dict
|
3 |
+
|
4 |
+
from .dict_utils import dict_get
|
5 |
+
from .operators import InstanceFieldOperator
|
6 |
+
|
7 |
+
|
8 |
+
def extract_images(text, instance):
|
9 |
+
regex = r'<img\s+src=["\'](.*?)["\']'
|
10 |
+
image_sources = re.findall(regex, text)
|
11 |
+
images = []
|
12 |
+
for image_source in image_sources:
|
13 |
+
image = dict_get(instance, image_source)
|
14 |
+
images.append(image)
|
15 |
+
return images
|
16 |
+
|
17 |
+
|
18 |
+
class ImageToText(InstanceFieldOperator):
|
19 |
+
def process_instance_value(self, value: Any, instance: Dict[str, Any]):
|
20 |
+
if "media" not in instance:
|
21 |
+
instance["media"] = {}
|
22 |
+
if "images" not in instance["media"]:
|
23 |
+
instance["media"]["images"] = []
|
24 |
+
idx = len(instance["media"]["images"])
|
25 |
+
instance["media"]["images"].append(value)
|
26 |
+
return f'<img src="media/images/{idx}">'
|
inference.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
import abc
|
2 |
import os
|
|
|
3 |
from typing import Any, Dict, List, Literal, Optional, Union
|
4 |
|
5 |
from tqdm import tqdm
|
6 |
|
7 |
from .artifact import Artifact
|
8 |
-
from .dataclass import InternalField
|
9 |
from .deprecation_utils import deprecation
|
|
|
10 |
from .logging_utils import get_logger
|
11 |
from .operator import PackageRequirementsMixin
|
12 |
|
@@ -61,11 +63,20 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
|
|
61 |
return self._infer_log_probs(dataset)
|
62 |
|
63 |
|
64 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
model_name: str
|
66 |
max_new_tokens: int
|
67 |
use_fp16: bool = True
|
68 |
-
lazy_load: bool = False
|
69 |
|
70 |
_requirements_list = {
|
71 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
@@ -115,11 +126,11 @@ class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
|
115 |
if not self.lazy_load:
|
116 |
self._prepare_pipeline()
|
117 |
|
118 |
-
def
|
119 |
return hasattr(self, "model") and self.model is not None
|
120 |
|
121 |
def _infer(self, dataset):
|
122 |
-
if not self.
|
123 |
self._prepare_pipeline()
|
124 |
|
125 |
outputs = []
|
@@ -497,3 +508,71 @@ class WMLInferenceEngine(
|
|
497 |
prompt=dataset["source"],
|
498 |
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
|
499 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import abc
|
2 |
import os
|
3 |
+
import re
|
4 |
from typing import Any, Dict, List, Literal, Optional, Union
|
5 |
|
6 |
from tqdm import tqdm
|
7 |
|
8 |
from .artifact import Artifact
|
9 |
+
from .dataclass import InternalField, NonPositionalField
|
10 |
from .deprecation_utils import deprecation
|
11 |
+
from .image_operators import extract_images
|
12 |
from .logging_utils import get_logger
|
13 |
from .operator import PackageRequirementsMixin
|
14 |
|
|
|
63 |
return self._infer_log_probs(dataset)
|
64 |
|
65 |
|
66 |
+
class LazyLoadMixin(Artifact):
|
67 |
+
lazy_load: bool = NonPositionalField(default=False)
|
68 |
+
|
69 |
+
@abc.abstractmethod
|
70 |
+
def _is_loaded(self):
|
71 |
+
pass
|
72 |
+
|
73 |
+
|
74 |
+
class HFPipelineBasedInferenceEngine(
|
75 |
+
InferenceEngine, PackageRequirementsMixin, LazyLoadMixin
|
76 |
+
):
|
77 |
model_name: str
|
78 |
max_new_tokens: int
|
79 |
use_fp16: bool = True
|
|
|
80 |
|
81 |
_requirements_list = {
|
82 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
|
|
126 |
if not self.lazy_load:
|
127 |
self._prepare_pipeline()
|
128 |
|
129 |
+
def _is_loaded(self):
|
130 |
return hasattr(self, "model") and self.model is not None
|
131 |
|
132 |
def _infer(self, dataset):
|
133 |
+
if not self._is_loaded():
|
134 |
self._prepare_pipeline()
|
135 |
|
136 |
outputs = []
|
|
|
508 |
prompt=dataset["source"],
|
509 |
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
|
510 |
)
|
511 |
+
|
512 |
+
|
513 |
+
class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
|
514 |
+
model_name: str
|
515 |
+
max_new_tokens: int
|
516 |
+
lazy_load = True
|
517 |
+
|
518 |
+
_requirements_list = {
|
519 |
+
"transformers": "Install huggingface package using 'pip install --upgrade transformers",
|
520 |
+
"torch": "Install torch, go on PyTorch website for mode details.",
|
521 |
+
"accelerate": "pip install accelerate",
|
522 |
+
}
|
523 |
+
|
524 |
+
def _prepare_engine(self):
|
525 |
+
import torch
|
526 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
527 |
+
|
528 |
+
self.device = torch.device(
|
529 |
+
"mps"
|
530 |
+
if torch.backends.mps.is_available()
|
531 |
+
else 0
|
532 |
+
if torch.cuda.is_available()
|
533 |
+
else "cpu"
|
534 |
+
)
|
535 |
+
|
536 |
+
self.model = LlavaForConditionalGeneration.from_pretrained(
|
537 |
+
self.model_name,
|
538 |
+
torch_dtype=torch.float16,
|
539 |
+
low_cpu_mem_usage=True,
|
540 |
+
).to(self.device)
|
541 |
+
|
542 |
+
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
543 |
+
|
544 |
+
def prepare(self):
|
545 |
+
if not self.lazy_load:
|
546 |
+
self._prepare_engine()
|
547 |
+
|
548 |
+
def _is_loaded(self):
|
549 |
+
return hasattr(self, "model") and self.model is not None
|
550 |
+
|
551 |
+
def _infer(self, dataset):
|
552 |
+
if not self._is_loaded():
|
553 |
+
self._prepare_engine()
|
554 |
+
|
555 |
+
import torch
|
556 |
+
|
557 |
+
results = []
|
558 |
+
for instance in dataset:
|
559 |
+
text = instance["source"]
|
560 |
+
images = extract_images(text, instance)
|
561 |
+
# Regular expression to match all <img src="..."> tags
|
562 |
+
regex = r'<img\s+src=["\'](.*?)["\']\s*/?>'
|
563 |
+
model_input = re.sub(regex, "<image>", text)
|
564 |
+
if len(images) == 1:
|
565 |
+
images = images[0]
|
566 |
+
inputs = self.processor(
|
567 |
+
images=images, text=model_input, return_tensors="pt"
|
568 |
+
).to(self.device, torch.float16)
|
569 |
+
input_len = len(inputs["input_ids"][0])
|
570 |
+
output = self.model.generate(
|
571 |
+
**inputs, max_new_tokens=self.max_new_tokens, do_sample=False
|
572 |
+
)
|
573 |
+
result = self.processor.decode(
|
574 |
+
output[0][input_len:], skip_special_tokens=True
|
575 |
+
)
|
576 |
+
results.append(result)
|
577 |
+
|
578 |
+
return results
|
llm_as_judge.py
CHANGED
@@ -1,28 +1,32 @@
|
|
1 |
from typing import Any, Dict, List, Literal, Optional
|
2 |
|
3 |
-
from .api import
|
4 |
-
from .artifact import
|
5 |
-
from .
|
|
|
6 |
from .inference import InferenceEngine, OpenAiInferenceEngine
|
7 |
from .metrics import BulkInstanceMetric
|
8 |
from .operator import SequentialOperator
|
9 |
-
from .
|
|
|
10 |
from .templates import Template
|
11 |
|
|
|
|
|
12 |
|
13 |
class LLMAsJudge(BulkInstanceMetric):
|
14 |
-
"""LLM
|
15 |
|
16 |
Attributes:
|
17 |
main_score (str): The main score label used for evaluation.
|
18 |
-
task (Literal["rating.single_turn"]): The type of task the llm
|
19 |
-
format of the
|
20 |
template (Template): The template used when generating inputs for the judge llm.
|
21 |
format (Format): The format used when generating inputs for judge llm.
|
22 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
23 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
24 |
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
25 |
-
inference_model (InferenceEngine):
|
26 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
27 |
batch_size (int): The size of the bulk.
|
28 |
"""
|
@@ -34,8 +38,8 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
34 |
"pairwise_comparative_rating.single_turn",
|
35 |
]
|
36 |
template: Template
|
37 |
-
|
38 |
-
|
39 |
strip_system_prompt_and_format_from_inputs: bool = True
|
40 |
inference_model: InferenceEngine
|
41 |
reduction_map: Optional[Dict[str, List[str]]] = None
|
@@ -72,7 +76,6 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
72 |
{
|
73 |
"question": input_instance,
|
74 |
"answer": prediction,
|
75 |
-
"rating": 5.0, # This is a dummy value that is not used in practice
|
76 |
}
|
77 |
for input_instance, prediction, reference in zip(
|
78 |
input_instances, predictions, references
|
@@ -84,7 +87,6 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
84 |
"question": input_instance,
|
85 |
"answer": prediction,
|
86 |
"reference_answer": reference[0],
|
87 |
-
"rating": 5.0, # This is a dummy value that is not used in practice
|
88 |
}
|
89 |
for input_instance, prediction, reference in zip(
|
90 |
input_instances, predictions, references
|
@@ -98,7 +100,6 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
98 |
"answer_b": reference[0],
|
99 |
"model_a": "input_model",
|
100 |
"model_b": "baseline_model",
|
101 |
-
"answer_a_preference": 0, # This is a dummy value that is not used in practice,
|
102 |
}
|
103 |
for input_instance, prediction, reference in zip(
|
104 |
input_instances, predictions, references
|
@@ -110,15 +111,6 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
110 |
)
|
111 |
return instances
|
112 |
|
113 |
-
@staticmethod
|
114 |
-
def _add_metadata_to_judge_instances(
|
115 |
-
instances: List[List[Any]], task_data: List[Dict]
|
116 |
-
):
|
117 |
-
for instance, data in zip(instances, task_data):
|
118 |
-
instance["data_classification_policy"] = data["metadata"][
|
119 |
-
"data_classification_policy"
|
120 |
-
]
|
121 |
-
|
122 |
def prepare(self):
|
123 |
super().prepare()
|
124 |
if self.task == "pairwise_comparative_rating.single_turn":
|
@@ -176,47 +168,38 @@ class LLMAsJudge(BulkInstanceMetric):
|
|
176 |
instances = self._get_instance_for_judge_model(
|
177 |
input_instances, predictions, references
|
178 |
)
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
if self.format:
|
192 |
-
recipe_args["format"] = self.format
|
193 |
-
recipe = Artifact.from_dict(recipe_args)
|
194 |
-
dataset = produce(instances, recipe)
|
195 |
-
verdicts = self.inference_model.infer(dataset)
|
196 |
-
meta_scores = evaluate(predictions=verdicts, data=dataset)
|
197 |
-
|
198 |
-
res_list = []
|
199 |
-
for instance, verdict in zip(meta_scores, verdicts):
|
200 |
if self.task == "pairwise_comparative_rating.single_turn":
|
201 |
is_model_b_the_baseline = (
|
202 |
instance["task_data"]["model_b"] == "baseline_model"
|
203 |
)
|
204 |
if is_model_b_the_baseline:
|
205 |
-
model_a_preference_score = instance["
|
206 |
else:
|
207 |
-
model_a_preference_score = instance["
|
208 |
|
209 |
-
|
210 |
self.main_score: model_a_preference_score,
|
211 |
-
"judge_raw_output":
|
212 |
"judge_raw_input": instance["source"],
|
213 |
}
|
214 |
else:
|
215 |
-
|
216 |
-
self.main_score: instance["
|
217 |
-
"judge_raw_output":
|
218 |
"judge_raw_input": instance["source"],
|
219 |
}
|
220 |
-
|
221 |
|
222 |
-
return
|
|
|
1 |
from typing import Any, Dict, List, Literal, Optional
|
2 |
|
3 |
+
from .api import infer
|
4 |
+
from .artifact import fetch_artifact
|
5 |
+
from .dataclass import Field
|
6 |
+
from .formats import Format, SystemFormat
|
7 |
from .inference import InferenceEngine, OpenAiInferenceEngine
|
8 |
from .metrics import BulkInstanceMetric
|
9 |
from .operator import SequentialOperator
|
10 |
+
from .settings_utils import get_settings
|
11 |
+
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
12 |
from .templates import Template
|
13 |
|
14 |
+
settings = get_settings()
|
15 |
+
|
16 |
|
17 |
class LLMAsJudge(BulkInstanceMetric):
|
18 |
+
"""LLM-as-judge-based metric class for evaluating correctness.
|
19 |
|
20 |
Attributes:
|
21 |
main_score (str): The main score label used for evaluation.
|
22 |
+
task (Literal["rating.single_turn"]): The type of task the llm as judge runs. This defines the output and input
|
23 |
+
format of the judge model.
|
24 |
template (Template): The template used when generating inputs for the judge llm.
|
25 |
format (Format): The format used when generating inputs for judge llm.
|
26 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
27 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
28 |
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
29 |
+
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
30 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
31 |
batch_size (int): The size of the bulk.
|
32 |
"""
|
|
|
38 |
"pairwise_comparative_rating.single_turn",
|
39 |
]
|
40 |
template: Template
|
41 |
+
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
42 |
+
format: Format = Field(default_factory=SystemFormat)
|
43 |
strip_system_prompt_and_format_from_inputs: bool = True
|
44 |
inference_model: InferenceEngine
|
45 |
reduction_map: Optional[Dict[str, List[str]]] = None
|
|
|
76 |
{
|
77 |
"question": input_instance,
|
78 |
"answer": prediction,
|
|
|
79 |
}
|
80 |
for input_instance, prediction, reference in zip(
|
81 |
input_instances, predictions, references
|
|
|
87 |
"question": input_instance,
|
88 |
"answer": prediction,
|
89 |
"reference_answer": reference[0],
|
|
|
90 |
}
|
91 |
for input_instance, prediction, reference in zip(
|
92 |
input_instances, predictions, references
|
|
|
100 |
"answer_b": reference[0],
|
101 |
"model_a": "input_model",
|
102 |
"model_b": "baseline_model",
|
|
|
103 |
}
|
104 |
for input_instance, prediction, reference in zip(
|
105 |
input_instances, predictions, references
|
|
|
111 |
)
|
112 |
return instances
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
def prepare(self):
|
115 |
super().prepare()
|
116 |
if self.task == "pairwise_comparative_rating.single_turn":
|
|
|
168 |
instances = self._get_instance_for_judge_model(
|
169 |
input_instances, predictions, references
|
170 |
)
|
171 |
+
outputs = infer(
|
172 |
+
instances,
|
173 |
+
engine=self.inference_model,
|
174 |
+
task=f"tasks.response_assessment.{self.task}",
|
175 |
+
template=self.template,
|
176 |
+
system_prompt=self.system_prompt,
|
177 |
+
format=self.format,
|
178 |
+
return_data=True,
|
179 |
+
)
|
180 |
+
|
181 |
+
results = []
|
182 |
+
for instance in outputs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
if self.task == "pairwise_comparative_rating.single_turn":
|
184 |
is_model_b_the_baseline = (
|
185 |
instance["task_data"]["model_b"] == "baseline_model"
|
186 |
)
|
187 |
if is_model_b_the_baseline:
|
188 |
+
model_a_preference_score = instance["prediction"]
|
189 |
else:
|
190 |
+
model_a_preference_score = instance["prediction"] * -1
|
191 |
|
192 |
+
result = {
|
193 |
self.main_score: model_a_preference_score,
|
194 |
+
"judge_raw_output": instance["raw_prediction"],
|
195 |
"judge_raw_input": instance["source"],
|
196 |
}
|
197 |
else:
|
198 |
+
result = {
|
199 |
+
self.main_score: instance["prediction"],
|
200 |
+
"judge_raw_output": instance["raw_prediction"],
|
201 |
"judge_raw_input": instance["source"],
|
202 |
}
|
203 |
+
results.append(result)
|
204 |
|
205 |
+
return results
|
loaders.py
CHANGED
@@ -7,23 +7,23 @@ Unitxt is all about readily preparing of any given data source for feeding into
|
|
7 |
post-processing the model's output, preparing it for any given evaluator.
|
8 |
|
9 |
Through that journey, the data advances in the form of Unitxt Multistream, undergoing a sequential application
|
10 |
-
of various off
|
11 |
-
The journey starts by a Unitxt
|
12 |
A loader, therefore, is the first item on any Unitxt Recipe.
|
13 |
|
14 |
Unitxt catalog contains several loaders for the most popular datasource formats.
|
15 |
-
All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource
|
16 |
-
|
17 |
|
18 |
Available Loaders Overview:
|
19 |
-
- :ref:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from
|
20 |
- :ref:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
|
21 |
- :ref:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
|
22 |
- :ref:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
|
23 |
- :ref:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
|
24 |
- :ref:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
|
25 |
- :ref:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
|
26 |
-
- :ref:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from
|
27 |
|
28 |
|
29 |
|
@@ -64,7 +64,7 @@ class Loader(SourceOperator):
|
|
64 |
|
65 |
A loader is the first component in the Unitxt Recipe,
|
66 |
responsible for loading data from various sources and preparing it as a MultiStream for processing.
|
67 |
-
The loader_limit an optional parameter used to control the maximum number of instances to load from the data source. It is applied for each split separately.
|
68 |
It is usually provided to the loader via the recipe (see standard.py)
|
69 |
The loader can use this value to limit the amount of data downloaded from the source
|
70 |
to reduce loading time. However, this may not always be possible, so the
|
@@ -140,13 +140,13 @@ class Loader(SourceOperator):
|
|
140 |
|
141 |
|
142 |
class LoadHF(Loader):
|
143 |
-
"""Loads datasets from the
|
144 |
|
145 |
It supports loading with or without streaming,
|
146 |
-
and can filter datasets upon loading.
|
147 |
|
148 |
Args:
|
149 |
-
path: The path or identifier of the dataset on the
|
150 |
name: An optional dataset name.
|
151 |
data_dir: Optional directory to store downloaded data.
|
152 |
split: Optional specification of which split to load.
|
@@ -652,7 +652,7 @@ class MultipleSourceLoader(Loader):
|
|
652 |
sources: A list of loaders that will be combined to form a unified dataset.
|
653 |
|
654 |
Examples:
|
655 |
-
1) Loading the train split from
|
656 |
|
657 |
.. code-block:: python
|
658 |
|
@@ -678,12 +678,12 @@ class MultipleSourceLoader(Loader):
|
|
678 |
|
679 |
def load_data(self):
|
680 |
return FixedFusion(
|
681 |
-
|
682 |
).process()
|
683 |
|
684 |
|
685 |
class LoadFromDictionary(Loader):
|
686 |
-
"""Allows loading data from dictionary of constants.
|
687 |
|
688 |
The loader can be used, for example, when debugging or working with small datasets.
|
689 |
|
@@ -733,29 +733,29 @@ class LoadFromDictionary(Loader):
|
|
733 |
|
734 |
|
735 |
class LoadFromHFSpace(LoadHF):
|
736 |
-
"""Used to load data from
|
737 |
|
738 |
Loaders firstly tries to download all files specified in the 'data_files' parameter
|
739 |
-
from the given space and then reads them as a
|
740 |
|
741 |
Args:
|
742 |
-
space_name (str): Name of the
|
743 |
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
|
744 |
paths to files within a given repository. If given as a mapping, paths should
|
745 |
be values, while keys should represent the type of respective files
|
746 |
(training, testing etc.).
|
747 |
-
path (str, optional): Absolute path to a directory where data should be downloaded
|
748 |
revision (str, optional): ID of a Git branch or commit to be used. By default, it is
|
749 |
set to None, thus data is downloaded from the main branch of the accessed
|
750 |
repository.
|
751 |
-
use_token (bool, optional): Whether token used for authentication when accessing
|
752 |
-
the
|
753 |
config folder.
|
754 |
token_env (str, optional): Key of an env variable which value will be used for
|
755 |
-
authentication when accessing the
|
756 |
|
757 |
Example:
|
758 |
-
Loading from
|
759 |
|
760 |
.. code-block:: python
|
761 |
|
|
|
7 |
post-processing the model's output, preparing it for any given evaluator.
|
8 |
|
9 |
Through that journey, the data advances in the form of Unitxt Multistream, undergoing a sequential application
|
10 |
+
of various off-the-shelf operators (i.e., picked from Unitxt catalog), or operators easily implemented by inheriting.
|
11 |
+
The journey starts by a Unitxt Loader bearing a Multistream from the given datasource.
|
12 |
A loader, therefore, is the first item on any Unitxt Recipe.
|
13 |
|
14 |
Unitxt catalog contains several loaders for the most popular datasource formats.
|
15 |
+
All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource is
|
16 |
+
straightforward.
|
17 |
|
18 |
Available Loaders Overview:
|
19 |
+
- :ref:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from HuggingFace Datasets.
|
20 |
- :ref:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
|
21 |
- :ref:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
|
22 |
- :ref:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
|
23 |
- :ref:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
|
24 |
- :ref:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
|
25 |
- :ref:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
|
26 |
+
- :ref:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
|
27 |
|
28 |
|
29 |
|
|
|
64 |
|
65 |
A loader is the first component in the Unitxt Recipe,
|
66 |
responsible for loading data from various sources and preparing it as a MultiStream for processing.
|
67 |
+
The loader_limit is an optional parameter used to control the maximum number of instances to load from the data source. It is applied for each split separately.
|
68 |
It is usually provided to the loader via the recipe (see standard.py)
|
69 |
The loader can use this value to limit the amount of data downloaded from the source
|
70 |
to reduce loading time. However, this may not always be possible, so the
|
|
|
140 |
|
141 |
|
142 |
class LoadHF(Loader):
|
143 |
+
"""Loads datasets from the HuggingFace Hub.
|
144 |
|
145 |
It supports loading with or without streaming,
|
146 |
+
and it can filter datasets upon loading.
|
147 |
|
148 |
Args:
|
149 |
+
path: The path or identifier of the dataset on the HuggingFace Hub.
|
150 |
name: An optional dataset name.
|
151 |
data_dir: Optional directory to store downloaded data.
|
152 |
split: Optional specification of which split to load.
|
|
|
652 |
sources: A list of loaders that will be combined to form a unified dataset.
|
653 |
|
654 |
Examples:
|
655 |
+
1) Loading the train split from a HuggingFace Hub and the test set from a local file:
|
656 |
|
657 |
.. code-block:: python
|
658 |
|
|
|
678 |
|
679 |
def load_data(self):
|
680 |
return FixedFusion(
|
681 |
+
subsets=self.sources, max_instances_per_subset=self.get_limit()
|
682 |
).process()
|
683 |
|
684 |
|
685 |
class LoadFromDictionary(Loader):
|
686 |
+
"""Allows loading data from a dictionary of constants.
|
687 |
|
688 |
The loader can be used, for example, when debugging or working with small datasets.
|
689 |
|
|
|
733 |
|
734 |
|
735 |
class LoadFromHFSpace(LoadHF):
|
736 |
+
"""Used to load data from HuggingFace Spaces.
|
737 |
|
738 |
Loaders firstly tries to download all files specified in the 'data_files' parameter
|
739 |
+
from the given space and then reads them as a HuggingFace Dataset.
|
740 |
|
741 |
Args:
|
742 |
+
space_name (str): Name of the HuggingFace Space to be accessed.
|
743 |
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
|
744 |
paths to files within a given repository. If given as a mapping, paths should
|
745 |
be values, while keys should represent the type of respective files
|
746 |
(training, testing etc.).
|
747 |
+
path (str, optional): Absolute path to a directory where data should be downloaded.
|
748 |
revision (str, optional): ID of a Git branch or commit to be used. By default, it is
|
749 |
set to None, thus data is downloaded from the main branch of the accessed
|
750 |
repository.
|
751 |
+
use_token (bool, optional): Whether a token is used for authentication when accessing
|
752 |
+
the HuggingFace Space. If necessary, the token is read from the HuggingFace
|
753 |
config folder.
|
754 |
token_env (str, optional): Key of an env variable which value will be used for
|
755 |
+
authentication when accessing the HuggingFace Space - if necessary.
|
756 |
|
757 |
Example:
|
758 |
+
Loading from a HuggingFace Space
|
759 |
|
760 |
.. code-block:: python
|
761 |
|
metric.py
CHANGED
@@ -4,6 +4,7 @@ import evaluate
|
|
4 |
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
|
|
7 |
from .blocks import __file__ as _
|
8 |
from .card import __file__ as _
|
9 |
from .catalog import __file__ as _
|
@@ -22,6 +23,7 @@ from .fusion import __file__ as _
|
|
22 |
from .generator_utils import __file__ as _
|
23 |
from .hf_utils import __file__ as _
|
24 |
from .hf_utils import verify_versions_compatibility
|
|
|
25 |
from .inference import __file__ as _
|
26 |
from .instructions import __file__ as _
|
27 |
from .llm_as_judge import __file__ as _
|
|
|
4 |
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
7 |
+
from .benchmark import __file__ as _
|
8 |
from .blocks import __file__ as _
|
9 |
from .card import __file__ as _
|
10 |
from .catalog import __file__ as _
|
|
|
23 |
from .generator_utils import __file__ as _
|
24 |
from .hf_utils import __file__ as _
|
25 |
from .hf_utils import verify_versions_compatibility
|
26 |
+
from .image_operators import __file__ as _
|
27 |
from .inference import __file__ as _
|
28 |
from .instructions import __file__ as _
|
29 |
from .llm_as_judge import __file__ as _
|
metric_utils.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import json
|
2 |
-
from
|
|
|
|
|
|
|
3 |
|
4 |
from datasets import Features, Value
|
5 |
-
from numpy import nanmean
|
6 |
|
7 |
from .dataclass import Dataclass
|
8 |
-
from .dict_utils import dict_set
|
9 |
from .operator import (
|
10 |
MultiStreamOperator,
|
11 |
SequentialOperator,
|
@@ -17,111 +18,20 @@ from .operators import (
|
|
17 |
ApplyOperatorsField,
|
18 |
Copy,
|
19 |
FlattenInstances,
|
20 |
-
|
21 |
-
RenameFields,
|
22 |
-
SplitByNestedGroup,
|
23 |
)
|
24 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
25 |
from .schema import UNITXT_DATASET_SCHEMA
|
26 |
-
from .settings_utils import get_settings
|
27 |
from .stream import DynamicStream, MultiStream
|
28 |
from .struct_data_operators import LoadJson
|
29 |
from .utils import deepcopy
|
30 |
|
|
|
31 |
|
32 |
-
class MultiStreamScoreMean(MultiStreamOperator):
|
33 |
-
"""Given a multi-stream where each stream is already scored globally, generate a nested global score for the whole multi-stream.
|
34 |
-
|
35 |
-
The whole-ms-global-score is a nested structure, specifying (also) the individual global scores of the
|
36 |
-
individual streams participating in the input multi_stream.
|
37 |
-
The instances of all these individual streams are assumed to have the "group" field indicate the stream
|
38 |
-
they belong to.
|
39 |
-
Potentially, these individual streams were produced from a SplitByNestedGroup
|
40 |
-
operator that did not use the full length of the value in field "group" of the instances, but only the
|
41 |
-
first g components thereof, indicated by argument 'number_of_fusion_generations' of operator SplitByNestedGroup.
|
42 |
-
At any rate, a distinguishing prefix of the "group" value is recorded, by operator SplitByNestedGroup, in the stream_name.
|
43 |
-
The nested structure of the whole-ms-global-score is induced by these distinguishing prefixes,
|
44 |
-
by virtue of the global score of each individual stream sitting in the nested whole-ms-global-score,
|
45 |
-
deep in that dictionary, at the leaf lead to by a path being the distinguishing prefix indicated in the stream_name.
|
46 |
-
Thus, the global score of the stream becomes a leaf (though a dict by itself) of the whole-ms-global-score.
|
47 |
-
|
48 |
-
The ancestor nodes of the above leaves, in the whole-ms-global-score, contain each (in addition to dicts
|
49 |
-
leading down to leaves) a field named "score" whose value is set to be the mean of the values
|
50 |
-
sitting in field "score" of its immediate children nodes, and a field named "score_name" whose
|
51 |
-
value is set to be "group_mean".
|
52 |
-
|
53 |
-
When the input multistream consists of one single stream, it is returned as is, mainly for backward compatibility.
|
54 |
-
"""
|
55 |
-
|
56 |
-
def update_intermediate_level_scores(self, level: dict) -> float:
|
57 |
-
if "score" in level:
|
58 |
-
return level["score"]
|
59 |
-
# the global score of the stream participating in this MultiStream
|
60 |
-
sub_scores = []
|
61 |
-
for key in level:
|
62 |
-
if isinstance(level[key], dict):
|
63 |
-
sub_scores.append(self.update_intermediate_level_scores(level[key]))
|
64 |
-
level.update({"score": nanmean(sub_scores), "score_name": "groups_mean"})
|
65 |
-
return level["score"]
|
66 |
-
|
67 |
-
def process(self, multi_stream: MultiStream) -> MultiStream:
|
68 |
-
# each stream went through Metric which is a single-stream-operator , and ended up with all
|
69 |
-
# its instance["score"]["global"] linking to the same single dict object.
|
70 |
-
# Here we first generate a new, nested version, for the whole-ms-global_score, and then update
|
71 |
-
# each stream's global score with the new version
|
72 |
-
# but if only one stream in the multistream - we return it as is
|
73 |
-
if len(multi_stream) == 1:
|
74 |
-
return multi_stream
|
75 |
-
global_score = {}
|
76 |
-
first_instances = {}
|
77 |
-
iterators = {}
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
try:
|
82 |
-
first_instances[stream_name] = next(iterators[stream_name])
|
83 |
-
except StopIteration:
|
84 |
-
continue # an empty stream, goto next stream
|
85 |
-
instance = first_instances[stream_name]
|
86 |
-
dict_set(
|
87 |
-
dic=global_score,
|
88 |
-
query=stream_name.split("~")[-1],
|
89 |
-
value=deepcopy(instance["score"]["global"]),
|
90 |
-
not_exist_ok=True,
|
91 |
-
)
|
92 |
-
|
93 |
-
self.update_intermediate_level_scores(global_score)
|
94 |
-
# update the global_score object for each stream. Recall that all instances
|
95 |
-
# in each stream link all to same python dict object
|
96 |
-
for stream_name in multi_stream.keys():
|
97 |
-
instance = first_instances[stream_name]
|
98 |
-
instance["score"]["global"].clear()
|
99 |
-
instance["score"]["global"].update(global_score)
|
100 |
-
|
101 |
-
def never_peek_twice_generator(
|
102 |
-
stream_name: str, first_instances: dict, iterators: dict
|
103 |
-
) -> Generator:
|
104 |
-
while True:
|
105 |
-
if stream_name in first_instances:
|
106 |
-
yield first_instances.pop(stream_name)
|
107 |
-
try:
|
108 |
-
yield next(iterators[stream_name])
|
109 |
-
except StopIteration:
|
110 |
-
return
|
111 |
-
|
112 |
-
return MultiStream(
|
113 |
-
{
|
114 |
-
stream_name: DynamicStream(
|
115 |
-
never_peek_twice_generator,
|
116 |
-
gen_kwargs={
|
117 |
-
"stream_name": stream_name,
|
118 |
-
"first_instances": first_instances,
|
119 |
-
"iterators": iterators,
|
120 |
-
},
|
121 |
-
)
|
122 |
-
for stream_name in multi_stream.keys()
|
123 |
-
}
|
124 |
-
)
|
125 |
|
126 |
|
127 |
class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
@@ -142,11 +52,6 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
142 |
)
|
143 |
|
144 |
|
145 |
-
# The task_data field in the schema is defined as
|
146 |
-
# Sequence({"key": Value(dtype="string"), "value": Value("string")})
|
147 |
-
# When receiving instances from this scheme, the keys and values are returned as two separate
|
148 |
-
# lists, and are converted to a dictionary.
|
149 |
-
|
150 |
_post_process_steps = SequentialOperator(
|
151 |
steps=[
|
152 |
Copy(
|
@@ -156,6 +61,7 @@ _post_process_steps = SequentialOperator(
|
|
156 |
Copy(
|
157 |
field="references",
|
158 |
to_field="raw_references",
|
|
|
159 |
),
|
160 |
Copy(
|
161 |
field="source",
|
@@ -171,11 +77,177 @@ _post_process_steps = SequentialOperator(
|
|
171 |
Copy(
|
172 |
field="references",
|
173 |
to_field="processed_references",
|
|
|
174 |
),
|
175 |
]
|
176 |
)
|
177 |
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
class PostProcessRecipe(SequentialOperatorInitializer):
|
180 |
def prepare(self):
|
181 |
register_all_artifacts()
|
@@ -185,10 +257,10 @@ class PostProcessRecipe(SequentialOperatorInitializer):
|
|
185 |
]
|
186 |
|
187 |
|
188 |
-
def
|
189 |
predictions: List[str],
|
190 |
references: Iterable,
|
191 |
-
split_name: str =
|
192 |
):
|
193 |
_reset_env_local_catalogs()
|
194 |
register_all_artifacts()
|
@@ -203,7 +275,7 @@ def _post_process(
|
|
203 |
|
204 |
class MetricRecipe(SequentialOperatorInitializer):
|
205 |
calc_confidence_intervals: bool = True
|
206 |
-
|
207 |
|
208 |
def prepare(self):
|
209 |
register_all_artifacts()
|
@@ -211,21 +283,19 @@ class MetricRecipe(SequentialOperatorInitializer):
|
|
211 |
FromPredictionsAndOriginalData(),
|
212 |
LoadJson(field="task_data"),
|
213 |
_post_process_steps,
|
214 |
-
|
215 |
-
|
216 |
-
number_of_fusion_generations=self.number_of_fusion_generations,
|
217 |
),
|
218 |
ApplyMetric(
|
219 |
"metrics",
|
220 |
calc_confidence_intervals=self.calc_confidence_intervals,
|
221 |
),
|
222 |
-
|
223 |
-
|
224 |
-
RenameFields(
|
225 |
field="raw_prediction",
|
226 |
to_field="prediction",
|
227 |
),
|
228 |
-
|
229 |
field="raw_references",
|
230 |
to_field="references",
|
231 |
),
|
|
|
1 |
import json
|
2 |
+
from collections import defaultdict
|
3 |
+
from functools import lru_cache
|
4 |
+
from statistics import mean
|
5 |
+
from typing import Any, Dict, Iterable, List, Optional
|
6 |
|
7 |
from datasets import Features, Value
|
|
|
8 |
|
9 |
from .dataclass import Dataclass
|
|
|
10 |
from .operator import (
|
11 |
MultiStreamOperator,
|
12 |
SequentialOperator,
|
|
|
18 |
ApplyOperatorsField,
|
19 |
Copy,
|
20 |
FlattenInstances,
|
21 |
+
Rename,
|
|
|
|
|
22 |
)
|
23 |
from .register import _reset_env_local_catalogs, register_all_artifacts
|
24 |
from .schema import UNITXT_DATASET_SCHEMA
|
25 |
+
from .settings_utils import get_constants, get_settings
|
26 |
from .stream import DynamicStream, MultiStream
|
27 |
from .struct_data_operators import LoadJson
|
28 |
from .utils import deepcopy
|
29 |
|
30 |
+
constants = get_constants()
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
def nan_mean(scores):
|
34 |
+
return mean(score for score in scores if score == score)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
|
52 |
)
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
55 |
_post_process_steps = SequentialOperator(
|
56 |
steps=[
|
57 |
Copy(
|
|
|
61 |
Copy(
|
62 |
field="references",
|
63 |
to_field="raw_references",
|
64 |
+
dont_apply_to_streams=[constants.inference_stream],
|
65 |
),
|
66 |
Copy(
|
67 |
field="source",
|
|
|
77 |
Copy(
|
78 |
field="references",
|
79 |
to_field="processed_references",
|
80 |
+
dont_apply_to_streams=[constants.inference_stream],
|
81 |
),
|
82 |
]
|
83 |
)
|
84 |
|
85 |
|
86 |
+
@lru_cache(maxsize=None)
|
87 |
+
def group_str(json_str):
|
88 |
+
data = json.loads(json_str)
|
89 |
+
return ",".join(f"{k}:{v}" for k, v in data.items())
|
90 |
+
|
91 |
+
|
92 |
+
class SplitSubsetsAndGroups(MultiStreamOperator):
|
93 |
+
"""Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
number_of_fusion_generations: int
|
97 |
+
|
98 |
+
the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
|
99 |
+
when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
|
100 |
+
(See BaseFusion and its extensions)
|
101 |
+
subsets_depth specifies the depth of the prefix by which to split the stream.
|
102 |
+
"""
|
103 |
+
|
104 |
+
subsets_field: str = "subset"
|
105 |
+
groups_field: str = "groups"
|
106 |
+
subset_depth: Optional[int] = None
|
107 |
+
|
108 |
+
def process(self, multi_stream: MultiStream) -> MultiStream:
|
109 |
+
result = defaultdict(list)
|
110 |
+
|
111 |
+
for stream_name, stream in multi_stream.items():
|
112 |
+
for i, instance in enumerate(stream):
|
113 |
+
instance["__idx__"] = i
|
114 |
+
|
115 |
+
for field in [self.subsets_field, self.groups_field]:
|
116 |
+
if field not in instance:
|
117 |
+
raise ValueError(
|
118 |
+
f"Field {field} is missing from instance {instance}"
|
119 |
+
)
|
120 |
+
|
121 |
+
subset_stream_name = (
|
122 |
+
stream_name
|
123 |
+
+ "://"
|
124 |
+
+ "/".join(instance[self.subsets_field][: self.subset_depth])
|
125 |
+
)
|
126 |
+
|
127 |
+
result[subset_stream_name].append(instance)
|
128 |
+
|
129 |
+
for group in instance[self.groups_field]:
|
130 |
+
result[subset_stream_name + "?" + group_str(group)].append(instance)
|
131 |
+
|
132 |
+
return MultiStream.from_iterables(result, copying=True)
|
133 |
+
|
134 |
+
|
135 |
+
@lru_cache(maxsize=None)
|
136 |
+
def group_str_to_key_value(group_str):
|
137 |
+
keys = []
|
138 |
+
values = []
|
139 |
+
for k_v in group_str.split(","):
|
140 |
+
k, v = k_v.split(":")
|
141 |
+
if v.isdigit():
|
142 |
+
v = int(v)
|
143 |
+
keys.append(k)
|
144 |
+
values.append(v)
|
145 |
+
|
146 |
+
if len(keys) == 1:
|
147 |
+
key = keys[0]
|
148 |
+
else:
|
149 |
+
key = tuple(keys)
|
150 |
+
|
151 |
+
if len(values) == 1:
|
152 |
+
value = values[0]
|
153 |
+
else:
|
154 |
+
value = tuple(values)
|
155 |
+
|
156 |
+
return key, value
|
157 |
+
|
158 |
+
|
159 |
+
@lru_cache(maxsize=None)
|
160 |
+
def stream_name_to_origin_subset_group(stream_name):
|
161 |
+
origin, subset_group = stream_name.split("://")
|
162 |
+
if "?" in subset_group:
|
163 |
+
subset, group = subset_group.split("?")
|
164 |
+
else:
|
165 |
+
subset, group = subset_group, None
|
166 |
+
return origin, subset, group
|
167 |
+
|
168 |
+
|
169 |
+
class JoinSubsetsAndGroups(MultiStreamOperator):
|
170 |
+
def process(self, multi_stream: MultiStream) -> MultiStream:
|
171 |
+
instances = defaultdict(dict)
|
172 |
+
global_scores = defaultdict(dict)
|
173 |
+
|
174 |
+
for stream_name, stream in multi_stream.items():
|
175 |
+
origin, subset, group = stream_name_to_origin_subset_group(stream_name)
|
176 |
+
|
177 |
+
for i, instance in enumerate(stream):
|
178 |
+
global_score = instance["score"].pop("global")
|
179 |
+
|
180 |
+
idx = instance.pop("__idx__")
|
181 |
+
if idx not in instances[origin]:
|
182 |
+
instances[origin][idx] = instance
|
183 |
+
|
184 |
+
# from here below setting the global scores from that stream
|
185 |
+
# can be done with first instance only
|
186 |
+
if i > 0:
|
187 |
+
continue
|
188 |
+
|
189 |
+
if not group and not subset:
|
190 |
+
global_scores[origin]["global"] = global_score
|
191 |
+
else:
|
192 |
+
path = []
|
193 |
+
|
194 |
+
if subset:
|
195 |
+
path += ["subsets", *subset.split("/")]
|
196 |
+
|
197 |
+
if group:
|
198 |
+
key, value = group_str_to_key_value(group)
|
199 |
+
path += ["groups", key, value]
|
200 |
+
|
201 |
+
target = global_scores[origin]
|
202 |
+
for part in path[:-1]:
|
203 |
+
if part not in target:
|
204 |
+
target[part] = {}
|
205 |
+
target = target[part]
|
206 |
+
target[path[-1]] = global_score
|
207 |
+
|
208 |
+
# the leafs always have score_name and score
|
209 |
+
def recursive_mean(dic):
|
210 |
+
if isinstance(dic, dict):
|
211 |
+
if "score" in dic and "score_name" in dic:
|
212 |
+
return dic
|
213 |
+
|
214 |
+
result = {}
|
215 |
+
all_scores = []
|
216 |
+
for k, v in dic.items():
|
217 |
+
score = recursive_mean(v)
|
218 |
+
if score is not None:
|
219 |
+
all_scores.append(score["score"])
|
220 |
+
result[k] = score
|
221 |
+
|
222 |
+
result["score"] = nan_mean(all_scores)
|
223 |
+
result["score_name"] = "subsets_mean"
|
224 |
+
|
225 |
+
if result:
|
226 |
+
return result
|
227 |
+
|
228 |
+
return None
|
229 |
+
|
230 |
+
result = {}
|
231 |
+
for stream_name, stream_instances in instances.items():
|
232 |
+
score = global_scores[stream_name]
|
233 |
+
|
234 |
+
if "subsets" in score:
|
235 |
+
score["subsets"] = recursive_mean(score["subsets"])
|
236 |
+
score["global"] = {
|
237 |
+
"score": score["subsets"]["score"],
|
238 |
+
"score_name": score["subsets"]["score_name"],
|
239 |
+
}
|
240 |
+
|
241 |
+
sorted_instances = []
|
242 |
+
for key in sorted(stream_instances.keys()):
|
243 |
+
instance = stream_instances[key]
|
244 |
+
instance["score"].update(deepcopy(score))
|
245 |
+
sorted_instances.append(instance)
|
246 |
+
result[stream_name] = sorted_instances
|
247 |
+
|
248 |
+
return MultiStream.from_iterables(result, copying=True)
|
249 |
+
|
250 |
+
|
251 |
class PostProcessRecipe(SequentialOperatorInitializer):
|
252 |
def prepare(self):
|
253 |
register_all_artifacts()
|
|
|
257 |
]
|
258 |
|
259 |
|
260 |
+
def _inference_post_process(
|
261 |
predictions: List[str],
|
262 |
references: Iterable,
|
263 |
+
split_name: str = constants.inference_stream,
|
264 |
):
|
265 |
_reset_env_local_catalogs()
|
266 |
register_all_artifacts()
|
|
|
275 |
|
276 |
class MetricRecipe(SequentialOperatorInitializer):
|
277 |
calc_confidence_intervals: bool = True
|
278 |
+
subset_depth: int = 2
|
279 |
|
280 |
def prepare(self):
|
281 |
register_all_artifacts()
|
|
|
283 |
FromPredictionsAndOriginalData(),
|
284 |
LoadJson(field="task_data"),
|
285 |
_post_process_steps,
|
286 |
+
SplitSubsetsAndGroups(
|
287 |
+
subset_depth=self.subset_depth,
|
|
|
288 |
),
|
289 |
ApplyMetric(
|
290 |
"metrics",
|
291 |
calc_confidence_intervals=self.calc_confidence_intervals,
|
292 |
),
|
293 |
+
JoinSubsetsAndGroups(),
|
294 |
+
Rename(
|
|
|
295 |
field="raw_prediction",
|
296 |
to_field="prediction",
|
297 |
),
|
298 |
+
Rename(
|
299 |
field="raw_references",
|
300 |
to_field="references",
|
301 |
),
|
metrics.py
CHANGED
@@ -21,7 +21,6 @@ from scipy.stats._warnings_errors import DegenerateDataWarning
|
|
21 |
from .artifact import Artifact, fetch_artifact
|
22 |
from .dataclass import (
|
23 |
AbstractField,
|
24 |
-
DeprecatedField,
|
25 |
InternalField,
|
26 |
NonPositionalField,
|
27 |
OptionalField,
|
@@ -1425,11 +1424,7 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
1425 |
main_score: str = None
|
1426 |
preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
|
1427 |
postprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
|
1428 |
-
postpreprocess_steps: Optional[List[StreamingOperator]] =
|
1429 |
-
metadata={
|
1430 |
-
"deprecation_msg": "Field 'postpreprocess_steps' is deprecated. Please use 'postprocess_steps' for the same purpose."
|
1431 |
-
}
|
1432 |
-
)
|
1433 |
metric: Metric = None
|
1434 |
|
1435 |
def disable_confidence_interval_calculation(self):
|
@@ -1446,6 +1441,9 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
1446 |
assert isinstance(
|
1447 |
self.metric, Metric
|
1448 |
), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
|
|
|
|
|
|
|
1449 |
|
1450 |
def prepare(self):
|
1451 |
super().prepare()
|
@@ -4729,3 +4727,170 @@ class MetricsEnsemble(InstanceMetric):
|
|
4729 |
|
4730 |
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
4731 |
return {self.main_score: prediction}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
from .artifact import Artifact, fetch_artifact
|
22 |
from .dataclass import (
|
23 |
AbstractField,
|
|
|
24 |
InternalField,
|
25 |
NonPositionalField,
|
26 |
OptionalField,
|
|
|
1424 |
main_score: str = None
|
1425 |
preprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
|
1426 |
postprocess_steps: Optional[List[StreamingOperator]] = field(default_factory=list)
|
1427 |
+
postpreprocess_steps: Optional[List[StreamingOperator]] = None
|
|
|
|
|
|
|
|
|
1428 |
metric: Metric = None
|
1429 |
|
1430 |
def disable_confidence_interval_calculation(self):
|
|
|
1441 |
assert isinstance(
|
1442 |
self.metric, Metric
|
1443 |
), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
|
1444 |
+
if self.postpreprocess_steps is not None:
|
1445 |
+
depr_message = "Field 'postpreprocess_steps' is deprecated. Please use 'postprocess_steps' for the same purpose."
|
1446 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
1447 |
|
1448 |
def prepare(self):
|
1449 |
super().prepare()
|
|
|
4727 |
|
4728 |
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
4729 |
return {self.main_score: prediction}
|
4730 |
+
|
4731 |
+
|
4732 |
+
class F1Strings(InstanceMetric):
|
4733 |
+
main_score = "f1_strings"
|
4734 |
+
reduction_map = {"mean": ["f1_strings"]}
|
4735 |
+
prediction_type = str
|
4736 |
+
single_reference_per_prediction = True
|
4737 |
+
_requirements_list = {
|
4738 |
+
"spacy": "Please pip install spacy",
|
4739 |
+
}
|
4740 |
+
|
4741 |
+
def prepare(self):
|
4742 |
+
super().prepare()
|
4743 |
+
import spacy
|
4744 |
+
|
4745 |
+
try:
|
4746 |
+
self.nlp = spacy.load("en_core_web_sm")
|
4747 |
+
except OSError:
|
4748 |
+
from spacy.cli import download
|
4749 |
+
|
4750 |
+
download("en_core_web_sm")
|
4751 |
+
self.nlp = spacy.load("en_core_web_sm")
|
4752 |
+
|
4753 |
+
def compute(
|
4754 |
+
self,
|
4755 |
+
references: List[str],
|
4756 |
+
prediction: str,
|
4757 |
+
task_data: List[Dict],
|
4758 |
+
) -> dict:
|
4759 |
+
doc_ref = self.nlp(references[0])
|
4760 |
+
set_ref = Counter([token.text.lower() for token in doc_ref])
|
4761 |
+
doc_pred = self.nlp(prediction)
|
4762 |
+
set_pred = Counter([token.text.lower() for token in doc_pred])
|
4763 |
+
|
4764 |
+
true_positives = sum((set_ref & set_pred).values())
|
4765 |
+
false_positives = sum((set_ref - set_pred).values())
|
4766 |
+
false_negatives = sum((set_pred - set_ref).values())
|
4767 |
+
|
4768 |
+
if true_positives == 0:
|
4769 |
+
f1 = 0.0
|
4770 |
+
else:
|
4771 |
+
precision = true_positives / (true_positives + false_positives)
|
4772 |
+
recall = true_positives / (true_positives + false_negatives)
|
4773 |
+
if precision + recall == 0:
|
4774 |
+
f1 = 0.0
|
4775 |
+
else:
|
4776 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
4777 |
+
|
4778 |
+
return {self.main_score: [f1], "score_name": self.main_score}
|
4779 |
+
|
4780 |
+
|
4781 |
+
class RandomForestMetricsEnsemble(MetricsEnsemble):
|
4782 |
+
"""This class extends the `MetricsEnsemble` base class and leverages a pre-trained scikit-learn Random Forest classification model to combine and aggregate scores from multiple judges.
|
4783 |
+
|
4784 |
+
`load_weights` method:
|
4785 |
+
Loads model weights from dictionary representation of a random forest classifier.
|
4786 |
+
`ensemble` method:
|
4787 |
+
Decodes the RandomForestClassifier object and predict a score based on the given instance.
|
4788 |
+
"""
|
4789 |
+
|
4790 |
+
_requirements_list: List[str] = ["sklearn"]
|
4791 |
+
|
4792 |
+
def decode_tree(self, tree_dict, n_features, n_classes, n_outputs):
|
4793 |
+
from sklearn.tree._tree import Tree
|
4794 |
+
|
4795 |
+
tree_dict["nodes"] = [tuple(lst) for lst in tree_dict["nodes"]]
|
4796 |
+
|
4797 |
+
tree_dict["values"] = np.array(tree_dict["values"])
|
4798 |
+
names = [
|
4799 |
+
"left_child",
|
4800 |
+
"right_child",
|
4801 |
+
"feature",
|
4802 |
+
"threshold",
|
4803 |
+
"impurity",
|
4804 |
+
"n_node_samples",
|
4805 |
+
"weighted_n_node_samples",
|
4806 |
+
"missing_go_to_left",
|
4807 |
+
]
|
4808 |
+
tree_dict["nodes"] = np.array(
|
4809 |
+
tree_dict["nodes"],
|
4810 |
+
dtype=np.dtype({"names": names, "formats": tree_dict["nodes_dtype"]}),
|
4811 |
+
)
|
4812 |
+
|
4813 |
+
tree = Tree(n_features, np.array([n_classes], dtype=np.intp), n_outputs)
|
4814 |
+
tree.__setstate__(tree_dict)
|
4815 |
+
|
4816 |
+
return tree
|
4817 |
+
|
4818 |
+
def decode_decision_tree(self, model_dict):
|
4819 |
+
from sklearn.tree import DecisionTreeClassifier
|
4820 |
+
|
4821 |
+
decoded_model = DecisionTreeClassifier(**model_dict["params"])
|
4822 |
+
|
4823 |
+
decoded_model.n_features_in_ = model_dict["n_features_in_"]
|
4824 |
+
decoded_model.n_outputs_ = model_dict["n_outputs_"]
|
4825 |
+
decoded_model.max_features_ = model_dict["max_features_"]
|
4826 |
+
decoded_model.n_classes_ = model_dict["n_classes_"]
|
4827 |
+
decoded_model.classes_ = np.array(model_dict["classes_"])
|
4828 |
+
|
4829 |
+
tree = self.decode_tree(
|
4830 |
+
model_dict["tree_"],
|
4831 |
+
model_dict["n_features_in_"],
|
4832 |
+
model_dict["n_classes_"],
|
4833 |
+
model_dict["n_outputs_"],
|
4834 |
+
)
|
4835 |
+
decoded_model.tree_ = tree
|
4836 |
+
|
4837 |
+
return decoded_model
|
4838 |
+
|
4839 |
+
def decode_forest(self, model_dict):
|
4840 |
+
from sklearn.ensemble import RandomForestClassifier
|
4841 |
+
|
4842 |
+
model = RandomForestClassifier(**model_dict["params"])
|
4843 |
+
estimators = [
|
4844 |
+
self.decode_decision_tree(decision_tree)
|
4845 |
+
for decision_tree in model_dict["estimators_"]
|
4846 |
+
]
|
4847 |
+
model.estimators_ = np.array(estimators)
|
4848 |
+
|
4849 |
+
model.n_features_in_ = model_dict["n_features_in_"]
|
4850 |
+
model.feature_names_in_ = np.array(model_dict["feature_names_in_"])
|
4851 |
+
|
4852 |
+
model.min_samples_split = model_dict["min_samples_split"]
|
4853 |
+
model.max_depth = model_dict["max_depth"]
|
4854 |
+
model.min_samples_leaf = model_dict["min_samples_leaf"]
|
4855 |
+
model.min_weight_fraction_leaf = model_dict["min_weight_fraction_leaf"]
|
4856 |
+
model.max_features = model_dict["max_features"]
|
4857 |
+
model.classes_ = np.array(model_dict["classes_"])
|
4858 |
+
model.max_leaf_nodes = model_dict["max_leaf_nodes"]
|
4859 |
+
model.min_impurity_decrease = model_dict["min_impurity_decrease"]
|
4860 |
+
model.n_outputs_ = model_dict["n_outputs_"]
|
4861 |
+
|
4862 |
+
if isinstance(model_dict["n_classes_"], list):
|
4863 |
+
model.n_classes_ = np.array(model_dict["n_classes_"])
|
4864 |
+
else:
|
4865 |
+
model.n_classes_ = model_dict["n_classes_"]
|
4866 |
+
|
4867 |
+
if "oob_score_" in model_dict:
|
4868 |
+
model.oob_score_ = model_dict["oob_score_"]
|
4869 |
+
if "oob_decision_function_" in model_dict:
|
4870 |
+
model.oob_decision_function_ = model_dict["oob_decision_function_"]
|
4871 |
+
|
4872 |
+
return model
|
4873 |
+
|
4874 |
+
def prepare(self):
|
4875 |
+
super().prepare()
|
4876 |
+
|
4877 |
+
@staticmethod
|
4878 |
+
def load_weights(json_file):
|
4879 |
+
with open(json_file) as file:
|
4880 |
+
return json.load(file)
|
4881 |
+
|
4882 |
+
def ensemble(self, instance):
|
4883 |
+
assert (
|
4884 |
+
self.weights is not None
|
4885 |
+
), "RandomForestMetricsEnsemble must set self.weights before it can be used"
|
4886 |
+
ensemble_model = self.decode_forest(self.weights)
|
4887 |
+
|
4888 |
+
prediction_lst = []
|
4889 |
+
for i, metric in enumerate(self.metrics):
|
4890 |
+
prediction_lst.append(
|
4891 |
+
instance["score"]["instance"][
|
4892 |
+
self.get_prefix_name(i) + metric.main_score
|
4893 |
+
]
|
4894 |
+
)
|
4895 |
+
score = ensemble_model.predict([prediction_lst])
|
4896 |
+
return score.tolist()[0]
|
operator.py
CHANGED
@@ -4,9 +4,12 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
|
4 |
|
5 |
from .artifact import Artifact
|
6 |
from .dataclass import InternalField, NonPositionalField
|
|
|
7 |
from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
|
8 |
from .utils import is_module_available
|
9 |
|
|
|
|
|
10 |
|
11 |
class Operator(Artifact):
|
12 |
pass
|
@@ -135,8 +138,13 @@ class MultiStreamOperator(StreamingOperator):
|
|
135 |
|
136 |
caching: bool = NonPositionalField(default=None)
|
137 |
|
138 |
-
def __call__(
|
|
|
|
|
139 |
self.before_process_multi_stream()
|
|
|
|
|
|
|
140 |
result = self._process_multi_stream(multi_stream)
|
141 |
if self.caching is not None:
|
142 |
result.set_caching(self.caching)
|
@@ -158,11 +166,11 @@ class MultiStreamOperator(StreamingOperator):
|
|
158 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
159 |
pass
|
160 |
|
161 |
-
def process_instance(self, instance, stream_name=
|
162 |
instance = self.verify_instance(instance)
|
163 |
multi_stream = MultiStream({stream_name: stream_single(instance)})
|
164 |
processed_multi_stream = self(multi_stream)
|
165 |
-
return
|
166 |
|
167 |
|
168 |
class SourceOperator(MultiStreamOperator):
|
@@ -214,6 +222,15 @@ class StreamInitializerOperator(SourceOperator):
|
|
214 |
pass
|
215 |
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
class StreamOperator(MultiStreamOperator):
|
218 |
"""A class representing a single-stream operator in the streaming system.
|
219 |
|
@@ -277,12 +294,12 @@ class StreamOperator(MultiStreamOperator):
|
|
277 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
278 |
pass
|
279 |
|
280 |
-
def process_instance(self, instance, stream_name=
|
281 |
instance = self.verify_instance(instance)
|
282 |
processed_stream = self._process_single_stream(
|
283 |
stream_single(instance), stream_name
|
284 |
)
|
285 |
-
return
|
286 |
|
287 |
|
288 |
class SingleStreamOperator(StreamOperator):
|
@@ -323,10 +340,10 @@ class PagedStreamOperator(StreamOperator):
|
|
323 |
def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
|
324 |
pass
|
325 |
|
326 |
-
def process_instance(self, instance, stream_name=
|
327 |
instance = self.verify_instance(instance)
|
328 |
processed_stream = self._process_page([instance], stream_name)
|
329 |
-
return
|
330 |
|
331 |
|
332 |
class SingleStreamReducer(StreamingOperator):
|
@@ -381,7 +398,7 @@ class InstanceOperator(StreamOperator):
|
|
381 |
) -> Dict[str, Any]:
|
382 |
pass
|
383 |
|
384 |
-
def process_instance(self, instance, stream_name=
|
385 |
return self._process_instance(instance, stream_name)
|
386 |
|
387 |
|
@@ -404,30 +421,13 @@ class InstanceOperatorValidator(InstanceOperator):
|
|
404 |
except StopIteration as e:
|
405 |
raise EmptyStreamError(f"Stream '{stream_name}' is empty") from e
|
406 |
result = self._process_instance(first_instance, stream_name)
|
407 |
-
self.validate(result)
|
408 |
yield result
|
409 |
yield from (
|
410 |
self._process_instance(instance, stream_name) for instance in iterator
|
411 |
)
|
412 |
|
413 |
|
414 |
-
class BaseFieldOperator(Artifact):
|
415 |
-
"""A class representing a field operator in the streaming system.
|
416 |
-
|
417 |
-
A field operator is a type of `Artifact` that operates on a single field within an instance. It takes an instance and a field name as input, processes the field, and updates the field in the instance with the processed value.
|
418 |
-
"""
|
419 |
-
|
420 |
-
def __call__(self, data: Dict[str, Any], field: str) -> dict:
|
421 |
-
data = self.verify_instance(data)
|
422 |
-
value = self.process(data[field])
|
423 |
-
data[field] = value
|
424 |
-
return data
|
425 |
-
|
426 |
-
@abstractmethod
|
427 |
-
def process(self, value: Any) -> Any:
|
428 |
-
pass
|
429 |
-
|
430 |
-
|
431 |
class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
|
432 |
"""A class representing an instance operator with global access in the streaming system.
|
433 |
|
@@ -436,7 +436,12 @@ class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
|
|
436 |
In order to make this efficient and to avoid qudratic complexity, it caches the accessible streams by default.
|
437 |
"""
|
438 |
|
439 |
-
def __call__(
|
|
|
|
|
|
|
|
|
|
|
440 |
result = {}
|
441 |
|
442 |
for stream_name, stream in multi_stream.items():
|
|
|
4 |
|
5 |
from .artifact import Artifact
|
6 |
from .dataclass import InternalField, NonPositionalField
|
7 |
+
from .settings_utils import get_constants
|
8 |
from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
|
9 |
from .utils import is_module_available
|
10 |
|
11 |
+
constants = get_constants()
|
12 |
+
|
13 |
|
14 |
class Operator(Artifact):
|
15 |
pass
|
|
|
138 |
|
139 |
caching: bool = NonPositionalField(default=None)
|
140 |
|
141 |
+
def __call__(
|
142 |
+
self, multi_stream: Optional[MultiStream] = None, **instance: Dict[str, Any]
|
143 |
+
) -> Union[MultiStream, Dict[str, Any]]:
|
144 |
self.before_process_multi_stream()
|
145 |
+
if instance:
|
146 |
+
if multi_stream is not None:
|
147 |
+
return self.process_instance(instance)
|
148 |
result = self._process_multi_stream(multi_stream)
|
149 |
if self.caching is not None:
|
150 |
result.set_caching(self.caching)
|
|
|
166 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
167 |
pass
|
168 |
|
169 |
+
def process_instance(self, instance, stream_name=constants.instance_stream):
|
170 |
instance = self.verify_instance(instance)
|
171 |
multi_stream = MultiStream({stream_name: stream_single(instance)})
|
172 |
processed_multi_stream = self(multi_stream)
|
173 |
+
return instance_result(processed_multi_stream[stream_name])
|
174 |
|
175 |
|
176 |
class SourceOperator(MultiStreamOperator):
|
|
|
222 |
pass
|
223 |
|
224 |
|
225 |
+
def instance_result(result_stream):
|
226 |
+
result = list(result_stream)
|
227 |
+
if len(result) == 0:
|
228 |
+
return None
|
229 |
+
if len(result) == 1:
|
230 |
+
return result[0]
|
231 |
+
return result
|
232 |
+
|
233 |
+
|
234 |
class StreamOperator(MultiStreamOperator):
|
235 |
"""A class representing a single-stream operator in the streaming system.
|
236 |
|
|
|
294 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
295 |
pass
|
296 |
|
297 |
+
def process_instance(self, instance, stream_name=constants.instance_stream):
|
298 |
instance = self.verify_instance(instance)
|
299 |
processed_stream = self._process_single_stream(
|
300 |
stream_single(instance), stream_name
|
301 |
)
|
302 |
+
return instance_result(processed_stream)
|
303 |
|
304 |
|
305 |
class SingleStreamOperator(StreamOperator):
|
|
|
340 |
def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
|
341 |
pass
|
342 |
|
343 |
+
def process_instance(self, instance, stream_name=constants.instance_stream):
|
344 |
instance = self.verify_instance(instance)
|
345 |
processed_stream = self._process_page([instance], stream_name)
|
346 |
+
return instance_result(processed_stream)
|
347 |
|
348 |
|
349 |
class SingleStreamReducer(StreamingOperator):
|
|
|
398 |
) -> Dict[str, Any]:
|
399 |
pass
|
400 |
|
401 |
+
def process_instance(self, instance, stream_name=constants.instance_stream):
|
402 |
return self._process_instance(instance, stream_name)
|
403 |
|
404 |
|
|
|
421 |
except StopIteration as e:
|
422 |
raise EmptyStreamError(f"Stream '{stream_name}' is empty") from e
|
423 |
result = self._process_instance(first_instance, stream_name)
|
424 |
+
self.validate(result, stream_name)
|
425 |
yield result
|
426 |
yield from (
|
427 |
self._process_instance(instance, stream_name) for instance in iterator
|
428 |
)
|
429 |
|
430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
|
432 |
"""A class representing an instance operator with global access in the streaming system.
|
433 |
|
|
|
436 |
In order to make this efficient and to avoid qudratic complexity, it caches the accessible streams by default.
|
437 |
"""
|
438 |
|
439 |
+
def __call__(
|
440 |
+
self, multi_stream: Optional[MultiStream] = None, **instance: Dict[str, Any]
|
441 |
+
) -> MultiStream:
|
442 |
+
if instance:
|
443 |
+
raise NotImplementedError("Instance mode is not supported")
|
444 |
+
|
445 |
result = {}
|
446 |
|
447 |
for stream_name, stream in multi_stream.items():
|
operators.py
CHANGED
@@ -14,9 +14,9 @@ To enhance the functionality of Unitxt, users are encouraged to develop custom o
|
|
14 |
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
|
15 |
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
|
16 |
|
17 |
-
General or
|
18 |
--------------------------------
|
19 |
-
Some operators are
|
20 |
|
21 |
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
|
22 |
- :class:`splitters<unitxt.splitters>` for fixing data splits.
|
@@ -28,12 +28,12 @@ Some operators are specielized in specific data or specific operations such as:
|
|
28 |
- :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
|
29 |
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
|
30 |
|
31 |
-
Other
|
32 |
|
33 |
- :class:`templates<unitxt.templates>` for verbalizing data examples.
|
34 |
- :class:`formats<unitxt.formats>` for preparing data for models.
|
35 |
|
36 |
-
The rest of this section is dedicated
|
37 |
|
38 |
General Operators List:
|
39 |
------------------------
|
@@ -42,6 +42,7 @@ General Operators List:
|
|
42 |
import copy
|
43 |
import operator
|
44 |
import uuid
|
|
|
45 |
import zipfile
|
46 |
from abc import abstractmethod
|
47 |
from collections import Counter, defaultdict
|
@@ -63,7 +64,7 @@ from typing import (
|
|
63 |
import requests
|
64 |
|
65 |
from .artifact import Artifact, fetch_artifact
|
66 |
-
from .dataclass import
|
67 |
from .deprecation_utils import deprecation
|
68 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
69 |
from .operator import (
|
@@ -81,13 +82,14 @@ from .operator import (
|
|
81 |
StreamOperator,
|
82 |
)
|
83 |
from .random_utils import new_random_generator
|
84 |
-
from .settings_utils import get_settings
|
85 |
from .stream import DynamicStream, Stream
|
86 |
from .text_utils import nested_tuple_to_string
|
87 |
from .type_utils import isoftype
|
88 |
from .utils import deepcopy, flatten_dict
|
89 |
|
90 |
settings = get_settings()
|
|
|
91 |
|
92 |
|
93 |
class FromIterables(StreamInitializerOperator):
|
@@ -253,14 +255,15 @@ class Set(InstanceOperator):
|
|
253 |
"""
|
254 |
|
255 |
fields: Dict[str, object]
|
256 |
-
use_query: bool =
|
257 |
-
metadata={
|
258 |
-
"deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
|
259 |
-
"Please remove this field from your code."
|
260 |
-
}
|
261 |
-
)
|
262 |
use_deepcopy: bool = False
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
def process(
|
265 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
266 |
) -> Dict[str, Any]:
|
@@ -341,19 +344,37 @@ class InstanceFieldOperator(InstanceOperator):
|
|
341 |
field: Optional[str] = None
|
342 |
to_field: Optional[str] = None
|
343 |
field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
|
344 |
-
use_query: bool =
|
345 |
-
metadata={
|
346 |
-
"deprecation_msg": "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. "
|
347 |
-
"Please remove this field from your code."
|
348 |
-
}
|
349 |
-
)
|
350 |
process_every_value: bool = False
|
351 |
get_default: Any = None
|
352 |
not_exist_ok: bool = False
|
353 |
|
354 |
def verify(self):
|
355 |
super().verify()
|
|
|
|
|
|
|
356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
assert (
|
358 |
self.field is not None or self.field_to_field is not None
|
359 |
), "Must supply a field to work on"
|
@@ -363,7 +384,9 @@ class InstanceFieldOperator(InstanceOperator):
|
|
363 |
assert (
|
364 |
self.field is None or self.field_to_field is None
|
365 |
), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
|
366 |
-
assert
|
|
|
|
|
367 |
assert (
|
368 |
len(self._field_to_field) > 0
|
369 |
), f"'input argument 'field_to_field' should convey at least one field to process. Got {self.field_to_field}"
|
@@ -402,31 +425,10 @@ class InstanceFieldOperator(InstanceOperator):
|
|
402 |
def process_instance_value(self, value: Any, instance: Dict[str, Any]):
|
403 |
pass
|
404 |
|
405 |
-
def prepare(self):
|
406 |
-
super().prepare()
|
407 |
-
|
408 |
-
# prepare is invoked before verify, hence must make some checks here, before the changes done here
|
409 |
-
assert (
|
410 |
-
(self.field is None) != (self.field_to_field is None)
|
411 |
-
), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
|
412 |
-
assert (
|
413 |
-
self.to_field is None or self.field_to_field is None
|
414 |
-
), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
|
415 |
-
|
416 |
-
if self.field_to_field is None:
|
417 |
-
self._field_to_field = [
|
418 |
-
(self.field, self.to_field if self.to_field is not None else self.field)
|
419 |
-
]
|
420 |
-
else:
|
421 |
-
self._field_to_field = (
|
422 |
-
list(self.field_to_field.items())
|
423 |
-
if isinstance(self.field_to_field, dict)
|
424 |
-
else self.field_to_field
|
425 |
-
)
|
426 |
-
|
427 |
def process(
|
428 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
429 |
) -> Dict[str, Any]:
|
|
|
430 |
# Need to deep copy instance, because when assigning two dictionary fields,
|
431 |
# dict_set() the target field dictionary fields.
|
432 |
# This means that if this target field was assigned to another field before,
|
@@ -474,23 +476,23 @@ class FieldOperator(InstanceFieldOperator):
|
|
474 |
pass
|
475 |
|
476 |
|
477 |
-
class
|
478 |
"""Renames fields.
|
479 |
|
480 |
Move value from one field to another, potentially, if field name contains a /, from one branch into another.
|
481 |
Remove the from field, potentially part of it in case of / in from_field.
|
482 |
|
483 |
Examples:
|
484 |
-
|
485 |
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
|
486 |
|
487 |
-
|
488 |
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
|
489 |
|
490 |
-
|
491 |
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
|
492 |
|
493 |
-
|
494 |
will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
|
495 |
|
496 |
"""
|
@@ -511,6 +513,11 @@ class RenameFields(FieldOperator):
|
|
511 |
return res
|
512 |
|
513 |
|
|
|
|
|
|
|
|
|
|
|
514 |
class AddConstant(FieldOperator):
|
515 |
"""Adds a constant, being argument 'add', to the processed value.
|
516 |
|
@@ -777,7 +784,7 @@ class Apply(InstanceOperator):
|
|
777 |
Args:
|
778 |
function (str): name of function.
|
779 |
to_field (str): the field to store the result
|
780 |
-
additional arguments are field names passed to the function
|
781 |
|
782 |
Examples:
|
783 |
Store in field "b" the uppercase string of the value in field "a"
|
@@ -846,12 +853,13 @@ class ListFieldValues(InstanceOperator):
|
|
846 |
|
847 |
fields: List[str]
|
848 |
to_field: str
|
849 |
-
use_query: bool =
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
|
|
855 |
|
856 |
def process(
|
857 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
@@ -878,12 +886,13 @@ class ZipFieldValues(InstanceOperator):
|
|
878 |
fields: List[str]
|
879 |
to_field: str
|
880 |
longest: bool = False
|
881 |
-
use_query: bool =
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
|
|
887 |
|
888 |
def process(
|
889 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
@@ -950,12 +959,13 @@ class IndexOf(InstanceOperator):
|
|
950 |
search_in: str
|
951 |
index_of: str
|
952 |
to_field: str
|
953 |
-
use_query: bool =
|
954 |
-
|
955 |
-
|
956 |
-
|
957 |
-
|
958 |
-
|
|
|
959 |
|
960 |
def process(
|
961 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
@@ -972,12 +982,13 @@ class TakeByField(InstanceOperator):
|
|
972 |
field: str
|
973 |
index: str
|
974 |
to_field: str = None
|
975 |
-
use_query: bool =
|
976 |
-
|
977 |
-
|
978 |
-
|
979 |
-
|
980 |
-
|
|
|
981 |
|
982 |
def prepare(self):
|
983 |
if self.to_field is None:
|
@@ -1060,8 +1071,12 @@ class Copy(FieldOperator):
|
|
1060 |
|
1061 |
"""
|
1062 |
|
|
|
|
|
1063 |
def process_value(self, value: Any) -> Any:
|
1064 |
-
|
|
|
|
|
1065 |
|
1066 |
|
1067 |
@deprecation(version="2.0.0", alternative=Copy)
|
@@ -1090,6 +1105,31 @@ class AddID(InstanceOperator):
|
|
1090 |
return instance
|
1091 |
|
1092 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1093 |
class CastFields(InstanceOperator):
|
1094 |
"""Casts specified fields to specified types.
|
1095 |
|
@@ -1247,7 +1287,7 @@ class ApplyOperatorsField(InstanceOperator):
|
|
1247 |
|
1248 |
# we now have a list of nanes of operators, each is equipped with process_instance method.
|
1249 |
operator = SequentialOperator(steps=operator_names)
|
1250 |
-
return operator.process_instance(instance)
|
1251 |
|
1252 |
|
1253 |
class FilterByCondition(StreamOperator):
|
@@ -1767,7 +1807,7 @@ class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
|
|
1767 |
operator, StreamingOperator
|
1768 |
), f"Operator {operator_name} must be a StreamOperator"
|
1769 |
|
1770 |
-
stream = operator(MultiStream({
|
1771 |
|
1772 |
yield from stream
|
1773 |
|
@@ -1808,7 +1848,7 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1808 |
# Here we keep all the fields besides the score, and restore them after the metric finishes.
|
1809 |
first_instance = stream.peek()
|
1810 |
keys_to_restore = set(first_instance.keys()).difference({"score"})
|
1811 |
-
multi_stream = MultiStream({
|
1812 |
multi_stream = CopyFields(
|
1813 |
field_to_field={k: f"{k}_orig" for k in keys_to_restore}
|
1814 |
)(multi_stream)
|
@@ -1830,7 +1870,7 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
1830 |
multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
|
1831 |
multi_stream
|
1832 |
)
|
1833 |
-
stream = multi_stream[
|
1834 |
yield from stream
|
1835 |
|
1836 |
|
|
|
14 |
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
|
15 |
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
|
16 |
|
17 |
+
General or Specialized Operators
|
18 |
--------------------------------
|
19 |
+
Some operators are specialized in specific data or specific operations such as:
|
20 |
|
21 |
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
|
22 |
- :class:`splitters<unitxt.splitters>` for fixing data splits.
|
|
|
28 |
- :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
|
29 |
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
|
30 |
|
31 |
+
Other specialized operators are used by unitxt internally:
|
32 |
|
33 |
- :class:`templates<unitxt.templates>` for verbalizing data examples.
|
34 |
- :class:`formats<unitxt.formats>` for preparing data for models.
|
35 |
|
36 |
+
The rest of this section is dedicated to general operators.
|
37 |
|
38 |
General Operators List:
|
39 |
------------------------
|
|
|
42 |
import copy
|
43 |
import operator
|
44 |
import uuid
|
45 |
+
import warnings
|
46 |
import zipfile
|
47 |
from abc import abstractmethod
|
48 |
from collections import Counter, defaultdict
|
|
|
64 |
import requests
|
65 |
|
66 |
from .artifact import Artifact, fetch_artifact
|
67 |
+
from .dataclass import NonPositionalField, OptionalField
|
68 |
from .deprecation_utils import deprecation
|
69 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
70 |
from .operator import (
|
|
|
82 |
StreamOperator,
|
83 |
)
|
84 |
from .random_utils import new_random_generator
|
85 |
+
from .settings_utils import get_constants, get_settings
|
86 |
from .stream import DynamicStream, Stream
|
87 |
from .text_utils import nested_tuple_to_string
|
88 |
from .type_utils import isoftype
|
89 |
from .utils import deepcopy, flatten_dict
|
90 |
|
91 |
settings = get_settings()
|
92 |
+
constants = get_constants()
|
93 |
|
94 |
|
95 |
class FromIterables(StreamInitializerOperator):
|
|
|
255 |
"""
|
256 |
|
257 |
fields: Dict[str, object]
|
258 |
+
use_query: Optional[bool] = None
|
|
|
|
|
|
|
|
|
|
|
259 |
use_deepcopy: bool = False
|
260 |
|
261 |
+
def verify(self):
|
262 |
+
super().verify()
|
263 |
+
if self.use_query is not None:
|
264 |
+
depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
|
265 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
266 |
+
|
267 |
def process(
|
268 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
269 |
) -> Dict[str, Any]:
|
|
|
344 |
field: Optional[str] = None
|
345 |
to_field: Optional[str] = None
|
346 |
field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
|
347 |
+
use_query: Optional[bool] = None
|
|
|
|
|
|
|
|
|
|
|
348 |
process_every_value: bool = False
|
349 |
get_default: Any = None
|
350 |
not_exist_ok: bool = False
|
351 |
|
352 |
def verify(self):
|
353 |
super().verify()
|
354 |
+
if self.use_query is not None:
|
355 |
+
depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
|
356 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
357 |
|
358 |
+
def verify_field_definition(self):
|
359 |
+
if hasattr(self, "_field_to_field") and self._field_to_field is not None:
|
360 |
+
return
|
361 |
+
assert (
|
362 |
+
(self.field is None) != (self.field_to_field is None)
|
363 |
+
), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
|
364 |
+
assert (
|
365 |
+
self.to_field is None or self.field_to_field is None
|
366 |
+
), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
|
367 |
+
|
368 |
+
if self.field_to_field is None:
|
369 |
+
self._field_to_field = [
|
370 |
+
(self.field, self.to_field if self.to_field is not None else self.field)
|
371 |
+
]
|
372 |
+
else:
|
373 |
+
self._field_to_field = (
|
374 |
+
list(self.field_to_field.items())
|
375 |
+
if isinstance(self.field_to_field, dict)
|
376 |
+
else self.field_to_field
|
377 |
+
)
|
378 |
assert (
|
379 |
self.field is not None or self.field_to_field is not None
|
380 |
), "Must supply a field to work on"
|
|
|
384 |
assert (
|
385 |
self.field is None or self.field_to_field is None
|
386 |
), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
|
387 |
+
assert (
|
388 |
+
self._field_to_field is not None
|
389 |
+
), f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
|
390 |
assert (
|
391 |
len(self._field_to_field) > 0
|
392 |
), f"'input argument 'field_to_field' should convey at least one field to process. Got {self.field_to_field}"
|
|
|
425 |
def process_instance_value(self, value: Any, instance: Dict[str, Any]):
|
426 |
pass
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
def process(
|
429 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
430 |
) -> Dict[str, Any]:
|
431 |
+
self.verify_field_definition()
|
432 |
# Need to deep copy instance, because when assigning two dictionary fields,
|
433 |
# dict_set() the target field dictionary fields.
|
434 |
# This means that if this target field was assigned to another field before,
|
|
|
476 |
pass
|
477 |
|
478 |
|
479 |
+
class Rename(FieldOperator):
|
480 |
"""Renames fields.
|
481 |
|
482 |
Move value from one field to another, potentially, if field name contains a /, from one branch into another.
|
483 |
Remove the from field, potentially part of it in case of / in from_field.
|
484 |
|
485 |
Examples:
|
486 |
+
Rename(field_to_field={"b": "c"})
|
487 |
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
|
488 |
|
489 |
+
Rename(field_to_field={"b": "c/d"})
|
490 |
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
|
491 |
|
492 |
+
Rename(field_to_field={"b": "b/d"})
|
493 |
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
|
494 |
|
495 |
+
Rename(field_to_field={"b/c/e": "b/d"})
|
496 |
will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
|
497 |
|
498 |
"""
|
|
|
513 |
return res
|
514 |
|
515 |
|
516 |
+
@deprecation(version="2.0.0", alternative=Rename)
|
517 |
+
class RenameFields(Rename):
|
518 |
+
pass
|
519 |
+
|
520 |
+
|
521 |
class AddConstant(FieldOperator):
|
522 |
"""Adds a constant, being argument 'add', to the processed value.
|
523 |
|
|
|
784 |
Args:
|
785 |
function (str): name of function.
|
786 |
to_field (str): the field to store the result
|
787 |
+
any additional arguments are field names whose values will be passed directly to the function specified
|
788 |
|
789 |
Examples:
|
790 |
Store in field "b" the uppercase string of the value in field "a"
|
|
|
853 |
|
854 |
fields: List[str]
|
855 |
to_field: str
|
856 |
+
use_query: Optional[bool] = None
|
857 |
+
|
858 |
+
def verify(self):
|
859 |
+
super().verify()
|
860 |
+
if self.use_query is not None:
|
861 |
+
depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
|
862 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
863 |
|
864 |
def process(
|
865 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
|
|
886 |
fields: List[str]
|
887 |
to_field: str
|
888 |
longest: bool = False
|
889 |
+
use_query: Optional[bool] = None
|
890 |
+
|
891 |
+
def verify(self):
|
892 |
+
super().verify()
|
893 |
+
if self.use_query is not None:
|
894 |
+
depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
|
895 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
896 |
|
897 |
def process(
|
898 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
|
|
959 |
search_in: str
|
960 |
index_of: str
|
961 |
to_field: str
|
962 |
+
use_query: Optional[bool] = None
|
963 |
+
|
964 |
+
def verify(self):
|
965 |
+
super().verify()
|
966 |
+
if self.use_query is not None:
|
967 |
+
depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
|
968 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
969 |
|
970 |
def process(
|
971 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
|
|
982 |
field: str
|
983 |
index: str
|
984 |
to_field: str = None
|
985 |
+
use_query: Optional[bool] = None
|
986 |
+
|
987 |
+
def verify(self):
|
988 |
+
super().verify()
|
989 |
+
if self.use_query is not None:
|
990 |
+
depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
|
991 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
992 |
|
993 |
def prepare(self):
|
994 |
if self.to_field is None:
|
|
|
1071 |
|
1072 |
"""
|
1073 |
|
1074 |
+
use_deep_copy: bool = True
|
1075 |
+
|
1076 |
def process_value(self, value: Any) -> Any:
|
1077 |
+
if self.use_deep_copy:
|
1078 |
+
return copy.deepcopy(value)
|
1079 |
+
return value
|
1080 |
|
1081 |
|
1082 |
@deprecation(version="2.0.0", alternative=Copy)
|
|
|
1105 |
return instance
|
1106 |
|
1107 |
|
1108 |
+
class Cast(FieldOperator):
|
1109 |
+
"""Casts specified fields to specified types.
|
1110 |
+
|
1111 |
+
Args:
|
1112 |
+
default (object): A dictionary mapping field names to default values for cases of casting failure.
|
1113 |
+
process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
|
1114 |
+
"""
|
1115 |
+
|
1116 |
+
to: str
|
1117 |
+
failure_default: Optional[Any] = "__UNDEFINED__"
|
1118 |
+
|
1119 |
+
def prepare(self):
|
1120 |
+
self.types = {"int": int, "float": float, "str": str, "bool": bool}
|
1121 |
+
|
1122 |
+
def process_value(self, value):
|
1123 |
+
try:
|
1124 |
+
return self.types[self.to](value)
|
1125 |
+
except ValueError as e:
|
1126 |
+
if self.failure_default == "__UNDEFINED__":
|
1127 |
+
raise ValueError(
|
1128 |
+
f'Failed to cast value {value} to type "{self.to}", and no default value is provided.'
|
1129 |
+
) from e
|
1130 |
+
return self.failure_default
|
1131 |
+
|
1132 |
+
|
1133 |
class CastFields(InstanceOperator):
|
1134 |
"""Casts specified fields to specified types.
|
1135 |
|
|
|
1287 |
|
1288 |
# we now have a list of nanes of operators, each is equipped with process_instance method.
|
1289 |
operator = SequentialOperator(steps=operator_names)
|
1290 |
+
return operator.process_instance(instance, stream_name=stream_name)
|
1291 |
|
1292 |
|
1293 |
class FilterByCondition(StreamOperator):
|
|
|
1807 |
operator, StreamingOperator
|
1808 |
), f"Operator {operator_name} must be a StreamOperator"
|
1809 |
|
1810 |
+
stream = operator(MultiStream({stream_name: stream}))[stream_name]
|
1811 |
|
1812 |
yield from stream
|
1813 |
|
|
|
1848 |
# Here we keep all the fields besides the score, and restore them after the metric finishes.
|
1849 |
first_instance = stream.peek()
|
1850 |
keys_to_restore = set(first_instance.keys()).difference({"score"})
|
1851 |
+
multi_stream = MultiStream({stream_name: stream})
|
1852 |
multi_stream = CopyFields(
|
1853 |
field_to_field={k: f"{k}_orig" for k in keys_to_restore}
|
1854 |
)(multi_stream)
|
|
|
1870 |
multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
|
1871 |
multi_stream
|
1872 |
)
|
1873 |
+
stream = multi_stream[stream_name]
|
1874 |
yield from stream
|
1875 |
|
1876 |
|
processors.py
CHANGED
@@ -1,10 +1,38 @@
|
|
1 |
import ast
|
|
|
2 |
import json
|
3 |
import re
|
4 |
from difflib import get_close_matches
|
5 |
from typing import Any, Dict
|
6 |
|
|
|
|
|
7 |
from .operators import FieldOperator, InstanceFieldOperator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class ToString(FieldOperator):
|
@@ -114,11 +142,16 @@ class LowerCaseTillPunc(FieldOperator):
|
|
114 |
return non_empty_line
|
115 |
|
116 |
|
117 |
-
class
|
118 |
def process_value(self, text: Any) -> Any:
|
119 |
return text.lower()
|
120 |
|
121 |
|
|
|
|
|
|
|
|
|
|
|
122 |
class Capitalize(FieldOperator):
|
123 |
def process_value(self, text: Any) -> Any:
|
124 |
return text.capitalize()
|
@@ -212,7 +245,7 @@ class StanceToProCon(FieldOperator):
|
|
212 |
return "none"
|
213 |
|
214 |
|
215 |
-
class
|
216 |
string: str
|
217 |
|
218 |
def process_value(self, text: Any) -> Any:
|
@@ -223,6 +256,11 @@ class StringOrNotString(FieldOperator):
|
|
223 |
return text
|
224 |
|
225 |
|
|
|
|
|
|
|
|
|
|
|
226 |
class ExtractMtBenchRatingJudgment(FieldOperator):
|
227 |
def process_value(self, text: Any) -> Any:
|
228 |
match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
|
|
|
1 |
import ast
|
2 |
+
import copy
|
3 |
import json
|
4 |
import re
|
5 |
from difflib import get_close_matches
|
6 |
from typing import Any, Dict
|
7 |
|
8 |
+
from .deprecation_utils import deprecation
|
9 |
+
from .operator import MultiStreamOperator
|
10 |
from .operators import FieldOperator, InstanceFieldOperator
|
11 |
+
from .settings_utils import get_constants
|
12 |
+
|
13 |
+
constants = get_constants()
|
14 |
+
|
15 |
+
|
16 |
+
class PostProcess(MultiStreamOperator):
|
17 |
+
operator: InstanceFieldOperator
|
18 |
+
process_prediction: bool = True
|
19 |
+
process_references: bool = True
|
20 |
+
|
21 |
+
def prepare(self):
|
22 |
+
super().prepare()
|
23 |
+
self.prediction_operator = copy.deepcopy(self.operator)
|
24 |
+
self.prediction_operator.field = "prediction"
|
25 |
+
self.references_operator = copy.deepcopy(self.operator)
|
26 |
+
self.references_operator.field = "references"
|
27 |
+
self.references_operator.process_every_value = True
|
28 |
+
self.references_operator.dont_apply_to_streams = [constants.inference_stream]
|
29 |
+
|
30 |
+
def process(self, multi_stream):
|
31 |
+
if self.process_prediction:
|
32 |
+
multi_stream = self.prediction_operator(multi_stream)
|
33 |
+
if self.process_references:
|
34 |
+
multi_stream = self.references_operator(multi_stream)
|
35 |
+
return multi_stream
|
36 |
|
37 |
|
38 |
class ToString(FieldOperator):
|
|
|
142 |
return non_empty_line
|
143 |
|
144 |
|
145 |
+
class Lower(FieldOperator):
|
146 |
def process_value(self, text: Any) -> Any:
|
147 |
return text.lower()
|
148 |
|
149 |
|
150 |
+
@deprecation("2.0.0", alternative=Lower)
|
151 |
+
class LowerCase(Lower):
|
152 |
+
pass
|
153 |
+
|
154 |
+
|
155 |
class Capitalize(FieldOperator):
|
156 |
def process_value(self, text: Any) -> Any:
|
157 |
return text.capitalize()
|
|
|
245 |
return "none"
|
246 |
|
247 |
|
248 |
+
class StringEquals(FieldOperator):
|
249 |
string: str
|
250 |
|
251 |
def process_value(self, text: Any) -> Any:
|
|
|
256 |
return text
|
257 |
|
258 |
|
259 |
+
@deprecation("2.0.0", alternative=StringEquals)
|
260 |
+
class StringOrNotString(StringEquals):
|
261 |
+
pass
|
262 |
+
|
263 |
+
|
264 |
class ExtractMtBenchRatingJudgment(FieldOperator):
|
265 |
def process_value(self, text: Any) -> Any:
|
266 |
match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
|
schema.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1 |
import json
|
2 |
-
from typing import Any, Dict, Optional
|
3 |
|
4 |
-
from datasets import Features, Sequence, Value
|
5 |
|
6 |
from .artifact import Artifact
|
|
|
7 |
from .operator import InstanceOperatorValidator
|
|
|
|
|
|
|
8 |
|
9 |
UNITXT_DATASET_SCHEMA = Features(
|
10 |
{
|
@@ -12,7 +16,24 @@ UNITXT_DATASET_SCHEMA = Features(
|
|
12 |
"target": Value("string"),
|
13 |
"references": Sequence(Value("string")),
|
14 |
"metrics": Sequence(Value("string")),
|
15 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
"postprocessors": Sequence(Value("string")),
|
17 |
"task_data": Value(dtype="string"),
|
18 |
"data_classification_policy": Sequence(Value("string")),
|
@@ -20,7 +41,14 @@ UNITXT_DATASET_SCHEMA = Features(
|
|
20 |
)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
class Finalize(InstanceOperatorValidator):
|
|
|
24 |
remove_unnecessary_fields: bool = True
|
25 |
|
26 |
@staticmethod
|
@@ -29,33 +57,63 @@ class Finalize(InstanceOperatorValidator):
|
|
29 |
return artifact.to_dict()
|
30 |
return artifact.__id__
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def process(
|
33 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
34 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
task_data = {
|
36 |
**instance["input_fields"],
|
37 |
-
|
38 |
-
"metadata": {
|
39 |
-
"data_classification_policy": instance["data_classification_policy"],
|
40 |
-
"template": self.artifact_to_jsonable(
|
41 |
-
instance["recipe_metadata"]["template"]
|
42 |
-
),
|
43 |
-
"num_demos": instance["recipe_metadata"]["num_demos"],
|
44 |
-
},
|
45 |
}
|
|
|
|
|
|
|
|
|
46 |
instance["task_data"] = json.dumps(task_data)
|
47 |
|
48 |
if self.remove_unnecessary_fields:
|
49 |
keys_to_delete = []
|
50 |
|
51 |
for key in instance.keys():
|
52 |
-
if key not in
|
53 |
keys_to_delete.append(key)
|
54 |
|
55 |
for key in keys_to_delete:
|
56 |
del instance[key]
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
instance["metrics"] = [
|
60 |
metric.to_json() if isinstance(metric, Artifact) else metric
|
61 |
for metric in instance["metrics"]
|
@@ -64,6 +122,7 @@ class Finalize(InstanceOperatorValidator):
|
|
64 |
processor.to_json() if isinstance(processor, Artifact) else processor
|
65 |
for processor in instance["postprocessors"]
|
66 |
]
|
|
|
67 |
return instance
|
68 |
|
69 |
def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
|
@@ -72,7 +131,8 @@ class Finalize(InstanceOperatorValidator):
|
|
72 |
assert isinstance(
|
73 |
instance, dict
|
74 |
), f"Instance should be a dict, got {type(instance)}"
|
|
|
75 |
assert all(
|
76 |
-
key in instance for key in
|
77 |
-
), f"Instance should have the following keys: {
|
78 |
-
|
|
|
1 |
import json
|
2 |
+
from typing import Any, Dict, List, Optional
|
3 |
|
4 |
+
from datasets import Audio, Features, Image, Sequence, Value
|
5 |
|
6 |
from .artifact import Artifact
|
7 |
+
from .dict_utils import dict_get
|
8 |
from .operator import InstanceOperatorValidator
|
9 |
+
from .settings_utils import get_constants
|
10 |
+
|
11 |
+
constants = get_constants()
|
12 |
|
13 |
UNITXT_DATASET_SCHEMA = Features(
|
14 |
{
|
|
|
16 |
"target": Value("string"),
|
17 |
"references": Sequence(Value("string")),
|
18 |
"metrics": Sequence(Value("string")),
|
19 |
+
"groups": Sequence(Value("string")),
|
20 |
+
"subset": Sequence(Value("string")),
|
21 |
+
"media": {
|
22 |
+
"images": Sequence(Image()),
|
23 |
+
"audios": Sequence(Audio()),
|
24 |
+
},
|
25 |
+
"postprocessors": Sequence(Value("string")),
|
26 |
+
"task_data": Value(dtype="string"),
|
27 |
+
"data_classification_policy": Sequence(Value("string")),
|
28 |
+
}
|
29 |
+
)
|
30 |
+
|
31 |
+
UNITXT_INFERENCE_SCHEMA = Features(
|
32 |
+
{
|
33 |
+
"source": Value("string"),
|
34 |
+
"metrics": Sequence(Value("string")),
|
35 |
+
"groups": Sequence(Value("string")),
|
36 |
+
"subset": Sequence(Value("string")),
|
37 |
"postprocessors": Sequence(Value("string")),
|
38 |
"task_data": Value(dtype="string"),
|
39 |
"data_classification_policy": Sequence(Value("string")),
|
|
|
41 |
)
|
42 |
|
43 |
|
44 |
+
def get_schema(stream_name):
|
45 |
+
if stream_name == constants.inference_stream:
|
46 |
+
return UNITXT_INFERENCE_SCHEMA
|
47 |
+
return UNITXT_DATASET_SCHEMA
|
48 |
+
|
49 |
+
|
50 |
class Finalize(InstanceOperatorValidator):
|
51 |
+
group_by: List[List[str]]
|
52 |
remove_unnecessary_fields: bool = True
|
53 |
|
54 |
@staticmethod
|
|
|
57 |
return artifact.to_dict()
|
58 |
return artifact.__id__
|
59 |
|
60 |
+
def _prepare_media(self, instance):
|
61 |
+
if "media" not in instance:
|
62 |
+
instance["media"] = {}
|
63 |
+
|
64 |
+
if "images" not in instance["media"]:
|
65 |
+
instance["media"]["images"] = []
|
66 |
+
|
67 |
+
if "audios" not in instance["media"]:
|
68 |
+
instance["media"]["audios"] = []
|
69 |
+
|
70 |
+
return instance
|
71 |
+
|
72 |
def process(
|
73 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
74 |
) -> Dict[str, Any]:
|
75 |
+
metadata = {
|
76 |
+
"data_classification_policy": instance["data_classification_policy"],
|
77 |
+
"template": self.artifact_to_jsonable(
|
78 |
+
instance["recipe_metadata"]["template"]
|
79 |
+
),
|
80 |
+
"num_demos": instance["recipe_metadata"]["num_demos"],
|
81 |
+
}
|
82 |
task_data = {
|
83 |
**instance["input_fields"],
|
84 |
+
"metadata": metadata,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
}
|
86 |
+
|
87 |
+
if stream_name != constants.inference_stream:
|
88 |
+
task_data = {**task_data, **instance["reference_fields"]}
|
89 |
+
|
90 |
instance["task_data"] = json.dumps(task_data)
|
91 |
|
92 |
if self.remove_unnecessary_fields:
|
93 |
keys_to_delete = []
|
94 |
|
95 |
for key in instance.keys():
|
96 |
+
if key not in get_schema(stream_name):
|
97 |
keys_to_delete.append(key)
|
98 |
|
99 |
for key in keys_to_delete:
|
100 |
del instance[key]
|
101 |
+
|
102 |
+
data = {**task_data, **metadata}
|
103 |
+
groups = []
|
104 |
+
for group_attributes in self.group_by:
|
105 |
+
group = {}
|
106 |
+
if isinstance(group_attributes, str):
|
107 |
+
group_attributes = [group_attributes]
|
108 |
+
for attribute in group_attributes:
|
109 |
+
group[attribute] = dict_get(data, attribute)
|
110 |
+
groups.append(json.dumps(group))
|
111 |
+
|
112 |
+
instance["groups"] = groups
|
113 |
+
instance["subset"] = []
|
114 |
+
|
115 |
+
instance = self._prepare_media(instance)
|
116 |
+
|
117 |
instance["metrics"] = [
|
118 |
metric.to_json() if isinstance(metric, Artifact) else metric
|
119 |
for metric in instance["metrics"]
|
|
|
122 |
processor.to_json() if isinstance(processor, Artifact) else processor
|
123 |
for processor in instance["postprocessors"]
|
124 |
]
|
125 |
+
|
126 |
return instance
|
127 |
|
128 |
def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
|
|
|
131 |
assert isinstance(
|
132 |
instance, dict
|
133 |
), f"Instance should be a dict, got {type(instance)}"
|
134 |
+
schema = get_schema(stream_name)
|
135 |
assert all(
|
136 |
+
key in instance for key in schema
|
137 |
+
), f"Instance should have the following keys: {schema}. Instance is: {instance}"
|
138 |
+
schema.encode_example(instance)
|
settings_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import importlib.util
|
2 |
import os
|
|
|
3 |
|
4 |
from .version import version
|
5 |
|
@@ -87,6 +88,17 @@ class Settings:
|
|
87 |
self.environment_variable_key_name(key) for key in self._settings.keys()
|
88 |
]
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
class Constants:
|
92 |
_instance = None
|
@@ -163,6 +175,8 @@ if Constants.is_uninitilized():
|
|
163 |
]
|
164 |
constants.codebase_url = "https://github.com/IBM/unitxt"
|
165 |
constants.website_url = "https://www.unitxt.org"
|
|
|
|
|
166 |
|
167 |
|
168 |
def get_settings():
|
|
|
1 |
import importlib.util
|
2 |
import os
|
3 |
+
from contextlib import contextmanager
|
4 |
|
5 |
from .version import version
|
6 |
|
|
|
88 |
self.environment_variable_key_name(key) for key in self._settings.keys()
|
89 |
]
|
90 |
|
91 |
+
@contextmanager
|
92 |
+
def context(self, **kwargs):
|
93 |
+
old_values = {key: self._settings.get(key, None) for key in kwargs}
|
94 |
+
try:
|
95 |
+
for key, value in kwargs.items():
|
96 |
+
self.__setattr__(key, value)
|
97 |
+
yield
|
98 |
+
finally:
|
99 |
+
for key, value in old_values.items():
|
100 |
+
self.__setattr__(key, value)
|
101 |
+
|
102 |
|
103 |
class Constants:
|
104 |
_instance = None
|
|
|
175 |
]
|
176 |
constants.codebase_url = "https://github.com/IBM/unitxt"
|
177 |
constants.website_url = "https://www.unitxt.org"
|
178 |
+
constants.inference_stream = "__INFERENCE_STREAM__"
|
179 |
+
constants.instance_stream = "__INSTANCE_STREAM__"
|
180 |
|
181 |
|
182 |
def get_settings():
|
standard.py
CHANGED
@@ -9,11 +9,14 @@ from .operator import SequentialOperator, SourceSequentialOperator, StreamingOpe
|
|
9 |
from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
|
10 |
from .recipe import Recipe
|
11 |
from .schema import Finalize
|
|
|
12 |
from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
|
13 |
from .stream import MultiStream
|
14 |
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
|
|
15 |
from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template
|
16 |
|
|
|
17 |
logger = get_logger()
|
18 |
|
19 |
|
@@ -24,7 +27,8 @@ class CreateDemosPool(SeparateSplit):
|
|
24 |
|
25 |
class BaseRecipe(Recipe, SourceSequentialOperator):
|
26 |
# Base parameters
|
27 |
-
card: TaskCard
|
|
|
28 |
template: Union[Template, List[Template]] = None
|
29 |
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
30 |
format: Format = Field(default_factory=SystemFormat)
|
@@ -34,6 +38,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
34 |
metrics: List[str] = NonPositionalField(default=None)
|
35 |
postprocessors: List[str] = NonPositionalField(default=None)
|
36 |
|
|
|
|
|
37 |
loader_limit: int = None
|
38 |
|
39 |
max_train_instances: int = None
|
@@ -68,6 +74,17 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
68 |
|
69 |
def verify(self):
|
70 |
super().verify()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
if self.use_demos:
|
72 |
if self.demos_pool_size is None or self.demos_pool_size < 1:
|
73 |
raise ValueError(
|
@@ -143,19 +160,18 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
143 |
)
|
144 |
|
145 |
def set_pipelines(self):
|
146 |
-
self.loading = SequentialOperator(
|
147 |
-
|
148 |
-
self.metadata = SequentialOperator()
|
149 |
-
self.metadata.__description__ = (
|
150 |
-
"Adding metadata (e.g. format, system prompt, template) "
|
151 |
)
|
152 |
-
self.
|
153 |
-
|
154 |
-
"Standardizing the raw dataset fields to task field definition."
|
155 |
)
|
156 |
-
self.
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
159 |
)
|
160 |
self.verbalization = SequentialOperator()
|
161 |
self.verbalization.__description__ = "Verbalizing the input to the model and gold references to the 'source', 'target' and 'references' fields."
|
@@ -197,8 +213,8 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
197 |
self._demos_pool_cache = None
|
198 |
|
199 |
def production_preprocess(self, task_instances):
|
200 |
-
ms = MultiStream.from_iterables({
|
201 |
-
return list(self.inference_instance(ms)[
|
202 |
|
203 |
def production_demos_pool(self):
|
204 |
if self.use_demos:
|
@@ -222,30 +238,34 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
222 |
self.before_process_multi_stream()
|
223 |
multi_stream = MultiStream.from_iterables(
|
224 |
{
|
225 |
-
|
226 |
self.demos_pool_name: self.production_demos_pool(),
|
227 |
}
|
228 |
)
|
229 |
multi_stream = self.inference(multi_stream)
|
230 |
-
return list(multi_stream[
|
231 |
|
232 |
-
def
|
233 |
-
|
234 |
-
# and then set it to an empty list if it is None.
|
235 |
-
if self.card.preprocess_steps is None:
|
236 |
self.card.preprocess_steps = []
|
237 |
|
238 |
-
self.
|
|
|
239 |
|
240 |
-
|
241 |
-
if self.loader_limit:
|
242 |
-
loader.loader_limit = self.loader_limit
|
243 |
-
logger.info(f"Loader line limit was set to {self.loader_limit}")
|
244 |
-
self.loading.steps.append(loader)
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
self.metadata.steps.append(
|
251 |
Set(
|
@@ -256,9 +276,10 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
256 |
)
|
257 |
)
|
258 |
|
259 |
-
self.
|
|
|
260 |
|
261 |
-
self.processing.steps.append(self.
|
262 |
|
263 |
if self.augmentor.augment_task_input:
|
264 |
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
|
@@ -352,7 +373,10 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
352 |
if self.metrics is not None:
|
353 |
self.finalize.steps.append(Set(fields={"metrics": self.metrics}))
|
354 |
|
355 |
-
self.finalize.steps.append(Finalize())
|
|
|
|
|
|
|
356 |
|
357 |
|
358 |
class StandardRecipeWithIndexes(BaseRecipe):
|
@@ -395,6 +419,7 @@ class StandardRecipe(StandardRecipeWithIndexes):
|
|
395 |
format (SystemFormat, optional): SystemFormat object to be used for the recipe.
|
396 |
metrics (List[str]): list of catalog metrics to use with this recipe.
|
397 |
postprocessors (List[str]): list of catalog processors to apply at post processing. (Not recommended to use from here)
|
|
|
398 |
train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
|
399 |
max_train_instances (int, optional): Maximum training instances for the refiner.
|
400 |
validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
|
|
|
9 |
from .operators import Augmentor, NullAugmentor, Set, StreamRefiner
|
10 |
from .recipe import Recipe
|
11 |
from .schema import Finalize
|
12 |
+
from .settings_utils import get_constants
|
13 |
from .splitters import ConstantSizeSample, RandomSizeSample, Sampler, SeparateSplit
|
14 |
from .stream import MultiStream
|
15 |
from .system_prompts import EmptySystemPrompt, SystemPrompt
|
16 |
+
from .task import Task
|
17 |
from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template
|
18 |
|
19 |
+
constants = get_constants()
|
20 |
logger = get_logger()
|
21 |
|
22 |
|
|
|
27 |
|
28 |
class BaseRecipe(Recipe, SourceSequentialOperator):
|
29 |
# Base parameters
|
30 |
+
card: TaskCard = None
|
31 |
+
task: Task = None
|
32 |
template: Union[Template, List[Template]] = None
|
33 |
system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
|
34 |
format: Format = Field(default_factory=SystemFormat)
|
|
|
38 |
metrics: List[str] = NonPositionalField(default=None)
|
39 |
postprocessors: List[str] = NonPositionalField(default=None)
|
40 |
|
41 |
+
group_by: List[Union[str, List[str]]] = []
|
42 |
+
|
43 |
loader_limit: int = None
|
44 |
|
45 |
max_train_instances: int = None
|
|
|
74 |
|
75 |
def verify(self):
|
76 |
super().verify()
|
77 |
+
|
78 |
+
if self.task is None and self.card is None:
|
79 |
+
raise ValueError("Set card or task in the recipe")
|
80 |
+
|
81 |
+
if self.card is None and (
|
82 |
+
self.num_demos > 0 or self.demos_pool_size is not None
|
83 |
+
):
|
84 |
+
raise ValueError(
|
85 |
+
"To use num_demos and demos_pool_size in recipe set a card."
|
86 |
+
)
|
87 |
+
|
88 |
if self.use_demos:
|
89 |
if self.demos_pool_size is None or self.demos_pool_size < 1:
|
90 |
raise ValueError(
|
|
|
160 |
)
|
161 |
|
162 |
def set_pipelines(self):
|
163 |
+
self.loading = SequentialOperator(
|
164 |
+
__description__="Loading the data from the data source."
|
|
|
|
|
|
|
165 |
)
|
166 |
+
self.metadata = SequentialOperator(
|
167 |
+
__description__="Adding metadata (e.g. format, system prompt, template) "
|
|
|
168 |
)
|
169 |
+
self.standardization = SequentialOperator(
|
170 |
+
__description__="Standardizing the raw dataset fields to task field definition."
|
171 |
+
)
|
172 |
+
|
173 |
+
self.processing = SequentialOperator(
|
174 |
+
__description__="Setting task fields (and selecting demos per sample if needed)."
|
175 |
)
|
176 |
self.verbalization = SequentialOperator()
|
177 |
self.verbalization.__description__ = "Verbalizing the input to the model and gold references to the 'source', 'target' and 'references' fields."
|
|
|
213 |
self._demos_pool_cache = None
|
214 |
|
215 |
def production_preprocess(self, task_instances):
|
216 |
+
ms = MultiStream.from_iterables({constants.inference_stream: task_instances})
|
217 |
+
return list(self.inference_instance(ms)[constants.inference_stream])
|
218 |
|
219 |
def production_demos_pool(self):
|
220 |
if self.use_demos:
|
|
|
238 |
self.before_process_multi_stream()
|
239 |
multi_stream = MultiStream.from_iterables(
|
240 |
{
|
241 |
+
constants.inference_stream: self.production_preprocess(task_instances),
|
242 |
self.demos_pool_name: self.production_demos_pool(),
|
243 |
}
|
244 |
)
|
245 |
multi_stream = self.inference(multi_stream)
|
246 |
+
return list(multi_stream[constants.inference_stream])
|
247 |
|
248 |
+
def reset_pipeline(self):
|
249 |
+
if self.card and self.card.preprocess_steps is None:
|
|
|
|
|
250 |
self.card.preprocess_steps = []
|
251 |
|
252 |
+
if self.task is None:
|
253 |
+
self.task = self.card.task
|
254 |
|
255 |
+
self.set_pipelines()
|
|
|
|
|
|
|
|
|
256 |
|
257 |
+
if self.card is not None:
|
258 |
+
loader = self.card.loader
|
259 |
+
if self.loader_limit:
|
260 |
+
loader.loader_limit = self.loader_limit
|
261 |
+
logger.info(f"Loader line limit was set to {self.loader_limit}")
|
262 |
+
self.loading.steps.append(loader)
|
263 |
+
|
264 |
+
# This is required in case loader_limit is not enforced by the loader
|
265 |
+
if self.loader_limit:
|
266 |
+
self.loading.steps.append(
|
267 |
+
StreamRefiner(max_instances=self.loader_limit)
|
268 |
+
)
|
269 |
|
270 |
self.metadata.steps.append(
|
271 |
Set(
|
|
|
276 |
)
|
277 |
)
|
278 |
|
279 |
+
if self.card:
|
280 |
+
self.standardization.steps.extend(self.card.preprocess_steps)
|
281 |
|
282 |
+
self.processing.steps.append(self.task)
|
283 |
|
284 |
if self.augmentor.augment_task_input:
|
285 |
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
|
|
|
373 |
if self.metrics is not None:
|
374 |
self.finalize.steps.append(Set(fields={"metrics": self.metrics}))
|
375 |
|
376 |
+
self.finalize.steps.append(Finalize(group_by=self.group_by))
|
377 |
+
|
378 |
+
def prepare(self):
|
379 |
+
self.reset_pipeline()
|
380 |
|
381 |
|
382 |
class StandardRecipeWithIndexes(BaseRecipe):
|
|
|
419 |
format (SystemFormat, optional): SystemFormat object to be used for the recipe.
|
420 |
metrics (List[str]): list of catalog metrics to use with this recipe.
|
421 |
postprocessors (List[str]): list of catalog processors to apply at post processing. (Not recommended to use from here)
|
422 |
+
group_by (List[Union[str, List[str]]]): list of task_data or metadata keys to group global scores by.
|
423 |
train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
|
424 |
max_train_instances (int, optional): Maximum training instances for the refiner.
|
425 |
validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
|
stream.py
CHANGED
@@ -222,7 +222,9 @@ class MultiStream(dict):
|
|
222 |
for stream in self.values():
|
223 |
stream.set_copying(copying)
|
224 |
|
225 |
-
def to_dataset(
|
|
|
|
|
226 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
227 |
cache_dir = dir_to_be_deleted if disable_cache else cache_dir
|
228 |
return DatasetDict(
|
@@ -232,6 +234,7 @@ class MultiStream(dict):
|
|
232 |
keep_in_memory=disable_cache,
|
233 |
cache_dir=cache_dir,
|
234 |
gen_kwargs={"key": key},
|
|
|
235 |
)
|
236 |
for key in self.keys()
|
237 |
}
|
@@ -281,7 +284,10 @@ class MultiStream(dict):
|
|
281 |
|
282 |
@classmethod
|
283 |
def from_iterables(
|
284 |
-
cls,
|
|
|
|
|
|
|
285 |
):
|
286 |
"""Creates a MultiStream from a dictionary of iterables.
|
287 |
|
|
|
222 |
for stream in self.values():
|
223 |
stream.set_copying(copying)
|
224 |
|
225 |
+
def to_dataset(
|
226 |
+
self, disable_cache=True, cache_dir=None, features=None
|
227 |
+
) -> DatasetDict:
|
228 |
with tempfile.TemporaryDirectory() as dir_to_be_deleted:
|
229 |
cache_dir = dir_to_be_deleted if disable_cache else cache_dir
|
230 |
return DatasetDict(
|
|
|
234 |
keep_in_memory=disable_cache,
|
235 |
cache_dir=cache_dir,
|
236 |
gen_kwargs={"key": key},
|
237 |
+
features=features,
|
238 |
)
|
239 |
for key in self.keys()
|
240 |
}
|
|
|
284 |
|
285 |
@classmethod
|
286 |
def from_iterables(
|
287 |
+
cls,
|
288 |
+
iterables: Dict[str, Iterable[Dict[str, Any]]],
|
289 |
+
caching=False,
|
290 |
+
copying=False,
|
291 |
):
|
292 |
"""Creates a MultiStream from a dictionary of iterables.
|
293 |
|
struct_data_operators.py
CHANGED
@@ -623,3 +623,40 @@ class MapTableListsToStdTableJSON(FieldOperator):
|
|
623 |
|
624 |
def map_tablelists_to_stdtablejson_util(self, table_content: str) -> Dict:
|
625 |
return {"header": table_content[0], "rows": table_content[1:]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
|
624 |
def map_tablelists_to_stdtablejson_util(self, table_content: str) -> Dict:
|
625 |
return {"header": table_content[0], "rows": table_content[1:]}
|
626 |
+
|
627 |
+
|
628 |
+
class ConstructTableFromRowsCols(InstanceOperator):
|
629 |
+
"""Maps column and row field into single table field encompassing both header and rows.
|
630 |
+
|
631 |
+
field[0] = header string as List
|
632 |
+
field[1] = rows string as List[List]
|
633 |
+
field[2] = table caption string(optional)
|
634 |
+
"""
|
635 |
+
|
636 |
+
fields: List[str]
|
637 |
+
to_field: str
|
638 |
+
|
639 |
+
def process(
|
640 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
641 |
+
) -> Dict[str, Any]:
|
642 |
+
header = dict_get(instance, self.fields[0])
|
643 |
+
rows = dict_get(instance, self.fields[1])
|
644 |
+
|
645 |
+
if len(self.fields) >= 3:
|
646 |
+
caption = instance[self.fields[2]]
|
647 |
+
else:
|
648 |
+
caption = None
|
649 |
+
|
650 |
+
import ast
|
651 |
+
|
652 |
+
header_processed = ast.literal_eval(header)
|
653 |
+
rows_processed = ast.literal_eval(rows)
|
654 |
+
|
655 |
+
output_dict = {"header": header_processed, "rows": rows_processed}
|
656 |
+
|
657 |
+
if caption is not None:
|
658 |
+
output_dict["caption"] = caption
|
659 |
+
|
660 |
+
instance[self.to_field] = output_dict
|
661 |
+
|
662 |
+
return instance
|
task.py
CHANGED
@@ -1,11 +1,13 @@
|
|
|
|
1 |
from functools import lru_cache
|
2 |
from typing import Any, Dict, List, Optional, Union
|
3 |
|
4 |
from .artifact import fetch_artifact
|
5 |
-
from .dataclass import DeprecatedField
|
6 |
from .deprecation_utils import deprecation
|
7 |
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
|
|
8 |
from .operator import InstanceOperator
|
|
|
9 |
from .type_utils import (
|
10 |
Type,
|
11 |
get_args,
|
@@ -19,6 +21,9 @@ from .type_utils import (
|
|
19 |
verify_required_schema,
|
20 |
)
|
21 |
|
|
|
|
|
|
|
22 |
|
23 |
@deprecation(
|
24 |
version="2.0.0",
|
@@ -57,18 +62,8 @@ class Task(InstanceOperator):
|
|
57 |
|
58 |
input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
59 |
reference_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
60 |
-
inputs: Union[Dict[str, Type], Dict[str, str], List[str]] =
|
61 |
-
|
62 |
-
metadata={
|
63 |
-
"deprecation_msg": "The 'inputs' field is deprecated. Please use 'input_fields' instead."
|
64 |
-
},
|
65 |
-
)
|
66 |
-
outputs: Union[Dict[str, Type], Dict[str, str], List[str]] = DeprecatedField(
|
67 |
-
default=None,
|
68 |
-
metadata={
|
69 |
-
"deprecation_msg": "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
|
70 |
-
},
|
71 |
-
)
|
72 |
metrics: List[str]
|
73 |
prediction_type: Optional[Union[Type, str]] = None
|
74 |
augmentable_inputs: List[str] = []
|
@@ -108,6 +103,16 @@ class Task(InstanceOperator):
|
|
108 |
)
|
109 |
|
110 |
def verify(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if self.input_fields is None:
|
112 |
raise UnitxtError(
|
113 |
"Missing attribute in task: 'input_fields' not set.",
|
@@ -249,19 +254,26 @@ class Task(InstanceOperator):
|
|
249 |
instance = self.set_default_values(instance)
|
250 |
|
251 |
verify_required_schema(self.input_fields, instance)
|
252 |
-
verify_required_schema(self.reference_fields, instance)
|
253 |
-
|
254 |
input_fields = {key: instance[key] for key in self.input_fields.keys()}
|
255 |
-
reference_fields = {key: instance[key] for key in self.reference_fields.keys()}
|
256 |
data_classification_policy = instance.get("data_classification_policy", [])
|
257 |
|
258 |
-
|
259 |
"input_fields": input_fields,
|
260 |
-
"reference_fields": reference_fields,
|
261 |
"metrics": self.metrics,
|
262 |
"data_classification_policy": data_classification_policy,
|
|
|
263 |
}
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
@deprecation(version="2.0.0", alternative=Task)
|
267 |
class FormTask(Task):
|
|
|
1 |
+
import warnings
|
2 |
from functools import lru_cache
|
3 |
from typing import Any, Dict, List, Optional, Union
|
4 |
|
5 |
from .artifact import fetch_artifact
|
|
|
6 |
from .deprecation_utils import deprecation
|
7 |
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
8 |
+
from .logging_utils import get_logger
|
9 |
from .operator import InstanceOperator
|
10 |
+
from .settings_utils import get_constants
|
11 |
from .type_utils import (
|
12 |
Type,
|
13 |
get_args,
|
|
|
21 |
verify_required_schema,
|
22 |
)
|
23 |
|
24 |
+
constants = get_constants()
|
25 |
+
logger = get_logger()
|
26 |
+
|
27 |
|
28 |
@deprecation(
|
29 |
version="2.0.0",
|
|
|
62 |
|
63 |
input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
64 |
reference_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
65 |
+
inputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
66 |
+
outputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
metrics: List[str]
|
68 |
prediction_type: Optional[Union[Type, str]] = None
|
69 |
augmentable_inputs: List[str] = []
|
|
|
103 |
)
|
104 |
|
105 |
def verify(self):
|
106 |
+
if hasattr(self, "inputs") and self.inputs is not None:
|
107 |
+
depr_message = (
|
108 |
+
"The 'inputs' field is deprecated. Please use 'input_fields' instead."
|
109 |
+
)
|
110 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
111 |
+
|
112 |
+
if hasattr(self, "outputs") and self.outputs is not None:
|
113 |
+
depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
|
114 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
115 |
+
|
116 |
if self.input_fields is None:
|
117 |
raise UnitxtError(
|
118 |
"Missing attribute in task: 'input_fields' not set.",
|
|
|
254 |
instance = self.set_default_values(instance)
|
255 |
|
256 |
verify_required_schema(self.input_fields, instance)
|
|
|
|
|
257 |
input_fields = {key: instance[key] for key in self.input_fields.keys()}
|
|
|
258 |
data_classification_policy = instance.get("data_classification_policy", [])
|
259 |
|
260 |
+
result = {
|
261 |
"input_fields": input_fields,
|
|
|
262 |
"metrics": self.metrics,
|
263 |
"data_classification_policy": data_classification_policy,
|
264 |
+
"media": instance.get("media", {}),
|
265 |
}
|
266 |
|
267 |
+
if stream_name == constants.inference_stream:
|
268 |
+
return result
|
269 |
+
|
270 |
+
verify_required_schema(self.reference_fields, instance)
|
271 |
+
result["reference_fields"] = {
|
272 |
+
key: instance[key] for key in self.reference_fields.keys()
|
273 |
+
}
|
274 |
+
|
275 |
+
return result
|
276 |
+
|
277 |
|
278 |
@deprecation(version="2.0.0", alternative=Task)
|
279 |
class FormTask(Task):
|
templates.py
CHANGED
@@ -10,8 +10,11 @@ from .dict_utils import dict_set
|
|
10 |
from .error_utils import Documentation, UnitxtError
|
11 |
from .operator import InstanceOperator
|
12 |
from .random_utils import new_random_generator
|
|
|
13 |
from .type_utils import isoftype
|
14 |
|
|
|
|
|
15 |
|
16 |
class TemplateFormatKeyError(UnitxtError):
|
17 |
def __init__(self, template, data, data_type, format_str, format_name):
|
@@ -84,20 +87,30 @@ class Template(InstanceOperator):
|
|
84 |
instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
|
85 |
input_fields
|
86 |
)
|
87 |
-
target, references = self.reference_fields_to_target_and_references(
|
88 |
-
reference_fields
|
89 |
-
)
|
90 |
|
91 |
-
|
92 |
**instance,
|
93 |
"source": source,
|
94 |
-
"target": target,
|
95 |
-
"references": references,
|
96 |
"instruction": instruction,
|
97 |
"target_prefix": target_prefix,
|
98 |
"postprocessors": self.postprocessors,
|
99 |
}
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
@abstractmethod
|
102 |
def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
|
103 |
pass
|
@@ -143,8 +156,13 @@ class ApplyTemplate(InstanceOperator):
|
|
143 |
def get_template(self, instance: Dict[str, Any]) -> Template:
|
144 |
pass
|
145 |
|
146 |
-
def apply(
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
def process(
|
150 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
@@ -155,11 +173,11 @@ class ApplyTemplate(InstanceOperator):
|
|
155 |
if self.demos_field not in instance:
|
156 |
raise ValueError("Demos field is missing.")
|
157 |
instance[self.demos_field] = [
|
158 |
-
self.apply(template, demo_instance)
|
159 |
for demo_instance in instance[self.demos_field]
|
160 |
]
|
161 |
dict_set(instance, "recipe_metadata/template", template)
|
162 |
-
return self.apply(template, instance)
|
163 |
|
164 |
|
165 |
class ApplySingleTemplate(ApplyTemplate):
|
@@ -268,6 +286,9 @@ class PairwiseChoiceTemplate(InputOutputTemplate):
|
|
268 |
choice_tie_label: str
|
269 |
shuffle: bool
|
270 |
|
|
|
|
|
|
|
271 |
def verbalize_answer_field(self, reference_fields: Dict[str, object]):
|
272 |
answer = reference_fields[self.answer_field]
|
273 |
assert answer in ["choice_a", "choice_b", "tie"]
|
@@ -552,34 +573,45 @@ class MultipleChoiceTemplate(Template):
|
|
552 |
|
553 |
return target, [target]
|
554 |
|
555 |
-
def _shuffle_choices(self, instance):
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
|
|
560 |
choices = instance["input_fields"][self.choices_field]
|
561 |
-
|
562 |
-
|
563 |
-
|
|
|
564 |
random_generator.shuffle(choices)
|
565 |
instance["input_fields"][self.choices_field] = choices
|
|
|
|
|
|
|
|
|
566 |
instance["reference_fields"][self.choices_field] = choices
|
567 |
instance["reference_fields"][self.target_field] = choices.index(
|
568 |
original_label_choice
|
569 |
)
|
|
|
570 |
return instance
|
571 |
|
572 |
def process(
|
573 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
574 |
) -> Dict[str, Any]:
|
575 |
if self.shuffle_choices:
|
576 |
-
instance = self._shuffle_choices(instance)
|
577 |
result = super().process(instance, stream_name)
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
instance["reference_fields"], self.target_choice_format
|
582 |
)
|
|
|
|
|
|
|
|
|
|
|
583 |
return result
|
584 |
|
585 |
|
|
|
10 |
from .error_utils import Documentation, UnitxtError
|
11 |
from .operator import InstanceOperator
|
12 |
from .random_utils import new_random_generator
|
13 |
+
from .settings_utils import get_constants
|
14 |
from .type_utils import isoftype
|
15 |
|
16 |
+
constants = get_constants()
|
17 |
+
|
18 |
|
19 |
class TemplateFormatKeyError(UnitxtError):
|
20 |
def __init__(self, template, data, data_type, format_str, format_name):
|
|
|
87 |
instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
|
88 |
input_fields
|
89 |
)
|
|
|
|
|
|
|
90 |
|
91 |
+
result = {
|
92 |
**instance,
|
93 |
"source": source,
|
|
|
|
|
94 |
"instruction": instruction,
|
95 |
"target_prefix": target_prefix,
|
96 |
"postprocessors": self.postprocessors,
|
97 |
}
|
98 |
|
99 |
+
if stream_name == constants.inference_stream:
|
100 |
+
return result
|
101 |
+
|
102 |
+
if reference_fields is None:
|
103 |
+
raise ValueError("Should have reference_fields")
|
104 |
+
|
105 |
+
target, references = self.reference_fields_to_target_and_references(
|
106 |
+
reference_fields
|
107 |
+
)
|
108 |
+
|
109 |
+
result["target"] = target
|
110 |
+
result["references"] = references
|
111 |
+
|
112 |
+
return result
|
113 |
+
|
114 |
@abstractmethod
|
115 |
def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
|
116 |
pass
|
|
|
156 |
def get_template(self, instance: Dict[str, Any]) -> Template:
|
157 |
pass
|
158 |
|
159 |
+
def apply(
|
160 |
+
self,
|
161 |
+
template: Template,
|
162 |
+
instance: Dict[str, Any],
|
163 |
+
stream_name: Optional[str] = None,
|
164 |
+
):
|
165 |
+
return template.process_instance(instance, stream_name)
|
166 |
|
167 |
def process(
|
168 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
|
|
173 |
if self.demos_field not in instance:
|
174 |
raise ValueError("Demos field is missing.")
|
175 |
instance[self.demos_field] = [
|
176 |
+
self.apply(template, demo_instance, stream_name)
|
177 |
for demo_instance in instance[self.demos_field]
|
178 |
]
|
179 |
dict_set(instance, "recipe_metadata/template", template)
|
180 |
+
return self.apply(template, instance, stream_name)
|
181 |
|
182 |
|
183 |
class ApplySingleTemplate(ApplyTemplate):
|
|
|
286 |
choice_tie_label: str
|
287 |
shuffle: bool
|
288 |
|
289 |
+
def verify(self):
|
290 |
+
super().verify()
|
291 |
+
|
292 |
def verbalize_answer_field(self, reference_fields: Dict[str, object]):
|
293 |
answer = reference_fields[self.answer_field]
|
294 |
assert answer in ["choice_a", "choice_b", "tie"]
|
|
|
573 |
|
574 |
return target, [target]
|
575 |
|
576 |
+
def _shuffle_choices(self, instance, stream_name):
|
577 |
+
if stream_name != constants.inference_stream:
|
578 |
+
target_index = self.outputs_to_target_index(instance["reference_fields"])
|
579 |
+
original_label_choice = instance["reference_fields"][self.choices_field][
|
580 |
+
target_index
|
581 |
+
]
|
582 |
choices = instance["input_fields"][self.choices_field]
|
583 |
+
|
584 |
+
random_seed = {**instance["input_fields"]}
|
585 |
+
|
586 |
+
random_generator = new_random_generator(random_seed)
|
587 |
random_generator.shuffle(choices)
|
588 |
instance["input_fields"][self.choices_field] = choices
|
589 |
+
|
590 |
+
if stream_name == constants.inference_stream:
|
591 |
+
return instance
|
592 |
+
|
593 |
instance["reference_fields"][self.choices_field] = choices
|
594 |
instance["reference_fields"][self.target_field] = choices.index(
|
595 |
original_label_choice
|
596 |
)
|
597 |
+
|
598 |
return instance
|
599 |
|
600 |
def process(
|
601 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
602 |
) -> Dict[str, Any]:
|
603 |
if self.shuffle_choices:
|
604 |
+
instance = self._shuffle_choices(instance, stream_name)
|
605 |
result = super().process(instance, stream_name)
|
606 |
+
if stream_name == constants.inference_stream:
|
607 |
+
result["input_fields"]["options"] = self.inputs_to_choices(
|
608 |
+
instance["input_fields"], self.target_choice_format
|
|
|
609 |
)
|
610 |
+
else:
|
611 |
+
if "options" not in result["reference_fields"]:
|
612 |
+
result["reference_fields"]["options"] = self.inputs_to_choices(
|
613 |
+
instance["reference_fields"], self.target_choice_format
|
614 |
+
)
|
615 |
return result
|
616 |
|
617 |
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.12.
|
|
|
1 |
+
version = "1.12.4"
|