Elron commited on
Commit
d443ad5
·
verified ·
1 Parent(s): cea5047

Upload folder using huggingface_hub

Browse files
Files changed (12) hide show
  1. augmentors.py +3 -6
  2. image_operators.py +12 -0
  3. inference.py +1381 -426
  4. llm_as_judge.py +14 -2
  5. loaders.py +9 -9
  6. metrics.py +7 -0
  7. operators.py +15 -8
  8. settings_utils.py +1 -1
  9. standard.py +6 -9
  10. task.py +23 -19
  11. text_utils.py +2 -1
  12. version.py +1 -1
augmentors.py CHANGED
@@ -49,7 +49,7 @@ class TextAugmentor(TypeDependentAugmentor):
49
  augmented_type = Text
50
 
51
 
52
- class NullAugmentor(Augmentor):
53
  """Does not change the input string."""
54
 
55
  def process_value(self, value: Any) -> Any:
@@ -83,12 +83,9 @@ class AugmentPrefixSuffix(TextAugmentor):
83
  r"""Augments the input by prepending and appending randomly selected (typically, whitespace) patterns.
84
 
85
  Args:
86
- prefixes, suffixes (list or dict) : the potential (typically, whitespace) patterns to select from.
87
- The dictionary version allows the specification relative weights for the different patterns.
88
  prefix_len, suffix_len (positive int) : The added prefix or suffix will be of a certain length.
89
- remove_existing_whitespaces : Clean any existing leading and trailing whitespaces.
90
- The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially
91
- trimmed input.
92
  If only either just prefixes or just suffixes are needed, set the other to None.
93
 
94
  Examples:
 
49
  augmented_type = Text
50
 
51
 
52
+ class NullAugmentor(TaskInputsAugmentor):
53
  """Does not change the input string."""
54
 
55
  def process_value(self, value: Any) -> Any:
 
83
  r"""Augments the input by prepending and appending randomly selected (typically, whitespace) patterns.
84
 
85
  Args:
86
+ prefixes, suffixes (list or dict) : the potential patterns (typically, whitespace) to select from. The dictionary version allows the specification relative weights for the different patterns.
 
87
  prefix_len, suffix_len (positive int) : The added prefix or suffix will be of a certain length.
88
+ remove_existing_whitespaces : Clean any existing leading and trailing whitespaces. The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially trimmed input.
 
 
89
  If only either just prefixes or just suffixes are needed, set the other to None.
90
 
91
  Examples:
image_operators.py CHANGED
@@ -93,6 +93,18 @@ def extract_images(text, instance):
93
  return images
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  class DecodeImage(FieldOperator, PillowMixin):
97
  def process_value(self, value: str) -> Any:
98
  image_data = base64.b64decode(value)
 
93
  return images
94
 
95
 
96
+ class EncodeImageToString(FieldOperator):
97
+ image_format: str = "JPEG"
98
+
99
+ def encode_image_to_base64(self, image):
100
+ buffer = io.BytesIO()
101
+ image.save(buffer, format=self.image_format)
102
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
103
+
104
+ def process_value(self, value: Any) -> Any:
105
+ return {"image": self.encode_image_to_base64(value)}
106
+
107
+
108
  class DecodeImage(FieldOperator, PillowMixin):
109
  def process_value(self, value: str) -> Any:
110
  image_data = base64.b64decode(value)
inference.py CHANGED
@@ -9,7 +9,18 @@ import sys
9
  import time
10
  import uuid
11
  from collections import Counter
12
- from typing import Any, Dict, List, Literal, Optional, Union
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  from datasets import DatasetDict
15
  from tqdm import tqdm, trange
@@ -19,11 +30,12 @@ from .artifact import Artifact
19
  from .dataclass import InternalField, NonPositionalField
20
  from .deprecation_utils import deprecation
21
  from .error_utils import UnitxtError
22
- from .image_operators import data_url_to_image, extract_images
23
  from .logging_utils import get_logger
24
  from .operator import PackageRequirementsMixin
25
  from .operators import ArtifactFetcherMixin
26
  from .settings_utils import get_constants, get_settings
 
27
 
28
  constants = get_constants()
29
  settings = get_settings()
@@ -67,6 +79,9 @@ class TextGenerationInferenceOutput:
67
 
68
  input_tokens (int) : number of input tokens to the model.
69
  output_tokens (int) : number of output tokens to the model.
 
 
 
70
  model_name (str): the model_name as kept in the InferenceEngine.
71
  inference_type (str): The label stating the type of the InferenceEngine.
72
  """
@@ -74,6 +89,9 @@ class TextGenerationInferenceOutput:
74
  prediction: Union[str, List[Dict[str, Any]]]
75
  input_tokens: Optional[int] = None
76
  output_tokens: Optional[int] = None
 
 
 
77
  model_name: Optional[str] = None
78
  inference_type: Optional[str] = None
79
 
@@ -152,6 +170,10 @@ class InferenceEngine(Artifact):
152
  if param_inst_val is None:
153
  setattr(self, param, param_dict_val)
154
 
 
 
 
 
155
  def verify_not_chat_api(self, dataset):
156
  if isinstance(dataset[0]["source"], list):
157
  raise NotImplementedError(
@@ -216,259 +238,898 @@ class LazyLoadMixin(Artifact):
216
  pass
217
 
218
 
219
- class HFPipelineBasedInferenceEngine(
220
- InferenceEngine, PackageRequirementsMixin, LazyLoadMixin
221
- ):
222
- model_name: str
223
  max_new_tokens: int
224
- use_fp16: bool = True
225
- batch_size: int = 1
 
226
  top_k: Optional[int] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  _requirements_list = {
229
- "transformers": "Install huggingface package using 'pip install --upgrade transformers"
 
 
230
  }
231
 
232
- def get_engine_id(self):
233
- return get_model_and_label_id(self.model_name, "hf_pipeline")
234
 
235
- def _get_task(self):
236
- from transformers import AutoConfig
 
 
 
 
237
 
238
- return (
239
- "text2text-generation"
240
- if AutoConfig.from_pretrained(
241
- self.model_name, trust_remote_code=True
242
- ).is_encoder_decoder
243
- else "text-generation"
244
- )
245
 
246
- def _prepare_pipeline(self):
247
- import torch
248
- from transformers import pipeline
 
 
 
 
249
 
250
- model_args: Dict[str, Any] = (
251
- {"torch_dtype": torch.float16} if self.use_fp16 else {}
252
- )
253
- model_args.update({"max_new_tokens": self.max_new_tokens})
254
 
255
- device = torch.device(
256
- "mps"
257
- if torch.backends.mps.is_available()
258
- else 0
259
- if torch.cuda.is_available()
260
- else "cpu"
261
- )
262
- # We do this, because in some cases, using device:auto will offload some weights to the cpu
263
- # (even though the model might *just* fit to a single gpu), even if there is a gpu available, and this will
264
- # cause an error because the data is always on the gpu
265
- if torch.cuda.device_count() > 1:
266
- assert device == torch.device(0)
267
- model_args.update({"device_map": "auto"})
268
- else:
269
- model_args.update({"device": device})
 
270
 
271
- task = self._get_task()
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- if task == "text-generation":
274
- model_args.update({"return_full_text": False})
275
 
276
- self.model = pipeline(
277
- model=self.model_name, trust_remote_code=True, **model_args
278
- )
 
279
 
280
  def prepare_engine(self):
281
  if not self.lazy_load:
282
- self._prepare_pipeline()
283
 
284
- def _is_loaded(self):
285
- return hasattr(self, "model") and self.model is not None
286
 
287
- def _infer(
288
- self,
289
- dataset: Union[List[Dict[str, Any]], DatasetDict],
290
- return_meta_data: bool = False,
291
- ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
292
- if self._get_task() == "text2text-generation":
293
- self.verify_not_chat_api(dataset)
294
 
295
- if not self._is_loaded():
296
- self._prepare_pipeline()
 
 
 
 
 
 
 
 
 
297
 
298
- outputs = []
299
- for output in self.model(
300
- [instance["source"] for instance in dataset],
301
- batch_size=self.batch_size,
302
- top_k=self.top_k,
303
- ):
304
- if isinstance(output, list):
305
- output = output[0]
306
- outputs.append(output["generated_text"])
307
- return outputs
 
308
 
 
 
 
 
 
 
 
 
309
 
310
- class MockInferenceEngine(InferenceEngine):
311
- model_name: str
312
- default_inference_value: str = "[[10]]"
 
 
313
 
314
- def get_engine_id(self):
315
- return get_model_and_label_id(self.model_name, "mock")
316
 
317
- def prepare_engine(self):
318
- return
319
 
320
- def _mock_infer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  self,
322
  dataset: Union[List[Dict[str, Any]], DatasetDict],
 
323
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
324
- return [self.default_inference_value for _ in dataset]
 
 
325
 
 
326
  def _infer(
327
  self,
328
  dataset: Union[List[Dict[str, Any]], DatasetDict],
329
  return_meta_data: bool = False,
330
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
331
- return self._mock_infer(dataset)
332
-
333
-
334
- class MockModeMixin(Artifact):
335
- mock_mode: bool = False
336
 
 
 
 
 
 
 
 
 
337
 
338
- class IbmGenAiInferenceEngineParamsMixin(Artifact):
339
- beam_width: Optional[int] = None
340
- decoding_method: Optional[Literal["greedy", "sample"]] = None
341
- include_stop_sequence: Optional[bool] = None
342
- length_penalty: Any = None
343
- max_new_tokens: Optional[int] = None
344
- min_new_tokens: Optional[int] = None
345
- random_seed: Optional[int] = None
346
- repetition_penalty: Optional[float] = None
347
- return_options: Any = None
348
- stop_sequences: Optional[List[str]] = None
349
- temperature: Optional[float] = None
350
- time_limit: Optional[int] = None
351
- top_k: Optional[int] = None
352
- top_p: Optional[float] = None
353
- truncate_input_tokens: Optional[int] = None
354
- typical_p: Optional[float] = None
355
 
356
 
357
- @deprecation(version="2.0.0", alternative=IbmGenAiInferenceEngineParamsMixin)
358
- class IbmGenAiInferenceEngineParams(Artifact):
359
- beam_width: Optional[int] = None
360
- decoding_method: Optional[Literal["greedy", "sample"]] = None
361
- include_stop_sequence: Optional[bool] = None
362
- length_penalty: Any = None
363
- max_new_tokens: Optional[int] = None
364
- min_new_tokens: Optional[int] = None
365
- random_seed: Optional[int] = None
366
- repetition_penalty: Optional[float] = None
367
- return_options: Any = None
368
- stop_sequences: Optional[List[str]] = None
369
- temperature: Optional[float] = None
370
- time_limit: Optional[int] = None
371
- top_k: Optional[int] = None
372
- top_p: Optional[float] = None
373
- truncate_input_tokens: Optional[int] = None
374
- typical_p: Optional[float] = None
375
 
 
 
376
 
377
- class GenericInferenceEngine(InferenceEngine, ArtifactFetcherMixin):
378
- default: Optional[str] = None
 
 
 
 
379
 
380
- def prepare_engine(self):
381
- if "UNITXT_INFERENCE_ENGINE" in os.environ:
382
- engine_reference = os.environ["UNITXT_INFERENCE_ENGINE"]
383
- else:
384
- assert self.default is not None, (
385
- "GenericInferenceEngine could not be initialized"
386
- '\nThis is since both the "UNITXT_INFERENCE_ENGINE" environmental variable is not set and no default engine was not inputted.'
387
- "\nFor example, you can fix it by setting"
388
- "\nexport UNITXT_INFERENCE_ENGINE=engines.ibm_gen_ai.llama_3_70b_instruct"
389
- "\nto your ~/.bashrc"
390
- "\nor passing a similar required engine in the default argument"
391
- )
392
- engine_reference = self.default
393
- self.engine = self.get_artifact(engine_reference)
394
 
395
- def get_engine_id(self):
396
- return "generic_inference_engine"
 
 
 
397
 
398
- def _infer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  self,
400
  dataset: Union[List[Dict[str, Any]], DatasetDict],
401
- return_meta_data: bool = False,
402
- ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
403
- return self.engine._infer(dataset)
 
 
 
 
 
 
 
 
404
 
 
 
405
 
406
- class OllamaInferenceEngine(
407
- InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin
408
- ):
409
- label: str = "ollama"
410
- _requirements_list = {
411
- "ollama": "Install ollama package using 'pip install --upgrade ollama"
412
- }
413
- data_classification_policy = ["public", "proprietary"]
414
 
415
- def get_engine_id(self):
416
- return get_model_and_label_id(self.model, self.label)
 
 
 
417
 
418
- def prepare_engine(self):
419
- pass
 
 
 
 
 
 
 
 
 
 
420
 
421
  def _infer(
422
  self,
423
  dataset: Union[List[Dict[str, Any]], DatasetDict],
424
  return_meta_data: bool = False,
425
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
426
- import ollama
427
-
428
- args = self.to_dict([StandardAPIParamsMixin])
429
-
430
- results = []
431
-
432
- for instance in dataset:
433
- messages = self.to_messages(instance)
434
- response = ollama.chat(
435
- model=self.model,
436
- messages=messages,
437
- **args,
438
- )
439
- results.append(response)
440
 
441
- return [element["message"]["content"] for element in results]
 
 
 
 
 
 
442
 
443
 
444
- class OptionSelectingByLogProbsInferenceEngine:
445
- """OptionSelectingByLogProbsInferenceEngine inference engine is used to select an option based on the logprobs of an options list conditioned by a prompt.
 
 
446
 
447
- The inference engines that inherit from this class must implement `get_token_count` and `get_options_log_probs`.
448
- """
 
 
 
449
 
450
- @abc.abstractmethod
451
- def get_token_count(self, dataset):
452
- """Get the token count of the source key of each dict of the dataset. Add to each instance in the data a "token_count" field.
453
 
454
- Args:
455
- dataset (List[Dict[str, Any]]): A list of dictionaries, each representing a data instance.
456
 
457
- Returns:
458
- List[int]: The token count of the texts
459
- """
460
 
461
- @abc.abstractmethod
462
- def get_options_log_probs(self, dataset):
463
- """Get the token logprobs of the options of the key task_data.options of each dict of the dataset.
464
 
465
- Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}.
 
466
 
467
- Args:
468
- dataset (List[Dict[str, Any]]): A list of dictionaries, each representing a data instance.
 
 
 
 
 
 
469
 
470
- Returns:
471
- List[int]: The token count of the texts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  """
473
 
474
  def select(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@@ -552,12 +1213,14 @@ class IbmGenAiInferenceEngine(
552
  }
553
  data_classification_policy = ["public", "proprietary"]
554
  parameters: Optional[IbmGenAiInferenceEngineParams] = None
 
555
 
556
  def get_engine_id(self):
557
  return get_model_and_label_id(self.model_name, self.label)
558
 
559
- def prepare_engine(self):
560
- from genai import Client, Credentials
 
561
 
562
  api_key_env_var_name = "GENAI_KEY"
563
  api_key = os.environ.get(api_key_env_var_name)
@@ -566,9 +1229,22 @@ class IbmGenAiInferenceEngine(
566
  f"Error while trying to run IbmGenAiInferenceEngine."
567
  f" Please set the environment param '{api_key_env_var_name}'."
568
  )
569
- credentials = Credentials(api_key=api_key)
 
 
 
 
 
 
 
 
 
570
  self.client = Client(credentials=credentials)
571
 
 
 
 
 
572
  self._set_inference_parameters()
573
 
574
  def _infer(
@@ -576,22 +1252,26 @@ class IbmGenAiInferenceEngine(
576
  dataset: Union[List[Dict[str, Any]], DatasetDict],
577
  return_meta_data: bool = False,
578
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
579
- from genai.schema import TextGenerationParameters
 
 
580
 
581
  genai_params = TextGenerationParameters(
582
  **self.to_dict([IbmGenAiInferenceEngineParamsMixin])
583
  )
584
 
585
- results = []
586
  responses = self.client.text.generation.create(
587
  model_id=self.model_name,
588
  inputs=[instance["source"] for instance in dataset],
589
  parameters=genai_params,
 
590
  )
 
 
591
  for response in responses:
592
- generated_text = response.results[0].generated_text
593
  result = self.get_return_object(
594
- generated_text, response.results[0], return_meta_data
595
  )
596
  results.append(result)
597
  return results
@@ -601,7 +1281,9 @@ class IbmGenAiInferenceEngine(
601
  dataset: Union[List[Dict[str, Any]], DatasetDict],
602
  return_meta_data: bool = False,
603
  ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
604
- from genai.schema import TextGenerationParameters
 
 
605
 
606
  logprobs_return_options = {
607
  "generated_tokens": True,
@@ -620,11 +1302,12 @@ class IbmGenAiInferenceEngine(
620
  model_id=self.model_name,
621
  inputs=[instance["source"] for instance in dataset],
622
  parameters=genai_params,
 
623
  )
624
 
625
  predict_results = []
626
  for prediction in predictions:
627
- result = prediction.results[0]
628
  assert isinstance(
629
  result.generated_tokens, list
630
  ), "result.generated_tokens should be a list"
@@ -651,9 +1334,22 @@ class IbmGenAiInferenceEngine(
651
  output_tokens=result.generated_token_count,
652
  model_name=self.model_name,
653
  inference_type=self.label,
 
 
 
654
  )
655
  return predict_result
656
 
 
 
 
 
 
 
 
 
 
 
657
  def get_token_count(self, dataset):
658
  texts = [instance["source"] for instance in dataset]
659
  token_counts = list(
@@ -973,6 +1669,10 @@ class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
973
  return OpenAI(api_key=api_key, base_url=api_url)
974
 
975
 
 
 
 
 
976
  class WMLInferenceEngineParamsMixin(Artifact):
977
  decoding_method: Optional[Literal["greedy", "sample"]] = None
978
  length_penalty: Optional[Dict[str, Union[int, float]]] = None
@@ -1008,78 +1708,87 @@ class WMLInferenceEngineParams(Artifact):
1008
  return_options: Optional[Dict[str, bool]] = None
1009
 
1010
 
1011
- class WMLInferenceEngine(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
  InferenceEngine,
1013
- WMLInferenceEngineParamsMixin,
1014
  PackageRequirementsMixin,
1015
  LogProbInferenceEngine,
1016
  OptionSelectingByLogProbsInferenceEngine,
1017
  ):
1018
- """Runs inference using ibm-watsonx-ai.
1019
 
1020
  Attributes:
1021
  credentials (Dict[str, str], optional): By default, it is created by a class
1022
  instance which tries to retrieve proper environment variables
1023
- ("WML_URL", "WML_PROJECT_ID", "WML_APIKEY"). However, a dictionary with
1024
- the following keys: "url", "apikey", "project_id" can be directly provided
1025
- instead.
 
1026
  model_name (str, optional): ID of a model to be used for inference. Mutually
1027
  exclusive with 'deployment_id'.
1028
  deployment_id (str, optional): Deployment ID of a tuned model to be used for
1029
  inference. Mutually exclusive with 'model_name'.
1030
- parameters (WMLInferenceEngineParams, optional): Instance of WMLInferenceEngineParams
1031
- which defines inference parameters and their values. Deprecated attribute, please
1032
- pass respective parameters directly to the WMLInferenceEngine class instead.
1033
- concurrency_limit (int): number of requests that will be sent in parallel, max is 10.
1034
-
1035
- Examples:
1036
- from .api import load_dataset
1037
-
1038
- wml_credentials = {
1039
- "url": "some_url", "project_id": "some_id", "api_key": "some_key"
1040
- }
1041
- model_name = "google/flan-t5-xxl"
1042
- wml_inference = WMLInferenceEngine(
1043
- credentials=wml_credentials,
1044
- model_name=model_name,
1045
- data_classification_policy=["public"],
1046
- top_p=0.5,
1047
- random_seed=123,
1048
- )
1049
-
1050
- dataset = load_dataset(
1051
- dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5"
1052
- )
1053
- results = wml_inference.infer(dataset["test"])
1054
  """
1055
 
1056
- credentials: Optional[Dict[Literal["url", "apikey", "project_id"], str]] = None
1057
  model_name: Optional[str] = None
1058
  deployment_id: Optional[str] = None
1059
  label: str = "wml"
1060
  _requirements_list = {
1061
- "ibm-watsonx-ai==1.1.14": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
1062
  "It is advised to have Python version >=3.10 installed, as at lower version this package "
1063
  "may cause conflicts with other installed packages."
1064
  }
1065
  data_classification_policy = ["public", "proprietary"]
1066
- parameters: Optional[WMLInferenceEngineParams] = None
1067
- concurrency_limit: int = 10
 
 
1068
  _client: Any = InternalField(default=None, name="WML client")
 
1069
 
1070
  def get_engine_id(self):
1071
- return get_model_and_label_id(self.model_name, self.label)
1072
 
1073
  def verify(self):
1074
  super().verify()
1075
 
1076
- if self.credentials is not None:
1077
- for key in self.credentials:
1078
- if key not in ["url", "apikey", "project_id", "space_id"]:
1079
- raise ValueError(
1080
- f'Illegal credential key: {key}, use only ["url", "apikey", "project_id", "space_id"]'
1081
- )
1082
-
1083
  assert (
1084
  self.model_name
1085
  or self.deployment_id
@@ -1095,166 +1804,186 @@ class WMLInferenceEngine(
1095
  data["credentials"][key] = value
1096
  return data
1097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1098
  @staticmethod
1099
- def _read_wml_credentials_from_env() -> (
1100
- Dict[Literal["url", "apikey", "project_id", "space_id"], str]
1101
- ):
1102
- credentials = {}
1103
- project_or_deployment_var_name = (
1104
- "WML_SPACE_ID" if "WML_SPACE_ID" in os.environ else "WML_PROJECT_ID"
 
1105
  )
 
1106
 
1107
- for env_var_name in ["WML_URL", project_or_deployment_var_name, "WML_APIKEY"]:
1108
- env_var = os.environ.get(env_var_name)
1109
- assert env_var, (
1110
- f"Error while trying to run 'WMLInferenceEngine'. "
1111
- f"Please set the env variable: '{env_var_name}', or "
1112
- f"directly provide an instance of ibm-watsonx-ai 'Credentials' "
1113
- f"to the engine."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1114
  )
1115
 
1116
- name = env_var_name.lower().replace("wml_", "")
1117
- credentials[name] = env_var
 
 
 
 
 
 
 
 
 
1118
 
1119
  return credentials
1120
 
1121
- def _initialize_wml_client(self):
1122
- from ibm_watsonx_ai.client import APIClient
1123
-
1124
- if self.credentials is None:
1125
- self.credentials = self._read_wml_credentials_from_env()
 
 
1126
 
1127
- client = APIClient(credentials=self.credentials)
1128
- if "space_id" in self.credentials:
1129
- client.set.default_space(self.credentials["space_id"])
1130
- else:
1131
- client.set.default_project(self.credentials["project_id"])
1132
- return client
 
 
 
 
 
 
 
1133
 
1134
  def prepare_engine(self):
 
 
1135
  self._client = self._initialize_wml_client()
1136
 
1137
  self._set_inference_parameters()
1138
 
1139
- def _load_model_and_params(self):
1140
- from ibm_watsonx_ai.foundation_models import ModelInference
1141
 
1142
- model = ModelInference(
1143
  model_id=self.model_name,
1144
  deployment_id=self.deployment_id,
1145
  api_client=self._client,
1146
  )
1147
- params = self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False)
1148
 
1149
- return model, params
 
 
 
 
 
 
 
 
 
 
 
1150
 
1151
  def _infer(
1152
  self,
1153
  dataset: Union[List[Dict[str, Any]], DatasetDict],
1154
  return_meta_data: bool = False,
1155
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1156
- self.verify_not_chat_api(dataset)
1157
- model, params = self._load_model_and_params()
1158
-
1159
- result = []
1160
- for source in dataset["source"]:
1161
- instance_result = model.generate(
1162
- prompt=source,
1163
- params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False),
1164
- )
1165
- prediction = instance_result["results"][0]["generated_text"]
1166
- instance_final_results = self.get_return_object(
1167
- prediction, instance_result, return_meta_data
1168
- )
1169
- result.append(instance_final_results)
1170
 
1171
- return result
 
 
 
 
1172
 
1173
  def _infer_log_probs(
1174
  self,
1175
  dataset: Union[List[Dict[str, Any]], DatasetDict],
1176
  return_meta_data: bool = False,
1177
  ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
1178
- self.verify_not_chat_api(dataset)
1179
-
1180
- model, params = self._load_model_and_params()
1181
-
1182
- user_return_options = params.pop("return_options", {})
1183
- # currently this is the only configuration that returns generated logprobs and behaves as expected
1184
- logprobs_return_options = {
1185
- "input_tokens": True,
1186
- "generated_tokens": True,
1187
- "token_logprobs": True,
1188
- "top_n_tokens": user_return_options.get("top_n_tokens", 5),
1189
- }
1190
- for key, value in logprobs_return_options.items():
1191
- if key in user_return_options and user_return_options[key] != value:
1192
- raise ValueError(
1193
- f"'{key}={user_return_options[key]}' is not supported for the 'infer_log_probs' "
1194
- f"method of {self.__class__.__name__}. For obtaining the logprobs of generated tokens "
1195
- f"please use '{key}={value}'."
1196
- )
1197
-
1198
- params = {
1199
- **params,
1200
- "return_options": logprobs_return_options,
1201
- }
1202
 
1203
- results = model.generate(
1204
- prompt=[instance["source"] for instance in dataset],
1205
- params=params,
 
1206
  )
1207
- final_results = []
1208
- for result in results:
1209
- generated_tokens = result["results"][0]["generated_tokens"]
1210
- final_results.append(
1211
- self.get_return_object(generated_tokens, result, return_meta_data)
1212
- )
1213
- return final_results
1214
 
1215
- def get_return_object(self, predict_result, result, return_meta_data):
1216
- if return_meta_data:
1217
- return TextGenerationInferenceOutput(
1218
- prediction=predict_result,
1219
- input_tokens=result["results"][0]["input_token_count"],
1220
- output_tokens=result["results"][0]["generated_token_count"],
1221
- model_name=self.model_name,
1222
- inference_type=self.label,
1223
- )
1224
- return predict_result
1225
 
1226
  def get_token_count(self, dataset):
1227
- from ibm_watsonx_ai.foundation_models import ModelInference
 
1228
 
1229
  texts = [instance["source"] for instance in dataset]
1230
 
1231
- model = ModelInference(
1232
- model_id=self.model_name,
1233
- deployment_id=self.deployment_id,
1234
- api_client=self._client,
1235
- )
1236
-
1237
  for i in trange(len(texts), desc="Tokenizing"):
1238
- response = model.tokenize(prompt=texts[i], return_tokens=True)["result"]
 
 
1239
  dataset[i]["token_count"] = response["token_count"]
1240
 
1241
  return dataset
1242
 
1243
  def get_options_log_probs(self, dataset):
1244
  """Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}."""
1245
- from ibm_watsonx_ai.foundation_models import ModelInference
1246
-
1247
- model = ModelInference(
1248
- model_id=self.model_name,
1249
- deployment_id=self.deployment_id,
1250
- api_client=self._client,
1251
- )
1252
 
1253
  texts = [x["source"] for x in dataset]
1254
 
1255
  responses = list(
1256
  tqdm(
1257
- model.generate(
1258
  prompt=texts,
1259
  params={
1260
  "decoding_method": "greedy",
@@ -1286,110 +2015,335 @@ class WMLInferenceEngine(
1286
  return dataset
1287
 
1288
 
1289
- def get_images_without_text(instance):
1290
- return extract_images(instance["source"], instance)
1291
 
 
1292
 
1293
- def get_text_without_images(instance, image_token="<image>"):
1294
- regex = r"<" + f"{constants.image_tag}" + r'\s+src=["\'](.*?)["\']\s*/?>'
1295
- return re.sub(regex, image_token, instance["source"])
1296
 
 
 
1297
 
1298
- class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin):
1299
- model_name: str
1300
- max_new_tokens: int
1301
- lazy_load = True
1302
- image_token = "<image>"
 
 
 
 
 
 
1303
 
1304
- _requirements_list = {
1305
- "transformers": "Install huggingface package using 'pip install --upgrade transformers",
1306
- "torch": "Install torch, go on PyTorch website for mode details.",
1307
- "accelerate": "pip install accelerate",
1308
- }
1309
 
1310
- def get_engine_id(self):
1311
- return get_model_and_label_id(self.model_name, "hf_lava")
1312
 
1313
- def _prepare_engine(self):
1314
- import torch
1315
- from transformers import AutoProcessor, LlavaForConditionalGeneration
1316
 
1317
- self.device = torch.device(
1318
- "mps"
1319
- if torch.backends.mps.is_available()
1320
- else 0
1321
- if torch.cuda.is_available()
1322
- else "cpu"
1323
  )
1324
 
1325
- self.model = LlavaForConditionalGeneration.from_pretrained(
1326
- self.model_name,
1327
- torch_dtype=torch.float16,
1328
- low_cpu_mem_usage=True,
1329
- ).to(self.device)
1330
-
1331
- self.processor = AutoProcessor.from_pretrained(self.model_name)
1332
-
1333
- def prepare_engine(self):
1334
- if not self.lazy_load:
1335
- self._prepare_engine()
1336
 
1337
- def _is_loaded(self):
1338
- return hasattr(self, "model") and self.model is not None
 
 
 
 
 
1339
 
1340
- def _get_input(self, instance):
1341
- assert isinstance(instance["source"], list), "Must use format=formats.chat_api"
1342
- images = []
1343
- conversation = []
1344
- for turn in instance["source"]:
1345
- if isinstance(turn["content"], list):
1346
- for content in turn["content"]:
1347
- if content["type"] == "image_url":
1348
- content["type"] = "image"
1349
- image_url = content.pop("image_url")["url"]
1350
- image = data_url_to_image(image_url)
1351
- images.append(image)
1352
- conversation.append(turn)
1353
- return conversation, images
1354
 
1355
- def _infer(
1356
  self,
1357
  dataset: Union[List[Dict[str, Any]], DatasetDict],
1358
- return_meta_data: bool = False,
1359
- ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1360
- if not self._is_loaded():
1361
- self._prepare_engine()
1362
 
1363
- import torch
1364
 
1365
- results = []
1366
- for instance in tqdm(dataset):
1367
- conversation, images = self._get_input(instance)
 
 
1368
 
1369
- if len(images) == 1:
1370
- images = images[0]
1371
 
1372
- text = self.processor.apply_chat_template(
1373
- conversation, add_generation_prompt=True
1374
- )
 
 
1375
 
1376
- inputs = self.processor(images=images, text=text, return_tensors="pt").to(
1377
- self.device, torch.float16
 
 
 
 
 
 
1378
  )
 
1379
 
1380
- input_len = len(inputs["input_ids"][0])
1381
- output = self.model.generate(
1382
- **inputs,
1383
- max_new_tokens=self.max_new_tokens,
1384
- do_sample=False,
1385
- pad_token_id=self.processor.tokenizer.eos_token_id,
 
 
 
 
 
1386
  )
1387
- result = self.processor.decode(
1388
- output[0][input_len:], skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1389
  )
1390
- results.append(result)
1391
 
1392
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1393
 
1394
 
1395
  class LMMSEvalBaseInferenceEngine(
@@ -1400,7 +2354,9 @@ class LMMSEvalBaseInferenceEngine(
1400
  batch_size: int = 1
1401
  image_token = "<image>"
1402
 
1403
- _requirements_list = ["lmms-eval==0.2.4"]
 
 
1404
 
1405
  def prepare_engine(self):
1406
  if not self.lazy_load:
@@ -1447,7 +2403,6 @@ class LMMSEvalInferenceEngine(LMMSEvalBaseInferenceEngine):
1447
  dataset: Union[List[Dict[str, Any]], DatasetDict],
1448
  return_meta_data: bool = False,
1449
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1450
- self.verify_not_chat_api(dataset)
1451
  if not self._is_loaded():
1452
  self._prepare_engine()
1453
 
 
9
  import time
10
  import uuid
11
  from collections import Counter
12
+ from typing import (
13
+ Any,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Literal,
18
+ Mapping,
19
+ Optional,
20
+ Sequence,
21
+ Tuple,
22
+ Union,
23
+ )
24
 
25
  from datasets import DatasetDict
26
  from tqdm import tqdm, trange
 
30
  from .dataclass import InternalField, NonPositionalField
31
  from .deprecation_utils import deprecation
32
  from .error_utils import UnitxtError
33
+ from .image_operators import EncodeImageToString, data_url_to_image, extract_images
34
  from .logging_utils import get_logger
35
  from .operator import PackageRequirementsMixin
36
  from .operators import ArtifactFetcherMixin
37
  from .settings_utils import get_constants, get_settings
38
+ from .type_utils import isoftype
39
 
40
  constants = get_constants()
41
  settings = get_settings()
 
79
 
80
  input_tokens (int) : number of input tokens to the model.
81
  output_tokens (int) : number of output tokens to the model.
82
+ stop_reason (str): stop reason for text generation, for example "eos" (end of string).
83
+ seed (int): seed used by the model during generation.
84
+ input_text (str): input to the model.
85
  model_name (str): the model_name as kept in the InferenceEngine.
86
  inference_type (str): The label stating the type of the InferenceEngine.
87
  """
 
89
  prediction: Union[str, List[Dict[str, Any]]]
90
  input_tokens: Optional[int] = None
91
  output_tokens: Optional[int] = None
92
+ stop_reason: Optional[str] = None
93
+ seed: Optional[int] = None
94
+ input_text: Optional[str] = None
95
  model_name: Optional[str] = None
96
  inference_type: Optional[str] = None
97
 
 
170
  if param_inst_val is None:
171
  setattr(self, param, param_dict_val)
172
 
173
+ def get_model_details(self) -> Dict:
174
+ """Might not be possible to implement for all inference engines. Returns an empty dict by default."""
175
+ return {}
176
+
177
  def verify_not_chat_api(self, dataset):
178
  if isinstance(dataset[0]["source"], list):
179
  raise NotImplementedError(
 
238
  pass
239
 
240
 
241
+ class HFGenerationParamsMixin(Artifact):
 
 
 
242
  max_new_tokens: int
243
+ do_sample: bool = False
244
+ temperature: Optional[float] = None
245
+ top_p: Optional[float] = None
246
  top_k: Optional[int] = None
247
+ num_beams: Optional[int] = None
248
+ repetition_penalty: Optional[float] = None
249
+ pad_token_id: Optional[int] = None
250
+ eos_token_id: Optional[int] = None
251
+
252
+
253
+ class HFInferenceEngineBase(
254
+ InferenceEngine,
255
+ LogProbInferenceEngine,
256
+ PackageRequirementsMixin,
257
+ LazyLoadMixin,
258
+ HFGenerationParamsMixin,
259
+ ):
260
+ model_name: str
261
+ label: str
262
+
263
+ n_top_tokens: int = 5
264
+
265
+ device: Any = None
266
+ device_map: Any = None
267
+
268
+ use_fast_tokenizer: bool = True
269
+ low_cpu_mem_usage: bool = True
270
+ torch_dtype: str = "torch.float16"
271
+
272
+ model: Any = InternalField(default=None, name="Inference object")
273
+ processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
274
 
275
  _requirements_list = {
276
+ "transformers": "Install huggingface package using 'pip install --upgrade transformers",
277
+ "torch": "Install torch, go on PyTorch website for mode details.",
278
+ "accelerate": "pip install accelerate",
279
  }
280
 
281
+ def _is_loaded(self):
282
+ return hasattr(self, "model") and self.model is not None
283
 
284
+ def _set_inference_device(self):
285
+ if self.device is not None and self.device_map is not None:
286
+ raise ValueError(
287
+ f"You must specify either 'device' or 'device_map', however both "
288
+ f"were given: 'device={self.device}', 'device_map={self.device_map}'."
289
+ )
290
 
291
+ if self.device is None and self.device_map is None:
292
+ import torch
 
 
 
 
 
293
 
294
+ self.device = torch.device(
295
+ "mps"
296
+ if torch.backends.mps.is_available()
297
+ else 0
298
+ if torch.cuda.is_available()
299
+ else "cpu"
300
+ )
301
 
302
+ @abc.abstractmethod
303
+ def _init_processor(self):
304
+ raise NotImplementedError
 
305
 
306
+ @abc.abstractmethod
307
+ def _init_model(self):
308
+ raise NotImplementedError
309
+
310
+ def _get_torch_dtype(self):
311
+ import torch
312
+
313
+ if not isinstance(self.torch_dtype, str) or not self.torch_dtype.startswith(
314
+ "torch."
315
+ ):
316
+ raise ValueError(
317
+ f"'torch_dtype' must be a string representing torch data "
318
+ f"type used for inference. The name should be an absolute "
319
+ f"import, for example: 'torch.float16'. However, "
320
+ f"'{self.torch_dtype}' was given instead."
321
+ )
322
 
323
+ try:
324
+ dtype = eval(self.torch_dtype)
325
+ except (AttributeError, TypeError) as e:
326
+ raise ValueError(
327
+ f"Incorrect value of 'torch_dtype' was given: '{self.torch_dtype}'."
328
+ ) from e
329
+
330
+ if not isinstance(dtype, torch.dtype):
331
+ raise ValueError(
332
+ f"'torch_dtype' must be an instance of 'torch.dtype', however, "
333
+ f"'{dtype}' is an instance of '{type(dtype)}'."
334
+ )
335
 
336
+ return dtype
 
337
 
338
+ def _prepare_engine(self):
339
+ self._set_inference_device()
340
+ self._init_processor()
341
+ self._init_model()
342
 
343
  def prepare_engine(self):
344
  if not self.lazy_load:
345
+ self._prepare_engine()
346
 
347
+ def get_engine_id(self):
348
+ return get_model_and_label_id(self.model_name, self.label)
349
 
350
+ def decode_tokens(self, tokens: Sequence, inp_length: int) -> List[str]:
351
+ return [
352
+ self.processor.decode(token, skip_special_tokens=True)
353
+ for token in tokens[inp_length:]
354
+ ]
 
 
355
 
356
+ @staticmethod
357
+ def create_string_from_tokens(string_tokens: List[str]) -> str:
358
+ return "".join(token for token in string_tokens)
359
+
360
+ def make_predictions(self, prepared_inputs: Mapping) -> Mapping:
361
+ return self.model.generate(
362
+ **prepared_inputs,
363
+ **self.to_dict([HFGenerationParamsMixin], keep_empty=False),
364
+ output_scores=True,
365
+ return_dict_in_generate=True,
366
+ )
367
 
368
+ def compute_transition_scores(
369
+ self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int]
370
+ ) -> Sequence:
371
+ # Some models may not support computing scores in this form by default, so a possible
372
+ # child class should have its own implementation of this method if necessary.
373
+ return self.model.compute_transition_scores(
374
+ sequences,
375
+ scores,
376
+ normalize_logits=True,
377
+ beam_indices=beam_indices,
378
+ )
379
 
380
+ def get_logprobs(
381
+ self, predictions: Mapping, string_tokens: List[List[str]]
382
+ ) -> List[List[Dict[str, Any]]]:
383
+ beam_indices = (
384
+ predictions.beam_indices
385
+ if self.num_beams is not None and self.num_beams > 1
386
+ else None
387
+ )
388
 
389
+ transition_scores = self.compute_transition_scores(
390
+ sequences=predictions.sequences,
391
+ scores=predictions.scores,
392
+ beam_indices=beam_indices,
393
+ )
394
 
395
+ logprobs: List[List[Dict[str, Any]]] = []
 
396
 
397
+ for sample_no, sample_scores in enumerate(transition_scores.detach().cpu()):
398
+ sample_logprobs: List[Dict[str, Any]] = []
399
 
400
+ for n, score in enumerate(sample_scores):
401
+ sample_logprobs.append(
402
+ {
403
+ "text": string_tokens[sample_no][n],
404
+ "logprob": float(score.cpu()),
405
+ "top_tokens": [
406
+ {
407
+ "text": self.processor.decode(idx),
408
+ "logprob": float(
409
+ predictions.scores[n][sample_no][idx].cpu()
410
+ ),
411
+ }
412
+ for idx in predictions.scores[n][sample_no].argsort(
413
+ dim=0, descending=True
414
+ )[: self.n_top_tokens]
415
+ ],
416
+ }
417
+ )
418
+
419
+ logprobs.append(sample_logprobs)
420
+
421
+ return logprobs
422
+
423
+ @abc.abstractmethod
424
+ def prepare_inputs(self, data: Iterable) -> Mapping:
425
+ raise NotImplementedError
426
+
427
+ def get_return_object(
428
+ self,
429
+ output: Union[str, List[Dict[str, Any]]],
430
+ output_tokens: Optional[int],
431
+ inp: Optional[str],
432
+ inp_tokens: Optional[int],
433
+ return_meta_data: bool,
434
+ ) -> Union[str, List[Dict[str, Any]], TextGenerationInferenceOutput]:
435
+ if return_meta_data:
436
+ return TextGenerationInferenceOutput(
437
+ prediction=output,
438
+ output_tokens=output_tokens if output_tokens is not None else None,
439
+ input_text=inp,
440
+ input_tokens=inp_tokens if inp_tokens is not None else None,
441
+ model_name=self.model_name,
442
+ inference_type=self.label,
443
+ )
444
+ return output
445
+
446
+ def infer(
447
  self,
448
  dataset: Union[List[Dict[str, Any]], DatasetDict],
449
+ return_meta_data: bool = False,
450
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
451
+ if not self._is_loaded():
452
+ self._prepare_engine()
453
+ return super().infer(dataset, return_meta_data)
454
 
455
+ @abc.abstractmethod
456
  def _infer(
457
  self,
458
  dataset: Union[List[Dict[str, Any]], DatasetDict],
459
  return_meta_data: bool = False,
460
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
461
+ raise NotImplementedError
 
 
 
 
462
 
463
+ def infer_log_probs(
464
+ self,
465
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
466
+ return_meta_data: bool = False,
467
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
468
+ if not self._is_loaded():
469
+ self._prepare_engine()
470
+ return super().infer_log_probs(dataset, return_meta_data)
471
 
472
+ @abc.abstractmethod
473
+ def _infer_log_probs(
474
+ self,
475
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
476
+ return_meta_data: bool = False,
477
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
478
+ raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
479
 
480
 
481
+ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
482
+ label: str = "hf_auto_model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
+ def _init_processor(self):
485
+ from transformers import AutoTokenizer
486
 
487
+ self.processor = AutoTokenizer.from_pretrained(
488
+ pretrained_model_name_or_path=self.model_name,
489
+ use_fast=self.use_fast_tokenizer,
490
+ padding=True,
491
+ truncation=True,
492
+ )
493
 
494
+ def _init_model(self):
495
+ from transformers import (
496
+ AutoConfig,
497
+ AutoModelForCausalLM,
498
+ AutoModelForSeq2SeqLM,
499
+ )
 
 
 
 
 
 
 
 
500
 
501
+ model_class = (
502
+ AutoModelForSeq2SeqLM
503
+ if AutoConfig.from_pretrained(self.model_name).is_encoder_decoder
504
+ else AutoModelForCausalLM
505
+ )
506
 
507
+ self.model = model_class.from_pretrained(
508
+ pretrained_model_name_or_path=self.model_name,
509
+ trust_remote_code=True,
510
+ device_map=self.device_map,
511
+ torch_dtype=self._get_torch_dtype(),
512
+ )
513
+ if self.device_map is None:
514
+ self.model.to(self.device)
515
+
516
+ def prepare_inputs(self, data: Iterable) -> Mapping:
517
+ return self.processor(
518
+ data,
519
+ padding=True,
520
+ truncation=True,
521
+ return_tensors="pt",
522
+ ).to(self.device or self.device_map)
523
+
524
+ def _infer_fn(
525
  self,
526
  dataset: Union[List[Dict[str, Any]], DatasetDict],
527
+ return_meta_data: bool,
528
+ return_logprobs: bool,
529
+ ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
530
+ tokenized_inputs = self.prepare_inputs(
531
+ [instance["source"] for instance in dataset]
532
+ )
533
+ input_length = (
534
+ 1
535
+ if self.model.config.is_encoder_decoder
536
+ else tokenized_inputs.input_ids.shape[1]
537
+ )
538
 
539
+ predictions = self.make_predictions(tokenized_inputs)
540
+ sequences = predictions.sequences
541
 
542
+ string_tokens = [
543
+ self.decode_tokens(sequence, input_length) for sequence in sequences
544
+ ]
 
 
 
 
 
545
 
546
+ final_outputs = (
547
+ self.get_logprobs(predictions, string_tokens)
548
+ if return_logprobs
549
+ else [self.create_string_from_tokens(strings) for strings in string_tokens]
550
+ )
551
 
552
+ return [
553
+ self.get_return_object(
554
+ output=final_outputs[i],
555
+ output_tokens=len(string_tokens[i]),
556
+ inp=dataset[i]["source"],
557
+ inp_tokens=len(tokenized_inputs.encodings[i].tokens)
558
+ if tokenized_inputs.encodings is not None
559
+ else None,
560
+ return_meta_data=return_meta_data,
561
+ )
562
+ for i in range(len(sequences))
563
+ ]
564
 
565
  def _infer(
566
  self,
567
  dataset: Union[List[Dict[str, Any]], DatasetDict],
568
  return_meta_data: bool = False,
569
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
570
+ self.verify_not_chat_api(dataset)
571
+ return self._infer_fn(dataset, return_meta_data, False)
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
+ def _infer_log_probs(
574
+ self,
575
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
576
+ return_meta_data: bool = False,
577
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
578
+ self.verify_not_chat_api(dataset)
579
+ return self._infer_fn(dataset, return_meta_data, True)
580
 
581
 
582
+ class HFLlavaInferenceEngine(HFInferenceEngineBase):
583
+ lazy_load: bool = True
584
+ label: str = "hf_lava"
585
+ image_token: str = "<image>"
586
 
587
+ def compute_transition_scores(
588
+ self, sequences: Sequence, scores: Sequence, beam_indices: Optional[int]
589
+ ) -> Sequence:
590
+ if not hasattr(self.model.config, "vocab_size"):
591
+ self.model.config.vocab_size = self.model.vocab_size
592
 
593
+ return super().compute_transition_scores(sequences, scores, beam_indices)
 
 
594
 
595
+ def _init_processor(self):
596
+ from transformers import AutoProcessor
597
 
598
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
 
 
599
 
600
+ if not self.pad_token_id and hasattr(self.processor, "eos_token_id"):
601
+ self.pad_token_id = self.processor.eos_token_id
 
602
 
603
+ def _init_model(self):
604
+ from transformers import LlavaForConditionalGeneration
605
 
606
+ self.model = LlavaForConditionalGeneration.from_pretrained(
607
+ self.model_name,
608
+ torch_dtype=self._get_torch_dtype(),
609
+ low_cpu_mem_usage=self.low_cpu_mem_usage,
610
+ device_map=self.device_map,
611
+ )
612
+ if self.device_map is None:
613
+ self.model.to(self.device)
614
 
615
+ @staticmethod
616
+ def _get_input(instance):
617
+ assert isinstance(instance["source"], list), "Must use format=formats.chat_api"
618
+ images = []
619
+ conversation = []
620
+ for turn in instance["source"]:
621
+ if isinstance(turn["content"], list):
622
+ for content in turn["content"]:
623
+ if content["type"] == "image_url":
624
+ content["type"] = "image"
625
+ image_url = content.pop("image_url")["url"]
626
+ image = data_url_to_image(image_url)
627
+ images.append(image)
628
+ conversation.append(turn)
629
+ return conversation, images
630
+
631
+ def prepare_inputs(self, data: Iterable) -> Mapping:
632
+ conversation, images = self._get_input(data)
633
+
634
+ if len(images) == 1:
635
+ images = images[0]
636
+
637
+ text = self.processor.apply_chat_template(
638
+ conversation, add_generation_prompt=True
639
+ )
640
+
641
+ inputs: Mapping = self.processor(
642
+ images=images, text=text, return_tensors="pt"
643
+ ).to(self.device or self.device_map, self._get_torch_dtype())
644
+
645
+ return inputs
646
+
647
+ def _infer_fn(
648
+ self,
649
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
650
+ return_meta_data: bool,
651
+ return_logprobs: bool,
652
+ ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
653
+ results = []
654
+
655
+ for instance in tqdm(dataset):
656
+ processed_inputs = self.prepare_inputs(instance)
657
+ input_len = len(processed_inputs["input_ids"][0])
658
+
659
+ predictions = self.make_predictions(processed_inputs)
660
+
661
+ string_tokens = self.decode_tokens(predictions.sequences[0], input_len)
662
+
663
+ final_outputs = (
664
+ self.get_logprobs(predictions, [string_tokens])[0]
665
+ if return_logprobs
666
+ else self.create_string_from_tokens(string_tokens)
667
+ )
668
+
669
+ results.append(
670
+ self.get_return_object(
671
+ output=final_outputs,
672
+ output_tokens=len(string_tokens),
673
+ inp=instance["source"],
674
+ inp_tokens=None,
675
+ return_meta_data=return_meta_data,
676
+ )
677
+ )
678
+
679
+ return results
680
+
681
+ def _infer(
682
+ self,
683
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
684
+ return_meta_data: bool = False,
685
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
686
+ return self._infer_fn(dataset, return_meta_data, False)
687
+
688
+ def _infer_log_probs(
689
+ self,
690
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
691
+ return_meta_data: bool = False,
692
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
693
+ return self._infer_fn(dataset, return_meta_data, True)
694
+
695
+
696
+ class HFPeftInferenceEngine(HFAutoModelInferenceEngine):
697
+ label: str = "hf_peft_auto_model"
698
+
699
+ peft_config: Any = InternalField(
700
+ default=None,
701
+ name="PEFT config read from the directory or the Hub repository "
702
+ "id specified in the 'model_name'.",
703
+ )
704
+
705
+ _requirements_list = {
706
+ "transformers": "Install huggingface package using 'pip install --upgrade transformers",
707
+ "torch": "Install torch, go on PyTorch website for mode details.",
708
+ "accelerate": "pip install accelerate",
709
+ "peft": "Install 'peft' package using: 'pip install peft'.",
710
+ }
711
+
712
+ def _prepare_engine(self):
713
+ self._read_peft_config()
714
+ super()._prepare_engine()
715
+
716
+ def _read_peft_config(self):
717
+ from peft import PeftConfig
718
+
719
+ try:
720
+ config = PeftConfig.from_pretrained(self.model_name)
721
+ assert isinstance(config.base_model_name_or_path, str)
722
+ self.peft_config = config
723
+
724
+ except ValueError as e:
725
+ if "Can't find" in str(e):
726
+ raise ValueError(
727
+ f"Specified model '{self.model_name}' is not the PEFT model. "
728
+ f"Use a regular instance of the `HFAutoModelInferenceEngine` "
729
+ f"instead."
730
+ ) from e
731
+
732
+ raise e
733
+
734
+ def _init_processor(self):
735
+ from transformers import AutoTokenizer
736
+
737
+ self.processor = AutoTokenizer.from_pretrained(
738
+ self.peft_config.base_model_name_or_path
739
+ )
740
+
741
+ def _init_model(self):
742
+ from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
743
+ from transformers import AutoConfig
744
+
745
+ model_class = (
746
+ AutoPeftModelForSeq2SeqLM
747
+ if AutoConfig.from_pretrained(self.model_name).is_encoder_decoder
748
+ else AutoPeftModelForCausalLM
749
+ )
750
+
751
+ self.model = model_class.from_pretrained(
752
+ pretrained_model_name_or_path=self.peft_config.base_model_name_or_path,
753
+ trust_remote_code=True,
754
+ device_map=self.device_map,
755
+ low_cpu_mem_usage=self.low_cpu_mem_usage,
756
+ torch_dtype=self._get_torch_dtype(),
757
+ )
758
+ if self.device_map is None:
759
+ self.model.to(self.device)
760
+
761
+
762
+ @deprecation(
763
+ version="2.0.0", msg=" Use non-pipeline-based 'HFInferenceEngine' instead."
764
+ )
765
+ class HFPipelineBasedInferenceEngine(
766
+ InferenceEngine, PackageRequirementsMixin, LazyLoadMixin, HFGenerationParamsMixin
767
+ ):
768
+ model_name: str
769
+ label: str = "hf_pipeline_inference_engine"
770
+
771
+ use_fast_tokenizer: bool = True
772
+ use_fp16: bool = True
773
+ load_in_8bit: bool = False
774
+
775
+ task: Optional[str] = None
776
+
777
+ device: Any = None
778
+ device_map: Any = None
779
+
780
+ pipe: Any = InternalField(default=None)
781
+
782
+ _requirements_list = {
783
+ "transformers": "Install huggingface package using 'pip install --upgrade transformers",
784
+ "torch": "Install torch, go on PyTorch website for mode details.",
785
+ "accelerate": "pip install accelerate",
786
+ }
787
+
788
+ def _is_loaded(self):
789
+ return hasattr(self, "model") and self.model is not None
790
+
791
+ def get_engine_id(self):
792
+ return get_model_and_label_id(self.model_name, "hf_pipeline")
793
+
794
+ def _define_task(self):
795
+ from transformers import AutoConfig
796
+
797
+ self.task = (
798
+ "text2text-generation"
799
+ if AutoConfig.from_pretrained(
800
+ self.model_name, trust_remote_code=True
801
+ ).is_encoder_decoder
802
+ else "text-generation"
803
+ )
804
+
805
+ def _get_model_args(self) -> Dict[str, Any]:
806
+ import torch
807
+ from transformers import BitsAndBytesConfig
808
+
809
+ args = {}
810
+
811
+ if self.load_in_8bit:
812
+ quantization_config = BitsAndBytesConfig(load_in_8bit=self.load_in_8bit)
813
+ args["quantization_config"] = quantization_config
814
+ elif self.use_fp16:
815
+ if self.device == torch.device("mps"):
816
+ args["torch_dtype"] = torch.float16
817
+ else:
818
+ args["torch_dtype"] = torch.bfloat16
819
+
820
+ # We do this, because in some cases, using device:auto will offload some weights to the cpu
821
+ # (even though the model might *just* fit to a single gpu), even if there is a gpu available, and this will
822
+ # cause an error because the data is always on the gpu
823
+ if torch.cuda.device_count() > 1:
824
+ assert self.device == torch.device(0)
825
+ args["device_map"] = "auto"
826
+ else:
827
+ if not self.load_in_8bit:
828
+ args["device"] = self.device
829
+
830
+ if self.task == "text-generation":
831
+ args["return_full_text"] = False
832
+
833
+ return args
834
+
835
+ def _create_pipeline(self, model_args: Dict[str, Any]):
836
+ from transformers import pipeline
837
+
838
+ self.model = pipeline(
839
+ model=self.model_name,
840
+ task=self.task,
841
+ use_fast=self.use_fast_tokenizer,
842
+ trust_remote_code=True,
843
+ **model_args,
844
+ **self.to_dict(
845
+ [HFGenerationParamsMixin],
846
+ keep_empty=False,
847
+ ),
848
+ )
849
+
850
+ def _set_inference_device(self):
851
+ if self.device is not None and self.device_map is not None:
852
+ raise ValueError(
853
+ f"You must specify either 'device' or 'device_map', however both "
854
+ f"were given: 'device={self.device}', 'device_map={self.device_map}'."
855
+ )
856
+
857
+ if self.device is None and self.device_map is None:
858
+ import torch
859
+
860
+ self.device = torch.device(
861
+ "mps"
862
+ if torch.backends.mps.is_available()
863
+ else 0
864
+ if torch.cuda.is_available()
865
+ else "cpu"
866
+ )
867
+
868
+ def _prepare_engine(self):
869
+ self._set_inference_device()
870
+ if self.task is None:
871
+ self._define_task()
872
+ model_args = self._get_model_args()
873
+ self._create_pipeline(model_args)
874
+
875
+ def prepare_engine(self):
876
+ if not self.lazy_load:
877
+ self._prepare_engine()
878
+
879
+ def _infer(
880
+ self,
881
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
882
+ return_meta_data: bool = False,
883
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
884
+ if not self._is_loaded():
885
+ self._prepare_engine()
886
+
887
+ outputs = self.model([instance["source"] for instance in dataset])
888
+
889
+ return [
890
+ self.get_return_object(output[0], instance["source"], return_meta_data)
891
+ if isinstance(output, list)
892
+ else self.get_return_object(output, instance["source"], return_meta_data)
893
+ for output, instance in zip(outputs, dataset)
894
+ ]
895
+
896
+ def get_return_object(self, output, inp, return_meta_data):
897
+ if return_meta_data:
898
+ return TextGenerationInferenceOutput(
899
+ prediction=output["generated_text"],
900
+ model_name=self.model_name,
901
+ inference_type=self.label,
902
+ input_text=inp,
903
+ )
904
+ return output["generated_text"]
905
+
906
+
907
+ def mock_logprobs_default_value_factory() -> List[Dict[str, Any]]:
908
+ return [
909
+ {
910
+ "logprob": -1,
911
+ "text": "[[10]]",
912
+ "top_tokens": [
913
+ {"logprob": -1, "text": "[[10]]"},
914
+ ],
915
+ }
916
+ ]
917
+
918
+
919
+ class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine):
920
+ model_name: str
921
+ default_inference_value: str = "[[10]]"
922
+ default_inference_value_logprob: List[Dict[str, Any]] = dataclasses.field(
923
+ default_factory=mock_logprobs_default_value_factory,
924
+ )
925
+ label: str = "mock_inference_engine"
926
+
927
+ def get_engine_id(self):
928
+ return get_model_and_label_id(self.model_name, "mock")
929
+
930
+ def prepare_engine(self):
931
+ return
932
+
933
+ def _mock_infer(
934
+ self,
935
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
936
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
937
+ return [self.default_inference_value for _ in dataset]
938
+
939
+ def _infer(
940
+ self,
941
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
942
+ return_meta_data: bool = False,
943
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
944
+ return [
945
+ self.get_return_object(
946
+ self.default_inference_value, instance, return_meta_data
947
+ )
948
+ for instance in dataset
949
+ ]
950
+
951
+ def _infer_log_probs(
952
+ self,
953
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
954
+ return_meta_data: bool = False,
955
+ ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
956
+ return [
957
+ self.get_return_object(
958
+ self.default_inference_value_logprob, instance, return_meta_data
959
+ )
960
+ for instance in dataset
961
+ ]
962
+
963
+ def get_return_object(self, predict_result, instance, return_meta_data):
964
+ if return_meta_data:
965
+ return TextGenerationInferenceOutput(
966
+ prediction=predict_result,
967
+ input_tokens=len(instance["source"]),
968
+ output_tokens=len(predict_result),
969
+ model_name=self.model_name,
970
+ inference_type=self.label,
971
+ input_text=instance["source"],
972
+ seed=111,
973
+ stop_reason="",
974
+ )
975
+ return predict_result
976
+
977
+
978
+ class MockModeMixin(Artifact):
979
+ mock_mode: bool = False
980
+
981
+
982
+ class IbmGenAiInferenceEngineParamsMixin(Artifact):
983
+ beam_width: Optional[int] = None
984
+ decoding_method: Optional[Literal["greedy", "sample"]] = None
985
+ include_stop_sequence: Optional[bool] = None
986
+ length_penalty: Any = None
987
+ max_new_tokens: Optional[int] = None
988
+ min_new_tokens: Optional[int] = None
989
+ random_seed: Optional[int] = None
990
+ repetition_penalty: Optional[float] = None
991
+ return_options: Any = None
992
+ stop_sequences: Optional[List[str]] = None
993
+ temperature: Optional[float] = None
994
+ time_limit: Optional[int] = None
995
+ top_k: Optional[int] = None
996
+ top_p: Optional[float] = None
997
+ truncate_input_tokens: Optional[int] = None
998
+ typical_p: Optional[float] = None
999
+
1000
+
1001
+ @deprecation(version="2.0.0", alternative=IbmGenAiInferenceEngineParamsMixin)
1002
+ class IbmGenAiInferenceEngineParams(Artifact):
1003
+ beam_width: Optional[int] = None
1004
+ decoding_method: Optional[Literal["greedy", "sample"]] = None
1005
+ include_stop_sequence: Optional[bool] = None
1006
+ length_penalty: Any = None
1007
+ max_new_tokens: Optional[int] = None
1008
+ min_new_tokens: Optional[int] = None
1009
+ random_seed: Optional[int] = None
1010
+ repetition_penalty: Optional[float] = None
1011
+ return_options: Any = None
1012
+ stop_sequences: Optional[List[str]] = None
1013
+ temperature: Optional[float] = None
1014
+ time_limit: Optional[int] = None
1015
+ top_k: Optional[int] = None
1016
+ top_p: Optional[float] = None
1017
+ truncate_input_tokens: Optional[int] = None
1018
+ typical_p: Optional[float] = None
1019
+
1020
+
1021
+ class GenericInferenceEngine(
1022
+ InferenceEngine, ArtifactFetcherMixin, LogProbInferenceEngine
1023
+ ):
1024
+ default: Optional[str] = None
1025
+
1026
+ def prepare_engine(self):
1027
+ if "UNITXT_INFERENCE_ENGINE" in os.environ:
1028
+ engine_reference = os.environ["UNITXT_INFERENCE_ENGINE"]
1029
+ else:
1030
+ assert self.default is not None, (
1031
+ "GenericInferenceEngine could not be initialized"
1032
+ '\nThis is since both the "UNITXT_INFERENCE_ENGINE" environmental variable is not set and no default engine was not inputted.'
1033
+ "\nFor example, you can fix it by setting"
1034
+ "\nexport UNITXT_INFERENCE_ENGINE=engines.ibm_gen_ai.llama_3_70b_instruct"
1035
+ "\nto your ~/.bashrc"
1036
+ "\nor passing a similar required engine in the default argument"
1037
+ )
1038
+ engine_reference = self.default
1039
+ self.engine = self.get_artifact(engine_reference)
1040
+
1041
+ def get_engine_id(self):
1042
+ # If mock_inference_mode is set, no engine is prepared.
1043
+ if hasattr(self, "engine"):
1044
+ return f"generic_{self.engine.get_engine_id()}"
1045
+ return "generic_inference_engine"
1046
+
1047
+ def _infer(
1048
+ self,
1049
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
1050
+ return_meta_data: bool = False,
1051
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1052
+ return self.engine._infer(dataset)
1053
+
1054
+ def _infer_log_probs(
1055
+ self,
1056
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
1057
+ return_meta_data: bool = False,
1058
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1059
+ if not isinstance(self.engine, LogProbInferenceEngine):
1060
+ raise NotImplementedError(
1061
+ f"Error in infer: inference engine used by the GenericInferenceEngine"
1062
+ f"({self.engine.__class__.__name__}) does not support logprobs."
1063
+ )
1064
+ return self.engine._infer_log_probs(dataset)
1065
+
1066
+
1067
+ class OllamaInferenceEngine(
1068
+ InferenceEngine, StandardAPIParamsMixin, PackageRequirementsMixin
1069
+ ):
1070
+ label: str = "ollama"
1071
+ _requirements_list = {
1072
+ "ollama": "Install ollama package using 'pip install --upgrade ollama"
1073
+ }
1074
+ data_classification_policy = ["public", "proprietary"]
1075
+
1076
+ def get_engine_id(self):
1077
+ return get_model_and_label_id(self.model, self.label)
1078
+
1079
+ def prepare_engine(self):
1080
+ pass
1081
+
1082
+ def _infer(
1083
+ self,
1084
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
1085
+ return_meta_data: bool = False,
1086
+ ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1087
+ import ollama
1088
+
1089
+ args = self.to_dict([StandardAPIParamsMixin])
1090
+
1091
+ results = []
1092
+
1093
+ for instance in dataset:
1094
+ messages = self.to_messages(instance)
1095
+ response = ollama.chat(
1096
+ model=self.model,
1097
+ messages=messages,
1098
+ **args,
1099
+ )
1100
+ results.append(response)
1101
+
1102
+ return [element["message"]["content"] for element in results]
1103
+
1104
+
1105
+ class OptionSelectingByLogProbsInferenceEngine:
1106
+ """OptionSelectingByLogProbsInferenceEngine inference engine is used to select an option based on the logprobs of an options list conditioned by a prompt.
1107
+
1108
+ The inference engines that inherit from this class must implement `get_token_count` and `get_options_log_probs`.
1109
+ """
1110
+
1111
+ @abc.abstractmethod
1112
+ def get_token_count(self, dataset):
1113
+ """Get the token count of the source key of each dict of the dataset. Add to each instance in the data a "token_count" field.
1114
+
1115
+ Args:
1116
+ dataset (List[Dict[str, Any]]): A list of dictionaries, each representing a data instance.
1117
+
1118
+ Returns:
1119
+ List[int]: The token count of the texts
1120
+ """
1121
+
1122
+ @abc.abstractmethod
1123
+ def get_options_log_probs(self, dataset):
1124
+ """Get the token logprobs of the options of the key task_data.options of each dict of the dataset.
1125
+
1126
+ Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}.
1127
+
1128
+ Args:
1129
+ dataset (List[Dict[str, Any]]): A list of dictionaries, each representing a data instance.
1130
+
1131
+ Returns:
1132
+ List[int]: The token count of the texts
1133
  """
1134
 
1135
  def select(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
1213
  }
1214
  data_classification_policy = ["public", "proprietary"]
1215
  parameters: Optional[IbmGenAiInferenceEngineParams] = None
1216
+ rate_limit: int = 10
1217
 
1218
  def get_engine_id(self):
1219
  return get_model_and_label_id(self.model_name, self.label)
1220
 
1221
+ @staticmethod
1222
+ def _get_credentials():
1223
+ from genai import Credentials
1224
 
1225
  api_key_env_var_name = "GENAI_KEY"
1226
  api_key = os.environ.get(api_key_env_var_name)
 
1229
  f"Error while trying to run IbmGenAiInferenceEngine."
1230
  f" Please set the environment param '{api_key_env_var_name}'."
1231
  )
1232
+
1233
+ return Credentials(api_key=api_key)
1234
+
1235
+ def prepare_engine(self):
1236
+ self.check_missing_requirements()
1237
+
1238
+ from genai import Client
1239
+ from genai.text.generation import CreateExecutionOptions
1240
+
1241
+ credentials = self._get_credentials()
1242
  self.client = Client(credentials=credentials)
1243
 
1244
+ self.execution_options = CreateExecutionOptions(
1245
+ concurrency_limit=self.rate_limit
1246
+ )
1247
+
1248
  self._set_inference_parameters()
1249
 
1250
  def _infer(
 
1252
  dataset: Union[List[Dict[str, Any]], DatasetDict],
1253
  return_meta_data: bool = False,
1254
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1255
+ from genai.schema import TextGenerationParameters, TextGenerationResult
1256
+
1257
+ self.verify_not_chat_api(dataset)
1258
 
1259
  genai_params = TextGenerationParameters(
1260
  **self.to_dict([IbmGenAiInferenceEngineParamsMixin])
1261
  )
1262
 
 
1263
  responses = self.client.text.generation.create(
1264
  model_id=self.model_name,
1265
  inputs=[instance["source"] for instance in dataset],
1266
  parameters=genai_params,
1267
+ execution_options=self.execution_options,
1268
  )
1269
+
1270
+ results = []
1271
  for response in responses:
1272
+ generation_result: TextGenerationResult = response.results[0]
1273
  result = self.get_return_object(
1274
+ generation_result.generated_text, generation_result, return_meta_data
1275
  )
1276
  results.append(result)
1277
  return results
 
1281
  dataset: Union[List[Dict[str, Any]], DatasetDict],
1282
  return_meta_data: bool = False,
1283
  ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
1284
+ from genai.schema import TextGenerationParameters, TextGenerationResult
1285
+
1286
+ self.verify_not_chat_api(dataset)
1287
 
1288
  logprobs_return_options = {
1289
  "generated_tokens": True,
 
1302
  model_id=self.model_name,
1303
  inputs=[instance["source"] for instance in dataset],
1304
  parameters=genai_params,
1305
+ execution_options=self.execution_options,
1306
  )
1307
 
1308
  predict_results = []
1309
  for prediction in predictions:
1310
+ result: TextGenerationResult = prediction.results[0]
1311
  assert isinstance(
1312
  result.generated_tokens, list
1313
  ), "result.generated_tokens should be a list"
 
1334
  output_tokens=result.generated_token_count,
1335
  model_name=self.model_name,
1336
  inference_type=self.label,
1337
+ input_text=result.input_text,
1338
+ seed=self.random_seed,
1339
+ stop_reason=result.stop_reason,
1340
  )
1341
  return predict_result
1342
 
1343
+ def get_model_details(self) -> Dict:
1344
+ from genai import ApiClient
1345
+ from genai.model import ModelService
1346
+
1347
+ api_client = ApiClient(credentials=self._get_credentials())
1348
+ model_info = (
1349
+ ModelService(api_client=api_client).retrieve(id=self.model_name).result
1350
+ )
1351
+ return model_info.dict()
1352
+
1353
  def get_token_count(self, dataset):
1354
  texts = [instance["source"] for instance in dataset]
1355
  token_counts = list(
 
1669
  return OpenAI(api_key=api_key, base_url=api_url)
1670
 
1671
 
1672
+ @deprecation(
1673
+ version="2.0.0",
1674
+ msg=" You can specify inference parameters directly when initializing an inference engine.",
1675
+ )
1676
  class WMLInferenceEngineParamsMixin(Artifact):
1677
  decoding_method: Optional[Literal["greedy", "sample"]] = None
1678
  length_penalty: Optional[Dict[str, Union[int, float]]] = None
 
1708
  return_options: Optional[Dict[str, bool]] = None
1709
 
1710
 
1711
+ class WMLGenerationParamsMixin(Artifact):
1712
+ decoding_method: Optional[Literal["greedy", "sample"]] = None
1713
+ length_penalty: Optional[Dict[str, Union[int, float]]] = None
1714
+ temperature: Optional[float] = None
1715
+ top_p: Optional[float] = None
1716
+ top_k: Optional[int] = None
1717
+ random_seed: Optional[int] = None
1718
+ repetition_penalty: Optional[float] = None
1719
+ min_new_tokens: Optional[int] = None
1720
+ max_new_tokens: Optional[int] = None
1721
+ stop_sequences: Optional[List[str]] = None
1722
+ time_limit: Optional[int] = None
1723
+ truncate_input_tokens: Optional[int] = None
1724
+ prompt_variables: Optional[Dict[str, Any]] = None
1725
+ return_options: Optional[Dict[str, bool]] = None
1726
+
1727
+
1728
+ class WMLChatParamsMixin(Artifact):
1729
+ frequency_penalty: Optional[float] = None
1730
+ top_logprobs: Optional[int] = 5
1731
+ presence_penalty: Optional[float] = None
1732
+ response_format: Optional[Dict[str, Any]] = None
1733
+ temperature: Optional[float] = None
1734
+ max_tokens: Optional[int] = None
1735
+ time_limit: Optional[int] = None
1736
+ top_p: Optional[float] = None
1737
+ n: Optional[int] = None
1738
+
1739
+
1740
+ CredentialsWML = Dict[
1741
+ Literal["url", "username", "password", "apikey", "project_id", "space_id"], str
1742
+ ]
1743
+
1744
+
1745
+ class WMLInferenceEngineBase(
1746
  InferenceEngine,
 
1747
  PackageRequirementsMixin,
1748
  LogProbInferenceEngine,
1749
  OptionSelectingByLogProbsInferenceEngine,
1750
  ):
1751
+ """Base for classes running inference using ibm-watsonx-ai.
1752
 
1753
  Attributes:
1754
  credentials (Dict[str, str], optional): By default, it is created by a class
1755
  instance which tries to retrieve proper environment variables
1756
+ ("WML_URL", "WML_PROJECT_ID", "WML_SPACE_ID", "WML_APIKEY", "WML_USERNAME", "WML_PASSWORD").
1757
+ However, a dictionary with the following keys: "url", "apikey", "project_id", "space_id",
1758
+ "username", "password".
1759
+ can be directly provided instead.
1760
  model_name (str, optional): ID of a model to be used for inference. Mutually
1761
  exclusive with 'deployment_id'.
1762
  deployment_id (str, optional): Deployment ID of a tuned model to be used for
1763
  inference. Mutually exclusive with 'model_name'.
1764
+ parameters (Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin], optional):
1765
+ Defines inference parameters and their values. Deprecated attribute, please pass respective
1766
+ parameters directly to the respective class instead.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1767
  """
1768
 
1769
+ credentials: Optional[CredentialsWML] = None
1770
  model_name: Optional[str] = None
1771
  deployment_id: Optional[str] = None
1772
  label: str = "wml"
1773
  _requirements_list = {
1774
+ "ibm_watsonx_ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
1775
  "It is advised to have Python version >=3.10 installed, as at lower version this package "
1776
  "may cause conflicts with other installed packages."
1777
  }
1778
  data_classification_policy = ["public", "proprietary"]
1779
+ parameters: Optional[
1780
+ Union[WMLInferenceEngineParams, WMLGenerationParamsMixin, WMLChatParamsMixin]
1781
+ ] = None
1782
+
1783
  _client: Any = InternalField(default=None, name="WML client")
1784
+ _model: Any = InternalField(default=None, name="WML model")
1785
 
1786
  def get_engine_id(self):
1787
+ return get_model_and_label_id(self.model_name or self.deployment_id, self.label)
1788
 
1789
  def verify(self):
1790
  super().verify()
1791
 
 
 
 
 
 
 
 
1792
  assert (
1793
  self.model_name
1794
  or self.deployment_id
 
1804
  data["credentials"][key] = value
1805
  return data
1806
 
1807
+ def _initialize_wml_client(self):
1808
+ from ibm_watsonx_ai.client import APIClient
1809
+
1810
+ if self.credentials is None:
1811
+ self.credentials = self._read_wml_credentials_from_env()
1812
+ self._verify_wml_credentials(self.credentials)
1813
+
1814
+ client = APIClient(credentials=self.credentials)
1815
+ if "space_id" in self.credentials:
1816
+ client.set.default_space(self.credentials["space_id"])
1817
+ else:
1818
+ client.set.default_project(self.credentials["project_id"])
1819
+ return client
1820
+
1821
  @staticmethod
1822
+ def _read_wml_credentials_from_env() -> CredentialsWML:
1823
+ credentials: CredentialsWML = {}
1824
+
1825
+ url = os.environ.get("WML_URL")
1826
+ assert url, (
1827
+ "Error while trying to run 'WMLInferenceEngine'. "
1828
+ "Please set the env variable: 'WML_URL'"
1829
  )
1830
+ credentials["url"] = url
1831
 
1832
+ space_id = os.environ.get("WML_SPACE_ID")
1833
+ project_id = os.environ.get("WML_PROJECT_ID")
1834
+ if space_id and project_id:
1835
+ get_logger().warning(
1836
+ "Either 'WML_SPACE_ID' or 'WML_PROJECT_ID' need to be "
1837
+ "specified, however, both were found. 'WMLInferenceEngine' "
1838
+ "will use space by default. If it is not desired, then have "
1839
+ "only one of those defined in the env."
1840
+ )
1841
+ credentials["space_id"] = space_id
1842
+ elif project_id:
1843
+ credentials["project_id"] = project_id
1844
+ else:
1845
+ raise AssertionError(
1846
+ "Error while trying to run 'WMLInferenceEngine'. "
1847
+ "Please set either 'WML_SPACE_ID' or 'WML_PROJECT_ID' env "
1848
+ "variable."
1849
+ )
1850
+
1851
+ apikey = os.environ.get("WML_APIKEY")
1852
+ username = os.environ.get("WML_USERNAME")
1853
+ password = os.environ.get("WML_PASSWORD")
1854
+
1855
+ if apikey and username and password:
1856
+ get_logger().warning(
1857
+ "Either 'WML_APIKEY' or both 'WML_USERNAME' and 'WML_PASSWORD' "
1858
+ "need to be specified, however, all of them were found. "
1859
+ "'WMLInferenceEngine' will use api key only by default. If it is not "
1860
+ "desired, then have only one of those options defined in the env."
1861
  )
1862
 
1863
+ if apikey:
1864
+ credentials["apikey"] = apikey
1865
+ elif username and password:
1866
+ credentials["username"] = username
1867
+ credentials["password"] = password
1868
+ else:
1869
+ raise AssertionError(
1870
+ "Error while trying to run 'WMLInferenceEngine'. "
1871
+ "Please set either 'WML_APIKEY' or both 'WML_USERNAME' and "
1872
+ "'WML_PASSWORD' env variables."
1873
+ )
1874
 
1875
  return credentials
1876
 
1877
+ @staticmethod
1878
+ def _verify_wml_credentials(credentials: CredentialsWML) -> None:
1879
+ assert isoftype(credentials, CredentialsWML), (
1880
+ "WML credentials object must be a dictionary which may "
1881
+ "contain only the following keys: "
1882
+ "['url', 'apikey', 'username', 'password']."
1883
+ )
1884
 
1885
+ assert credentials.get(
1886
+ "url"
1887
+ ), "'url' is a mandatory key for WML credentials dict."
1888
+ assert "space_id" in credentials or "project_id" in credentials, (
1889
+ "Either 'space_id' or 'project_id' must be provided "
1890
+ "as keys for WML credentials dict."
1891
+ )
1892
+ assert "apikey" in credentials or (
1893
+ "username" in credentials and "password" in credentials
1894
+ ), (
1895
+ "Either 'apikey' or both 'username' and 'password' must be provided "
1896
+ "as keys for WML credentials dict."
1897
+ )
1898
 
1899
  def prepare_engine(self):
1900
+ self.check_missing_requirements()
1901
+
1902
  self._client = self._initialize_wml_client()
1903
 
1904
  self._set_inference_parameters()
1905
 
1906
+ def _load_model(self):
1907
+ from ibm_watsonx_ai.foundation_models.inference import ModelInference
1908
 
1909
+ self._model = ModelInference(
1910
  model_id=self.model_name,
1911
  deployment_id=self.deployment_id,
1912
  api_client=self._client,
1913
  )
 
1914
 
1915
+ @abc.abstractmethod
1916
+ def _send_requests(
1917
+ self,
1918
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
1919
+ return_logprobs: bool,
1920
+ return_meta_data: bool,
1921
+ ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
1922
+ raise NotImplementedError(
1923
+ f"The class '{self.get_pretty_print_name()}' is an abstract class. "
1924
+ f"Please used either 'WMLInferenceEngineGeneration' or "
1925
+ f"'WMLInferenceEngineChat' instead, depending on your task."
1926
+ )
1927
 
1928
  def _infer(
1929
  self,
1930
  dataset: Union[List[Dict[str, Any]], DatasetDict],
1931
  return_meta_data: bool = False,
1932
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
1933
+ if self._model is None:
1934
+ self._load_model()
 
 
 
 
 
 
 
 
 
 
 
 
1935
 
1936
+ return self._send_requests(
1937
+ dataset=dataset,
1938
+ return_logprobs=False,
1939
+ return_meta_data=return_meta_data,
1940
+ )
1941
 
1942
  def _infer_log_probs(
1943
  self,
1944
  dataset: Union[List[Dict[str, Any]], DatasetDict],
1945
  return_meta_data: bool = False,
1946
  ) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
1947
+ if self._model is None:
1948
+ self._load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1949
 
1950
+ return self._send_requests(
1951
+ dataset=dataset,
1952
+ return_logprobs=True,
1953
+ return_meta_data=return_meta_data,
1954
  )
 
 
 
 
 
 
 
1955
 
1956
+ @abc.abstractmethod
1957
+ def get_return_object(self, predict_result, result, input_text, return_meta_data):
1958
+ raise NotImplementedError
1959
+
1960
+ def get_model_details(self) -> Dict:
1961
+ return self._model.get_details()
 
 
 
 
1962
 
1963
  def get_token_count(self, dataset):
1964
+ if self._model is None:
1965
+ self._load_model()
1966
 
1967
  texts = [instance["source"] for instance in dataset]
1968
 
 
 
 
 
 
 
1969
  for i in trange(len(texts), desc="Tokenizing"):
1970
+ response = self._model.tokenize(prompt=texts[i], return_tokens=True)[
1971
+ "result"
1972
+ ]
1973
  dataset[i]["token_count"] = response["token_count"]
1974
 
1975
  return dataset
1976
 
1977
  def get_options_log_probs(self, dataset):
1978
  """Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}."""
1979
+ if self._model is None:
1980
+ self._load_model()
 
 
 
 
 
1981
 
1982
  texts = [x["source"] for x in dataset]
1983
 
1984
  responses = list(
1985
  tqdm(
1986
+ self._model.generate(
1987
  prompt=texts,
1988
  params={
1989
  "decoding_method": "greedy",
 
2015
  return dataset
2016
 
2017
 
2018
+ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMixin):
2019
+ """Generates text for textual inputs.
2020
 
2021
+ If you want to include images in your input, please use 'WMLInferenceEngineChat' instead.
2022
 
2023
+ Attributes:
2024
+ concurrency_limit (int): Number of concurrent requests sent to a model. Default is 10,
2025
+ which is also the maximum value.
2026
 
2027
+ Examples:
2028
+ from .api import load_dataset
2029
 
2030
+ wml_credentials = {
2031
+ "url": "some_url", "project_id": "some_id", "api_key": "some_key"
2032
+ }
2033
+ model_name = "google/flan-t5-xxl"
2034
+ wml_inference = WMLInferenceEngineGeneration(
2035
+ credentials=wml_credentials,
2036
+ model_name=model_name,
2037
+ data_classification_policy=["public"],
2038
+ top_p=0.5,
2039
+ random_seed=123,
2040
+ )
2041
 
2042
+ dataset = load_dataset(
2043
+ dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5"
2044
+ )
2045
+ results = wml_inference.infer(dataset["test"])
2046
+ """
2047
 
2048
+ concurrency_limit: int = 10
 
2049
 
2050
+ def verify(self):
2051
+ super().verify()
 
2052
 
2053
+ assert (
2054
+ isinstance(self.concurrency_limit, int)
2055
+ and 1 <= self.concurrency_limit <= 10
2056
+ ), (
2057
+ f"'concurrency_limit' must be a positive integer not greater than 10. "
2058
+ f"However, '{self.concurrency_limit}' was given."
2059
  )
2060
 
2061
+ def _set_logprobs_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
2062
+ user_return_options = params.pop("return_options", {})
2063
+ # currently this is the only configuration that returns generated
2064
+ # logprobs and behaves as expected
2065
+ logprobs_return_options = {
2066
+ "input_tokens": True,
2067
+ "generated_tokens": True,
2068
+ "token_logprobs": True,
2069
+ "top_n_tokens": user_return_options.get("top_n_tokens", 5),
2070
+ }
 
2071
 
2072
+ for key, value in logprobs_return_options.items():
2073
+ if key in user_return_options and user_return_options[key] != value:
2074
+ raise ValueError(
2075
+ f"'{key}={user_return_options[key]}' is not supported for the 'infer_log_probs' "
2076
+ f"method of {self.__class__.__name__}. For obtaining the logprobs of generated tokens "
2077
+ f"please use '{key}={value}'."
2078
+ )
2079
 
2080
+ return {
2081
+ **params,
2082
+ "return_options": logprobs_return_options,
2083
+ }
 
 
 
 
 
 
 
 
 
 
2084
 
2085
+ def _send_requests(
2086
  self,
2087
  dataset: Union[List[Dict[str, Any]], DatasetDict],
2088
+ return_logprobs: bool,
2089
+ return_meta_data: bool,
2090
+ ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
2091
+ self.verify_not_chat_api(dataset)
2092
 
2093
+ params = self.to_dict([WMLGenerationParamsMixin], keep_empty=False)
2094
 
2095
+ if return_logprobs:
2096
+ generation_type = "generated_tokens"
2097
+ params = self._set_logprobs_params(params)
2098
+ else:
2099
+ generation_type = "generated_text"
2100
 
2101
+ inputs: List[str] = [instance["source"] for instance in dataset]
 
2102
 
2103
+ results = self._model.generate(
2104
+ prompt=inputs,
2105
+ params=params,
2106
+ concurrency_limit=self.concurrency_limit,
2107
+ )
2108
 
2109
+ final_results = []
2110
+ for result, inp in zip(results, inputs):
2111
+ result_metadata = result["results"][0]
2112
+ generated_content = result_metadata[generation_type]
2113
+ final_results.append(
2114
+ self.get_return_object(
2115
+ generated_content, result_metadata, inp, return_meta_data
2116
+ )
2117
  )
2118
+ return final_results
2119
 
2120
+ def get_return_object(self, predict_result, result, input_text, return_meta_data):
2121
+ if return_meta_data:
2122
+ return TextGenerationInferenceOutput(
2123
+ prediction=predict_result,
2124
+ input_tokens=result["input_token_count"],
2125
+ output_tokens=result["generated_token_count"],
2126
+ model_name=self.model_name or self.deployment_id,
2127
+ inference_type=self.label,
2128
+ stop_reason=result["stop_reason"],
2129
+ seed=self.random_seed,
2130
+ input_text=input_text,
2131
  )
2132
+ return predict_result
2133
+
2134
+
2135
+ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2136
+ """Creates chat session and returns a model's response.
2137
+
2138
+ You can also include images in your inputs. If you use only textual input, it is
2139
+ recommended to use 'WMLInferenceEngineGeneration' instead as it is faster, and allows
2140
+ more parameters for text generation.
2141
+
2142
+ You can provide either already formatted messages, or a raw dataset as an input.
2143
+ In case of the former, all passed images should be base64-encoded strings given as
2144
+ an 'image_url' within a message. Moreover, only one image per a list of messages
2145
+ may be sent.
2146
+ As for the latter, if there are multiple images per one instance, they will be sent
2147
+ separately with the same query. If that could possibly affect expected responses,
2148
+ concatenate images within an instance into a single image and adjust your query
2149
+ accordingly (if necessary).
2150
+
2151
+ Attributes:
2152
+ image_encoder (EncodeImageToString, optional): operator which encodes images in
2153
+ given format to base64 strings required by service. You should specify it when
2154
+ you are using images in your inputs.
2155
+
2156
+ Example:
2157
+ from .api import load_dataset
2158
+ from .image_operators
2159
+
2160
+ image_encoder = EncodeImageToString(image_format="JPEG")
2161
+
2162
+ wml_credentials = {
2163
+ "url": "some_url", "project_id": "some_id", "api_key": "some_key"
2164
+ }
2165
+ model_name = "meta-llama/llama-3-2-11b-vision-instruct"
2166
+ wml_inference = WMLInferenceEngineChat(
2167
+ credentials=wml_credentials,
2168
+ model_name=model_name,
2169
+ image_encoder=image_encoder,
2170
+ data_classification_policy=["public"],
2171
+ max_tokens=1024,
2172
+ )
2173
+
2174
+ dataset = load_dataset(
2175
+ dataset_query="card=cards.doc_vqa.en,template=templates.qa.with_context.with_type,loader_limit=30"
2176
+ )
2177
+ results = wml_inference.infer(dataset["test"])
2178
+ """
2179
+
2180
+ image_encoder: Optional[EncodeImageToString] = None
2181
+
2182
+ @staticmethod
2183
+ def _extract_queries(instance: Dict[str, Any]) -> Tuple[Optional[str], List]:
2184
+ task_data = instance["task_data"]
2185
+ if isinstance(task_data, str):
2186
+ task_data = json.loads(task_data)
2187
+ question = task_data.get("question")
2188
+
2189
+ images = [None]
2190
+ if "images" in instance["media"]:
2191
+ images = extract_images(instance["source"], instance)
2192
+
2193
+ return question or instance["source"], images
2194
+
2195
+ def _create_messages_from_instance(
2196
+ self, instance: Dict[str, Any]
2197
+ ) -> List[List[Dict[str, Any]]]:
2198
+ """Method creates chat messages to be sent to a watsonx.ai model based on a given instance from a dataset."""
2199
+ text, images = self._extract_queries(instance)
2200
+
2201
+ messages: List[List[Dict[str, Any]]] = []
2202
+ base_message = {
2203
+ "role": "user",
2204
+ "content": [
2205
+ {
2206
+ "type": "text",
2207
+ "text": text,
2208
+ }
2209
+ ],
2210
+ }
2211
+
2212
+ # Iteration over all possible images to create a separate message for
2213
+ # every single image, since SDK allows only one image per request.
2214
+ for image in images:
2215
+ message = base_message.copy()
2216
+
2217
+ if image is not None:
2218
+ encoded_image = image
2219
+ if not isinstance(encoded_image, str):
2220
+ if self.image_encoder is None:
2221
+ raise ValueError(
2222
+ "If sending image queries as well, and they are not "
2223
+ "already encoded to base64 strings, you must specify "
2224
+ "the 'image_encoder' to be used."
2225
+ )
2226
+ encoded_image = self.image_encoder.encode_image_to_base64(image)
2227
+
2228
+ message["content"].append(
2229
+ {
2230
+ "type": "image_url",
2231
+ "image_url": {
2232
+ "url": "data:image/jpeg;base64," + encoded_image,
2233
+ },
2234
+ }
2235
+ )
2236
+
2237
+ messages.append([message])
2238
+
2239
+ return messages
2240
+
2241
+ @staticmethod
2242
+ def verify_messages(messages: List[Dict[str, Any]]):
2243
+ """Method verifies if externally provided messages containing images are compatible with the format required by ibm-watsonx-ai."""
2244
+ n_images = 0
2245
+ for message in messages:
2246
+ if isinstance(message["content"], str):
2247
+ continue
2248
+
2249
+ for content in message["content"]:
2250
+ if isinstance(content, dict):
2251
+ if "image" in content["type"] and content["type"] != "image_url":
2252
+ raise ValueError(
2253
+ f"ibm-watsonx-ai only supports sending images as base64-encoded "
2254
+ f"strings, which should be given as 'image_url' in a message. "
2255
+ f"However, '{content['type']}' was given."
2256
+ )
2257
+
2258
+ if content["type"] == "image_url":
2259
+ n_images += 1
2260
+ if n_images > 1:
2261
+ raise ValueError(
2262
+ "ibm-watsonx-ai only supports sending one image per a list "
2263
+ "of messages."
2264
+ )
2265
+
2266
+ def to_messages(self, instance: Union[Dict, List]) -> List[List[Dict[str, Any]]]:
2267
+ if isinstance(instance["source"], str) and "media" in instance:
2268
+ return self._create_messages_from_instance(instance)
2269
+
2270
+ messages = super().to_messages(instance)
2271
+ self.verify_messages(messages)
2272
+ # This is done to be compatible with inputs containing
2273
+ # images as SDK allows sending only one image per message.
2274
+ return [messages]
2275
+
2276
+ def _send_requests(
2277
+ self,
2278
+ dataset: Union[List[Dict[str, Any]], DatasetDict],
2279
+ return_logprobs: bool,
2280
+ return_meta_data: bool,
2281
+ ) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
2282
+ params = self.to_dict([WMLChatParamsMixin], keep_empty=False)
2283
+
2284
+ if return_logprobs:
2285
+ output_type = "logprobs"
2286
+ params["logprobs"] = True
2287
+ else:
2288
+ output_type = "message"
2289
+ params["logprobs"] = False
2290
+
2291
+ final_results = []
2292
+
2293
+ for instance in dataset:
2294
+ messages = self.to_messages(instance)
2295
+
2296
+ for message in messages:
2297
+ result = self._model.chat(
2298
+ messages=message,
2299
+ params=params,
2300
+ )
2301
+
2302
+ final_results.append(
2303
+ self.get_return_object(
2304
+ result["choices"][0][output_type]["content"],
2305
+ result,
2306
+ instance["source"],
2307
+ return_meta_data,
2308
+ )
2309
+ )
2310
+
2311
+ return final_results
2312
+
2313
+ def get_return_object(self, predict_result, result, input_text, return_meta_data):
2314
+ if return_meta_data:
2315
+ return TextGenerationInferenceOutput(
2316
+ prediction=predict_result,
2317
+ input_tokens=result["usage"]["prompt_tokens"],
2318
+ output_tokens=len(predict_result)
2319
+ if isinstance(predict_result, list)
2320
+ else None,
2321
+ model_name=self.model_name or self.deployment_id,
2322
+ inference_type=self.label,
2323
+ stop_reason=result["choices"][0]["finish_reason"],
2324
+ input_text=input_text,
2325
  )
2326
+ return predict_result
2327
 
2328
+
2329
+ @deprecation(
2330
+ version="2.0.0",
2331
+ msg=" Please use either 'WMLInferenceEngineGeneration' or 'WMLInferenceEngineChat'"
2332
+ " depending on your task.",
2333
+ )
2334
+ class WMLInferenceEngine(WMLInferenceEngineGeneration):
2335
+ def prepare_engine(self):
2336
+ super().prepare_engine()
2337
+ get_logger().warning("'WMLInferenceEngine' is deprecated")
2338
+
2339
+
2340
+ def get_images_without_text(instance):
2341
+ return extract_images(instance["source"], instance)
2342
+
2343
+
2344
+ def get_text_without_images(instance, image_token="<image>"):
2345
+ regex = r"<" + f"{constants.image_tag}" + r'\s+src=["\'](.*?)["\']\s*/?>'
2346
+ return re.sub(regex, image_token, instance["source"])
2347
 
2348
 
2349
  class LMMSEvalBaseInferenceEngine(
 
2354
  batch_size: int = 1
2355
  image_token = "<image>"
2356
 
2357
+ _requirements_list = {
2358
+ "lmms_eval": "Install llms-eval package using 'pip install lmms-eval==0.2.4'",
2359
+ }
2360
 
2361
  def prepare_engine(self):
2362
  if not self.lazy_load:
 
2403
  dataset: Union[List[Dict[str, Any]], DatasetDict],
2404
  return_meta_data: bool = False,
2405
  ) -> Union[List[str], List[TextGenerationInferenceOutput]]:
 
2406
  if not self._is_loaded():
2407
  self._prepare_engine()
2408
 
llm_as_judge.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from abc import abstractmethod
2
  from typing import Any, Dict, List, Literal, Optional
3
 
@@ -23,7 +24,7 @@ def get_task_data_dict(task_data):
23
  return json.loads(task_data) if isinstance(task_data, str) else task_data
24
 
25
 
26
- class LLMAsJudgeBase(BulkInstanceMetric):
27
  """LLM-as-judge-base metric class for evaluating correctness of generated predictions.
28
 
29
  Attributes:
@@ -122,7 +123,7 @@ class LLMAsJudgeBase(BulkInstanceMetric):
122
  pass
123
 
124
 
125
- class LLMAsJudge(LLMAsJudgeBase, ArtifactFetcherMixin):
126
  """LLM-as-judge-based metric class for evaluating correctness of generated predictions.
127
 
128
  This class uses the source prompt given to the generator and the generator's predictions to evaluate
@@ -371,6 +372,17 @@ class TaskBasedLLMasJudge(LLMAsJudgeBase):
371
  super().prepare()
372
  self.reduction_map = {"mean": [self.main_score]}
373
  self.score_prefix = f"{self.inference_model.get_engine_id()}_"
 
 
 
 
 
 
 
 
 
 
 
374
 
375
  def get_full_task_name(self):
376
  return self.task
 
1
+ import re
2
  from abc import abstractmethod
3
  from typing import Any, Dict, List, Literal, Optional
4
 
 
24
  return json.loads(task_data) if isinstance(task_data, str) else task_data
25
 
26
 
27
+ class LLMAsJudgeBase(BulkInstanceMetric, ArtifactFetcherMixin):
28
  """LLM-as-judge-base metric class for evaluating correctness of generated predictions.
29
 
30
  Attributes:
 
123
  pass
124
 
125
 
126
+ class LLMAsJudge(LLMAsJudgeBase):
127
  """LLM-as-judge-based metric class for evaluating correctness of generated predictions.
128
 
129
  This class uses the source prompt given to the generator and the generator's predictions to evaluate
 
372
  super().prepare()
373
  self.reduction_map = {"mean": [self.main_score]}
374
  self.score_prefix = f"{self.inference_model.get_engine_id()}_"
375
+ if not self.format:
376
+ self.set_format_for_inference_engine()
377
+
378
+ # if format is not directly set in constructor, choose according to the inference model
379
+ def set_format_for_inference_engine(self):
380
+ model_name = self.inference_model.get_engine_id()
381
+ if re.search("llama.?3.*instruct", model_name):
382
+ format_name = "formats.llama3_instruct"
383
+ else:
384
+ format_name = "formats.empty"
385
+ self.format = self.get_artifact(format_name)
386
 
387
  def get_full_task_name(self):
388
  return self.task
loaders.py CHANGED
@@ -1,7 +1,7 @@
1
  """This section describes unitxt loaders.
2
 
3
  Loaders: Generators of Unitxt Multistreams from existing date sources
4
- ==============================================================
5
 
6
  Unitxt is all about readily preparing of any given data source for feeding into any given language model, and then,
7
  post-processing the model's output, preparing it for any given evaluator.
@@ -16,14 +16,14 @@ All these loaders inherit from Loader, and hence, implementing a loader to expan
16
  straightforward.
17
 
18
  Available Loaders Overview:
19
- - :ref:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from HuggingFace Datasets.
20
- - :ref:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
21
- - :ref:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
22
- - :ref:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
23
- - :ref:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
24
- - :ref:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
25
- - :ref:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
26
- - :ref:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
27
 
28
 
29
 
 
1
  """This section describes unitxt loaders.
2
 
3
  Loaders: Generators of Unitxt Multistreams from existing date sources
4
+ =====================================================================
5
 
6
  Unitxt is all about readily preparing of any given data source for feeding into any given language model, and then,
7
  post-processing the model's output, preparing it for any given evaluator.
 
16
  straightforward.
17
 
18
  Available Loaders Overview:
19
+ - :class:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from HuggingFace Datasets.
20
+ - :class:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
21
+ - :class:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
22
+ - :class:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
23
+ - :class:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
24
+ - :class:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
25
+ - :class:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
26
+ - :class:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
27
 
28
 
29
 
metrics.py CHANGED
@@ -18,6 +18,7 @@ from scipy.stats import bootstrap
18
  from scipy.stats._warnings_errors import DegenerateDataWarning
19
 
20
  from .artifact import Artifact
 
21
  from .dataclass import (
22
  AbstractField,
23
  InternalField,
@@ -50,6 +51,12 @@ settings = get_settings()
50
  warnings.filterwarnings("ignore", category=DegenerateDataWarning)
51
 
52
 
 
 
 
 
 
 
53
  def abstract_factory():
54
  return {}
55
 
 
18
  from scipy.stats._warnings_errors import DegenerateDataWarning
19
 
20
  from .artifact import Artifact
21
+ from .collections import ListCollection
22
  from .dataclass import (
23
  AbstractField,
24
  InternalField,
 
51
  warnings.filterwarnings("ignore", category=DegenerateDataWarning)
52
 
53
 
54
+ class MetricsList(ListCollection):
55
+ def verify(self):
56
+ for metric in self.items:
57
+ assert isinstance(metric, Metric)
58
+
59
+
60
  def abstract_factory():
61
  return {}
62
 
operators.py CHANGED
@@ -1617,7 +1617,7 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1617
  calc_confidence_intervals: bool
1618
 
1619
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1620
- from .metrics import Metric
1621
 
1622
  # Number of instances in input stream is assumed to be small. This is why
1623
  # each metric consumes all of them and lays them in its main memory, and even generates
@@ -1646,18 +1646,25 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1646
  if isinstance(metric_names, str):
1647
  metric_names = [metric_names]
1648
 
 
 
 
 
 
 
 
 
 
 
 
 
1649
  # Each metric operator computes its score and then sets the main score, overwriting
1650
  # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1651
  # This will cause the first listed metric to run last, and the main score will be set
1652
  # by the first listed metric (as desired).
1653
- metric_names = list(reversed(metric_names))
1654
-
1655
- for metric_name in metric_names:
1656
- metric = self.get_artifact(metric_name)
1657
- assert isinstance(
1658
- metric, Metric
1659
- ), f"Operator {metric_name} must be a Metric"
1660
 
 
1661
  if not self.calc_confidence_intervals:
1662
  metric.disable_confidence_interval_calculation()
1663
  multi_stream = MultiStream(
 
1617
  calc_confidence_intervals: bool
1618
 
1619
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1620
+ from .metrics import Metric, MetricsList
1621
 
1622
  # Number of instances in input stream is assumed to be small. This is why
1623
  # each metric consumes all of them and lays them in its main memory, and even generates
 
1646
  if isinstance(metric_names, str):
1647
  metric_names = [metric_names]
1648
 
1649
+ metrics_list = []
1650
+ for metric_name in metric_names:
1651
+ metric = self.get_artifact(metric_name)
1652
+ if isinstance(metric, MetricsList):
1653
+ metrics_list.extend(list(metric.items))
1654
+ elif isinstance(metric, Metric):
1655
+ metrics_list.append(metric)
1656
+ else:
1657
+ raise ValueError(
1658
+ f"Operator {metric_name} must be a Metric or MetricsList"
1659
+ )
1660
+
1661
  # Each metric operator computes its score and then sets the main score, overwriting
1662
  # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1663
  # This will cause the first listed metric to run last, and the main score will be set
1664
  # by the first listed metric (as desired).
1665
+ metrics_list = list(reversed(metrics_list))
 
 
 
 
 
 
1666
 
1667
+ for metric in metrics_list:
1668
  if not self.calc_confidence_intervals:
1669
  metric.disable_confidence_interval_calculation()
1670
  multi_stream = MultiStream(
settings_utils.py CHANGED
@@ -161,8 +161,8 @@ if Constants.is_uninitilized():
161
  constants.metric_file = os.path.join(os.path.dirname(__file__), "metric.py")
162
  constants.local_catalog_path = os.path.join(os.path.dirname(__file__), "catalog")
163
  unitxt_pkg = importlib.util.find_spec("unitxt")
164
- constants.package_dir = os.path.dirname(unitxt_pkg.origin)
165
  if unitxt_pkg and unitxt_pkg.origin:
 
166
  constants.default_catalog_path = os.path.join(constants.package_dir, "catalog")
167
  else:
168
  constants.default_catalog_path = constants.local_catalog_path
 
161
  constants.metric_file = os.path.join(os.path.dirname(__file__), "metric.py")
162
  constants.local_catalog_path = os.path.join(os.path.dirname(__file__), "catalog")
163
  unitxt_pkg = importlib.util.find_spec("unitxt")
 
164
  if unitxt_pkg and unitxt_pkg.origin:
165
+ constants.package_dir = os.path.dirname(unitxt_pkg.origin)
166
  constants.default_catalog_path = os.path.join(constants.package_dir, "catalog")
167
  else:
168
  constants.default_catalog_path = constants.local_catalog_path
standard.py CHANGED
@@ -1,9 +1,7 @@
1
  from typing import List, Optional, Union
2
 
3
  from .artifact import fetch_artifact
4
- from .augmentors import (
5
- Augmentor,
6
- )
7
  from .card import TaskCard
8
  from .collections_operators import GetLength
9
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
@@ -21,6 +19,7 @@ from .stream import MultiStream
21
  from .system_prompts import EmptySystemPrompt, SystemPrompt
22
  from .task import Task
23
  from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
 
24
  from .utils import LRUCache
25
 
26
  constants = get_constants()
@@ -305,7 +304,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
305
 
306
  self.processing.steps.append(self.task)
307
 
308
- if self.augmentor is not None:
309
  if (
310
  self.card.task.augmentable_inputs is None
311
  or len(self.task.augmentable_inputs) == 0
@@ -484,14 +483,12 @@ class StandardRecipe(StandardRecipeWithIndexes):
484
  sampler (Sampler, optional): The Sampler used to select the demonstrations when num_demos > 0.
485
  steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
486
  augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
487
- instruction_card_index (int, optional): Index of instruction card to be used
488
- for preparing the recipe.
489
- template_card_index (int, optional): Index of template card to be used for
490
- preparing the recipe.
491
 
492
  Methods:
493
  prepare(): This overridden method is used for preparing the recipe
494
- by arranging all the steps, refiners, and renderers in a sequential manner.
495
 
496
  Raises:
497
  AssertionError: If both template and template_card_index are specified at the same time.
 
1
  from typing import List, Optional, Union
2
 
3
  from .artifact import fetch_artifact
4
+ from .augmentors import Augmentor, NullAugmentor
 
 
5
  from .card import TaskCard
6
  from .collections_operators import GetLength
7
  from .dataclass import Field, InternalField, NonPositionalField, OptionalField
 
19
  from .system_prompts import EmptySystemPrompt, SystemPrompt
20
  from .task import Task
21
  from .templates import ApplyRandomTemplate, ApplySingleTemplate, Template, TemplatesList
22
+ from .type_utils import isoftype
23
  from .utils import LRUCache
24
 
25
  constants = get_constants()
 
304
 
305
  self.processing.steps.append(self.task)
306
 
307
+ if self.augmentor is not None and not isoftype(self.augmentor, NullAugmentor):
308
  if (
309
  self.card.task.augmentable_inputs is None
310
  or len(self.task.augmentable_inputs) == 0
 
483
  sampler (Sampler, optional): The Sampler used to select the demonstrations when num_demos > 0.
484
  steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
485
  augmentor (Augmentor) : Augmentor to be used to pseudo randomly augment the source text
486
+ instruction_card_index (int, optional): Index of instruction card to be used for preparing the recipe.
487
+ template_card_index (int, optional): Index of template card to be used for preparing the recipe.
 
 
488
 
489
  Methods:
490
  prepare(): This overridden method is used for preparing the recipe
491
+ by arranging all the steps, refiners, and renderers in a sequential manner.
492
 
493
  Raises:
494
  AssertionError: If both template and template_card_index are specified at the same time.
task.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Union
5
  from .deprecation_utils import deprecation
6
  from .error_utils import Documentation, UnitxtError, UnitxtWarning
7
  from .logging_utils import get_logger
 
8
  from .operator import InstanceOperator
9
  from .operators import ArtifactFetcherMixin
10
  from .settings_utils import get_constants
@@ -186,31 +187,34 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
186
 
187
  @classmethod
188
  @lru_cache(maxsize=None)
189
- def get_metric_prediction_type(cls, metric_id: str):
190
  metric = cls.get_artifact(metric_id)
191
- return metric.prediction_type
 
 
192
 
193
  def check_metrics_type(self) -> None:
194
  prediction_type = self.prediction_type
195
  for metric_id in self.metrics:
196
- metric_prediction_type = Task.get_metric_prediction_type(metric_id)
197
-
198
- if (
199
- prediction_type == metric_prediction_type
200
- or prediction_type == Any
201
- or metric_prediction_type == Any
202
- or (
203
- get_origin(metric_prediction_type) is Union
204
- and prediction_type in get_args(metric_prediction_type)
205
- )
206
- ):
207
- continue
 
208
 
209
- raise UnitxtError(
210
- f"The task's prediction type ({prediction_type}) and '{metric_id}' "
211
- f"metric's prediction type ({metric_prediction_type}) are different.",
212
- Documentation.ADDING_TASK,
213
- )
214
 
215
  def verify_defaults(self):
216
  if self.defaults:
 
5
  from .deprecation_utils import deprecation
6
  from .error_utils import Documentation, UnitxtError, UnitxtWarning
7
  from .logging_utils import get_logger
8
+ from .metrics import MetricsList
9
  from .operator import InstanceOperator
10
  from .operators import ArtifactFetcherMixin
11
  from .settings_utils import get_constants
 
187
 
188
  @classmethod
189
  @lru_cache(maxsize=None)
190
+ def get_metrics_artifacts(cls, metric_id: str):
191
  metric = cls.get_artifact(metric_id)
192
+ if isinstance(metric, MetricsList):
193
+ return metric.items
194
+ return [metric]
195
 
196
  def check_metrics_type(self) -> None:
197
  prediction_type = self.prediction_type
198
  for metric_id in self.metrics:
199
+ metric_artifacts_list = Task.get_metrics_artifacts(metric_id)
200
+ for metric_artifact in metric_artifacts_list:
201
+ metric_prediction_type = metric_artifact.prediction_type
202
+ if (
203
+ prediction_type == metric_prediction_type
204
+ or prediction_type == Any
205
+ or metric_prediction_type == Any
206
+ or (
207
+ get_origin(metric_prediction_type) is Union
208
+ and prediction_type in get_args(metric_prediction_type)
209
+ )
210
+ ):
211
+ continue
212
 
213
+ raise UnitxtError(
214
+ f"The task's prediction type ({prediction_type}) and '{metric_id}' "
215
+ f"metric's prediction type ({metric_prediction_type}) are different.",
216
+ Documentation.ADDING_TASK,
217
+ )
218
 
219
  def verify_defaults(self):
220
  if self.defaults:
text_utils.py CHANGED
@@ -137,7 +137,8 @@ def construct_dict_as_yaml_lines(d, indent_delta=2) -> List[str]:
137
  if len(d) == 0:
138
  return ["{}"]
139
  for key, val in d.items():
140
- res.append(key + ": ")
 
141
  yaml_for_val = construct_dict_as_yaml_lines(val, indent_delta=indent_delta)
142
  assert len(yaml_for_val) > 0
143
  if is_simple(val):
 
137
  if len(d) == 0:
138
  return ["{}"]
139
  for key, val in d.items():
140
+ printable_key = f'"{key}"' if (" " in key) or (key == "") else key
141
+ res.append(printable_key + ": ")
142
  yaml_for_val = construct_dict_as_yaml_lines(val, indent_delta=indent_delta)
143
  assert len(yaml_for_val) > 0
144
  if is_simple(val):
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.15.6"
 
1
+ version = "1.15.7"