File size: 15,001 Bytes
fe8dcb5 59be457 b462f85 fe8dcb5 058c80a fe8dcb5 59be457 fe8dcb5 0a1b314 fe8dcb5 058c80a 0a1b314 fe8dcb5 058c80a 59be457 fe8dcb5 b462f85 59be457 fe8dcb5 b462f85 fe8dcb5 b462f85 fe8dcb5 0a1b314 b462f85 0a1b314 b462f85 59be457 b462f85 59be457 b462f85 59be457 0a1b314 59be457 100c2eb 59be457 0a1b314 59be457 b462f85 100c2eb 59be457 100c2eb 59be457 b462f85 59be457 058c80a 59be457 058c80a 59be457 b462f85 59be457 058c80a 59be457 0a1b314 058c80a 59be457 058c80a 59be457 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 |
import abc
import os
from dataclasses import field
from typing import Any, Dict, List, Literal, Optional, Union
from tqdm import tqdm
from .artifact import Artifact
from .operator import PackageRequirementsMixin
class InferenceEngine(abc.ABC, Artifact):
"""Abstract base class for inference."""
@abc.abstractmethod
def _infer(self, dataset):
"""Perform inference on the input dataset."""
pass
def infer(self, dataset) -> str:
"""Verifies instances of a dataset and performs inference."""
[self.verify_instance(instance) for instance in dataset]
return self._infer(dataset)
class LogProbInferenceEngine(abc.ABC, Artifact):
"""Abstract base class for inference with log probs."""
@abc.abstractmethod
def _infer_log_probs(self, dataset):
"""Perform inference on the input dataset that returns log probs."""
pass
def infer_log_probs(self, dataset) -> List[Dict]:
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
For each instance , returns a list of top tokens per position.
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]
"""
[self.verify_instance(instance) for instance in dataset]
return self._infer_log_probs(dataset)
class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
model_name: str
max_new_tokens: int
use_fp16: bool = True
_requirement = {
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
}
def prepare(self):
import torch
from transformers import AutoConfig, pipeline
model_args: Dict[str, Any] = (
{"torch_dtype": torch.float16} if self.use_fp16 else {}
)
model_args.update({"max_new_tokens": self.max_new_tokens})
device = torch.device(
"mps"
if torch.backends.mps.is_available()
else 0
if torch.cuda.is_available()
else "cpu"
)
# We do this, because in some cases, using device:auto will offload some weights to the cpu
# (even though the model might *just* fit to a single gpu), even if there is a gpu available, and this will
# cause an error because the data is always on the gpu
if torch.cuda.device_count() > 1:
assert device == torch.device(0)
model_args.update({"device_map": "auto"})
else:
model_args.update({"device": device})
task = (
"text2text-generation"
if AutoConfig.from_pretrained(
self.model_name, trust_remote_code=True
).is_encoder_decoder
else "text-generation"
)
if task == "text-generation":
model_args.update({"return_full_text": False})
self.model = pipeline(
model=self.model_name, trust_remote_code=True, **model_args
)
def _infer(self, dataset):
outputs = []
for output in self.model([instance["source"] for instance in dataset]):
if isinstance(output, list):
output = output[0]
outputs.append(output["generated_text"])
return outputs
class MockInferenceEngine(InferenceEngine):
model_name: str
def prepare(self):
return
def _infer(self, dataset):
return ["[[10]]" for instance in dataset]
class IbmGenAiInferenceEngineParams(Artifact):
decoding_method: Optional[Literal["greedy", "sample"]] = None
max_new_tokens: Optional[int] = None
min_new_tokens: Optional[int] = None
random_seed: Optional[int] = None
repetition_penalty: Optional[float] = None
stop_sequences: Optional[List[str]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
typical_p: Optional[float] = None
class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
label: str = "ibm_genai"
model_name: str
parameters: IbmGenAiInferenceEngineParams = field(
default_factory=IbmGenAiInferenceEngineParams
)
_requirement = {
"genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
}
data_classification_policy = ["public", "proprietary"]
def prepare(self):
from genai import Client, Credentials
api_key_env_var_name = "GENAI_KEY"
api_key = os.environ.get(api_key_env_var_name)
assert api_key is not None, (
f"Error while trying to run IbmGenAiInferenceEngine."
f" Please set the environment param '{api_key_env_var_name}'."
)
credentials = Credentials(api_key=api_key)
self.client = Client(credentials=credentials)
def _infer(self, dataset):
from genai.schema import TextGenerationParameters
genai_params = TextGenerationParameters(
max_new_tokens=self.parameters.max_new_tokens,
min_new_tokens=self.parameters.min_new_tokens,
random_seed=self.parameters.random_seed,
repetition_penalty=self.parameters.repetition_penalty,
stop_sequences=self.parameters.stop_sequences,
temperature=self.parameters.temperature,
top_p=self.parameters.top_p,
top_k=self.parameters.top_k,
typical_p=self.parameters.typical_p,
decoding_method=self.parameters.decoding_method,
)
return [
response.results[0].generated_text
for response in self.client.text.generation.create(
model_id=self.model_name,
inputs=[instance["source"] for instance in dataset],
parameters=genai_params,
)
]
class OpenAiInferenceEngineParams(Artifact):
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Union[Optional[str], List[str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_logprobs: Optional[int] = 20
class OpenAiInferenceEngine(
InferenceEngine, LogProbInferenceEngine, PackageRequirementsMixin
):
label: str = "openai"
model_name: str
parameters: OpenAiInferenceEngineParams = field(
default_factory=OpenAiInferenceEngineParams
)
_requirement = {
"openai": "Install openai package using 'pip install --upgrade openai"
}
data_classification_policy = ["public"]
def prepare(self):
from openai import OpenAI
api_key_env_var_name = "OPENAI_API_KEY"
api_key = os.environ.get(api_key_env_var_name)
assert api_key is not None, (
f"Error while trying to run OpenAiInferenceEngine."
f" Please set the environment param '{api_key_env_var_name}'."
)
self.client = OpenAI(api_key=api_key)
def _infer(self, dataset):
outputs = []
for instance in tqdm(dataset, desc="Inferring with openAI API"):
response = self.client.chat.completions.create(
messages=[
# {
# "role": "system",
# "content": self.system_prompt,
# },
{
"role": "user",
"content": instance["source"],
}
],
model=self.model_name,
frequency_penalty=self.parameters.frequency_penalty,
presence_penalty=self.parameters.presence_penalty,
max_tokens=self.parameters.max_tokens,
seed=self.parameters.seed,
stop=self.parameters.stop,
temperature=self.parameters.temperature,
top_p=self.parameters.top_p,
)
output = response.choices[0].message.content
outputs.append(output)
return outputs
def _infer_log_probs(self, dataset):
outputs = []
for instance in tqdm(dataset, desc="Inferring with openAI API"):
response = self.client.chat.completions.create(
messages=[
# {
# "role": "system",
# "content": self.system_prompt,
# },
{
"role": "user",
"content": instance["source"],
}
],
model=self.model_name,
frequency_penalty=self.parameters.frequency_penalty,
presence_penalty=self.parameters.presence_penalty,
max_tokens=self.parameters.max_tokens,
seed=self.parameters.seed,
stop=self.parameters.stop,
temperature=self.parameters.temperature,
top_p=self.parameters.top_p,
logprobs=True,
top_logprobs=self.parameters.top_logprobs,
)
top_logprobs_response = response.choices[0].logprobs.content
output = [
{
"top_tokens": [
{"text": obj.token, "logprob": obj.logprob}
for obj in generated_token.top_logprobs
]
}
for generated_token in top_logprobs_response
]
outputs.append(output)
return outputs
class WMLInferenceEngineParams(Artifact):
decoding_method: Optional[Literal["greedy", "sample"]] = None
length_penalty: Optional[Dict[str, Union[int, float]]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
random_seed: Optional[int] = None
repetition_penalty: Optional[float] = None
min_new_tokens: Optional[int] = None
max_new_tokens: Optional[int] = None
stop_sequences: Optional[List[str]] = None
time_limit: Optional[int] = None
truncate_input_tokens: Optional[int] = None
prompt_variables: Optional[Dict[str, Any]] = None
return_options: Optional[Dict[str, bool]] = None
def initialize_wml_parameters(self) -> Dict[str, Any]:
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
return {
param_name.upper(): param_value
for param_name, param_value in self.to_dict().items()
if param_value and param_name.upper() in GenTextParamsMetaNames().get()
}
class WMLInferenceEngine(InferenceEngine, PackageRequirementsMixin):
"""Runs inference using ibm-watsonx-ai.
Attributes:
client: By default, it is created by a class instance but can be directly
provided instead as an instance of 'ibm_watsonx_ai.client.APIClient'.
credentials: By default, it is created by a class instance which tries to retrieve
proper environment variables ("WML_URL", "WML_PROJECT_ID", "WML_APIKEY").
However, either a dictionary with the following keys: "url", "apikey",
"project_id", or an instance of 'ibm_watsonx_ai.credentials.Credentials'
can be directly provided instead.
model_name (str, optional): ID of a model to be used for inference. Mutually
exclusive with 'deployment_id'.
deployment_id (str, optional): Deployment ID of a tuned model to be used for
inference. Mutually exclusive with 'model_name'.
parameters (WMLInferenceEngineParams): An instance of 'WMLInferenceEngineParams'
which defines parameters used for inference. All the parameters are optional.
Examples:
from .api import load_dataset
wml_parameters = WMLInferenceEngineParams(top_p=0.5, random_seed=123)
wml_credentials = {
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
}
model_name = "google/flan-t5-xxl"
wml_inference = WMLInferenceEngine(
credentials=wml_credentials,
parameters=wml_parameters,
model_name=model_name,
)
dataset = load_dataset(
dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5"
)
results = wml_inference.infer(dataset["test"])
"""
client = None
credentials = None
model_name: Optional[str] = None
deployment_id: Optional[str] = None
parameters: WMLInferenceEngineParams = field(
default_factory=WMLInferenceEngineParams
)
_parameters: Dict[str, Any] = field(default_factory=dict)
label: str = "wml"
_requirement = {
"ibm-watsonx-ai": "Install ibm-watsonx-ai package using 'pip install --upgrade ibm-watsonx-ai'. "
"It is advised to have Python version >=3.10 installed, as at lower version this package "
"may cause conflicts with other installed packages."
}
data_classification_policy = ["proprietary"]
@staticmethod
def _read_wml_credentials_from_env() -> Dict[str, str]:
credentials = {}
for env_var_name in ["WML_URL", "WML_PROJECT_ID", "WML_APIKEY"]:
env_var = os.environ.get(env_var_name)
assert env_var, (
f"Error while trying to run 'WMLInferenceEngine'. "
f"Please set the env variable: '{env_var_name}', or "
f"directly provide an instance of ibm-watsonx-ai 'Credentials' "
f"to the engine."
)
name = env_var_name.lower().replace("wml_", "")
credentials[name] = env_var
return credentials
def _initialize_wml_client(self):
from ibm_watsonx_ai.client import APIClient
if self.credentials is None:
self.credentials = self._read_wml_credentials_from_env()
client = APIClient(credentials=self.credentials)
client.set.default_project(self.credentials["project_id"])
return client
def prepare(self):
if self.client is None:
self.client = self._initialize_wml_client()
self._parameters = self.parameters.initialize_wml_parameters()
def verify(self):
assert (
self.model_name
or self.deployment_id
and not (self.model_name and self.deployment_id)
), "Either 'model_name' or 'deployment_id' must be specified, but not both at the same time."
super().verify()
def _infer(self, dataset):
from ibm_watsonx_ai.foundation_models import ModelInference
model = ModelInference(
model_id=self.model_name,
deployment_id=self.deployment_id,
api_client=self.client,
)
return [
model.generate_text(
prompt=instance["source"],
params=self._parameters,
)
for instance in dataset
]
|