Upload folder using huggingface_hub
Browse files- inference.py +72 -31
- metrics.py +54 -0
- operators.py +1 -1
- parsing_utils.py +1 -1
- standard.py +4 -0
- version.py +1 -1
inference.py
CHANGED
@@ -19,6 +19,7 @@ from typing import (
|
|
19 |
Optional,
|
20 |
Sequence,
|
21 |
Tuple,
|
|
|
22 |
Union,
|
23 |
)
|
24 |
|
@@ -1407,6 +1408,11 @@ class IbmGenAiInferenceEngine(
|
|
1407 |
return dataset
|
1408 |
|
1409 |
|
|
|
|
|
|
|
|
|
|
|
1410 |
class OpenAiInferenceEngineParamsMixin(Artifact):
|
1411 |
frequency_penalty: Optional[float] = None
|
1412 |
presence_penalty: Optional[float] = None
|
@@ -1453,27 +1459,40 @@ class OpenAiInferenceEngine(
|
|
1453 |
}
|
1454 |
data_classification_policy = ["public"]
|
1455 |
parameters: Optional[OpenAiInferenceEngineParams] = None
|
|
|
|
|
|
|
1456 |
|
1457 |
-
def get_engine_id(self):
|
1458 |
return get_model_and_label_id(self.model_name, self.label)
|
1459 |
|
1460 |
-
|
1461 |
-
|
1462 |
-
|
1463 |
-
assert api_key is not None, (
|
1464 |
-
f"Error while trying to run {inference_engine}."
|
1465 |
-
f" Please set the environment param '{api_param_env_var_name}'."
|
1466 |
)
|
1467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1468 |
|
1469 |
def create_client(self):
|
1470 |
from openai import OpenAI
|
1471 |
|
1472 |
-
|
1473 |
-
|
1474 |
-
|
|
|
|
|
1475 |
)
|
1476 |
-
return OpenAI(api_key=api_key)
|
1477 |
|
1478 |
def prepare_engine(self):
|
1479 |
self.client = self.create_client()
|
@@ -1553,6 +1572,32 @@ class OpenAiInferenceEngine(
|
|
1553 |
return predict_result
|
1554 |
|
1555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1556 |
class TogetherAiInferenceEngineParamsMixin(Artifact):
|
1557 |
max_tokens: Optional[int] = None
|
1558 |
stop: Optional[List[str]] = None
|
@@ -1652,23 +1697,6 @@ class TogetherAiInferenceEngine(
|
|
1652 |
return outputs
|
1653 |
|
1654 |
|
1655 |
-
class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
|
1656 |
-
label: str = "vllm"
|
1657 |
-
|
1658 |
-
def create_client(self):
|
1659 |
-
from openai import OpenAI
|
1660 |
-
|
1661 |
-
api_key = self.get_api_param(
|
1662 |
-
inference_engine="VLLMRemoteInferenceEngine",
|
1663 |
-
api_param_env_var_name="VLLM_API_KEY",
|
1664 |
-
)
|
1665 |
-
api_url = self.get_api_param(
|
1666 |
-
inference_engine="VLLMRemoteInferenceEngine",
|
1667 |
-
api_param_env_var_name="VLLM_API_URL",
|
1668 |
-
)
|
1669 |
-
return OpenAI(api_key=api_key, base_url=api_url)
|
1670 |
-
|
1671 |
-
|
1672 |
@deprecation(
|
1673 |
version="2.0.0",
|
1674 |
msg=" You can specify inference parameters directly when initializing an inference engine.",
|
@@ -2667,7 +2695,7 @@ class LiteLLMInferenceEngine(
|
|
2667 |
|
2668 |
|
2669 |
_supported_apis = Literal[
|
2670 |
-
"watsonx", "together-ai", "open-ai", "aws", "ollama", "bam", "watsonx-sdk"
|
2671 |
]
|
2672 |
|
2673 |
|
@@ -2698,6 +2726,8 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
2698 |
"granite-3-8b-instruct": "watsonx/ibm/granite-3-8b-instruct",
|
2699 |
"flan-t5-xxl": "watsonx/google/flan-t5-xxl",
|
2700 |
"llama-3-2-1b-instruct": "watsonx/meta-llama/llama-3-2-1b-instruct",
|
|
|
|
|
2701 |
},
|
2702 |
"watsonx-sdk": {
|
2703 |
"llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct",
|
@@ -2723,6 +2753,15 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
2723 |
"llama-3-2-1b-instruct": "meta-llama/llama-3-2-1b-instruct",
|
2724 |
"flan-t5-xxl": "google/flan-t5-xxl",
|
2725 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2726 |
}
|
2727 |
|
2728 |
_provider_to_base_class = {
|
@@ -2733,11 +2772,13 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
2733 |
"ollama": OllamaInferenceEngine,
|
2734 |
"bam": IbmGenAiInferenceEngine,
|
2735 |
"watsonx-sdk": WMLInferenceEngine,
|
|
|
2736 |
}
|
2737 |
|
2738 |
_provider_param_renaming = {
|
2739 |
"bam": {"max_tokens": "max_new_tokens", "model": "model_name"},
|
2740 |
"watsonx-sdk": {"max_tokens": "max_new_tokens", "model": "model_name"},
|
|
|
2741 |
}
|
2742 |
|
2743 |
def get_provider_name(self):
|
@@ -2747,7 +2788,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
2747 |
provider = self.get_provider_name()
|
2748 |
if provider not in self._provider_to_base_class:
|
2749 |
raise UnitxtError(
|
2750 |
-
f"{provider} a
|
2751 |
)
|
2752 |
if self.model not in self.provider_model_map[provider]:
|
2753 |
raise UnitxtError(
|
|
|
19 |
Optional,
|
20 |
Sequence,
|
21 |
Tuple,
|
22 |
+
TypedDict,
|
23 |
Union,
|
24 |
)
|
25 |
|
|
|
1408 |
return dataset
|
1409 |
|
1410 |
|
1411 |
+
class CredentialsOpenAi(TypedDict, total=False):
|
1412 |
+
api_key: str
|
1413 |
+
api_url: str
|
1414 |
+
|
1415 |
+
|
1416 |
class OpenAiInferenceEngineParamsMixin(Artifact):
|
1417 |
frequency_penalty: Optional[float] = None
|
1418 |
presence_penalty: Optional[float] = None
|
|
|
1459 |
}
|
1460 |
data_classification_policy = ["public"]
|
1461 |
parameters: Optional[OpenAiInferenceEngineParams] = None
|
1462 |
+
base_url: Optional[str] = None
|
1463 |
+
default_headers: Dict[str, str] = {}
|
1464 |
+
credentials: CredentialsOpenAi = {}
|
1465 |
|
1466 |
+
def get_engine_id(self) -> str:
|
1467 |
return get_model_and_label_id(self.model_name, self.label)
|
1468 |
|
1469 |
+
def _prepare_credentials(self) -> CredentialsOpenAi:
|
1470 |
+
api_key = self.credentials.get(
|
1471 |
+
"api_key", os.environ.get(f"{self.label.upper()}_API_KEY", None)
|
|
|
|
|
|
|
1472 |
)
|
1473 |
+
assert api_key, (
|
1474 |
+
f"Error while trying to run {self.label}. "
|
1475 |
+
f"Please set the env variable: '{self.label.upper()}_API_KEY'"
|
1476 |
+
)
|
1477 |
+
|
1478 |
+
api_url = self.credentials.get(
|
1479 |
+
"api_url", os.environ.get(f"{self.label.upper()}_API_URL", None)
|
1480 |
+
)
|
1481 |
+
|
1482 |
+
return {"api_key": api_key, "api_url": api_url}
|
1483 |
+
|
1484 |
+
def get_default_headers(self) -> Dict[str, str]:
|
1485 |
+
return self.default_headers
|
1486 |
|
1487 |
def create_client(self):
|
1488 |
from openai import OpenAI
|
1489 |
|
1490 |
+
self.credentials = self._prepare_credentials()
|
1491 |
+
return OpenAI(
|
1492 |
+
api_key=self.credentials["api_key"],
|
1493 |
+
base_url=self.base_url or self.credentials["api_url"],
|
1494 |
+
default_headers=self.get_default_headers(),
|
1495 |
)
|
|
|
1496 |
|
1497 |
def prepare_engine(self):
|
1498 |
self.client = self.create_client()
|
|
|
1572 |
return predict_result
|
1573 |
|
1574 |
|
1575 |
+
class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
|
1576 |
+
label: str = "vllm"
|
1577 |
+
|
1578 |
+
|
1579 |
+
class RITSInferenceEngine(OpenAiInferenceEngine):
|
1580 |
+
label: str = "rits"
|
1581 |
+
|
1582 |
+
def get_default_headers(self):
|
1583 |
+
return {"RITS_API_KEY": self.credentials["api_key"]}
|
1584 |
+
|
1585 |
+
def prepare_engine(self):
|
1586 |
+
base_url_template = "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{}/v1"
|
1587 |
+
self.base_url = base_url_template.format(self._get_model_name_for_endpoint())
|
1588 |
+
logger.info(f"Created RITS inference engine with endpoint: {self.base_url}")
|
1589 |
+
super().prepare_engine()
|
1590 |
+
|
1591 |
+
def _get_model_name_for_endpoint(self):
|
1592 |
+
return (
|
1593 |
+
self.model_name.split("/")[-1]
|
1594 |
+
.lower()
|
1595 |
+
.replace("v0.1", "v01")
|
1596 |
+
.replace("vision-", "")
|
1597 |
+
.replace(".", "-")
|
1598 |
+
)
|
1599 |
+
|
1600 |
+
|
1601 |
class TogetherAiInferenceEngineParamsMixin(Artifact):
|
1602 |
max_tokens: Optional[int] = None
|
1603 |
stop: Optional[List[str]] = None
|
|
|
1697 |
return outputs
|
1698 |
|
1699 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1700 |
@deprecation(
|
1701 |
version="2.0.0",
|
1702 |
msg=" You can specify inference parameters directly when initializing an inference engine.",
|
|
|
2695 |
|
2696 |
|
2697 |
_supported_apis = Literal[
|
2698 |
+
"watsonx", "together-ai", "open-ai", "aws", "ollama", "bam", "watsonx-sdk", "rits"
|
2699 |
]
|
2700 |
|
2701 |
|
|
|
2726 |
"granite-3-8b-instruct": "watsonx/ibm/granite-3-8b-instruct",
|
2727 |
"flan-t5-xxl": "watsonx/google/flan-t5-xxl",
|
2728 |
"llama-3-2-1b-instruct": "watsonx/meta-llama/llama-3-2-1b-instruct",
|
2729 |
+
"llama-3-2-11b-vision-instruct": "watsonx/meta-llama/llama-3-2-11b-vision-instruct",
|
2730 |
+
"llama-3-2-90b-vision-instruct": "watsonx/meta-llama/llama-3-2-90b-vision-instruct",
|
2731 |
},
|
2732 |
"watsonx-sdk": {
|
2733 |
"llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct",
|
|
|
2753 |
"llama-3-2-1b-instruct": "meta-llama/llama-3-2-1b-instruct",
|
2754 |
"flan-t5-xxl": "google/flan-t5-xxl",
|
2755 |
},
|
2756 |
+
"rits": {
|
2757 |
+
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
2758 |
+
"llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
|
2759 |
+
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
2760 |
+
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
2761 |
+
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
2762 |
+
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
2763 |
+
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
2764 |
+
},
|
2765 |
}
|
2766 |
|
2767 |
_provider_to_base_class = {
|
|
|
2772 |
"ollama": OllamaInferenceEngine,
|
2773 |
"bam": IbmGenAiInferenceEngine,
|
2774 |
"watsonx-sdk": WMLInferenceEngine,
|
2775 |
+
"rits": RITSInferenceEngine,
|
2776 |
}
|
2777 |
|
2778 |
_provider_param_renaming = {
|
2779 |
"bam": {"max_tokens": "max_new_tokens", "model": "model_name"},
|
2780 |
"watsonx-sdk": {"max_tokens": "max_new_tokens", "model": "model_name"},
|
2781 |
+
"rits": {"model": "model_name"},
|
2782 |
}
|
2783 |
|
2784 |
def get_provider_name(self):
|
|
|
2788 |
provider = self.get_provider_name()
|
2789 |
if provider not in self._provider_to_base_class:
|
2790 |
raise UnitxtError(
|
2791 |
+
f"{provider} is not a configured API for CrossProviderInferenceEngine. Supported apis: {','.join(self.provider_model_map.keys())}"
|
2792 |
)
|
2793 |
if self.model not in self.provider_model_map[provider]:
|
2794 |
raise UnitxtError(
|
metrics.py
CHANGED
@@ -3536,6 +3536,60 @@ class Perplexity(BulkInstanceMetric):
|
|
3536 |
return shifted_logits, shifted_labels
|
3537 |
|
3538 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3539 |
class Squad(HuggingfaceMetric):
|
3540 |
hf_metric_name = "squad"
|
3541 |
main_score = "f1"
|
|
|
3536 |
return shifted_logits, shifted_labels
|
3537 |
|
3538 |
|
3539 |
+
class FaithfulnessHHEM(BulkInstanceMetric):
|
3540 |
+
reduction_map = {"mean": ["score"]}
|
3541 |
+
main_score = "score"
|
3542 |
+
batch_size: int = 2
|
3543 |
+
model_name: str = "vectara/hallucination_evaluation_model"
|
3544 |
+
prediction_type = str
|
3545 |
+
single_reference_per_prediction = True
|
3546 |
+
max_context_words = 4096
|
3547 |
+
|
3548 |
+
_requirements_list: List[str] = ["transformers", "torch"]
|
3549 |
+
|
3550 |
+
def prepare(self):
|
3551 |
+
super().prepare()
|
3552 |
+
import torch
|
3553 |
+
|
3554 |
+
if torch.cuda.is_available():
|
3555 |
+
device = "cuda"
|
3556 |
+
elif torch.backends.mps.is_available():
|
3557 |
+
device = "mps"
|
3558 |
+
else:
|
3559 |
+
device = "cpu"
|
3560 |
+
from transformers import AutoModelForSequenceClassification
|
3561 |
+
|
3562 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
3563 |
+
self.model_name, trust_remote_code=True
|
3564 |
+
).to(device)
|
3565 |
+
|
3566 |
+
def compute(
|
3567 |
+
self,
|
3568 |
+
references: List[List[Any]],
|
3569 |
+
predictions: List[Any],
|
3570 |
+
task_data: List[Dict],
|
3571 |
+
) -> List[Dict[str, Any]]:
|
3572 |
+
from tqdm import tqdm
|
3573 |
+
|
3574 |
+
# treat the references as the contexts and the predictions as answers
|
3575 |
+
# concat references
|
3576 |
+
contexts = ["\n".join(refs) for refs in references]
|
3577 |
+
contexts = [" ".join(c.split(" ")[: self.max_context_words]) for c in contexts]
|
3578 |
+
answers = predictions
|
3579 |
+
|
3580 |
+
# prepare for computation
|
3581 |
+
inputs = [[c, a] for c, a in zip(contexts, answers)]
|
3582 |
+
scores = []
|
3583 |
+
input_batches = [
|
3584 |
+
inputs[x : x + self.batch_size]
|
3585 |
+
for x in range(0, len(inputs), self.batch_size)
|
3586 |
+
]
|
3587 |
+
for input_batch in tqdm(input_batches, "input batch"):
|
3588 |
+
batch_scores = self.model.predict(input_batch).cpu().tolist()
|
3589 |
+
scores.extend(batch_scores)
|
3590 |
+
return [{"score": score} for score in scores]
|
3591 |
+
|
3592 |
+
|
3593 |
class Squad(HuggingfaceMetric):
|
3594 |
hf_metric_name = "squad"
|
3595 |
main_score = "f1"
|
operators.py
CHANGED
@@ -450,7 +450,7 @@ class InstanceFieldOperator(InstanceOperator):
|
|
450 |
)
|
451 |
if old_value is default_place_holder:
|
452 |
if self.not_exist_do_nothing:
|
453 |
-
|
454 |
old_value = self.get_default
|
455 |
except Exception as e:
|
456 |
raise ValueError(
|
|
|
450 |
)
|
451 |
if old_value is default_place_holder:
|
452 |
if self.not_exist_do_nothing:
|
453 |
+
continue
|
454 |
old_value = self.get_default
|
455 |
except Exception as e:
|
456 |
raise ValueError(
|
parsing_utils.py
CHANGED
@@ -45,7 +45,7 @@ from typing import Any, Tuple
|
|
45 |
def consume_name_val(instring: str) -> Tuple[Any, str]:
|
46 |
name_val = ""
|
47 |
for char in instring:
|
48 |
-
if char in "[]
|
49 |
break
|
50 |
name_val += char
|
51 |
instring = instring[len(name_val) :].strip()
|
|
|
45 |
def consume_name_val(instring: str) -> Tuple[Any, str]:
|
46 |
name_val = ""
|
47 |
for char in instring:
|
48 |
+
if char in "[],{}=":
|
49 |
break
|
50 |
name_val += char
|
51 |
instring = instring[len(name_val) :].strip()
|
standard.py
CHANGED
@@ -140,6 +140,10 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
140 |
f"post processors must be a list of post processor. Got postprocessors = {self.postprocessors}"
|
141 |
)
|
142 |
|
|
|
|
|
|
|
|
|
143 |
if self.template is None:
|
144 |
raise ValueError(
|
145 |
"You must set in the recipe either `template`, `template_card_index`."
|
|
|
140 |
f"post processors must be a list of post processor. Got postprocessors = {self.postprocessors}"
|
141 |
)
|
142 |
|
143 |
+
if self.format is not None and not isinstance(self.format, Format):
|
144 |
+
raise ValueError(
|
145 |
+
f"format parameter must be a list of of class derived from Format. Got format = {self.format}"
|
146 |
+
)
|
147 |
if self.template is None:
|
148 |
raise ValueError(
|
149 |
"You must set in the recipe either `template`, `template_card_index`."
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.15.
|
|
|
1 |
+
version = "1.15.8"
|