|
import abc |
|
import os |
|
import re |
|
from typing import Any, Dict, List, Literal, Optional, Union |
|
|
|
from tqdm import tqdm |
|
|
|
from .artifact import Artifact |
|
from .dataclass import InternalField, NonPositionalField |
|
from .deprecation_utils import deprecation |
|
from .image_operators import extract_images |
|
from .logging_utils import get_logger |
|
from .operator import PackageRequirementsMixin |
|
|
|
|
|
class InferenceEngine(abc.ABC, Artifact): |
|
"""Abstract base class for inference.""" |
|
|
|
@abc.abstractmethod |
|
def _infer(self, dataset): |
|
"""Perform inference on the input dataset.""" |
|
pass |
|
|
|
def infer(self, dataset) -> str: |
|
"""Verifies instances of a dataset and performs inference.""" |
|
[self.verify_instance(instance) for instance in dataset] |
|
return self._infer(dataset) |
|
|
|
@deprecation(version="2.0.0") |
|
def _set_inference_parameters(self): |
|
"""Sets inference parameters of an instance based on 'parameters' attribute (if given).""" |
|
if hasattr(self, "parameters") and self.parameters is not None: |
|
get_logger().warning( |
|
f"The 'parameters' attribute of '{self.get_pretty_print_name()}' " |
|
f"is deprecated. Please pass inference parameters directly to the " |
|
f"inference engine instance instead." |
|
) |
|
|
|
for param, param_dict_val in self.parameters.to_dict( |
|
[self.parameters] |
|
).items(): |
|
param_inst_val = getattr(self, param) |
|
if param_inst_val is None: |
|
setattr(self, param, param_dict_val) |
|
|
|
|
|
class LogProbInferenceEngine(abc.ABC, Artifact): |
|
"""Abstract base class for inference with log probs.""" |
|
|
|
@abc.abstractmethod |
|
def _infer_log_probs(self, dataset): |
|
"""Perform inference on the input dataset that returns log probs.""" |
|
pass |
|
|
|
def infer_log_probs(self, dataset) -> List[Dict]: |
|
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens. |
|
|
|
For each instance , returns a list of top tokens per position. |
|
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ] |
|
|
|
""" |
|
[self.verify_instance(instance) for instance in dataset] |
|
return self._infer_log_probs(dataset) |
|
|
|
|
|
class LazyLoadMixin(Artifact): |
|
lazy_load: bool = NonPositionalField(default=False) |
|
|
|
@abc.abstractmethod |
|
def _is_loaded(self): |
|
pass |
|
|
|
|
|
class HFPipelineBasedInferenceEngine( |
|
InferenceEngine, PackageRequirementsMixin, LazyLoadMixin |
|
): |
|
model_name: str |
|
max_new_tokens: int |
|
use_fp16: bool = True |
|
|
|
_requirements_list = { |
|
"transformers": "Install huggingface package using 'pip install --upgrade transformers" |
|
} |
|
|
|
def _prepare_pipeline(self): |
|
import torch |
|
from transformers import AutoConfig, pipeline |
|
|
|
model_args: Dict[str, Any] = ( |
|
{"torch_dtype": torch.float16} if self.use_fp16 else {} |
|
) |
|
model_args.update({"max_new_tokens": self.max_new_tokens}) |
|
|
|
device = torch.device( |
|
"mps" |
|
if torch.backends.mps.is_available() |
|
else 0 |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
|
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
assert device == torch.device(0) |
|
model_args.update({"device_map": "auto"}) |
|
else: |
|
model_args.update({"device": device}) |
|
|
|
task = ( |
|
"text2text-generation" |
|
if AutoConfig.from_pretrained( |
|
self.model_name, trust_remote_code=True |
|
).is_encoder_decoder |
|
else "text-generation" |
|
) |
|
|
|
if task == "text-generation": |
|
model_args.update({"return_full_text": False}) |
|
|
|
self.model = pipeline( |
|
model=self.model_name, trust_remote_code=True, **model_args |
|
) |
|
|
|
def prepare(self): |
|
if not self.lazy_load: |
|
self._prepare_pipeline() |
|
|
|
def _is_loaded(self): |
|
return hasattr(self, "model") and self.model is not None |
|
|
|
def _infer(self, dataset): |
|
if not self._is_loaded(): |
|
self._prepare_pipeline() |
|
|
|
outputs = [] |
|
for output in self.model([instance["source"] for instance in dataset]): |
|
if isinstance(output, list): |
|
output = output[0] |
|
outputs.append(output["generated_text"]) |
|
return outputs |
|
|
|
|
|
class MockInferenceEngine(InferenceEngine): |
|
model_name: str |
|
|
|
def prepare(self): |
|
return |
|
|
|
def _infer(self, dataset): |
|
return ["[[10]]" for instance in dataset] |
|
|
|
|
|
class IbmGenAiInferenceEngineParamsMixin(Artifact): |
|
beam_width: Optional[int] = None |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
include_stop_sequence: Optional[bool] = None |
|
length_penalty: Any = None |
|
max_new_tokens: Optional[int] = None |
|
min_new_tokens: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
return_options: Any = None |
|
stop_sequences: Optional[List[str]] = None |
|
temperature: Optional[float] = None |
|
time_limit: Optional[int] = None |
|
top_k: Optional[int] = None |
|
top_p: Optional[float] = None |
|
truncate_input_tokens: Optional[int] = None |
|
typical_p: Optional[float] = None |
|
|
|
|
|
@deprecation(version="2.0.0", alternative=IbmGenAiInferenceEngineParamsMixin) |
|
class IbmGenAiInferenceEngineParams(Artifact): |
|
beam_width: Optional[int] = None |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
include_stop_sequence: Optional[bool] = None |
|
length_penalty: Any = None |
|
max_new_tokens: Optional[int] = None |
|
min_new_tokens: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
return_options: Any = None |
|
stop_sequences: Optional[List[str]] = None |
|
temperature: Optional[float] = None |
|
time_limit: Optional[int] = None |
|
top_k: Optional[int] = None |
|
top_p: Optional[float] = None |
|
truncate_input_tokens: Optional[int] = None |
|
typical_p: Optional[float] = None |
|
|
|
|
|
class IbmGenAiInferenceEngine( |
|
InferenceEngine, IbmGenAiInferenceEngineParamsMixin, PackageRequirementsMixin |
|
): |
|
label: str = "ibm_genai" |
|
model_name: str |
|
_requirements_list = { |
|
"genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai" |
|
} |
|
data_classification_policy = ["public", "proprietary"] |
|
parameters: Optional[IbmGenAiInferenceEngineParams] = None |
|
|
|
def prepare(self): |
|
from genai import Client, Credentials |
|
|
|
api_key_env_var_name = "GENAI_KEY" |
|
api_key = os.environ.get(api_key_env_var_name) |
|
assert api_key is not None, ( |
|
f"Error while trying to run IbmGenAiInferenceEngine." |
|
f" Please set the environment param '{api_key_env_var_name}'." |
|
) |
|
credentials = Credentials(api_key=api_key) |
|
self.client = Client(credentials=credentials) |
|
|
|
self._set_inference_parameters() |
|
|
|
def _infer(self, dataset): |
|
from genai.schema import TextGenerationParameters |
|
|
|
genai_params = TextGenerationParameters( |
|
**self.to_dict([IbmGenAiInferenceEngineParamsMixin]) |
|
) |
|
|
|
return [ |
|
response.results[0].generated_text |
|
for response in self.client.text.generation.create( |
|
model_id=self.model_name, |
|
inputs=[instance["source"] for instance in dataset], |
|
parameters=genai_params, |
|
) |
|
] |
|
|
|
|
|
class OpenAiInferenceEngineParamsMixin(Artifact): |
|
frequency_penalty: Optional[float] = None |
|
presence_penalty: Optional[float] = None |
|
max_tokens: Optional[int] = None |
|
seed: Optional[int] = None |
|
stop: Union[Optional[str], List[str]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_logprobs: Optional[int] = 20 |
|
logit_bias: Optional[Dict[str, int]] = None |
|
logprobs: Optional[bool] = None |
|
n: Optional[int] = None |
|
parallel_tool_calls: bool = None |
|
service_tier: Optional[Literal["auto", "default"]] = None |
|
|
|
|
|
@deprecation(version="2.0.0", alternative=OpenAiInferenceEngineParamsMixin) |
|
class OpenAiInferenceEngineParams(Artifact): |
|
frequency_penalty: Optional[float] = None |
|
presence_penalty: Optional[float] = None |
|
max_tokens: Optional[int] = None |
|
seed: Optional[int] = None |
|
stop: Union[Optional[str], List[str]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_logprobs: Optional[int] = 20 |
|
logit_bias: Optional[Dict[str, int]] = None |
|
logprobs: Optional[bool] = None |
|
n: Optional[int] = None |
|
parallel_tool_calls: bool = None |
|
service_tier: Optional[Literal["auto", "default"]] = None |
|
|
|
|
|
class OpenAiInferenceEngine( |
|
InferenceEngine, |
|
LogProbInferenceEngine, |
|
OpenAiInferenceEngineParamsMixin, |
|
PackageRequirementsMixin, |
|
): |
|
label: str = "openai" |
|
model_name: str |
|
_requirements_list = { |
|
"openai": "Install openai package using 'pip install --upgrade openai" |
|
} |
|
data_classification_policy = ["public"] |
|
parameters: Optional[OpenAiInferenceEngineParams] = None |
|
|
|
def prepare(self): |
|
from openai import OpenAI |
|
|
|
api_key_env_var_name = "OPENAI_API_KEY" |
|
api_key = os.environ.get(api_key_env_var_name) |
|
assert api_key is not None, ( |
|
f"Error while trying to run OpenAiInferenceEngine." |
|
f" Please set the environment param '{api_key_env_var_name}'." |
|
) |
|
|
|
self.client = OpenAI(api_key=api_key) |
|
|
|
self._set_inference_parameters() |
|
|
|
def _infer(self, dataset): |
|
outputs = [] |
|
for instance in tqdm(dataset, desc="Inferring with openAI API"): |
|
response = self.client.chat.completions.create( |
|
messages=[ |
|
|
|
|
|
|
|
|
|
{ |
|
"role": "user", |
|
"content": instance["source"], |
|
} |
|
], |
|
model=self.model_name, |
|
**self.to_dict([OpenAiInferenceEngineParamsMixin]), |
|
) |
|
output = response.choices[0].message.content |
|
|
|
outputs.append(output) |
|
|
|
return outputs |
|
|
|
def _infer_log_probs(self, dataset): |
|
outputs = [] |
|
for instance in tqdm(dataset, desc="Inferring with openAI API"): |
|
response = self.client.chat.completions.create( |
|
messages=[ |
|
|
|
|
|
|
|
|
|
{ |
|
"role": "user", |
|
"content": instance["source"], |
|
} |
|
], |
|
model=self.model_name, |
|
**self.to_dict([OpenAiInferenceEngineParamsMixin]), |
|
) |
|
top_logprobs_response = response.choices[0].logprobs.content |
|
output = [ |
|
{ |
|
"top_tokens": [ |
|
{"text": obj.token, "logprob": obj.logprob} |
|
for obj in generated_token.top_logprobs |
|
] |
|
} |
|
for generated_token in top_logprobs_response |
|
] |
|
outputs.append(output) |
|
return outputs |
|
|
|
|
|
class WMLInferenceEngineParamsMixin(Artifact): |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
length_penalty: Optional[Dict[str, Union[int, float]]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_k: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
min_new_tokens: Optional[int] = None |
|
max_new_tokens: Optional[int] = None |
|
stop_sequences: Optional[List[str]] = None |
|
time_limit: Optional[int] = None |
|
truncate_input_tokens: Optional[int] = None |
|
prompt_variables: Optional[Dict[str, Any]] = None |
|
return_options: Optional[Dict[str, bool]] = None |
|
|
|
|
|
@deprecation(version="2.0.0", alternative=WMLInferenceEngineParamsMixin) |
|
class WMLInferenceEngineParams(Artifact): |
|
decoding_method: Optional[Literal["greedy", "sample"]] = None |
|
length_penalty: Optional[Dict[str, Union[int, float]]] = None |
|
temperature: Optional[float] = None |
|
top_p: Optional[float] = None |
|
top_k: Optional[int] = None |
|
random_seed: Optional[int] = None |
|
repetition_penalty: Optional[float] = None |
|
min_new_tokens: Optional[int] = None |
|
max_new_tokens: Optional[int] = None |
|
stop_sequences: Optional[List[str]] = None |
|
time_limit: Optional[int] = None |
|
truncate_input_tokens: Optional[int] = None |
|
prompt_variables: Optional[Dict[str, Any]] = None |
|
return_options: Optional[Dict[str, bool]] = None |
|
|
|
|
|
class WMLInferenceEngine( |
|
InferenceEngine, WMLInferenceEngineParamsMixin, PackageRequirementsMixin |
|
): |
|
"""Runs inference using ibm-watsonx-ai. |
|
|
|
Attributes: |
|
credentials (Dict[str, str], optional): By default, it is created by a class |
|
instance which tries to retrieve proper environment variables |
|
("WML_URL", "WML_PROJECT_ID", "WML_APIKEY"). However, a dictionary with |
|
the following keys: "url", "apikey", "project_id" can be directly provided |
|
instead. |
|
model_name (str, optional): ID of a model to be used for inference. Mutually |
|
exclusive with 'deployment_id'. |
|
deployment_id (str, optional): Deployment ID of a tuned model to be used for |
|
inference. Mutually exclusive with 'model_name'. |
|
parameters (WMLInferenceEngineParams, optional): Instance of WMLInferenceEngineParams |
|
which defines inference parameters and their values. Deprecated attribute, please |
|
pass respective parameters directly to the WMLInferenceEngine class instead. |
|
|
|
Examples: |
|
from .api import load_dataset |
|
|
|
wml_credentials = { |
|
"url": "some_url", "project_id": "some_id", "api_key": "some_key" |
|
} |
|
model_name = "google/flan-t5-xxl" |
|
wml_inference = WMLInferenceEngine( |
|
credentials=wml_credentials, |
|
model_name=model_name, |
|
data_classification_policy=["public"], |
|
top_p=0.5, |
|
random_seed=123, |
|
) |
|
|
|
dataset = load_dataset( |
|
dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5" |
|
) |
|
results = wml_inference.infer(dataset["test"]) |
|
""" |
|
|
|
credentials: Optional[Dict[Literal["url", "apikey", "project_id"], str]] = None |
|
model_name: Optional[str] = None |
|
deployment_id: Optional[str] = None |
|
label: str = "wml" |
|
_requirements_list = { |
|
"ibm_watsonx_ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. " |
|
"It is advised to have Python version >=3.10 installed, as at lower version this package " |
|
"may cause conflicts with other installed packages." |
|
} |
|
data_classification_policy = ["public", "proprietary"] |
|
parameters: Optional[WMLInferenceEngineParams] = None |
|
|
|
_client: Any = InternalField(default=None, name="WML client") |
|
|
|
def verify(self): |
|
super().verify() |
|
|
|
if self.credentials is not None: |
|
for key in self.credentials: |
|
if key not in ["url", "apikey", "project_id"]: |
|
raise ValueError( |
|
f'Illegal credential key: {key}, use only ["url", "apikey", "project_id"]' |
|
) |
|
|
|
assert ( |
|
self.model_name |
|
or self.deployment_id |
|
and not (self.model_name and self.deployment_id) |
|
), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time." |
|
|
|
def process_data_before_dump(self, data): |
|
if "credentials" in data: |
|
for key, value in data["credentials"].items(): |
|
if key != "url": |
|
data["credentials"][key] = "<hidden>" |
|
else: |
|
data["credentials"][key] = value |
|
return data |
|
|
|
@staticmethod |
|
def _read_wml_credentials_from_env() -> ( |
|
Dict[Literal["url", "apikey", "project_id"], str] |
|
): |
|
credentials = {} |
|
for env_var_name in ["WML_URL", "WML_PROJECT_ID", "WML_APIKEY"]: |
|
env_var = os.environ.get(env_var_name) |
|
assert env_var, ( |
|
f"Error while trying to run 'WMLInferenceEngine'. " |
|
f"Please set the env variable: '{env_var_name}', or " |
|
f"directly provide an instance of ibm-watsonx-ai 'Credentials' " |
|
f"to the engine." |
|
) |
|
|
|
name = env_var_name.lower().replace("wml_", "") |
|
credentials[name] = env_var |
|
|
|
return credentials |
|
|
|
def _initialize_wml_client(self): |
|
from ibm_watsonx_ai.client import APIClient |
|
|
|
if self.credentials is None: |
|
self.credentials = self._read_wml_credentials_from_env() |
|
|
|
client = APIClient(credentials=self.credentials) |
|
client.set.default_project(self.credentials["project_id"]) |
|
return client |
|
|
|
def prepare(self): |
|
self._client = self._initialize_wml_client() |
|
|
|
self._set_inference_parameters() |
|
|
|
def _infer(self, dataset): |
|
from ibm_watsonx_ai.foundation_models import ModelInference |
|
|
|
model = ModelInference( |
|
model_id=self.model_name, |
|
deployment_id=self.deployment_id, |
|
api_client=self._client, |
|
) |
|
|
|
return model.generate_text( |
|
prompt=dataset["source"], |
|
params=self.to_dict([WMLInferenceEngineParamsMixin], keep_empty=False), |
|
) |
|
|
|
|
|
class HFLlavaInferenceEngine(InferenceEngine, LazyLoadMixin): |
|
model_name: str |
|
max_new_tokens: int |
|
lazy_load = True |
|
|
|
_requirements_list = { |
|
"transformers": "Install huggingface package using 'pip install --upgrade transformers", |
|
"torch": "Install torch, go on PyTorch website for mode details.", |
|
"accelerate": "pip install accelerate", |
|
} |
|
|
|
def _prepare_engine(self): |
|
import torch |
|
from transformers import AutoProcessor, LlavaForConditionalGeneration |
|
|
|
self.device = torch.device( |
|
"mps" |
|
if torch.backends.mps.is_available() |
|
else 0 |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
|
|
self.model = LlavaForConditionalGeneration.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
).to(self.device) |
|
|
|
self.processor = AutoProcessor.from_pretrained(self.model_name) |
|
|
|
def prepare(self): |
|
if not self.lazy_load: |
|
self._prepare_engine() |
|
|
|
def _is_loaded(self): |
|
return hasattr(self, "model") and self.model is not None |
|
|
|
def _infer(self, dataset): |
|
if not self._is_loaded(): |
|
self._prepare_engine() |
|
|
|
import torch |
|
|
|
results = [] |
|
for instance in dataset: |
|
text = instance["source"] |
|
images = extract_images(text, instance) |
|
|
|
regex = r'<img\s+src=["\'](.*?)["\']\s*/?>' |
|
model_input = re.sub(regex, "<image>", text) |
|
if len(images) == 1: |
|
images = images[0] |
|
inputs = self.processor( |
|
images=images, text=model_input, return_tensors="pt" |
|
).to(self.device, torch.float16) |
|
input_len = len(inputs["input_ids"][0]) |
|
output = self.model.generate( |
|
**inputs, max_new_tokens=self.max_new_tokens, do_sample=False |
|
) |
|
result = self.processor.decode( |
|
output[0][input_len:], skip_special_tokens=True |
|
) |
|
results.append(result) |
|
|
|
return results |
|
|