Elron commited on
Commit
5b41acf
·
verified ·
1 Parent(s): 66630b0

Upload folder using huggingface_hub

Browse files
Files changed (14) hide show
  1. api.py +10 -8
  2. artifact.py +11 -2
  3. dataset.py +0 -1
  4. dataset_utils.py +8 -5
  5. inference.py +86 -29
  6. metric.py +0 -1
  7. metrics.py +76 -32
  8. operators.py +47 -0
  9. serializers.py +1 -6
  10. struct_data_operators.py +21 -1
  11. tool_calling.py +0 -119
  12. type_utils.py +15 -3
  13. types.py +12 -6
  14. version.py +1 -1
api.py CHANGED
@@ -37,12 +37,11 @@ def short_hex_hash(value, length=8):
37
  return h[:length]
38
 
39
 
40
- def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
41
- dataset_query = dataset_query.replace("sys_prompt", "instruction")
42
  try:
43
- dataset_stream, _ = fetch_artifact(dataset_query)
44
  except:
45
- dataset_stream = get_dataset_artifact(dataset_query)
46
  return dataset_stream
47
 
48
 
@@ -82,14 +81,15 @@ def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
82
  if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
83
  return dataset_query
84
 
85
- _verify_dataset_args(dataset_query, kwargs)
86
-
87
  if dataset_query:
88
- recipe = _get_recipe_from_query(dataset_query)
89
 
90
- if kwargs:
91
  recipe = _get_recipe_from_dict(kwargs)
92
 
 
 
 
93
  return recipe
94
 
95
 
@@ -187,6 +187,8 @@ def load_dataset(
187
  Alternatively, dataset is loaded from a provided card based on explicitly
188
  given parameters.
189
 
 
 
190
  Args:
191
  dataset_query (str, optional):
192
  A string query which specifies a dataset to load from
 
37
  return h[:length]
38
 
39
 
40
+ def _get_recipe_from_query(dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]]=None) -> DatasetRecipe:
 
41
  try:
42
+ dataset_stream, _ = fetch_artifact(dataset_query, overwrite_kwargs=overwrite_kwargs)
43
  except:
44
+ dataset_stream = get_dataset_artifact(dataset_query, overwrite_kwargs=overwrite_kwargs)
45
  return dataset_stream
46
 
47
 
 
81
  if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
82
  return dataset_query
83
 
 
 
84
  if dataset_query:
85
+ recipe = _get_recipe_from_query(dataset_query, kwargs)
86
 
87
+ elif kwargs:
88
  recipe = _get_recipe_from_dict(kwargs)
89
 
90
+ else:
91
+ raise UnitxtError("Specify either dataset recipe string artifact name or recipe args.")
92
+
93
  return recipe
94
 
95
 
 
187
  Alternatively, dataset is loaded from a provided card based on explicitly
188
  given parameters.
189
 
190
+ If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
191
+
192
  Args:
193
  dataset_query (str, optional):
194
  A string query which specifies a dataset to load from
artifact.py CHANGED
@@ -22,7 +22,7 @@ from .parsing_utils import (
22
  separate_inside_and_outside_square_brackets,
23
  )
24
  from .settings_utils import get_constants, get_settings
25
- from .text_utils import camel_to_snake_case, is_camel_case
26
  from .type_utils import isoftype, issubtype
27
  from .utils import (
28
  artifacts_json_cache,
@@ -369,6 +369,10 @@ class Artifact(Dataclass):
369
  data = self.to_dict()
370
  return json_dump(data)
371
 
 
 
 
 
372
  def serialize(self):
373
  if self.__id__ is not None:
374
  return self.__id__
@@ -528,7 +532,7 @@ class UnitxtArtifactNotFoundError(UnitxtError):
528
  super().__init__(msg)
529
 
530
 
531
- def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]]:
532
  """Loads an artifict from one of possible representations.
533
 
534
  (1) If artifact representation is already an Artifact object, return it.
@@ -553,6 +557,11 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]
553
  name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
554
  if is_name_legal_for_catalog(name):
555
  catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
 
 
 
 
 
556
  artifact_to_return = catalog.get_with_overwrite(
557
  artifact_rep, overwrite_args=args
558
  )
 
22
  separate_inside_and_outside_square_brackets,
23
  )
24
  from .settings_utils import get_constants, get_settings
25
+ from .text_utils import camel_to_snake_case, is_camel_case, print_dict_as_yaml
26
  from .type_utils import isoftype, issubtype
27
  from .utils import (
28
  artifacts_json_cache,
 
369
  data = self.to_dict()
370
  return json_dump(data)
371
 
372
+ def to_yaml(self):
373
+ data = self.to_dict()
374
+ return print_dict_as_yaml(data)
375
+
376
  def serialize(self):
377
  if self.__id__ is not None:
378
  return self.__id__
 
532
  super().__init__(msg)
533
 
534
 
535
+ def fetch_artifact(artifact_rep, overwrite_kwargs: Optional[Dict[str, Any]]=None) -> Tuple[Artifact, Union[AbstractCatalog, None]]:
536
  """Loads an artifict from one of possible representations.
537
 
538
  (1) If artifact representation is already an Artifact object, return it.
 
557
  name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
558
  if is_name_legal_for_catalog(name):
559
  catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
560
+ if overwrite_kwargs is not None:
561
+ if args is None:
562
+ args = overwrite_kwargs
563
+ else:
564
+ args.update(overwrite_kwargs)
565
  artifact_to_return = catalog.get_with_overwrite(
566
  artifact_rep, overwrite_args=args
567
  )
dataset.py CHANGED
@@ -68,7 +68,6 @@ from .system_prompts import __file__ as _
68
  from .task import __file__ as _
69
  from .templates import __file__ as _
70
  from .text_utils import __file__ as _
71
- from .tool_calling import __file__ as _
72
  from .type_utils import __file__ as _
73
  from .types import __file__ as _
74
  from .utils import __file__ as _
 
68
  from .task import __file__ as _
69
  from .templates import __file__ as _
70
  from .text_utils import __file__ as _
 
71
  from .type_utils import __file__ as _
72
  from .types import __file__ as _
73
  from .utils import __file__ as _
dataset_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  from json.decoder import JSONDecodeError
 
2
 
3
  from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
4
  from .logging_utils import get_logger
@@ -11,19 +12,19 @@ logger = get_logger()
11
  settings = get_settings()
12
 
13
 
14
- def fetch(artifact_name):
15
  try:
16
- artifact, _ = fetch_artifact(artifact_name)
17
  return artifact
18
  except (UnitxtArtifactNotFoundError, JSONDecodeError):
19
  return None
20
 
21
 
22
- def parse(query: str):
23
  return parse_key_equals_value_string_to_dict(query)
24
 
25
 
26
- def get_dataset_artifact(dataset):
27
  if isinstance(dataset, DatasetRecipe):
28
  return dataset
29
  assert isinstance(
@@ -31,10 +32,12 @@ def get_dataset_artifact(dataset):
31
  ), "dataset should be string description of recipe, or recipe object."
32
  _reset_env_local_catalogs()
33
  register_all_artifacts()
34
- recipe = fetch(dataset)
35
  if recipe is None:
36
  args = parse(dataset)
37
  if "__type__" not in args:
38
  args["__type__"] = settings.default_recipe
 
 
39
  recipe = Artifact.from_dict(args)
40
  return recipe
 
1
  from json.decoder import JSONDecodeError
2
+ from typing import Any, Dict, Optional
3
 
4
  from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
5
  from .logging_utils import get_logger
 
12
  settings = get_settings()
13
 
14
 
15
+ def fetch(artifact_name: str, overwrite_kwargs: Optional[Dict[str, Any]]=None):
16
  try:
17
+ artifact, _ = fetch_artifact(artifact_name, overwrite_kwargs=overwrite_kwargs)
18
  return artifact
19
  except (UnitxtArtifactNotFoundError, JSONDecodeError):
20
  return None
21
 
22
 
23
+ def parse(query: str) -> dict:
24
  return parse_key_equals_value_string_to_dict(query)
25
 
26
 
27
+ def get_dataset_artifact(dataset, overwrite_kwargs: Optional[Dict[str, Any]]=None):
28
  if isinstance(dataset, DatasetRecipe):
29
  return dataset
30
  assert isinstance(
 
32
  ), "dataset should be string description of recipe, or recipe object."
33
  _reset_env_local_catalogs()
34
  register_all_artifacts()
35
+ recipe = fetch(dataset, overwrite_kwargs=overwrite_kwargs)
36
  if recipe is None:
37
  args = parse(dataset)
38
  if "__type__" not in args:
39
  args["__type__"] = settings.default_recipe
40
+ if overwrite_kwargs is not None:
41
+ args.update(overwrite_kwargs)
42
  recipe = Artifact.from_dict(args)
43
  return recipe
inference.py CHANGED
@@ -344,6 +344,8 @@ class InferenceEngine(Artifact):
344
 
345
  def to_tools(self, instance):
346
  task_data = instance.get("task_data")
 
 
347
  if isinstance(task_data, str):
348
  task_data = json.loads(task_data)
349
  if "__tools__" in task_data:
@@ -445,6 +447,8 @@ class HFInferenceEngineBase(
445
  model: Any = InternalField(default=None, name="Inference object")
446
  processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
447
 
 
 
448
  _requirements_list = {
449
  "transformers": "Install huggingface package using 'pip install --upgrade transformers",
450
  "torch": "Install torch, go on PyTorch website for mode details.",
@@ -655,8 +659,6 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
655
  truncation: bool = True
656
  padding_side: str = "left" # for decoder only models
657
 
658
- chat_kwargs_dict: dict = {}
659
-
660
  def _init_processor(self):
661
  from transformers import AutoTokenizer
662
 
@@ -712,10 +714,9 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
712
  trust_remote_code=True,
713
  **model_args,
714
  )
715
- if self.device_map is None:
716
- self.model.to(self.device)
717
 
718
  def prepare_inputs(self, data: Iterable) -> Mapping:
 
719
  if isinstance(data[0], list):
720
  data = self.processor.apply_chat_template(
721
  data,
@@ -723,6 +724,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
723
  add_generation_prompt=True,
724
  **self.chat_kwargs_dict,
725
  )
 
726
 
727
  if self.processor.pad_token is None:
728
  self.processor.pad_token_id = self.model.config.eos_token_id[0]
@@ -733,6 +735,8 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
733
  padding=self.padding,
734
  truncation=self.truncation,
735
  padding_side=self.padding_side,
 
 
736
  ).to(self.device or self.device_map)
737
 
738
  def _infer_fn(
@@ -755,13 +759,14 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
755
  """
756
  all_final_outputs = [] # List to store results from all batches
757
 
758
- for i in tqdm(
759
- range(0, len(dataset), self.batch_size),
760
  desc=f"Running inference in batches of {self.batch_size}",
 
761
  ):
 
762
  # Get the current batch
763
- batch_data = dataset[i : i + self.batch_size]
764
- batch_sources = [instance["source"] for instance in batch_data]
765
 
766
  # --- Process the current batch ---
767
  # 1. Tokenize inputs for the batch
@@ -800,7 +805,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
800
  j
801
  ], # Output for the j-th item in the batch
802
  output_tokens=len(string_tokens_batch[j]),
803
- inp=batch_data[j]["source"], # Original input for the j-th item
804
  inp_tokens=len(tokenized_inputs.encodings[j].tokens)
805
  if tokenized_inputs.encodings is not None
806
  else None,
@@ -1840,15 +1845,26 @@ class OpenAiInferenceEngine(
1840
  @run_with_imap
1841
  def _get_chat_completion(self, instance, return_meta_data):
1842
  import openai
1843
-
1844
  messages = self.to_messages(instance)
1845
  try:
1846
  response = self.client.chat.completions.create(
1847
  messages=messages,
 
1848
  model=self.get_client_model_name(),
1849
  **self._get_completion_kwargs(),
 
1850
  )
1851
- prediction = response.choices[0].message.content
 
 
 
 
 
 
 
 
 
1852
  return self.get_return_object(prediction, response, return_meta_data)
1853
  # catch in case of content_filtering failure
1854
  except openai.BadRequestError as e:
@@ -2742,14 +2758,37 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2742
  # images as SDK allows sending only one image per message.
2743
  return [messages]
2744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2745
  def _handle_async_requests(
2746
  self,
2747
- messages: List[List[Dict[str, Any]]],
2748
  params: Dict[str, Any],
2749
  ) -> List[Dict[str, Any]]:
2750
  async def handle_async_requests(start_idx, end_idx):
2751
  coroutines = [
2752
- self._model.achat(messages=messages[idx], params=params)
 
 
 
 
 
2753
  for idx in range(start_idx, end_idx)
2754
  ]
2755
  batch_results = await asyncio.gather(*coroutines)
@@ -2758,10 +2797,10 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2758
  loop = asyncio.get_event_loop()
2759
  results = []
2760
 
2761
- for batch_idx in range(0, len(messages), self.concurrency_limit):
2762
  batch_results = loop.run_until_complete(
2763
  handle_async_requests(
2764
- batch_idx, min(batch_idx + self.concurrency_limit, len(messages))
2765
  )
2766
  )
2767
  results.extend(batch_results)
@@ -2783,25 +2822,43 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
2783
  output_type = "message"
2784
  params["logprobs"] = False
2785
 
2786
- indexed_messages = [
2787
- (i, message)
 
 
 
 
2788
  for i in range(len(dataset))
2789
  for message in self.to_messages(dataset[i])
2790
  ]
2791
 
2792
- results = self._handle_async_requests(
2793
- [msg[1] for msg in indexed_messages], params
2794
- )
2795
 
2796
- return [
2797
- self.get_return_object(
2798
- result["choices"][0][output_type]["content"],
2799
- result,
2800
- dataset[idx[0]]["source"],
2801
- return_meta_data,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2802
  )
2803
- for result, idx in zip(results, indexed_messages)
2804
- ]
2805
 
2806
  def get_return_object(self, predict_result, result, input_text, return_meta_data):
2807
  if return_meta_data:
@@ -3439,7 +3496,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
3439
  "aws": LiteLLMInferenceEngine,
3440
  "ollama": OllamaInferenceEngine,
3441
  "bam": IbmGenAiInferenceEngine,
3442
- "watsonx-sdk": WMLInferenceEngine,
3443
  "rits": RITSInferenceEngine,
3444
  "azure": LiteLLMInferenceEngine,
3445
  "vertex-ai": LiteLLMInferenceEngine,
 
344
 
345
  def to_tools(self, instance):
346
  task_data = instance.get("task_data")
347
+ if task_data is None:
348
+ return None
349
  if isinstance(task_data, str):
350
  task_data = json.loads(task_data)
351
  if "__tools__" in task_data:
 
447
  model: Any = InternalField(default=None, name="Inference object")
448
  processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
449
 
450
+ chat_kwargs_dict: dict = {}
451
+
452
  _requirements_list = {
453
  "transformers": "Install huggingface package using 'pip install --upgrade transformers",
454
  "torch": "Install torch, go on PyTorch website for mode details.",
 
659
  truncation: bool = True
660
  padding_side: str = "left" # for decoder only models
661
 
 
 
662
  def _init_processor(self):
663
  from transformers import AutoTokenizer
664
 
 
714
  trust_remote_code=True,
715
  **model_args,
716
  )
 
 
717
 
718
  def prepare_inputs(self, data: Iterable) -> Mapping:
719
+ tokenizer_kargs = {}
720
  if isinstance(data[0], list):
721
  data = self.processor.apply_chat_template(
722
  data,
 
724
  add_generation_prompt=True,
725
  **self.chat_kwargs_dict,
726
  )
727
+ tokenizer_kargs["add_special_tokens"] = False
728
 
729
  if self.processor.pad_token is None:
730
  self.processor.pad_token_id = self.model.config.eos_token_id[0]
 
735
  padding=self.padding,
736
  truncation=self.truncation,
737
  padding_side=self.padding_side,
738
+ **tokenizer_kargs
739
+
740
  ).to(self.device or self.device_map)
741
 
742
  def _infer_fn(
 
759
  """
760
  all_final_outputs = [] # List to store results from all batches
761
 
762
+ for batch in tqdm(
763
+ batched(dataset, self.batch_size),
764
  desc=f"Running inference in batches of {self.batch_size}",
765
+ total=len(dataset) // self.batch_size,
766
  ):
767
+
768
  # Get the current batch
769
+ batch_sources = [instance["source"] for instance in batch]
 
770
 
771
  # --- Process the current batch ---
772
  # 1. Tokenize inputs for the batch
 
805
  j
806
  ], # Output for the j-th item in the batch
807
  output_tokens=len(string_tokens_batch[j]),
808
+ inp=batch[j]["source"], # Original input for the j-th item
809
  inp_tokens=len(tokenized_inputs.encodings[j].tokens)
810
  if tokenized_inputs.encodings is not None
811
  else None,
 
1845
  @run_with_imap
1846
  def _get_chat_completion(self, instance, return_meta_data):
1847
  import openai
1848
+ tools = self.to_tools(instance)
1849
  messages = self.to_messages(instance)
1850
  try:
1851
  response = self.client.chat.completions.create(
1852
  messages=messages,
1853
+ tools=tools,
1854
  model=self.get_client_model_name(),
1855
  **self._get_completion_kwargs(),
1856
+ # tool_choice="auto"
1857
  )
1858
+
1859
+ if tools is None:
1860
+ prediction = response.choices[0].message.content
1861
+ else:
1862
+ try:
1863
+ func_call = response.choices[0].message.tool_calls[0].function
1864
+ prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}'
1865
+ except:
1866
+ prediction = response.choices[0].message.content or ""
1867
+
1868
  return self.get_return_object(prediction, response, return_meta_data)
1869
  # catch in case of content_filtering failure
1870
  except openai.BadRequestError as e:
 
2758
  # images as SDK allows sending only one image per message.
2759
  return [messages]
2760
 
2761
+ def to_tools(
2762
+ self,
2763
+ instance: Dict[str, Any]
2764
+ ) -> Dict[str, Union[Optional[List[Dict[str, str]]], Optional[Dict[str, str]]]]:
2765
+ """watsonx.ai chat also allows specifying which tools models must use."""
2766
+ task_data = instance.get("task_data")
2767
+ if task_data is None:
2768
+ return {"tools": None, "tool_choice": None}
2769
+
2770
+ if isinstance(task_data, str):
2771
+ task_data = json.loads(task_data)
2772
+ if "__tools__" in task_data:
2773
+ tools: List[Dict[str, str]] = task_data["__tools__"]
2774
+ tool_choice: Optional[Dict[str, str]] = task_data.get("__tool_choice__")
2775
+ return {"tools": tools, "tool_choice": tool_choice}
2776
+
2777
+ return {"tools": None, "tool_choice": None}
2778
+
2779
  def _handle_async_requests(
2780
  self,
2781
+ data: List[Dict[str, Any]],
2782
  params: Dict[str, Any],
2783
  ) -> List[Dict[str, Any]]:
2784
  async def handle_async_requests(start_idx, end_idx):
2785
  coroutines = [
2786
+ self._model.achat(
2787
+ messages=data[idx]["msg"],
2788
+ params=params,
2789
+ tools=data[idx]["tools"]["tools"],
2790
+ tool_choice=data[idx]["tools"]["tool_choice"],
2791
+ )
2792
  for idx in range(start_idx, end_idx)
2793
  ]
2794
  batch_results = await asyncio.gather(*coroutines)
 
2797
  loop = asyncio.get_event_loop()
2798
  results = []
2799
 
2800
+ for batch_idx in range(0, len(data), self.concurrency_limit):
2801
  batch_results = loop.run_until_complete(
2802
  handle_async_requests(
2803
+ batch_idx, min(batch_idx + self.concurrency_limit, len(data))
2804
  )
2805
  )
2806
  results.extend(batch_results)
 
2822
  output_type = "message"
2823
  params["logprobs"] = False
2824
 
2825
+ data = [
2826
+ {
2827
+ "idx": i,
2828
+ "msg": message,
2829
+ "tools": self.to_tools(dataset[i]),
2830
+ }
2831
  for i in range(len(dataset))
2832
  for message in self.to_messages(dataset[i])
2833
  ]
2834
 
2835
+ responses = self._handle_async_requests(data, params)
 
 
2836
 
2837
+ results = []
2838
+ for inp, response in zip(data, responses):
2839
+ idx = inp["idx"]
2840
+ tool_call = data[idx]["tools"]["tools"] is not None
2841
+
2842
+ output = response["choices"][0][output_type]
2843
+ if tool_call:
2844
+ if "tool_calls" in output:
2845
+ func = output["tool_calls"][0]["function"]
2846
+ prediction = f'{{"name": "{func["name"]}", "arguments": {func["arguments"]}}}'
2847
+ else:
2848
+ prediction = output["content"]
2849
+ else:
2850
+ prediction = output["content"]
2851
+
2852
+ results.append(
2853
+ self.get_return_object(
2854
+ prediction,
2855
+ response,
2856
+ str(inp),
2857
+ return_meta_data,
2858
+ )
2859
  )
2860
+
2861
+ return results
2862
 
2863
  def get_return_object(self, predict_result, result, input_text, return_meta_data):
2864
  if return_meta_data:
 
3496
  "aws": LiteLLMInferenceEngine,
3497
  "ollama": OllamaInferenceEngine,
3498
  "bam": IbmGenAiInferenceEngine,
3499
+ "watsonx-sdk": WMLInferenceEngineChat,
3500
  "rits": RITSInferenceEngine,
3501
  "azure": LiteLLMInferenceEngine,
3502
  "vertex-ai": LiteLLMInferenceEngine,
metric.py CHANGED
@@ -65,7 +65,6 @@ from .system_prompts import __file__ as _
65
  from .task import __file__ as _
66
  from .templates import __file__ as _
67
  from .text_utils import __file__ as _
68
- from .tool_calling import __file__ as _
69
  from .type_utils import __file__ as _
70
  from .types import __file__ as _
71
  from .utils import __file__ as _
 
65
  from .task import __file__ as _
66
  from .templates import __file__ as _
67
  from .text_utils import __file__ as _
 
68
  from .type_utils import __file__ as _
69
  from .types import __file__ as _
70
  from .utils import __file__ as _
metrics.py CHANGED
@@ -63,7 +63,6 @@ from .operators import ArtifactFetcherMixin, Copy, Set
63
  from .random_utils import get_seed
64
  from .settings_utils import get_settings
65
  from .stream import MultiStream, Stream
66
- from .tool_calling import convert_chat_api_format_to_tool
67
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
68
  from .types import ToolCall
69
  from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
@@ -789,74 +788,92 @@ class F1Fast(MapReduceMetric[str, Tuple[int, int]]):
789
  return result
790
 
791
  class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
 
792
  main_score = "exact_match"
793
  reduction = MeanReduction()
794
  prediction_type = ToolCall
 
 
 
 
 
 
795
 
796
  def map(
797
  self, prediction: ToolCall, references: List[ToolCall], task_data: Dict[str, Any]
798
  ) -> Dict[str, float]:
799
 
800
-
801
  exact_match = float(
802
- str(prediction) in [str(reference) for reference in references]
803
  )
804
 
805
- tool_choice = float(
806
  str(prediction["name"]) in [str(reference["name"]) for reference in references]
807
  )
808
 
809
- parameter_choice = 0.0
810
  for reference in references:
811
- if len(prediction["arguments"]) > 0:
 
 
 
 
 
812
 
 
 
 
813
  score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(prediction["arguments"]))
814
- else:
815
  score = 1.0
816
- if score > parameter_choice:
817
- parameter_choice = score
 
 
 
818
 
 
819
 
820
- parameter_values = 0.0
821
  for reference in references:
822
  value_matches = 0
 
823
  for key, val in prediction["arguments"].items():
824
  try:
825
- if val in reference["arguments"][key] or reference["arguments"][key] in val:
 
 
826
  value_matches += 1
827
  except:
828
  pass
829
 
830
  if len(prediction["arguments"]) > 0:
831
-
832
  score = value_matches / len(prediction["arguments"])
833
  else:
834
  score = 1.0
835
- if score > parameter_values:
836
- parameter_values = score
837
 
 
838
  for tool in task_data["__tools__"]:
839
- tool = convert_chat_api_format_to_tool(tool)
840
- tool_params_types = {}
841
- for param in tool["parameters"]:
842
- tool_params_types[param["name"]] = param["type"]
843
- correct_parameters_types = 0
844
- for key, value in prediction["arguments"].items():
845
- typing_type = tool_params_types.get(key, Any)
846
- if isoftype(value, typing_type):
847
- correct_parameters_types += 1
848
- if len(prediction["arguments"]) > 0:
849
- parameters_types = correct_parameters_types / len(prediction["arguments"])
850
- else:
851
- parameters_types = 1.0
852
 
 
 
 
 
 
 
 
 
853
 
854
  return {
855
  self.main_score: exact_match,
856
- "tool_choice": tool_choice,
857
- "parameter_choice": parameter_choice,
858
- "parameters_types": parameters_types,
859
- "parameter_values": parameter_values
 
860
  }
861
 
862
 
@@ -3499,7 +3516,7 @@ class CustomF1(GlobalMetric):
3499
  class KeyValueExtraction(GlobalMetric):
3500
  prediction_type = Dict[str, str]
3501
  metric: Metric
3502
- single_reference_per_prediction = True
3503
  main_score = ""
3504
 
3505
  def prepare(self):
@@ -3575,6 +3592,33 @@ class KeyValueExtraction(GlobalMetric):
3575
 
3576
  return result
3577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3578
 
3579
  class NER(CustomF1):
3580
  """F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
 
63
  from .random_utils import get_seed
64
  from .settings_utils import get_settings
65
  from .stream import MultiStream, Stream
 
66
  from .type_utils import Type, isoftype, parse_type_string, to_type_string
67
  from .types import ToolCall
68
  from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
 
788
  return result
789
 
790
  class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
791
+ """Compares each predicted tool call with list of references tool call."""
792
  main_score = "exact_match"
793
  reduction = MeanReduction()
794
  prediction_type = ToolCall
795
+ _requirements_list = ["jsonschema-rs"]
796
+
797
+ def prepare(self):
798
+ super().prepare()
799
+ import jsonschema_rs
800
+ self._schema = jsonschema_rs
801
 
802
  def map(
803
  self, prediction: ToolCall, references: List[ToolCall], task_data: Dict[str, Any]
804
  ) -> Dict[str, float]:
805
 
 
806
  exact_match = float(
807
+ json.dumps(prediction, sort_keys=True) in [json.dumps(reference, sort_keys=True) for reference in references]
808
  )
809
 
810
+ tool_name_accuracy = float(
811
  str(prediction["name"]) in [str(reference["name"]) for reference in references]
812
  )
813
 
814
+ argument_name_recall = 0.0
815
  for reference in references:
816
+ if len(reference["arguments"]) > 0:
817
+ score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(reference["arguments"]))
818
+ else:
819
+ score = 1.0
820
+ if score > argument_name_recall:
821
+ argument_name_recall = score
822
 
823
+ argument_name_precision = 0.0
824
+ for reference in references:
825
+ if len(prediction["arguments"]) > 0:
826
  score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(prediction["arguments"]))
827
+ elif len(reference["arguments"]) == 0:
828
  score = 1.0
829
+ else:
830
+ score = 0.0
831
+ if score > argument_name_precision:
832
+ argument_name_precision = score
833
+
834
 
835
+ argument_value_precision = 0.0
836
 
 
837
  for reference in references:
838
  value_matches = 0
839
+
840
  for key, val in prediction["arguments"].items():
841
  try:
842
+ predicted = json.dumps(val, sort_keys=True)
843
+ target = json.dumps(reference["arguments"][key], sort_keys=True)
844
+ if predicted == target:
845
  value_matches += 1
846
  except:
847
  pass
848
 
849
  if len(prediction["arguments"]) > 0:
 
850
  score = value_matches / len(prediction["arguments"])
851
  else:
852
  score = 1.0
853
+ if score > argument_value_precision:
854
+ argument_value_precision = score
855
 
856
+ parameters = None
857
  for tool in task_data["__tools__"]:
858
+ if tool["function"]["name"] == prediction["name"]:
859
+ parameters = tool["function"]["parameters"]
 
 
 
 
 
 
 
 
 
 
 
860
 
861
+ if parameters is None:
862
+ argument_schema_validation = 0.0
863
+ else:
864
+ try:
865
+ self._schema.validate(parameters, prediction["arguments"], )
866
+ argument_schema_validation = 1.0
867
+ except self._schema.ValidationError:
868
+ argument_schema_validation = 0.0
869
 
870
  return {
871
  self.main_score: exact_match,
872
+ "tool_name_accuracy": tool_name_accuracy,
873
+ "argument_name_recall": argument_name_recall,
874
+ "argument_name_precision": argument_name_precision,
875
+ "argument_value_precision": argument_value_precision,
876
+ "argument_schema_validation": argument_schema_validation,
877
  }
878
 
879
 
 
3516
  class KeyValueExtraction(GlobalMetric):
3517
  prediction_type = Dict[str, str]
3518
  metric: Metric
3519
+ single_reference_per_prediction = False
3520
  main_score = ""
3521
 
3522
  def prepare(self):
 
3592
 
3593
  return result
3594
 
3595
+ class ToolCallKeyValueExtraction(KeyValueExtraction):
3596
+ prediction_type = ToolCall
3597
+
3598
+ def flatten_dict(self,nested_dict, parent_key="", sep="."):
3599
+ flat_dict = {}
3600
+ for k, v in nested_dict.items():
3601
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
3602
+ if isinstance(v, list):
3603
+ for e in v:
3604
+ if isinstance(e,dict):
3605
+ flat_dict.update(self.flatten_dict(e, new_key, sep=sep))
3606
+ elif isinstance(v, dict):
3607
+ flat_dict.update(self.flatten_dict(v, new_key, sep=sep))
3608
+ else:
3609
+ flat_dict[new_key] = v
3610
+ return flat_dict
3611
+
3612
+ def compute(
3613
+ self,
3614
+ references: List[List[ToolCall]],
3615
+ predictions: List[ToolCall],
3616
+ task_data: List[Dict],
3617
+ ) -> dict:
3618
+ return super().compute([[ self.flatten_dict(r) for r in ref ] for ref in references],
3619
+ [ self.flatten_dict(p) for p in predictions],task_data)
3620
+
3621
+
3622
 
3623
  class NER(CustomF1):
3624
  """F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
operators.py CHANGED
@@ -283,6 +283,53 @@ class Set(InstanceOperator):
283
  dict_set(instance, key, value)
284
  return instance
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  @deprecation(version="2.0.0", alternative=Set)
288
  class AddFields(Set):
 
283
  dict_set(instance, key, value)
284
  return instance
285
 
286
+ def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
287
+ """Recursively traverses a data structure (dicts and lists), replaces values of target_key using value_map, and removes values listed in value_remove.
288
+
289
+ Args:
290
+ data: The data structure (dict or list) to traverse.
291
+ target_key: The specific key whose value needs to be checked and replaced or removed.
292
+ value_map: A dictionary mapping old values to new values.
293
+ value_remove: A list of values to completely remove if found as values of target_key.
294
+
295
+ Returns:
296
+ The modified data structure. Modification is done in-place.
297
+ """
298
+ if value_remove is None:
299
+ value_remove = []
300
+
301
+ if isinstance(data, dict):
302
+ keys_to_delete = []
303
+ for key, value in data.items():
304
+ if key == target_key:
305
+ if isinstance(value, list):
306
+ data[key] = [
307
+ value_map.get(item, item)
308
+ for item in value
309
+ if not isinstance(item, dict) and item not in value_remove
310
+ ]
311
+ elif isinstance(value, dict):
312
+ pass # Skip or handle dict values if needed
313
+ elif value in value_remove:
314
+ keys_to_delete.append(key)
315
+ elif value in value_map:
316
+ data[key] = value_map[value]
317
+ else:
318
+ recursive_key_value_replace(value, target_key, value_map, value_remove)
319
+ for key in keys_to_delete:
320
+ del data[key]
321
+ elif isinstance(data, list):
322
+ for item in data:
323
+ recursive_key_value_replace(item, target_key, value_map, value_remove)
324
+ return data
325
+
326
+ class RecursiveReplace(InstanceOperator):
327
+ key: str
328
+ map_values: dict
329
+ remove_values: Optional[list] = None
330
+
331
+ def process(self, instance: Dict[str, Any], stream_name: Optional[str] = None) -> Dict[str, Any]:
332
+ return recursive_key_value_replace(instance, self.key, self.map_values, self.remove_values)
333
 
334
  @deprecation(version="2.0.0", alternative=Set)
335
  class AddFields(Set):
serializers.py CHANGED
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Union
7
  from .dataclass import AbstractField, Field
8
  from .operators import InstanceFieldOperator
9
  from .settings_utils import get_constants
10
- from .tool_calling import convert_to_chat_api_format
11
  from .type_utils import isoftype, to_type_string
12
  from .types import (
13
  Dialog,
@@ -168,24 +167,20 @@ class MultiDocumentSerializer(DocumentSerializer):
168
  class ToolsSerializer(SingleTypeSerializer):
169
 
170
  serialized_type = List[Tool]
171
- _requirements_list: List[str] = ["pydantic"]
172
 
173
  def serialize(self, value: List[Tool], instance: Dict[str, Any]) -> str:
174
  if "__tools__" not in instance:
175
  instance["__tools__"] = []
176
  tool = []
177
  for tool in value:
178
- chat_api_tool = convert_to_chat_api_format(tool=tool)
179
  instance["__tools__"].append(
180
- chat_api_tool
181
  )
182
- tool["parameters"] = chat_api_tool["function"]["parameters"]
183
  return json.dumps(instance["__tools__"], indent=4)
184
 
185
  class ToolCallSerializer(SingleTypeSerializer):
186
 
187
  serialized_type = ToolCall
188
- _requirements_list: List[str] = ["pydantic"]
189
 
190
  def serialize(self, value: ToolCall, instance: Dict[str, Any]) -> str:
191
  return json.dumps(value)
 
7
  from .dataclass import AbstractField, Field
8
  from .operators import InstanceFieldOperator
9
  from .settings_utils import get_constants
 
10
  from .type_utils import isoftype, to_type_string
11
  from .types import (
12
  Dialog,
 
167
  class ToolsSerializer(SingleTypeSerializer):
168
 
169
  serialized_type = List[Tool]
 
170
 
171
  def serialize(self, value: List[Tool], instance: Dict[str, Any]) -> str:
172
  if "__tools__" not in instance:
173
  instance["__tools__"] = []
174
  tool = []
175
  for tool in value:
 
176
  instance["__tools__"].append(
177
+ {"type": "function", "function": tool}
178
  )
 
179
  return json.dumps(instance["__tools__"], indent=4)
180
 
181
  class ToolCallSerializer(SingleTypeSerializer):
182
 
183
  serialized_type = ToolCall
 
184
 
185
  def serialize(self, value: ToolCall, instance: Dict[str, Any]) -> str:
186
  return json.dumps(value)
struct_data_operators.py CHANGED
@@ -43,7 +43,7 @@ from .operators import FieldOperator, InstanceOperator
43
  from .random_utils import new_random_generator
44
  from .serializers import ImageSerializer, TableSerializer
45
  from .type_utils import isoftype
46
- from .types import Table
47
  from .utils import recursive_copy
48
 
49
 
@@ -754,6 +754,26 @@ class LoadJson(FieldOperator):
754
  return json.loads(value, strict=False)
755
 
756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
  class DumpJson(FieldOperator):
758
  def process_value(self, value: str) -> str:
759
  return json.dumps(value)
 
43
  from .random_utils import new_random_generator
44
  from .serializers import ImageSerializer, TableSerializer
45
  from .type_utils import isoftype
46
+ from .types import Table, ToolCall
47
  from .utils import recursive_copy
48
 
49
 
 
754
  return json.loads(value, strict=False)
755
 
756
 
757
+ class ToolCallPostProcessor(FieldOperator):
758
+ failure_value: Any = None
759
+ allow_failure: bool = False
760
+ def process_value(self, value: str) -> ToolCall:
761
+ if self.allow_failure:
762
+ try:
763
+ result = json.loads(value)
764
+ except json.JSONDecodeError:
765
+ return self.failure_value
766
+ else:
767
+ result = json.loads(value, strict=False)
768
+ if isoftype(result, List[ToolCall]):
769
+ if len(result) > 1:
770
+ UnitxtWarning(f"More than one tool returned from model: {result}" )
771
+ return self.failure_value
772
+ return result[0]
773
+ if not isoftype(result, ToolCall):
774
+ return self.failure_value
775
+ return result
776
+
777
  class DumpJson(FieldOperator):
778
  def process_value(self, value: str) -> str:
779
  return json.dumps(value)
tool_calling.py DELETED
@@ -1,119 +0,0 @@
1
- from typing import Any, Dict, List, Type
2
-
3
- from .operators import FieldOperator
4
- from .types import Parameter, Tool
5
-
6
-
7
- def convert_to_chat_api_format(tool: Tool) -> Dict[str, Any]:
8
-
9
- from pydantic import create_model
10
-
11
- field_definitions = {}
12
- for param in tool["parameters"]:
13
- param_name = param["name"]
14
- param_type = param.get("type", Any)
15
- field_definitions[param_name] = (param_type, ...) # ... means required in Pydantic
16
-
17
- model = create_model(f"{tool['name']}Params", **field_definitions)
18
-
19
- schema = model.model_json_schema()
20
-
21
- return {
22
- "type": "function",
23
- "function": {
24
- "name": tool["name"],
25
- "description": tool["description"],
26
- "parameters": schema
27
- }
28
- }
29
-
30
-
31
- def convert_chat_api_format_to_tool(chat_api_tool: Dict[str, Any]) -> Tool:
32
- """Convert a Chat API formatted tool back to the original Tool structure.
33
-
34
- Args:
35
- chat_api_tool: A dictionary representing a tool in Chat API format
36
-
37
- Returns:
38
- A Tool dictionary with name, description, and parameters
39
- """
40
- # Extract function information
41
- function_info = chat_api_tool.get("function", {})
42
- name = function_info.get("name", chat_api_tool.get("name", ""))
43
- description = function_info.get("description", chat_api_tool.get("description", ""))
44
-
45
- # Extract parameters from schema
46
- parameters: List[Parameter] = []
47
- schema = function_info.get("parameters", chat_api_tool.get("parameters", ""))
48
- properties = schema.get("properties", {})
49
-
50
- for param_name, param_schema in properties.items():
51
- # Map JSON schema type to Python type
52
- param_type = json_schema_to_python_type(param_schema)
53
-
54
- parameter: Parameter = {
55
- "name": param_name,
56
- "type": param_type
57
- }
58
- parameters.append(parameter)
59
-
60
- # Construct and return the Tool
61
- tool: Tool = {
62
- "name": name,
63
- "description": description,
64
- "parameters": parameters
65
- }
66
-
67
- return tool
68
-
69
- def json_schema_to_python_type(schema: Dict[str, Any]) -> Type:
70
- """Convert JSON schema type to Python type."""
71
- from typing import Any, Dict, List, Union
72
-
73
- schema_type = schema.get("type")
74
-
75
- # Handle simple types
76
- simple_types = {
77
- "string": str,
78
- "integer": int,
79
- "number": float,
80
- "boolean": bool,
81
- "null": type(None)
82
- }
83
-
84
- if schema_type in simple_types:
85
- return simple_types[schema_type]
86
-
87
- # Handle arrays
88
- if schema_type == "array":
89
- items = schema.get("items", {})
90
- if not items:
91
- return List[Any]
92
-
93
- item_type = json_schema_to_python_type(items)
94
- return List[item_type]
95
-
96
- # Handle objects
97
- if schema_type == "object":
98
- return Dict[str, Any]
99
-
100
- # Handle unions with anyOf/oneOf
101
- if "anyOf" in schema or "oneOf" in schema:
102
- union_schemas = schema.get("anyOf", []) or schema.get("oneOf", [])
103
- union_types = [json_schema_to_python_type(s) for s in union_schemas]
104
- # Use Union for Python 3.9+ or create Union using typing module
105
- return Union[tuple(union_types)] if union_types else Any
106
-
107
- # Handle references (simplified)
108
- if "$ref" in schema:
109
- # In a real implementation, you'd resolve references
110
- return Any
111
-
112
- # Default to Any for unrecognized schema types
113
- return Any
114
-
115
-
116
- class ToTool(FieldOperator):
117
-
118
- def process_value(self, value: Dict[str, Any]) -> Tool:
119
- return convert_chat_api_format_to_tool(value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
type_utils.py CHANGED
@@ -27,7 +27,7 @@ _registered_types = {
27
  def register_type(new_type):
28
  assert is_new_type(new_type) or is_typed_dict(
29
  new_type
30
- ), "Can register only typing.NewType or typing.TypedDict"
31
  _registered_types[new_type.__name__] = new_type
32
 
33
 
@@ -489,6 +489,9 @@ def isoftype(object, typing_type):
489
  if not is_type(typing_type):
490
  raise UnsupportedTypeError(typing_type)
491
 
 
 
 
492
  if typing_type is typing.Type:
493
  return is_type(object)
494
 
@@ -1066,9 +1069,18 @@ def verify_required_schema(
1066
  f"{class_name} description: {description}"
1067
  ) from e
1068
 
1069
- if not isoftype(value, data_type):
 
 
 
 
 
 
 
 
 
1070
  raise ValueError(
1071
- f"Passed value '{value}' of field '{field_name}' is not "
1072
  f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
1073
  f"{class_name} description: {description}"
1074
  )
 
27
  def register_type(new_type):
28
  assert is_new_type(new_type) or is_typed_dict(
29
  new_type
30
+ ) or hasattr(new_type, "__verify_type__"), "Can register only typing.NewType or typing.TypedDict or object with __verify_type__ class function"
31
  _registered_types[new_type.__name__] = new_type
32
 
33
 
 
489
  if not is_type(typing_type):
490
  raise UnsupportedTypeError(typing_type)
491
 
492
+ if hasattr(typing_type, "__verify_type__"):
493
+ return typing_type.__verify_type__(object)
494
+
495
  if typing_type is typing.Type:
496
  return is_type(object)
497
 
 
1069
  f"{class_name} description: {description}"
1070
  ) from e
1071
 
1072
+ try:
1073
+ valid = isoftype(value, data_type)
1074
+ except Exception as e:
1075
+ raise ValueError(
1076
+ f"Passed value {value} of field '{field_name}' is not "
1077
+ f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
1078
+ f"{class_name} description: {description}\nReason:\n{e}"
1079
+ ) from e
1080
+
1081
+ if not valid:
1082
  raise ValueError(
1083
+ f"Passed value {value} of field '{field_name}' is not "
1084
  f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
1085
  f"{class_name} description: {description}"
1086
  )
types.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Literal, NewType, Optional, Type, TypedDict, Union
2
 
3
  from .type_utils import register_type
4
 
@@ -51,14 +51,20 @@ class SQLDatabase(TypedDict):
51
  dbms: Optional[str]
52
  data: Optional[Dict[str, Dict]]
53
 
54
- class Parameter(TypedDict):
55
- name: str
56
- type: Optional[Type] # Using actual Python type objects
 
 
 
 
 
 
57
 
58
  class Tool(TypedDict):
59
  name: str
60
  description: str
61
- parameters: List[Parameter]
62
 
63
  class ToolCall(TypedDict):
64
  name: str
@@ -76,7 +82,7 @@ register_type(Document)
76
  register_type(MultiDocument)
77
  register_type(RagResponse)
78
  register_type(SQLDatabase)
79
- register_type(Parameter)
80
  register_type(Tool)
 
81
  register_type(ToolCall)
82
 
 
1
+ from typing import Any, Dict, List, Literal, NewType, Optional, TypedDict, Union
2
 
3
  from .type_utils import register_type
4
 
 
51
  dbms: Optional[str]
52
  data: Optional[Dict[str, Dict]]
53
 
54
+ class JsonSchema:
55
+
56
+ @classmethod
57
+ def __verify_type__(cls, object):
58
+ if not isinstance(object, dict):
59
+ return False
60
+ import jsonschema_rs
61
+ jsonschema_rs.meta.validate(object)
62
+ return True
63
 
64
  class Tool(TypedDict):
65
  name: str
66
  description: str
67
+ parameters: JsonSchema
68
 
69
  class ToolCall(TypedDict):
70
  name: str
 
82
  register_type(MultiDocument)
83
  register_type(RagResponse)
84
  register_type(SQLDatabase)
 
85
  register_type(Tool)
86
+ register_type(JsonSchema)
87
  register_type(ToolCall)
88
 
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.22.4"
 
1
+ version = "1.23.0"