BayesTensor's picture
Upload folder using huggingface_hub
9d5b280 verified
import copy
import json
import os
from functools import lru_cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import JsonChatStr
from lm_eval.utils import eval_logger, simple_parse_args_string
class LogLikelihoodResult(NamedTuple):
log_likelihood: float
is_greedy: bool
def _verify_credentials(creds: Any) -> None:
"""
Verifies that all required keys are present in the credentials dictionary.
Args:
creds (Any): A dictionary containing the credentials.
Raises:
ValueError: If any of the necessary credentials are missing, with guidance on which environment variables need to be set.
"""
required_keys = ["apikey", "url", "project_id"]
env_var_mapping = {
"apikey": "WATSONX_API_KEY",
"url": "WATSONX_URL",
"project_id": "WATSONX_PROJECT_ID",
}
missing_keys = [key for key in required_keys if key not in creds or not creds[key]]
if missing_keys:
missing_env_vars = [env_var_mapping[key] for key in missing_keys]
raise ValueError(
f"Missing required credentials: {', '.join(missing_keys)}. Please set the following environment variables: {', '.join(missing_env_vars)}"
)
@lru_cache(maxsize=None)
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
Returns:
Dict[str, str]: A dictionary containing the credentials necessary for authentication, including
keys such as `apikey`, `url`, and `project_id`.
Raises:
AssertionError: If the credentials format is invalid or any of the necessary credentials are missing.
"""
credentials = {
"apikey": os.getenv("WATSONX_API_KEY", None),
"url": os.getenv("WATSONX_URL", None),
"project_id": os.getenv("WATSONX_PROJECT_ID", None),
}
_verify_credentials(credentials)
return credentials
@register_model("watsonx_llm")
class WatsonxLLM(LM):
"""
Implementation of LM model interface for evaluating Watsonx model with the lm_eval framework.
See https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/model_guide.md for reference.
"""
@classmethod
def create_from_arg_string(
cls: Type["WatsonxLLM"],
arg_string: str,
additional_config: Optional[Dict] = None,
) -> "WatsonxLLM":
"""
Allow the user to specify model parameters (TextGenerationParameters) in CLI arguments.
"""
try:
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
except ImportError:
raise ImportError(
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
args = simple_parse_args_string(arg_string)
args.update(additional_config)
model_id = args.pop("model_id", None)
if model_id is None:
raise ValueError("'model_id' is required, please pass it in 'model_args'")
if not args.get("do_sample", None):
args["temperature"] = None
args["top_p"] = None
args["top_k"] = None
args["seed"] = None
generate_params = {
GenParams.DECODING_METHOD: (
"greedy" if not args.get("do_sample", None) else "sample"
),
GenParams.LENGTH_PENALTY: args.get("length_penalty", None),
GenParams.TEMPERATURE: args.get("temperature", None),
GenParams.TOP_P: args.get("top_p", None),
GenParams.TOP_K: args.get("top_k", None),
GenParams.RANDOM_SEED: args.get("seed", None),
GenParams.REPETITION_PENALTY: args.get("repetition_penalty", None),
GenParams.MIN_NEW_TOKENS: args.get("min_new_tokens", None),
GenParams.MAX_NEW_TOKENS: args.get("max_new_tokens", 256),
GenParams.STOP_SEQUENCES: args.get("stop_sequences", None),
GenParams.TIME_LIMIT: args.get("time_limit", None),
GenParams.TRUNCATE_INPUT_TOKENS: args.get("truncate_input_tokens", None),
GenParams.RETURN_OPTIONS: {
"generated_tokens": True,
"input_tokens": True,
"token_logprobs": True,
"token_ranks": True,
},
}
generate_params = {k: v for k, v in generate_params.items() if v is not None}
return cls(
watsonx_credentials=get_watsonx_credentials(),
model_id=model_id,
generate_params=generate_params,
)
def __init__(
self,
watsonx_credentials: Dict,
model_id,
generate_params: Optional[Dict[Any, Any]] = None,
) -> None:
try:
from ibm_watsonx_ai import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
except ImportError:
raise ImportError(
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
super().__init__()
client = APIClient(watsonx_credentials)
project_id = watsonx_credentials.get("project_id", None)
deployment_id = watsonx_credentials.get("deployment_id", None)
client.set.default_project(project_id)
self.generate_params = generate_params
self.model = ModelInference(
model_id=model_id,
deployment_id=deployment_id,
api_client=client,
project_id=project_id,
)
self._model_id = model_id
@staticmethod
def _has_stop_token(response_tokens: List[str], context_tokens: List[str]) -> bool:
"""
Determines whether a stop token has been generated in the `response_tokens` compared to the `context_tokens`.
If the tokens do not match as expected, the function raises a RuntimeError, indicating a possible
misalignment between the tokens generated by the tokenizer and the model.
Args:
response_tokens (List[str]): The List of tokens generated as a response by the model.
context_tokens (List[str]): The List of tokens representing the input context.
Returns:
bool: True if the `response_tokens` likely contain a stop token that terminates the sequence,
otherwise raises an exception.
Raises:
RuntimeError: If there is an unexpected mismatch between the `response_tokens` and the `context_tokens`.
"""
context_length = len(context_tokens)
if response_tokens[: context_length - 1] == context_tokens[:-1]:
return (
response_tokens[-1] != context_tokens[-1]
) # only last token differs, probably stop sequence (</s>)
raise RuntimeError(
f"There is an unexpected difference between tokenizer and model tokens:\n"
f"context_tokens={context_tokens}\n"
f"response_tokens={response_tokens[:context_length]}"
)
def _check_model_logprobs_support(self):
"""
Verifies if the model supports returning log probabilities for input tokens.
This function sends a prompt to the model and checks whether the model's response
includes log probabilities for the input tokens. If log probabilities are not present,
it raises a `RuntimeError`, indicating that the model is not supported.
Raises:
RuntimeError: If the model does not return log probabilities for input tokens.
"""
tokens = self.model.generate_text(
prompt=["The best ice cream flavor is:"],
params=self.generate_params,
raw_response=True,
)[0]["results"][0]
if all(token.get("logprob", None) is None for token in tokens["input_tokens"]):
raise RuntimeError(
f"Model {self._model_id} is not supported: does not return logprobs for input tokens"
)
def _get_log_likelihood(
self,
input_tokens: List[Dict[str, float]],
context_tokens: List[Dict[str, float]],
) -> LogLikelihoodResult:
"""
Calculates the log likelihood of the generated tokens compared to the context tokens.
Args:
input_tokens (List[Dict[str, float]]): A List of token dictionaries, each containing
token information like `text` and `logprob`.
context_tokens (List[Dict[str, float]]): A List of token dictionaries representing
the input context.
Returns:
LogLikelihoodResult: An object containing the calculated log likelihood and a boolean
flag indicating if the tokens were generated greedily.
"""
response_tokens = [token["text"] for token in input_tokens]
context_length = len(context_tokens)
if self._has_stop_token(response_tokens, context_tokens):
context_length -= 1
return LogLikelihoodResult(
log_likelihood=sum(
token.get("logprob", 0) for token in input_tokens[context_length:]
),
is_greedy=all(
token["rank"] == 1 for token in input_tokens[context_length:]
),
)
def generate_until(self, requests: List[Instance]) -> List[str]:
"""
Generates text responses for a List of requests, with progress tracking and caching.
Args:
requests (List[Instance]): A List of instances, each containing a text input to be processed.
Returns:
List[str]: A List of generated responses.
"""
requests = [request.args for request in requests]
results = []
for request in tqdm(
requests,
desc="Running generate_until function ...",
):
context, continuation = request
try:
if isinstance(context, JsonChatStr):
context = json.loads(context.prompt)
response = self.model.chat(context, self.generate_params)
response = response["choices"][0]["message"]["content"]
else:
response = self.model.generate_text(context, self.generate_params)
except Exception as exp:
eval_logger.error("Error while generating text.")
raise exp
results.append(response)
self.cache_hook.add_partial(
"generate_until", (context, continuation), response
)
return results
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
"""
Args:
requests: Each request contains Instance.args : Tuple[str, str] containing:
1. an input string to the LM and
2. a target string on which the loglikelihood of the LM producing this target,
conditioned on the input, will be returned.
Returns:
Tuple (loglikelihood, is_greedy) for each request according to the input order:
loglikelihood: probability of generating the target string conditioned on the input
is_greedy: True if and only if the target string would be generated by greedy sampling from the LM
"""
try:
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
except ImportError:
raise ImportError(
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
self._check_model_logprobs_support()
generate_params = copy.copy(self.generate_params)
generate_params[GenParams.MAX_NEW_TOKENS] = 1
requests = [request.args for request in requests]
results: List[LogLikelihoodResult] = []
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
for request in tqdm(
requests,
desc="Running loglikelihood function ...",
):
context, continuation = request
try:
tokenized_context = self.model.tokenize(
prompt=context, return_tokens=True
)["result"]["tokens"]
except Exception as exp:
eval_logger.error("Error while model tokenize.")
raise exp
input_prompt = context + continuation
try:
response = self.model.generate_text(
prompt=input_prompt, params=generate_params, raw_response=True
)
except Exception as exp:
eval_logger.error("Error while model generate text.")
raise exp
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], tokenized_context
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood",
(context, continuation),
(
log_likelihood_response.log_likelihood,
log_likelihood_response.is_greedy,
),
)
return cast(List[Tuple[float, bool]], results)
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
"""
Used to evaluate perplexity on a data distribution.
Args:
requests: Each request contains Instance.args : Tuple[str] containing an input string to the model whose
entire loglikelihood, conditioned on purely the EOT token, will be calculated.
Returns:
Tuple (loglikelihood,) for each request according to the input order:
loglikelihood: solely the probability of producing each piece of text given no starting input.
"""
try:
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
except ImportError:
raise ImportError(
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
self._check_model_logprobs_support()
generate_params = copy.deepcopy(self.generate_params)
generate_params[GenParams.MAX_NEW_TOKENS] = 1
requests = [request.args for request in requests]
results: List[LogLikelihoodResult] = []
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
for request in tqdm(
requests,
desc="Running loglikelihood_rolling function ...",
):
context, continuation = request
try:
response = self.model.generate_text(
prompt=context, params=generate_params, raw_response=True
)
except Exception as exp:
eval_logger.error("Error while model generate text.")
raise exp
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], []
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood_rolling",
(context, continuation),
log_likelihood_response.log_likelihood,
)
return cast(List[Tuple[float, bool]], results)
@property
def tokenizer_name(self) -> str:
return ""
def apply_chat_template(
self, chat_history: List[Dict[str, str]]
) -> List[Dict[str, str]]:
# A hack similar from api_model to allow encoding for cache
return JsonChatStr(json.dumps(chat_history))