Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team. | |
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import inspect | |
| import warnings | |
| from functools import partial | |
| from typing import Any, Dict, Optional, Union | |
| import flax | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from jax import lax | |
| from ..models.auto import ( | |
| FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, | |
| FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |
| FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, | |
| ) | |
| from ..utils import ModelOutput, logging | |
| from .configuration_utils import GenerationConfig | |
| from .flax_logits_process import ( | |
| FlaxForcedBOSTokenLogitsProcessor, | |
| FlaxForcedEOSTokenLogitsProcessor, | |
| FlaxForceTokensLogitsProcessor, | |
| FlaxLogitsProcessorList, | |
| FlaxMinLengthLogitsProcessor, | |
| FlaxSuppressTokensAtBeginLogitsProcessor, | |
| FlaxSuppressTokensLogitsProcessor, | |
| FlaxTemperatureLogitsWarper, | |
| FlaxTopKLogitsWarper, | |
| FlaxTopPLogitsWarper, | |
| ) | |
| logger = logging.get_logger(__name__) | |
| class FlaxGreedySearchOutput(ModelOutput): | |
| """ | |
| Flax Base class for outputs of decoder-only generation models using greedy search. | |
| Args: | |
| sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): | |
| The generated sequences. | |
| """ | |
| sequences: jnp.ndarray = None | |
| class FlaxSampleOutput(ModelOutput): | |
| """ | |
| Flax Base class for outputs of decoder-only generation models using sampling. | |
| Args: | |
| sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): | |
| The generated sequences. | |
| """ | |
| sequences: jnp.ndarray = None | |
| class FlaxBeamSearchOutput(ModelOutput): | |
| """ | |
| Flax Base class for outputs of decoder-only generation models using greedy search. | |
| Args: | |
| sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): | |
| The generated sequences. | |
| scores (`jnp.ndarray` of shape `(batch_size,)`): | |
| The scores (log probabilities) of the generated sequences. | |
| """ | |
| sequences: jnp.ndarray = None | |
| scores: jnp.ndarray = None | |
| class GreedyState: | |
| cur_len: jnp.ndarray | |
| sequences: jnp.ndarray | |
| running_token: jnp.ndarray | |
| is_sent_finished: jnp.ndarray | |
| model_kwargs: Dict[str, jnp.ndarray] | |
| class SampleState: | |
| cur_len: jnp.ndarray | |
| sequences: jnp.ndarray | |
| running_token: jnp.ndarray | |
| is_sent_finished: jnp.ndarray | |
| prng_key: jnp.ndarray | |
| model_kwargs: Dict[str, jnp.ndarray] | |
| class BeamSearchState: | |
| cur_len: jnp.ndarray | |
| running_sequences: jnp.ndarray | |
| running_scores: jnp.ndarray | |
| sequences: jnp.ndarray | |
| scores: jnp.ndarray | |
| is_sent_finished: jnp.ndarray | |
| model_kwargs: Dict[str, jnp.ndarray] | |
| class FlaxGenerationMixin: | |
| """ | |
| A class containing all functions for auto-regressive text generation, to be used as a mixin in | |
| [`FlaxPreTrainedModel`]. | |
| The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: | |
| - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and | |
| `do_sample=False` | |
| - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and | |
| `do_sample=True` | |
| - *beam-search decoding* by calling [`~generation.FlaxGenerationMixin._beam_search`] if `num_beams>1` and | |
| `do_sample=False` | |
| You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To | |
| learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). | |
| """ | |
| def prepare_inputs_for_generation(self, *args, **kwargs): | |
| raise NotImplementedError( | |
| "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." | |
| ) | |
| def _run_loop_in_debug(cond_fn, body_fn, init_state): | |
| """ | |
| Run generation in untraced mode. This should only be used for debugging purposes. | |
| """ | |
| state = init_state | |
| while cond_fn(state): | |
| state = body_fn(state) | |
| return state | |
| def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs): | |
| encoder_kwargs = { | |
| argument: value | |
| for argument, value in model_kwargs.items() | |
| if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) | |
| } | |
| model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs) | |
| return model_kwargs | |
| def _prepare_decoder_input_ids_for_generation( | |
| self, | |
| batch_size: int, | |
| decoder_start_token_id: int = None, | |
| bos_token_id: int = None, | |
| model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, | |
| ) -> jnp.ndarray: | |
| if model_kwargs is not None and "decoder_input_ids" in model_kwargs: | |
| # Only use this arg if not None, otherwise just remove from model_kwargs | |
| decoder_input_ids = model_kwargs.pop("decoder_input_ids") | |
| if decoder_input_ids is not None: | |
| return decoder_input_ids | |
| decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) | |
| return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0) | |
| def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: | |
| # retrieve decoder_start_token_id for encoder-decoder models | |
| # fall back to bos_token_id if necessary | |
| decoder_start_token_id = ( | |
| decoder_start_token_id | |
| if decoder_start_token_id is not None | |
| else self.generation_config.decoder_start_token_id | |
| ) | |
| bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id | |
| if decoder_start_token_id is not None: | |
| return decoder_start_token_id | |
| elif ( | |
| hasattr(self.config, "decoder") | |
| and hasattr(self.config.decoder, "decoder_start_token_id") | |
| and self.config.decoder.decoder_start_token_id is not None | |
| ): | |
| return self.config.decoder.decoder_start_token_id | |
| elif bos_token_id is not None: | |
| return bos_token_id | |
| elif ( | |
| hasattr(self.config, "decoder") | |
| and hasattr(self.config.decoder, "bos_token_id") | |
| and self.config.decoder.bos_token_id is not None | |
| ): | |
| return self.config.decoder.bos_token_id | |
| raise ValueError( | |
| "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." | |
| ) | |
| def _expand_to_num_beams(tensor, num_beams): | |
| return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) | |
| def _adapt_logits_for_beam_search(self, logits): | |
| """ | |
| This function can be overwritten in the specific modeling_flax_<model-name>.py classes to allow for custom beam | |
| search behavior. Note that the only model that overwrites this method is [`~transformes.FlaxMarianMTModel`]. | |
| """ | |
| return logits | |
| def _validate_model_class(self): | |
| """ | |
| Confirms that the model class is compatible with generation. If not, raises an exception that points to the | |
| right class to use. | |
| """ | |
| if not self.can_generate(): | |
| generate_compatible_mappings = [ | |
| FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, | |
| FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, | |
| FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | |
| ] | |
| generate_compatible_classes = set() | |
| for model_mapping in generate_compatible_mappings: | |
| supported_models = model_mapping.get(type(self.config), default=None) | |
| if supported_models is not None: | |
| generate_compatible_classes.add(supported_models.__name__) | |
| exception_message = ( | |
| f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " | |
| "it doesn't have a language model head." | |
| ) | |
| if generate_compatible_classes: | |
| exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" | |
| raise TypeError(exception_message) | |
| def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): | |
| """Validates model kwargs for generation. Generate argument typos will also be caught here.""" | |
| unused_model_args = [] | |
| model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) | |
| # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If | |
| # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) | |
| if "kwargs" in model_args or "model_kwargs" in model_args: | |
| model_args |= set(inspect.signature(self.__call__).parameters) | |
| for key, value in model_kwargs.items(): | |
| if value is not None and key not in model_args: | |
| unused_model_args.append(key) | |
| if unused_model_args: | |
| raise ValueError( | |
| f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" | |
| " generate arguments will also show up in this list)" | |
| ) | |
| def generate( | |
| self, | |
| input_ids: jnp.ndarray, | |
| generation_config: Optional[GenerationConfig] = None, | |
| prng_key: Optional[jnp.ndarray] = None, | |
| trace: bool = True, | |
| params: Optional[Dict[str, jnp.ndarray]] = None, | |
| logits_processor: Optional[FlaxLogitsProcessorList] = None, | |
| **kwargs, | |
| ): | |
| r""" | |
| Generates sequences of token ids for models with a language modeling head. | |
| Parameters: | |
| input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): | |
| The sequence used as a prompt for the generation. | |
| generation_config (`~generation.GenerationConfig`, *optional*): | |
| The generation configuration to be used as base parametrization for the generation call. `**kwargs` | |
| passed to generate matching the attributes of `generation_config` will override them. If | |
| `generation_config` is not provided, the default will be used, which had the following loading | |
| priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model | |
| configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s | |
| default values, whose documentation should be checked to parameterize generation. | |
| trace (`bool`, *optional*, defaults to `True`): | |
| Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a | |
| considerably slower runtime. | |
| params (`Dict[str, jnp.ndarray]`, *optional*): | |
| Optionally the model parameters can be passed. Can be useful for parallelized generation. | |
| logits_processor (`FlaxLogitsProcessorList `, *optional*): | |
| Custom logits processors that complement the default logits processors built from arguments and | |
| generation config. If a logit processor is passed that is already created with the arguments or a | |
| generation config an error is thrown. This feature is intended for advanced users. | |
| kwargs (`Dict[str, Any]`, *optional*): | |
| Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be | |
| forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder | |
| specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. | |
| Return: | |
| [`~utils.ModelOutput`]. | |
| """ | |
| # Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call | |
| self._validate_model_class() | |
| # priority: `generation_config` argument > `model.generation_config` (the default generation config) | |
| if generation_config is None: | |
| # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, | |
| # two conditions must be met | |
| # 1) the generation config must have been created from the model config (`_from_model_config` field); | |
| # 2) the generation config must have seen no modification since its creation (the hash is the same). | |
| if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( | |
| self.generation_config | |
| ): | |
| new_generation_config = GenerationConfig.from_model_config(self.config) | |
| if new_generation_config != self.generation_config: | |
| warnings.warn( | |
| "You have modified the pretrained model configuration to control generation. This is a" | |
| " deprecated strategy to control generation and will be removed soon, in a future version." | |
| " Please use and modify the model generation configuration (see" | |
| " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" | |
| ) | |
| self.generation_config = new_generation_config | |
| generation_config = self.generation_config | |
| generation_config = copy.deepcopy(generation_config) | |
| model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs | |
| generation_config.validate() | |
| self._validate_model_kwargs(model_kwargs.copy()) | |
| logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList() | |
| # set init values | |
| prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) | |
| if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: | |
| if model_kwargs.get("attention_mask") is None: | |
| logger.warning( | |
| "The attention mask and the pad token id were not set. As a consequence, you may observe " | |
| "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." | |
| ) | |
| eos_token_id = generation_config.eos_token_id | |
| if isinstance(eos_token_id, list): | |
| eos_token_id = eos_token_id[0] | |
| logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |
| generation_config.pad_token_id = eos_token_id | |
| if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder: | |
| raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.") | |
| # decoder-only models should use left-padding for generation (can't be checked with `trace=True`) | |
| if not self.config.is_encoder_decoder and not trace: | |
| if ( | |
| generation_config.pad_token_id is not None | |
| and jnp.sum(input_ids[:, -1] == generation_config.pad_token_id) > 0 | |
| ): | |
| logger.warning( | |
| "A decoder-only architecture is being used, but right-padding was detected! For correct " | |
| "generation results, please set `padding_side='left'` when initializing the tokenizer." | |
| ) | |
| batch_size = input_ids.shape[0] | |
| if self.config.is_encoder_decoder: | |
| # add encoder_outputs to model_kwargs | |
| if model_kwargs.get("encoder_outputs") is None: | |
| model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) | |
| # prepare decoder_input_ids for generation | |
| input_ids = self._prepare_decoder_input_ids_for_generation( | |
| batch_size, | |
| decoder_start_token_id=generation_config.decoder_start_token_id, | |
| bos_token_id=generation_config.bos_token_id, | |
| model_kwargs=model_kwargs, | |
| ) | |
| # Prepare `max_length` depending on other stopping criteria. | |
| input_ids_seq_length = input_ids.shape[-1] | |
| has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None | |
| if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: | |
| # 20 is the default max_length of the generation config | |
| warnings.warn( | |
| f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " | |
| "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.", | |
| UserWarning, | |
| ) | |
| elif generation_config.max_new_tokens is not None: | |
| if not has_default_max_length and generation_config.max_length is not None: | |
| logger.warning( | |
| f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" | |
| f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " | |
| "Please refer to the documentation for more information. " | |
| "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" | |
| ) | |
| generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length | |
| if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: | |
| raise ValueError( | |
| f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than" | |
| f" the maximum length ({generation_config.max_length})" | |
| ) | |
| if input_ids_seq_length >= generation_config.max_length: | |
| input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
| logger.warning( | |
| f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" | |
| f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" | |
| " increasing`max_new_tokens`." | |
| ) | |
| logits_processor = self._get_logits_processor( | |
| generation_config=generation_config, | |
| input_ids_seq_length=input_ids_seq_length, | |
| logits_processor=logits_processor, | |
| ) | |
| if not generation_config.do_sample and generation_config.num_beams == 1: | |
| return self._greedy_search( | |
| input_ids, | |
| generation_config.max_length, | |
| generation_config.pad_token_id, | |
| generation_config.eos_token_id, | |
| logits_processor=logits_processor, | |
| trace=trace, | |
| params=params, | |
| model_kwargs=model_kwargs, | |
| ) | |
| elif generation_config.do_sample and generation_config.num_beams == 1: | |
| logits_warper = self._get_logits_warper(generation_config=generation_config) | |
| return self._sample( | |
| input_ids, | |
| generation_config.max_length, | |
| generation_config.pad_token_id, | |
| generation_config.eos_token_id, | |
| prng_key, | |
| logits_warper=logits_warper, | |
| logits_processor=logits_processor, | |
| trace=trace, | |
| params=params, | |
| model_kwargs=model_kwargs, | |
| ) | |
| elif not generation_config.do_sample and generation_config.num_beams > 1: | |
| # broadcast input_ids & encoder_outputs | |
| input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams) | |
| if "encoder_outputs" in model_kwargs: | |
| model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams( | |
| model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams | |
| ) | |
| for kwarg in ["attention_mask", "decoder_attention_mask"]: | |
| if kwarg in model_kwargs: | |
| model_kwargs[kwarg] = self._expand_to_num_beams( | |
| model_kwargs[kwarg], num_beams=generation_config.num_beams | |
| ) | |
| return self._beam_search( | |
| input_ids, | |
| generation_config.max_length, | |
| generation_config.pad_token_id, | |
| generation_config.eos_token_id, | |
| length_penalty=generation_config.length_penalty, | |
| early_stopping=generation_config.early_stopping, | |
| logits_processor=logits_processor, | |
| trace=trace, | |
| params=params, | |
| num_return_sequences=generation_config.num_return_sequences, | |
| model_kwargs=model_kwargs, | |
| ) | |
| else: | |
| raise NotImplementedError("`Beam sampling is currently not implemented.") | |
| def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList: | |
| """ | |
| This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`] | |
| instances used for multinomial sampling. | |
| """ | |
| warpers = FlaxLogitsProcessorList() | |
| if generation_config.temperature is not None and generation_config.temperature != 1.0: | |
| warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature)) | |
| if generation_config.top_k is not None and generation_config.top_k != 0: | |
| warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1)) | |
| if generation_config.top_p is not None and generation_config.top_p < 1.0: | |
| warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1)) | |
| return warpers | |
| def _get_logits_processor( | |
| self, | |
| generation_config: GenerationConfig, | |
| input_ids_seq_length: int, | |
| logits_processor: Optional[FlaxLogitsProcessorList], | |
| ) -> FlaxLogitsProcessorList: | |
| """ | |
| This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] | |
| instances used to modify the scores of the language model head. | |
| """ | |
| processors = FlaxLogitsProcessorList() | |
| if ( | |
| generation_config.min_length is not None | |
| and generation_config.eos_token_id is not None | |
| and generation_config.min_length > -1 | |
| ): | |
| processors.append( | |
| FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id) | |
| ) | |
| if generation_config.forced_bos_token_id is not None: | |
| processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id)) | |
| if generation_config.forced_eos_token_id is not None: | |
| processors.append( | |
| FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) | |
| ) | |
| if generation_config.suppress_tokens is not None: | |
| processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens)) | |
| if generation_config.begin_suppress_tokens is not None: | |
| begin_index = input_ids_seq_length | |
| begin_index = ( | |
| begin_index | |
| if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) | |
| else begin_index + 1 | |
| ) | |
| if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0: | |
| # generation starts after the last token that is forced | |
| begin_index += generation_config.forced_decoder_ids[-1][0] | |
| processors.append( | |
| FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) | |
| ) | |
| if generation_config.forced_decoder_ids is not None: | |
| forced_decoder_ids = [ | |
| [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids | |
| ] | |
| processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) | |
| processors = self._merge_criteria_processor_list(processors, logits_processor) | |
| return processors | |
| def _merge_criteria_processor_list( | |
| self, | |
| default_list: FlaxLogitsProcessorList, | |
| custom_list: FlaxLogitsProcessorList, | |
| ) -> FlaxLogitsProcessorList: | |
| if len(custom_list) == 0: | |
| return default_list | |
| for default in default_list: | |
| for custom in custom_list: | |
| if type(custom) is type(default): | |
| object_type = "logits processor" | |
| raise ValueError( | |
| f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" | |
| f" `generate`, but it has already been created with the values {default}. {default} has been" | |
| " created by passing the corresponding arguments to generate or by the model's config default" | |
| f" values. If you just want to change the default values of {object_type} consider passing" | |
| f" them as arguments to `generate` instead of using a custom {object_type}." | |
| ) | |
| default_list.extend(custom_list) | |
| return default_list | |
| def _greedy_search( | |
| self, | |
| input_ids: None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| logits_processor: Optional[FlaxLogitsProcessorList] = None, | |
| trace: bool = True, | |
| params: Optional[Dict[str, jnp.ndarray]] = None, | |
| model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, | |
| ): | |
| # init values | |
| max_length = max_length if max_length is not None else self.generation_config.max_length | |
| pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |
| eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |
| batch_size, cur_len = input_ids.shape | |
| eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) | |
| pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) | |
| cur_len = jnp.array(cur_len) | |
| # per batch-item holding current token in loop. | |
| sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) | |
| sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) | |
| # per batch-item state bit indicating if sentence has finished. | |
| is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) | |
| # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop | |
| # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. | |
| model = self.decode if self.config.is_encoder_decoder else self | |
| # initialize model specific kwargs | |
| model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) | |
| # initialize state | |
| state = GreedyState( | |
| cur_len=cur_len, | |
| sequences=sequences, | |
| running_token=input_ids, | |
| is_sent_finished=is_sent_finished, | |
| model_kwargs=model_kwargs, | |
| ) | |
| def greedy_search_cond_fn(state): | |
| """state termination condition fn.""" | |
| has_reached_max_length = state.cur_len == max_length | |
| all_sequence_finished = jnp.all(state.is_sent_finished) | |
| finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) | |
| return ~finish_generation | |
| def greedy_search_body_fn(state): | |
| """state update fn.""" | |
| model_outputs = model(state.running_token, params=params, **state.model_kwargs) | |
| logits = model_outputs.logits[:, -1] | |
| # apply min_length, ... | |
| logits = logits_processor(state.sequences, logits, state.cur_len) | |
| next_token = jnp.argmax(logits, axis=-1) | |
| next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished | |
| next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) | |
| next_token = next_token[:, None] | |
| next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) | |
| next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) | |
| return GreedyState( | |
| cur_len=state.cur_len + 1, | |
| sequences=next_sequences, | |
| running_token=next_token, | |
| is_sent_finished=next_is_sent_finished, | |
| model_kwargs=next_model_kwargs, | |
| ) | |
| # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU | |
| if input_ids.shape[1] > 1: | |
| state = greedy_search_body_fn(state) | |
| if not trace: | |
| state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state) | |
| else: | |
| state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state) | |
| return FlaxGreedySearchOutput(sequences=state.sequences) | |
| def _sample( | |
| self, | |
| input_ids: None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| prng_key: Optional[jnp.ndarray] = None, | |
| logits_processor: Optional[FlaxLogitsProcessorList] = None, | |
| logits_warper: Optional[FlaxLogitsProcessorList] = None, | |
| trace: bool = True, | |
| params: Optional[Dict[str, jnp.ndarray]] = None, | |
| model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, | |
| ): | |
| # init values | |
| max_length = max_length if max_length is not None else self.generation_config.max_length | |
| pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |
| eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |
| prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) | |
| batch_size, cur_len = input_ids.shape | |
| eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) | |
| pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) | |
| cur_len = jnp.array(cur_len) | |
| # per batch-item holding current token in loop. | |
| sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) | |
| sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) | |
| # per batch-item state bit indicating if sentence has finished. | |
| is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) | |
| # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop | |
| # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. | |
| model = self.decode if self.config.is_encoder_decoder else self | |
| # initialize model specific kwargs | |
| model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) | |
| # initialize state | |
| state = SampleState( | |
| cur_len=cur_len, | |
| sequences=sequences, | |
| running_token=input_ids, | |
| is_sent_finished=is_sent_finished, | |
| prng_key=prng_key, | |
| model_kwargs=model_kwargs, | |
| ) | |
| def sample_search_cond_fn(state): | |
| """state termination condition fn.""" | |
| has_reached_max_length = state.cur_len == max_length | |
| all_sequence_finished = jnp.all(state.is_sent_finished) | |
| finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) | |
| return ~finish_generation | |
| def sample_search_body_fn(state): | |
| """state update fn.""" | |
| prng_key, prng_key_next = jax.random.split(state.prng_key) | |
| model_outputs = model(state.running_token, params=params, **state.model_kwargs) | |
| logits = model_outputs.logits[:, -1] | |
| # apply min_length, ... | |
| logits = logits_processor(state.sequences, logits, state.cur_len) | |
| # apply top_p, top_k, temperature | |
| logits = logits_warper(logits, logits, state.cur_len) | |
| next_token = jax.random.categorical(prng_key, logits, axis=-1) | |
| next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) | |
| next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished | |
| next_token = next_token[:, None] | |
| next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) | |
| next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) | |
| return SampleState( | |
| cur_len=state.cur_len + 1, | |
| sequences=next_sequences, | |
| running_token=next_token, | |
| is_sent_finished=next_is_sent_finished, | |
| model_kwargs=next_model_kwargs, | |
| prng_key=prng_key_next, | |
| ) | |
| # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU | |
| if input_ids.shape[1] > 1: | |
| state = sample_search_body_fn(state) | |
| if not trace: | |
| state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state) | |
| else: | |
| state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state) | |
| return FlaxSampleOutput(sequences=state.sequences) | |
| def _beam_search( | |
| self, | |
| input_ids: None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| length_penalty: Optional[float] = None, | |
| early_stopping: Optional[Union[bool, str]] = None, | |
| logits_processor: Optional[FlaxLogitsProcessorList] = None, | |
| trace: bool = True, | |
| params: Optional[Dict[str, jnp.ndarray]] = None, | |
| num_return_sequences: Optional[int] = None, | |
| model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, | |
| ): | |
| """ | |
| This beam search function is heavily inspired by Flax's official example: | |
| https://github.com/google/flax/blob/main/examples/wmt/decode.py | |
| """ | |
| def flatten_beam_dim(tensor): | |
| """Flattens the first two dimensions of a non-scalar array.""" | |
| # ignore scalars (e.g. cache index) | |
| if tensor.ndim == 0: | |
| return tensor | |
| return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) | |
| def unflatten_beam_dim(tensor, batch_size, num_beams): | |
| """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" | |
| # ignore scalars (e.g. cache index) | |
| if tensor.ndim == 0: | |
| return tensor | |
| return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) | |
| def gather_beams(nested, beam_indices, batch_size, new_num_beams): | |
| """ | |
| Gathers the beam slices indexed by beam_indices into new beam array. | |
| """ | |
| batch_indices = jnp.reshape( | |
| jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams) | |
| ) | |
| def gather_fn(tensor): | |
| # ignore scalars (e.g. cache index) | |
| if tensor.ndim == 0: | |
| return tensor | |
| else: | |
| return tensor[batch_indices, beam_indices] | |
| return jax.tree_util.tree_map(gather_fn, nested) | |
| # init values | |
| max_length = max_length if max_length is not None else self.generation_config.max_length | |
| pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |
| eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |
| length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty | |
| early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping | |
| num_return_sequences = ( | |
| num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences | |
| ) | |
| batch_size, num_beams, cur_len = input_ids.shape | |
| eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None) | |
| pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32) | |
| cur_len = jnp.array(cur_len) | |
| # record the prompt length of decoder | |
| decoder_prompt_len = input_ids.shape[-1] | |
| # per batch,beam-item holding current token in loop. | |
| sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) | |
| running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) | |
| running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) | |
| # per batch,beam-item state bit indicating if sentence has finished. | |
| is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) | |
| # per batch,beam-item score, logprobs | |
| running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]) | |
| scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) | |
| # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop | |
| # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. | |
| model = self.decode if self.config.is_encoder_decoder else self | |
| # flatten beam dim | |
| if "encoder_outputs" in model_kwargs: | |
| model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( | |
| model_kwargs["encoder_outputs"]["last_hidden_state"] | |
| ) | |
| for kwarg in ["attention_mask", "decoder_attention_mask"]: | |
| if kwarg in model_kwargs: | |
| model_kwargs[kwarg] = flatten_beam_dim(model_kwargs[kwarg]) | |
| # initialize model specific kwargs | |
| model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs) | |
| # initialize state | |
| state = BeamSearchState( | |
| cur_len=cur_len, | |
| running_sequences=running_sequences, | |
| running_scores=running_scores, | |
| sequences=sequences, | |
| scores=scores, | |
| is_sent_finished=is_sent_finished, | |
| model_kwargs=model_kwargs, | |
| ) | |
| def beam_search_cond_fn(state): | |
| """beam search state termination condition fn.""" | |
| # 1. is less than max length? | |
| not_max_length_yet = state.cur_len < max_length | |
| # 2. can the new beams still improve? | |
| # early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion | |
| # below for more details. | |
| # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 | |
| # early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of | |
| # length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there. | |
| if early_stopping == "never" and length_penalty > 0.0: | |
| best_running_score = state.running_scores[:, :1] / ( | |
| (max_length - decoder_prompt_len) ** length_penalty | |
| ) | |
| else: | |
| best_running_score = state.running_scores[:, :1] / ( | |
| (state.cur_len - decoder_prompt_len) ** length_penalty | |
| ) | |
| worst_finished_score = jnp.where( | |
| state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) | |
| ) | |
| improvement_still_possible = jnp.any(best_running_score > worst_finished_score) | |
| # 3. is there still a beam that has not finished? | |
| still_open_beam = ~(jnp.all(state.is_sent_finished) & (early_stopping is True)) | |
| return not_max_length_yet & still_open_beam & improvement_still_possible | |
| def beam_search_body_fn(state, input_ids_length=1): | |
| """beam search state update fn.""" | |
| # 1. Forward current tokens | |
| # Collect the current position slice along length to feed the fast | |
| # autoregressive decoder model. Flatten the beam dimension into batch | |
| # dimension for feeding into the model. | |
| # unflatten beam dimension | |
| # Unflatten beam dimension in attention cache arrays | |
| input_token = flatten_beam_dim( | |
| lax.dynamic_slice( | |
| state.running_sequences, | |
| (0, 0, state.cur_len - input_ids_length), | |
| (batch_size, num_beams, input_ids_length), | |
| ) | |
| ) | |
| model_outputs = model(input_token, params=params, **state.model_kwargs) | |
| logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) | |
| cache = jax.tree_util.tree_map( | |
| lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values | |
| ) | |
| # adapt logits for FlaxMarianMTModel | |
| logits = self._adapt_logits_for_beam_search(logits) | |
| # 2. Compute log probs | |
| # get log probabilities from logits, | |
| # process logits with processors (*e.g.* min_length, ...), and | |
| # add new logprobs to existing running logprobs scores. | |
| log_probs = jax.nn.log_softmax(logits) | |
| log_probs = logits_processor( | |
| flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len | |
| ) | |
| log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) | |
| log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) | |
| vocab_size = log_probs.shape[2] | |
| log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) | |
| # 3. Retrieve top-K | |
| # Each item in batch has num_beams * vocab_size candidate sequences. | |
| # For each item, get the top 2*k candidates with the highest log- | |
| # probabilities. We gather the top 2*K beams here so that even if the best | |
| # K sequences reach EOS simultaneously, we have another K sequences | |
| # remaining to continue the live beam search. | |
| # Gather the top 2*K scores from _all_ beams. | |
| # Gather 2*k top beams. | |
| # Recover the beam index by floor division. | |
| # Recover token id by modulo division and expand Id array for broadcasting. | |
| # Update sequences for the 2*K top-k new sequences. | |
| beams_to_keep = 2 * num_beams | |
| topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) | |
| topk_beam_indices = topk_indices // vocab_size | |
| topk_running_sequences = gather_beams( | |
| state.running_sequences, topk_beam_indices, batch_size, beams_to_keep | |
| ) | |
| topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) | |
| topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) | |
| # 4. Check which sequences have ended | |
| # Update current sequences: | |
| # Did any of these sequences reach an end marker? | |
| # To prevent these just finished sequences from being added to the current sequences | |
| # set of active beam search sequences, set their log probs to a very large | |
| # negative value. | |
| did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id | |
| running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) | |
| # 5. Get running sequences scores for next | |
| # Determine the top k beam indices (from top 2*k beams) from log probs | |
| # and gather top k beams (from top 2*k beams). | |
| next_topk_indices = lax.top_k(running_topk_log_probs, k=num_beams)[1] | |
| next_running_sequences, next_running_scores = gather_beams( | |
| [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams | |
| ) | |
| # 6. Process topk logits | |
| # Further process log probs: | |
| # - add length penalty | |
| # - make sure no scores can be added anymore if beam is full | |
| # - make sure still running sequences cannot be chosen as finalized beam | |
| topk_log_probs = topk_log_probs / ((state.cur_len + 1 - decoder_prompt_len) ** length_penalty) | |
| beams_in_batch_are_full = jnp.broadcast_to( | |
| state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape | |
| ) & (early_stopping is True) | |
| add_penalty = ~did_topk_just_finished | beams_in_batch_are_full | |
| topk_log_probs += add_penalty * np.array(-1.0e7) | |
| # 7. Get scores, sequences, is sentence finished for next. | |
| # Combine sequences, scores, and flags along the beam dimension and compare | |
| # new finished sequence scores to existing finished scores and select the | |
| # best from the new set of beams | |
| merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) | |
| merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) | |
| merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) | |
| topk_merged_indices = lax.top_k(merged_scores, k=num_beams)[1] | |
| next_sequences, next_scores, next_is_sent_finished = gather_beams( | |
| [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams | |
| ) | |
| # 8. Update model kwargs. | |
| # Determine the top k beam indices from the original set of all beams. | |
| # With these, gather the top k beam-associated caches. | |
| next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) | |
| next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) | |
| model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache) | |
| next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) | |
| return BeamSearchState( | |
| cur_len=state.cur_len + 1, | |
| running_scores=next_running_scores, | |
| running_sequences=next_running_sequences, | |
| scores=next_scores, | |
| sequences=next_sequences, | |
| is_sent_finished=next_is_sent_finished, | |
| model_kwargs=next_model_kwargs, | |
| ) | |
| # Always run first iteration outside of `lax.while_loop` to avoid calling `beam_search_cond_fn` | |
| # when `state.cur_len` equals `decoder_prompt_len`. This also helps to comply with TPU when | |
| # the very first prompt has sequence length > 1. | |
| state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state) | |
| if not trace: | |
| state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state) | |
| else: | |
| state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) | |
| # Account for the edge-case where there are no finished sequences for a | |
| # particular batch item. If so, return running sequences for that batch item. | |
| none_finished = jnp.any(state.is_sent_finished, axis=1) | |
| sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) | |
| scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) | |
| # Take best beams for each batch (the score is sorted in descending order) | |
| sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) | |
| scores = flatten_beam_dim(scores[:, :num_return_sequences]) | |
| return FlaxBeamSearchOutput(sequences=sequences, scores=scores) | |