Elron commited on
Commit
cc5f321
1 Parent(s): d389578

Upload folder using huggingface_hub

Browse files
api.py CHANGED
@@ -1,10 +1,10 @@
 
1
  from functools import lru_cache
2
  from typing import Any, Dict, List, Optional, Union
3
 
4
- from datasets import DatasetDict
5
-
6
  from .artifact import fetch_artifact
7
  from .dataset_utils import get_dataset_artifact
 
8
  from .logging_utils import get_logger
9
  from .metric_utils import _compute, _inference_post_process
10
  from .operator import SourceOperator
@@ -14,7 +14,7 @@ from .standard import StandardRecipe
14
  logger = get_logger()
15
 
16
 
17
- def load(source: Union[SourceOperator, str]) -> DatasetDict:
18
  assert isinstance(
19
  source, (SourceOperator, str)
20
  ), "source must be a SourceOperator or a string"
@@ -79,7 +79,9 @@ def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> StandardRecipe
79
  return recipe
80
 
81
 
82
- def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
 
 
83
  """Loads dataset.
84
 
85
  If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
@@ -90,6 +92,7 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
90
  dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
91
  For example:
92
  "card=cards.wnli,template=templates.classification.multi_class.relation.default".
 
93
  **kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
94
 
95
  Returns:
@@ -107,6 +110,9 @@ def load_dataset(dataset_query: Optional[str] = None, **kwargs) -> DatasetDict:
107
  """
108
  recipe = load_recipe(dataset_query, **kwargs)
109
 
 
 
 
110
  return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
111
 
112
 
@@ -135,19 +141,45 @@ def produce(instance_or_instances, dataset_query: Optional[str] = None, **kwargs
135
 
136
  def infer(
137
  instance_or_instances,
138
- engine,
139
  dataset_query: Optional[str] = None,
140
- return_data=False,
 
 
141
  **kwargs,
142
  ):
143
  dataset = produce(instance_or_instances, dataset_query, **kwargs)
144
  engine, _ = fetch_artifact(engine)
145
- raw_predictions = engine.infer(dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  predictions = post_process(raw_predictions, dataset)
147
  if return_data:
148
- for prediction, raw_prediction, instance in zip(
149
- predictions, raw_predictions, dataset
150
  ):
 
 
 
151
  instance["prediction"] = prediction
152
  instance["raw_prediction"] = raw_prediction
153
  return dataset
 
1
+ import json
2
  from functools import lru_cache
3
  from typing import Any, Dict, List, Optional, Union
4
 
 
 
5
  from .artifact import fetch_artifact
6
  from .dataset_utils import get_dataset_artifact
7
+ from .inference import InferenceEngine, LogProbInferenceEngine
8
  from .logging_utils import get_logger
9
  from .metric_utils import _compute, _inference_post_process
10
  from .operator import SourceOperator
 
14
  logger = get_logger()
15
 
16
 
17
+ def load(source: Union[SourceOperator, str]):
18
  assert isinstance(
19
  source, (SourceOperator, str)
20
  ), "source must be a SourceOperator or a string"
 
79
  return recipe
80
 
81
 
82
+ def load_dataset(
83
+ dataset_query: Optional[str] = None, streaming: bool = False, **kwargs
84
+ ):
85
  """Loads dataset.
86
 
87
  If the 'dataset_query' argument is provided, then dataset is loaded from a card in local
 
92
  dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
93
  For example:
94
  "card=cards.wnli,template=templates.classification.multi_class.relation.default".
95
+ streaming (bool, False): When True yields the data as Unitxt streams dictionary
96
  **kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
97
 
98
  Returns:
 
110
  """
111
  recipe = load_recipe(dataset_query, **kwargs)
112
 
113
+ if streaming:
114
+ return recipe()
115
+
116
  return recipe().to_dataset(features=UNITXT_DATASET_SCHEMA)
117
 
118
 
 
141
 
142
  def infer(
143
  instance_or_instances,
144
+ engine: InferenceEngine,
145
  dataset_query: Optional[str] = None,
146
+ return_data: bool = False,
147
+ return_log_probs: bool = False,
148
+ return_meta_data: bool = False,
149
  **kwargs,
150
  ):
151
  dataset = produce(instance_or_instances, dataset_query, **kwargs)
152
  engine, _ = fetch_artifact(engine)
153
+ if return_log_probs:
154
+ if not isinstance(engine, LogProbInferenceEngine):
155
+ raise NotImplementedError(
156
+ f"Error in infer: return_log_probs set to True but supplied engine "
157
+ f"{engine.__class__.__name__} does not support logprobs."
158
+ )
159
+ infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
160
+ raw_predictions = (
161
+ [output.prediction for output in infer_outputs]
162
+ if return_meta_data
163
+ else infer_outputs
164
+ )
165
+ raw_predictions = [
166
+ json.dumps(raw_prediction) for raw_prediction in raw_predictions
167
+ ]
168
+ else:
169
+ infer_outputs = engine.infer(dataset, return_meta_data)
170
+ raw_predictions = (
171
+ [output.prediction for output in infer_outputs]
172
+ if return_meta_data
173
+ else infer_outputs
174
+ )
175
  predictions = post_process(raw_predictions, dataset)
176
  if return_data:
177
+ for prediction, raw_prediction, instance, infer_output in zip(
178
+ predictions, raw_predictions, dataset, infer_outputs
179
  ):
180
+ if return_meta_data:
181
+ instance["infer_meta_data"] = infer_output.__dict__
182
+ del instance["infer_meta_data"]["prediction"]
183
  instance["prediction"] = prediction
184
  instance["raw_prediction"] = raw_prediction
185
  return dataset
artifact.py CHANGED
@@ -22,7 +22,12 @@ from .parsing_utils import (
22
  from .settings_utils import get_constants, get_settings
23
  from .text_utils import camel_to_snake_case, is_camel_case
24
  from .type_utils import issubtype
25
- from .utils import artifacts_json_cache, deepcopy, json_dump, save_to_file
 
 
 
 
 
26
 
27
  logger = get_logger()
28
  settings = get_settings()
@@ -405,7 +410,7 @@ def get_raw(obj):
405
  if isinstance(obj, dict):
406
  return type(obj)({get_raw(k): get_raw(v) for k, v in obj.items()})
407
 
408
- return deepcopy(obj)
409
 
410
 
411
  class ArtifactList(list, Artifact):
 
22
  from .settings_utils import get_constants, get_settings
23
  from .text_utils import camel_to_snake_case, is_camel_case
24
  from .type_utils import issubtype
25
+ from .utils import (
26
+ artifacts_json_cache,
27
+ json_dump,
28
+ save_to_file,
29
+ shallow_copy,
30
+ )
31
 
32
  logger = get_logger()
33
  settings = get_settings()
 
410
  if isinstance(obj, dict):
411
  return type(obj)({get_raw(k): get_raw(v) for k, v in obj.items()})
412
 
413
+ return shallow_copy(obj)
414
 
415
 
416
  class ArtifactList(list, Artifact):
collections_operators.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Generator, List, Optional
3
  from .dict_utils import dict_get, dict_set
4
  from .operators import FieldOperator, StreamOperator
5
  from .stream import Stream
6
- from .utils import deepcopy
7
 
8
 
9
  class Dictify(FieldOperator):
@@ -70,10 +70,10 @@ class DuplicateByList(StreamOperator):
70
  elements = dict_get(instance, self.field)
71
  for element in elements:
72
  if self.use_deep_copy:
73
- instance_copy = deepcopy(instance)
74
 
75
  else:
76
- instance_copy = {**instance}
77
  dict_set(instance_copy, to_field, element)
78
  yield instance_copy
79
 
@@ -93,7 +93,7 @@ class DuplicateBySubLists(StreamOperator):
93
  elements = instance[self.field]
94
  for i in range(1, len(elements) + 1):
95
  if self.use_deep_copy:
96
- instance_copy = deepcopy(instance)
97
  instance_copy[to_field] = elements[:i]
98
  else:
99
  instance_copy = {
@@ -107,3 +107,21 @@ class DuplicateBySubLists(StreamOperator):
107
  class GetLength(FieldOperator):
108
  def process_value(self, collection: Any) -> Any:
109
  return len(collection)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from .dict_utils import dict_get, dict_set
4
  from .operators import FieldOperator, StreamOperator
5
  from .stream import Stream
6
+ from .utils import recursive_shallow_copy
7
 
8
 
9
  class Dictify(FieldOperator):
 
70
  elements = dict_get(instance, self.field)
71
  for element in elements:
72
  if self.use_deep_copy:
73
+ instance_copy = recursive_shallow_copy(instance)
74
 
75
  else:
76
+ instance_copy = instance.copy()
77
  dict_set(instance_copy, to_field, element)
78
  yield instance_copy
79
 
 
93
  elements = instance[self.field]
94
  for i in range(1, len(elements) + 1):
95
  if self.use_deep_copy:
96
+ instance_copy = recursive_shallow_copy(instance)
97
  instance_copy[to_field] = elements[:i]
98
  else:
99
  instance_copy = {
 
107
  class GetLength(FieldOperator):
108
  def process_value(self, collection: Any) -> Any:
109
  return len(collection)
110
+
111
+
112
+ class Filter(FieldOperator):
113
+ values: List[Any]
114
+
115
+ def process_value(self, collection: Any) -> Any:
116
+ # If collection is a list, tuple, or set
117
+ if isinstance(collection, (list, set, tuple)):
118
+ return type(collection)(
119
+ item for item in collection if item not in self.values
120
+ )
121
+
122
+ # If collection is a dictionary, filter by keys
123
+ if isinstance(collection, dict):
124
+ return {k: v for k, v in collection.items() if k not in self.values}
125
+
126
+ # If collection is of an unsupported type
127
+ raise TypeError(f"Unsupported collection type: {type(collection)}")
dialog_operators.py CHANGED
@@ -157,13 +157,13 @@ class SerializeOpenAiFormatDialog(SerializeDialog):
157
  f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
158
  )
159
 
160
- if entry["role"] not in {"user", "assistant"}:
161
  raise ValueError(
162
  f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
163
  )
164
 
165
  first_entry = dialog[0]
166
- if first_entry["role"] != "user":
167
  raise ValueError(
168
  f"First entry role is expected to be 'user' It is {first_entry['role']}."
169
  )
 
157
  f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
158
  )
159
 
160
+ if entry["role"].lower() not in {"user", "assistant"}:
161
  raise ValueError(
162
  f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
163
  )
164
 
165
  first_entry = dialog[0]
166
+ if first_entry["role"].lower() != "user":
167
  raise ValueError(
168
  f"First entry role is expected to be 'user' It is {first_entry['role']}."
169
  )
formats.py CHANGED
@@ -182,6 +182,7 @@ class SystemFormat(BaseFormat):
182
  target_prefix=demo_target_prefix,
183
  source=demo_source,
184
  target=demo_target,
 
185
  **self.format_args,
186
  )
187
  demos_string += demo_str
 
182
  target_prefix=demo_target_prefix,
183
  source=demo_source,
184
  target=demo_target,
185
+ instruction=instruction,
186
  **self.format_args,
187
  )
188
  demos_string += demo_str
generator_utils.py CHANGED
@@ -1,7 +1,7 @@
1
  from typing import Any, Dict, List
2
 
3
  from .dataclass import Dataclass, OptionalField
4
- from .utils import deepcopy
5
 
6
 
7
  class ReusableGenerator(Dataclass):
@@ -22,34 +22,4 @@ class ReusableGenerator(Dataclass):
22
  class CopyingReusableGenerator(ReusableGenerator):
23
  def __iter__(self):
24
  for instance in self.activate():
25
- yield deepcopy(instance)
26
-
27
-
28
- # if __name__ == "__main__":
29
- # from itertools import chain, islice
30
-
31
- # # Creating objects of MyIterable
32
- # iterable1 = ReusableGenerator(range, gen_argv=[1, 4])
33
- # iterable2 = ReusableGenerator(range, gen_argv=[4, 7])
34
-
35
- # # Using itertools.chain
36
- # chained = list(chain(iterable1, iterable2))
37
- # logger.info(chained) # Prints: [1, 2, 3, 4, 5, 6]
38
-
39
- # # Using itertools.islice
40
- # sliced = list(islice(ReusableGenerator(range, gen_argv=[1, 7]), 1, 4))
41
- # logger.info(sliced) # Prints: [2, 3, 4]
42
-
43
- # # now same test with generators
44
- # def generator(start, end):
45
- # for i in range(start, end):
46
- # yield i
47
-
48
- # iterable1 = ReusableGenerator(generator, gen_argv=[1, 4])
49
- # iterable2 = ReusableGenerator(generator, gen_argv=[4, 7])
50
-
51
- # chained = list(chain(iterable1, iterable2))
52
- # logger.info(chained) # Prints: [1, 2, 3, 4, 5, 6]
53
-
54
- # sliced = list(islice(ReusableGenerator(generator, gen_argv=[1, 7]), 1, 4))
55
- # logger.info(sliced) # Prints: [2, 3, 4]
 
1
  from typing import Any, Dict, List
2
 
3
  from .dataclass import Dataclass, OptionalField
4
+ from .utils import recursive_shallow_copy
5
 
6
 
7
  class ReusableGenerator(Dataclass):
 
22
  class CopyingReusableGenerator(ReusableGenerator):
23
  def __iter__(self):
24
  for instance in self.activate():
25
+ yield recursive_shallow_copy(instance)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py CHANGED
@@ -1,8 +1,10 @@
1
  import abc
 
2
  import os
3
  import re
4
  from typing import Any, Dict, List, Literal, Optional, Union
5
 
 
6
  from tqdm import tqdm
7
 
8
  from .artifact import Artifact, fetch_artifact
@@ -16,12 +18,52 @@ from .settings_utils import get_settings
16
  settings = get_settings()
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class InferenceEngine(abc.ABC, Artifact):
20
  """Abstract base class for inference."""
21
 
22
  @abc.abstractmethod
23
- def _infer(self, dataset):
24
- """Perform inference on the input dataset."""
 
 
 
 
 
 
 
 
 
25
  pass
26
 
27
  @abc.abstractmethod
@@ -33,12 +75,29 @@ class InferenceEngine(abc.ABC, Artifact):
33
  if not settings.mock_inference_mode:
34
  self.prepare_engine()
35
 
36
- def infer(self, dataset) -> str:
37
- """Verifies instances of a dataset and performs inference."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  [self.verify_instance(instance) for instance in dataset]
39
  if settings.mock_inference_mode:
40
  return [instance["source"] for instance in dataset]
41
- return self._infer(dataset)
 
 
 
42
 
43
  @deprecation(version="2.0.0")
44
  def _set_inference_parameters(self):
@@ -62,19 +121,39 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
62
  """Abstract base class for inference with log probs."""
63
 
64
  @abc.abstractmethod
65
- def _infer_log_probs(self, dataset):
66
- """Perform inference on the input dataset that returns log probs."""
 
 
 
 
 
 
 
 
 
67
  pass
68
 
69
- def infer_log_probs(self, dataset) -> List[Dict]:
 
 
 
 
70
  """Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
71
 
72
- For each instance , returns a list of top tokens per position.
73
  [ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
74
-
 
75
  """
 
 
 
 
 
 
76
  [self.verify_instance(instance) for instance in dataset]
77
- return self._infer_log_probs(dataset)
78
 
79
 
80
  class LazyLoadMixin(Artifact):
@@ -96,6 +175,9 @@ class HFPipelineBasedInferenceEngine(
96
  "transformers": "Install huggingface package using 'pip install --upgrade transformers"
97
  }
98
 
 
 
 
99
  def _prepare_pipeline(self):
100
  import torch
101
  from transformers import AutoConfig, pipeline
@@ -143,7 +225,11 @@ class HFPipelineBasedInferenceEngine(
143
  def _is_loaded(self):
144
  return hasattr(self, "model") and self.model is not None
145
 
146
- def _infer(self, dataset):
 
 
 
 
147
  if not self._is_loaded():
148
  self._prepare_pipeline()
149
 
@@ -157,12 +243,20 @@ class HFPipelineBasedInferenceEngine(
157
 
158
  class MockInferenceEngine(InferenceEngine):
159
  model_name: str
 
 
 
 
160
 
161
  def prepare_engine(self):
162
  return
163
 
164
- def _infer(self, dataset):
165
- return ["[[10]]" for instance in dataset]
 
 
 
 
166
 
167
 
168
  class MockModeMixin(Artifact):
@@ -226,7 +320,14 @@ class GenericInferenceEngine(InferenceEngine):
226
  engine_reference = self.default
227
  self.engine, _ = fetch_artifact(engine_reference)
228
 
229
- def _infer(self, dataset):
 
 
 
 
 
 
 
230
  return self.engine._infer(dataset)
231
 
232
 
@@ -238,10 +339,17 @@ class OllamaInferenceEngine(InferenceEngine, PackageRequirementsMixin):
238
  }
239
  data_classification_policy = ["public", "proprietary"]
240
 
 
 
 
241
  def prepare_engine(self):
242
  pass
243
 
244
- def _infer(self, dataset):
 
 
 
 
245
  import ollama
246
 
247
  result = [
@@ -260,7 +368,10 @@ class OllamaInferenceEngine(InferenceEngine, PackageRequirementsMixin):
260
 
261
 
262
  class IbmGenAiInferenceEngine(
263
- InferenceEngine, IbmGenAiInferenceEngineParamsMixin, PackageRequirementsMixin
 
 
 
264
  ):
265
  label: str = "ibm_genai"
266
  model_name: str
@@ -270,6 +381,9 @@ class IbmGenAiInferenceEngine(
270
  data_classification_policy = ["public", "proprietary"]
271
  parameters: Optional[IbmGenAiInferenceEngineParams] = None
272
 
 
 
 
273
  def prepare_engine(self):
274
  from genai import Client, Credentials
275
 
@@ -285,21 +399,88 @@ class IbmGenAiInferenceEngine(
285
 
286
  self._set_inference_parameters()
287
 
288
- def _infer(self, dataset):
 
 
 
 
289
  from genai.schema import TextGenerationParameters
290
 
291
  genai_params = TextGenerationParameters(
292
  **self.to_dict([IbmGenAiInferenceEngineParamsMixin])
293
  )
294
 
295
- return [
296
- response.results[0].generated_text
297
- for response in self.client.text.generation.create(
298
- model_id=self.model_name,
299
- inputs=[instance["source"] for instance in dataset],
300
- parameters=genai_params,
 
 
 
 
301
  )
302
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
 
305
  class OpenAiInferenceEngineParamsMixin(Artifact):
@@ -349,18 +530,29 @@ class OpenAiInferenceEngine(
349
  data_classification_policy = ["public"]
350
  parameters: Optional[OpenAiInferenceEngineParams] = None
351
 
352
- def prepare_engine(self):
353
- from openai import OpenAI
354
 
355
- api_key_env_var_name = "OPENAI_API_KEY"
356
- api_key = os.environ.get(api_key_env_var_name)
 
357
  assert api_key is not None, (
358
- f"Error while trying to run OpenAiInferenceEngine."
359
- f" Please set the environment param '{api_key_env_var_name}'."
360
  )
 
361
 
362
- self.client = OpenAI(api_key=api_key)
 
363
 
 
 
 
 
 
 
 
 
364
  self._set_inference_parameters()
365
 
366
  def _get_completion_kwargs(self):
@@ -370,7 +562,11 @@ class OpenAiInferenceEngine(
370
  if v is not None
371
  }
372
 
373
- def _infer(self, dataset):
 
 
 
 
374
  outputs = []
375
  for instance in tqdm(dataset, desc="Inferring with openAI API"):
376
  response = self.client.chat.completions.create(
@@ -387,13 +583,18 @@ class OpenAiInferenceEngine(
387
  model=self.model_name,
388
  **self._get_completion_kwargs(),
389
  )
390
- output = response.choices[0].message.content
 
391
 
392
  outputs.append(output)
393
 
394
  return outputs
395
 
396
- def _infer_log_probs(self, dataset):
 
 
 
 
397
  outputs = []
398
  for instance in tqdm(dataset, desc="Inferring with openAI API"):
399
  response = self.client.chat.completions.create(
@@ -411,7 +612,7 @@ class OpenAiInferenceEngine(
411
  **self._get_completion_kwargs(),
412
  )
413
  top_logprobs_response = response.choices[0].logprobs.content
414
- output = [
415
  {
416
  "top_tokens": [
417
  {"text": obj.token, "logprob": obj.logprob}
@@ -420,9 +621,21 @@ class OpenAiInferenceEngine(
420
  }
421
  for generated_token in top_logprobs_response
422
  ]
 
423
  outputs.append(output)
424
  return outputs
425
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
  class TogetherAiInferenceEngineParamsMixin(Artifact):
428
  max_tokens: Optional[int] = None
@@ -450,6 +663,9 @@ class TogetherAiInferenceEngine(
450
  data_classification_policy = ["public"]
451
  parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None
452
 
 
 
 
453
  def prepare_engine(self):
454
  from together import Together
455
  from together.types.models import ModelType
@@ -501,7 +717,11 @@ class TogetherAiInferenceEngine(
501
  )
502
  return response.choices[0].text
503
 
504
- def _infer(self, dataset):
 
 
 
 
505
  from together.types.models import ModelType
506
 
507
  outputs = []
@@ -514,6 +734,23 @@ class TogetherAiInferenceEngine(
514
  return outputs
515
 
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  class WMLInferenceEngineParamsMixin(Artifact):
518
  decoding_method: Optional[Literal["greedy", "sample"]] = None
519
  length_penalty: Optional[Dict[str, Union[int, float]]] = None
@@ -550,7 +787,10 @@ class WMLInferenceEngineParams(Artifact):
550
 
551
 
552
  class WMLInferenceEngine(
553
- InferenceEngine, WMLInferenceEngineParamsMixin, PackageRequirementsMixin
 
 
 
554
  ):
555
  """Runs inference using ibm-watsonx-ai.
556
 
@@ -604,14 +844,17 @@ class WMLInferenceEngine(
604
  concurrency_limit: int = 10
605
  _client: Any = InternalField(default=None, name="WML client")
606
 
 
 
 
607
  def verify(self):
608
  super().verify()
609
 
610
  if self.credentials is not None:
611
  for key in self.credentials:
612
- if key not in ["url", "apikey", "project_id"]:
613
  raise ValueError(
614
- f'Illegal credential key: {key}, use only ["url", "apikey", "project_id"]'
615
  )
616
 
617
  assert (
@@ -631,10 +874,14 @@ class WMLInferenceEngine(
631
 
632
  @staticmethod
633
  def _read_wml_credentials_from_env() -> (
634
- Dict[Literal["url", "apikey", "project_id"], str]
635
  ):
636
  credentials = {}
637
- for env_var_name in ["WML_URL", "WML_PROJECT_ID", "WML_APIKEY"]:
 
 
 
 
638
  env_var = os.environ.get(env_var_name)
639
  assert env_var, (
640
  f"Error while trying to run 'WMLInferenceEngine'. "
@@ -655,7 +902,10 @@ class WMLInferenceEngine(
655
  self.credentials = self._read_wml_credentials_from_env()
656
 
657
  client = APIClient(credentials=self.credentials)
658
- client.set.default_project(self.credentials["project_id"])
 
 
 
659
  return client
660
 
661
  def prepare_engine(self):
@@ -663,7 +913,7 @@ class WMLInferenceEngine(
663
 
664
  self._set_inference_parameters()
665
 
666
- def _infer(self, dataset):
667
  from ibm_watsonx_ai.foundation_models import ModelInference
668
 
669
  model = ModelInference(
@@ -671,20 +921,81 @@ class WMLInferenceEngine(
671
  deployment_id=self.deployment_id,
672
  api_client=self._client,
673
  )
 
674
 
675
- # the class was previously used with a dataset that is a single instance
676
- dataset = dataset if isinstance(dataset, list) else [dataset]
677
 
678
- result = [
679
- model.generate_text(
 
 
 
 
 
 
 
 
680
  prompt=instance["source"],
681
  params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
682
  )
683
- for instance in dataset
684
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
 
686
- # the class was previously used with a dataset that is a single instance
687
- return result[0] if not isinstance(dataset, list) else result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
 
689
 
690
  class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
@@ -698,6 +1009,9 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
698
  "accelerate": "pip install accelerate",
699
  }
700
 
 
 
 
701
  def _prepare_engine(self):
702
  import torch
703
  from transformers import AutoProcessor, LlavaForConditionalGeneration
@@ -725,14 +1039,18 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
725
  def _is_loaded(self):
726
  return hasattr(self, "model") and self.model is not None
727
 
728
- def _infer(self, dataset):
 
 
 
 
729
  if not self._is_loaded():
730
  self._prepare_engine()
731
 
732
  import torch
733
 
734
  results = []
735
- for instance in dataset:
736
  text = instance["source"]
737
  images = extract_images(text, instance)
738
  # Regular expression to match all <img src="..."> tags
@@ -745,7 +1063,10 @@ class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
745
  ).to(self.device, torch.float16)
746
  input_len = len(inputs["input_ids"][0])
747
  output = self.model.generate(
748
- **inputs, max_new_tokens=self.max_new_tokens, do_sample=False
 
 
 
749
  )
750
  result = self.processor.decode(
751
  output[0][input_len:], skip_special_tokens=True
 
1
  import abc
2
+ import dataclasses
3
  import os
4
  import re
5
  from typing import Any, Dict, List, Literal, Optional, Union
6
 
7
+ from datasets import DatasetDict
8
  from tqdm import tqdm
9
 
10
  from .artifact import Artifact, fetch_artifact
 
18
  settings = get_settings()
19
 
20
 
21
+ def get_model_and_label_id(model_name, label):
22
+ model_id = model_name.split("/")[-1].replace("-", "_").replace(".", ",").lower()
23
+ return f"{model_id}_{label}"
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class TextGenerationInferenceOutput:
28
+ """Contains the prediction results and metadata for the inference.
29
+
30
+ Args:
31
+ prediction (Union[str, List[Dict[str, Any]]]): If this is the result of an _infer call, the string predicted by the model.
32
+ If this is the results of an _infer_log_probs call, a list of dictionaries. The i'th dictionary represents
33
+ the i'th token in the response. The entry "top_tokens" in the dictionary holds a sorted list of the top tokens
34
+ for this position and their probabilities.
35
+ For example: [ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
36
+ {.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]}
37
+ ]
38
+
39
+ input_tokens (int) : number of input tokens to the model.
40
+ output_tokens (int) : number of output tokens to the model.
41
+ model_name (str): the model_name as kept in the InferenceEngine.
42
+ inference_type (str): The label stating the type of the InferenceEngine.
43
+ """
44
+
45
+ prediction: Union[str, List[Dict[str, Any]]]
46
+ input_tokens: Optional[int] = None
47
+ output_tokens: Optional[int] = None
48
+ model_name: Optional[str] = None
49
+ inference_type: Optional[str] = None
50
+
51
+
52
  class InferenceEngine(abc.ABC, Artifact):
53
  """Abstract base class for inference."""
54
 
55
  @abc.abstractmethod
56
+ def _infer(
57
+ self,
58
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
59
+ return_meta_data: bool = False,
60
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
61
+ """Perform inference on the input dataset.
62
+
63
+ If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string.
64
+ return_meta_data is only supported for some InferenceEngines.
65
+ predictions.
66
+ """
67
  pass
68
 
69
  @abc.abstractmethod
 
75
  if not settings.mock_inference_mode:
76
  self.prepare_engine()
77
 
78
+ def infer(
79
+ self,
80
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
81
+ return_meta_data: bool = False,
82
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
83
+ """Verifies instances of a dataset and perform inference on the input dataset.
84
+
85
+ If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
86
+ predictions.
87
+ """
88
+ if return_meta_data and not hasattr(self, "get_return_object"):
89
+ raise NotImplementedError(
90
+ f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
91
+ f"does not contain a 'get_return_object' method. Please set return_meta_data=False."
92
+ )
93
+
94
  [self.verify_instance(instance) for instance in dataset]
95
  if settings.mock_inference_mode:
96
  return [instance["source"] for instance in dataset]
97
+ return self._infer(dataset, return_meta_data)
98
+
99
+ def get_engine_id(self):
100
+ raise NotImplementedError()
101
 
102
  @deprecation(version="2.0.0")
103
  def _set_inference_parameters(self):
 
121
  """Abstract base class for inference with log probs."""
122
 
123
  @abc.abstractmethod
124
+ def _infer_log_probs(
125
+ self,
126
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
127
+ return_meta_data: bool = False,
128
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
129
+ """Perform inference on the input dataset that returns log probs.
130
+
131
+ If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the logprob dicts.
132
+ return_meta_data is only supported for some InferenceEngines.
133
+ predictions.
134
+ """
135
  pass
136
 
137
+ def infer_log_probs(
138
+ self,
139
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
140
+ return_meta_data: bool = False,
141
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
142
  """Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
143
 
144
+ For each instance , generates a list of top tokens per position.
145
  [ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
146
+ If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns the list of the logprob dicts.
147
+ return_meta_data is only supported for some InferenceEngines.
148
  """
149
+ if return_meta_data and not hasattr(self, "get_return_object"):
150
+ raise NotImplementedError(
151
+ f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
152
+ f"does not contain a 'get_return_object' method. Please set return_meta_data=False."
153
+ )
154
+
155
  [self.verify_instance(instance) for instance in dataset]
156
+ return self._infer_log_probs(dataset, return_meta_data)
157
 
158
 
159
  class LazyLoadMixin(Artifact):
 
175
  "transformers": "Install huggingface package using 'pip install --upgrade transformers"
176
  }
177
 
178
+ def get_engine_id(self):
179
+ return get_model_and_label_id(self.model_name, "hf_pipeline")
180
+
181
  def _prepare_pipeline(self):
182
  import torch
183
  from transformers import AutoConfig, pipeline
 
225
  def _is_loaded(self):
226
  return hasattr(self, "model") and self.model is not None
227
 
228
+ def _infer(
229
+ self,
230
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
231
+ return_meta_data: bool = False,
232
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
233
  if not self._is_loaded():
234
  self._prepare_pipeline()
235
 
 
243
 
244
  class MockInferenceEngine(InferenceEngine):
245
  model_name: str
246
+ default_inference_value: str = "[[10]]"
247
+
248
+ def get_engine_id(self):
249
+ return get_model_and_label_id(self.model_name, "mock")
250
 
251
  def prepare_engine(self):
252
  return
253
 
254
+ def _infer(
255
+ self,
256
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
257
+ return_meta_data: bool = False,
258
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
259
+ return [self.default_inference_value for instance in dataset]
260
 
261
 
262
  class MockModeMixin(Artifact):
 
320
  engine_reference = self.default
321
  self.engine, _ = fetch_artifact(engine_reference)
322
 
323
+ def get_engine_id(self):
324
+ return "generic_inference_engine"
325
+
326
+ def _infer(
327
+ self,
328
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
329
+ return_meta_data: bool = False,
330
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
331
  return self.engine._infer(dataset)
332
 
333
 
 
339
  }
340
  data_classification_policy = ["public", "proprietary"]
341
 
342
+ def get_engine_id(self):
343
+ return get_model_and_label_id(self.model_name, self.label)
344
+
345
  def prepare_engine(self):
346
  pass
347
 
348
+ def _infer(
349
+ self,
350
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
351
+ return_meta_data: bool = False,
352
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
353
  import ollama
354
 
355
  result = [
 
368
 
369
 
370
  class IbmGenAiInferenceEngine(
371
+ InferenceEngine,
372
+ IbmGenAiInferenceEngineParamsMixin,
373
+ PackageRequirementsMixin,
374
+ LogProbInferenceEngine,
375
  ):
376
  label: str = "ibm_genai"
377
  model_name: str
 
381
  data_classification_policy = ["public", "proprietary"]
382
  parameters: Optional[IbmGenAiInferenceEngineParams] = None
383
 
384
+ def get_engine_id(self):
385
+ return get_model_and_label_id(self.model_name, self.label)
386
+
387
  def prepare_engine(self):
388
  from genai import Client, Credentials
389
 
 
399
 
400
  self._set_inference_parameters()
401
 
402
+ def _infer(
403
+ self,
404
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
405
+ return_meta_data: bool = False,
406
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
407
  from genai.schema import TextGenerationParameters
408
 
409
  genai_params = TextGenerationParameters(
410
  **self.to_dict([IbmGenAiInferenceEngineParamsMixin])
411
  )
412
 
413
+ results = []
414
+ responses = self.client.text.generation.create(
415
+ model_id=self.model_name,
416
+ inputs=[instance["source"] for instance in dataset],
417
+ parameters=genai_params,
418
+ )
419
+ for response in responses:
420
+ generated_text = response.results[0].generated_text
421
+ result = self.get_return_object(
422
+ generated_text, response.results[0], return_meta_data
423
  )
424
+ results.append(result)
425
+ return results
426
+
427
+ def _infer_log_probs(
428
+ self,
429
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
430
+ return_meta_data: bool = False,
431
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
432
+ from genai.schema import TextGenerationParameters
433
+
434
+ logprobs_return_options = {
435
+ "generated_tokens": True,
436
+ "input_text": False,
437
+ "input_tokens": False,
438
+ "token_logprobs": True,
439
+ "token_ranks": True,
440
+ "top_n_tokens": 5,
441
+ }
442
+ genai_params = self.to_dict(
443
+ [IbmGenAiInferenceEngineParamsMixin], keep_empty=False
444
+ )
445
+ genai_params = {**genai_params, "return_options": logprobs_return_options}
446
+ genai_params = TextGenerationParameters(**genai_params)
447
+ predictions = self.client.text.generation.create(
448
+ model_id=self.model_name,
449
+ inputs=[instance["source"] for instance in dataset],
450
+ parameters=genai_params,
451
+ )
452
+
453
+ predict_results = []
454
+ for prediction in predictions:
455
+ result = prediction.results[0]
456
+ assert isinstance(
457
+ result.generated_tokens, list
458
+ ), "result.generated_tokens should be a list"
459
+
460
+ predict_result = []
461
+ for base_token in result.generated_tokens:
462
+ res = {**base_token.__dict__, **base_token.model_extra}
463
+ res["top_tokens"] = [
464
+ {"logprob": top_token.logprob, "text": top_token.text}
465
+ for top_token in res["top_tokens"]
466
+ ]
467
+ predict_result.append(res)
468
+ final_results = self.get_return_object(
469
+ predict_result, result, return_meta_data
470
+ )
471
+ predict_results.append(final_results)
472
+ return predict_results
473
+
474
+ def get_return_object(self, predict_result, result, return_meta_data):
475
+ if return_meta_data:
476
+ return TextGenerationInferenceOutput(
477
+ prediction=predict_result,
478
+ input_tokens=result.input_token_count,
479
+ output_tokens=result.generated_token_count,
480
+ model_name=self.model_name,
481
+ inference_type=self.label,
482
+ )
483
+ return predict_result
484
 
485
 
486
  class OpenAiInferenceEngineParamsMixin(Artifact):
 
530
  data_classification_policy = ["public"]
531
  parameters: Optional[OpenAiInferenceEngineParams] = None
532
 
533
+ def get_engine_id(self):
534
+ return get_model_and_label_id(self.model_name, self.label)
535
 
536
+ @classmethod
537
+ def get_api_param(cls, inference_engine: str, api_param_env_var_name: str):
538
+ api_key = os.environ.get(api_param_env_var_name)
539
  assert api_key is not None, (
540
+ f"Error while trying to run {inference_engine}."
541
+ f" Please set the environment param '{api_param_env_var_name}'."
542
  )
543
+ return api_key
544
 
545
+ def create_client(self):
546
+ from openai import OpenAI
547
 
548
+ api_key = self.get_api_param(
549
+ inference_engine="OpenAiInferenceEngine",
550
+ api_param_env_var_name="OPENAI_API_KEY",
551
+ )
552
+ return OpenAI(api_key=api_key)
553
+
554
+ def prepare_engine(self):
555
+ self.client = self.create_client()
556
  self._set_inference_parameters()
557
 
558
  def _get_completion_kwargs(self):
 
562
  if v is not None
563
  }
564
 
565
+ def _infer(
566
+ self,
567
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
568
+ return_meta_data: bool = False,
569
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
570
  outputs = []
571
  for instance in tqdm(dataset, desc="Inferring with openAI API"):
572
  response = self.client.chat.completions.create(
 
583
  model=self.model_name,
584
  **self._get_completion_kwargs(),
585
  )
586
+ prediction = response.choices[0].message.content
587
+ output = self.get_return_object(prediction, response, return_meta_data)
588
 
589
  outputs.append(output)
590
 
591
  return outputs
592
 
593
+ def _infer_log_probs(
594
+ self,
595
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
596
+ return_meta_data: bool = False,
597
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
598
  outputs = []
599
  for instance in tqdm(dataset, desc="Inferring with openAI API"):
600
  response = self.client.chat.completions.create(
 
612
  **self._get_completion_kwargs(),
613
  )
614
  top_logprobs_response = response.choices[0].logprobs.content
615
+ pred_output = [
616
  {
617
  "top_tokens": [
618
  {"text": obj.token, "logprob": obj.logprob}
 
621
  }
622
  for generated_token in top_logprobs_response
623
  ]
624
+ output = self.get_return_object(pred_output, response, return_meta_data)
625
  outputs.append(output)
626
  return outputs
627
 
628
+ def get_return_object(self, predict_result, response, return_meta_data):
629
+ if return_meta_data:
630
+ return TextGenerationInferenceOutput(
631
+ prediction=predict_result,
632
+ input_tokens=response.usage.prompt_tokens,
633
+ output_tokens=response.usage.completion_tokens,
634
+ model_name=self.model_name,
635
+ inference_type=self.label,
636
+ )
637
+ return predict_result
638
+
639
 
640
  class TogetherAiInferenceEngineParamsMixin(Artifact):
641
  max_tokens: Optional[int] = None
 
663
  data_classification_policy = ["public"]
664
  parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None
665
 
666
+ def get_engine_id(self):
667
+ return get_model_and_label_id(self.model_name, self.label)
668
+
669
  def prepare_engine(self):
670
  from together import Together
671
  from together.types.models import ModelType
 
717
  )
718
  return response.choices[0].text
719
 
720
+ def _infer(
721
+ self,
722
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
723
+ return_meta_data: bool = False,
724
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
725
  from together.types.models import ModelType
726
 
727
  outputs = []
 
734
  return outputs
735
 
736
 
737
+ class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
738
+ label: str = "vllm"
739
+
740
+ def create_client(self):
741
+ from openai import OpenAI
742
+
743
+ api_key = self.get_api_param(
744
+ inference_engine="VLLMRemoteInferenceEngine",
745
+ api_param_env_var_name="VLLM_API_KEY",
746
+ )
747
+ api_url = self.get_api_param(
748
+ inference_engine="VLLMRemoteInferenceEngine",
749
+ api_param_env_var_name="VLLM_API_URL",
750
+ )
751
+ return OpenAI(api_key=api_key, base_url=api_url)
752
+
753
+
754
  class WMLInferenceEngineParamsMixin(Artifact):
755
  decoding_method: Optional[Literal["greedy", "sample"]] = None
756
  length_penalty: Optional[Dict[str, Union[int, float]]] = None
 
787
 
788
 
789
  class WMLInferenceEngine(
790
+ InferenceEngine,
791
+ WMLInferenceEngineParamsMixin,
792
+ PackageRequirementsMixin,
793
+ LogProbInferenceEngine,
794
  ):
795
  """Runs inference using ibm-watsonx-ai.
796
 
 
844
  concurrency_limit: int = 10
845
  _client: Any = InternalField(default=None, name="WML client")
846
 
847
+ def get_engine_id(self):
848
+ return get_model_and_label_id(self.model_name, self.label)
849
+
850
  def verify(self):
851
  super().verify()
852
 
853
  if self.credentials is not None:
854
  for key in self.credentials:
855
+ if key not in ["url", "apikey", "project_id", "space_id"]:
856
  raise ValueError(
857
+ f'Illegal credential key: {key}, use only ["url", "apikey", "project_id", "space_id"]'
858
  )
859
 
860
  assert (
 
874
 
875
  @staticmethod
876
  def _read_wml_credentials_from_env() -> (
877
+ Dict[Literal["url", "apikey", "project_id", "space_id"], str]
878
  ):
879
  credentials = {}
880
+ project_or_deployment_var_name = (
881
+ "WML_SPACE_ID" if "WML_SPACE_ID" in os.environ else "WML_PROJECT_ID"
882
+ )
883
+
884
+ for env_var_name in ["WML_URL", project_or_deployment_var_name, "WML_APIKEY"]:
885
  env_var = os.environ.get(env_var_name)
886
  assert env_var, (
887
  f"Error while trying to run 'WMLInferenceEngine'. "
 
902
  self.credentials = self._read_wml_credentials_from_env()
903
 
904
  client = APIClient(credentials=self.credentials)
905
+ if "space_id" in self.credentials:
906
+ client.set.default_space(self.credentials["space_id"])
907
+ else:
908
+ client.set.default_project(self.credentials["project_id"])
909
  return client
910
 
911
  def prepare_engine(self):
 
913
 
914
  self._set_inference_parameters()
915
 
916
+ def _load_model_and_params(self):
917
  from ibm_watsonx_ai.foundation_models import ModelInference
918
 
919
  model = ModelInference(
 
921
  deployment_id=self.deployment_id,
922
  api_client=self._client,
923
  )
924
+ params = self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False)
925
 
926
+ return model, params
 
927
 
928
+ def _infer(
929
+ self,
930
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
931
+ return_meta_data: bool = False,
932
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
933
+ model, params = self._load_model_and_params()
934
+
935
+ result = []
936
+ for instance in dataset:
937
+ instance_result = model.generate(
938
  prompt=instance["source"],
939
  params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
940
  )
941
+ prediction = instance_result["results"][0]["generated_text"]
942
+ instance_final_results = self.get_return_object(
943
+ prediction, instance_result, return_meta_data
944
+ )
945
+ result.append(instance_final_results)
946
+
947
+ return result
948
+
949
+ def _infer_log_probs(
950
+ self,
951
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
952
+ return_meta_data: bool = False,
953
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
954
+ model, params = self._load_model_and_params()
955
+
956
+ user_return_options = params.pop("return_options", {})
957
+ # currently this is the only configuration that returns generated logprobs and behaves as expected
958
+ logprobs_return_options = {
959
+ "input_tokens": True,
960
+ "generated_tokens": True,
961
+ "token_logprobs": True,
962
+ "top_n_tokens": user_return_options.get("top_n_tokens", 5),
963
+ }
964
+ for key, value in logprobs_return_options.items():
965
+ if key in user_return_options and user_return_options[key] != value:
966
+ raise ValueError(
967
+ f"'{key}={user_return_options[key]}' is not supported for the 'infer_log_probs' "
968
+ f"method of {self.__class__.__name__}. For obtaining the logprobs of generated tokens "
969
+ f"please use '{key}={value}'."
970
+ )
971
+
972
+ params = {
973
+ **params,
974
+ "return_options": logprobs_return_options,
975
+ }
976
 
977
+ results = model.generate(
978
+ prompt=[instance["source"] for instance in dataset],
979
+ params=params,
980
+ )
981
+ final_results = []
982
+ for result in results:
983
+ generated_tokens = result["results"][0]["generated_tokens"]
984
+ final_results.append(
985
+ self.get_return_object(generated_tokens, result, return_meta_data)
986
+ )
987
+ return final_results
988
+
989
+ def get_return_object(self, predict_result, result, return_meta_data):
990
+ if return_meta_data:
991
+ return TextGenerationInferenceOutput(
992
+ prediction=predict_result,
993
+ input_tokens=result["results"][0]["input_token_count"],
994
+ output_tokens=result["results"][0]["generated_token_count"],
995
+ model_name=self.model_name,
996
+ inference_type=self.label,
997
+ )
998
+ return predict_result
999
 
1000
 
1001
  class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
 
1009
  "accelerate": "pip install accelerate",
1010
  }
1011
 
1012
+ def get_engine_id(self):
1013
+ return get_model_and_label_id(self.model_name, "hf_lava")
1014
+
1015
  def _prepare_engine(self):
1016
  import torch
1017
  from transformers import AutoProcessor, LlavaForConditionalGeneration
 
1039
  def _is_loaded(self):
1040
  return hasattr(self, "model") and self.model is not None
1041
 
1042
+ def _infer(
1043
+ self,
1044
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
1045
+ return_meta_data: bool = False,
1046
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1047
  if not self._is_loaded():
1048
  self._prepare_engine()
1049
 
1050
  import torch
1051
 
1052
  results = []
1053
+ for instance in tqdm(dataset):
1054
  text = instance["source"]
1055
  images = extract_images(text, instance)
1056
  # Regular expression to match all <img src="..."> tags
 
1063
  ).to(self.device, torch.float16)
1064
  input_len = len(inputs["input_ids"][0])
1065
  output = self.model.generate(
1066
+ **inputs,
1067
+ max_new_tokens=self.max_new_tokens,
1068
+ do_sample=False,
1069
+ pad_token_id=self.processor.tokenizer.eos_token_id,
1070
  )
1071
  result = self.processor.decode(
1072
  output[0][input_len:], skip_special_tokens=True
llm_as_judge.py CHANGED
@@ -1,10 +1,11 @@
 
1
  from typing import Any, Dict, List, Literal, Optional
2
 
3
  from .api import infer
4
  from .artifact import fetch_artifact
5
  from .dataclass import Field
6
  from .formats import Format, SystemFormat
7
- from .inference import InferenceEngine, OpenAiInferenceEngine
8
  from .metrics import BulkInstanceMetric
9
  from .operator import SequentialOperator
10
  from .settings_utils import get_settings
@@ -14,38 +15,142 @@ from .templates import Template
14
  settings = get_settings()
15
 
16
 
17
- class LLMAsJudge(BulkInstanceMetric):
18
- """LLM-as-judge-based metric class for evaluating correctness.
 
 
 
 
 
 
 
 
19
 
20
  Attributes:
21
  main_score (str): The main score label used for evaluation.
22
- task (Literal["rating.single_turn"]): The type of task the llm as judge runs. This defines the output and input
23
  format of the judge model.
24
  template (Template): The template used when generating inputs for the judge llm.
25
  format (Format): The format used when generating inputs for judge llm.
26
  system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
27
- strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
28
- inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
29
  inference_model (InferenceEngine): The module that creates the inference of the judge llm.
30
  reduction_map (dict): A dictionary specifying the reduction method for the metric.
31
  batch_size (int): The size of the bulk.
32
  """
33
 
34
  main_score: str = "llm_as_judge"
35
- task: Literal[
36
- "rating.single_turn",
37
- "rating.single_turn_with_reference",
38
- "pairwise_comparative_rating.single_turn",
39
- ]
40
  template: Template
41
  system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
42
  format: Format = Field(default_factory=SystemFormat)
43
- strip_system_prompt_and_format_from_inputs: bool = True
44
  inference_model: InferenceEngine
45
  reduction_map: Optional[Dict[str, List[str]]] = None
46
  batch_size: int = 32
47
  prediction_type = Any # Because handled with multiple tasks
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def _get_input_instances(self, task_data: List[Dict]) -> List:
50
  if self.strip_system_prompt_and_format_from_inputs:
51
  instances = []
@@ -119,6 +224,7 @@ class LLMAsJudge(BulkInstanceMetric):
119
  self.reduction_map = {"mean": [self.main_score]}
120
 
121
  def verify(self):
 
122
  supported_tasks = [
123
  "rating.single_turn",
124
  "rating.single_turn_with_reference",
@@ -129,68 +235,25 @@ class LLMAsJudge(BulkInstanceMetric):
129
  f"The supported tasks types are: {', '.join(supported_tasks)}."
130
  )
131
 
132
- if not isinstance(self.template, Template):
133
- raise ValueError(
134
- f"Provided template argument to 'LLMAsJudge' metric is not of type Template, but {type(self.template)}"
135
- )
136
- if self.format and not isinstance(self.format, Format):
137
- raise ValueError(
138
- f"Provided format argument to 'LLMAsJudge' metric is not of type Format, but {type(self.format)}"
139
- )
140
-
141
- if self.system_prompt and not isinstance(self.system_prompt, SystemPrompt):
142
- raise ValueError(
143
- f"Provided system_prompt argument to 'LLMAsJudge' metric is not of type SystemPrompt, but {type(self.system_prompt)}"
144
- )
145
-
146
- if isinstance(self.inference_model, OpenAiInferenceEngine):
147
- if self.format and type(self.format) is not SystemFormat:
148
- raise ValueError(
149
- "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
150
- "not support formatting. Please remove the format definition from the recipe"
151
- " (OpenAi Chat API take care of the formatting automatically)."
152
- )
153
- if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
154
- raise ValueError(
155
- "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
156
- "not support system prompt. Please remove the system_prompt definition from the recipe"
157
- " (Current implementation of Unitxt does not support this."
158
- " Support will be added in future updates)."
159
- )
160
 
161
- def compute(
162
- self,
163
- references: List[List[Any]],
164
- predictions: List[Any],
165
- task_data: List[Dict],
166
- ) -> List[Dict[str, Any]]:
167
- input_instances = self._get_input_instances(task_data)
168
- instances = self._get_instance_for_judge_model(
169
- input_instances, predictions, references
170
- )
171
- outputs = infer(
172
  instances,
173
  engine=self.inference_model,
174
- task=f"tasks.response_assessment.{self.task}",
175
  template=self.template,
176
  system_prompt=self.system_prompt,
177
  format=self.format,
178
  return_data=True,
179
  )
180
 
 
181
  results = []
182
  for instance in outputs:
183
  if self.task == "pairwise_comparative_rating.single_turn":
184
- import json
185
-
186
- # seems like the task data sometimes comes as a string, not a dict
187
- # this fixes it
188
- task_data = (
189
- json.loads(instance["task_data"])
190
- if isinstance(instance["task_data"], str)
191
- else instance["task_data"]
192
- )
193
-
194
  is_model_b_the_baseline = task_data["model_b"] == "baseline_model"
195
  if is_model_b_the_baseline:
196
  model_a_preference_score = instance["prediction"]
@@ -209,5 +272,141 @@ class LLMAsJudge(BulkInstanceMetric):
209
  "judge_raw_input": instance["source"],
210
  }
211
  results.append(result)
 
 
 
 
 
 
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
  from typing import Any, Dict, List, Literal, Optional
3
 
4
  from .api import infer
5
  from .artifact import fetch_artifact
6
  from .dataclass import Field
7
  from .formats import Format, SystemFormat
8
+ from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
9
  from .metrics import BulkInstanceMetric
10
  from .operator import SequentialOperator
11
  from .settings_utils import get_settings
 
15
  settings = get_settings()
16
 
17
 
18
+ def get_task_data_dict(task_data):
19
+ import json
20
+
21
+ # seems like the task data sometimes comes as a string, not a dict
22
+ # this fixes it
23
+ return json.loads(task_data) if isinstance(task_data, str) else task_data
24
+
25
+
26
+ class LLMAsJudgeBase(BulkInstanceMetric):
27
+ """LLM-as-judge-base metric class for evaluating correctness of generated predictions.
28
 
29
  Attributes:
30
  main_score (str): The main score label used for evaluation.
31
+ task (str): The type of task the llm as judge runs. This defines the output and input
32
  format of the judge model.
33
  template (Template): The template used when generating inputs for the judge llm.
34
  format (Format): The format used when generating inputs for judge llm.
35
  system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
 
 
36
  inference_model (InferenceEngine): The module that creates the inference of the judge llm.
37
  reduction_map (dict): A dictionary specifying the reduction method for the metric.
38
  batch_size (int): The size of the bulk.
39
  """
40
 
41
  main_score: str = "llm_as_judge"
42
+ task: str
 
 
 
 
43
  template: Template
44
  system_prompt: SystemPrompt = Field(default_factory=EmptySystemPrompt)
45
  format: Format = Field(default_factory=SystemFormat)
 
46
  inference_model: InferenceEngine
47
  reduction_map: Optional[Dict[str, List[str]]] = None
48
  batch_size: int = 32
49
  prediction_type = Any # Because handled with multiple tasks
50
 
51
+ def verify(self):
52
+ if not isinstance(self.template, Template):
53
+ raise ValueError(
54
+ f"Provided template argument to 'LLMAsJudge' metric is not of type Template, but {type(self.template)}"
55
+ )
56
+ if self.format and not isinstance(self.format, Format):
57
+ raise ValueError(
58
+ f"Provided format argument to 'LLMAsJudge' metric is not of type Format, but {type(self.format)}"
59
+ )
60
+
61
+ if self.system_prompt and not isinstance(self.system_prompt, SystemPrompt):
62
+ raise ValueError(
63
+ f"Provided system_prompt argument to 'LLMAsJudge' metric is not of type SystemPrompt, but {type(self.system_prompt)}"
64
+ )
65
+
66
+ if isinstance(self.inference_model, OpenAiInferenceEngine):
67
+ if self.format and type(self.format) is not SystemFormat:
68
+ raise ValueError(
69
+ "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
70
+ "not support formatting. Please remove the format definition from the recipe"
71
+ " (OpenAi Chat API take care of the formatting automatically)."
72
+ )
73
+ if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
74
+ raise ValueError(
75
+ "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
76
+ "not support system prompt. Please remove the system_prompt definition from the recipe"
77
+ " (Current implementation of Unitxt does not support this."
78
+ " Support will be added in future updates)."
79
+ )
80
+
81
+ @abstractmethod
82
+ def get_full_task_name(self):
83
+ pass
84
+
85
+ def compute(
86
+ self,
87
+ references: List[List[Any]],
88
+ predictions: List[Any],
89
+ task_data: List[Dict],
90
+ ) -> List[Dict[str, Any]]:
91
+ instances = self.prepare_instances(references, predictions, task_data)
92
+ outputs = self.infer_instances(instances)
93
+ return self.get_metric_results_from_prediction_outputs(outputs)
94
+
95
+ @abstractmethod
96
+ def prepare_instances(
97
+ self, references, predictions, task_data
98
+ ) -> List[Dict[str, Any]]:
99
+ """Generate a list of instances for inference.
100
+
101
+ Each generated instance should include all the fields required by the metrics' task and template, to
102
+ create the source prompt for the judge.
103
+ """
104
+ pass
105
+
106
+ @abstractmethod
107
+ def infer_instances(self, instances: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
108
+ """Generate the dataset and call the inference engine to generate the judges' predictions.
109
+
110
+ Return the list of the produced instances with their generated judge predictions.
111
+ """
112
+ pass
113
+
114
+ @abstractmethod
115
+ def get_metric_results_from_prediction_outputs(
116
+ self, outputs: List[Dict[str, Any]]
117
+ ) -> List[Dict[str, Any]]:
118
+ """Generate a scores' dictionary for each instance.
119
+
120
+ Return the list of scores dictionaries for the input instances.
121
+ """
122
+ pass
123
+
124
+
125
+ class LLMAsJudge(LLMAsJudgeBase):
126
+ """LLM-as-judge-based metric class for evaluating correctness of generated predictions.
127
+
128
+ This class uses the source prompt given to the generator and the generator's predictions to evaluate
129
+ correctness using one of three supported tasks (rating.single_turn, rating.single_turn_with_reference,
130
+ pairwise_comparative_rating.single_turn).
131
+
132
+ Attributes:
133
+ main_score (str): The main score label used for evaluation.
134
+ task (Literal["rating.single_turn","rating.single_turn_with_reference",
135
+ "pairwise_comparative_rating.single_turn"]): The type of task the llm as judge runs.
136
+ This defines the output and input format of the judge model.
137
+ template (Template): The template used when generating inputs for the judge llm.
138
+ format (Format): The format used when generating inputs for judge llm.
139
+ system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
140
+ strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
141
+ inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
142
+ inference_model (InferenceEngine): The module that creates the inference of the judge llm.
143
+ reduction_map (dict): A dictionary specifying the reduction method for the metric.
144
+ batch_size (int): The size of the bulk.
145
+ """
146
+
147
+ task: Literal[
148
+ "rating.single_turn",
149
+ "rating.single_turn_with_reference",
150
+ "pairwise_comparative_rating.single_turn",
151
+ ]
152
+ strip_system_prompt_and_format_from_inputs: bool = True
153
+
154
  def _get_input_instances(self, task_data: List[Dict]) -> List:
155
  if self.strip_system_prompt_and_format_from_inputs:
156
  instances = []
 
224
  self.reduction_map = {"mean": [self.main_score]}
225
 
226
  def verify(self):
227
+ super().verify()
228
  supported_tasks = [
229
  "rating.single_turn",
230
  "rating.single_turn_with_reference",
 
235
  f"The supported tasks types are: {', '.join(supported_tasks)}."
236
  )
237
 
238
+ def get_full_task_name(self):
239
+ return f"tasks.response_assessment.{self.task}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ def infer_instances(self, instances):
242
+ return infer(
 
 
 
 
 
 
 
 
 
243
  instances,
244
  engine=self.inference_model,
245
+ task=self.get_full_task_name(),
246
  template=self.template,
247
  system_prompt=self.system_prompt,
248
  format=self.format,
249
  return_data=True,
250
  )
251
 
252
+ def get_metric_results_from_prediction_outputs(self, outputs):
253
  results = []
254
  for instance in outputs:
255
  if self.task == "pairwise_comparative_rating.single_turn":
256
+ task_data = get_task_data_dict(instance["task_data"])
 
 
 
 
 
 
 
 
 
257
  is_model_b_the_baseline = task_data["model_b"] == "baseline_model"
258
  if is_model_b_the_baseline:
259
  model_a_preference_score = instance["prediction"]
 
272
  "judge_raw_input": instance["source"],
273
  }
274
  results.append(result)
275
+ return results
276
+
277
+ def prepare_instances(self, references, predictions, task_data):
278
+ input_instances = self._get_input_instances(task_data)
279
+ return self._get_instance_for_judge_model(
280
+ input_instances, predictions, references
281
+ )
282
 
283
+
284
+ class TaskBasedLLMasJudge(LLMAsJudgeBase):
285
+ """LLM-as-judge-based metric class for evaluating correctness of generated predictions.
286
+
287
+ This class can use any task and matching template to evaluate the predictions. All
288
+ task/templates field are taken from the instance's task_data.
289
+ The instances sent to the judge can either be: 1.a unitxt dataset, in which case the predictions are
290
+ copied to a specified field of the task. 2. dictionaries with the fields required by the task and template.
291
+
292
+ Attributes:
293
+ main_score (str): The main score label used for evaluation.
294
+ task (str): The type of task the llm as judge runs.
295
+ This defines the output and input format of the judge model.
296
+ template (Template): The template used when generating inputs for the judge llm.
297
+ format (Format): The format used when generating inputs for judge llm.
298
+ system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
299
+ strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
300
+ inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
301
+ inference_model (InferenceEngine): The module that creates the inference of the judge llm.
302
+ reduction_map (dict): A dictionary specifying the reduction method for the metric.
303
+ batch_size (int): The size of the bulk.
304
+ infer_log_probs(bool): whether to perform the inference using logprobs. If true, the template's
305
+ post-processing must support the logprobs output.
306
+ judge_to_generator_fields_mapping (Dict[str, str]): optional mapping between the names of the fields in the generator task and the
307
+ judge task. For example, if the generator task uses "reference_answers" and the judge task expect "ground_truth",
308
+ include {"ground_truth": "reference_answers"} in this dictionary.
309
+ prediction_field: if indicated, and prediction exist, copy prediction to this field name in task_data.
310
+ include_meta_data (bool): whether to include the inference per-instance metadata in the returned results.
311
+
312
+ """
313
+
314
+ infer_log_probs: bool = False
315
+ judge_to_generator_fields_mapping: Dict[str, str] = {}
316
+ prediction_field: Optional[str] = None
317
+ include_meta_data: bool = True
318
+
319
+ # Allow for input which is a dictionary of all input fields. In this case, all input fields are
320
+ # treated as the task data, and the predictions and references are taken directly from there
321
+ # by the judge's template
322
+ def preprocess_instance(self, instance):
323
+ if "task_data" not in instance:
324
+ instance["task_data"] = instance.copy()
325
+ if "prediction" not in instance:
326
+ instance["prediction"] = None
327
+ if "references" not in instance:
328
+ instance["references"] = [""]
329
+ return instance
330
+
331
+ def verify(self):
332
+ super().verify()
333
+ if self.infer_log_probs and not isinstance(
334
+ self.inference_model, LogProbInferenceEngine
335
+ ):
336
+ raise NotImplementedError(
337
+ f"Error in TaskBasedLLMasJudge: return_log_probs set to True but supplied engine "
338
+ f"{self.inference_model.__class__.__name__} does not support logprobs."
339
+ )
340
+ if self.include_meta_data and not hasattr(
341
+ self.inference_model, "get_return_object"
342
+ ):
343
+ Warning(
344
+ f"Supplied inference engine {self.inference_model.__class__.__name__} does not support "
345
+ "return_meta_data. Setting return_meta_data to False. Metadata scores will not appear "
346
+ "in returned instances scores."
347
+ )
348
+ self.include_meta_data = False
349
+
350
+ def prepare(self):
351
+ super().prepare()
352
+ self.reduction_map = {"mean": [self.main_score]}
353
+ self.score_prefix = f"{self.inference_model.get_engine_id()}_"
354
+
355
+ def get_full_task_name(self):
356
+ return self.task
357
+
358
+ def get_metric_results_from_prediction_outputs(self, outputs):
359
+ results = []
360
+ for instance in outputs:
361
+ result = {
362
+ self.main_score: instance["prediction"],
363
+ f"{self.main_score}_judge_raw_output": instance["raw_prediction"],
364
+ f"{self.main_score}_judge_raw_input": instance["source"],
365
+ }
366
+ if self.include_meta_data:
367
+ meta_data = {
368
+ f"{self.main_score}_{k}": v
369
+ for k, v in instance["infer_meta_data"].items()
370
+ }
371
+ result.update(meta_data)
372
+ results.append(result)
373
  return results
374
+
375
+ def prepare_instances(self, references, predictions, task_data):
376
+ from . import get_from_catalog
377
+
378
+ instances = []
379
+ judge_task = get_from_catalog(self.get_full_task_name())
380
+ judge_task_input_fields = judge_task.input_fields
381
+
382
+ for input_instance, prediction, _ in zip(task_data, predictions, references):
383
+ input_instance = get_task_data_dict(input_instance)
384
+
385
+ instance_task_data = {}
386
+ for judge_task_input_field in judge_task_input_fields:
387
+ orig_task_field_name = self.judge_to_generator_fields_mapping.get(
388
+ judge_task_input_field, judge_task_input_field
389
+ )
390
+ new_val = input_instance.get(orig_task_field_name)
391
+ if new_val:
392
+ instance_task_data[judge_task_input_field] = new_val
393
+
394
+ if self.prediction_field and prediction:
395
+ instance_task_data[self.prediction_field] = str(prediction)
396
+ instance_task_data = judge_task.process(instance_task_data)["input_fields"]
397
+ instances.append(instance_task_data)
398
+
399
+ return instances
400
+
401
+ def infer_instances(self, instances):
402
+ return infer(
403
+ instances,
404
+ engine=self.inference_model,
405
+ task=self.get_full_task_name(),
406
+ template=self.template,
407
+ system_prompt=self.system_prompt,
408
+ format=self.format,
409
+ return_data=True,
410
+ return_log_probs=self.infer_log_probs,
411
+ return_meta_data=self.include_meta_data,
412
+ )
loaders.py CHANGED
@@ -53,7 +53,7 @@ from .operators import Set
53
  from .settings_utils import get_settings
54
  from .stream import DynamicStream, MultiStream
55
  from .type_utils import isoftype
56
- from .utils import deepcopy
57
 
58
  logger = get_logger()
59
  settings = get_settings()
@@ -195,6 +195,10 @@ class LoadHF(Loader):
195
  def stream_dataset(self):
196
  if self._cache is None:
197
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
 
 
 
 
198
  try:
199
  dataset = hf_load_dataset(
200
  self.path,
@@ -203,7 +207,7 @@ class LoadHF(Loader):
203
  data_files=self.data_files,
204
  revision=self.revision,
205
  streaming=self.streaming,
206
- cache_dir=None if self.streaming else dir_to_be_deleted,
207
  split=self.split,
208
  trust_remote_code=settings.allow_unverified_code,
209
  num_proc=self.num_proc,
@@ -231,6 +235,10 @@ class LoadHF(Loader):
231
  def load_dataset(self):
232
  if self._cache is None:
233
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
 
 
 
 
234
  try:
235
  dataset = hf_load_dataset(
236
  self.path,
@@ -239,7 +247,7 @@ class LoadHF(Loader):
239
  data_files=self.data_files,
240
  streaming=False,
241
  keep_in_memory=True,
242
- cache_dir=dir_to_be_deleted,
243
  split=self.split,
244
  trust_remote_code=settings.allow_unverified_code,
245
  num_proc=self.num_proc,
@@ -664,7 +672,7 @@ class MultipleSourceLoader(Loader):
664
 
665
  .. code-block:: python
666
 
667
- MultipleSourceLoader(loaders = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
668
 
669
 
670
 
@@ -672,7 +680,7 @@ class MultipleSourceLoader(Loader):
672
 
673
  .. code-block:: python
674
 
675
- MultipleSourceLoader(loaders = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
676
  """
677
 
678
  sources: List[Loader]
@@ -737,7 +745,7 @@ class LoadFromDictionary(Loader):
737
  self.sef_default_data_classification(
738
  ["proprietary"], "when loading from python dictionary"
739
  )
740
- return MultiStream.from_iterables(deepcopy(self.data))
741
 
742
 
743
  class LoadFromHFSpace(LoadHF):
 
53
  from .settings_utils import get_settings
54
  from .stream import DynamicStream, MultiStream
55
  from .type_utils import isoftype
56
+ from .utils import recursive_copy
57
 
58
  logger = get_logger()
59
  settings = get_settings()
 
195
  def stream_dataset(self):
196
  if self._cache is None:
197
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
198
+ if settings.disable_hf_datasets_cache and not self.streaming:
199
+ cache_dir = dir_to_be_deleted
200
+ else:
201
+ cache_dir = None
202
  try:
203
  dataset = hf_load_dataset(
204
  self.path,
 
207
  data_files=self.data_files,
208
  revision=self.revision,
209
  streaming=self.streaming,
210
+ cache_dir=cache_dir,
211
  split=self.split,
212
  trust_remote_code=settings.allow_unverified_code,
213
  num_proc=self.num_proc,
 
235
  def load_dataset(self):
236
  if self._cache is None:
237
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
238
+ if settings.disable_hf_datasets_cache:
239
+ cache_dir = dir_to_be_deleted
240
+ else:
241
+ cache_dir = None
242
  try:
243
  dataset = hf_load_dataset(
244
  self.path,
 
247
  data_files=self.data_files,
248
  streaming=False,
249
  keep_in_memory=True,
250
+ cache_dir=cache_dir,
251
  split=self.split,
252
  trust_remote_code=settings.allow_unverified_code,
253
  num_proc=self.num_proc,
 
672
 
673
  .. code-block:: python
674
 
675
+ MultipleSourceLoader(sources = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
676
 
677
 
678
 
 
680
 
681
  .. code-block:: python
682
 
683
+ MultipleSourceLoader(sources = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
684
  """
685
 
686
  sources: List[Loader]
 
745
  self.sef_default_data_classification(
746
  ["proprietary"], "when loading from python dictionary"
747
  )
748
+ return MultiStream.from_iterables(recursive_copy(self.data))
749
 
750
 
751
  class LoadFromHFSpace(LoadHF):
metric_utils.py CHANGED
@@ -16,8 +16,8 @@ from .operator import (
16
  from .operators import (
17
  ApplyMetric,
18
  ApplyOperatorsField,
19
- Copy,
20
  FlattenInstances,
 
21
  Rename,
22
  )
23
  from .register import _reset_env_local_catalogs, register_all_artifacts
@@ -25,7 +25,7 @@ from .schema import UNITXT_DATASET_SCHEMA
25
  from .settings_utils import get_constants, get_settings
26
  from .stream import DynamicStream, MultiStream
27
  from .struct_data_operators import LoadJson
28
- from .utils import deepcopy
29
 
30
  constants = get_constants()
31
 
@@ -54,27 +54,27 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
54
 
55
  _post_process_steps = SequentialOperator(
56
  steps=[
57
- Copy(
58
  field="prediction",
59
  to_field="raw_prediction",
60
  ),
61
- Copy(
62
  field="references",
63
  to_field="raw_references",
64
  dont_apply_to_streams=[constants.inference_stream],
65
  ),
66
- Copy(
67
  field="source",
68
  to_field="task_data/source",
69
  ),
70
  ApplyOperatorsField(
71
  operators_field="postprocessors",
72
  ),
73
- Copy(
74
  field="prediction",
75
  to_field="processed_prediction",
76
  ),
77
- Copy(
78
  field="references",
79
  to_field="processed_references",
80
  dont_apply_to_streams=[constants.inference_stream],
@@ -213,14 +213,19 @@ class JoinSubsetsAndGroups(MultiStreamOperator):
213
 
214
  result = {}
215
  all_scores = []
 
216
  for k, v in dic.items():
217
  score = recursive_mean(v)
218
  if score is not None:
219
  all_scores.append(score["score"])
 
 
220
  result[k] = score
221
 
222
  result["score"] = nan_mean(all_scores)
223
  result["score_name"] = "subsets_mean"
 
 
224
 
225
  if result:
226
  return result
@@ -237,11 +242,15 @@ class JoinSubsetsAndGroups(MultiStreamOperator):
237
  "score": score["subsets"]["score"],
238
  "score_name": score["subsets"]["score_name"],
239
  }
 
 
 
 
240
 
241
  sorted_instances = []
242
  for key in sorted(stream_instances.keys()):
243
  instance = stream_instances[key]
244
- instance["score"].update(deepcopy(score))
245
  sorted_instances.append(instance)
246
  result[stream_name] = sorted_instances
247
 
@@ -299,7 +308,7 @@ class MetricRecipe(SequentialOperatorInitializer):
299
  field="raw_references",
300
  to_field="references",
301
  ),
302
- Copy(
303
  field="source",
304
  to_field="task_data/source",
305
  ),
 
16
  from .operators import (
17
  ApplyMetric,
18
  ApplyOperatorsField,
 
19
  FlattenInstances,
20
+ RecursiveCopy,
21
  Rename,
22
  )
23
  from .register import _reset_env_local_catalogs, register_all_artifacts
 
25
  from .settings_utils import get_constants, get_settings
26
  from .stream import DynamicStream, MultiStream
27
  from .struct_data_operators import LoadJson
28
+ from .utils import recursive_shallow_copy
29
 
30
  constants = get_constants()
31
 
 
54
 
55
  _post_process_steps = SequentialOperator(
56
  steps=[
57
+ RecursiveCopy(
58
  field="prediction",
59
  to_field="raw_prediction",
60
  ),
61
+ RecursiveCopy(
62
  field="references",
63
  to_field="raw_references",
64
  dont_apply_to_streams=[constants.inference_stream],
65
  ),
66
+ RecursiveCopy(
67
  field="source",
68
  to_field="task_data/source",
69
  ),
70
  ApplyOperatorsField(
71
  operators_field="postprocessors",
72
  ),
73
+ RecursiveCopy(
74
  field="prediction",
75
  to_field="processed_prediction",
76
  ),
77
+ RecursiveCopy(
78
  field="references",
79
  to_field="processed_references",
80
  dont_apply_to_streams=[constants.inference_stream],
 
213
 
214
  result = {}
215
  all_scores = []
216
+ all_num_of_instances = []
217
  for k, v in dic.items():
218
  score = recursive_mean(v)
219
  if score is not None:
220
  all_scores.append(score["score"])
221
+ if "num_of_instances" in score:
222
+ all_num_of_instances.append(score["num_of_instances"])
223
  result[k] = score
224
 
225
  result["score"] = nan_mean(all_scores)
226
  result["score_name"] = "subsets_mean"
227
+ if all_num_of_instances:
228
+ result["num_of_instances"] = sum(all_num_of_instances)
229
 
230
  if result:
231
  return result
 
242
  "score": score["subsets"]["score"],
243
  "score_name": score["subsets"]["score_name"],
244
  }
245
+ if "num_of_instances" in score["subsets"]:
246
+ score["global"]["num_of_instances"] = score["subsets"][
247
+ "num_of_instances"
248
+ ]
249
 
250
  sorted_instances = []
251
  for key in sorted(stream_instances.keys()):
252
  instance = stream_instances[key]
253
+ instance["score"].update(recursive_shallow_copy(score))
254
  sorted_instances.append(instance)
255
  result[stream_name] = sorted_instances
256
 
 
308
  field="raw_references",
309
  to_field="references",
310
  ),
311
+ RecursiveCopy(
312
  field="source",
313
  to_field="task_data/source",
314
  ),
metrics.py CHANGED
@@ -8,10 +8,9 @@ import warnings
8
  from abc import ABC, abstractmethod
9
  from collections import Counter, defaultdict
10
  from dataclasses import field
11
- from operator import itemgetter
12
  from typing import Any, Dict, Generator, List, Optional, Tuple, Union
13
 
14
- import evaluate
15
  import numpy
16
  import numpy as np
17
  import pandas as pd
@@ -37,20 +36,18 @@ from .operator import (
37
  StreamingOperator,
38
  StreamOperator,
39
  )
40
- from .operators import Copy
41
  from .random_utils import get_seed
42
  from .settings_utils import get_settings
43
  from .stream import MultiStream, Stream
44
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
45
- from .utils import deepcopy
46
 
47
  logger = get_logger()
48
  settings = get_settings()
49
 
50
  warnings.filterwarnings("ignore", category=DegenerateDataWarning)
51
 
52
- warnings.filterwarnings("ignore", category=DegenerateDataWarning)
53
-
54
 
55
  def abstract_factory():
56
  return {}
@@ -139,6 +136,7 @@ class Metric(Artifact):
139
  return (
140
  self.score_prefix + score_name
141
  if score_name not in ["score", "score_name"]
 
142
  else score_name
143
  )
144
 
@@ -147,18 +145,24 @@ class Metric(Artifact):
147
  ) -> Dict[str, Any]:
148
  new_scores = {}
149
  for score_name, score in scores.items():
 
 
 
150
  score_with_prefix = self._add_score_prefix(score_name)
151
  new_scores[score_with_prefix] = (
152
  score if score_name not in ["score_name"] else self.score_prefix + score
153
  )
154
  for new_score_name in new_scores:
155
- if new_score_name in ["score", "score_name"]:
 
 
156
  continue
157
  if new_score_name in existing_scores:
158
  UnitxtWarning(
159
  message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
160
  f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
161
- f"To avoid overwriting the existing value, add a score_prefix to the metric (e.g. score_prefix='my_second_').",
 
162
  additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
163
  )
164
  return new_scores
@@ -279,7 +283,12 @@ class Metric(Artifact):
279
  self, instance: Dict[str, Any], global_score: dict
280
  ):
281
  for score_name in global_score:
282
- if score_name in ["score", "score_name", "score_ci_low", "score_ci_high"]:
 
 
 
 
 
283
  continue
284
  if score_name in instance["score"]["global"]:
285
  UnitxtWarning(
@@ -469,11 +478,17 @@ class MetricWithConfidenceInterval(Metric):
469
  # iterate over the rows and compute the metric on each resampling
470
  def metric(sample_refs, sample_preds, sample_task_data):
471
  try:
472
- return self._compute(
473
  references=sample_refs,
474
  predictions=sample_preds,
475
  task_data=sample_task_data,
476
- )["score"]
 
 
 
 
 
 
477
  except Exception as e:
478
  # this happens in edge cases, for example, when the sampling creates a
479
  # sample where all strings are empty and this fails bleu.
@@ -538,7 +553,6 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
538
  references = []
539
  predictions = []
540
  task_data = []
541
- global_score = {}
542
 
543
  instances = []
544
 
@@ -589,6 +603,7 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
589
  )
590
  )
591
  self._validate_references_and_prediction(references, predictions)
 
592
 
593
  result = self._compute(references, predictions, task_data)
594
  global_score.update(
@@ -596,11 +611,18 @@ class GlobalMetric(StreamOperator, MetricWithConfidenceInterval):
596
  result, global_score
597
  )
598
  )
599
- score_name = global_score["score_name"]
600
- confidence_interval = self.compute_global_confidence_intervals(
601
- references, predictions, task_data, score_name
602
- )
603
- global_score.update(confidence_interval)
 
 
 
 
 
 
 
604
 
605
  for instance in instances:
606
  self.update_and_adjust_global_score(instance, global_score)
@@ -649,28 +671,24 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
649
  default_factory=lambda: ["mean", "weighted_win_rate"]
650
  )
651
 
 
 
 
652
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
653
- global_score = {}
654
  instances = []
 
 
 
 
655
 
656
- # consume the stream
657
- references, predictions = map(
658
- list,
659
- zip(
660
- *[
661
- itemgetter("references", "prediction")(
662
- self.verify_instance(instance)
663
- )
664
- for instance in stream
665
- ]
666
- ),
667
- )
668
-
669
  task_data = [
670
  instance["task_data"] if "task_data" in instance else {}
671
- for instance in stream
672
  ]
673
  self._validate_references_and_prediction(references, predictions)
 
674
  # compute the metric over all refs and preds
675
  instance_scores = self.compute(
676
  references=references,
@@ -683,7 +701,7 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
683
  instance_score["score"] = instance_score[self.main_score]
684
  instance_score["score_name"] = self.main_score
685
 
686
- for instance, score in zip(stream, instance_scores):
687
  if "score" not in instance:
688
  instance["score"] = {"global": {}, "instance": {}}
689
 
@@ -692,7 +710,6 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):
692
  score, instance["score"]["instance"]
693
  )
694
  )
695
- instances.append(instance)
696
 
697
  for reduction, fields in self.reduction_map.items():
698
  assert (
@@ -1059,7 +1076,7 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1059
 
1060
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1061
  instances = self.compute_instance_scores(stream)
1062
- global_score = {}
1063
  for reduction_type, reduction_params in self.reduction_map.items():
1064
  assert (
1065
  reduction_type in self.implemented_reductions
@@ -1096,7 +1113,10 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1096
  scores_to_resample,
1097
  aggregation_function,
1098
  ) = self._set_up_group_mean_aggregation(
1099
- instances, reduction_params, reduction_fields
 
 
 
1100
  )
1101
  else:
1102
  raise ValueError(
@@ -1171,13 +1191,16 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1171
  instance_score["score_name"] = self.main_score
1172
  if "score" not in instance:
1173
  instance["score"] = {"global": {}, "instance": {}}
 
 
 
 
1174
 
1175
  instance["score"]["instance"].update(
1176
  self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1177
  instance_score, instance["score"]["instance"]
1178
  )
1179
  )
1180
-
1181
  instances.append(instance)
1182
 
1183
  return instances
@@ -1187,7 +1210,9 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1187
  instances: List[dict],
1188
  score_names: List[str],
1189
  group_aggregation_func,
1190
- prepend_score_prefix: bool = True,
 
 
1191
  ):
1192
  """Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
1193
 
@@ -1199,6 +1224,8 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1199
  callable function returns a single score for the group
1200
  prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
1201
  if down the stream such a prepending is expected.
 
 
1202
 
1203
  Returns:
1204
  List of dicts, each corresponding to a group of instances (defined by 'group_id'),
@@ -1233,8 +1260,27 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1233
  ]
1234
  )
1235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1236
  # if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
1237
- return [
1238
  {
1239
  "score": {
1240
  "instance": {
@@ -1255,9 +1301,25 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1255
  ) # sorted for consistency
1256
  ]
1257
 
 
 
 
 
 
 
 
 
 
 
 
1258
  def _set_up_group_mean_aggregation(
1259
- self, instances, reduction_params, reduction_fields
 
 
 
 
1260
  ):
 
1261
  group_aggregation_func = reduction_params["agg_func"][1]
1262
  # if treat groups as units
1263
  do_resample_as_group = reduction_params["agg_func"][2]
@@ -1265,7 +1327,12 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1265
  # pass the group aggregate---not instance---scores to resample as usual
1266
  aggregation_function = self.average_item_scores
1267
  scores_to_resample = self.get_group_scores(
1268
- instances, reduction_fields, group_aggregation_func
 
 
 
 
 
1269
  )
1270
  else:
1271
  # pass the instance scores to resample, and calculate the group aggregation on the resamplings
@@ -1277,7 +1344,12 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
1277
  group_aggregation_func=group_aggregation_func,
1278
  ):
1279
  group_scores = self.get_group_scores(
1280
- instances, [field_name], group_aggregation_func, False
 
 
 
 
 
1281
  )
1282
  return nan_mean(
1283
  [group["score"]["instance"][field_name] for group in group_scores]
@@ -1315,6 +1387,19 @@ class ANLS(InstanceMetric):
1315
  reduction_map = {"mean": ["anls"]}
1316
  prediction_type = Any # string representation is compared
1317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1318
  def compute(
1319
  self,
1320
  references: List[Any],
@@ -1324,20 +1409,14 @@ class ANLS(InstanceMetric):
1324
  ) -> dict:
1325
  """ANLS image-text accuracy metric."""
1326
  values = []
1327
- for answer in references:
1328
- # preprocess both the answers - gt and prediction
1329
- gt_answer = " ".join(answer.strip().lower().split())
1330
- det_answer = " ".join(prediction.strip().lower().split())
1331
-
1332
- # dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
1333
- dist = self.levenshtein_distance(gt_answer, det_answer)
1334
- length = max(len(answer.upper()), len(prediction.upper()))
1335
- values.append(0.0 if length == 0 else float(dist) / float(length))
1336
 
1337
  question_result = 1.0 - min(values)
1338
 
1339
  if question_result < threshold:
1340
  question_result = 0.0
 
1341
  result = {}
1342
  result["score"] = question_result
1343
  result[self.main_score] = question_result
@@ -1345,6 +1424,7 @@ class ANLS(InstanceMetric):
1345
  return result
1346
 
1347
  @staticmethod
 
1348
  def levenshtein_distance(s1, s2):
1349
  if len(s1) > len(s2):
1350
  s1, s2 = s2, s1
@@ -1526,16 +1606,40 @@ class MetricPipeline(MultiStreamOperator, Metric):
1526
  ), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
1527
  if has_postpreprocess:
1528
  self.postprocess_steps = self.postpreprocess_steps
1529
- self.prepare_score = Copy(
1530
- field_to_field=[
1531
- [
1532
- f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
1533
- "score/instance/score",
1534
- ],
1535
- [
1536
- f"score/global/{self.metric._add_score_prefix(self.main_score)}",
1537
- "score/global/score",
1538
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1539
  ],
1540
  )
1541
 
@@ -1589,6 +1693,8 @@ class HuggingfaceMetric(GlobalMetric):
1589
 
1590
  def prepare(self):
1591
  super().prepare()
 
 
1592
  self.metric = evaluate.load(
1593
  self.hf_metric_name, experiment_id=self.experiment_id
1594
  )
@@ -1663,6 +1769,8 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
1663
 
1664
  def prepare(self):
1665
  super().prepare()
 
 
1666
  self.metric = evaluate.load(
1667
  self.hf_metric_name, experiment_id=str(uuid.uuid4())
1668
  )
@@ -1709,6 +1817,8 @@ class HuggingfaceInstanceMetric(InstanceMetric):
1709
 
1710
  def prepare(self):
1711
  super().prepare()
 
 
1712
  self.metric = evaluate.load(
1713
  self.hf_metric_name, experiment_id=str(uuid.uuid4())
1714
  )
@@ -1788,6 +1898,8 @@ class F1(GlobalMetric):
1788
 
1789
  def prepare(self):
1790
  super().prepare()
 
 
1791
  self._metric = evaluate.load(self.metric, experiment_id=str(uuid.uuid4()))
1792
 
1793
  def get_str_id(self, str):
@@ -1847,6 +1959,7 @@ class F1Binary(GlobalMetric):
1847
  _metric = None
1848
  metric = "f1"
1849
  single_reference_per_prediction = True
 
1850
  _requirements_list: List[str] = ["sklearn"]
1851
 
1852
  def prepare(self):
@@ -2064,6 +2177,8 @@ class F1MultiLabel(GlobalMetric):
2064
 
2065
  def prepare(self):
2066
  super().prepare()
 
 
2067
  self._metric = evaluate.load(
2068
  self.metric, "multilabel", experiment_id=str(uuid.uuid4())
2069
  )
@@ -3033,7 +3148,7 @@ class SafetyMetric(GlobalMetric):
3033
  class LlamaIndexLLMMetric(InstanceMetric):
3034
  model_name: str = ""
3035
  main_score: str = ""
3036
- prediction_type: str = str
3037
  reduction_map: Dict[str, List[str]] = None
3038
  openai_models: List[str] = ["gpt-3.5-turbo"]
3039
  anthropic_models: List[
@@ -3679,6 +3794,7 @@ class RetrievalAtK(RetrievalMetric):
3679
  (recall_at_k, "recall"),
3680
  (match_at_k, "match"),
3681
  ]:
 
3682
  max_k = max(measure_array.keys())
3683
  for k in self.k_list:
3684
  result[self.score_name(measure_name, k)] = measure_array[min(k, max_k)]
@@ -3725,7 +3841,7 @@ class RemoteMetric(StreamOperator, Metric):
3725
  remotely (pre and post processing steps in the MetricPipeline will be computed locally).
3726
  """
3727
  local_inner_metric = metric_pipeline.metric
3728
- metric_pipeline = deepcopy(
3729
  metric_pipeline
3730
  ) # To avoid unintentional changes to the catalog contents
3731
  metric_pipeline.metric = RemoteMetric(
@@ -4376,6 +4492,7 @@ class BinaryMaxF1(F1Binary):
4376
  main_score = "max_f1_binary"
4377
  single_reference_per_prediction = True
4378
  average = None
 
4379
 
4380
  def compute(
4381
  self,
@@ -4799,17 +4916,22 @@ class F1Strings(InstanceMetric):
4799
  "spacy": "Please pip install spacy",
4800
  }
4801
 
4802
- def prepare(self):
4803
- super().prepare()
4804
  import spacy
4805
 
 
 
 
 
 
 
4806
  try:
4807
- self.nlp = spacy.load("en_core_web_sm")
4808
  except OSError:
4809
  from spacy.cli import download
4810
 
4811
  download("en_core_web_sm")
4812
- self.nlp = spacy.load("en_core_web_sm")
4813
 
4814
  def compute(
4815
  self,
@@ -4955,3 +5077,20 @@ class RandomForestMetricsEnsemble(MetricsEnsemble):
4955
  )
4956
  score = ensemble_model.predict([prediction_lst])
4957
  return score.tolist()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from abc import ABC, abstractmethod
9
  from collections import Counter, defaultdict
10
  from dataclasses import field
11
+ from functools import lru_cache
12
  from typing import Any, Dict, Generator, List, Optional, Tuple, Union
13
 
 
14
  import numpy
15
  import numpy as np
16
  import pandas as pd
 
36
  StreamingOperator,
37
  StreamOperator,
38
  )
39
+ from .operators import Copy, Set
40
  from .random_utils import get_seed
41
  from .settings_utils import get_settings
42
  from .stream import MultiStream, Stream
43
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
44
+ from .utils import deep_copy
45
 
46
  logger = get_logger()
47
  settings = get_settings()
48
 
49
  warnings.filterwarnings("ignore", category=DegenerateDataWarning)
50
 
 
 
51
 
52
  def abstract_factory():
53
  return {}
 
136
  return (
137
  self.score_prefix + score_name
138
  if score_name not in ["score", "score_name"]
139
+ and not score_name.startswith("num_of_instances")
140
  else score_name
141
  )
142
 
 
145
  ) -> Dict[str, Any]:
146
  new_scores = {}
147
  for score_name, score in scores.items():
148
+ if isinstance(score, dict):
149
+ new_scores[score_name] = score
150
+ continue # do not prefix group names
151
  score_with_prefix = self._add_score_prefix(score_name)
152
  new_scores[score_with_prefix] = (
153
  score if score_name not in ["score_name"] else self.score_prefix + score
154
  )
155
  for new_score_name in new_scores:
156
+ if new_score_name in ["score", "score_name"] or new_score_name.startswith(
157
+ "num_of_instances"
158
+ ):
159
  continue
160
  if new_score_name in existing_scores:
161
  UnitxtWarning(
162
  message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
163
  f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
164
+ f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
165
+ f"which will yield, in this case, a score named: 'my_second_{new_score_name}')",
166
  additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
167
  )
168
  return new_scores
 
283
  self, instance: Dict[str, Any], global_score: dict
284
  ):
285
  for score_name in global_score:
286
+ if score_name in [
287
+ "score",
288
+ "score_name",
289
+ "score_ci_low",
290
+ "score_ci_high",
291
+ ] or score_name.startswith("num_of_instances"):
292
  continue
293
  if score_name in instance["score"]["global"]:
294
  UnitxtWarning(
 
478
  # iterate over the rows and compute the metric on each resampling
479
  def metric(sample_refs, sample_preds, sample_task_data):
480
  try:
481
+ results = self._compute(
482
  references=sample_refs,
483
  predictions=sample_preds,
484
  task_data=sample_task_data,
485
+ )
486
+ results.update(
487
+ self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
488
+ results, {}
489
+ )
490
+ )
491
+ return results[score_name]
492
  except Exception as e:
493
  # this happens in edge cases, for example, when the sampling creates a
494
  # sample where all strings are empty and this fails bleu.
 
553
  references = []
554
  predictions = []
555
  task_data = []
 
556
 
557
  instances = []
558
 
 
603
  )
604
  )
605
  self._validate_references_and_prediction(references, predictions)
606
+ global_score = {"num_of_instances": len(instances)}
607
 
608
  result = self._compute(references, predictions, task_data)
609
  global_score.update(
 
611
  result, global_score
612
  )
613
  )
614
+ if self.ci_scores:
615
+ score_names = [
616
+ self._add_score_prefix(score_name) for score_name in self.ci_scores
617
+ ]
618
+ else:
619
+ score_names = [global_score["score_name"]]
620
+
621
+ for score_name in score_names:
622
+ confidence_interval = self.compute_global_confidence_intervals(
623
+ references, predictions, task_data, score_name
624
+ )
625
+ global_score.update(confidence_interval)
626
 
627
  for instance in instances:
628
  self.update_and_adjust_global_score(instance, global_score)
 
671
  default_factory=lambda: ["mean", "weighted_win_rate"]
672
  )
673
 
674
+ def preprocess_instance(self, instance):
675
+ return instance
676
+
677
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
 
678
  instances = []
679
+ for instance in stream:
680
+ self.verify_instance(instance)
681
+ instance = self.preprocess_instance(instance)
682
+ instances.append(instance)
683
 
684
+ predictions = [instance["prediction"] for instance in instances]
685
+ references = [instance["references"] for instance in instances]
 
 
 
 
 
 
 
 
 
 
 
686
  task_data = [
687
  instance["task_data"] if "task_data" in instance else {}
688
+ for instance in instances
689
  ]
690
  self._validate_references_and_prediction(references, predictions)
691
+ global_score = {"num_of_instances": len(instances)}
692
  # compute the metric over all refs and preds
693
  instance_scores = self.compute(
694
  references=references,
 
701
  instance_score["score"] = instance_score[self.main_score]
702
  instance_score["score_name"] = self.main_score
703
 
704
+ for instance, score in zip(instances, instance_scores):
705
  if "score" not in instance:
706
  instance["score"] = {"global": {}, "instance": {}}
707
 
 
710
  score, instance["score"]["instance"]
711
  )
712
  )
 
713
 
714
  for reduction, fields in self.reduction_map.items():
715
  assert (
 
1076
 
1077
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1078
  instances = self.compute_instance_scores(stream)
1079
+ global_score = {"num_of_instances": len(instances)}
1080
  for reduction_type, reduction_params in self.reduction_map.items():
1081
  assert (
1082
  reduction_type in self.implemented_reductions
 
1113
  scores_to_resample,
1114
  aggregation_function,
1115
  ) = self._set_up_group_mean_aggregation(
1116
+ instances,
1117
+ reduction_params,
1118
+ reduction_fields,
1119
+ global_score,
1120
  )
1121
  else:
1122
  raise ValueError(
 
1191
  instance_score["score_name"] = self.main_score
1192
  if "score" not in instance:
1193
  instance["score"] = {"global": {}, "instance": {}}
1194
+ if "global" not in instance["score"]:
1195
+ instance["score"]["global"] = {}
1196
+ if "instance" not in instance["score"]:
1197
+ instance["score"]["instance"] = {}
1198
 
1199
  instance["score"]["instance"].update(
1200
  self._add_score_prefixes_to_score_dict_and_check_against_existing_scores(
1201
  instance_score, instance["score"]["instance"]
1202
  )
1203
  )
 
1204
  instances.append(instance)
1205
 
1206
  return instances
 
1210
  instances: List[dict],
1211
  score_names: List[str],
1212
  group_aggregation_func,
1213
+ prepend_score_prefix: bool,
1214
+ global_score: dict,
1215
+ aggregation_function_name: str,
1216
  ):
1217
  """Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
1218
 
 
1224
  callable function returns a single score for the group
1225
  prepend_score_prefix: if True - prepend the score_prefix to the score names in the returned dicts. Set to False
1226
  if down the stream such a prepending is expected.
1227
+ global_score: the being built up global score. It will be filled here with number of instances per each group, and group scores.
1228
+ aggregation_function_name: used to annotate the groups' global scores.
1229
 
1230
  Returns:
1231
  List of dicts, each corresponding to a group of instances (defined by 'group_id'),
 
1260
  ]
1261
  )
1262
 
1263
+ # count the instances in each group and subgroup.
1264
+ # Each instance goes into group_to_instances per each score_name.
1265
+ # So we count over the first score_name only
1266
+ for group_key in group_to_instance_scores:
1267
+ if group_key not in global_score:
1268
+ global_score[group_key] = {}
1269
+ global_score[group_key]["num_of_instances"] = sum(
1270
+ [
1271
+ len(
1272
+ group_to_instance_scores[group_key][score_names[0]][
1273
+ subgroup_type
1274
+ ]
1275
+ )
1276
+ for subgroup_type in group_to_instance_scores[group_key][
1277
+ score_names[0]
1278
+ ]
1279
+ ]
1280
+ )
1281
+
1282
  # if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
1283
+ to_return = [
1284
  {
1285
  "score": {
1286
  "instance": {
 
1301
  ) # sorted for consistency
1302
  ]
1303
 
1304
+ # update each group section in global_score
1305
+ for i, group_name in enumerate(sorted(group_to_instance_scores.keys())):
1306
+ global_score[group_name].update(
1307
+ {
1308
+ aggregation_function_name + "_" + k: v
1309
+ for k, v in to_return[i]["score"]["instance"].items()
1310
+ }
1311
+ )
1312
+
1313
+ return to_return
1314
+
1315
  def _set_up_group_mean_aggregation(
1316
+ self,
1317
+ instances,
1318
+ reduction_params,
1319
+ reduction_fields,
1320
+ global_score,
1321
  ):
1322
+ aggregation_function_name = str(reduction_params["agg_func"][0])
1323
  group_aggregation_func = reduction_params["agg_func"][1]
1324
  # if treat groups as units
1325
  do_resample_as_group = reduction_params["agg_func"][2]
 
1327
  # pass the group aggregate---not instance---scores to resample as usual
1328
  aggregation_function = self.average_item_scores
1329
  scores_to_resample = self.get_group_scores(
1330
+ instances=instances,
1331
+ score_names=reduction_fields,
1332
+ group_aggregation_func=group_aggregation_func,
1333
+ prepend_score_prefix=True,
1334
+ global_score=global_score,
1335
+ aggregation_function_name=aggregation_function_name,
1336
  )
1337
  else:
1338
  # pass the instance scores to resample, and calculate the group aggregation on the resamplings
 
1344
  group_aggregation_func=group_aggregation_func,
1345
  ):
1346
  group_scores = self.get_group_scores(
1347
+ instances=instances,
1348
+ score_names=[field_name],
1349
+ group_aggregation_func=group_aggregation_func,
1350
+ prepend_score_prefix=False,
1351
+ global_score=global_score,
1352
+ aggregation_function_name=aggregation_function_name,
1353
  )
1354
  return nan_mean(
1355
  [group["score"]["instance"][field_name] for group in group_scores]
 
1387
  reduction_map = {"mean": ["anls"]}
1388
  prediction_type = Any # string representation is compared
1389
 
1390
+ @staticmethod
1391
+ @lru_cache(maxsize=10000)
1392
+ def preprocess_text(text):
1393
+ return " ".join(text.strip().lower().split()), len(text.upper())
1394
+
1395
+ def distance(self, prediction, reference):
1396
+ processed_reference, len_reference = self.preprocess_text(reference)
1397
+ processed_prediction, len_prediction = self.preprocess_text(prediction)
1398
+
1399
+ dist = self.levenshtein_distance(processed_reference, processed_prediction)
1400
+ length = max(len_reference, len_prediction)
1401
+ return 0.0 if length == 0 else float(dist) / float(length)
1402
+
1403
  def compute(
1404
  self,
1405
  references: List[Any],
 
1409
  ) -> dict:
1410
  """ANLS image-text accuracy metric."""
1411
  values = []
1412
+ for reference in references:
1413
+ values.append(self.distance(prediction, reference))
 
 
 
 
 
 
 
1414
 
1415
  question_result = 1.0 - min(values)
1416
 
1417
  if question_result < threshold:
1418
  question_result = 0.0
1419
+
1420
  result = {}
1421
  result["score"] = question_result
1422
  result[self.main_score] = question_result
 
1424
  return result
1425
 
1426
  @staticmethod
1427
+ @lru_cache(maxsize=10000)
1428
  def levenshtein_distance(s1, s2):
1429
  if len(s1) > len(s2):
1430
  s1, s2 = s2, s1
 
1606
  ), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
1607
  if has_postpreprocess:
1608
  self.postprocess_steps = self.postpreprocess_steps
1609
+ self.prepare_score = SequentialOperator(
1610
+ steps=[
1611
+ Copy(
1612
+ field=f"score/instance/{self.metric._add_score_prefix(self.main_score)}",
1613
+ to_field="score/instance/score",
1614
+ ),
1615
+ Copy(
1616
+ field=f"score/global/{self.metric._add_score_prefix(self.main_score)}",
1617
+ to_field="score/global/score",
1618
+ ),
1619
+ Copy(
1620
+ field=f"score/global/{self.metric._add_score_prefix(self.main_score)}_ci_low",
1621
+ to_field="score/global/score_ci_low",
1622
+ not_exist_do_nothing=True,
1623
+ ),
1624
+ Copy(
1625
+ field=f"score/global/{self.metric._add_score_prefix(self.main_score)}_ci_high",
1626
+ to_field="score/global/score_ci_high",
1627
+ not_exist_do_nothing=True,
1628
+ ),
1629
+ Set(
1630
+ fields={
1631
+ "score/instance/score_name": self.metric._add_score_prefix(
1632
+ self.main_score
1633
+ )
1634
+ }
1635
+ ),
1636
+ Set(
1637
+ fields={
1638
+ "score/global/score_name": self.metric._add_score_prefix(
1639
+ self.main_score
1640
+ )
1641
+ }
1642
+ ),
1643
  ],
1644
  )
1645
 
 
1693
 
1694
  def prepare(self):
1695
  super().prepare()
1696
+ import evaluate
1697
+
1698
  self.metric = evaluate.load(
1699
  self.hf_metric_name, experiment_id=self.experiment_id
1700
  )
 
1769
 
1770
  def prepare(self):
1771
  super().prepare()
1772
+ import evaluate
1773
+
1774
  self.metric = evaluate.load(
1775
  self.hf_metric_name, experiment_id=str(uuid.uuid4())
1776
  )
 
1817
 
1818
  def prepare(self):
1819
  super().prepare()
1820
+ import evaluate
1821
+
1822
  self.metric = evaluate.load(
1823
  self.hf_metric_name, experiment_id=str(uuid.uuid4())
1824
  )
 
1898
 
1899
  def prepare(self):
1900
  super().prepare()
1901
+ import evaluate
1902
+
1903
  self._metric = evaluate.load(self.metric, experiment_id=str(uuid.uuid4()))
1904
 
1905
  def get_str_id(self, str):
 
1959
  _metric = None
1960
  metric = "f1"
1961
  single_reference_per_prediction = True
1962
+ ci_scores = [main_score, "f1_binary_neg"]
1963
  _requirements_list: List[str] = ["sklearn"]
1964
 
1965
  def prepare(self):
 
2177
 
2178
  def prepare(self):
2179
  super().prepare()
2180
+ import evaluate
2181
+
2182
  self._metric = evaluate.load(
2183
  self.metric, "multilabel", experiment_id=str(uuid.uuid4())
2184
  )
 
3148
  class LlamaIndexLLMMetric(InstanceMetric):
3149
  model_name: str = ""
3150
  main_score: str = ""
3151
+ prediction_type = str
3152
  reduction_map: Dict[str, List[str]] = None
3153
  openai_models: List[str] = ["gpt-3.5-turbo"]
3154
  anthropic_models: List[
 
3794
  (recall_at_k, "recall"),
3795
  (match_at_k, "match"),
3796
  ]:
3797
+ measure_array[0] = 0.0 # to support cases where the prediction is empty.
3798
  max_k = max(measure_array.keys())
3799
  for k in self.k_list:
3800
  result[self.score_name(measure_name, k)] = measure_array[min(k, max_k)]
 
3841
  remotely (pre and post processing steps in the MetricPipeline will be computed locally).
3842
  """
3843
  local_inner_metric = metric_pipeline.metric
3844
+ metric_pipeline = deep_copy(
3845
  metric_pipeline
3846
  ) # To avoid unintentional changes to the catalog contents
3847
  metric_pipeline.metric = RemoteMetric(
 
4492
  main_score = "max_f1_binary"
4493
  single_reference_per_prediction = True
4494
  average = None
4495
+ ci_scores = [main_score, "max_f1_binary_neg"]
4496
 
4497
  def compute(
4498
  self,
 
4916
  "spacy": "Please pip install spacy",
4917
  }
4918
 
4919
+ def load_spacy(self):
 
4920
  import spacy
4921
 
4922
+ self.nlp = spacy.load(
4923
+ "en_core_web_sm", disable=["tagger", "parser", "ner", "lemmatizer"]
4924
+ )
4925
+
4926
+ def prepare(self):
4927
+ super().prepare()
4928
  try:
4929
+ self.load_spacy()
4930
  except OSError:
4931
  from spacy.cli import download
4932
 
4933
  download("en_core_web_sm")
4934
+ self.load_spacy()
4935
 
4936
  def compute(
4937
  self,
 
5077
  )
5078
  score = ensemble_model.predict([prediction_lst])
5079
  return score.tolist()[0]
5080
+
5081
+
5082
+ class PredictionLength(InstanceMetric):
5083
+ """Returns the length of the prediction."""
5084
+
5085
+ main_score = "prediction_length"
5086
+ reduction_map = {"mean": ["prediction_length"]}
5087
+ prediction_type = str
5088
+ single_reference_per_prediction = True
5089
+
5090
+ def compute(
5091
+ self,
5092
+ references: List[str],
5093
+ prediction: str,
5094
+ task_data: List[Dict],
5095
+ ) -> dict:
5096
+ return {self.main_score: [len(prediction)], "score_name": self.main_score}
operators.py CHANGED
@@ -39,7 +39,6 @@ General Operators List:
39
  ------------------------
40
  """
41
 
42
- import copy
43
  import operator
44
  import uuid
45
  import warnings
@@ -82,14 +81,19 @@ from .operator import (
82
  StreamOperator,
83
  )
84
  from .random_utils import new_random_generator
85
- from .settings_utils import get_constants, get_settings
86
- from .stream import DynamicStream, Stream
87
  from .text_utils import nested_tuple_to_string
88
  from .type_utils import isoftype
89
- from .utils import deepcopy, flatten_dict
 
 
 
 
 
 
90
 
91
  settings = get_settings()
92
- constants = get_constants()
93
 
94
 
95
  class FromIterables(StreamInitializerOperator):
@@ -132,8 +136,8 @@ class MapInstanceValues(InstanceOperator):
132
  it maps values of instances in a stream using predefined mappers.
133
 
134
  Attributes:
135
- mappers (Dict[str, Dict[str, str]]): The mappers to use for mapping instance values.
136
- Keys are the names of the fields to be mapped, and values are dictionaries
137
  that define the mapping from old values to new values.
138
  strict (bool): If True, the mapping is applied strictly. That means if a value
139
  does not exist in the mapper, it will raise a KeyError. If False, values
@@ -203,13 +207,12 @@ class MapInstanceValues(InstanceOperator):
203
 
204
  def get_mapped_value(self, instance, key, mapper, val):
205
  val_as_str = str(val) # make sure the value is a string
206
- if self.strict and (val_as_str not in mapper):
 
 
207
  raise KeyError(
208
  f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
209
  )
210
- # By default deep copy the value in mapper to avoid shared modifications
211
- if val_as_str in mapper:
212
- return deepcopy(mapper[val_as_str])
213
  return val
214
 
215
 
@@ -269,7 +272,7 @@ class Set(InstanceOperator):
269
  ) -> Dict[str, Any]:
270
  for key, value in self.fields.items():
271
  if self.use_deepcopy:
272
- value = deepcopy(value)
273
  dict_set(instance, key, value)
274
  return instance
275
 
@@ -318,6 +321,13 @@ class SelectFields(InstanceOperator):
318
  return new_instance
319
 
320
 
 
 
 
 
 
 
 
321
  class InstanceFieldOperator(InstanceOperator):
322
  """A general stream instance operator that processes the values of a field (or multiple ones).
323
 
@@ -348,6 +358,7 @@ class InstanceFieldOperator(InstanceOperator):
348
  process_every_value: bool = False
349
  get_default: Any = None
350
  not_exist_ok: bool = False
 
351
 
352
  def verify(self):
353
  super().verify()
@@ -429,19 +440,18 @@ class InstanceFieldOperator(InstanceOperator):
429
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
430
  ) -> Dict[str, Any]:
431
  self.verify_field_definition()
432
- # Need to deep copy instance, because when assigning two dictionary fields,
433
- # dict_set() the target field dictionary fields.
434
- # This means that if this target field was assigned to another field before,
435
- # the field is updated as well.
436
- instance = deepcopy(instance)
437
  for from_field, to_field in self._field_to_field:
438
  try:
439
  old_value = dict_get(
440
  instance,
441
  from_field,
442
- default=self.get_default,
443
- not_exist_ok=self.not_exist_ok,
444
  )
 
 
 
 
445
  except Exception as e:
446
  raise ValueError(
447
  f"Failed to get '{from_field}' from {instance} due to : {e}"
@@ -476,6 +486,13 @@ class FieldOperator(InstanceFieldOperator):
476
  pass
477
 
478
 
 
 
 
 
 
 
 
479
  class Rename(FieldOperator):
480
  """Renames fields.
481
 
@@ -643,7 +660,9 @@ class ListFieldValues(InstanceOperator):
643
  values = []
644
  for field_name in self.fields:
645
  values.append(dict_get(instance, field_name))
646
- instance[self.to_field] = values
 
 
647
  return instance
648
 
649
 
@@ -680,7 +699,7 @@ class ZipFieldValues(InstanceOperator):
680
  zipped = zip_longest(*values)
681
  else:
682
  zipped = zip(*values)
683
- instance[self.to_field] = list(zipped)
684
  return instance
685
 
686
 
@@ -847,14 +866,15 @@ class Copy(FieldOperator):
847
 
848
  """
849
 
850
- use_deep_copy: bool = True
851
-
852
  def process_value(self, value: Any) -> Any:
853
- if self.use_deep_copy:
854
- return copy.deepcopy(value)
855
  return value
856
 
857
 
 
 
 
 
 
858
  @deprecation(version="2.0.0", alternative=Copy)
859
  class CopyFields(Copy):
860
  pass
@@ -1022,7 +1042,7 @@ class ArtifactFetcherMixin:
1022
  if artifact_identifier not in cls.cache:
1023
  artifact, artifactory = fetch_artifact(artifact_identifier)
1024
  cls.cache[artifact_identifier] = artifact
1025
- return copy.deepcopy(cls.cache[artifact_identifier])
1026
 
1027
 
1028
  class ApplyOperatorsField(InstanceOperator):
@@ -1602,7 +1622,23 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1602
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1603
  from .metrics import Metric
1604
 
1605
- first_instance = stream.peek()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1606
 
1607
  metric_names = first_instance.get(self.metric_field, [])
1608
  if not metric_names:
@@ -1619,16 +1655,6 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1619
  # by the first listed metric (as desired).
1620
  metric_names = list(reversed(metric_names))
1621
 
1622
- # Workaround: The metric/MetricPipeline modifies the stream itself, sometimes making it incompatible
1623
- # for further metrics' processing, instead of just modifying the score field.
1624
- # Here we keep all the fields besides the score, and restore them after the metric finishes.
1625
- first_instance = stream.peek()
1626
- keys_to_restore = set(first_instance.keys()).difference({"score"})
1627
- multi_stream = MultiStream({stream_name: stream})
1628
- multi_stream = CopyFields(
1629
- field_to_field={k: f"{k}_orig" for k in keys_to_restore}
1630
- )(multi_stream)
1631
-
1632
  for metric_name in metric_names:
1633
  metric = self.get_artifact(metric_name)
1634
  assert isinstance(
@@ -1637,17 +1663,23 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1637
 
1638
  if not self.calc_confidence_intervals:
1639
  metric.disable_confidence_interval_calculation()
1640
-
 
 
 
 
 
 
 
1641
  multi_stream = metric(multi_stream)
1642
- multi_stream = CopyFields(
1643
- field_to_field={f"{k}_orig": k for k in keys_to_restore}
1644
- )(multi_stream)
 
 
 
1645
 
1646
- multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
1647
- multi_stream
1648
- )
1649
- stream = multi_stream[stream_name]
1650
- yield from stream
1651
 
1652
 
1653
  class MergeStreams(MultiStreamOperator):
@@ -2066,7 +2098,7 @@ class DuplicateInstances(StreamOperator):
2066
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
2067
  for instance in stream:
2068
  for idx in range(self.num_duplications):
2069
- duplicate = deepcopy(instance)
2070
  if self.duplication_index_field:
2071
  duplicate.update({self.duplication_index_field: idx})
2072
  yield duplicate
 
39
  ------------------------
40
  """
41
 
 
42
  import operator
43
  import uuid
44
  import warnings
 
81
  StreamOperator,
82
  )
83
  from .random_utils import new_random_generator
84
+ from .settings_utils import get_settings
85
+ from .stream import DynamicStream, ListStream, Stream
86
  from .text_utils import nested_tuple_to_string
87
  from .type_utils import isoftype
88
+ from .utils import (
89
+ deep_copy,
90
+ flatten_dict,
91
+ recursive_copy,
92
+ recursive_shallow_copy,
93
+ shallow_copy,
94
+ )
95
 
96
  settings = get_settings()
 
97
 
98
 
99
  class FromIterables(StreamInitializerOperator):
 
136
  it maps values of instances in a stream using predefined mappers.
137
 
138
  Attributes:
139
+ mappers (Dict[str, Dict[str, Any]]): The mappers to use for mapping instance values.
140
+ Keys are the names of the fields to undergo mapping, and values are dictionaries
141
  that define the mapping from old values to new values.
142
  strict (bool): If True, the mapping is applied strictly. That means if a value
143
  does not exist in the mapper, it will raise a KeyError. If False, values
 
207
 
208
  def get_mapped_value(self, instance, key, mapper, val):
209
  val_as_str = str(val) # make sure the value is a string
210
+ if val_as_str in mapper:
211
+ return recursive_copy(mapper[val_as_str])
212
+ if self.strict:
213
  raise KeyError(
214
  f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
215
  )
 
 
 
216
  return val
217
 
218
 
 
272
  ) -> Dict[str, Any]:
273
  for key, value in self.fields.items():
274
  if self.use_deepcopy:
275
+ value = deep_copy(value)
276
  dict_set(instance, key, value)
277
  return instance
278
 
 
321
  return new_instance
322
 
323
 
324
+ class DefaultPlaceHolder:
325
+ pass
326
+
327
+
328
+ default_place_holder = DefaultPlaceHolder()
329
+
330
+
331
  class InstanceFieldOperator(InstanceOperator):
332
  """A general stream instance operator that processes the values of a field (or multiple ones).
333
 
 
358
  process_every_value: bool = False
359
  get_default: Any = None
360
  not_exist_ok: bool = False
361
+ not_exist_do_nothing: bool = False
362
 
363
  def verify(self):
364
  super().verify()
 
440
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
441
  ) -> Dict[str, Any]:
442
  self.verify_field_definition()
 
 
 
 
 
443
  for from_field, to_field in self._field_to_field:
444
  try:
445
  old_value = dict_get(
446
  instance,
447
  from_field,
448
+ default=default_place_holder,
449
+ not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
450
  )
451
+ if old_value is default_place_holder:
452
+ if self.not_exist_do_nothing:
453
+ return instance
454
+ old_value = self.get_default
455
  except Exception as e:
456
  raise ValueError(
457
  f"Failed to get '{from_field}' from {instance} due to : {e}"
 
486
  pass
487
 
488
 
489
+ class MapValues(FieldOperator):
490
+ mapping: Dict[str, str]
491
+
492
+ def process_value(self, value: Any) -> Any:
493
+ return self.mapping[str(value)]
494
+
495
+
496
  class Rename(FieldOperator):
497
  """Renames fields.
498
 
 
660
  values = []
661
  for field_name in self.fields:
662
  values.append(dict_get(instance, field_name))
663
+
664
+ dict_set(instance, self.to_field, values)
665
+
666
  return instance
667
 
668
 
 
699
  zipped = zip_longest(*values)
700
  else:
701
  zipped = zip(*values)
702
+ dict_set(instance, self.to_field, list(zipped))
703
  return instance
704
 
705
 
 
866
 
867
  """
868
 
 
 
869
  def process_value(self, value: Any) -> Any:
 
 
870
  return value
871
 
872
 
873
+ class RecursiveCopy(FieldOperator):
874
+ def process_value(self, value: Any) -> Any:
875
+ return recursive_copy(value)
876
+
877
+
878
  @deprecation(version="2.0.0", alternative=Copy)
879
  class CopyFields(Copy):
880
  pass
 
1042
  if artifact_identifier not in cls.cache:
1043
  artifact, artifactory = fetch_artifact(artifact_identifier)
1044
  cls.cache[artifact_identifier] = artifact
1045
+ return shallow_copy(cls.cache[artifact_identifier])
1046
 
1047
 
1048
  class ApplyOperatorsField(InstanceOperator):
 
1622
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1623
  from .metrics import Metric
1624
 
1625
+ # Number of instances in input stream is assumed to be small. This is why
1626
+ # each metric consumes all of them and lays them in its main memory, and even generates
1627
+ # some 1000 copies thereof for the sake of CI.
1628
+ # So we start with deep copying here, to make a 'frozen' status of the stream, having
1629
+ # passed the preprocess_steps of the task, and inference, and now getting to be evaluated,
1630
+ # a frozen status to be fed into each of the metrics listed in metric_field,
1631
+ # so that the evaluation of one does not affect the evaluation of another
1632
+ # (typically, affecting via change of instance as part of
1633
+ # preprocess_steps of MetricPipeline, as illustrated in docs/adding_metrics/Using Metric Pipelines).
1634
+
1635
+ instances_upon_entrance_to_metrics_evaluations = []
1636
+ for instance in stream:
1637
+ instances_upon_entrance_to_metrics_evaluations.append(
1638
+ recursive_copy(instance)
1639
+ )
1640
+
1641
+ first_instance = instances_upon_entrance_to_metrics_evaluations[0]
1642
 
1643
  metric_names = first_instance.get(self.metric_field, [])
1644
  if not metric_names:
 
1655
  # by the first listed metric (as desired).
1656
  metric_names = list(reversed(metric_names))
1657
 
 
 
 
 
 
 
 
 
 
 
1658
  for metric_name in metric_names:
1659
  metric = self.get_artifact(metric_name)
1660
  assert isinstance(
 
1663
 
1664
  if not self.calc_confidence_intervals:
1665
  metric.disable_confidence_interval_calculation()
1666
+ multi_stream = MultiStream(
1667
+ {
1668
+ "tmp": ListStream(
1669
+ instances_list=instances_upon_entrance_to_metrics_evaluations,
1670
+ copying=True, # ensures deep copy when iterating over instances
1671
+ )
1672
+ }
1673
+ )
1674
  multi_stream = metric(multi_stream)
1675
+ for evaluated_instance, freezed_instance in zip(
1676
+ multi_stream["tmp"], instances_upon_entrance_to_metrics_evaluations
1677
+ ):
1678
+ freezed_instance["score"] = recursive_shallow_copy(
1679
+ evaluated_instance["score"]
1680
+ )
1681
 
1682
+ yield from instances_upon_entrance_to_metrics_evaluations
 
 
 
 
1683
 
1684
 
1685
  class MergeStreams(MultiStreamOperator):
 
2098
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
2099
  for instance in stream:
2100
  for idx in range(self.num_duplications):
2101
+ duplicate = recursive_shallow_copy(instance)
2102
  if self.duplication_index_field:
2103
  duplicate.update({self.duplication_index_field: idx})
2104
  yield duplicate
processors.py CHANGED
@@ -2,9 +2,12 @@ import ast
2
  import copy
3
  import json
4
  import re
 
5
  from difflib import get_close_matches
6
  from typing import Any, Dict
7
 
 
 
8
  from .deprecation_utils import deprecation
9
  from .operator import MultiStreamOperator
10
  from .operators import FieldOperator, InstanceFieldOperator
@@ -20,9 +23,9 @@ class PostProcess(MultiStreamOperator):
20
 
21
  def prepare(self):
22
  super().prepare()
23
- self.prediction_operator = copy.deepcopy(self.operator)
24
  self.prediction_operator.field = "prediction"
25
- self.references_operator = copy.deepcopy(self.operator)
26
  self.references_operator.field = "references"
27
  self.references_operator.process_every_value = True
28
  self.references_operator.dont_apply_to_streams = [constants.inference_stream]
@@ -315,3 +318,75 @@ class ExtractArenaHardNumericalJudgment(FieldOperator):
315
 
316
  except:
317
  return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import copy
3
  import json
4
  import re
5
+ import string
6
  from difflib import get_close_matches
7
  from typing import Any, Dict
8
 
9
+ import numpy as np
10
+
11
  from .deprecation_utils import deprecation
12
  from .operator import MultiStreamOperator
13
  from .operators import FieldOperator, InstanceFieldOperator
 
23
 
24
  def prepare(self):
25
  super().prepare()
26
+ self.prediction_operator = copy.copy(self.operator)
27
  self.prediction_operator.field = "prediction"
28
+ self.references_operator = copy.copy(self.operator)
29
  self.references_operator.field = "references"
30
  self.references_operator.process_every_value = True
31
  self.references_operator.dont_apply_to_streams = [constants.inference_stream]
 
318
 
319
  except:
320
  return 0
321
+
322
+
323
+ class InferDictsToBinaryLogprobs(FieldOperator):
324
+ neg_class_name: str
325
+ pos_class_name: str
326
+
327
+ take_logprobs_from_end: bool = False
328
+ num_logprobs_to_take: int = 3
329
+ min_probability_mass = 0.0001
330
+
331
+ def verify(self):
332
+ super().verify()
333
+ if (
334
+ self.neg_class_name.lower() in self.pos_class_name.lower()
335
+ or self.pos_class_name.lower() in self.neg_class_name.lower()
336
+ ):
337
+ raise ValueError(
338
+ f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}"""
339
+ )
340
+
341
+ def process_value(self, obj: Any) -> Any:
342
+ for i in self.get_token_range(obj):
343
+ try:
344
+ pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i])
345
+ if pos_probs or neg_probs:
346
+ sum_probs = sum(pos_probs) + sum(neg_probs)
347
+ if sum_probs > self.min_probability_mass:
348
+ return sum(pos_probs) / sum_probs
349
+ except:
350
+ pass
351
+ return 0
352
+
353
+ def get_pos_neg_probs(self, pred_dict):
354
+ token_logprobs = pred_dict["top_tokens"]
355
+
356
+ pos_and_neg_probs = []
357
+ for class_name in [self.pos_class_name, self.neg_class_name]:
358
+ # We need to capture different variants of model behavior and tokenizers, for example with opening space,
359
+ # punctuation etc. but avoid longer words that contain the class name.
360
+ # For example, for class "yes" we would capture "YES," and " Yes" but not "yesterday".
361
+ name_regex = re.compile(
362
+ rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE
363
+ )
364
+ class_probs = [
365
+ np.exp(d["logprob"])
366
+ for d in token_logprobs
367
+ if name_regex.fullmatch(d["text"])
368
+ ]
369
+ pos_and_neg_probs.append(class_probs)
370
+ return pos_and_neg_probs
371
+
372
+ def get_token_range(self, obj: Any) -> range:
373
+ n_tokens = min([self.num_logprobs_to_take, len(obj)])
374
+ if self.take_logprobs_from_end:
375
+ return range(-1, -(n_tokens + 1), -1)
376
+ return range(n_tokens)
377
+
378
+
379
+ class RemoveArticles(FieldOperator):
380
+ def process_value(self, text: Any) -> Any:
381
+ return re.sub(r"\b(a|an|the)\b", " ", text)
382
+
383
+
384
+ class RemovePunctuations(FieldOperator):
385
+ def process_value(self, text: Any) -> Any:
386
+ puncs_to_exclude = set(string.punctuation)
387
+ return "".join(c for c in text if c not in puncs_to_exclude)
388
+
389
+
390
+ class FixWhiteSpace(FieldOperator):
391
+ def process_value(self, text: Any) -> Any:
392
+ return " ".join(text.split())
settings_utils.py CHANGED
@@ -147,6 +147,7 @@ if Settings.is_uninitilized():
147
  settings.skip_artifacts_prepare_and_verify = (bool, False)
148
  settings.data_classification_policy = None
149
  settings.mock_inference_mode = (bool, False)
 
150
 
151
  if Constants.is_uninitilized():
152
  constants = Constants()
 
147
  settings.skip_artifacts_prepare_and_verify = (bool, False)
148
  settings.data_classification_policy = None
149
  settings.mock_inference_mode = (bool, False)
150
+ settings.disable_hf_datasets_cache = (bool, True)
151
 
152
  if Constants.is_uninitilized():
153
  constants = Constants()
split_utils.py CHANGED
@@ -226,7 +226,12 @@ def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]):
226
  dict: A dictionary containing the generated new streams, where each key is the name
227
  of the new stream and the value is a generator representing the stream.
228
  """
229
- return {mapping.get(key, key): val for key, val in input_streams.items()}
 
 
 
 
 
230
 
231
 
232
  def random_mix_generator(
 
226
  dict: A dictionary containing the generated new streams, where each key is the name
227
  of the new stream and the value is a generator representing the stream.
228
  """
229
+ new_streams = {}
230
+ for key, val in mapping.items():
231
+ if key not in input_streams:
232
+ raise ValueError("Wrong stream name")
233
+ new_streams[val] = input_streams.pop(key)
234
+ return {**input_streams, **new_streams}
235
 
236
 
237
  def random_mix_generator(
splitters.py CHANGED
@@ -16,7 +16,7 @@ from .split_utils import (
16
  )
17
  from .stream import EmptyStreamError, FaultyStreamError, MultiStream
18
  from .type_utils import isoftype
19
- from .utils import deepcopy
20
 
21
 
22
  class Splitter(MultiStreamOperator):
@@ -353,7 +353,9 @@ class Sample(InstanceOperatorWithMultiStreamAccess):
353
  sample_size = self.get_sample_size(instance)
354
  try:
355
  if self.local_cache is None:
356
- self.local_cache = deepcopy(list(multi_stream[self.from_stream]))
 
 
357
 
358
  source_stream = self.local_cache
359
  source_stream = self.sampler.filter_source_by_instance(
 
16
  )
17
  from .stream import EmptyStreamError, FaultyStreamError, MultiStream
18
  from .type_utils import isoftype
19
+ from .utils import recursive_shallow_copy
20
 
21
 
22
  class Splitter(MultiStreamOperator):
 
353
  sample_size = self.get_sample_size(instance)
354
  try:
355
  if self.local_cache is None:
356
+ self.local_cache = recursive_shallow_copy(
357
+ list(multi_stream[self.from_stream])
358
+ )
359
 
360
  source_stream = self.local_cache
361
  source_stream = self.sampler.filter_source_by_instance(
standard.py CHANGED
@@ -249,12 +249,12 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
249
  def produce(self, task_instances):
250
  """Use the recipe in production to produce model ready query from standard task instance."""
251
  self.before_process_multi_stream()
252
- multi_stream = MultiStream.from_iterables(
253
- {
254
- constants.inference_stream: self.production_preprocess(task_instances),
255
- self.demos_pool_name: self.production_demos_pool(),
256
- }
257
- )
258
  multi_stream = self.inference(multi_stream)
259
  return list(multi_stream[constants.inference_stream])
260
 
 
249
  def produce(self, task_instances):
250
  """Use the recipe in production to produce model ready query from standard task instance."""
251
  self.before_process_multi_stream()
252
+ streams = {
253
+ constants.inference_stream: self.production_preprocess(task_instances),
254
+ }
255
+ if self.use_demos:
256
+ streams[self.demos_pool_name] = self.production_demos_pool()
257
+ multi_stream = MultiStream.from_iterables(streams)
258
  multi_stream = self.inference(multi_stream)
259
  return list(multi_stream[constants.inference_stream])
260
 
stream.py CHANGED
@@ -10,7 +10,7 @@ from .dataclass import Dataclass, OptionalField
10
  from .generator_utils import CopyingReusableGenerator, ReusableGenerator
11
  from .logging_utils import get_logger
12
  from .settings_utils import get_settings
13
- from .utils import deepcopy
14
 
15
  settings = get_settings()
16
  logger = get_logger()
@@ -40,7 +40,7 @@ class ListStream(Stream):
40
 
41
  def __iter__(self):
42
  if self.copying:
43
- return iter(deepcopy(self.instances_list))
44
  return iter(self.instances_list)
45
 
46
  def peek(self):
@@ -244,7 +244,8 @@ class MultiStream(dict):
244
  return IterableDatasetDict(
245
  {
246
  key: IterableDataset.from_generator(
247
- self.get_generator, gen_kwargs={"key": key}
 
248
  )
249
  for key in self.keys()
250
  }
 
10
  from .generator_utils import CopyingReusableGenerator, ReusableGenerator
11
  from .logging_utils import get_logger
12
  from .settings_utils import get_settings
13
+ from .utils import recursive_copy
14
 
15
  settings = get_settings()
16
  logger = get_logger()
 
40
 
41
  def __iter__(self):
42
  if self.copying:
43
+ return iter(recursive_copy(self.instances_list))
44
  return iter(self.instances_list)
45
 
46
  def peek(self):
 
244
  return IterableDatasetDict(
245
  {
246
  key: IterableDataset.from_generator(
247
+ self.get_generator,
248
+ gen_kwargs={"key": key},
249
  )
250
  for key in self.keys()
251
  }
stream_operators.py CHANGED
@@ -31,6 +31,7 @@ The rest of this section is dedicated for operators that operates on streams.
31
 
32
  """
33
 
 
34
  from typing import (
35
  List,
36
  Literal,
@@ -154,6 +155,7 @@ class DuplicateSplit(MultiStreamOperator):
154
 
155
  def process(self, multi_stream: MultiStream) -> MultiStream:
156
  assert self.split in multi_stream
157
- generators = multi_stream
158
- generators[self.to_split] = generators[self.split]
159
- return MultiStream(generators)
 
 
31
 
32
  """
33
 
34
+ import copy
35
  from typing import (
36
  List,
37
  Literal,
 
155
 
156
  def process(self, multi_stream: MultiStream) -> MultiStream:
157
  assert self.split in multi_stream
158
+ new_stream = copy.deepcopy(multi_stream[self.split])
159
+ new_stream.set_copying(copying=True)
160
+ multi_stream[self.to_split] = new_stream
161
+ return multi_stream
string_operators.py CHANGED
@@ -87,3 +87,12 @@ class Replace(FieldOperator):
87
 
88
  def process_value(self, value: str) -> str:
89
  return value.replace(self.old, self.new)
 
 
 
 
 
 
 
 
 
 
87
 
88
  def process_value(self, value: str) -> str:
89
  return value.replace(self.old, self.new)
90
+
91
+
92
+ class MapReplace(FieldOperator):
93
+ mapping: Dict[str, str]
94
+
95
+ def process_value(self, value: Any) -> Any:
96
+ for key, val in self.mapping.items():
97
+ value = value.replace(key, val)
98
+ return value
struct_data_operators.py CHANGED
@@ -32,7 +32,7 @@ from .operators import FieldOperator, InstanceOperator
32
  from .random_utils import new_random_generator
33
  from .serializers import TableSerializer
34
  from .types import Table
35
- from .utils import deepcopy
36
 
37
 
38
  def shuffle_columns(table: Table, seed=0) -> Table:
@@ -76,7 +76,7 @@ class SerializeTable(ABC, TableSerializer):
76
  shuffle_columns: bool = False
77
 
78
  def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
79
- value = deepcopy(value)
80
  if self.shuffle_columns:
81
  value = shuffle_columns(table=value, seed=self.seed)
82
 
@@ -207,6 +207,12 @@ class SerializeTableAsDFLoader(SerializeTable):
207
 
208
  assert header and rows, "Incorrect input table format"
209
 
 
 
 
 
 
 
210
  # Create a pandas DataFrame
211
  df = pd.DataFrame(rows, columns=header)
212
 
@@ -252,6 +258,59 @@ class SerializeTableAsJson(SerializeTable):
252
  return json.dumps(output_dict)
253
 
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  # truncate cell value to maximum allowed length
256
  def truncate_cell(cell_value, max_len):
257
  if cell_value is None:
@@ -490,7 +549,7 @@ class ConvertTableColNamesToSequential(FieldOperator):
490
  """
491
 
492
  def process_value(self, table: Any) -> Any:
493
- table_input = deepcopy(table)
494
  return self.replace_header(table_content=table_input)
495
 
496
  # replaces header with sequential column names
@@ -523,7 +582,7 @@ class ShuffleTableRows(FieldOperator):
523
  """
524
 
525
  def process_value(self, table: Any) -> Any:
526
- table_input = deepcopy(table)
527
  return shuffle_rows(table_input)
528
 
529
 
@@ -544,7 +603,7 @@ class ShuffleTableColumns(FieldOperator):
544
  """
545
 
546
  def process_value(self, table: Any) -> Any:
547
- table_input = deepcopy(table)
548
  return shuffle_columns(table_input)
549
 
550
 
@@ -658,3 +717,133 @@ class ConstructTableFromRowsCols(InstanceOperator):
658
  instance[self.to_field] = output_dict
659
 
660
  return instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  from .random_utils import new_random_generator
33
  from .serializers import TableSerializer
34
  from .types import Table
35
+ from .utils import recursive_copy
36
 
37
 
38
  def shuffle_columns(table: Table, seed=0) -> Table:
 
76
  shuffle_columns: bool = False
77
 
78
  def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
79
+ value = recursive_copy(value)
80
  if self.shuffle_columns:
81
  value = shuffle_columns(table=value, seed=self.seed)
82
 
 
207
 
208
  assert header and rows, "Incorrect input table format"
209
 
210
+ # Fix duplicate columns, ensuring the first occurrence has no suffix
211
+ header = [
212
+ f"{col}_{header[:i].count(col)}" if header[:i].count(col) > 0 else col
213
+ for i, col in enumerate(header)
214
+ ]
215
+
216
  # Create a pandas DataFrame
217
  df = pd.DataFrame(rows, columns=header)
218
 
 
258
  return json.dumps(output_dict)
259
 
260
 
261
+ class SerializeTableAsHTML(SerializeTable):
262
+ """HTML Table Serializer.
263
+
264
+ HTML table format used for rendering tables in web pages.
265
+ Format(Sample):
266
+ <table>
267
+ <thead>
268
+ <tr><th>name</th><th>age</th><th>sex</th></tr>
269
+ </thead>
270
+ <tbody>
271
+ <tr><td>Alice</td><td>26</td><td>F</td></tr>
272
+ <tr><td>Raj</td><td>34</td><td>M</td></tr>
273
+ </tbody>
274
+ </table>
275
+ """
276
+
277
+ # main method that serializes a table.
278
+ # table_content must be in the prescribed input format.
279
+ def serialize_table(self, table_content: Dict) -> str:
280
+ # Extract headers and rows from the dictionary
281
+ header = table_content.get("header", [])
282
+ rows = table_content.get("rows", [])
283
+
284
+ assert header and rows, "Incorrect input table format"
285
+
286
+ # Build the HTML table structure
287
+ serialized_tbl_str = "<table>\n"
288
+ serialized_tbl_str += self.process_header(header) + "\n"
289
+ serialized_tbl_str += self.process_rows(rows) + "\n"
290
+ serialized_tbl_str += "</table>"
291
+
292
+ return serialized_tbl_str.strip()
293
+
294
+ # serialize the header into an HTML <thead> section
295
+ def process_header(self, header: List) -> str:
296
+ header_html = " <thead>\n <tr>"
297
+ for col in header:
298
+ header_html += f"<th>{col}</th>"
299
+ header_html += "</tr>\n </thead>"
300
+ return header_html
301
+
302
+ # serialize the rows into an HTML <tbody> section
303
+ def process_rows(self, rows: List[List]) -> str:
304
+ rows_html = " <tbody>"
305
+ for row in rows:
306
+ rows_html += "\n <tr>"
307
+ for cell in row:
308
+ rows_html += f"<td>{cell}</td>"
309
+ rows_html += "</tr>"
310
+ rows_html += "\n </tbody>"
311
+ return rows_html
312
+
313
+
314
  # truncate cell value to maximum allowed length
315
  def truncate_cell(cell_value, max_len):
316
  if cell_value is None:
 
549
  """
550
 
551
  def process_value(self, table: Any) -> Any:
552
+ table_input = recursive_copy(table)
553
  return self.replace_header(table_content=table_input)
554
 
555
  # replaces header with sequential column names
 
582
  """
583
 
584
  def process_value(self, table: Any) -> Any:
585
+ table_input = recursive_copy(table)
586
  return shuffle_rows(table_input)
587
 
588
 
 
603
  """
604
 
605
  def process_value(self, table: Any) -> Any:
606
+ table_input = recursive_copy(table)
607
  return shuffle_columns(table_input)
608
 
609
 
 
717
  instance[self.to_field] = output_dict
718
 
719
  return instance
720
+
721
+
722
+ class TransposeTable(FieldOperator):
723
+ """Transpose a table.
724
+
725
+ Sample Input:
726
+ {
727
+ "header": ["name", "age", "sex"],
728
+ "rows": [["Alice", 26, "F"], ["Raj", 34, "M"], ["Donald", 39, "M"]],
729
+ }
730
+
731
+ Sample Output:
732
+ {
733
+ "header": [" ", "0", "1", "2"],
734
+ "rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
735
+ }
736
+ """
737
+
738
+ def process_value(self, table: Any) -> Any:
739
+ return self.transpose_table(table)
740
+
741
+ def transpose_table(self, table: Dict) -> Dict:
742
+ # Extract the header and rows from the table object
743
+ header = table["header"]
744
+ rows = table["rows"]
745
+
746
+ # Transpose the table by converting rows as columns and vice versa
747
+ transposed_header = [" "] + [str(i) for i in range(len(rows))]
748
+ transposed_rows = [
749
+ [header[i]] + [row[i] for row in rows] for i in range(len(header))
750
+ ]
751
+
752
+ return {"header": transposed_header, "rows": transposed_rows}
753
+
754
+
755
+ class DuplicateTableRows(FieldOperator):
756
+ """Duplicates specific rows of a table for the given number of times.
757
+
758
+ Args:
759
+ row_indices (List[int]) - rows to be duplicated
760
+ times(int) - how many times to duplicate
761
+ """
762
+
763
+ row_indices: List[int] = []
764
+ times: int = 1
765
+
766
+ def process_value(self, table: Any) -> Any:
767
+ # Extract the header and rows from the table
768
+ header = table["header"]
769
+ rows = table["rows"]
770
+
771
+ # Duplicate only the specified rows
772
+ duplicated_rows = []
773
+ for i, row in enumerate(rows):
774
+ if i in self.row_indices:
775
+ duplicated_rows.extend(
776
+ [row] * self.times
777
+ ) # Duplicate the selected rows
778
+ else:
779
+ duplicated_rows.append(row) # Leave other rows unchanged
780
+
781
+ # Return the new table with selectively duplicated rows
782
+ return {"header": header, "rows": duplicated_rows}
783
+
784
+
785
+ class DuplicateTableColumns(FieldOperator):
786
+ """Duplicates specific columns of a table for the given number of times.
787
+
788
+ Args:
789
+ column_indices (List[int]) - columns to be duplicated
790
+ times(int) - how many times to duplicate
791
+ """
792
+
793
+ column_indices: List[int] = []
794
+ times: int = 1
795
+
796
+ def process_value(self, table: Any) -> Any:
797
+ # Extract the header and rows from the table
798
+ header = table["header"]
799
+ rows = table["rows"]
800
+
801
+ # Duplicate the specified columns in the header
802
+ duplicated_header = []
803
+ for i, col in enumerate(header):
804
+ if i in self.column_indices:
805
+ duplicated_header.extend([col] * self.times)
806
+ else:
807
+ duplicated_header.append(col)
808
+
809
+ # Duplicate the specified columns in each row
810
+ duplicated_rows = []
811
+ for row in rows:
812
+ new_row = []
813
+ for i, value in enumerate(row):
814
+ if i in self.column_indices:
815
+ new_row.extend([value] * self.times)
816
+ else:
817
+ new_row.append(value)
818
+ duplicated_rows.append(new_row)
819
+
820
+ # Return the new table with selectively duplicated columns
821
+ return {"header": duplicated_header, "rows": duplicated_rows}
822
+
823
+
824
+ class InsertEmptyTableRows(FieldOperator):
825
+ """Inserts empty rows in a table randomly for the given number of times.
826
+
827
+ Args:
828
+ times(int) - how many times to insert
829
+ """
830
+
831
+ times: int = 0
832
+
833
+ def process_value(self, table: Any) -> Any:
834
+ # Extract the header and rows from the table
835
+ header = table["header"]
836
+ rows = table["rows"]
837
+
838
+ # Insert empty rows at random positions
839
+ for _ in range(self.times):
840
+ empty_row = [""] * len(
841
+ header
842
+ ) # Create an empty row with the same number of columns
843
+ insert_pos = random.randint(
844
+ 0, len(rows)
845
+ ) # Get a random position to insert the empty row created
846
+ rows.insert(insert_pos, empty_row)
847
+
848
+ # Return the modified table
849
+ return {"header": header, "rows": rows}
templates.py CHANGED
@@ -210,7 +210,7 @@ class ApplyTemplate(InstanceOperator):
210
  if self.demos_field not in instance:
211
  raise ValueError("Demos field is missing.")
212
  instance[self.demos_field] = [
213
- self.apply(template, demo_instance, stream_name)
214
  for demo_instance in instance[self.demos_field]
215
  ]
216
  dict_set(instance, "recipe_metadata/template", template)
 
210
  if self.demos_field not in instance:
211
  raise ValueError("Demos field is missing.")
212
  instance[self.demos_field] = [
213
+ self.apply(template, demo_instance)
214
  for demo_instance in instance[self.demos_field]
215
  ]
216
  dict_set(instance, "recipe_metadata/template", template)
type_utils.py CHANGED
@@ -4,6 +4,7 @@ import io
4
  import itertools
5
  import re
6
  import typing
 
7
  from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
8
 
9
  from .utils import safe_eval
@@ -810,6 +811,7 @@ class NormalizedType(typing.NamedTuple):
810
  return f"{self.origin}[{self.args}])"
811
 
812
 
 
813
  def _normalize_args(tps: TypeArgs):
814
  if isinstance(tps, str):
815
  return tps
@@ -918,6 +920,7 @@ def _is_origin_subtype_args(
918
  return _is_normal_subtype(left, right, forward_refs)
919
 
920
 
 
921
  def _is_normal_subtype(
922
  left: NormalizedType,
923
  right: NormalizedType,
 
4
  import itertools
5
  import re
6
  import typing
7
+ from functools import lru_cache
8
  from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
9
 
10
  from .utils import safe_eval
 
811
  return f"{self.origin}[{self.args}])"
812
 
813
 
814
+ @lru_cache(maxsize=None)
815
  def _normalize_args(tps: TypeArgs):
816
  if isinstance(tps, str):
817
  return tps
 
920
  return _is_normal_subtype(left, right, forward_refs)
921
 
922
 
923
+ @lru_cache(maxsize=None)
924
  def _is_normal_subtype(
925
  left: NormalizedType,
926
  right: NormalizedType,
utils.py CHANGED
@@ -148,5 +148,88 @@ def import_module_from_file(file_path):
148
  return module
149
 
150
 
151
- def deepcopy(obj):
 
 
 
 
 
 
 
 
152
  return copy.deepcopy(obj)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  return module
149
 
150
 
151
+ def deep_copy(obj):
152
+ """Creates a deep copy of the given object.
153
+
154
+ Args:
155
+ obj: The object to be deep copied.
156
+
157
+ Returns:
158
+ A deep copy of the original object.
159
+ """
160
  return copy.deepcopy(obj)
161
+
162
+
163
+ def shallow_copy(obj):
164
+ """Creates a shallow copy of the given object.
165
+
166
+ Args:
167
+ obj: The object to be shallow copied.
168
+
169
+ Returns:
170
+ A shallow copy of the original object.
171
+ """
172
+ return copy.copy(obj)
173
+
174
+
175
+ def recursive_copy(obj, internal_copy=None):
176
+ """Recursively copies an object with a selective copy method.
177
+
178
+ For `list`, `dict`, and `tuple` types, it recursively copies their contents.
179
+ For other types, it uses the provided `internal_copy` function if available.
180
+ Objects without a `copy` method are returned as is.
181
+
182
+ Args:
183
+ obj: The object to be copied.
184
+ internal_copy (callable, optional): The copy function to use for non-container objects.
185
+ If `None`, objects without a `copy` method are returned as is.
186
+
187
+ Returns:
188
+ The recursively copied object.
189
+ """
190
+ # Handle dictionaries
191
+ if isinstance(obj, dict):
192
+ return type(obj)(
193
+ {key: recursive_copy(value, internal_copy) for key, value in obj.items()}
194
+ )
195
+
196
+ # Handle named tuples
197
+ if isinstance(obj, tuple) and hasattr(obj, "_fields"):
198
+ return type(obj)(*(recursive_copy(item, internal_copy) for item in obj))
199
+
200
+ # Handle tuples and lists
201
+ if isinstance(obj, (tuple, list)):
202
+ return type(obj)(recursive_copy(item, internal_copy) for item in obj)
203
+
204
+ if internal_copy is None:
205
+ return obj
206
+
207
+ return internal_copy(obj)
208
+
209
+
210
+ def recursive_deep_copy(obj):
211
+ """Performs a recursive deep copy of the given object.
212
+
213
+ This function uses `deep_copy` as the internal copy method for non-container objects.
214
+
215
+ Args:
216
+ obj: The object to be deep copied.
217
+
218
+ Returns:
219
+ A recursively deep-copied version of the original object.
220
+ """
221
+ return recursive_copy(obj, deep_copy)
222
+
223
+
224
+ def recursive_shallow_copy(obj):
225
+ """Performs a recursive shallow copy of the given object.
226
+
227
+ This function uses `shallow_copy` as the internal copy method for non-container objects.
228
+
229
+ Args:
230
+ obj: The object to be shallow copied.
231
+
232
+ Returns:
233
+ A recursively shallow-copied version of the original object.
234
+ """
235
+ return recursive_copy(obj, shallow_copy)
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.13.1"
 
1
+ version = "1.14.0"