whisper-large-v3-mlx
/
lib
/python3.11
/site-packages
/huggingface_hub
/inference
/_text_generation.py
# coding=utf-8 | |
# Copyright 2023-present, the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# Original implementation taken from the `text-generation` Python client (see https://pypi.org/project/text-generation/ | |
# and https://github.com/huggingface/text-generation-inference/tree/main/clients/python) | |
# | |
# Changes compared to original implementation: | |
# - use pydantic.dataclasses instead of BaseModel | |
# - default to Python's dataclasses if Pydantic is not installed (same implementation but no validation) | |
# - added default values for all parameters (not needed in BaseModel but dataclasses yes) | |
# - integrated in `huggingface_hub.InferenceClient`` | |
# - added `stream: bool` and `details: bool` in the `text_generation` method instead of having different methods for each use case | |
import warnings | |
from dataclasses import field | |
from enum import Enum | |
from typing import List, NoReturn, Optional | |
from requests import HTTPError | |
from ..utils import is_pydantic_available | |
if is_pydantic_available(): | |
from pydantic import validator as pydantic_validator | |
from pydantic.dataclasses import dataclass | |
def validator(*args, **kwargs): | |
# Pydantic v1's `@validator` is deprecated in favor of `@field_validator`. In order to support both pydantic v1 | |
# and v2 without changing the logic, we catch the warning message in pydantic v2 and ignore it. If we want to | |
# support pydantic v3 in the future, we will drop support for pydantic v1 and use `pydantic.field_validator` | |
# correctly. | |
# | |
# Related: | |
# - https://docs.pydantic.dev/latest/migration/#changes-to-validators | |
# - https://github.com/huggingface/huggingface_hub/pull/1837 | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", message="Pydantic V1 style `@validator` validators are deprecated.") | |
return pydantic_validator(*args, **kwargs) | |
else: | |
# No validation if Pydantic is not installed | |
from dataclasses import dataclass # type: ignore | |
def validator(x): # type: ignore | |
return lambda y: y | |
class TextGenerationParameters: | |
""" | |
Parameters for text generation. | |
Args: | |
do_sample (`bool`, *optional*): | |
Activate logits sampling. Defaults to False. | |
max_new_tokens (`int`, *optional*): | |
Maximum number of generated tokens. Defaults to 20. | |
repetition_penalty (`Optional[float]`, *optional*): | |
The parameter for repetition penalty. A value of 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) | |
for more details. Defaults to None. | |
return_full_text (`bool`, *optional*): | |
Whether to prepend the prompt to the generated text. Defaults to False. | |
stop (`List[str]`, *optional*): | |
Stop generating tokens if a member of `stop_sequences` is generated. Defaults to an empty list. | |
seed (`Optional[int]`, *optional*): | |
Random sampling seed. Defaults to None. | |
temperature (`Optional[float]`, *optional*): | |
The value used to modulate the logits distribution. Defaults to None. | |
top_k (`Optional[int]`, *optional*): | |
The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. | |
top_p (`Optional[float]`, *optional*): | |
If set to a value less than 1, only the smallest set of most probable tokens with probabilities that add up | |
to `top_p` or higher are kept for generation. Defaults to None. | |
truncate (`Optional[int]`, *optional*): | |
Truncate input tokens to the given size. Defaults to None. | |
typical_p (`Optional[float]`, *optional*): | |
Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) | |
for more information. Defaults to None. | |
best_of (`Optional[int]`, *optional*): | |
Generate `best_of` sequences and return the one with the highest token logprobs. Defaults to None. | |
watermark (`bool`, *optional*): | |
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226). Defaults to False. | |
details (`bool`, *optional*): | |
Get generation details. Defaults to False. | |
decoder_input_details (`bool`, *optional*): | |
Get decoder input token logprobs and ids. Defaults to False. | |
""" | |
# Activate logits sampling | |
do_sample: bool = False | |
# Maximum number of generated tokens | |
max_new_tokens: int = 20 | |
# The parameter for repetition penalty. 1.0 means no penalty. | |
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
repetition_penalty: Optional[float] = None | |
# Whether to prepend the prompt to the generated text | |
return_full_text: bool = False | |
# Stop generating tokens if a member of `stop_sequences` is generated | |
stop: List[str] = field(default_factory=lambda: []) | |
# Random sampling seed | |
seed: Optional[int] = None | |
# The value used to module the logits distribution. | |
temperature: Optional[float] = None | |
# The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
top_k: Optional[int] = None | |
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
# higher are kept for generation. | |
top_p: Optional[float] = None | |
# truncate inputs tokens to the given size | |
truncate: Optional[int] = None | |
# Typical Decoding mass | |
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
typical_p: Optional[float] = None | |
# Generate best_of sequences and return the one if the highest token logprobs | |
best_of: Optional[int] = None | |
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
watermark: bool = False | |
# Get generation details | |
details: bool = False | |
# Get decoder input token logprobs and ids | |
decoder_input_details: bool = False | |
def valid_best_of(cls, field_value, values): | |
if field_value is not None: | |
if field_value <= 0: | |
raise ValueError("`best_of` must be strictly positive") | |
if field_value > 1 and values["seed"] is not None: | |
raise ValueError("`seed` must not be set when `best_of` is > 1") | |
sampling = ( | |
values["do_sample"] | |
| (values["temperature"] is not None) | |
| (values["top_k"] is not None) | |
| (values["top_p"] is not None) | |
| (values["typical_p"] is not None) | |
) | |
if field_value > 1 and not sampling: | |
raise ValueError("you must use sampling when `best_of` is > 1") | |
return field_value | |
def valid_repetition_penalty(cls, v): | |
if v is not None and v <= 0: | |
raise ValueError("`repetition_penalty` must be strictly positive") | |
return v | |
def valid_seed(cls, v): | |
if v is not None and v < 0: | |
raise ValueError("`seed` must be positive") | |
return v | |
def valid_temp(cls, v): | |
if v is not None and v <= 0: | |
raise ValueError("`temperature` must be strictly positive") | |
return v | |
def valid_top_k(cls, v): | |
if v is not None and v <= 0: | |
raise ValueError("`top_k` must be strictly positive") | |
return v | |
def valid_top_p(cls, v): | |
if v is not None and (v <= 0 or v >= 1.0): | |
raise ValueError("`top_p` must be > 0.0 and < 1.0") | |
return v | |
def valid_truncate(cls, v): | |
if v is not None and v <= 0: | |
raise ValueError("`truncate` must be strictly positive") | |
return v | |
def valid_typical_p(cls, v): | |
if v is not None and (v <= 0 or v >= 1.0): | |
raise ValueError("`typical_p` must be > 0.0 and < 1.0") | |
return v | |
class TextGenerationRequest: | |
""" | |
Request object for text generation (only for internal use). | |
Args: | |
inputs (`str`): | |
The prompt for text generation. | |
parameters (`Optional[TextGenerationParameters]`, *optional*): | |
Generation parameters. | |
stream (`bool`, *optional*): | |
Whether to stream output tokens. Defaults to False. | |
""" | |
# Prompt | |
inputs: str | |
# Generation parameters | |
parameters: Optional[TextGenerationParameters] = None | |
# Whether to stream output tokens | |
stream: bool = False | |
def valid_input(cls, v): | |
if not v: | |
raise ValueError("`inputs` cannot be empty") | |
return v | |
def valid_best_of_stream(cls, field_value, values): | |
parameters = values["parameters"] | |
if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value: | |
raise ValueError("`best_of` != 1 is not supported when `stream` == True") | |
return field_value | |
def __post_init__(self): | |
if not is_pydantic_available(): | |
# If pydantic is not installed, we need to instantiate the nested dataclasses manually | |
if self.parameters is not None and isinstance(self.parameters, dict): | |
self.parameters = TextGenerationParameters(**self.parameters) | |
# Decoder input tokens | |
class InputToken: | |
""" | |
Represents an input token. | |
Args: | |
id (`int`): | |
Token ID from the model tokenizer. | |
text (`str`): | |
Token text. | |
logprob (`float` or `None`): | |
Log probability of the token. Optional since the logprob of the first token cannot be computed. | |
""" | |
# Token ID from the model tokenizer | |
id: int | |
# Token text | |
text: str | |
# Logprob | |
# Optional since the logprob of the first token cannot be computed | |
logprob: Optional[float] = None | |
# Generated tokens | |
class Token: | |
""" | |
Represents a token. | |
Args: | |
id (`int`): | |
Token ID from the model tokenizer. | |
text (`str`): | |
Token text. | |
logprob (`float`): | |
Log probability of the token. | |
special (`bool`): | |
Indicates whether the token is a special token. It can be used to ignore | |
tokens when concatenating. | |
""" | |
# Token ID from the model tokenizer | |
id: int | |
# Token text | |
text: str | |
# Logprob | |
logprob: float | |
# Is the token a special token | |
# Can be used to ignore tokens when concatenating | |
special: bool | |
# Generation finish reason | |
class FinishReason(str, Enum): | |
# number of generated tokens == `max_new_tokens` | |
Length = "length" | |
# the model generated its end of sequence token | |
EndOfSequenceToken = "eos_token" | |
# the model generated a text included in `stop_sequences` | |
StopSequence = "stop_sequence" | |
# Additional sequences when using the `best_of` parameter | |
class BestOfSequence: | |
""" | |
Represents a best-of sequence generated during text generation. | |
Args: | |
generated_text (`str`): | |
The generated text. | |
finish_reason (`FinishReason`): | |
The reason for the generation to finish, represented by a `FinishReason` value. | |
generated_tokens (`int`): | |
The number of generated tokens in the sequence. | |
seed (`Optional[int]`): | |
The sampling seed if sampling was activated. | |
prefill (`List[InputToken]`): | |
The decoder input tokens. Empty if `decoder_input_details` is False. Defaults to an empty list. | |
tokens (`List[Token]`): | |
The generated tokens. Defaults to an empty list. | |
""" | |
# Generated text | |
generated_text: str | |
# Generation finish reason | |
finish_reason: FinishReason | |
# Number of generated tokens | |
generated_tokens: int | |
# Sampling seed if sampling was activated | |
seed: Optional[int] = None | |
# Decoder input tokens, empty if decoder_input_details is False | |
prefill: List[InputToken] = field(default_factory=lambda: []) | |
# Generated tokens | |
tokens: List[Token] = field(default_factory=lambda: []) | |
def __post_init__(self): | |
if not is_pydantic_available(): | |
# If pydantic is not installed, we need to instantiate the nested dataclasses manually | |
self.prefill = [ | |
InputToken(**input_token) if isinstance(input_token, dict) else input_token | |
for input_token in self.prefill | |
] | |
self.tokens = [Token(**token) if isinstance(token, dict) else token for token in self.tokens] | |
# `generate` details | |
class Details: | |
""" | |
Represents details of a text generation. | |
Args: | |
finish_reason (`FinishReason`): | |
The reason for the generation to finish, represented by a `FinishReason` value. | |
generated_tokens (`int`): | |
The number of generated tokens. | |
seed (`Optional[int]`): | |
The sampling seed if sampling was activated. | |
prefill (`List[InputToken]`, *optional*): | |
The decoder input tokens. Empty if `decoder_input_details` is False. Defaults to an empty list. | |
tokens (`List[Token]`): | |
The generated tokens. Defaults to an empty list. | |
best_of_sequences (`Optional[List[BestOfSequence]]`): | |
Additional sequences when using the `best_of` parameter. | |
""" | |
# Generation finish reason | |
finish_reason: FinishReason | |
# Number of generated tokens | |
generated_tokens: int | |
# Sampling seed if sampling was activated | |
seed: Optional[int] = None | |
# Decoder input tokens, empty if decoder_input_details is False | |
prefill: List[InputToken] = field(default_factory=lambda: []) | |
# Generated tokens | |
tokens: List[Token] = field(default_factory=lambda: []) | |
# Additional sequences when using the `best_of` parameter | |
best_of_sequences: Optional[List[BestOfSequence]] = None | |
def __post_init__(self): | |
if not is_pydantic_available(): | |
# If pydantic is not installed, we need to instantiate the nested dataclasses manually | |
self.prefill = [ | |
InputToken(**input_token) if isinstance(input_token, dict) else input_token | |
for input_token in self.prefill | |
] | |
self.tokens = [Token(**token) if isinstance(token, dict) else token for token in self.tokens] | |
if self.best_of_sequences is not None: | |
self.best_of_sequences = [ | |
BestOfSequence(**best_of_sequence) if isinstance(best_of_sequence, dict) else best_of_sequence | |
for best_of_sequence in self.best_of_sequences | |
] | |
# `generate` return value | |
class TextGenerationResponse: | |
""" | |
Represents a response for text generation. | |
Only returned when `details=True`, otherwise a string is returned. | |
Args: | |
generated_text (`str`): | |
The generated text. | |
details (`Optional[Details]`): | |
Generation details. Returned only if `details=True` is sent to the server. | |
""" | |
# Generated text | |
generated_text: str | |
# Generation details | |
details: Optional[Details] = None | |
def __post_init__(self): | |
if not is_pydantic_available(): | |
# If pydantic is not installed, we need to instantiate the nested dataclasses manually | |
if self.details is not None and isinstance(self.details, dict): | |
self.details = Details(**self.details) | |
# `generate_stream` details | |
class StreamDetails: | |
""" | |
Represents details of a text generation stream. | |
Args: | |
finish_reason (`FinishReason`): | |
The reason for the generation to finish, represented by a `FinishReason` value. | |
generated_tokens (`int`): | |
The number of generated tokens. | |
seed (`Optional[int]`): | |
The sampling seed if sampling was activated. | |
""" | |
# Generation finish reason | |
finish_reason: FinishReason | |
# Number of generated tokens | |
generated_tokens: int | |
# Sampling seed if sampling was activated | |
seed: Optional[int] = None | |
# `generate_stream` return value | |
class TextGenerationStreamResponse: | |
""" | |
Represents a response for streaming text generation. | |
Only returned when `details=True` and `stream=True`. | |
Args: | |
token (`Token`): | |
The generated token. | |
generated_text (`Optional[str]`, *optional*): | |
The complete generated text. Only available when the generation is finished. | |
details (`Optional[StreamDetails]`, *optional*): | |
Generation details. Only available when the generation is finished. | |
""" | |
# Generated token | |
token: Token | |
# Complete generated text | |
# Only available when the generation is finished | |
generated_text: Optional[str] = None | |
# Generation details | |
# Only available when the generation is finished | |
details: Optional[StreamDetails] = None | |
def __post_init__(self): | |
if not is_pydantic_available(): | |
# If pydantic is not installed, we need to instantiate the nested dataclasses manually | |
if isinstance(self.token, dict): | |
self.token = Token(**self.token) | |
if self.details is not None and isinstance(self.details, dict): | |
self.details = StreamDetails(**self.details) | |
# TEXT GENERATION ERRORS | |
# ---------------------- | |
# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation | |
# inference project (https://github.com/huggingface/text-generation-inference). | |
# ---------------------- | |
class TextGenerationError(HTTPError): | |
"""Generic error raised if text-generation went wrong.""" | |
# Text Generation Inference Errors | |
class ValidationError(TextGenerationError): | |
"""Server-side validation error.""" | |
class GenerationError(TextGenerationError): | |
pass | |
class OverloadedError(TextGenerationError): | |
pass | |
class IncompleteGenerationError(TextGenerationError): | |
pass | |
class UnknownError(TextGenerationError): | |
pass | |
def raise_text_generation_error(http_error: HTTPError) -> NoReturn: | |
""" | |
Try to parse text-generation-inference error message and raise HTTPError in any case. | |
Args: | |
error (`HTTPError`): | |
The HTTPError that have been raised. | |
""" | |
# Try to parse a Text Generation Inference error | |
try: | |
# Hacky way to retrieve payload in case of aiohttp error | |
payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() | |
error = payload.get("error") | |
error_type = payload.get("error_type") | |
except Exception: # no payload | |
raise http_error | |
# If error_type => more information than `hf_raise_for_status` | |
if error_type is not None: | |
exception = _parse_text_generation_error(error, error_type) | |
raise exception from http_error | |
# Otherwise, fallback to default error | |
raise http_error | |
def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: | |
if error_type == "generation": | |
return GenerationError(error) # type: ignore | |
if error_type == "incomplete_generation": | |
return IncompleteGenerationError(error) # type: ignore | |
if error_type == "overloaded": | |
return OverloadedError(error) # type: ignore | |
if error_type == "validation": | |
return ValidationError(error) # type: ignore | |
return UnknownError(error) # type: ignore | |