Elron commited on
Commit
b462f85
1 Parent(s): ff375eb

Upload folder using huggingface_hub

Browse files
Files changed (24) hide show
  1. artifact.py +4 -3
  2. blocks.py +1 -1
  3. card.py +2 -2
  4. formats.py +2 -2
  5. fusion.py +95 -43
  6. inference.py +78 -19
  7. llm_as_judge.py +111 -31
  8. loaders.py +33 -6
  9. metric_utils.py +103 -44
  10. metrics.py +76 -22
  11. operator.py +6 -4
  12. operators.py +97 -6
  13. processors.py +22 -9
  14. schema.py +14 -6
  15. settings_utils.py +2 -1
  16. splitters.py +6 -3
  17. standard.py +3 -0
  18. stream.py +36 -5
  19. string_operators.py +18 -2
  20. struct_data_operators.py +19 -0
  21. task.py +3 -51
  22. templates.py +142 -11
  23. text_utils.py +40 -0
  24. version.py +1 -1
artifact.py CHANGED
@@ -248,8 +248,9 @@ class Artifact(Dataclass):
248
  value = map_values_in_place(value, maybe_recover_artifact)
249
  setattr(self, field.name, value)
250
 
251
- self.prepare()
252
- self.verify()
 
253
 
254
  def _to_raw_dict(self):
255
  return {"type": self.type, **self._init_dict}
@@ -335,7 +336,7 @@ def get_artifactory_name_and_args(
335
 
336
  def verbosed_fetch_artifact(identifier):
337
  artifact, artifactory = fetch_artifact(identifier)
338
- logger.info(f"Artifact {identifier} is fetched from {artifactory}")
339
  return artifact
340
 
341
 
 
248
  value = map_values_in_place(value, maybe_recover_artifact)
249
  setattr(self, field.name, value)
250
 
251
+ if not settings.skip_artifacts_prepare_and_verify:
252
+ self.prepare()
253
+ self.verify()
254
 
255
  def _to_raw_dict(self):
256
  return {"type": self.type, **self._init_dict}
 
336
 
337
  def verbosed_fetch_artifact(identifier):
338
  artifact, artifactory = fetch_artifact(identifier)
339
+ logger.debug(f"Artifact {identifier} is fetched from {artifactory}")
340
  return artifact
341
 
342
 
blocks.py CHANGED
@@ -31,7 +31,7 @@ from .struct_data_operators import (
31
  TruncateTableCells,
32
  TruncateTableRows,
33
  )
34
- from .task import FormTask
35
  from .templates import (
36
  InputOutputTemplate,
37
  MultiLabelTemplate,
 
31
  TruncateTableCells,
32
  TruncateTableRows,
33
  )
34
+ from .task import Task
35
  from .templates import (
36
  InputOutputTemplate,
37
  MultiLabelTemplate,
card.py CHANGED
@@ -6,7 +6,7 @@ from .dataclass import OptionalField
6
  from .loaders import Loader
7
  from .operator import StreamingOperator
8
  from .splitters import RandomSampler, Sampler
9
- from .task import FormTask
10
 
11
 
12
  class TaskCard(Artifact):
@@ -24,6 +24,6 @@ class TaskCard(Artifact):
24
 
25
  loader: Loader
26
  preprocess_steps: List[StreamingOperator] = None
27
- task: FormTask
28
  templates: Collection = None
29
  sampler: Sampler = OptionalField(default_factory=RandomSampler)
 
6
  from .loaders import Loader
7
  from .operator import StreamingOperator
8
  from .splitters import RandomSampler, Sampler
9
+ from .task import Task
10
 
11
 
12
  class TaskCard(Artifact):
 
24
 
25
  loader: Loader
26
  preprocess_steps: List[StreamingOperator] = None
27
+ task: Task
28
  templates: Collection = None
29
  sampler: Sampler = OptionalField(default_factory=RandomSampler)
formats.py CHANGED
@@ -114,9 +114,9 @@ class SystemFormat(Format):
114
  """
115
 
116
  demos_field: str = "demos"
117
- demo_format: str = "{source}\n{target_prefix}{target}\n\n" # example: "User: {source}\nAgent: {target}\n\n"
118
  model_input_format: str = (
119
- "{system_prompt}{instruction}{demos}{source}\n{target_prefix}"
120
  )
121
  format_args: Dict[str, str] = OptionalField(default_factory=dict)
122
 
 
114
  """
115
 
116
  demos_field: str = "demos"
117
+ demo_format: str = "{source}\\N{target_prefix}{target}\n\n" # example: "User: {source}\nAgent: {target}\n\n"
118
  model_input_format: str = (
119
+ "{system_prompt}\\N{instruction}\\N{demos}{source}\\N{target_prefix}"
120
  )
121
  format_args: Dict[str, str] = OptionalField(default_factory=dict)
122
 
fusion.py CHANGED
@@ -1,31 +1,44 @@
1
- import copy
2
  from abc import abstractmethod
3
- from typing import Generator, List, Optional
4
 
5
  from .dataclass import NonPositionalField
6
  from .operator import SourceOperator
7
  from .random_utils import new_random_generator
8
- from .stream import MultiStream, Stream
 
9
 
10
 
11
  class BaseFusion(SourceOperator):
12
- """BaseFusion operator that combines multiple streams into one.
13
 
14
  Args:
15
- include_splits: List of splits to include. If None, all splits are included.
 
 
 
16
  """
17
 
18
- origins: List[SourceOperator]
19
  include_splits: Optional[List[str]] = NonPositionalField(default=None)
20
 
21
  @abstractmethod
22
  def fusion_generator(self, split) -> Generator:
23
  pass
24
 
25
- def splits(self) -> Generator:
 
 
 
 
 
 
 
 
 
 
26
  splits = []
27
- for origin in self.origins:
28
- for s in origin().keys():
29
  if s not in splits:
30
  if self.include_splits is None or s in self.include_splits:
31
  splits.append(s)
@@ -36,48 +49,62 @@ class BaseFusion(SourceOperator):
36
  ) -> MultiStream:
37
  result = {}
38
  for split in self.splits():
39
- result[split] = Stream(self.fusion_generator, gen_kwargs={"split": split})
 
 
40
  return MultiStream(result)
41
 
42
 
43
  class FixedFusion(BaseFusion):
44
- """FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.
45
 
46
  Args:
47
- origins: List of SourceOperator objects.
48
- examples_per_task: Number of examples per task. If None, all examples are returned.
49
- splits: List of splits to include. If None, all splits are included.
 
 
50
  """
51
 
52
- max_instances_per_origin: Optional[int] = None
 
 
 
53
 
 
54
  def fusion_generator(self, split) -> Generator:
55
- for origin in self.origins:
56
- multi_stream = origin()
57
- if split not in multi_stream:
58
  continue
59
- iterator = iter(multi_stream[split])
60
- if self.max_instances_per_origin is not None:
61
- for _ in range(self.max_instances_per_origin):
62
- try:
63
- yield next(iterator)
64
- except StopIteration:
65
- break
66
- else:
67
- yield from iterator
 
 
 
 
 
 
68
 
69
 
70
  class WeightedFusion(BaseFusion):
71
- """Fusion operator that combines multiple streams based.
72
 
73
  Args:
74
- origins: List of SourceOperator objects.
75
- weights: List of weights for each origin.
76
- max_total_examples: Total number of examples to return. If None, all examples are returned.
 
77
  """
78
 
79
- origins: List[SourceOperator] = None
80
- weights: List[float] = None
81
  max_total_examples: int = None
82
 
83
  def verify(self):
@@ -87,22 +114,47 @@ class WeightedFusion(BaseFusion):
87
  assert len(self.origins) == len(
88
  self.weights
89
  ), "origins and weights must have the same length"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def fusion_generator(self, split) -> Generator:
92
- weights = copy.deepcopy(self.weights)
93
- iterators = [iter(origin()[split]) for origin in self.origins]
 
 
94
  total_examples = 0
95
  random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
96
  while (
97
- self.max_total_examples is None or total_examples <= self.max_total_examples
98
  ) and len(iterators) > 0:
99
- iterator = random_generator.choices(population=iterators, weights=weights)[
100
- 0
101
- ]
 
 
 
102
  try:
103
- yield next(iterator)
 
 
 
 
 
104
  total_examples += 1
 
 
105
  except StopIteration:
106
- index = iterators.index(iterator)
107
- iterators.pop(index)
108
- weights.pop(index)
 
 
1
  from abc import abstractmethod
2
+ from typing import Dict, Generator, List, Optional, Union
3
 
4
  from .dataclass import NonPositionalField
5
  from .operator import SourceOperator
6
  from .random_utils import new_random_generator
7
+ from .stream import GeneratorStream, MultiStream
8
+ from .type_utils import isoftype
9
 
10
 
11
  class BaseFusion(SourceOperator):
12
+ """BaseFusion operator that combines multiple multistreams into one.
13
 
14
  Args:
15
+ origins: a dict of named SourceOperator objects (each to yield a MultiStream) or a list thereof,
16
+ each is specified along with its input, so can generate a MultiStream
17
+ include_splits: List of splits to include from each input MultiStream.
18
+ If None, all splits are included.
19
  """
20
 
21
+ origins: Union[List[SourceOperator], Dict[str, SourceOperator]]
22
  include_splits: Optional[List[str]] = NonPositionalField(default=None)
23
 
24
  @abstractmethod
25
  def fusion_generator(self, split) -> Generator:
26
  pass
27
 
28
+ def prepare(self):
29
+ assert isoftype(self.origins, Dict[str, SourceOperator]) or isoftype(
30
+ self.origins, List[SourceOperator]
31
+ )
32
+ self.named_origins = (
33
+ {i: self.origins[i]() for i in range(len(self.origins))}
34
+ if isinstance(self.origins, list)
35
+ else {name: origin() for name, origin in self.origins.items()}
36
+ )
37
+
38
+ def splits(self) -> List[str]:
39
  splits = []
40
+ for _, origin in self.named_origins.items():
41
+ for s in origin.keys():
42
  if s not in splits:
43
  if self.include_splits is None or s in self.include_splits:
44
  splits.append(s)
 
49
  ) -> MultiStream:
50
  result = {}
51
  for split in self.splits():
52
+ result[split] = GeneratorStream(
53
+ self.fusion_generator, gen_kwargs={"split": split}
54
+ )
55
  return MultiStream(result)
56
 
57
 
58
  class FixedFusion(BaseFusion):
59
+ """FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
60
 
61
  Args:
62
+ origins: Dict of named SourceOperator objects (each to yield a MultiStream), or a list thereof
63
+ splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
64
+ max_instances_per_origin_split: Number of instances to take from each input split of each input multistream.
65
+ If None, all instances of each split (that is specified in include_splits) are included in the result.
66
+
67
  """
68
 
69
+ max_instances_per_origin_split: Optional[int] = None
70
+
71
+ def prepare(self):
72
+ super().prepare()
73
 
74
+ # flake8: noqa: C901
75
  def fusion_generator(self, split) -> Generator:
76
+ for origin_name, origin in self.named_origins.items():
77
+ if split not in origin:
 
78
  continue
79
+ emitted_from_this_split = 0
80
+ for instance in origin[split]:
81
+ if (
82
+ self.max_instances_per_origin_split is not None
83
+ and emitted_from_this_split >= self.max_instances_per_origin_split
84
+ ):
85
+ break
86
+ if isinstance(origin_name, str):
87
+ # named origins, not anonymous, record in instance
88
+ if "group" in instance:
89
+ instance["group"] = origin_name + "/" + instance["group"]
90
+ else:
91
+ instance["group"] = origin_name
92
+ emitted_from_this_split += 1
93
+ yield instance
94
 
95
 
96
  class WeightedFusion(BaseFusion):
97
+ """Fusion operator that combines multiple MultiStream-s.
98
 
99
  Args:
100
+ origins: Dict of named MultiStream objects, or a list thereof
101
+ weights: Dict of named weights for each origin, or a list thereof
102
+ max_total_examples: Total number of instances to return per returned split.
103
+ If None, all instances are returned
104
  """
105
 
106
+ origins: Union[Dict[str, MultiStream], List[MultiStream]] = None
107
+ weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
108
  max_total_examples: int = None
109
 
110
  def verify(self):
 
114
  assert len(self.origins) == len(
115
  self.weights
116
  ), "origins and weights must have the same length"
117
+ assert isoftype(self.origins, Dict[str, SourceOperator]) or isoftype(
118
+ self.origins, List[SourceOperator]
119
+ )
120
+ assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
121
+ self.weights, List[Union[int, float]]
122
+ )
123
+ assert isinstance(self.origins, dict) == isinstance(self.weights, dict)
124
+
125
+ def prepare(self):
126
+ super().prepare()
127
+ self.named_weights = (
128
+ {i: float(self.weights[i]) for i in range(len(self.weights))}
129
+ if isinstance(self.weights, list)
130
+ else {k: float(v) for (k, v) in self.weights.items()}
131
+ )
132
 
133
  def fusion_generator(self, split) -> Generator:
134
+ iterators = {
135
+ named_origin: iter(origin[split])
136
+ for named_origin, origin in self.named_origins.items()
137
+ }
138
  total_examples = 0
139
  random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
140
  while (
141
+ self.max_total_examples is None or total_examples < self.max_total_examples
142
  ) and len(iterators) > 0:
143
+ population = list(iterators.keys())
144
+ origin_name = random_generator.choices(
145
+ population=population,
146
+ weights=[self.named_weights[name] for name in population],
147
+ )[0]
148
+ iterator = iterators[origin_name]
149
  try:
150
+ instance = next(iterator)
151
+ if isinstance(origin_name, str):
152
+ if "group" in instance:
153
+ instance["group"] = origin_name + "/" + instance["group"]
154
+ else:
155
+ instance["group"] = origin_name
156
  total_examples += 1
157
+ yield instance
158
+
159
  except StopIteration:
160
+ iterators.pop(origin_name)
 
 
inference.py CHANGED
@@ -1,7 +1,7 @@
1
  import abc
2
  import os
3
- from dataclasses import dataclass
4
- from typing import List, Optional, Union
5
 
6
  from .artifact import Artifact
7
  from .operator import PackageRequirementsMixin
@@ -28,28 +28,72 @@ class InferenceEngine(abc.ABC, Artifact):
28
  class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
29
  model_name: str
30
  max_new_tokens: int
 
31
  _requirement = {
32
  "transformers": "Install huggingface package using 'pip install --upgrade transformers"
33
  }
34
 
35
  def prepare(self):
36
- from transformers import pipeline
 
37
 
38
- self.model = pipeline(model=self.model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def infer(self, dataset):
41
- return [
42
- output["generated_text"]
43
- for output in self.model(
44
- [instance["source"] for instance in dataset],
45
- max_new_tokens=self.max_new_tokens,
46
- )
47
- ]
48
 
49
 
50
- @dataclass()
51
- class IbmGenAiInferenceEngineParams:
52
- decoding_method: str = None
 
 
 
 
 
 
 
 
 
53
  max_new_tokens: Optional[int] = None
54
  min_new_tokens: Optional[int] = None
55
  random_seed: Optional[int] = None
@@ -64,7 +108,9 @@ class IbmGenAiInferenceEngineParams:
64
  class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
65
  label: str = "ibm_genai"
66
  model_name: str
67
- parameters: IbmGenAiInferenceEngineParams = IbmGenAiInferenceEngineParams()
 
 
68
  _requirement = {
69
  "genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
70
  }
@@ -87,7 +133,19 @@ class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
87
  def infer(self, dataset):
88
  from genai.schema import TextGenerationParameters
89
 
90
- genai_params = TextGenerationParameters(**self.parameters.__dict__)
 
 
 
 
 
 
 
 
 
 
 
 
91
  return list(
92
  self.client.text.generation.create(
93
  model_id=self.model_name,
@@ -97,8 +155,7 @@ class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
97
  )
98
 
99
 
100
- @dataclass
101
- class OpenAiInferenceEngineParams:
102
  frequency_penalty: Optional[float] = None
103
  presence_penalty: Optional[float] = None
104
  max_tokens: Optional[int] = None
@@ -111,7 +168,9 @@ class OpenAiInferenceEngineParams:
111
  class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
112
  label: str = "openai"
113
  model_name: str
114
- parameters: OpenAiInferenceEngineParams = OpenAiInferenceEngineParams()
 
 
115
  _requirement = {
116
  "openai": "Install openai package using 'pip install --upgrade openai"
117
  }
 
1
  import abc
2
  import os
3
+ from dataclasses import field
4
+ from typing import Any, Dict, List, Literal, Optional, Union
5
 
6
  from .artifact import Artifact
7
  from .operator import PackageRequirementsMixin
 
28
  class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
29
  model_name: str
30
  max_new_tokens: int
31
+ use_fp16: bool = True
32
  _requirement = {
33
  "transformers": "Install huggingface package using 'pip install --upgrade transformers"
34
  }
35
 
36
  def prepare(self):
37
+ import torch
38
+ from transformers import AutoConfig, pipeline
39
 
40
+ model_args: Dict[str, Any] = (
41
+ {"torch_dtype": torch.float16} if self.use_fp16 else {}
42
+ )
43
+ model_args.update({"max_new_tokens": self.max_new_tokens})
44
+
45
+ device = torch.device(
46
+ "mps"
47
+ if torch.backends.mps.is_available()
48
+ else 0
49
+ if torch.cuda.is_available()
50
+ else "cpu"
51
+ )
52
+ # We do this, because in some cases, using device:auto will offload some weights to the cpu
53
+ # (even though the model might *just* fit to a single gpu), even if there is a gpu available, and this will
54
+ # cause an error because the data is always on the gpu
55
+ if torch.cuda.device_count() > 1:
56
+ assert device == torch.device(0)
57
+ model_args.update({"device_map": "auto"})
58
+ else:
59
+ model_args.update({"device": device})
60
+
61
+ task = (
62
+ "text2text-generation"
63
+ if AutoConfig.from_pretrained(
64
+ self.model_name, trust_remote_code=True
65
+ ).is_encoder_decoder
66
+ else "text-generation"
67
+ )
68
+
69
+ if task == "text-generation":
70
+ model_args.update({"return_full_text": False})
71
+
72
+ self.model = pipeline(
73
+ model=self.model_name, trust_remote_code=True, **model_args
74
+ )
75
 
76
  def infer(self, dataset):
77
+ outputs = []
78
+ for output in self.model([instance["source"] for instance in dataset]):
79
+ if isinstance(output, list):
80
+ output = output[0]
81
+ outputs.append(output["generated_text"])
82
+ return outputs
 
83
 
84
 
85
+ class MockInferenceEngine(InferenceEngine):
86
+ model_name: str
87
+
88
+ def prepare(self):
89
+ return
90
+
91
+ def infer(self, dataset):
92
+ return ["[[10]]" for instance in dataset]
93
+
94
+
95
+ class IbmGenAiInferenceEngineParams(Artifact):
96
+ decoding_method: Optional[Literal["greedy", "sample"]] = None
97
  max_new_tokens: Optional[int] = None
98
  min_new_tokens: Optional[int] = None
99
  random_seed: Optional[int] = None
 
108
  class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
109
  label: str = "ibm_genai"
110
  model_name: str
111
+ parameters: IbmGenAiInferenceEngineParams = field(
112
+ default_factory=IbmGenAiInferenceEngineParams
113
+ )
114
  _requirement = {
115
  "genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
116
  }
 
133
  def infer(self, dataset):
134
  from genai.schema import TextGenerationParameters
135
 
136
+ genai_params = TextGenerationParameters(
137
+ max_new_tokens=self.parameters.max_new_tokens,
138
+ min_new_tokens=self.parameters.min_new_tokens,
139
+ random_seed=self.parameters.random_seed,
140
+ repetition_penalty=self.parameters.repetition_penalty,
141
+ stop_sequences=self.parameters.stop_sequences,
142
+ temperature=self.parameters.temperature,
143
+ top_p=self.parameters.top_p,
144
+ top_k=self.parameters.top_k,
145
+ typical_p=self.parameters.typical_p,
146
+ decoding_method=self.parameters.decoding_method,
147
+ )
148
+
149
  return list(
150
  self.client.text.generation.create(
151
  model_id=self.model_name,
 
155
  )
156
 
157
 
158
+ class OpenAiInferenceEngineParams(Artifact):
 
159
  frequency_penalty: Optional[float] = None
160
  presence_penalty: Optional[float] = None
161
  max_tokens: Optional[int] = None
 
168
  class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
169
  label: str = "openai"
170
  model_name: str
171
+ parameters: OpenAiInferenceEngineParams = field(
172
+ default_factory=OpenAiInferenceEngineParams
173
+ )
174
  _requirement = {
175
  "openai": "Install openai package using 'pip install --upgrade openai"
176
  }
llm_as_judge.py CHANGED
@@ -1,58 +1,138 @@
1
- from typing import Any, Dict, List
2
 
3
- import evaluate
4
-
5
- from .api import produce
6
- from .inference import InferenceEngine
7
  from .metrics import BulkInstanceMetric
 
8
 
9
 
10
  class LLMAsJudge(BulkInstanceMetric):
11
  """LLM as judge based metric class for evaluating correctness.
12
 
13
  Attributes:
14
- main_score (str): The main score used for evaluation.
 
 
 
 
 
 
 
 
15
  reduction_map (dict): A dictionary specifying the reduction method for the metric.
16
- betch_size (int): The size of the bulk.
17
- recipe (str): The unitxt recipe that will be used to create the judge dataset.
18
- inference (InferenceEngine): the module that creates the inference.
19
-
20
- Methods:
21
- prepare(self): Initialization method for the metric.
22
- compute(self, references, predictions, additional_inputs): Method to compute the metric.
23
-
24
- Usage:
25
- metric = LlamaIndexCorrectnessMetric()
26
- scores = metric.compute(references, prediction, additional_inputs)
27
  """
28
 
29
  main_score: str = "llm_as_judge"
30
- reduction_map: Dict[str, List[str]] = None
31
- batch_size: int = 32
32
- recipe: str
 
 
33
  inference_model: InferenceEngine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def prepare(self):
36
  super().prepare()
37
  if self.reduction_map is None:
38
  self.reduction_map = {"mean": [self.main_score]}
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def compute(
41
  self,
42
  references: List[List[Any]],
43
  predictions: List[Any],
44
  task_data: List[Dict],
45
  ) -> List[Dict[str, Any]]:
46
- instances = [
47
- {
48
- **task_data_instance,
49
- **{"model_output": prediction, "rating_label": "[[5]]"},
50
- }
51
- for task_data_instance, prediction in zip(task_data, predictions)
52
- ]
53
-
54
- dataset = produce(instances, self.recipe)
 
 
 
 
 
 
 
 
 
55
  verdicts = self.inference_model.infer(dataset)
56
- meta_metric = evaluate.load("unitxt/metric")
57
- meta_scores = meta_metric.compute(predictions=verdicts, references=dataset)
58
  return [{self.main_score: instance["prediction"]} for instance in meta_scores]
 
1
+ from typing import Any, Dict, List, Literal, Optional
2
 
3
+ from .api import evaluate, produce
4
+ from .inference import InferenceEngine, OpenAiInferenceEngine
 
 
5
  from .metrics import BulkInstanceMetric
6
+ from .operator import SequentialOperator
7
 
8
 
9
  class LLMAsJudge(BulkInstanceMetric):
10
  """LLM as judge based metric class for evaluating correctness.
11
 
12
  Attributes:
13
+ main_score (str): The main score label used for evaluation.
14
+ task (Literal["rating.single_turn"]): The type of task the llm-as-judge runs. This defines the output and input
15
+ format of the jude model.
16
+ template (str): The template used when generating inputs for the judge llm.
17
+ format (str): The format used when generating inputs for judge llm.
18
+ system_prompt (str): The system prompt used when generating inputs for judge llm.
19
+ strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
20
+ inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
21
+ inference_model (InferenceEngine): the module that creates the inference of the judge llm.
22
  reduction_map (dict): A dictionary specifying the reduction method for the metric.
23
+ batch_size (int): The size of the bulk.
 
 
 
 
 
 
 
 
 
 
24
  """
25
 
26
  main_score: str = "llm_as_judge"
27
+ task: Literal["rating.single_turn", "single_turn_with_reference"]
28
+ template: str
29
+ format: Optional[str] = None
30
+ system_prompt: Optional[str] = None
31
+ strip_system_prompt_and_format_from_inputs: bool = True
32
  inference_model: InferenceEngine
33
+ reduction_map: Optional[Dict[str, List[str]]] = None
34
+ batch_size: int = 32
35
+
36
+ def _get_input_instances(self, task_data: List[Dict]) -> List:
37
+ if self.strip_system_prompt_and_format_from_inputs:
38
+ instances = []
39
+ for task_data_instance in task_data:
40
+ template = task_data_instance["metadata"]["template"]
41
+ instance = SequentialOperator(
42
+ steps=[template, "formats.empty"]
43
+ ).process_instance(
44
+ {"inputs": task_data_instance, "outputs": task_data_instance}
45
+ )
46
+ instances.append(instance["source"])
47
+ """
48
+ We also have access to: instance["target"]
49
+ instance["references"]
50
+ """
51
+ return instances
52
+ return [t["source"] for t in task_data]
53
+
54
+ def _get_instance_for_judge_model(
55
+ self, input_instances: List[str], predictions: List, references: List
56
+ ) -> List[Dict]:
57
+ if self.task == "rating.single_turn":
58
+ instances = [
59
+ {
60
+ "question": input_instance,
61
+ "answer": prediction,
62
+ "rating": 5.0, # This is a dummy value that is not used in practice
63
+ }
64
+ for input_instance, prediction, reference in zip(
65
+ input_instances, predictions, references
66
+ )
67
+ ]
68
+ elif self.task == "rating.single_turn_with_reference":
69
+ instances = [
70
+ {
71
+ "question": input_instance,
72
+ "answer": prediction,
73
+ "reference_answer": reference,
74
+ "rating": 5.0, # This is a dummy value that is not used in practice
75
+ }
76
+ for input_instance, prediction, reference in zip(
77
+ input_instances, predictions, references
78
+ )
79
+ ]
80
+ else:
81
+ raise NotImplementedError(
82
+ f"Error in 'LLMAsJudge' metric. {self.task} is not a supported task type."
83
+ )
84
+ return instances
85
 
86
  def prepare(self):
87
  super().prepare()
88
  if self.reduction_map is None:
89
  self.reduction_map = {"mean": [self.main_score]}
90
 
91
+ supported_tasks = ["rating.single_turn", "rating.single_turn_with_reference"]
92
+ assert self.task in supported_tasks, (
93
+ f"Error in 'LLMAsJudge' metric. {self.task} is not a supported task type."
94
+ f"The supported tasks types are: {', '.join(supported_tasks)}."
95
+ )
96
+
97
+ if isinstance(self.inference_model, OpenAiInferenceEngine):
98
+ if self.format:
99
+ raise ValueError(
100
+ "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
101
+ "not support formatting. Please remove the format definition from the recipe"
102
+ " (OpenAi Chat API take care of the formatting automatically)."
103
+ )
104
+ if self.system_prompt:
105
+ raise ValueError(
106
+ "Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
107
+ "not support system prompt. Please remove the system_prompt definition from the recipe"
108
+ " (Current implementation of Unitxt does not support this."
109
+ " Support will be added in future updates)."
110
+ )
111
+
112
  def compute(
113
  self,
114
  references: List[List[Any]],
115
  predictions: List[Any],
116
  task_data: List[Dict],
117
  ) -> List[Dict[str, Any]]:
118
+ input_instances = self._get_input_instances(task_data)
119
+ instances = self._get_instance_for_judge_model(
120
+ input_instances, predictions, references
121
+ )
122
+
123
+ card = f"cards.dynamic_cards_for_llm_judges.{self.task}"
124
+ recipe = (
125
+ f"card={card},"
126
+ f"template={self.template},"
127
+ "demos_pool_size=0,"
128
+ "num_demos=0"
129
+ )
130
+ if self.system_prompt:
131
+ recipe = f"{recipe},system_prompt={self.system_prompt}"
132
+ if self.format:
133
+ recipe = f"{recipe},format={self.format}"
134
+
135
+ dataset = produce(instances, recipe)
136
  verdicts = self.inference_model.infer(dataset)
137
+ meta_scores = evaluate(predictions=verdicts, data=dataset)
 
138
  return [{self.main_score: instance["prediction"]} for instance in meta_scores]
loaders.py CHANGED
@@ -27,7 +27,7 @@ import os
27
  import tempfile
28
  from pathlib import Path
29
  from tempfile import TemporaryDirectory
30
- from typing import Dict, List, Mapping, Optional, Sequence, Union
31
 
32
  import pandas as pd
33
  from datasets import load_dataset as hf_load_dataset
@@ -38,7 +38,7 @@ from .fusion import FixedFusion
38
  from .logging_utils import get_logger
39
  from .operator import SourceOperator
40
  from .settings_utils import get_settings
41
- from .stream import MultiStream, Stream
42
 
43
  logger = get_logger()
44
  settings = get_settings()
@@ -180,7 +180,7 @@ class LoadHF(Loader):
180
  self.log_limited_loading()
181
  return MultiStream(
182
  {
183
- name: Stream(
184
  generator=self.split_limited_load, gen_kwargs={"split_name": name}
185
  )
186
  for name in self._cache.keys()
@@ -240,14 +240,18 @@ class LoadCSV(Loader):
240
  if self.streaming:
241
  return MultiStream(
242
  {
243
- name: Stream(generator=self.stream_csv, gen_kwargs={"file": file})
 
 
244
  for name, file in self.files.items()
245
  }
246
  )
247
 
248
  return MultiStream(
249
  {
250
- name: Stream(generator=self.load_csv, gen_kwargs={"file": file})
 
 
251
  for name, file in self.files.items()
252
  }
253
  )
@@ -472,5 +476,28 @@ class MultipleSourceLoader(Loader):
472
 
473
  def process(self):
474
  return FixedFusion(
475
- origins=self.sources, max_instances_per_origin=self.get_limit()
476
  ).process()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  import tempfile
28
  from pathlib import Path
29
  from tempfile import TemporaryDirectory
30
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
31
 
32
  import pandas as pd
33
  from datasets import load_dataset as hf_load_dataset
 
38
  from .logging_utils import get_logger
39
  from .operator import SourceOperator
40
  from .settings_utils import get_settings
41
+ from .stream import GeneratorStream, MultiStream
42
 
43
  logger = get_logger()
44
  settings = get_settings()
 
180
  self.log_limited_loading()
181
  return MultiStream(
182
  {
183
+ name: GeneratorStream(
184
  generator=self.split_limited_load, gen_kwargs={"split_name": name}
185
  )
186
  for name in self._cache.keys()
 
240
  if self.streaming:
241
  return MultiStream(
242
  {
243
+ name: GeneratorStream(
244
+ generator=self.stream_csv, gen_kwargs={"file": file}
245
+ )
246
  for name, file in self.files.items()
247
  }
248
  )
249
 
250
  return MultiStream(
251
  {
252
+ name: GeneratorStream(
253
+ generator=self.load_csv, gen_kwargs={"file": file}
254
+ )
255
  for name, file in self.files.items()
256
  }
257
  )
 
476
 
477
  def process(self):
478
  return FixedFusion(
479
+ origins=self.sources, max_instances_per_origin_split=self.get_limit()
480
  ).process()
481
+
482
+
483
+ class LoadFromDictionary(Loader):
484
+ """Allows loading data from dictionary of constants.
485
+
486
+ The loader can be used, for example, when debugging or working with small datasets.
487
+
488
+ Attributes:
489
+ data (Dict[str, List[Dict[str, Any]]]): a dictionary of constants from which the data will be loaded
490
+
491
+ Examples:
492
+ data = {
493
+ "train": {"input": "SomeInput1", "output": "SomeResult1"},
494
+ "test": {"input": "SomeInput2", "output": "SomeResult2"},
495
+ }
496
+ loader = LoadFromDictionary(data=data)
497
+ multi_stream = loader.process()
498
+ """
499
+
500
+ data: Dict[str, List[Dict[str, Any]]]
501
+
502
+ def process(self) -> MultiStream:
503
+ return MultiStream.from_iterables(self.data)
metric_utils.py CHANGED
@@ -1,71 +1,125 @@
1
  import json
2
- from typing import Any, Dict, Iterable, List, Optional
 
3
 
4
  from datasets import Features, Value
 
5
 
6
  from .dataclass import Dataclass
 
7
  from .operator import (
8
  MultiStreamOperator,
9
  SequentialOperatorInitializer,
10
  StreamInitializerOperator,
11
  )
12
  from .operators import (
13
- Apply,
14
  ApplyMetric,
15
  ApplyOperatorsField,
 
16
  FlattenInstances,
17
  MergeStreams,
18
- SplitByValue,
19
  )
20
  from .register import _reset_env_local_catalogs, register_all_artifacts
21
  from .schema import UNITXT_DATASET_SCHEMA
22
  from .settings_utils import get_settings
23
- from .stream import MultiStream, Stream
 
24
 
25
 
26
  class MultiStreamScoreMean(MultiStreamOperator):
27
- def aggregate_results(self, multi_stream: MultiStream):
28
- scores = []
29
- for stream in multi_stream.values():
30
- instance = stream.peek()
31
- scores.append(instance["score"]["global"]["score"])
32
-
33
- from statistics import mean
34
-
35
- return mean(scores)
36
-
37
- def spread_results(self, stream: Stream, score: float):
38
- for instance in stream:
39
- instance["score"]["global"]["groups_mean_score"] = score
40
- yield instance
 
 
 
 
 
 
 
 
41
 
42
- def spread_results_one_stream(self, stream: Stream):
43
- for instance in stream:
44
- instance["score"]["global"]["groups_mean_score"] = instance["score"][
45
- "global"
46
- ]["score"]
47
- yield instance
 
 
 
 
48
 
49
  def process(self, multi_stream: MultiStream) -> MultiStream:
50
- result = {}
51
-
52
- # optimization in to avoid double calculation of metrics
53
- # when aggregating results, if there is only one stream.
 
54
  if len(multi_stream) == 1:
55
- for stream_name, stream in multi_stream.items():
56
- result[stream_name] = Stream(
57
- self.spread_results_one_stream, gen_kwargs={"stream": stream}
58
- )
59
- return MultiStream(result)
60
 
61
- mean_score = self.aggregate_results(multi_stream)
62
- result = {}
63
  for stream_name, stream in multi_stream.items():
64
- result[stream_name] = Stream(
65
- self.spread_results, gen_kwargs={"stream": stream, "score": mean_score}
 
 
 
 
 
 
 
 
 
66
  )
67
 
68
- return MultiStream(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  class FromPredictionsAndOriginalData(StreamInitializerOperator):
@@ -78,7 +132,7 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
78
  ) -> MultiStream:
79
  return MultiStream(
80
  {
81
- split_name: Stream(
82
  self.zip,
83
  gen_kwargs={"predictions": predictions, "references": references},
84
  )
@@ -94,20 +148,25 @@ class FromPredictionsAndOriginalData(StreamInitializerOperator):
94
 
95
  class MetricRecipe(SequentialOperatorInitializer):
96
  calc_confidence_intervals: bool = True
 
97
 
98
  def prepare(self):
99
  register_all_artifacts()
100
  self.steps = [
101
  FromPredictionsAndOriginalData(),
102
- Apply(
103
- "task_data",
104
- function="json.loads",
105
- to_field="task_data",
 
106
  ),
107
  ApplyOperatorsField(
108
  operators_field="postprocessors",
109
  ),
110
- SplitByValue(["group"]),
 
 
 
111
  ApplyMetric(
112
  "metrics",
113
  calc_confidence_intervals=self.calc_confidence_intervals,
 
1
  import json
2
+ from copy import deepcopy
3
+ from typing import Any, Dict, Generator, Iterable, List, Optional
4
 
5
  from datasets import Features, Value
6
+ from numpy import nanmean
7
 
8
  from .dataclass import Dataclass
9
+ from .dict_utils import dict_set
10
  from .operator import (
11
  MultiStreamOperator,
12
  SequentialOperatorInitializer,
13
  StreamInitializerOperator,
14
  )
15
  from .operators import (
 
16
  ApplyMetric,
17
  ApplyOperatorsField,
18
+ CopyFields,
19
  FlattenInstances,
20
  MergeStreams,
21
+ SplitByNestedGroup,
22
  )
23
  from .register import _reset_env_local_catalogs, register_all_artifacts
24
  from .schema import UNITXT_DATASET_SCHEMA
25
  from .settings_utils import get_settings
26
+ from .stream import GeneratorStream, MultiStream
27
+ from .struct_data_operators import LoadJson
28
 
29
 
30
  class MultiStreamScoreMean(MultiStreamOperator):
31
+ """Given a multi-stream where each stream is already scored globally, generate a nested global score for the whole multi-stream.
32
+
33
+ The whole-ms-global-score is a nested structure, specifying (also) the individual global scores of the
34
+ individual streams participating in the input multi_stream.
35
+ The instances of all these individual streams are assumed to have the "group" field indicate the stream
36
+ they belong to.
37
+ Potentially, these individual streams were produced from a SplitByNestedGroup
38
+ operator that did not use the full length of the value in field "group" of the instances, but only the
39
+ first g components thereof, indicated by argument 'number_of_fusion_generations' of operator SplitByNestedGroup.
40
+ At any rate, a distinguishing prefix of the "group" value is recorded, by operator SplitByNestedGroup, in the stream_name.
41
+ The nested structure of the whole-ms-global-score is induced by these distinguishing prefixes,
42
+ by virtue of the global score of each individual stream sitting in the nested whole-ms-global-score,
43
+ deep in that dictionary, at the leaf lead to by a path being the distinguishing prefix indicated in the stream_name.
44
+ Thus, the global score of the stream becomes a leaf (though a dict by itself) of the whole-ms-global-score.
45
+
46
+ The ancestor nodes of the above leaves, in the whole-ms-global-score, contain each (in addition to dicts
47
+ leading down to leaves) a field named "score" whose value is set to be the mean of the values
48
+ sitting in field "score" of its immediate children nodes, and a field named "score_name" whose
49
+ value is set to be "group_mean".
50
+
51
+ When the input multistream consists of one single stream, it is returned as is, mainly for backward compatibility.
52
+ """
53
 
54
+ def update_intermediate_level_scores(self, level: dict) -> float:
55
+ if "score" in level:
56
+ return level["score"]
57
+ # the global score of the stream participating in this MultiStream
58
+ sub_scores = []
59
+ for key in level:
60
+ if isinstance(level[key], dict):
61
+ sub_scores.append(self.update_intermediate_level_scores(level[key]))
62
+ level.update({"score": nanmean(sub_scores), "score_name": "groups_mean"})
63
+ return level["score"]
64
 
65
  def process(self, multi_stream: MultiStream) -> MultiStream:
66
+ # each stream went through Metric which is a single-stream-operator , and ended up with all
67
+ # its instance["score"]["global"] linking to the same single dict object.
68
+ # Here we first generate a new, nested version, for the whole-ms-global_score, and then update
69
+ # each stream's global score with the new version
70
+ # but if only one stream in the multistream - we return it as is
71
  if len(multi_stream) == 1:
72
+ return multi_stream
73
+ global_score = {}
74
+ first_instances = {}
75
+ iterators = {}
 
76
 
 
 
77
  for stream_name, stream in multi_stream.items():
78
+ iterators[stream_name] = iter(stream)
79
+ try:
80
+ first_instances[stream_name] = next(iterators[stream_name])
81
+ except StopIteration:
82
+ continue # an empty stream, goto next stream
83
+ instance = first_instances[stream_name]
84
+ dict_set(
85
+ dic=global_score,
86
+ query=stream_name.split("~")[-1],
87
+ value=deepcopy(instance["score"]["global"]),
88
+ not_exist_ok=True,
89
  )
90
 
91
+ self.update_intermediate_level_scores(global_score)
92
+ # update the global_score object for each stream. Recall that all instances
93
+ # in each stream link all to same python dict object
94
+ for stream_name in multi_stream.keys():
95
+ instance = first_instances[stream_name]
96
+ instance["score"]["global"].clear()
97
+ instance["score"]["global"].update(global_score)
98
+
99
+ def never_peek_twice_generator(
100
+ stream_name: str, first_instances: dict, iterators: dict
101
+ ) -> Generator:
102
+ while True:
103
+ if stream_name in first_instances:
104
+ yield first_instances.pop(stream_name)
105
+ try:
106
+ yield next(iterators[stream_name])
107
+ except StopIteration:
108
+ return
109
+
110
+ return MultiStream(
111
+ {
112
+ stream_name: GeneratorStream(
113
+ never_peek_twice_generator,
114
+ gen_kwargs={
115
+ "stream_name": stream_name,
116
+ "first_instances": first_instances,
117
+ "iterators": iterators,
118
+ },
119
+ )
120
+ for stream_name in multi_stream.keys()
121
+ }
122
+ )
123
 
124
 
125
  class FromPredictionsAndOriginalData(StreamInitializerOperator):
 
132
  ) -> MultiStream:
133
  return MultiStream(
134
  {
135
+ split_name: GeneratorStream(
136
  self.zip,
137
  gen_kwargs={"predictions": predictions, "references": references},
138
  )
 
148
 
149
  class MetricRecipe(SequentialOperatorInitializer):
150
  calc_confidence_intervals: bool = True
151
+ number_of_fusion_generations: int = 2
152
 
153
  def prepare(self):
154
  register_all_artifacts()
155
  self.steps = [
156
  FromPredictionsAndOriginalData(),
157
+ LoadJson(field="task_data"),
158
+ CopyFields(
159
+ field_to_field={
160
+ "source": "task_data/source",
161
+ }
162
  ),
163
  ApplyOperatorsField(
164
  operators_field="postprocessors",
165
  ),
166
+ SplitByNestedGroup(
167
+ field_name_of_group="group",
168
+ number_of_fusion_generations=self.number_of_fusion_generations,
169
+ ),
170
  ApplyMetric(
171
  "metrics",
172
  calc_confidence_intervals=self.calc_confidence_intervals,
metrics.py CHANGED
@@ -3,7 +3,7 @@ import string
3
  import uuid
4
  import warnings
5
  from abc import ABC, abstractmethod
6
- from collections import Counter
7
  from copy import deepcopy
8
  from dataclasses import field
9
  from statistics import mean
@@ -915,11 +915,15 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
915
  if uses_subgroups
916
  else score_dict[default_subgroup_name]
917
  )
918
- for score_name, score_dict in group_scores.items()
 
 
919
  }
920
  }
921
  }
922
- for group_scores in group_to_instance_scores.values()
 
 
923
  ]
924
 
925
  def _set_up_group_mean_aggregation(
@@ -977,6 +981,40 @@ class Accuracy(InstanceMetric):
977
  return result
978
 
979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  class MaxAccuracy(Accuracy):
981
  """Calculate the maximal accuracy over all instances as the global score."""
982
 
@@ -1274,10 +1312,13 @@ class F1Binary(GlobalMetric):
1274
  _metric = None
1275
  metric = "f1"
1276
  single_reference_per_prediction = True
 
1277
 
1278
  def prepare(self):
1279
  super().prepare()
1280
- self._metric = evaluate.load(self.metric)
 
 
1281
 
1282
  def _validate_reference(self, reference):
1283
  super()._validate_reference(reference)
@@ -1294,19 +1335,27 @@ class F1Binary(GlobalMetric):
1294
  ) -> dict:
1295
  flattened_int_references = [int(r[0]) for r in references]
1296
  int_predictions = [int(p > self.threshold) for p in predictions]
1297
-
1298
- result = self._metric.compute(
1299
- references=flattened_int_references,
1300
- predictions=int_predictions,
1301
  labels=[0, 1],
1302
  average=self.average,
1303
  )
1304
- if isinstance(result[self.metric], numpy.ndarray):
1305
  return {
1306
- self.main_score: result[self.metric][1],
1307
- f"{self.main_score}_neg": result[self.metric][0],
 
 
 
 
1308
  }
1309
- return {self.main_score: result[self.metric]}
 
 
 
 
 
1310
 
1311
 
1312
  class RecallBinary(F1Binary):
@@ -3358,6 +3407,7 @@ class BinaryMaxF1(F1Binary):
3358
 
3359
  main_score = "max_f1_binary"
3360
  single_reference_per_prediction = True
 
3361
 
3362
  def compute(
3363
  self,
@@ -3366,9 +3416,9 @@ class BinaryMaxF1(F1Binary):
3366
  task_data: List[Dict],
3367
  ) -> dict:
3368
  best_thr = -1
3369
- best_f1 = -1
3370
  best_thr_neg = -1
3371
- best_f1_neg = -1
3372
  thrs = {round(fp, 3) for fp in predictions}
3373
  for thr in thrs:
3374
  new_predictions = [
@@ -3377,21 +3427,25 @@ class BinaryMaxF1(F1Binary):
3377
  ]
3378
  f1_results = super().compute(references, new_predictions, task_data)
3379
 
3380
- f1 = f1_results[self.main_score]
3381
- if f1 > best_f1:
3382
- best_f1 = f1
3383
  best_thr = thr
3384
 
3385
- f1_neg = f1_results[f"{self.main_score}_neg"]
3386
- if f1_neg > best_f1_neg:
3387
- best_f1_neg = f1_neg
3388
  best_thr_neg = thr
3389
 
3390
  return {
3391
- self.main_score: best_f1,
3392
  "best_thr_maxf1": best_thr,
3393
- f"{self.main_score}_neg": best_f1_neg,
3394
  "best_thr_maxf1_neg": best_thr_neg,
 
 
 
 
3395
  }
3396
 
3397
 
 
3
  import uuid
4
  import warnings
5
  from abc import ABC, abstractmethod
6
+ from collections import Counter, defaultdict
7
  from copy import deepcopy
8
  from dataclasses import field
9
  from statistics import mean
 
915
  if uses_subgroups
916
  else score_dict[default_subgroup_name]
917
  )
918
+ for score_name, score_dict in group_to_instance_scores[
919
+ group_name
920
+ ].items()
921
  }
922
  }
923
  }
924
+ for group_name in sorted(
925
+ group_to_instance_scores.keys()
926
+ ) # sorted for consistency
927
  ]
928
 
929
  def _set_up_group_mean_aggregation(
 
981
  return result
982
 
983
 
984
+ class JaccardIndex(InstanceMetric):
985
+ reduction_map = {"mean": ["jaccard_index"]}
986
+ main_score = "jaccard_index"
987
+ ci_scores = ["jaccard_index"]
988
+
989
+ prediction_type = "Any" # string representation is compared
990
+
991
+ def compute(
992
+ self, references: List[Any], prediction: Any, task_data: List[Dict]
993
+ ) -> dict:
994
+ if not isinstance(prediction, set):
995
+ prediction = set(prediction)
996
+ references = [set(reference) for reference in references]
997
+
998
+ result = {
999
+ self.main_score: max(
1000
+ [
1001
+ float(
1002
+ (len(reference.intersection(prediction)))
1003
+ / (
1004
+ len(reference)
1005
+ + len(prediction)
1006
+ - len(reference.intersection(prediction))
1007
+ )
1008
+ )
1009
+ for reference in references
1010
+ ]
1011
+ )
1012
+ }
1013
+ result["score"] = result[self.main_score]
1014
+ result["score_name"] = self.main_score
1015
+ return result
1016
+
1017
+
1018
  class MaxAccuracy(Accuracy):
1019
  """Calculate the maximal accuracy over all instances as the global score."""
1020
 
 
1312
  _metric = None
1313
  metric = "f1"
1314
  single_reference_per_prediction = True
1315
+ _requirements_list: List[str] = ["sklearn"]
1316
 
1317
  def prepare(self):
1318
  super().prepare()
1319
+ from sklearn import metrics
1320
+
1321
+ self._metric = metrics.precision_recall_fscore_support
1322
 
1323
  def _validate_reference(self, reference):
1324
  super()._validate_reference(reference)
 
1335
  ) -> dict:
1336
  flattened_int_references = [int(r[0]) for r in references]
1337
  int_predictions = [int(p > self.threshold) for p in predictions]
1338
+ precision, recall, f1, _ = self._metric(
1339
+ y_true=flattened_int_references,
1340
+ y_pred=int_predictions,
 
1341
  labels=[0, 1],
1342
  average=self.average,
1343
  )
1344
+ if self.average is None:
1345
  return {
1346
+ "f1_binary": f1[1],
1347
+ "f1_binary_neg": f1[0],
1348
+ "recall_binary": recall[1],
1349
+ "recall_binary_neg": recall[0],
1350
+ "precision_binary": precision[1],
1351
+ "precision_binary_neg": precision[0],
1352
  }
1353
+ return {"f1_binary": f1, "recall_binary": recall, "precision_binary": precision}
1354
+
1355
+
1356
+ class F1BinaryPosOnly(F1Binary):
1357
+ average = "binary"
1358
+ main_score = "f1_binary"
1359
 
1360
 
1361
  class RecallBinary(F1Binary):
 
3407
 
3408
  main_score = "max_f1_binary"
3409
  single_reference_per_prediction = True
3410
+ average = None
3411
 
3412
  def compute(
3413
  self,
 
3416
  task_data: List[Dict],
3417
  ) -> dict:
3418
  best_thr = -1
3419
+ best_f1 = defaultdict(lambda: -1)
3420
  best_thr_neg = -1
3421
+ best_f1_neg = defaultdict(lambda: -1)
3422
  thrs = {round(fp, 3) for fp in predictions}
3423
  for thr in thrs:
3424
  new_predictions = [
 
3427
  ]
3428
  f1_results = super().compute(references, new_predictions, task_data)
3429
 
3430
+ f1 = f1_results["f1_binary"]
3431
+ if f1 > best_f1["f1_binary"]:
3432
+ best_f1 = f1_results.copy()
3433
  best_thr = thr
3434
 
3435
+ f1_neg = f1_results["f1_binary_neg"]
3436
+ if f1_neg > best_f1_neg["f1_binary_neg"]:
3437
+ best_f1_neg = f1_results.copy()
3438
  best_thr_neg = thr
3439
 
3440
  return {
3441
+ self.main_score: best_f1["f1_binary"],
3442
  "best_thr_maxf1": best_thr,
3443
+ f"{self.main_score}_neg": best_f1_neg["f1_binary_neg"],
3444
  "best_thr_maxf1_neg": best_thr_neg,
3445
+ "recall_at_max_f1": best_f1["recall_binary"],
3446
+ "recall_at_max_f1_neg": best_f1_neg["recall_binary_neg"],
3447
+ "precision_at_max_f1": best_f1["precision_binary"],
3448
+ "precision_at_max_f1_neg": best_f1_neg["precision_binary_neg"],
3449
  }
3450
 
3451
 
operator.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any, Dict, Generator, List, Optional, Union
5
 
6
  from .artifact import Artifact
7
  from .dataclass import InternalField, NonPositionalField
8
- from .stream import MultiStream, Stream
9
  from .utils import is_module_available
10
 
11
 
@@ -171,7 +171,9 @@ def instance_generator(instance):
171
 
172
 
173
  def stream_single(instance: Dict[str, Any]) -> Stream:
174
- return Stream(generator=instance_generator, gen_kwargs={"instance": instance})
 
 
175
 
176
 
177
  class MultiStreamOperator(StreamingOperator):
@@ -244,7 +246,7 @@ class SingleStreamOperator(MultiStreamOperator):
244
  def _process_single_stream(
245
  self, stream: Stream, stream_name: Optional[str] = None
246
  ) -> Stream:
247
- return Stream(
248
  self._process_stream,
249
  gen_kwargs={"stream": stream, "stream_name": stream_name},
250
  )
@@ -445,7 +447,7 @@ class InstanceOperatorWithMultiStreamAccess(StreamingOperator):
445
  result = {}
446
 
447
  for stream_name, stream in multi_stream.items():
448
- stream = Stream(
449
  self.generator,
450
  gen_kwargs={"stream": stream, "multi_stream": multi_stream},
451
  )
 
5
 
6
  from .artifact import Artifact
7
  from .dataclass import InternalField, NonPositionalField
8
+ from .stream import GeneratorStream, MultiStream, Stream
9
  from .utils import is_module_available
10
 
11
 
 
171
 
172
 
173
  def stream_single(instance: Dict[str, Any]) -> Stream:
174
+ return GeneratorStream(
175
+ generator=instance_generator, gen_kwargs={"instance": instance}
176
+ )
177
 
178
 
179
  class MultiStreamOperator(StreamingOperator):
 
246
  def _process_single_stream(
247
  self, stream: Stream, stream_name: Optional[str] = None
248
  ) -> Stream:
249
+ return GeneratorStream(
250
  self._process_stream,
251
  gen_kwargs={"stream": stream, "stream_name": stream_name},
252
  )
 
447
  result = {}
448
 
449
  for stream_name, stream in multi_stream.items():
450
+ stream = GeneratorStream(
451
  self.generator,
452
  gen_kwargs={"stream": stream, "multi_stream": multi_stream},
453
  )
operators.py CHANGED
@@ -37,7 +37,7 @@ import operator
37
  import uuid
38
  import zipfile
39
  from abc import abstractmethod
40
- from collections import Counter
41
  from copy import deepcopy
42
  from dataclasses import field
43
  from itertools import zip_longest
@@ -75,7 +75,7 @@ from .operator import (
75
  )
76
  from .random_utils import new_random_generator
77
  from .settings_utils import get_settings
78
- from .stream import Stream
79
  from .text_utils import nested_tuple_to_string
80
  from .type_utils import isoftype
81
  from .utils import flatten_dict
@@ -490,7 +490,7 @@ class Augmentor(StreamInstanceOperator):
490
 
491
  Args:
492
  augment_model_input: Whether to augment the input to the model.
493
- augment_task_input: Whether to augment the task input fields. The specific fields are defined in the FormTask operator.
494
 
495
  """
496
 
@@ -525,7 +525,7 @@ class Augmentor(StreamInstanceOperator):
525
  if self.augment_task_input:
526
  assert (
527
  len(self._task_input_fields) > 0
528
- ), "No augmentable input fields were defined in FormTask, and augmentation was requested. Specify the fields to augment in 'argumentable_inputs' attribute of the FormTask."
529
  fields = self._task_input_fields
530
  assert not self.augment_model_input
531
 
@@ -860,6 +860,51 @@ class ZipFieldValues(StreamInstanceOperator):
860
  return instance
861
 
862
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
  class IndexOf(StreamInstanceOperator):
864
  """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
865
 
@@ -1560,6 +1605,52 @@ class SplitByValue(MultiStreamOperator):
1560
  return MultiStream(result)
1561
 
1562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1563
  class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
1564
  """Applies stream operators to a stream based on specified fields in each instance.
1565
 
@@ -1668,7 +1759,7 @@ class MergeStreams(MultiStreamOperator):
1668
  add_origin_stream_name: bool = True
1669
  origin_stream_name_field_name: str = "origin"
1670
 
1671
- def merge(self, multi_stream):
1672
  for stream_name, stream in multi_stream.items():
1673
  if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1674
  for instance in stream:
@@ -1679,7 +1770,7 @@ class MergeStreams(MultiStreamOperator):
1679
  def process(self, multi_stream: MultiStream) -> MultiStream:
1680
  return MultiStream(
1681
  {
1682
- self.new_stream_name: Stream(
1683
  self.merge, gen_kwargs={"multi_stream": multi_stream}
1684
  )
1685
  }
 
37
  import uuid
38
  import zipfile
39
  from abc import abstractmethod
40
+ from collections import Counter, defaultdict
41
  from copy import deepcopy
42
  from dataclasses import field
43
  from itertools import zip_longest
 
75
  )
76
  from .random_utils import new_random_generator
77
  from .settings_utils import get_settings
78
+ from .stream import GeneratorStream, Stream
79
  from .text_utils import nested_tuple_to_string
80
  from .type_utils import isoftype
81
  from .utils import flatten_dict
 
490
 
491
  Args:
492
  augment_model_input: Whether to augment the input to the model.
493
+ augment_task_input: Whether to augment the task input fields. The specific fields are defined in the Task operator.
494
 
495
  """
496
 
 
525
  if self.augment_task_input:
526
  assert (
527
  len(self._task_input_fields) > 0
528
+ ), "No augmentable input fields were defined in Task, and augmentation was requested. Specify the fields to augment in 'argumentable_inputs' attribute of the Task."
529
  fields = self._task_input_fields
530
  assert not self.augment_model_input
531
 
 
860
  return instance
861
 
862
 
863
+ class InterleaveListsToDialogOperator(StreamInstanceOperator):
864
+ """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
865
+
866
+ The list of tuples if of format (role, turn_content), where the role label is specified by
867
+ the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
868
+
869
+ The user turns and assistant turns field are specified in the arguments.
870
+ The value of each of the 'fields' is assumed to be a list.
871
+
872
+ """
873
+
874
+ user_turns_field: str
875
+ assistant_turns_field: str
876
+ user_role_label: str = "user"
877
+ assistant_role_label: str = "assistant"
878
+ to_field: str
879
+
880
+ def process(
881
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
882
+ ) -> Dict[str, Any]:
883
+ user_turns = instance[self.user_turns_field]
884
+ assistant_turns = instance[self.assistant_turns_field]
885
+
886
+ assert (
887
+ len(user_turns) == len(assistant_turns)
888
+ or (len(user_turns) - len(assistant_turns) == 1)
889
+ ), "user_turns must have either the same length as assistant_turns or one more turn."
890
+
891
+ interleaved_dialog = []
892
+ i, j = 0, 0 # Indices for the user and assistant lists
893
+ # While either list has elements left, continue interleaving
894
+ while i < len(user_turns) or j < len(assistant_turns):
895
+ if i < len(user_turns):
896
+ interleaved_dialog.append((self.user_role_label, user_turns[i]))
897
+ i += 1
898
+ if j < len(assistant_turns):
899
+ interleaved_dialog.append(
900
+ (self.assistant_role_label, assistant_turns[j])
901
+ )
902
+ j += 1
903
+
904
+ instance[self.to_field] = interleaved_dialog
905
+ return instance
906
+
907
+
908
  class IndexOf(StreamInstanceOperator):
909
  """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
910
 
 
1605
  return MultiStream(result)
1606
 
1607
 
1608
+ class SplitByNestedGroup(MultiStreamOperator):
1609
+ """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
1610
+
1611
+ Args:
1612
+ number_of_fusion_generations: int
1613
+
1614
+ the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
1615
+ when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
1616
+ (See BaseFusion and its extensions)
1617
+ number_of_fuaion_generations specifies the length of the prefix by which to split the stream.
1618
+ E.g. for number_of_fusion_generations = 1, only the most recent fusion in creating this multi_stream, affects the splitting.
1619
+ For number_of_fusion_generations = -1, take the whole history written in this field, ignoring number of generations.
1620
+ """
1621
+
1622
+ field_name_of_group: str = "group"
1623
+ number_of_fusion_generations: int = 1
1624
+
1625
+ def process(self, multi_stream: MultiStream) -> MultiStream:
1626
+ result = defaultdict(list)
1627
+
1628
+ for stream_name, stream in multi_stream.items():
1629
+ for instance in stream:
1630
+ if self.field_name_of_group not in instance:
1631
+ raise ValueError(
1632
+ f"Field {self.field_name_of_group} is missing from instance {instance}"
1633
+ )
1634
+ signature = (
1635
+ stream_name
1636
+ + "~" # a sign that does not show within group values
1637
+ + (
1638
+ "/".join(
1639
+ instance[self.field_name_of_group].split("/")[
1640
+ : self.number_of_fusion_generations
1641
+ ]
1642
+ )
1643
+ if self.number_of_fusion_generations >= 0
1644
+ # for values with a smaller number of generations - take up to their last generation
1645
+ else instance[self.field_name_of_group]
1646
+ # for each instance - take all its generations
1647
+ )
1648
+ )
1649
+ result[signature].append(instance)
1650
+
1651
+ return MultiStream.from_iterables(result)
1652
+
1653
+
1654
  class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
1655
  """Applies stream operators to a stream based on specified fields in each instance.
1656
 
 
1759
  add_origin_stream_name: bool = True
1760
  origin_stream_name_field_name: str = "origin"
1761
 
1762
+ def merge(self, multi_stream) -> Generator:
1763
  for stream_name, stream in multi_stream.items():
1764
  if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1765
  for instance in stream:
 
1770
  def process(self, multi_stream: MultiStream) -> MultiStream:
1771
  return MultiStream(
1772
  {
1773
+ self.new_stream_name: GeneratorStream(
1774
  self.merge, gen_kwargs={"multi_stream": multi_stream}
1775
  )
1776
  }
processors.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import re
3
  from difflib import get_close_matches
@@ -54,14 +55,6 @@ class ExtractWithRegex(RegexParser):
54
  return ""
55
 
56
 
57
- class LoadJson(FieldOperator):
58
- def process_value(self, text: Any) -> Any:
59
- try:
60
- return json.loads(text)
61
- except json.JSONDecodeError:
62
- return []
63
-
64
-
65
  class ListToEmptyEntitiesTuples(FieldOperator):
66
  def process_value(self, lst: Any) -> Any:
67
  try:
@@ -225,10 +218,30 @@ class StringOrNotString(FieldOperator):
225
  return text
226
 
227
 
228
- class ExtractMtBenchJudgment(FieldOperator):
229
  def process_value(self, text: Any) -> Any:
230
  match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
231
  try:
232
  return float(match.group(1)) / 10
233
  except:
234
  return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
  import json
3
  import re
4
  from difflib import get_close_matches
 
55
  return ""
56
 
57
 
 
 
 
 
 
 
 
 
58
  class ListToEmptyEntitiesTuples(FieldOperator):
59
  def process_value(self, lst: Any) -> Any:
60
  try:
 
218
  return text
219
 
220
 
221
+ class ExtractMtBenchRatingJudgment(FieldOperator):
222
  def process_value(self, text: Any) -> Any:
223
  match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
224
  try:
225
  return float(match.group(1)) / 10
226
  except:
227
  return 0.0
228
+
229
+
230
+ class ExtractMtBenchLabelJudgment(FieldOperator):
231
+ def process_value(self, text: Any) -> Any:
232
+ match = re.search(r"\[\[([^\]]+)\]\]", text)
233
+ try:
234
+ return str(match.group(1))
235
+ except:
236
+ return "None"
237
+
238
+
239
+ class LiteralEval(FieldOperator):
240
+ def process_value(self, text: Any) -> Any:
241
+ if text is not None and not isinstance(text, str):
242
+ raise ValueError(
243
+ f"LiteralEval: field '{self.field}' is expected to be of 'str' input type, got: {type(text)}"
244
+ )
245
+ if text is None or text == "":
246
+ return text
247
+ return ast.literal_eval(text.strip())
schema.py CHANGED
@@ -34,16 +34,24 @@ class ToUnitxtGroup(StreamInstanceOperatorValidator):
34
  postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
35
  remove_unnecessary_fields: bool = True
36
 
37
- def _to_lists_of_keys_and_values(self, dict: Dict[str, str]):
38
- return {
39
- "key": [key for key, _ in dict.items()],
40
- "value": [str(value) for _, value in dict.items()],
41
- }
42
 
43
  def process(
44
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
45
  ) -> Dict[str, Any]:
46
- task_data = {**instance["inputs"], **instance["outputs"]}
 
 
 
 
 
 
 
 
47
  instance["task_data"] = json.dumps(task_data)
48
 
49
  if self.remove_unnecessary_fields:
 
34
  postprocessors: List[str] = field(default_factory=lambda: ["to_string_stripped"])
35
  remove_unnecessary_fields: bool = True
36
 
37
+ @staticmethod
38
+ def artifact_to_jsonable(artifact):
39
+ if artifact.__id__ is None:
40
+ return artifact.to_dict()
41
+ return artifact.__id__
42
 
43
  def process(
44
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
45
  ) -> Dict[str, Any]:
46
+ task_data = {
47
+ **instance["inputs"],
48
+ **instance["outputs"],
49
+ "metadata": {
50
+ "template": self.artifact_to_jsonable(
51
+ instance["recipe_metadata"]["template"]
52
+ )
53
+ },
54
+ }
55
  instance["task_data"] = json.dumps(task_data)
56
 
57
  if self.remove_unnecessary_fields:
settings_utils.py CHANGED
@@ -126,13 +126,14 @@ if Settings.is_uninitilized():
126
  settings.max_log_message_size = (int, 100000)
127
  settings.artifactories = None
128
  settings.default_recipe = "standard_recipe"
129
- settings.default_verbosity = "debug"
130
  settings.remote_metrics = []
131
  settings.allow_passing_data_to_remote_api = (bool, False)
132
  settings.test_card_disable = (bool, False)
133
  settings.test_metric_disable = (bool, False)
134
  settings.metrics_master_key_token = None
135
  settings.seed = (int, 42)
 
136
 
137
  if Constants.is_uninitilized():
138
  constants = Constants()
 
126
  settings.max_log_message_size = (int, 100000)
127
  settings.artifactories = None
128
  settings.default_recipe = "standard_recipe"
129
+ settings.default_verbosity = "info"
130
  settings.remote_metrics = []
131
  settings.allow_passing_data_to_remote_api = (bool, False)
132
  settings.test_card_disable = (bool, False)
133
  settings.test_metric_disable = (bool, False)
134
  settings.metrics_master_key_token = None
135
  settings.seed = (int, 42)
136
+ settings.skip_artifacts_prepare_and_verify = (bool, False)
137
 
138
  if Constants.is_uninitilized():
139
  constants = Constants()
splitters.py CHANGED
@@ -196,9 +196,12 @@ class DiverseLabelsSampler(Sampler):
196
  raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
197
  choices = inputs[self.choices]
198
  if not isinstance(choices, list):
199
- raise ValueError(
200
- f"Unexpected input choices value '{choices}'. Expected a list."
201
- )
 
 
 
202
 
203
  if "outputs" not in exemplar:
204
  raise ValueError(f"'outputs' field is missing from '{exemplar}'.")
 
196
  raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
197
  choices = inputs[self.choices]
198
  if not isinstance(choices, list):
199
+ if isinstance(choices, str):
200
+ choices = [choices]
201
+ else:
202
+ raise ValueError(
203
+ f"Unexpected input choices value '{choices}'. Expected a list or a string."
204
+ )
205
 
206
  if "outputs" not in exemplar:
207
  raise ValueError(f"'outputs' field is missing from '{exemplar}'.")
standard.py CHANGED
@@ -135,6 +135,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
135
  self.metadata,
136
  self.standardization,
137
  self.processing,
 
138
  self.verblization,
139
  self.finalize,
140
  ]
@@ -144,6 +145,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
144
  self.inference_instance.steps = [
145
  self.metadata,
146
  self.processing,
 
147
  ]
148
 
149
  self.inference_demos = SourceSequentialOperator()
@@ -153,6 +155,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
153
  self.metadata,
154
  self.standardization,
155
  self.processing,
 
156
  ]
157
 
158
  self.inference = SequentialOperator()
 
135
  self.metadata,
136
  self.standardization,
137
  self.processing,
138
+ self.metadata,
139
  self.verblization,
140
  self.finalize,
141
  ]
 
145
  self.inference_instance.steps = [
146
  self.metadata,
147
  self.processing,
148
+ self.metadata,
149
  ]
150
 
151
  self.inference_demos = SourceSequentialOperator()
 
155
  self.metadata,
156
  self.standardization,
157
  self.processing,
158
+ self.metadata,
159
  ]
160
 
161
  self.inference = SequentialOperator()
stream.py CHANGED
@@ -1,5 +1,6 @@
1
  import tempfile
2
- from typing import Dict, Iterable
 
3
 
4
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
5
 
@@ -8,6 +9,36 @@ from .generator_utils import CopyingReusableGenerator, ReusableGenerator
8
 
9
 
10
  class Stream(Dataclass):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """A class for handling streaming data in a customizable way.
12
 
13
  This class provides methods for generating, caching, and manipulating streaming data.
@@ -18,8 +49,8 @@ class Stream(Dataclass):
18
  caching (bool): Whether the data is cached or not. :no-index:
19
  """
20
 
21
- generator: callable
22
- gen_kwargs: Dict[str, any] = OptionalField(default_factory=dict)
23
  caching: bool = False
24
  copying: bool = False
25
 
@@ -147,7 +178,7 @@ class MultiStream(dict):
147
  assert all(isinstance(v, ReusableGenerator) for v in generators.values())
148
  return cls(
149
  {
150
- key: Stream(
151
  generator.generator,
152
  gen_kwargs=generator.gen_kwargs,
153
  caching=caching,
@@ -173,7 +204,7 @@ class MultiStream(dict):
173
  """
174
  return cls(
175
  {
176
- key: Stream(
177
  iterable.__iter__,
178
  caching=caching,
179
  copying=copying,
 
1
  import tempfile
2
+ from abc import abstractmethod
3
+ from typing import Any, Callable, Dict, Iterable, List
4
 
5
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
6
 
 
9
 
10
 
11
  class Stream(Dataclass):
12
+ @abstractmethod
13
+ def __iter__(self):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def peek(self):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def take(self, n):
22
+ pass
23
+
24
+
25
+ class ListStream(Stream):
26
+ instances_list: List[Dict[str, Any]]
27
+
28
+ def __iter__(self):
29
+ return iter(self.instances_list)
30
+
31
+ def peek(self):
32
+ return next(iter(self.instances_list))
33
+
34
+ def take(self, n):
35
+ for i, instance in enumerate(self.instances_list):
36
+ if i >= n:
37
+ break
38
+ yield instance
39
+
40
+
41
+ class GeneratorStream(Stream):
42
  """A class for handling streaming data in a customizable way.
43
 
44
  This class provides methods for generating, caching, and manipulating streaming data.
 
49
  caching (bool): Whether the data is cached or not. :no-index:
50
  """
51
 
52
+ generator: Callable
53
+ gen_kwargs: Dict[str, Any] = OptionalField(default_factory=dict)
54
  caching: bool = False
55
  copying: bool = False
56
 
 
178
  assert all(isinstance(v, ReusableGenerator) for v in generators.values())
179
  return cls(
180
  {
181
+ key: GeneratorStream(
182
  generator.generator,
183
  gen_kwargs=generator.gen_kwargs,
184
  caching=caching,
 
204
  """
205
  return cls(
206
  {
207
+ key: GeneratorStream(
208
  iterable.__iter__,
209
  caching=caching,
210
  copying=copying,
string_operators.py CHANGED
@@ -1,7 +1,12 @@
1
  import re
2
- from typing import List
 
 
 
 
 
3
 
4
- from .operators import FieldOperator
5
 
6
 
7
  class Split(FieldOperator):
@@ -39,6 +44,17 @@ class Join(FieldOperator):
39
  return self.by.join(value)
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
42
  class Strip(FieldOperator):
43
  def process_value(self, value: str) -> str:
44
  return value.strip()
 
1
  import re
2
+ from typing import (
3
+ Any,
4
+ Dict,
5
+ List,
6
+ Optional,
7
+ )
8
 
9
+ from .operators import FieldOperator, StreamInstanceOperator
10
 
11
 
12
  class Split(FieldOperator):
 
44
  return self.by.join(value)
45
 
46
 
47
+ class FormatText(StreamInstanceOperator):
48
+ to_field: str
49
+ text: str
50
+
51
+ def process(
52
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
53
+ ) -> Dict[str, Any]:
54
+ instance[self.to_field] = self.text.format(**instance)
55
+ return instance
56
+
57
+
58
  class Strip(FieldOperator):
59
  def process_value(self, value: str) -> str:
60
  return value.strip()
struct_data_operators.py CHANGED
@@ -547,3 +547,22 @@ class ShuffleTableColumns(FieldOperator):
547
  table_content["rows"] = shuffled_rows
548
 
549
  return table_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  table_content["rows"] = shuffled_rows
548
 
549
  return table_content
550
+
551
+
552
+ class LoadJson(FieldOperator):
553
+ failure_value: Any = None
554
+ allow_failure: bool = False
555
+
556
+ def process_value(self, value: str) -> Any:
557
+ if self.allow_failure:
558
+ try:
559
+ return json.loads(value)
560
+ except json.JSONDecodeError:
561
+ return self.failure_value
562
+ else:
563
+ return json.loads(value)
564
+
565
+
566
+ class DumpJson(FieldOperator):
567
+ def process_value(self, value: str) -> str:
568
+ return json.dumps(value)
task.py CHANGED
@@ -13,11 +13,7 @@ from .type_utils import (
13
  )
14
 
15
 
16
- class Tasker:
17
- pass
18
-
19
-
20
- class FormTask(Tasker, StreamInstanceOperator):
21
  """FormTask packs the different instance fields into dictionaries by their roles in the task.
22
 
23
  Attributes:
@@ -119,49 +115,5 @@ class FormTask(Tasker, StreamInstanceOperator):
119
  }
120
 
121
 
122
- class MultipleChoiceTask(FormTask):
123
- choices_field: str = "choices"
124
- choices_separator: str = "\n"
125
- enumeration_suffix: str = ". "
126
- use_text_in_target: bool = False
127
- alphabet: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
128
-
129
- def process_single_choice(
130
- self, choice: str, index: int, use_text: bool = True
131
- ) -> str:
132
- try:
133
- processed_choice = f"{self.alphabet[index]}"
134
- except IndexError as e:
135
- raise ValueError(
136
- f"Too many choices, the length of alphabet '{self.alphabet}': {len(self.alphabet)} is the limit"
137
- ) from e
138
- if use_text:
139
- processed_choice += f"{self.enumeration_suffix}{choice}"
140
- return processed_choice
141
-
142
- def process_choices(self, choices: List[str]) -> str:
143
- processed_choices = []
144
- for index, choice in enumerate(choices):
145
- processed_choices.append(self.process_single_choice(choice, index))
146
- return self.choices_separator.join(processed_choices)
147
-
148
- def process_target(self, choices, target_index):
149
- return self.process_single_choice(
150
- choices[target_index], target_index, use_text=self.use_text_in_target
151
- )
152
-
153
- def process(
154
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
155
- ) -> Dict[str, Any]:
156
- result = super().process(instance, stream_name)
157
- target_key, target_value = next(iter(result["outputs"].items()))
158
- choices = result["inputs"][self.choices_field]
159
- target_index_in_choices = choices.index(target_value)
160
-
161
- processed_choices = self.process_choices(choices)
162
- processed_target = self.process_target(choices, target_index_in_choices)
163
-
164
- result["inputs"][self.choices_field] = processed_choices
165
- result["outputs"][target_key] = processed_target
166
-
167
- return result
 
13
  )
14
 
15
 
16
+ class Task(StreamInstanceOperator):
 
 
 
 
17
  """FormTask packs the different instance fields into dictionaries by their roles in the task.
18
 
19
  Attributes:
 
115
  }
116
 
117
 
118
+ class FormTask(Task):
119
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
templates.py CHANGED
@@ -1,7 +1,9 @@
1
  import json
2
  from abc import abstractmethod
 
3
  from typing import Any, Dict, List, Optional, Tuple, Union
4
 
 
5
  from .collections import ListCollection
6
  from .dataclass import NonPositionalField
7
  from .operator import StreamInstanceOperator
@@ -48,6 +50,11 @@ class Template(StreamInstanceOperator):
48
  )
49
  return instruction, target_prefix
50
 
 
 
 
 
 
51
  def process(
52
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
53
  ) -> Dict[str, Any]:
@@ -61,9 +68,9 @@ class Template(StreamInstanceOperator):
61
 
62
  inputs = instance.get("inputs")
63
  outputs = instance.get("outputs")
 
64
 
65
  self.set_titles(inputs)
66
-
67
  source = self.inputs_to_source(inputs)
68
  instruction, target_prefix = self.inputs_to_instruction_and_target_prefix(
69
  inputs
@@ -150,6 +157,135 @@ class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
150
  return target, [reference]
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  class MultipleChoiceTemplate(Template):
154
  """Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
155
 
@@ -328,25 +464,20 @@ class YesNoTemplate(Template):
328
  raise RuntimeError(
329
  f"Available outputs are {list(outputs.keys())}, missing required label field: '{self.label_field}'."
330
  ) from e
331
- if not isinstance(gold_class_names, list) or not gold_class_names:
332
  raise RuntimeError(
333
- f"Unexpected value for gold_class_names: '{gold_class_names}'. Expected a non-empty list."
334
  )
335
  try:
336
- queried_class_names = outputs[self.class_field]
337
  except KeyError as e:
338
  raise RuntimeError(
339
  f"Available outputs are {list(outputs.keys())}, missing required class field: '{self.class_field}'."
340
  ) from e
341
- if (
342
- not queried_class_names
343
- or not isinstance(queried_class_names, list)
344
- or not len(queried_class_names) == 1
345
- ):
346
  raise RuntimeError(
347
- f"Unexpected value for queried_class_names: '{queried_class_names}'. Expected a list with one item."
348
  )
349
- queried_class_name = queried_class_names[0]
350
  if queried_class_name in gold_class_names:
351
  return self.yes_answer, [self.yes_answer]
352
  return self.no_answer, [self.no_answer]
 
1
  import json
2
  from abc import abstractmethod
3
+ from random import random
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
 
6
+ from .artifact import Artifact
7
  from .collections import ListCollection
8
  from .dataclass import NonPositionalField
9
  from .operator import StreamInstanceOperator
 
50
  )
51
  return instruction, target_prefix
52
 
53
+ def preprocess_inputs_and_outputs(
54
+ self, inputs: Dict[str, Any], outputs: Dict[str, Any]
55
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
56
+ return inputs, outputs
57
+
58
  def process(
59
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
60
  ) -> Dict[str, Any]:
 
68
 
69
  inputs = instance.get("inputs")
70
  outputs = instance.get("outputs")
71
+ inputs, outputs = self.preprocess_inputs_and_outputs(inputs, outputs)
72
 
73
  self.set_titles(inputs)
 
74
  source = self.inputs_to_source(inputs)
75
  instruction, target_prefix = self.inputs_to_instruction_and_target_prefix(
76
  inputs
 
157
  return target, [reference]
158
 
159
 
160
+ class PairwiseChoiceTemplate(InputOutputTemplate):
161
+ """PairwiseChoiceTemplate.
162
+
163
+ Requirements:
164
+ The answer field value should be of type Literal["choice_a", "choice_b", "tie"]
165
+
166
+ Args:
167
+ choice_a_field (str): The field which contains choice_a value
168
+ choice_b_field (str): The field which contains choice_b value
169
+ answer_field (str): The field which contains the answer value.
170
+ Should be of type Literal["choice_1", "choice_2", "tie"]
171
+ choice_a_label (str): The label of choice A answer as it is verbalized in the template.
172
+ choice_b_label (str): The label of choice B answer as it is verbalized in the template.
173
+ choice_tie_label (str): The label of a tie answer as it should be verbalized in the template.
174
+ shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
175
+
176
+ shuffle: 50% of the time:
177
+ 1) The values of choice_a_field and choice_b_field will be swapped.
178
+ 2) If the values of answer_field is choice_a_label, set it to choice_b_label.
179
+ Else if the values of answer_field is choice_b_label, set it to choice_a_label.
180
+ Else if the value of answer_field is choice_tie_label, do nothing.
181
+
182
+ """
183
+
184
+ choice_a_field: str
185
+ choice_b_field: str
186
+ answer_field: str
187
+ choice_a_label: str
188
+ choice_b_label: str
189
+ choice_tie_label: str
190
+ shuffle: bool
191
+
192
+ def verbalize_answer_field(self, outputs: Dict[str, object]):
193
+ answer = outputs[self.answer_field]
194
+ assert answer in ["choice_a", "choice_b", "tie"]
195
+ if answer == "choice_a":
196
+ outputs[self.answer_field] = self.choice_a_label
197
+ elif answer == "choice_b":
198
+ outputs[self.answer_field] = self.choice_b_label
199
+ else:
200
+ outputs[self.answer_field] = self.choice_tie_label
201
+
202
+ return outputs
203
+
204
+ def shuffle_values(self, inputs: Dict[str, object], outputs: Dict[str, object]):
205
+ outcome = random() # A float between 0 and 1
206
+ if outcome <= 0.5:
207
+ choice_a_value = inputs[self.choice_a_field]
208
+ choice_b_value = inputs[self.choice_b_field]
209
+
210
+ inputs[self.choice_a_field] = choice_a_value
211
+ inputs[self.choice_b_field] = choice_b_value
212
+
213
+ answer = outputs[self.answer_field]
214
+ assert answer in [
215
+ self.choice_a_label,
216
+ self.choice_b_label,
217
+ self.choice_tie_label,
218
+ ]
219
+ if answer == self.choice_a_label:
220
+ outputs[self.answer_field] = self.choice_b_label
221
+ elif answer == self.choice_b_label:
222
+ outputs[self.answer_field] = self.choice_a_label
223
+
224
+ return inputs, outputs
225
+
226
+ def preprocess_inputs_and_outputs(
227
+ self, inputs: Dict[str, Any], outputs: Dict[str, Any]
228
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
229
+ outputs = self.verbalize_answer_field(outputs)
230
+ inputs, outputs = self.shuffle_values(inputs, outputs)
231
+ return inputs, outputs
232
+
233
+
234
+ class DialogFieldsData(Artifact):
235
+ user_role_label: str
236
+ assistant_role_label: str
237
+ system_role_label: str
238
+ dialog_field: str
239
+
240
+
241
+ class DialogTemplate(InputOutputTemplate):
242
+ dialog_fields: List[DialogFieldsData]
243
+ turns_separator: str = "\n\n"
244
+ label_separator: str = " "
245
+
246
+ def process_dialog(self, inputs: Dict[str, object]):
247
+ for dialog_fields in self.dialog_fields:
248
+ dialog = inputs[dialog_fields.dialog_field]
249
+ # TODO: update isoftype method to support Literal verification and check
250
+ # it's List[Tuple[Literal["user", "assistant", "system"], str]] (Issue #799)
251
+ assert isoftype(dialog, List[Tuple[str, str]])
252
+
253
+ user_role_label = dialog_fields.user_role_label
254
+ assistant_role_label = dialog_fields.assistant_role_label
255
+ system_role_label = dialog_fields.system_role_label
256
+
257
+ dialog_str = ""
258
+ for i, turn in enumerate(dialog):
259
+ (turn_type, turn_text) = turn
260
+ turns_separator = "" if i == 0 else self.turns_separator
261
+ if turn_type == "user":
262
+ dialog_str += f"{turns_separator}{user_role_label}{self.label_separator}{turn_text}"
263
+ elif turn_type == "assistant":
264
+ dialog_str += f"{turns_separator}{assistant_role_label}{self.label_separator}{turn_text}"
265
+ elif turn_type == "system":
266
+ dialog_str += f"{turns_separator}{system_role_label}{self.label_separator}{turn_text}"
267
+
268
+ inputs[dialog_fields.dialog_field] = dialog_str
269
+ return inputs
270
+
271
+ def preprocess_inputs_and_outputs(
272
+ self, inputs: Dict[str, Any], outputs: Dict[str, Any]
273
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
274
+ return self.process_dialog(inputs), outputs
275
+
276
+
277
+ class DialogPairwiseChoiceTemplate(DialogTemplate, PairwiseChoiceTemplate):
278
+ def preprocess_inputs_and_outputs(
279
+ self, inputs: Dict[str, Any], outputs: Dict[str, Any]
280
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
281
+ inputs, outputs = DialogTemplate.preprocess_inputs_and_outputs(
282
+ self, inputs, outputs
283
+ )
284
+ return PairwiseChoiceTemplate.preprocess_inputs_and_outputs(
285
+ self, inputs, outputs
286
+ )
287
+
288
+
289
  class MultipleChoiceTemplate(Template):
290
  """Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
291
 
 
464
  raise RuntimeError(
465
  f"Available outputs are {list(outputs.keys())}, missing required label field: '{self.label_field}'."
466
  ) from e
467
+ if not isinstance(gold_class_names, list):
468
  raise RuntimeError(
469
+ f"Unexpected value for gold_class_names: '{gold_class_names}'. Expecting a list."
470
  )
471
  try:
472
+ queried_class_name = outputs[self.class_field]
473
  except KeyError as e:
474
  raise RuntimeError(
475
  f"Available outputs are {list(outputs.keys())}, missing required class field: '{self.class_field}'."
476
  ) from e
477
+ if not queried_class_name or not isinstance(queried_class_name, str):
 
 
 
 
478
  raise RuntimeError(
479
+ f"Unexpected value for queried_class_names: '{queried_class_name}'. Expected a string."
480
  )
 
481
  if queried_class_name in gold_class_names:
482
  return self.yes_answer, [self.yes_answer]
483
  return self.no_answer, [self.no_answer]
text_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import re
2
  import shutil
 
3
 
4
  from .logging_utils import get_logger
5
 
@@ -129,3 +130,42 @@ def nested_tuple_to_string(nested_tuple: tuple) -> str:
129
  def is_made_of_sub_strings(string, sub_strings):
130
  pattern = "^(" + "|".join(map(re.escape, sub_strings)) + ")+$"
131
  return bool(re.match(pattern, string))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import shutil
3
+ from typing import List, Tuple
4
 
5
  from .logging_utils import get_logger
6
 
 
130
  def is_made_of_sub_strings(string, sub_strings):
131
  pattern = "^(" + "|".join(map(re.escape, sub_strings)) + ")+$"
132
  return bool(re.match(pattern, string))
133
+
134
+
135
+ # Giveמ all the lines of a file, e.g. all the lines of prepare/cards/cohere_for_ai.py,
136
+ # and an object name, e.g. TaskCard,
137
+ # return the ordinal number of the line that starts that object, in our example: the
138
+ # line number of the following line (notice that the line where TaskCard is imported
139
+ # is not supposed to return):
140
+ # card = TaskCard(
141
+ # and the line number of the line that ends the object, in our case the line that include
142
+ # the matching close:
143
+ # )
144
+ # This util depends on ruff to ensure this setting of the card file: that a close of one
145
+ # tag and the open of the next tag, do not sit in same line, both tags being
146
+ # major level within TaskCard
147
+ # flake8: noqa: B007
148
+ def lines_defining_obj(
149
+ all_lines: List[str], obj_name: str, start_search_at_line: int = 0
150
+ ) -> Tuple[int, int]:
151
+ for starting_line in range(start_search_at_line, len(all_lines)):
152
+ line = all_lines[starting_line]
153
+ if obj_name in line:
154
+ break
155
+ if obj_name not in line:
156
+ # obj_name found no where in the input lines
157
+ return (-1, -1)
158
+ num_of_opens = 0
159
+ num_of_closes = 0
160
+ for ending_line in range(starting_line, len(all_lines)):
161
+ num_of_opens += len(re.findall(r"[({[]", all_lines[ending_line]))
162
+ num_of_closes += len(re.findall(r"[)}\]]", all_lines[ending_line]))
163
+ if num_of_closes == num_of_opens:
164
+ break
165
+
166
+ if num_of_closes != num_of_opens:
167
+ raise ValueError(
168
+ "input lines were exhausted before the matching close is found"
169
+ )
170
+
171
+ return (starting_line, ending_line)
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.8.1"
 
1
+ version = "1.9.0"