Spaces:
Runtime error
Runtime error
"""Wrapper around Aleph Alpha APIs.""" | |
from typing import Any, Dict, List, Optional, Sequence | |
from pydantic import BaseModel, Extra, root_validator | |
from langchain.llms.base import LLM | |
from langchain.llms.utils import enforce_stop_tokens | |
from langchain.utils import get_from_dict_or_env | |
class AlephAlpha(LLM, BaseModel): | |
"""Wrapper around Aleph Alpha large language models. | |
To use, you should have the ``aleph_alpha_client`` python package installed, and the | |
environment variable ``ALEPH_ALPHA_API_KEY`` set with your API key, or pass | |
it as a named parameter to the constructor. | |
Parameters are explained more in depth here: | |
https://github.com/Aleph-Alpha/aleph-alpha-client/blob/c14b7dd2b4325c7da0d6a119f6e76385800e097b/aleph_alpha_client/completion.py#L10 | |
Example: | |
.. code-block:: python | |
from langchain.llms import AlephAlpha | |
alpeh_alpha = AlephAlpha(aleph_alpha_api_key="my-api-key") | |
""" | |
client: Any #: :meta private: | |
model: Optional[str] = "luminous-base" | |
"""Model name to use.""" | |
maximum_tokens: int = 64 | |
"""The maximum number of tokens to be generated.""" | |
temperature: float = 0.0 | |
"""A non-negative float that tunes the degree of randomness in generation.""" | |
top_k: int = 0 | |
"""Number of most likely tokens to consider at each step.""" | |
top_p: float = 0.0 | |
"""Total probability mass of tokens to consider at each step.""" | |
presence_penalty: float = 0.0 | |
"""Penalizes repeated tokens.""" | |
frequency_penalty: float = 0.0 | |
"""Penalizes repeated tokens according to frequency.""" | |
repetition_penalties_include_prompt: Optional[bool] = False | |
"""Flag deciding whether presence penalty or frequency penalty are | |
updated from the prompt.""" | |
use_multiplicative_presence_penalty: Optional[bool] = False | |
"""Flag deciding whether presence penalty is applied | |
multiplicatively (True) or additively (False).""" | |
penalty_bias: Optional[str] = None | |
"""Penalty bias for the completion.""" | |
penalty_exceptions: Optional[List[str]] = None | |
"""List of strings that may be generated without penalty, | |
regardless of other penalty settings""" | |
penalty_exceptions_include_stop_sequences: Optional[bool] = None | |
"""Should stop_sequences be included in penalty_exceptions.""" | |
best_of: Optional[int] = None | |
"""returns the one with the "best of" results | |
(highest log probability per token) | |
""" | |
n: int = 1 | |
"""How many completions to generate for each prompt.""" | |
logit_bias: Optional[Dict[int, float]] = None | |
"""The logit bias allows to influence the likelihood of generating tokens.""" | |
log_probs: Optional[int] = None | |
"""Number of top log probabilities to be returned for each generated token.""" | |
tokens: Optional[bool] = False | |
"""return tokens of completion.""" | |
disable_optimizations: Optional[bool] = False | |
minimum_tokens: Optional[int] = 0 | |
"""Generate at least this number of tokens.""" | |
echo: bool = False | |
"""Echo the prompt in the completion.""" | |
use_multiplicative_frequency_penalty: bool = False | |
sequence_penalty: float = 0.0 | |
sequence_penalty_min_length: int = 2 | |
use_multiplicative_sequence_penalty: bool = False | |
completion_bias_inclusion: Optional[Sequence[str]] = None | |
completion_bias_inclusion_first_token_only: bool = False | |
completion_bias_exclusion: Optional[Sequence[str]] = None | |
completion_bias_exclusion_first_token_only: bool = False | |
"""Only consider the first token for the completion_bias_exclusion.""" | |
contextual_control_threshold: Optional[float] = None | |
"""If set to None, attention control parameters only apply to those tokens that have | |
explicitly been set in the request. | |
If set to a non-None value, control parameters are also applied to similar tokens. | |
""" | |
control_log_additive: Optional[bool] = True | |
"""True: apply control by adding the log(control_factor) to attention scores. | |
False: (attention_scores - - attention_scores.min(-1)) * control_factor | |
""" | |
repetition_penalties_include_completion: bool = True | |
"""Flag deciding whether presence penalty or frequency penalty | |
are updated from the completion.""" | |
raw_completion: bool = False | |
"""Force the raw completion of the model to be returned.""" | |
aleph_alpha_api_key: Optional[str] = None | |
"""API key for Aleph Alpha API.""" | |
stop_sequences: Optional[List[str]] = None | |
"""Stop sequences to use.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that api key and python package exists in environment.""" | |
aleph_alpha_api_key = get_from_dict_or_env( | |
values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY" | |
) | |
try: | |
import aleph_alpha_client | |
values["client"] = aleph_alpha_client.Client(token=aleph_alpha_api_key) | |
except ImportError: | |
raise ValueError( | |
"Could not import aleph_alpha_client python package. " | |
"Please it install it with `pip install aleph_alpha_client`." | |
) | |
return values | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters for calling the Aleph Alpha API.""" | |
return { | |
"maximum_tokens": self.maximum_tokens, | |
"temperature": self.temperature, | |
"top_k": self.top_k, | |
"top_p": self.top_p, | |
"presence_penalty": self.presence_penalty, | |
"frequency_penalty": self.frequency_penalty, | |
"n": self.n, | |
"repetition_penalties_include_prompt": self.repetition_penalties_include_prompt, # noqa: E501 | |
"use_multiplicative_presence_penalty": self.use_multiplicative_presence_penalty, # noqa: E501 | |
"penalty_bias": self.penalty_bias, | |
"penalty_exceptions": self.penalty_exceptions, | |
"penalty_exceptions_include_stop_sequences": self.penalty_exceptions_include_stop_sequences, # noqa: E501 | |
"best_of": self.best_of, | |
"logit_bias": self.logit_bias, | |
"log_probs": self.log_probs, | |
"tokens": self.tokens, | |
"disable_optimizations": self.disable_optimizations, | |
"minimum_tokens": self.minimum_tokens, | |
"echo": self.echo, | |
"use_multiplicative_frequency_penalty": self.use_multiplicative_frequency_penalty, # noqa: E501 | |
"sequence_penalty": self.sequence_penalty, | |
"sequence_penalty_min_length": self.sequence_penalty_min_length, | |
"use_multiplicative_sequence_penalty": self.use_multiplicative_sequence_penalty, # noqa: E501 | |
"completion_bias_inclusion": self.completion_bias_inclusion, | |
"completion_bias_inclusion_first_token_only": self.completion_bias_inclusion_first_token_only, # noqa: E501 | |
"completion_bias_exclusion": self.completion_bias_exclusion, | |
"completion_bias_exclusion_first_token_only": self.completion_bias_exclusion_first_token_only, # noqa: E501 | |
"contextual_control_threshold": self.contextual_control_threshold, | |
"control_log_additive": self.control_log_additive, | |
"repetition_penalties_include_completion": self.repetition_penalties_include_completion, # noqa: E501 | |
"raw_completion": self.raw_completion, | |
} | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
return {**{"model": self.model}, **self._default_params} | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "alpeh_alpha" | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
"""Call out to Aleph Alpha's completion endpoint. | |
Args: | |
prompt: The prompt to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
The string generated by the model. | |
Example: | |
.. code-block:: python | |
response = alpeh_alpha("Tell me a joke.") | |
""" | |
from aleph_alpha_client import CompletionRequest, Prompt | |
params = self._default_params | |
if self.stop_sequences is not None and stop is not None: | |
raise ValueError( | |
"stop sequences found in both the input and default params." | |
) | |
elif self.stop_sequences is not None: | |
params["stop_sequences"] = self.stop_sequences | |
else: | |
params["stop_sequences"] = stop | |
request = CompletionRequest(prompt=Prompt.from_text(prompt), **params) | |
response = self.client.complete(model=self.model, request=request) | |
text = response.completions[0].completion | |
# If stop tokens are provided, Aleph Alpha's endpoint returns them. | |
# In order to make this consistent with other endpoints, we strip them. | |
if stop is not None or self.stop_sequences is not None: | |
text = enforce_stop_tokens(text, params["stop_sequences"]) | |
return text | |