Elron commited on
Commit
5ba849c
1 Parent(s): 9245edf

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. artifact.py +76 -5
  2. augmentors.py +16 -11
  3. catalog.py +34 -5
  4. llm_as_judge.py +7 -1
  5. operator.py +1 -1
  6. operators.py +9 -9
  7. version.py +1 -1
artifact.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  import os
5
  import pkgutil
6
  import re
 
7
  from abc import abstractmethod
8
  from typing import Any, Dict, List, Optional, Tuple, Union, final
9
 
@@ -138,6 +139,12 @@ class Artifact(Dataclass):
138
  )
139
  __id__: str = InternalField(default=None, required=False, also_positional=False)
140
 
 
 
 
 
 
 
141
  data_classification_policy: List[str] = NonPositionalField(
142
  default=None, required=False, also_positional=False
143
  )
@@ -237,6 +244,11 @@ class Artifact(Dataclass):
237
  @classmethod
238
  def load(cls, path, artifact_identifier=None, overwrite_args=None):
239
  d = artifacts_json_cache(path)
 
 
 
 
 
240
  new_artifact = cls.from_dict(d, overwrite_args=overwrite_args)
241
  new_artifact.__id__ = artifact_identifier
242
  return new_artifact
@@ -247,7 +259,8 @@ class Artifact(Dataclass):
247
  return self.__class__.__name__
248
 
249
  def prepare(self):
250
- pass
 
251
 
252
  def verify(self):
253
  pass
@@ -396,6 +409,57 @@ class Artifact(Dataclass):
396
  return instance
397
 
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  def get_raw(obj):
400
  if isinstance(obj, Artifact):
401
  return obj._to_raw_dict()
@@ -456,20 +520,27 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]
456
  (5) Otherwise, check that the artifact representation is a dictionary and build an Artifact object from it.
457
  """
458
  if isinstance(artifact_rep, Artifact):
 
 
459
  return artifact_rep, None
460
 
461
  # If local file
462
  if isinstance(artifact_rep, str) and Artifact.is_artifact_file(artifact_rep):
463
- return Artifact.load(artifact_rep), None
 
 
464
 
465
- # If artifact name in catalog
 
 
466
  if isinstance(artifact_rep, str):
467
  name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
468
  if is_name_legal_for_catalog(name):
469
  catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
470
- return catalog.get_with_overwrite(
471
  artifact_rep, overwrite_args=args
472
- ), catalog
 
473
 
474
  # If Json string, first load into dictionary
475
  if isinstance(artifact_rep, str):
 
4
  import os
5
  import pkgutil
6
  import re
7
+ import warnings
8
  from abc import abstractmethod
9
  from typing import Any, Dict, List, Optional, Tuple, Union, final
10
 
 
139
  )
140
  __id__: str = InternalField(default=None, required=False, also_positional=False)
141
 
142
+ # if not None, the artifact is deprecated, and once instantiated, that msg
143
+ # is logged as a warning
144
+ __deprecated_msg__: str = NonPositionalField(
145
+ default=None, required=False, also_positional=False
146
+ )
147
+
148
  data_classification_policy: List[str] = NonPositionalField(
149
  default=None, required=False, also_positional=False
150
  )
 
244
  @classmethod
245
  def load(cls, path, artifact_identifier=None, overwrite_args=None):
246
  d = artifacts_json_cache(path)
247
+ if "artifact_linked_to" in d and d["artifact_linked_to"] is not None:
248
+ # d stands for an ArtifactLink
249
+ artifact_link = ArtifactLink.from_dict(d)
250
+ return artifact_link.load(overwrite_args)
251
+
252
  new_artifact = cls.from_dict(d, overwrite_args=overwrite_args)
253
  new_artifact.__id__ = artifact_identifier
254
  return new_artifact
 
259
  return self.__class__.__name__
260
 
261
  def prepare(self):
262
+ if self.__deprecated_msg__:
263
+ warnings.warn(self.__deprecated_msg__, DeprecationWarning, stacklevel=2)
264
 
265
  def verify(self):
266
  pass
 
409
  return instance
410
 
411
 
412
+ class ArtifactLink(Artifact):
413
+ # the artifact linked to, expressed by its catalog id
414
+ artifact_linked_to: str = Field(default=None, required=True)
415
+
416
+ @classmethod
417
+ def from_dict(cls, d: dict):
418
+ assert isinstance(d, dict), f"argument must be a dictionary, got: d = {d}."
419
+ assert (
420
+ "artifact_linked_to" in d and d["artifact_linked_to"] is not None
421
+ ), f"A non-none field named 'artifact_linked_to' is expected in input argument d, but got: {d}."
422
+ artifact_linked_to = d["artifact_linked_to"]
423
+ # artifact_linked_to is a name of catalog entry
424
+ assert isinstance(
425
+ artifact_linked_to, str
426
+ ), f"'artifact_linked_to' should be a string expressing a name of a catalog entry. Got{artifact_linked_to}."
427
+ msg = d["__deprecated_msg__"] if "__deprecated_msg__" in d else None
428
+ return ArtifactLink(
429
+ artifact_linked_to=artifact_linked_to, __deprecated_msg__=msg
430
+ )
431
+
432
+ def load(self, overwrite_args: dict) -> Artifact:
433
+ # identify the catalog for the artifact_linked_to
434
+ assert (
435
+ self.artifact_linked_to is not None
436
+ ), "'artifact_linked_to' must be non-None in order to load it from the catalog. Currently, it is None."
437
+ assert isinstance(
438
+ self.artifact_linked_to, str
439
+ ), f"'artifact_linked_to' should be a string (expressing a name of a catalog entry). Currently, its type is: {type(self.artifact_linked_to)}."
440
+ needed_catalog = None
441
+ catalogs = list(Catalogs())
442
+ for catalog in catalogs:
443
+ if self.artifact_linked_to in catalog:
444
+ needed_catalog = catalog
445
+
446
+ if needed_catalog is None:
447
+ raise UnitxtArtifactNotFoundError(self.artifact_linked_to, catalogs)
448
+
449
+ path = needed_catalog.path(self.artifact_linked_to)
450
+ d = artifacts_json_cache(path)
451
+ # if needed, follow, in a recursive manner, over multiple links,
452
+ # passing through instantiating of the ArtifactLink-s on the way, triggering
453
+ # deprecatioin warning as needed.
454
+ if "artifact_linked_to" in d and d["artifact_linked_to"] is not None:
455
+ # d stands for an ArtifactLink
456
+ artifact_link = ArtifactLink.from_dict(d)
457
+ return artifact_link.load(overwrite_args)
458
+ new_artifact = Artifact.from_dict(d, overwrite_args=overwrite_args)
459
+ new_artifact.__id__ = self.artifact_linked_to
460
+ return new_artifact
461
+
462
+
463
  def get_raw(obj):
464
  if isinstance(obj, Artifact):
465
  return obj._to_raw_dict()
 
520
  (5) Otherwise, check that the artifact representation is a dictionary and build an Artifact object from it.
521
  """
522
  if isinstance(artifact_rep, Artifact):
523
+ if isinstance(artifact_rep, ArtifactLink):
524
+ return fetch_artifact(artifact_rep.artifact_linked_to)
525
  return artifact_rep, None
526
 
527
  # If local file
528
  if isinstance(artifact_rep, str) and Artifact.is_artifact_file(artifact_rep):
529
+ artifact_to_return = Artifact.load(artifact_rep)
530
+ if isinstance(artifact_rep, ArtifactLink):
531
+ artifact_to_return = fetch_artifact(artifact_to_return.artifact_linked_to)
532
 
533
+ return artifact_to_return, None
534
+
535
+ # if artifact is a name of a catalog entry
536
  if isinstance(artifact_rep, str):
537
  name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
538
  if is_name_legal_for_catalog(name):
539
  catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
540
+ artifact_to_return = catalog.get_with_overwrite(
541
  artifact_rep, overwrite_args=args
542
+ )
543
+ return artifact_to_return, catalog
544
 
545
  # If Json string, first load into dictionary
546
  if isinstance(artifact_rep, str):
augmentors.py CHANGED
@@ -80,21 +80,26 @@ class AugmentWhitespace(TextAugmentor):
80
 
81
 
82
  class AugmentPrefixSuffix(TextAugmentor):
83
- r"""Augments the input by prepending and appending randomly selected (typically, whitespace) patterns.
84
 
85
  Args:
86
- prefixes, suffixes (list or dict) : the potential patterns (typically, whitespace) to select from. The dictionary version allows the specification relative weights for the different patterns.
87
- prefix_len, suffix_len (positive int) : The added prefix or suffix will be of a certain length.
88
- remove_existing_whitespaces : Clean any existing leading and trailing whitespaces. The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially trimmed input.
89
- If only either just prefixes or just suffixes are needed, set the other to None.
 
 
 
 
 
90
 
91
  Examples:
92
- To prepend the input with a prefix made of 4 '\n'-s or '\t'-s, employ
93
- AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)
94
- To append the input with a suffix made of 3 '\n'-s or '\t'-s, with triple '\n' suffixes
95
- being preferred over triple '\t', at 2:1 ratio, employ
96
- AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)
97
- which will append '\n'-s twice as often as '\t'-s.
98
 
99
  """
100
 
 
80
 
81
 
82
  class AugmentPrefixSuffix(TextAugmentor):
83
+ r"""Augments the input by prepending and appending randomly selected patterns (typically, whitespace).
84
 
85
  Args:
86
+ prefixes (list or dict or None): the potential patterns (typically, whitespace) to select prefix from. The dictionary version allows the specification of relative weights for the different patterns. Set to None if not needed (i.e., only suffixes are needed).
87
+
88
+ suffixes (list or dict or None): the potential patterns (typically, whitespace) to select suffix from. The dictionary version allows the specification of relative weights for the different patterns. Set to None if not needed (i.e., only prefixes are needed).
89
+
90
+ prefix_len (positive int): the length of the prefix to be added.
91
+
92
+ suffix_len (positive int): The length of the suffix to be added.
93
+
94
+ remove_existing_whitespaces (bool): Clean any existing leading and trailing whitespaces. The selected pattern(s) are then prepended and/or appended to the potentially trimmed input.
95
 
96
  Examples:
97
+ To prepend the input with a prefix made of 4 ``\n``-s or ``\t``-s, employ
98
+ ``AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)``.
99
+
100
+ To append the input with a suffix made of 3 ``\n``-s or ``\t``-s, with ``\n`` being preferred over ``\t``, at 2:1 ratio, employ
101
+ ``AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)``
102
+ which will append ``\n``-s twice as often as ``\t``-s.
103
 
104
  """
105
 
catalog.py CHANGED
@@ -10,6 +10,7 @@ import requests
10
  from .artifact import (
11
  AbstractCatalog,
12
  Artifact,
 
13
  Catalogs,
14
  get_catalog_name_and_args,
15
  reset_artifacts_json_cache,
@@ -51,7 +52,9 @@ class LocalCatalog(Catalog):
51
  ), f"Artifact with name {artifact_identifier} does not exist"
52
  path = self.path(artifact_identifier)
53
  return Artifact.load(
54
- path, artifact_identifier=artifact_identifier, overwrite_args=overwrite_args
 
 
55
  )
56
 
57
  def __getitem__(self, name) -> Artifact:
@@ -132,10 +135,36 @@ def add_to_catalog(
132
  catalog_path = constants.default_catalog_path
133
  catalog = LocalCatalog(location=catalog_path)
134
  verify_legal_catalog_name(name)
135
- catalog.save_artifact(
136
- artifact, name, overwrite=overwrite, verbose=verbose
137
- ) # remove collection (its actually the dir).
138
- # verify name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
 
141
  @lru_cache(maxsize=None)
 
10
  from .artifact import (
11
  AbstractCatalog,
12
  Artifact,
13
+ ArtifactLink,
14
  Catalogs,
15
  get_catalog_name_and_args,
16
  reset_artifacts_json_cache,
 
52
  ), f"Artifact with name {artifact_identifier} does not exist"
53
  path = self.path(artifact_identifier)
54
  return Artifact.load(
55
+ path,
56
+ artifact_identifier=artifact_identifier,
57
+ overwrite_args=overwrite_args,
58
  )
59
 
60
  def __getitem__(self, name) -> Artifact:
 
135
  catalog_path = constants.default_catalog_path
136
  catalog = LocalCatalog(location=catalog_path)
137
  verify_legal_catalog_name(name)
138
+ catalog.save_artifact(artifact, name, overwrite=overwrite, verbose=verbose)
139
+
140
+
141
+ def add_link_to_catalog(
142
+ artifact_linked_to: str,
143
+ name: str,
144
+ deprecate: bool = False,
145
+ catalog: Catalog = None,
146
+ overwrite: bool = False,
147
+ catalog_path: Optional[str] = None,
148
+ verbose=True,
149
+ ):
150
+ if deprecate:
151
+ deprecated_msg = f"Artifact '{name}' is deprecated. Artifact '{artifact_linked_to}' will be instantiated instead. "
152
+ deprecated_msg += f"In future uses, please reference artifact '{artifact_linked_to}' directly."
153
+ else:
154
+ deprecated_msg = None
155
+
156
+ artifact_link = ArtifactLink(
157
+ artifact_linked_to=artifact_linked_to, __deprecated_msg__=deprecated_msg
158
+ )
159
+
160
+ add_to_catalog(
161
+ artifact=artifact_link,
162
+ name=name,
163
+ catalog=catalog,
164
+ overwrite=overwrite,
165
+ catalog_path=catalog_path,
166
+ verbose=verbose,
167
+ )
168
 
169
 
170
  @lru_cache(maxsize=None)
llm_as_judge.py CHANGED
@@ -297,9 +297,15 @@ class LLMAsJudge(LLMAsJudgeBase):
297
 
298
  def prepare_instances(self, references, predictions, task_data):
299
  input_instances = self._get_input_instances(task_data)
300
- return self._get_instance_for_judge_model(
301
  input_instances, predictions, references
302
  )
 
 
 
 
 
 
303
 
304
 
305
  class TaskBasedLLMasJudge(LLMAsJudgeBase):
 
297
 
298
  def prepare_instances(self, references, predictions, task_data):
299
  input_instances = self._get_input_instances(task_data)
300
+ instances = self._get_instance_for_judge_model(
301
  input_instances, predictions, references
302
  )
303
+ # Copy the data classification policy from the original instance
304
+ for instance, single_task_data in zip(instances, task_data):
305
+ instance["data_classification_policy"] = single_task_data.get(
306
+ "metadata", {}
307
+ ).get("data_classification_policy")
308
+ return instances
309
 
310
 
311
  class TaskBasedLLMasJudge(LLMAsJudgeBase):
operator.py CHANGED
@@ -427,7 +427,7 @@ class InstanceOperator(StreamOperator):
427
  raise e
428
  else:
429
  raise ValueError(
430
- f"Error processing instance '{_index}' from stream '{stream_name}' in {self.__class__.__name__} due to: {e}"
431
  ) from e
432
 
433
  def _process_instance(
 
427
  raise e
428
  else:
429
  raise ValueError(
430
+ f"Error processing instance '{_index}' from stream '{stream_name}' in {self.__class__.__name__} due to the exception above."
431
  ) from e
432
 
433
  def _process_instance(
operators.py CHANGED
@@ -190,7 +190,7 @@ class MapInstanceValues(InstanceOperator):
190
  if value is not None:
191
  if (self.process_every_value is True) and (not isinstance(value, list)):
192
  raise ValueError(
193
- f"'process_every_field' == True is allowed only when all fields which have mappers, i.e., {list(self.mappers.keys())} are lists. Instance = {instance}"
194
  )
195
  if isinstance(value, list) and self.process_every_value:
196
  for i, val in enumerate(value):
@@ -211,7 +211,7 @@ class MapInstanceValues(InstanceOperator):
211
  return recursive_copy(mapper[val_as_str])
212
  if self.strict:
213
  raise KeyError(
214
- f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
215
  )
216
  return val
217
 
@@ -454,7 +454,7 @@ class InstanceFieldOperator(InstanceOperator):
454
  old_value = self.get_default
455
  except Exception as e:
456
  raise ValueError(
457
- f"Failed to get '{from_field}' from {instance} due to : {e}"
458
  ) from e
459
  try:
460
  if self.process_every_value:
@@ -466,7 +466,7 @@ class InstanceFieldOperator(InstanceOperator):
466
  new_value = self.process_instance_value(old_value, instance)
467
  except Exception as e:
468
  raise ValueError(
469
- f"Failed to process '{from_field}' from {instance} due to : {e}"
470
  ) from e
471
  dict_set(
472
  instance,
@@ -977,7 +977,7 @@ class CastFields(InstanceOperator):
977
  if self.process_every_value:
978
  assert isinstance(
979
  value, list
980
- ), f"'process_every_value' can be set to True only for fields that contain lists, whereas in instance {instance}, the contents of field '{field_name}' is of type '{type(value)}'"
981
  casted_value = self._cast_multiple(value, type, field_name)
982
  else:
983
  casted_value = self._cast_single(value, type, field_name)
@@ -1154,7 +1154,7 @@ class FilterByCondition(StreamOperator):
1154
  instance_key = dict_get(instance, key)
1155
  except ValueError as ve:
1156
  raise ValueError(
1157
- f"Required filter field ('{key}') in FilterByCondition is not found in {instance}"
1158
  ) from ve
1159
  if self.condition == "in":
1160
  if instance_key not in value:
@@ -1194,13 +1194,13 @@ class FilterByConditionBasedOnFields(FilterByCondition):
1194
  instance_key = dict_get(instance, key)
1195
  except ValueError as ve:
1196
  raise ValueError(
1197
- f"Required filter field ('{key}') in FilterByCondition is not found in {instance}"
1198
  ) from ve
1199
  try:
1200
  instance_value = dict_get(instance, value)
1201
  except ValueError as ve:
1202
  raise ValueError(
1203
- f"Required filter field ('{value}') in FilterByCondition is not found in {instance}"
1204
  ) from ve
1205
  if self.condition == "in":
1206
  if instance_key not in instance_value:
@@ -1551,7 +1551,7 @@ class SplitByNestedGroup(MultiStreamOperator):
1551
  for instance in stream:
1552
  if self.field_name_of_group not in instance:
1553
  raise ValueError(
1554
- f"Field {self.field_name_of_group} is missing from instance {instance}"
1555
  )
1556
  signature = (
1557
  stream_name
 
190
  if value is not None:
191
  if (self.process_every_value is True) and (not isinstance(value, list)):
192
  raise ValueError(
193
+ f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{key}' is '{value}'"
194
  )
195
  if isinstance(value, list) and self.process_every_value:
196
  for i, val in enumerate(value):
 
211
  return recursive_copy(mapper[val_as_str])
212
  if self.strict:
213
  raise KeyError(
214
+ f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
215
  )
216
  return val
217
 
 
454
  old_value = self.get_default
455
  except Exception as e:
456
  raise ValueError(
457
+ f"Failed to get '{from_field}' from instance due to the exception above."
458
  ) from e
459
  try:
460
  if self.process_every_value:
 
466
  new_value = self.process_instance_value(old_value, instance)
467
  except Exception as e:
468
  raise ValueError(
469
+ f"Failed to process field '{from_field}' from instance due to the exception above."
470
  ) from e
471
  dict_set(
472
  instance,
 
977
  if self.process_every_value:
978
  assert isinstance(
979
  value, list
980
+ ), f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{field_name}' is '{value}'"
981
  casted_value = self._cast_multiple(value, type, field_name)
982
  else:
983
  casted_value = self._cast_single(value, type, field_name)
 
1154
  instance_key = dict_get(instance, key)
1155
  except ValueError as ve:
1156
  raise ValueError(
1157
+ f"Required filter field ('{key}') in FilterByCondition is not found in instance."
1158
  ) from ve
1159
  if self.condition == "in":
1160
  if instance_key not in value:
 
1194
  instance_key = dict_get(instance, key)
1195
  except ValueError as ve:
1196
  raise ValueError(
1197
+ f"Required filter field ('{key}') in FilterByCondition is not found in instance"
1198
  ) from ve
1199
  try:
1200
  instance_value = dict_get(instance, value)
1201
  except ValueError as ve:
1202
  raise ValueError(
1203
+ f"Required filter field ('{value}') in FilterByCondition is not found in instance"
1204
  ) from ve
1205
  if self.condition == "in":
1206
  if instance_key not in instance_value:
 
1551
  for instance in stream:
1552
  if self.field_name_of_group not in instance:
1553
  raise ValueError(
1554
+ f"Field {self.field_name_of_group} is missing from instance. Available fields: {instance.keys()}"
1555
  )
1556
  signature = (
1557
  stream_name
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.15.8"
 
1
+ version = "1.15.9"