diff --git a/.venv/lib/python3.11/site-packages/transformers/models/barthez/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/barthez/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..323fe2fe8af9823d4478957b2f94b078ec39b7f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/barthez/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_barthez import * + from .tokenization_barthez_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58f8150de5b79155e57da82599d39f474e06c4d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a4502abc166027cf50e790cba916e226f65c69f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez_fast.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez_fast.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90ab51f504bd896abdac894a9d11d9344e0faa18 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez_fast.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez.py b/.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez.py new file mode 100644 index 0000000000000000000000000000000000000000..604f9c7c21519a655b02248e7fc2b41e786d5bd1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Tokenization classes for the BARThez model.""" + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} + + +SPIECE_UNDERLINE = "▁" + +# TODO this class is useless. This is the most standard sentencpiece model. Let's find which one is closest and nuke this. + + +class BarthezTokenizer(PreTrainedTokenizer): + """ + Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a BARThez tokenizer. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + # Mask token behave like a normal word, i.e. include the space before it. Will have normalized=False by default this way + mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BARThez sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + @property + def vocab_size(self): + return len(self.sp_model) + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index) + + # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["BarthezTokenizer"] diff --git a/.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez_fast.py b/.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d95ef03e48820f39dd91adf117fb6cdc94fe0a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez_fast.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Tokenization classes for the BARThez model.""" + +import os +from shutil import copyfile +from typing import List, Optional, Tuple + +from ...tokenization_utils import AddedToken +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_barthez import BarthezTokenizer +else: + BarthezTokenizer = None + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"} + + +SPIECE_UNDERLINE = "▁" + + +class BarthezTokenizerFast(PreTrainedTokenizerFast): + """ + Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a "fast" BARThez tokenizer. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + additional_special_tokens (`List[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = BarthezTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + **kwargs, + ): + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + **kwargs, + ) + + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BARThez sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["BarthezTokenizerFast"] diff --git a/.venv/lib/python3.11/site-packages/transformers/models/beit/configuration_beit.py b/.venv/lib/python3.11/site-packages/transformers/models/beit/configuration_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..834988258c6b755fcca9f275db139ed76b9a8b54 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/beit/configuration_beit.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BEiT model configuration""" + +import warnings +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +class BeitConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BeitModel`]. It is used to instantiate an BEiT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the BEiT + [microsoft/beit-base-patch16-224-pt22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k) architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 8192): + Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during + pre-training. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + use_mask_token (`bool`, *optional*, defaults to `False`): + Whether to use a mask token for masked image modeling. + use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`): + Whether to use BERT-style absolute position embeddings. + use_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use T5-style relative position embeddings in the self-attention layers. + use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`): + Whether to use the same relative position embeddings across all self-attention layers of the Transformer. + layer_scale_init_value (`float`, *optional*, defaults to 0.1): + Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_mean_pooling (`bool`, *optional*, defaults to `True`): + Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the + CLS token, before applying the classification head. + pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`): + Pooling scales used in Pooling Pyramid Module applied on the last feature map. + use_auxiliary_head (`bool`, *optional*, defaults to `True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (`float`, *optional*, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + auxiliary_channels (`int`, *optional*, defaults to 256): + Number of channels to use in the auxiliary head. + auxiliary_num_convs (`int`, *optional*, defaults to 1): + Number of convolutional layers to use in the auxiliary head. + auxiliary_concat_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the output of the auxiliary head with the input before the classification layer. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + add_fpn (`bool`, *optional*, defaults to `False`): + Whether to add a FPN as part of the backbone. Only relevant for [`BeitBackbone`]. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. Only relevant for [`BeitBackbone`]. + + Example: + + ```python + >>> from transformers import BeitConfig, BeitModel + + >>> # Initializing a BEiT beit-base-patch16-224-pt22k style configuration + >>> configuration = BeitConfig() + + >>> # Initializing a model (with random weights) from the beit-base-patch16-224-pt22k style configuration + >>> model = BeitModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "beit" + + def __init__( + self, + vocab_size=8192, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=16, + num_channels=3, + use_mask_token=False, + use_absolute_position_embeddings=False, + use_relative_position_bias=False, + use_shared_relative_position_bias=False, + layer_scale_init_value=0.1, + drop_path_rate=0.1, + use_mean_pooling=True, + pool_scales=[1, 2, 3, 6], + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + auxiliary_channels=256, + auxiliary_num_convs=1, + auxiliary_concat_input=False, + semantic_loss_ignore_index=255, + out_features=None, + out_indices=None, + add_fpn=False, + reshape_hidden_states=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.use_mask_token = use_mask_token + self.use_absolute_position_embeddings = use_absolute_position_embeddings + self.use_relative_position_bias = use_relative_position_bias + self.use_shared_relative_position_bias = use_shared_relative_position_bias + self.layer_scale_init_value = layer_scale_init_value + self.drop_path_rate = drop_path_rate + self.use_mean_pooling = use_mean_pooling + # decode head attributes (semantic segmentation) + self.pool_scales = pool_scales + # auxiliary head attributes (semantic segmentation) + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.auxiliary_channels = auxiliary_channels + self.auxiliary_num_convs = auxiliary_num_convs + self.auxiliary_concat_input = auxiliary_concat_input + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + # handle backwards compatibility + if "segmentation_indices" in kwargs: + warnings.warn( + "The `segmentation_indices` argument is deprecated and will be removed in a future version, use `out_indices` instead.", + FutureWarning, + ) + out_indices = kwargs.pop("segmentation_indices") + + # backbone attributes + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.add_fpn = add_fpn + self.reshape_hidden_states = reshape_hidden_states + + +# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig +class BeitOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["BeitConfig", "BeitOnnxConfig"] diff --git a/.venv/lib/python3.11/site-packages/transformers/models/beit/feature_extraction_beit.py b/.venv/lib/python3.11/site-packages/transformers/models/beit/feature_extraction_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..141d8bc36d2bbb80c2abcb0084c7369b38dc4f4d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/beit/feature_extraction_beit.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for BEiT.""" + +import warnings + +from ...utils import logging +from .image_processing_beit import BeitImageProcessor + + +logger = logging.get_logger(__name__) + + +class BeitFeatureExtractor(BeitImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class BeitFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use BeitImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["BeitFeatureExtractor"] diff --git a/.venv/lib/python3.11/site-packages/transformers/models/beit/image_processing_beit.py b/.venv/lib/python3.11/site-packages/transformers/models/beit/image_processing_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..ce75797f5ca68ec7295c0c5d190611d8f7c7b26d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/beit/image_processing_beit.py @@ -0,0 +1,515 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Beit.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_torch_available, + is_torch_tensor, + is_vision_available, + logging, +) +from ...utils.deprecation import deprecate_kwarg + + +if is_vision_available(): + import PIL + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + + +class BeitImageProcessor(BaseImageProcessor): + r""" + Constructs a BEiT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image + is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the + `preprocess` method. + crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`. + Can be overridden by the `crop_size` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + The mean to use if normalizing the image. This is a float or list of floats of length of the number of + channels of the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + The standard deviation to use if normalizing the image. This is a float or list of floats of length of the + number of channels of the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0") + @filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS) + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + rescale_factor: Union[int, float] = 1 / 255, + do_rescale: bool = True, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_reduce_labels: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 256, "width": 256} + size = get_size_dict(size) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, param_name="crop_size") + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_reduce_labels = do_reduce_labels + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs + """ + image_processor_dict = image_processor_dict.copy() + if "reduce_labels" in image_processor_dict: + image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to (size["height"], size["width"]). + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=True, param_name="size") + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}") + return resize( + image, + size=(size["height"], size["width"]), + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def reduce_label(self, label: ImageInput) -> np.ndarray: + label = to_numpy_array(label) + # Avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + return label + + def _preprocess( + self, + image: ImageInput, + do_reduce_labels: bool = None, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_reduce_labels: + image = self.reduce_label(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + image = self._preprocess( + image, + do_reduce_labels=False, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def _preprocess_segmentation_map( + self, + segmentation_map: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + do_reduce_labels: bool = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """Preprocesses a single segmentation map.""" + # All transformations expect numpy arrays. + segmentation_map = to_numpy_array(segmentation_map) + # Add an axis to the segmentation maps for transformations. + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + added_dimension = True + input_data_format = ChannelDimension.FIRST + else: + added_dimension = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + segmentation_map = self._preprocess( + image=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + resample=resample, + size=size, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_normalize=False, + do_rescale=False, + input_data_format=ChannelDimension.FIRST, + ) + # Remove extra axis if added + if added_dimension: + segmentation_map = np.squeeze(segmentation_map, axis=0) + segmentation_map = segmentation_map.astype(np.int64) + return segmentation_map + + def __call__(self, images, segmentation_maps=None, **kwargs): + # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both + # be passed in as positional arguments. + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0") + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: Dict[str, int] = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*) + Segmentation maps to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be + padded with zeros and then cropped + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=True, param_name="size") + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + + images = make_list_of_images(images) + + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + + if segmentation_maps is not None and not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + do_center_crop=do_center_crop, + do_rescale=do_rescale, + do_normalize=do_normalize, + resample=resample, + size=size, + rescale_factor=rescale_factor, + crop_size=crop_size, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + + data = {"pixel_values": images} + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_segmentation_map( + segmentation_map=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + resample=resample, + size=size, + do_center_crop=do_center_crop, + crop_size=crop_size, + ) + for segmentation_map in segmentation_maps + ] + data["labels"] = segmentation_maps + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`BeitForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["BeitImageProcessor"] diff --git a/.venv/lib/python3.11/site-packages/transformers/models/beit/modeling_flax_beit.py b/.venv/lib/python3.11/site-packages/transformers/models/beit/modeling_flax_beit.py new file mode 100644 index 0000000000000000000000000000000000000000..2d79c1820088a1c0034ebdb467f9e5c6b3370950 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/beit/modeling_flax_beit.py @@ -0,0 +1,956 @@ +# coding=utf-8 +# Copyright 2021 Microsoft Research and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, List, Optional, Tuple + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPooling, + FlaxMaskedLMOutput, + FlaxSequenceClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_beit import BeitConfig + + +@flax.struct.dataclass +class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling): + """ + Class for outputs of [`FlaxBeitModel`]. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + +BEIT_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`BeitConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +BEIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`AutoImageProcessor.__call__`] for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def relative_position_index_init(window_size: Tuple[int, int]) -> jnp.ndarray: + """ + get pair-wise relative position index for each token inside the window + """ + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + + coords_h = np.arange(window_size[0]) + coords_w = np.arange(window_size[1]) + coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww + coords_flatten = np.reshape(coords, (2, -1)) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = np.transpose(relative_coords, (1, 2, 0)) # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + + relative_position_index = np.zeros(shape=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + return jnp.array(relative_position_index) + + +def ones_with_scale(key, shape, scale, dtype=jnp.float32): + return jnp.ones(shape, dtype) * scale + + +class FlaxBeitDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + rate: float + + @nn.module.compact + def __call__(self, inputs, deterministic: Optional[bool] = True): + if self.rate == 0.0: + return inputs + keep_prob = 1.0 - self.rate + if deterministic: + return inputs + else: + shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + rng = self.make_rng("droppath") + random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype) + binary_tensor = jnp.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + +class FlaxBeitPatchEmbeddings(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.num_channels = self.config.num_channels + image_size = self.config.image_size + patch_size = self.config.patch_size + num_patches = (image_size // patch_size) * (image_size // patch_size) + patch_shape = (image_size // patch_size, image_size // patch_size) + self.num_patches = num_patches + self.patch_shape = patch_shape + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + padding="VALID", + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, pixel_values): + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + batch_size, _, _, channels = embeddings.shape + return jnp.reshape(embeddings, (batch_size, -1, channels)) + + +class FlaxBeitEmbeddings(nn.Module): + """Construct the CLS token, position and patch embeddings.""" + + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) + if self.config.use_mask_token: + self.mask_token = self.param("mask_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) + self.patch_embeddings = FlaxBeitPatchEmbeddings(self.config, dtype=self.dtype) + num_patches = self.patch_embeddings.num_patches + if self.config.use_absolute_position_embeddings: + self.position_embeddings = self.param( + "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, pixel_values, bool_masked_pos=None, deterministic=True): + embeddings = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.shape + + cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) + cls_tokens = cls_tokens.astype(embeddings.dtype) + + if bool_masked_pos is not None: + mask_tokens = jnp.broadcast_to(self.mask_token, (batch_size, seq_len, self.config.hidden_size)) + mask_tokens = mask_tokens.astype(embeddings.dtype) + # replace the masked visual tokens by mask_tokens + w = jnp.expand_dims(bool_masked_pos, axis=-1) + embeddings = embeddings * (1 - w) + mask_tokens * w + + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + + if self.config.use_absolute_position_embeddings: + embeddings = embeddings + self.position_embeddings.astype(embeddings.dtype) + + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +class FlaxBeitRelativePositionBias(nn.Module): + config: BeitConfig + window_size: Tuple[int, int] + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + num_relative_distance = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + 3 + self.relative_position_bias_table = self.param( + "relative_position_bias_table", + nn.initializers.zeros, + (num_relative_distance, self.config.num_attention_heads), + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + self.relative_position_index = relative_position_index_init(self.window_size) + + def __call__(self): + index = self.relative_position_index.reshape(-1) + shape = (self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) + relative_position_bias = self.relative_position_bias_table[index].reshape(shape) # Wh*Ww,Wh*Ww,nH + return jnp.transpose(relative_position_bias, (2, 0, 1)) + + +class FlaxBeitSelfAttention(nn.Module): + config: BeitConfig + window_size: Tuple[int, int] + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0 and not hasattr( + self.config, "embedding_size" + ): + raise ValueError( + f"The hidden size {self.config.hidden_size,} is not a multiple of the number of attention " + f"heads {self.config.num_attention_heads}." + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + use_bias=False, + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + self.relative_position_bias = ( + FlaxBeitRelativePositionBias(self.config, window_size=self.window_size, dtype=self.dtype) + if self.window_size + else None + ) + + def __call__( + self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False + ): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attention_bias = jnp.array(0.0, dtype=self.dtype) + # Add relative position bias if present. + if self.relative_position_bias is not None: + attention_bias = jnp.expand_dims(self.relative_position_bias(), 0) + attention_bias = attention_bias.astype(query_states.dtype) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + attention_bias = attention_bias + relative_position_bias.astype(attention_bias.dtype) + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxBeitSelfOutput(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxBeitAttention(nn.Module): + config: BeitConfig + window_size: Tuple[int, int] + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxBeitSelfAttention(self.config, self.window_size, dtype=self.dtype) + self.output = FlaxBeitSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, hidden_states, relative_position_bias=None, deterministic=True, output_attentions: bool = False + ): + attn_outputs = self.attention( + hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions + ) + attn_output = attn_outputs[0] + attn_output = self.output(attn_output, deterministic=deterministic) + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +class FlaxBeitIntermediate(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + + return hidden_states + + +class FlaxBeitOutput(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + return hidden_states + + +class FlaxBeitLayer(nn.Module): + config: BeitConfig + window_size: Tuple[int, int] + drop_path_rate: float + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype) + self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype) + self.output = FlaxBeitOutput(self.config, dtype=self.dtype) + self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate) + self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + self.init_values = self.config.layer_scale_init_value + if self.init_values > 0: + self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values) + self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values) + else: + self.lambda_1 = None + self.lambda_2 = None + + def __call__( + self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False + ): + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention + relative_position_bias, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1.astype(attention_output.dtype) * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output, deterministic=deterministic) + hidden_states + + # in BEiT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, deterministic=deterministic) + + # apply lambda_2 if present + if self.lambda_2 is not None: + layer_output = self.lambda_2.astype(layer_output.dtype) * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output, deterministic=deterministic) + hidden_states + + outputs = (layer_output,) + + if output_attentions: + outputs += (self_attention_outputs[1],) + + return outputs + + +class FlaxBeitLayerCollection(nn.Module): + config: BeitConfig + window_size: Tuple[int, int] + drop_path_rates: List[float] + relative_position_bias: Callable[[], jnp.ndarray] + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxBeitLayer( + self.config, + window_size=self.window_size if self.config.use_relative_position_bias else None, + drop_path_rate=self.drop_path_rates[i], + name=str(i), + dtype=self.dtype, + ) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + relative_position_bias = self.relative_position_bias() if self.relative_position_bias is not None else None + layer_outputs = layer( + hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxBeitEncoder(nn.Module): + config: BeitConfig + window_size: Tuple[int, int] + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.use_shared_relative_position_bias: + self.relative_position_bias = FlaxBeitRelativePositionBias( + config=self.config, window_size=self.window_size, dtype=self.dtype + ) + + # stochastic depth decay rule + drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers)) + self.layer = FlaxBeitLayerCollection( + self.config, + window_size=self.window_size, + drop_path_rates=drop_path_rates, + relative_position_bias=self.relative_position_bias + if self.config.use_shared_relative_position_bias + else None, + dtype=self.dtype, + ) + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BeitConfig + base_model_prefix = "beit" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: BeitConfig, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + params_rng, dropout_rng = jax.random.split(rng) + dropout_rng, droppath_rng = jax.random.split(dropout_rng) + rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + pixel_values, + bool_masked_pos=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + dropout_rng, droppath_rng = jax.random.split(dropout_rng) + rngs["dropout"] = dropout_rng + rngs["droppath"] = droppath_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + bool_masked_pos, + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxBeitPooler(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.use_mean_pooling: + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states): + if self.config.use_mean_pooling: + # Mean pool the final hidden states of the patch tokens + patch_tokens = hidden_states[:, 1:, :] + pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, 0] + + return pooled_output + + +class FlaxBeitModule(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxBeitEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxBeitEncoder( + self.config, window_size=self.embeddings.patch_embeddings.patch_shape, dtype=self.dtype + ) + if not self.config.use_mean_pooling: + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.pooler = FlaxBeitPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None + + def __call__( + self, + pixel_values, + bool_masked_pos=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embeddings(pixel_values, bool_masked_pos, deterministic=deterministic) + + outputs = self.encoder( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + if not self.config.use_mean_pooling: + hidden_states = self.layernorm(hidden_states) + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBeitModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.", + BEIT_START_DOCSTRING, +) +class FlaxBeitModel(FlaxBeitPreTrainedModel): + module_class = FlaxBeitModule + + +FLAX_BEIT_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxBeitModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k") + >>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig) + + +class FlaxBeitForMaskedImageModelingModule(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype) + + # Classifier head + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__( + self, + pixel_values=None, + bool_masked_pos=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + bool_masked_pos, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + sequence_output = self.layernorm(sequence_output) + prediction_scores = self.lm_head(sequence_output[:, 1:]) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return output + + return FlaxMaskedLMOutput( + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", + BEIT_START_DOCSTRING, +) +class FlaxBeitForMaskedImageModeling(FlaxBeitPreTrainedModel): + module_class = FlaxBeitForMaskedImageModelingModule + + +FLAX_BEIT_MLM_DOCSTRING = """ + bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k") + >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ``` +""" + +overwrite_call_docstring(FlaxBeitForMaskedImageModeling, FLAX_BEIT_MLM_DOCSTRING) +append_replace_return_docstrings( + FlaxBeitForMaskedImageModeling, output_type=FlaxMaskedLMOutput, config_class=BeitConfig +) + + +class FlaxBeitForImageClassificationModule(nn.Module): + config: BeitConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.beit = FlaxBeitModule(config=self.config, dtype=self.dtype, add_pooling_layer=True) + self.classifier = nn.Dense( + self.config.num_labels, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__( + self, + pixel_values=None, + bool_masked_pos=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + logits = self.classifier(pooled_output) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final + hidden states of the patch tokens) e.g. for ImageNet. + """, + BEIT_START_DOCSTRING, +) +class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel): + module_class = FlaxBeitForImageClassificationModule + + +FLAX_BEIT_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxBeitForImageClassification + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") + >>> model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + ``` +""" + +overwrite_call_docstring(FlaxBeitForImageClassification, FLAX_BEIT_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig +) + + +__all__ = [ + "FlaxBeitForImageClassification", + "FlaxBeitForMaskedImageModeling", + "FlaxBeitModel", + "FlaxBeitPreTrainedModel", +] diff --git a/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b65c4bddb4b0cd3fa8dfd6a781a3c0f58e30e5a7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_code_llama import * + from .tokenization_code_llama_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85c8ffc2eb66f097a9e9e0f8735a47c3912b1b2e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e333dfa7db0322039efad2048de9611286134cf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama_fast.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama_fast.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e14adae06a0783d6721aca9d12fb806019dd4e8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama_fast.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama.py b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..43386ecdaee4dce9e983470242648fd7ec03a90d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama.py @@ -0,0 +1,452 @@ +# coding=utf-8 +# Copyright 2023 MetaAI and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for Code LLaMA.""" + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging, requires_backends + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class CodeLlamaTokenizer(PreTrainedTokenizer): + """ + Construct a CodeLlama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as + there is no padding token in the original model. + + The default configuration match that of + [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json) + which supports prompt infilling. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + prefix_token (`str`, *optional*, defaults to `"▁
"`):
+            Prefix token used for infilling.
+        middle_token (`str`, *optional*, defaults to `"▁"`):
+            Middle token used for infilling.
+        suffix_token (`str`, *optional*, defaults to `"▁"`):
+            Suffix token used for infilling.
+        eot_token (`str`, *optional*, defaults to `"▁"`):
+            End of text token used for infilling.
+        fill_token (`str`, *optional*, defaults to `""`):
+            The token used to split the input between the prefix and suffix.
+        suffix_first (`bool`, *optional*, defaults to `False`):
+            Whether the input prompt and suffix should be formatted with the suffix first.
+        sp_model_kwargs (`dict`, *optional*):
+            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+            to set:
+
+            - `enable_sampling`: Enable subword regularization.
+            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+              - `nbest_size = {0,1}`: No sampling is performed.
+              - `nbest_size > 1`: samples from the nbest_size results.
+              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+                using forward-filtering-and-backward-sampling algorithm.
+
+            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+              BPE-dropout.
+        add_bos_token (`bool`, *optional*, defaults to `True`):
+            Whether to add a beginning of sequence token at the start of sequences.
+        add_eos_token (`bool`, *optional*, defaults to `False`):
+            Whether to add an end of sequence token at the end of sequences.
+        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+            Whether or not to clean up the tokenization spaces.
+        additional_special_tokens (`List[str]`, *optional*):
+            Additional special tokens used by the tokenizer.
+        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+            Whether or not the default system prompt for Llama should be used.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file,
+        unk_token="",
+        bos_token="",
+        eos_token="",
+        prefix_token="▁
",
+        middle_token="▁",
+        suffix_token="▁",
+        eot_token="▁",
+        fill_token="",
+        suffix_first=False,
+        sp_model_kwargs: Optional[Dict[str, Any]] = None,
+        add_bos_token=True,
+        add_eos_token=False,
+        clean_up_tokenization_spaces=False,
+        additional_special_tokens=None,
+        use_default_system_prompt=False,
+        **kwargs,
+    ):
+        requires_backends(self, "protobuf")
+        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+        bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+        eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+        unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+
+        self.use_default_system_prompt = use_default_system_prompt
+        # mark tokens special to skip them
+        additional_special_tokens = additional_special_tokens or []
+        for token in [prefix_token, middle_token, suffix_token, eot_token]:
+            additional_special_tokens += [token] if token is not None else []
+
+        self.vocab_file = vocab_file
+        self.add_bos_token = add_bos_token
+        self.add_eos_token = add_eos_token
+        self._prefix_token = prefix_token
+        self._middle_token = middle_token
+        self._suffix_token = suffix_token
+        self._eot_token = eot_token
+        self.fill_token = fill_token
+        self.suffix_first = suffix_first
+        self.sp_model = self.get_spm_processor()
+
+        super().__init__(
+            bos_token=bos_token,
+            eos_token=eos_token,
+            unk_token=unk_token,
+            add_bos_token=add_bos_token,
+            add_eos_token=add_eos_token,
+            prefix_token=prefix_token,
+            middle_token=middle_token,
+            suffix_token=suffix_token,
+            eot_token=eot_token,
+            fill_token=fill_token,
+            sp_model_kwargs=self.sp_model_kwargs,
+            suffix_first=suffix_first,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            additional_special_tokens=additional_special_tokens,
+            use_default_system_prompt=use_default_system_prompt,
+            **kwargs,
+        )
+
+    @property
+    def unk_token_length(self):
+        return len(self.sp_model.encode(str(self.unk_token)))
+
+    def get_spm_processor(self):
+        tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        with open(self.vocab_file, "rb") as f:
+            sp_model = f.read()
+            model_pb2 = import_protobuf()
+            model = model_pb2.ModelProto.FromString(sp_model)
+            normalizer_spec = model_pb2.NormalizerSpec()
+            normalizer_spec.add_dummy_prefix = False
+            model.normalizer_spec.MergeFrom(normalizer_spec)
+            sp_model = model.SerializeToString()
+            tokenizer.LoadFromSerializedProto(sp_model)
+        return tokenizer
+
+    @property
+    def prefix_token(self):
+        return self._prefix_token
+
+    @property
+    def prefix_id(self):
+        if self._prefix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.prefix_token)
+
+    @property
+    def middle_token(self):
+        return self._middle_token
+
+    @property
+    def middle_id(self):
+        if self._middle_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.middle_token)
+
+    @property
+    def suffix_token(self):
+        return self._suffix_token
+
+    @property
+    def suffix_id(self):
+        if self._suffix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.suffix_token)
+
+    @property
+    def eot_token(self):
+        return self._eot_token
+
+    @property
+    def eot_id(self):
+        if self._eot_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.eot_token)
+
+    @property
+    def vocab_size(self):
+        """Returns vocab size"""
+        return self.sp_model.get_piece_size()
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
+    def get_vocab(self):
+        """Returns vocab as a dict"""
+        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+        vocab.update(self.added_tokens_encoder)
+        return vocab
+
+    def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> List[int]:
+        # add a prefix space to `prefix`
+        if self.fill_token is not None and self.fill_token in prefix and suffix is None:
+            prefix, suffix = prefix.split(self.fill_token)
+
+        if len(prefix) > 0:
+            prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
+
+        if suffix is None or len(suffix) < 1:
+            tokens = super().tokenize(prefix, **kwargs)
+            if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
+                tokens = tokens[1:]
+            return tokens
+
+        prefix_tokens = self._tokenize(prefix)  # prefix has an extra `SPIECE_UNDERLINE`
+
+        if None in (self.prefix_id, self.middle_id, self.suffix_id):
+            raise ValueError(
+                "The input either includes a `prefix` and a `suffix` used for the infilling task,"
+                f"  or can be split on the {self.fill_token} token, creating a suffix and prefix,"
+                " but the model does not support `infilling`."
+            )
+        suffix_tokens = self._tokenize(suffix)  # make sure CodeLlama sp model does not mess up
+
+        suffix_first = suffix_first if suffix_first is not None else self.suffix_first
+        if suffix_first:
+            # format as " 
 {suf}  {pre}"
+            return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
+        else:
+            # format as " 
 {pre} {suf} "
+            return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
+
+    def _tokenize(self, text, **kwargs):
+        """
+        Returns a tokenized string.
+
+        We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
+        SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
+        `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
+        `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
+        `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
+        """
+        tokens = self.sp_model.encode(text, out_type=str)
+        if not text.startswith((SPIECE_UNDERLINE, " ")):
+            return tokens
+        # 1. Encode string + prefix ex: " Hey"
+        tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
+        # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
+        return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
+    def _convert_token_to_id(self, token):
+        """Converts a token (str) in an id using the vocab."""
+        return self.sp_model.piece_to_id(token)
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
+    def _convert_id_to_token(self, index):
+        """Converts an index (integer) in a token (str) using the vocab."""
+        token = self.sp_model.IdToPiece(index)
+        return token
+
+    def convert_tokens_to_string(self, tokens):
+        """Converts a sequence of tokens (string) in a single string."""
+        # since we manually add the prefix space, we have to remove it when decoding
+        if tokens[0].startswith(SPIECE_UNDERLINE):
+            tokens[0] = tokens[0][1:]
+
+        current_sub_tokens = []
+        out_string = ""
+        for _, token in enumerate(tokens):
+            # make sure that special tokens are not decoded using sentencepiece model
+            if token in self.all_special_tokens:
+                out_string += self.sp_model.decode(current_sub_tokens) + token
+                current_sub_tokens = []
+            else:
+                current_sub_tokens.append(token)
+        out_string += self.sp_model.decode(current_sub_tokens)
+        return out_string
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
+    def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
+        """
+        Save the vocabulary and special tokens file to a directory.
+
+        Args:
+            save_directory (`str`):
+                The directory in which to save the vocabulary.
+
+        Returns:
+            `Tuple(str)`: Paths to the files saved.
+        """
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+        elif not os.path.isfile(self.vocab_file):
+            with open(out_vocab_file, "wb") as fi:
+                content_spiece_model = self.sp_model.serialized_model_proto()
+                fi.write(content_spiece_model)
+
+        return (out_vocab_file,)
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
+    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+        output = bos_token_id + token_ids_0 + eos_token_id
+
+        if token_ids_1 is not None:
+            output = output + bos_token_id + token_ids_1 + eos_token_id
+
+        return output
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
+    def get_special_tokens_mask(
+        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+    ) -> List[int]:
+        """
+        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+        special tokens using the tokenizer `prepare_for_model` method.
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of IDs.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special tokens for the model.
+
+        Returns:
+            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+        """
+        if already_has_special_tokens:
+            return super().get_special_tokens_mask(
+                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+            )
+
+        bos_token_id = [1] if self.add_bos_token else []
+        eos_token_id = [1] if self.add_eos_token else []
+
+        if token_ids_1 is None:
+            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+        return (
+            bos_token_id
+            + ([0] * len(token_ids_0))
+            + eos_token_id
+            + bos_token_id
+            + ([0] * len(token_ids_1))
+            + eos_token_id
+        )
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
+    def create_token_type_ids_from_sequences(
+        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+    ) -> List[int]:
+        """
+        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+        sequence pair mask has the following format:
+
+        ```
+        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+        | first sequence    | second sequence |
+        ```
+
+        if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of ids.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+        """
+        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+        if token_ids_1 is not None:
+            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+        return output
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["sp_model"] = None
+        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+        return state
+
+    def __setstate__(self, d):
+        self.__dict__ = d
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+
+__all__ = ["CodeLlamaTokenizer"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bc831cdd6a15b0af85982173e903b7462e9a2a2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
@@ -0,0 +1,381 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import List, Optional, Tuple
+
+from tokenizers import normalizers, processors
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+from ...utils.versions import require_version
+
+
+require_version("tokenizers>=0.13.3")
+
+if is_sentencepiece_available():
+    from .tokenization_code_llama import CodeLlamaTokenizer
+else:
+    CodeLlamaTokenizer = None
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
+
+SPIECE_UNDERLINE = "▁"
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+# fmt: off
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information."""
+# fmt: on
+
+
+class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
+    """
+    Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+    This uses notably ByteFallback and no normalization.
+
+    ```python
+    >>> from transformers import CodeLlamaTokenizerFast
+
+    >>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
+    >>> tokenizer.encode("Hello this is a test")
+    [1, 15043, 445, 338, 263, 1243]
+    ```
+
+    If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
+    call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
+    values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
+    [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
+
+
+    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+    refer to this superclass for more information regarding those methods. The default configuration match that of
+    [meta-llama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
+    which supports prompt infilling.
+
+    Args:
+        vocab_file (`str`, *optional*):
+            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
+            contains the vocabulary necessary to instantiate a tokenizer.
+        tokenizer_file (`str`, *optional*):
+            [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+            contains everything needed to load the tokenizer.
+        clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
+            Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
+            spaces.
+        unk_token (`str`, *optional*, defaults to `""`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        bos_token (`str`, *optional*, defaults to `""`):
+            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+        eos_token (`str`, *optional*, defaults to `""`):
+            The end of sequence token.
+        prefix_token (`str`, *optional*, defaults to `"▁
"`):
+            Prefix token used for infilling.
+        middle_token (`str`, *optional*, defaults to `"▁"`):
+            Middle token used for infilling.
+        suffix_token (`str`, *optional*, defaults to `"▁"`):
+            Suffix token used for infilling.
+        eot_token (`str`, *optional*, defaults to `"▁"`):
+            End of text token used for infilling.
+        fill_token (`str`, *optional*, defaults to `""`):
+            The token used to split the input between the prefix and suffix.
+        additional_special_tokens (`List[str]`, *optional*):
+            Additional special tokens used by the tokenizer.
+        add_bos_token (`bool`, *optional*, defaults to `True`):
+            Whether to add a beginning of sequence token at the start of sequences.
+        add_eos_token (`bool`, *optional*, defaults to `False`):
+            Whether to add an end of sequence token at the end of sequences.
+        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+            Whether or not the default system prompt for Llama should be used.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    slow_tokenizer_class = CodeLlamaTokenizer
+    padding_side = "left"
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file=None,
+        tokenizer_file=None,
+        clean_up_tokenization_spaces=False,
+        unk_token="",
+        bos_token="",
+        eos_token="",
+        prefix_token="▁
",
+        middle_token="▁",
+        suffix_token="▁",
+        eot_token="▁",
+        fill_token="",
+        additional_special_tokens=None,
+        add_bos_token=True,
+        add_eos_token=False,
+        use_default_system_prompt=False,
+        **kwargs,
+    ):
+        # mark tokens special to skip them
+        additional_special_tokens = additional_special_tokens or []
+        for token in [prefix_token, middle_token, suffix_token, eot_token]:
+            additional_special_tokens += [token] if token is not None else []
+        self.use_default_system_prompt = use_default_system_prompt
+
+        super().__init__(
+            vocab_file=vocab_file,
+            tokenizer_file=tokenizer_file,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            additional_special_tokens=additional_special_tokens,
+            unk_token=unk_token,
+            bos_token=bos_token,
+            eos_token=eos_token,
+            add_bos_token=add_bos_token,
+            add_eos_token=add_eos_token,
+            prefix_token=prefix_token,
+            middle_token=middle_token,
+            suffix_token=suffix_token,
+            eot_token=eot_token,
+            fill_token=fill_token,
+            use_default_system_prompt=use_default_system_prompt,
+            **kwargs,
+        )
+        self._add_bos_token = add_bos_token
+        self._add_eos_token = add_eos_token
+        self.update_post_processor()
+
+        self.vocab_file = vocab_file
+
+        self._prefix_token = prefix_token
+        self._middle_token = middle_token
+        self._suffix_token = suffix_token
+        self._eot_token = eot_token
+        self.fill_token = fill_token
+
+    @property
+    def can_save_slow_tokenizer(self) -> bool:
+        return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+    # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
+    def update_post_processor(self):
+        """
+        Updates the underlying post processor with the current `bos_token` and `eos_token`.
+        """
+        bos = self.bos_token
+        bos_token_id = self.bos_token_id
+        if bos is None and self.add_bos_token:
+            raise ValueError("add_bos_token = True but bos_token = None")
+
+        eos = self.eos_token
+        eos_token_id = self.eos_token_id
+        if eos is None and self.add_eos_token:
+            raise ValueError("add_eos_token = True but eos_token = None")
+
+        single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
+        pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
+
+        special_tokens = []
+        if self.add_bos_token:
+            special_tokens.append((bos, bos_token_id))
+        if self.add_eos_token:
+            special_tokens.append((eos, eos_token_id))
+        self._tokenizer.post_processor = processors.TemplateProcessing(
+            single=single, pair=pair, special_tokens=special_tokens
+        )
+
+    @property
+    def prefix_token(self):
+        return self._prefix_token
+
+    @property
+    def prefix_id(self):
+        if self._prefix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.prefix_token)
+
+    @property
+    def middle_token(self):
+        return self._middle_token
+
+    @property
+    def middle_id(self):
+        if self._middle_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.middle_token)
+
+    @property
+    def suffix_token(self):
+        return self._suffix_token
+
+    @property
+    def suffix_id(self):
+        if self._suffix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.suffix_token)
+
+    @property
+    def eot_id(self):
+        if self._eot_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.eot_token)
+
+    @property
+    def eot_token(self):
+        return self._eot_token
+
+    @property
+    def add_eos_token(self):
+        return self._add_eos_token
+
+    @property
+    def add_bos_token(self):
+        return self._add_bos_token
+
+    @add_eos_token.setter
+    def add_eos_token(self, value):
+        self._add_eos_token = value
+        self.update_post_processor()
+
+    @add_bos_token.setter
+    def add_bos_token(self, value):
+        self._add_bos_token = value
+        self.update_post_processor()
+
+    def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
+        """
+        Updates the normalizer to make sure the prompt format for `infilling` is respected. The infilling format is the
+        following: if suffix_first
+            " 
 {suf}  {pre}"
+        else:
+            " 
 {pre} {suf} "
+
+        If `reset` is set to `True`, the `normalizer` and `post_processor` are reset to their "normal" behaviour, which
+        is to add a prefix space for the normalizer, and add a `bos_token` to the input text for the `post_processor`.
+        """
+        if reset:
+            self._tokenizer.normalizer = normalizers.Sequence(
+                [
+                    normalizers.Prepend(prepend="▁"),
+                    normalizers.Replace(pattern=" ", content="▁"),
+                ]
+            )
+            self.update_post_processor()
+            return
+
+        self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
+        pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
+        special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
+        if suffix_first:
+            # format as " 
 {suf}  {pre}"
+            pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
+            special_tokens += [
+                (self.prefix_token, self.prefix_id),
+                (self.suffix_token, self.suffix_id),
+                (self.middle_token, self.middle_id),
+            ]
+        else:
+            # format as " 
 {pre} {suf} "
+            pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
+            special_tokens += [
+                (self.prefix_token, self.prefix_id),
+                (self.suffix_token, self.suffix_id),
+                (self.middle_token, self.middle_id),
+            ]
+
+        if self.add_eos_token and add_special_tokens:
+            pair += [self.eos_token]
+            special_tokens += [(self.eos_token, self.eos_token_id)]
+        self._tokenizer.post_processor = processors.TemplateProcessing(
+            single="$A", pair=pair, special_tokens=special_tokens
+        )
+
+    def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
+        # hack to make sure the input is pre-process but outside rust
+        text_pair = kwargs.pop("suffix", text_pair)
+        if self.fill_token is not None and self.fill_token in text and text_pair is None:
+            text, text_pair = text.split(self.fill_token)
+
+        if text_pair is None or len(text_pair) < 1:
+            return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
+
+        if None in (self.prefix_id, self.middle_id, self.suffix_id):
+            raise ValueError(
+                "Then input includes a `prefix` and a `suffix` used for the infilling task,"
+                " the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
+                f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
+            )
+
+        self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
+        tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
+        self.set_infilling_processor(True)
+        return tokens
+
+    # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+        if not self.can_save_slow_tokenizer:
+            raise ValueError(
+                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+                "tokenizer."
+            )
+
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+
+        return (out_vocab_file,)
+
+    def build_inputs_with_special_tokens(
+        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+    ) -> List[int]:
+        """
+        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+        adding special tokens. The special tokens depend on calling set_lang.
+
+        An NLLB sequence has the following format, where `X` represents the sequence:
+
+        - `input_ids` (for encoder) `X [eos, src_lang_code]`
+        - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+        BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+        separator.
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of IDs to which the special tokens will be added.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+        """
+        if token_ids_1 is None:
+            return self.bos_token_id + token_ids_0 + self.eos_token_id
+        return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
+
+
+__all__ = ["CodeLlamaTokenizerFast"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..796b9a48926fe6239195588496012842cf355299
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_convnext import *
+    from .feature_extraction_convnext import *
+    from .image_processing_convnext import *
+    from .modeling_convnext import *
+    from .modeling_tf_convnext import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0377936552b0fb20458a7acadc29f4244df0504
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26c7a4a5a4bae2df53bafe2d51eb266d4788ab4f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f9ed3bfd469df4a8ea03966de43afb9fb67623c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py
@@ -0,0 +1,142 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ConvNeXT model configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
+    ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the ConvNeXT
+    [facebook/convnext-tiny-224](https://huggingface.co/facebook/convnext-tiny-224) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        patch_size (`int`, *optional*, defaults to 4):
+            Patch size to use in the patch embedding layer.
+        num_stages (`int`, *optional*, defaults to 4):
+            The number of stages in the model.
+        hidden_sizes (`List[int]`, *optional*, defaults to [96, 192, 384, 768]):
+            Dimensionality (hidden size) at each stage.
+        depths (`List[int]`, *optional*, defaults to [3, 3, 9, 3]):
+            Depth (number of blocks) for each stage.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
+            `"selu"` and `"gelu_new"` are supported.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        layer_scale_init_value (`float`, *optional*, defaults to 1e-6):
+            The initial value for the layer scale.
+        drop_path_rate (`float`, *optional*, defaults to 0.0):
+            The drop rate for stochastic depth.
+        out_features (`List[str]`, *optional*):
+            If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+            (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+            corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+        out_indices (`List[int]`, *optional*):
+            If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+            many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+            If unset and `out_features` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+
+    Example:
+    ```python
+    >>> from transformers import ConvNextConfig, ConvNextModel
+
+    >>> # Initializing a ConvNext convnext-tiny-224 style configuration
+    >>> configuration = ConvNextConfig()
+
+    >>> # Initializing a model (with random weights) from the convnext-tiny-224 style configuration
+    >>> model = ConvNextModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "convnext"
+
+    def __init__(
+        self,
+        num_channels=3,
+        patch_size=4,
+        num_stages=4,
+        hidden_sizes=None,
+        depths=None,
+        hidden_act="gelu",
+        initializer_range=0.02,
+        layer_norm_eps=1e-12,
+        layer_scale_init_value=1e-6,
+        drop_path_rate=0.0,
+        image_size=224,
+        out_features=None,
+        out_indices=None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.num_channels = num_channels
+        self.patch_size = patch_size
+        self.num_stages = num_stages
+        self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
+        self.depths = [3, 3, 9, 3] if depths is None else depths
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.layer_scale_init_value = layer_scale_init_value
+        self.drop_path_rate = drop_path_rate
+        self.image_size = image_size
+        self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
+        self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+            out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+        )
+
+
+class ConvNextOnnxConfig(OnnxConfig):
+    torch_onnx_minimum_version = version.parse("1.11")
+
+    @property
+    def inputs(self) -> Mapping[str, Mapping[int, str]]:
+        return OrderedDict(
+            [
+                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+            ]
+        )
+
+    @property
+    def atol_for_validation(self) -> float:
+        return 1e-5
+
+
+__all__ = ["ConvNextConfig", "ConvNextOnnxConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b2208e5b1134fd998e767a15e1639a097b5fa41
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py
@@ -0,0 +1,36 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for ConvNeXT."""
+
+import warnings
+
+from ...utils import logging
+from .image_processing_convnext import ConvNextImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextFeatureExtractor(ConvNextImageProcessor):
+    def __init__(self, *args, **kwargs) -> None:
+        warnings.warn(
+            "The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+            " Please use ConvNextImageProcessor instead.",
+            FutureWarning,
+        )
+        super().__init__(*args, **kwargs)
+
+
+__all__ = ["ConvNextFeatureExtractor"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..35cbc91797add39d04a436e67c6f8c1f8c4ad1ce
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py
@@ -0,0 +1,323 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ConvNeXT."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+    center_crop,
+    get_resize_output_image_size,
+    resize,
+    to_channel_dimension_format,
+)
+from ...image_utils import (
+    IMAGENET_STANDARD_MEAN,
+    IMAGENET_STANDARD_STD,
+    ChannelDimension,
+    ImageInput,
+    PILImageResampling,
+    infer_channel_dimension_format,
+    is_scaled_image,
+    make_list_of_images,
+    to_numpy_array,
+    valid_images,
+    validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+    import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextImageProcessor(BaseImageProcessor):
+    r"""
+    Constructs a ConvNeXT image processor.
+
+    Args:
+        do_resize (`bool`, *optional*, defaults to `True`):
+            Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
+            by `do_resize` in the `preprocess` method.
+        size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
+            Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
+            resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
+            be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
+            `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
+            be overriden by `size` in the `preprocess` method.
+        crop_pct (`float` *optional*, defaults to 224 / 256):
+            Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
+            overriden by `crop_pct` in the `preprocess` method.
+        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+            Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
+        do_rescale (`bool`, *optional*, defaults to `True`):
+            Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
+            the `preprocess` method.
+        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+            Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
+            method.
+        do_normalize (`bool`, *optional*, defaults to `True`):
+            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+            method.
+        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+    """
+
+    model_input_names = ["pixel_values"]
+
+    def __init__(
+        self,
+        do_resize: bool = True,
+        size: Dict[str, int] = None,
+        crop_pct: float = None,
+        resample: PILImageResampling = PILImageResampling.BILINEAR,
+        do_rescale: bool = True,
+        rescale_factor: Union[int, float] = 1 / 255,
+        do_normalize: bool = True,
+        image_mean: Optional[Union[float, List[float]]] = None,
+        image_std: Optional[Union[float, List[float]]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        size = size if size is not None else {"shortest_edge": 384}
+        size = get_size_dict(size, default_to_square=False)
+
+        self.do_resize = do_resize
+        self.size = size
+        # Default value set here for backwards compatibility where the value in config is None
+        self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
+        self.resample = resample
+        self.do_rescale = do_rescale
+        self.rescale_factor = rescale_factor
+        self.do_normalize = do_normalize
+        self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+    def resize(
+        self,
+        image: np.ndarray,
+        size: Dict[str, int],
+        crop_pct: float,
+        resample: PILImageResampling = PILImageResampling.BICUBIC,
+        data_format: Optional[Union[str, ChannelDimension]] = None,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+        **kwargs,
+    ) -> np.ndarray:
+        """
+        Resize an image.
+
+        Args:
+            image (`np.ndarray`):
+                Image to resize.
+            size (`Dict[str, int]`):
+                Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
+                `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
+                Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
+                after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
+            crop_pct (`float`):
+                Percentage of the image to crop. Only has an effect if size < 384.
+            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+                Resampling filter to use when resizing the image.
+            data_format (`str` or `ChannelDimension`, *optional*):
+                The channel dimension format of the image. If not provided, it will be the same as the input image.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format of the input image. If not provided, it will be inferred from the input
+                image.
+        """
+        size = get_size_dict(size, default_to_square=False)
+        if "shortest_edge" not in size:
+            raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
+        shortest_edge = size["shortest_edge"]
+
+        if shortest_edge < 384:
+            # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
+            resize_shortest_edge = int(shortest_edge / crop_pct)
+            resize_size = get_resize_output_image_size(
+                image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format
+            )
+            image = resize(
+                image=image,
+                size=resize_size,
+                resample=resample,
+                data_format=data_format,
+                input_data_format=input_data_format,
+                **kwargs,
+            )
+            # then crop to (shortest_edge, shortest_edge)
+            return center_crop(
+                image=image,
+                size=(shortest_edge, shortest_edge),
+                data_format=data_format,
+                input_data_format=input_data_format,
+                **kwargs,
+            )
+        else:
+            # warping (no cropping) when evaluated at 384 or larger
+            return resize(
+                image,
+                size=(shortest_edge, shortest_edge),
+                resample=resample,
+                data_format=data_format,
+                input_data_format=input_data_format,
+                **kwargs,
+            )
+
+    @filter_out_non_signature_kwargs()
+    def preprocess(
+        self,
+        images: ImageInput,
+        do_resize: bool = None,
+        size: Dict[str, int] = None,
+        crop_pct: float = None,
+        resample: PILImageResampling = None,
+        do_rescale: bool = None,
+        rescale_factor: float = None,
+        do_normalize: bool = None,
+        image_mean: Optional[Union[float, List[float]]] = None,
+        image_std: Optional[Union[float, List[float]]] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        data_format: ChannelDimension = ChannelDimension.FIRST,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+    ) -> PIL.Image.Image:
+        """
+        Preprocess an image or batch of images.
+
+        Args:
+            images (`ImageInput`):
+                Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+                passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+            do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+                Whether to resize the image.
+            size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+                Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
+                is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
+                image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
+                `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
+            crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
+                Percentage of the image to crop if size < 384.
+            resample (`int`, *optional*, defaults to `self.resample`):
+                Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
+                has an effect if `do_resize` is set to `True`.
+            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+                Whether to rescale the image values between [0 - 1].
+            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+                Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+                Whether to normalize the image.
+            image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+                Image mean.
+            image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+                Image standard deviation.
+            return_tensors (`str` or `TensorType`, *optional*):
+                The type of tensors to return. Can be one of:
+                    - Unset: Return a list of `np.ndarray`.
+                    - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+                    - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+                    - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+                    - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+            data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+                The channel dimension format for the output image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - Unset: Use the channel dimension format of the input image.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+        """
+        do_resize = do_resize if do_resize is not None else self.do_resize
+        crop_pct = crop_pct if crop_pct is not None else self.crop_pct
+        resample = resample if resample is not None else self.resample
+        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+        rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+        image_mean = image_mean if image_mean is not None else self.image_mean
+        image_std = image_std if image_std is not None else self.image_std
+
+        size = size if size is not None else self.size
+        size = get_size_dict(size, default_to_square=False)
+
+        images = make_list_of_images(images)
+
+        if not valid_images(images):
+            raise ValueError(
+                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+                "torch.Tensor, tf.Tensor or jax.ndarray."
+            )
+
+        validate_preprocess_arguments(
+            do_rescale=do_rescale,
+            rescale_factor=rescale_factor,
+            do_normalize=do_normalize,
+            image_mean=image_mean,
+            image_std=image_std,
+            do_resize=do_resize,
+            size=size,
+            resample=resample,
+        )
+
+        # All transformations expect numpy arrays.
+        images = [to_numpy_array(image) for image in images]
+
+        if do_rescale and is_scaled_image(images[0]):
+            logger.warning_once(
+                "It looks like you are trying to rescale already rescaled images. If the input"
+                " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+            )
+
+        if input_data_format is None:
+            # We assume that all images have the same channel dimension format.
+            input_data_format = infer_channel_dimension_format(images[0])
+
+        if do_resize:
+            images = [
+                self.resize(
+                    image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
+                )
+                for image in images
+            ]
+
+        if do_rescale:
+            images = [
+                self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+                for image in images
+            ]
+
+        if do_normalize:
+            images = [
+                self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+                for image in images
+            ]
+
+        images = [
+            to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+        ]
+
+        data = {"pixel_values": images}
+        return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["ConvNextImageProcessor"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..155f466ac4ae68034752d9cf11d9e0d5fdfdee69
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py
@@ -0,0 +1,551 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ConvNext model."""
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+    BackboneOutput,
+    BaseModelOutputWithNoAttention,
+    BaseModelOutputWithPoolingAndNoAttention,
+    ImageClassifierOutputWithNoAttention,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "ConvNextConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/convnext-tiny-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+    """
+    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+    argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return input
+    keep_prob = 1 - drop_prob
+    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+    random_tensor.floor_()  # binarize
+    output = input.div(keep_prob) * random_tensor
+    return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
+class ConvNextDropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob: Optional[float] = None) -> None:
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        return drop_path(hidden_states, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return "p={}".format(self.drop_prob)
+
+
+class ConvNextLayerNorm(nn.Module):
+    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+    """
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError(f"Unsupported data format: {self.data_format}")
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.data_format == "channels_last":
+            x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+        elif self.data_format == "channels_first":
+            input_dtype = x.dtype
+            x = x.float()
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = x.to(dtype=input_dtype)
+            x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
+
+
+class ConvNextEmbeddings(nn.Module):
+    """This class is comparable to (and inspired by) the SwinEmbeddings class
+    found in src/transformers/models/swin/modeling_swin.py.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.patch_embeddings = nn.Conv2d(
+            config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
+        )
+        self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
+        self.num_channels = config.num_channels
+
+    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+        num_channels = pixel_values.shape[1]
+        if num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+            )
+        embeddings = self.patch_embeddings(pixel_values)
+        embeddings = self.layernorm(embeddings)
+        return embeddings
+
+
+class ConvNextLayer(nn.Module):
+    """This corresponds to the `Block` class in the original implementation.
+
+    There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+    H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+    The authors used (2) as they find it slightly faster in PyTorch.
+
+    Args:
+        config ([`ConvNextConfig`]): Model configuration class.
+        dim (`int`): Number of input channels.
+        drop_path (`float`): Stochastic depth rate. Default: 0.0.
+    """
+
+    def __init__(self, config, dim, drop_path=0):
+        super().__init__()
+        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
+        self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
+        self.act = ACT2FN[config.hidden_act]
+        self.pwconv2 = nn.Linear(4 * dim, dim)
+        self.layer_scale_parameter = (
+            nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+            if config.layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
+        input = hidden_states
+        x = self.dwconv(hidden_states)
+        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
+        x = self.layernorm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+        if self.layer_scale_parameter is not None:
+            x = self.layer_scale_parameter * x
+        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
+
+        x = input + self.drop_path(x)
+        return x
+
+
+class ConvNextStage(nn.Module):
+    """ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+    Args:
+        config ([`ConvNextConfig`]): Model configuration class.
+        in_channels (`int`): Number of input channels.
+        out_channels (`int`): Number of output channels.
+        depth (`int`): Number of residual blocks.
+        drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
+    """
+
+    def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
+        super().__init__()
+
+        if in_channels != out_channels or stride > 1:
+            self.downsampling_layer = nn.Sequential(
+                ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
+                nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
+            )
+        else:
+            self.downsampling_layer = nn.Identity()
+        drop_path_rates = drop_path_rates or [0.0] * depth
+        self.layers = nn.Sequential(
+            *[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
+        )
+
+    def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
+        hidden_states = self.downsampling_layer(hidden_states)
+        hidden_states = self.layers(hidden_states)
+        return hidden_states
+
+
+class ConvNextEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.stages = nn.ModuleList()
+        drop_path_rates = [
+            x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
+        ]
+        prev_chs = config.hidden_sizes[0]
+        for i in range(config.num_stages):
+            out_chs = config.hidden_sizes[i]
+            stage = ConvNextStage(
+                config,
+                in_channels=prev_chs,
+                out_channels=out_chs,
+                stride=2 if i > 0 else 1,
+                depth=config.depths[i],
+                drop_path_rates=drop_path_rates[i],
+            )
+            self.stages.append(stage)
+            prev_chs = out_chs
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        output_hidden_states: Optional[bool] = False,
+        return_dict: Optional[bool] = True,
+    ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
+        all_hidden_states = () if output_hidden_states else None
+
+        for i, layer_module in enumerate(self.stages):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            hidden_states = layer_module(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+        return BaseModelOutputWithNoAttention(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+        )
+
+
+class ConvNextPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = ConvNextConfig
+    base_model_prefix = "convnext"
+    main_input_name = "pixel_values"
+    _no_split_modules = ["ConvNextLayer"]
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+CONVNEXT_START_DOCSTRING = r"""
+    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+    behavior.
+
+    Parameters:
+        config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CONVNEXT_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+            [`ConvNextImageProcessor.__call__`] for details.
+
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare ConvNext model outputting raw features without any specific head on top.",
+    CONVNEXT_START_DOCSTRING,
+)
+class ConvNextModel(ConvNextPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = ConvNextEmbeddings(config)
+        self.encoder = ConvNextEncoder(config)
+
+        # final layernorm layer
+        self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=BaseModelOutputWithPoolingAndNoAttention,
+        config_class=_CONFIG_FOR_DOC,
+        modality="vision",
+        expected_output=_EXPECTED_OUTPUT_SHAPE,
+    )
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        embedding_output = self.embeddings(pixel_values)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        last_hidden_state = encoder_outputs[0]
+
+        # global average pooling, (N, C, H, W) -> (N, C)
+        pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
+
+        if not return_dict:
+            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPoolingAndNoAttention(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+        )
+
+
+@add_start_docstrings(
+    """
+    ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+    ImageNet.
+    """,
+    CONVNEXT_START_DOCSTRING,
+)
+class ConvNextForImageClassification(ConvNextPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.convnext = ConvNextModel(config)
+
+        # Classifier head
+        self.classifier = (
+            nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_IMAGE_CLASS_CHECKPOINT,
+        output_type=ImageClassifierOutputWithNoAttention,
+        config_class=_CONFIG_FOR_DOC,
+        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+    )
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.convnext(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+        pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+        logits = self.classifier(pooled_output)
+
+        loss = None
+        if labels is not None:
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(logits, labels)
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return ImageClassifierOutputWithNoAttention(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+        )
+
+
+@add_start_docstrings(
+    """
+    ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
+    """,
+    CONVNEXT_START_DOCSTRING,
+)
+class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
+    def __init__(self, config):
+        super().__init__(config)
+        super()._init_backbone(config)
+
+        self.embeddings = ConvNextEmbeddings(config)
+        self.encoder = ConvNextEncoder(config)
+        self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
+
+        # Add layer norms to hidden states of out_features
+        hidden_states_norms = {}
+        for stage, num_channels in zip(self._out_features, self.channels):
+            hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
+        self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+        # initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> BackboneOutput:
+        """
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, AutoBackbone
+        >>> import torch
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+        >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
+
+        >>> inputs = processor(image, return_tensors="pt")
+        >>> outputs = model(**inputs)
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        embedding_output = self.embeddings(pixel_values)
+
+        outputs = self.encoder(
+            embedding_output,
+            output_hidden_states=True,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+        feature_maps = ()
+        for stage, hidden_state in zip(self.stage_names, hidden_states):
+            if stage in self.out_features:
+                hidden_state = self.hidden_states_norms[stage](hidden_state)
+                feature_maps += (hidden_state,)
+
+        if not return_dict:
+            output = (feature_maps,)
+            if output_hidden_states:
+                output += (hidden_states,)
+            return output
+
+        return BackboneOutput(
+            feature_maps=feature_maps,
+            hidden_states=hidden_states if output_hidden_states else None,
+            attentions=None,
+        )
+
+
+__all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..af1bae81db495f17d5fe1696efa020a8b4d28af2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py
@@ -0,0 +1,669 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 ConvNext model."""
+
+from __future__ import annotations
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
+from ...modeling_tf_utils import (
+    TFModelInputType,
+    TFPreTrainedModel,
+    TFSequenceClassificationLoss,
+    get_initializer,
+    keras,
+    keras_serializable,
+    unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "ConvNextConfig"
+_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
+
+
+class TFConvNextDropPath(keras.layers.Layer):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+    References:
+        (1) github.com:rwightman/pytorch-image-models
+    """
+
+    def __init__(self, drop_path: float, **kwargs):
+        super().__init__(**kwargs)
+        self.drop_path = drop_path
+
+    def call(self, x: tf.Tensor, training=None):
+        if training:
+            keep_prob = 1 - self.drop_path
+            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+            random_tensor = tf.floor(random_tensor)
+            return (x / keep_prob) * random_tensor
+        return x
+
+
+class TFConvNextEmbeddings(keras.layers.Layer):
+    """This class is comparable to (and inspired by) the SwinEmbeddings class
+    found in src/transformers/models/swin/modeling_swin.py.
+    """
+
+    def __init__(self, config: ConvNextConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.patch_embeddings = keras.layers.Conv2D(
+            filters=config.hidden_sizes[0],
+            kernel_size=config.patch_size,
+            strides=config.patch_size,
+            name="patch_embeddings",
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer=keras.initializers.Zeros(),
+        )
+        self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
+        self.num_channels = config.num_channels
+        self.config = config
+
+    def call(self, pixel_values):
+        if isinstance(pixel_values, dict):
+            pixel_values = pixel_values["pixel_values"]
+
+        tf.debugging.assert_equal(
+            shape_list(pixel_values)[1],
+            self.num_channels,
+            message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
+        )
+
+        # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+        # So change the input format from `NCHW` to `NHWC`.
+        # shape = (batch_size, in_height, in_width, in_channels)
+        pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+        embeddings = self.patch_embeddings(pixel_values)
+        embeddings = self.layernorm(embeddings)
+        return embeddings
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "patch_embeddings", None) is not None:
+            with tf.name_scope(self.patch_embeddings.name):
+                self.patch_embeddings.build([None, None, None, self.config.num_channels])
+        if getattr(self, "layernorm", None) is not None:
+            with tf.name_scope(self.layernorm.name):
+                self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
+
+
+class TFConvNextLayer(keras.layers.Layer):
+    """This corresponds to the `Block` class in the original implementation.
+
+    There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+    H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+    The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
+    NHWC ordering, we can just apply the operations straight-away without the permutation.
+
+    Args:
+        config ([`ConvNextConfig`]): Model configuration class.
+        dim (`int`): Number of input channels.
+        drop_path (`float`): Stochastic depth rate. Default: 0.0.
+    """
+
+    def __init__(self, config, dim, drop_path=0.0, **kwargs):
+        super().__init__(**kwargs)
+        self.dim = dim
+        self.config = config
+        self.dwconv = keras.layers.Conv2D(
+            filters=dim,
+            kernel_size=7,
+            padding="same",
+            groups=dim,
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer="zeros",
+            name="dwconv",
+        )  # depthwise conv
+        self.layernorm = keras.layers.LayerNormalization(
+            epsilon=1e-6,
+            name="layernorm",
+        )
+        self.pwconv1 = keras.layers.Dense(
+            units=4 * dim,
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer="zeros",
+            name="pwconv1",
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = get_tf_activation(config.hidden_act)
+        self.pwconv2 = keras.layers.Dense(
+            units=dim,
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer="zeros",
+            name="pwconv2",
+        )
+        # Using `layers.Activation` instead of `tf.identity` to better control `training`
+        # behaviour.
+        self.drop_path = (
+            TFConvNextDropPath(drop_path, name="drop_path")
+            if drop_path > 0.0
+            else keras.layers.Activation("linear", name="drop_path")
+        )
+
+    def build(self, input_shape: tf.TensorShape = None):
+        # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
+        self.layer_scale_parameter = (
+            self.add_weight(
+                shape=(self.dim,),
+                initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
+                trainable=True,
+                name="layer_scale_parameter",
+            )
+            if self.config.layer_scale_init_value > 0
+            else None
+        )
+
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dwconv", None) is not None:
+            with tf.name_scope(self.dwconv.name):
+                self.dwconv.build([None, None, None, self.dim])
+        if getattr(self, "layernorm", None) is not None:
+            with tf.name_scope(self.layernorm.name):
+                self.layernorm.build([None, None, None, self.dim])
+        if getattr(self, "pwconv1", None) is not None:
+            with tf.name_scope(self.pwconv1.name):
+                self.pwconv1.build([None, None, self.dim])
+        if getattr(self, "pwconv2", None) is not None:
+            with tf.name_scope(self.pwconv2.name):
+                self.pwconv2.build([None, None, 4 * self.dim])
+        if getattr(self, "drop_path", None) is not None:
+            with tf.name_scope(self.drop_path.name):
+                self.drop_path.build(None)
+
+    def call(self, hidden_states, training=False):
+        input = hidden_states
+        x = self.dwconv(hidden_states)
+        x = self.layernorm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+
+        if self.layer_scale_parameter is not None:
+            x = self.layer_scale_parameter * x
+
+        x = input + self.drop_path(x, training=training)
+        return x
+
+
+class TFConvNextStage(keras.layers.Layer):
+    """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+    Args:
+        config (`ConvNextV2Config`):
+            Model configuration class.
+        in_channels (`int`):
+            Number of input channels.
+        out_channels (`int`):
+            Number of output channels.
+        depth (`int`):
+            Number of residual blocks.
+        drop_path_rates(`List[float]`):
+            Stochastic depth rates for each layer.
+    """
+
+    def __init__(
+        self,
+        config: ConvNextConfig,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int = 2,
+        stride: int = 2,
+        depth: int = 2,
+        drop_path_rates: Optional[List[float]] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        if in_channels != out_channels or stride > 1:
+            self.downsampling_layer = [
+                keras.layers.LayerNormalization(
+                    epsilon=1e-6,
+                    name="downsampling_layer.0",
+                ),
+                # Inputs to this layer will follow NHWC format since we
+                # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings`
+                # layer. All the outputs throughout the model will be in NHWC
+                # from this point on until the output where we again change to
+                # NCHW.
+                keras.layers.Conv2D(
+                    filters=out_channels,
+                    kernel_size=kernel_size,
+                    strides=stride,
+                    kernel_initializer=get_initializer(config.initializer_range),
+                    bias_initializer=keras.initializers.Zeros(),
+                    name="downsampling_layer.1",
+                ),
+            ]
+        else:
+            self.downsampling_layer = [tf.identity]
+
+        drop_path_rates = drop_path_rates or [0.0] * depth
+        self.layers = [
+            TFConvNextLayer(
+                config,
+                dim=out_channels,
+                drop_path=drop_path_rates[j],
+                name=f"layers.{j}",
+            )
+            for j in range(depth)
+        ]
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.stride = stride
+
+    def call(self, hidden_states):
+        for layer in self.downsampling_layer:
+            hidden_states = layer(hidden_states)
+        for layer in self.layers:
+            hidden_states = layer(hidden_states)
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "layers", None) is not None:
+            for layer in self.layers:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+        if self.in_channels != self.out_channels or self.stride > 1:
+            with tf.name_scope(self.downsampling_layer[0].name):
+                self.downsampling_layer[0].build([None, None, None, self.in_channels])
+            with tf.name_scope(self.downsampling_layer[1].name):
+                self.downsampling_layer[1].build([None, None, None, self.in_channels])
+
+
+class TFConvNextEncoder(keras.layers.Layer):
+    def __init__(self, config, **kwargs):
+        super().__init__(**kwargs)
+        self.stages = []
+        drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
+        drop_path_rates = tf.split(drop_path_rates, config.depths)
+        drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
+        prev_chs = config.hidden_sizes[0]
+        for i in range(config.num_stages):
+            out_chs = config.hidden_sizes[i]
+            stage = TFConvNextStage(
+                config,
+                in_channels=prev_chs,
+                out_channels=out_chs,
+                stride=2 if i > 0 else 1,
+                depth=config.depths[i],
+                drop_path_rates=drop_path_rates[i],
+                name=f"stages.{i}",
+            )
+            self.stages.append(stage)
+            prev_chs = out_chs
+
+    def call(self, hidden_states, output_hidden_states=False, return_dict=True):
+        all_hidden_states = () if output_hidden_states else None
+
+        for i, layer_module in enumerate(self.stages):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            hidden_states = layer_module(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+        return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+    def build(self, input_shape=None):
+        for stage in self.stages:
+            with tf.name_scope(stage.name):
+                stage.build(None)
+
+
+@keras_serializable
+class TFConvNextMainLayer(keras.layers.Layer):
+    config_class = ConvNextConfig
+
+    def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs):
+        super().__init__(**kwargs)
+
+        self.config = config
+        self.embeddings = TFConvNextEmbeddings(config, name="embeddings")
+        self.encoder = TFConvNextEncoder(config, name="encoder")
+        self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+        # We are setting the `data_format` like so because from here on we will revert to the
+        # NCHW output format
+        self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
+
+    @unpack_inputs
+    def call(
+        self,
+        pixel_values: TFModelInputType | None = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        training: bool = False,
+    ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        embedding_output = self.embeddings(pixel_values, training=training)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        last_hidden_state = encoder_outputs[0]
+        # Change to NCHW output format have uniformity in the modules
+        last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
+        pooled_output = self.layernorm(self.pooler(last_hidden_state))
+
+        # Change the other hidden state outputs to NCHW as well
+        if output_hidden_states:
+            hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
+
+        if not return_dict:
+            hidden_states = hidden_states if output_hidden_states else ()
+            return (last_hidden_state, pooled_output) + hidden_states
+
+        return TFBaseModelOutputWithPooling(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embeddings", None) is not None:
+            with tf.name_scope(self.embeddings.name):
+                self.embeddings.build(None)
+        if getattr(self, "encoder", None) is not None:
+            with tf.name_scope(self.encoder.name):
+                self.encoder.build(None)
+        if getattr(self, "layernorm", None) is not None:
+            with tf.name_scope(self.layernorm.name):
+                self.layernorm.build([None, self.config.hidden_sizes[-1]])
+
+
+class TFConvNextPreTrainedModel(TFPreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = ConvNextConfig
+    base_model_prefix = "convnext"
+    main_input_name = "pixel_values"
+
+
+CONVNEXT_START_DOCSTRING = r"""
+    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+    behavior.
+
+    
+
+    TensorFlow models and layers in `transformers` accept two formats as input:
+
+    - having all inputs as keyword arguments (like PyTorch models), or
+    - having all inputs as a list, tuple or dict in the first positional argument.
+
+    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+    positional argument:
+
+    - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
+    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+    `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
+    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+    `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
+
+    Note that when creating models and layers with
+    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+    about any of this, as you can just pass inputs like you would to any other Python function!
+
+    
+
+    Parameters:
+        config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CONVNEXT_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+            [`ConvNextImageProcessor.__call__`] for details.
+
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+            used instead.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+            eager mode, in graph mode the value will always be set to True.
+"""
+
+
+@add_start_docstrings(
+    "The bare ConvNext model outputting raw features without any specific head on top.",
+    CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextModel(TFConvNextPreTrainedModel):
+    def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+        self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        pixel_values: TFModelInputType | None = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        training: bool = False,
+    ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, TFConvNextModel
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+        >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
+
+        >>> inputs = image_processor(images=image, return_tensors="tf")
+        >>> outputs = model(**inputs)
+        >>> last_hidden_states = outputs.last_hidden_state
+        ```"""
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        outputs = self.convnext(
+            pixel_values=pixel_values,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        if not return_dict:
+            return (outputs[0],) + outputs[1:]
+
+        return TFBaseModelOutputWithPooling(
+            last_hidden_state=outputs.last_hidden_state,
+            pooler_output=outputs.pooler_output,
+            hidden_states=outputs.hidden_states,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "convnext", None) is not None:
+            with tf.name_scope(self.convnext.name):
+                self.convnext.build(None)
+
+
+@add_start_docstrings(
+    """
+    ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+    ImageNet.
+    """,
+    CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss):
+    def __init__(self, config: ConvNextConfig, *inputs, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+
+        self.num_labels = config.num_labels
+        self.convnext = TFConvNextMainLayer(config, name="convnext")
+
+        # Classifier head
+        self.classifier = keras.layers.Dense(
+            units=config.num_labels,
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer="zeros",
+            name="classifier",
+        )
+        self.config = config
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        pixel_values: TFModelInputType | None = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: np.ndarray | tf.Tensor | None = None,
+        training: Optional[bool] = False,
+    ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
+        r"""
+        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, TFConvNextForImageClassification
+        >>> import tensorflow as tf
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+        >>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
+
+        >>> inputs = image_processor(images=image, return_tensors="tf")
+        >>> outputs = model(**inputs)
+        >>> logits = outputs.logits
+        >>> # model predicts one of the 1000 ImageNet classes
+        >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
+        >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
+        ```"""
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        outputs = self.convnext(
+            pixel_values,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+        logits = self.classifier(pooled_output)
+        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TFSequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "convnext", None) is not None:
+            with tf.name_scope(self.convnext.name):
+                self.convnext.build(None)
+        if getattr(self, "classifier", None) is not None:
+            if hasattr(self.classifier, "name"):
+                with tf.name_scope(self.classifier.name):
+                    self.classifier.build([None, None, self.config.hidden_sizes[-1]])
+
+
+__all__ = ["TFConvNextForImageClassification", "TFConvNextModel", "TFConvNextPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..455f7ffec5dee0567039cfd844588cb991c15622
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_decision_transformer import *
+    from .modeling_decision_transformer import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b302319ce5786d77101087c8da66a7deb200be96
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..935ba6dbf3ad967f81adcef9a6c7fb68f56bbfaa
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b0b36beec6dacf375ccbe131395eb50d668bcac
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c0f918c43578dd43405e9d68eb807cab0d8144
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py
@@ -0,0 +1,963 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DecisionTransformer model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
+from ...utils import (
+    ModelOutput,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_decision_transformer import DecisionTransformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "edbeeching/decision-transformer-gym-hopper-medium"
+_CONFIG_FOR_DOC = "DecisionTransformerConfig"
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.load_tf_weights_in_gpt2
+def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
+    """Load tf checkpoints in a pytorch model"""
+    try:
+        import re
+
+        import tensorflow as tf
+    except ImportError:
+        logger.error(
+            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+            "https://www.tensorflow.org/install/ for installation instructions."
+        )
+        raise
+    tf_path = os.path.abspath(gpt2_checkpoint_path)
+    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+    # Load weights from TF model
+    init_vars = tf.train.list_variables(tf_path)
+    names = []
+    arrays = []
+    for name, shape in init_vars:
+        logger.info(f"Loading TF weight {name} with shape {shape}")
+        array = tf.train.load_variable(tf_path, name)
+        names.append(name)
+        arrays.append(array.squeeze())
+
+    for name, array in zip(names, arrays):
+        name = name[6:]  # skip "model/"
+        name = name.split("/")
+        pointer = model
+        for m_name in name:
+            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
+                scope_names = re.split(r"(\d+)", m_name)
+            else:
+                scope_names = [m_name]
+            if scope_names[0] == "w" or scope_names[0] == "g":
+                pointer = getattr(pointer, "weight")
+            elif scope_names[0] == "b":
+                pointer = getattr(pointer, "bias")
+            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
+                pointer = getattr(pointer, scope_names[0])
+                pointer = getattr(pointer, "weight")
+            else:
+                pointer = getattr(pointer, scope_names[0])
+            if len(scope_names) >= 2:
+                num = int(scope_names[1])
+                pointer = pointer[num]
+        try:
+            if pointer.shape != array.shape:
+                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+        except ValueError as e:
+            e.args += (pointer.shape, array.shape)
+            raise
+        logger.info(f"Initialize PyTorch weight {name}")
+        pointer.data = torch.from_numpy(array)
+    return model
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward
+def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
+    attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+    if module.scale_attn_weights:
+        attn_weights = attn_weights / torch.full(
+            [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
+        )
+
+    # Layer-wise attention scaling
+    if module.scale_attn_by_inverse_layer_idx:
+        attn_weights = attn_weights / float(module.layer_idx + 1)
+
+    if not module.is_cross_attention:
+        # if only "normal" attention layer implements causal mask
+        query_length, key_length = query.size(-2), key.size(-2)
+        causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
+        mask_value = torch.finfo(attn_weights.dtype).min
+        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+        mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+        attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+    if attention_mask is not None:
+        # Apply the attention mask
+        attn_weights = attn_weights + attention_mask
+
+    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+    # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+    attn_weights = attn_weights.type(value.dtype)
+    attn_weights = module.attn_dropout(attn_weights)
+
+    # Mask heads if we want to
+    if head_mask is not None:
+        attn_weights = attn_weights * head_mask
+
+    attn_output = torch.matmul(attn_weights, value)
+    attn_output = attn_output.transpose(1, 2)
+
+    return attn_output, attn_weights
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2
+class DecisionTransformerGPT2Attention(nn.Module):
+    def __init__(self, config, is_cross_attention=False, layer_idx=None):
+        super().__init__()
+        self.config = config
+        max_positions = config.max_position_embeddings
+        self.register_buffer(
+            "bias",
+            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
+                1, 1, max_positions, max_positions
+            ),
+            persistent=False,
+        )
+        self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
+
+        self.embed_dim = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.embed_dim // self.num_heads
+        self.split_size = self.embed_dim
+        if self.head_dim * self.num_heads != self.embed_dim:
+            raise ValueError(
+                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+                f" {self.num_heads})."
+            )
+
+        self.scale_attn_weights = config.scale_attn_weights
+        self.is_cross_attention = is_cross_attention
+
+        # Layer-wise attention scaling, reordering, and upcasting
+        self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
+        self.layer_idx = layer_idx
+        self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
+
+        if self.is_cross_attention:
+            self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
+            self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
+        else:
+            self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
+        self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
+
+        self.attn_dropout = nn.Dropout(config.attn_pdrop)
+        self.resid_dropout = nn.Dropout(config.resid_pdrop)
+        self.is_causal = True
+
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
+        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
+
+        # Prune conv1d layers
+        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
+        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
+
+        # Update hyper params
+        self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
+        self.num_heads = self.num_heads - len(heads)
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
+        # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
+        bsz, num_heads, q_seq_len, dk = query.size()
+        _, _, k_seq_len, _ = key.size()
+
+        # Preallocate attn_weights for `baddbmm`
+        attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
+
+        # Compute Scale Factor
+        scale_factor = 1.0
+        if self.scale_attn_weights:
+            scale_factor /= float(value.size(-1)) ** 0.5
+
+        if self.scale_attn_by_inverse_layer_idx:
+            scale_factor /= float(self.layer_idx + 1)
+
+        # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
+        with torch.amp.autocast(query.device.type, enabled=False):
+            q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
+            attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
+            attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
+
+        if not self.is_cross_attention:
+            # if only "normal" attention layer implements causal mask
+            query_length, key_length = query.size(-2), key.size(-2)
+            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+            mask_value = torch.finfo(attn_weights.dtype).min
+            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+            mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+            attn_weights = torch.where(causal_mask, attn_weights, mask_value)
+
+        if attention_mask is not None:
+            # Apply the attention mask
+            attn_weights = attn_weights + attention_mask
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
+        if attn_weights.dtype != torch.float32:
+            raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
+        attn_weights = attn_weights.type(value.dtype)
+        attn_weights = self.attn_dropout(attn_weights)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attn_weights = attn_weights * head_mask
+
+        attn_output = torch.matmul(attn_weights, value)
+        attn_output = attn_output.transpose(1, 2)
+
+        return attn_output, attn_weights
+
+    def forward(
+        self,
+        hidden_states: Optional[Tuple[torch.FloatTensor]],
+        layer_past: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = False,
+        output_attentions: Optional[bool] = False,
+        **kwargs,
+    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+        if encoder_hidden_states is not None:
+            if not hasattr(self, "q_attn"):
+                raise ValueError(
+                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
+                    "Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
+                )
+
+            query_states = self.q_attn(hidden_states)
+            key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+            attention_mask = encoder_attention_mask
+        else:
+            query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+        shape_q = (*query_states.shape[:-1], -1, self.head_dim)
+        shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
+
+        query_states = query_states.view(shape_q).transpose(1, 2)
+        key_states = key_states.view(shape_kv).transpose(1, 2)
+        value_states = value_states.view(shape_kv).transpose(1, 2)
+
+        if layer_past is not None:
+            past_key, past_value = layer_past
+            key_states = torch.cat((past_key, key_states), dim=-2)
+            value_states = torch.cat((past_value, value_states), dim=-2)
+
+        if use_cache is True:
+            present = (key_states, value_states)
+        else:
+            present = None
+
+        is_cross_attention = encoder_hidden_states is not None
+        is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
+
+        using_eager = self.config._attn_implementation == "eager"
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
+                using_eager = True
+                logger.warning_once(
+                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+                )
+            else:
+                # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
+                # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
+                # not necessarily to eager (if mentionned options are provided).
+                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        if using_eager and self.reorder_and_upcast_attn:
+            attn_output, attn_weights = self._upcast_and_reordered_attn(
+                query_states, key_states, value_states, attention_mask, head_mask
+            )
+        else:
+            attn_output, attn_weights = attention_interface(
+                self,
+                query_states,
+                key_states,
+                value_states,
+                attention_mask,
+                head_mask=head_mask,
+                dropout=self.attn_dropout.p if self.training else 0.0,
+                is_causal=is_causal,
+                **kwargs,
+            )
+
+        attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
+        attn_output = self.c_proj(attn_output)
+        attn_output = self.resid_dropout(attn_output)
+
+        outputs = (attn_output, present)
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs  # a, present, (attentions)
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
+class DecisionTransformerGPT2MLP(nn.Module):
+    def __init__(self, intermediate_size, config):
+        super().__init__()
+        embed_dim = config.hidden_size
+        self.c_fc = Conv1D(intermediate_size, embed_dim)
+        self.c_proj = Conv1D(embed_dim, intermediate_size)
+        self.act = ACT2FN[config.activation_function]
+        self.dropout = nn.Dropout(config.resid_pdrop)
+
+    def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+        hidden_states = self.c_fc(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states = self.c_proj(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
+class DecisionTransformerGPT2Block(nn.Module):
+    # Ignore copy
+    def __init__(self, config, layer_idx=None):
+        super().__init__()
+        hidden_size = config.hidden_size
+        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+
+        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.attn = DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
+        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+        if config.add_cross_attention:
+            self.crossattention = DecisionTransformerGPT2Attention(
+                config, is_cross_attention=True, layer_idx=layer_idx
+            )
+            self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+        self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
+
+    def forward(
+        self,
+        hidden_states: Optional[Tuple[torch.FloatTensor]],
+        layer_past: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = False,
+        output_attentions: Optional[bool] = False,
+    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
+        residual = hidden_states
+        hidden_states = self.ln_1(hidden_states)
+        attn_outputs = self.attn(
+            hidden_states,
+            layer_past=layer_past,
+            attention_mask=attention_mask,
+            head_mask=head_mask,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+        )
+        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
+        outputs = attn_outputs[1:]
+        # residual connection
+        hidden_states = attn_output + residual
+
+        if encoder_hidden_states is not None:
+            # add one self-attention block for cross-attention
+            if not hasattr(self, "crossattention"):
+                raise ValueError(
+                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+                    "cross-attention layers by setting `config.add_cross_attention=True`"
+                )
+            residual = hidden_states
+            hidden_states = self.ln_cross_attn(hidden_states)
+            cross_attn_outputs = self.crossattention(
+                hidden_states,
+                attention_mask=attention_mask,
+                head_mask=head_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                output_attentions=output_attentions,
+            )
+            attn_output = cross_attn_outputs[0]
+            # residual connection
+            hidden_states = residual + attn_output
+            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights
+
+        residual = hidden_states
+        hidden_states = self.ln_2(hidden_states)
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        # residual connection
+        hidden_states = residual + feed_forward_hidden_states
+
+        if use_cache:
+            outputs = (hidden_states,) + outputs
+        else:
+            outputs = (hidden_states,) + outputs[1:]
+
+        return outputs  # hidden_states, present, (attentions, cross_attentions)
+
+
+class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = DecisionTransformerConfig
+    load_tf_weights = load_tf_weights_in_gpt2
+    base_model_prefix = "transformer"
+    is_parallelizable = True
+    supports_gradient_checkpointing = True
+
+    def __init__(self, *inputs, **kwargs):
+        super().__init__(*inputs, **kwargs)
+
+    def _init_weights(self, module):
+        """Initialize the weights."""
+        if isinstance(module, (nn.Linear, Conv1D)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
+        #
+        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+        for name, p in module.named_parameters():
+            if "c_proj" in name and "weight" in name:
+                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+                p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
+
+
+class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.embed_dim = config.hidden_size
+
+        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+        self.drop = nn.Dropout(config.embd_pdrop)
+        self.h = nn.ModuleList(
+            [DecisionTransformerGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
+        )
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+        # Model parallel
+        self.model_parallel = False
+        self.device_map = None
+        self.gradient_checkpointing = False
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.wte
+
+    def set_input_embeddings(self, new_embeddings):
+        self.wte = new_embeddings
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+            batch_size = input_ids.shape[0]
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            batch_size = inputs_embeds.shape[0]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+        if token_type_ids is not None:
+            token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+        if past_key_values is None:
+            past_length = 0
+            past_key_values = tuple([None] * len(self.h))
+        else:
+            past_length = past_key_values[0][0].size(-2)
+        if position_ids is None:
+            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+            position_ids = position_ids.unsqueeze(0)
+
+        # Attention mask.
+        if attention_mask is not None:
+            if batch_size <= 0:
+                raise ValueError("batch_size has to be defined and > 0")
+            attention_mask = attention_mask.view(batch_size, -1)
+            # We create a 3D attention mask from a 2D tensor mask.
+            # Sizes are [batch_size, 1, 1, to_seq_length]
+            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+            # this attention mask is more simple than the triangular masking of causal attention
+            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+            attention_mask = attention_mask[:, None, None, :]
+
+            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+            # masked positions, this operation will create a tensor which is 0.0 for
+            # positions we want to attend and the dtype's smallest value for masked positions.
+            # Since we are adding it to the raw scores before the softmax, this is
+            # effectively the same as removing these entirely.
+            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if self.config.add_cross_attention and encoder_hidden_states is not None:
+            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            if encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+        else:
+            encoder_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # head_mask has shape n_layer x batch x n_heads x N x N
+        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.wte(input_ids)
+        position_embeds = self.wpe(position_ids)
+        hidden_states = inputs_embeds + position_embeds
+
+        if token_type_ids is not None:
+            token_type_embeds = self.wte(token_type_ids)
+            hidden_states = hidden_states + token_type_embeds
+
+        hidden_states = self.drop(hidden_states)
+
+        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        presents = () if use_cache else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+        all_hidden_states = () if output_hidden_states else None
+        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+            # Model parallel
+            if self.model_parallel:
+                torch.cuda.set_device(hidden_states.device)
+                # Ensure layer_past is on same device as hidden_states (might not be correct)
+                if layer_past is not None:
+                    layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+                # Ensure that attention_mask is always on the same device as hidden_states
+                if attention_mask is not None:
+                    attention_mask = attention_mask.to(hidden_states.device)
+                if isinstance(head_mask, torch.Tensor):
+                    head_mask = head_mask.to(hidden_states.device)
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+                outputs = self._gradient_checkpointing_func(
+                    block.__call__,
+                    hidden_states,
+                    None,
+                    attention_mask,
+                    head_mask[i],
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    use_cache,
+                    output_attentions,
+                )
+            else:
+                outputs = block(
+                    hidden_states,
+                    layer_past=layer_past,
+                    attention_mask=attention_mask,
+                    head_mask=head_mask[i],
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    use_cache=use_cache,
+                    output_attentions=output_attentions,
+                )
+
+            hidden_states = outputs[0]
+            if use_cache is True:
+                presents = presents + (outputs[1],)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+                if self.config.add_cross_attention:
+                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+            # Model Parallel: If it's the last layer for that device, put things on the next device
+            if self.model_parallel:
+                for k, v in self.device_map.items():
+                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
+                        hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+        hidden_states = self.ln_f(hidden_states)
+
+        hidden_states = hidden_states.view(output_shape)
+        # Add last hidden state
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+                if v is not None
+            )
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=presents,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+@dataclass
+class DecisionTransformerOutput(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
+            Environment state predictions
+        action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
+            Model action predictions
+        return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
+            Predicted returns for each state
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+            shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    state_preds: torch.FloatTensor = None
+    action_preds: torch.FloatTensor = None
+    return_preds: torch.FloatTensor = None
+    hidden_states: torch.FloatTensor = None
+    attentions: torch.FloatTensor = None
+    last_hidden_state: torch.FloatTensor = None
+
+
+class DecisionTransformerPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = DecisionTransformerConfig
+    base_model_prefix = "decision_transformer"
+    main_input_name = "states"
+    supports_gradient_checkpointing = False
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, nn.Linear):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+DECISION_TRANSFORMER_START_DOCSTRING = r"""
+    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+    behavior.
+
+    Parameters:
+        config ([`~DecisionTransformerConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DECISION_TRANSFORMER_INPUTS_DOCSTRING = r"""
+    Args:
+        states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
+            The states for each step in the trajectory
+        actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
+            The actions taken by the "expert" policy for the current state, these are masked for auto regressive
+            prediction
+        rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
+            The rewards for each state, action
+        returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
+            The returns for each state in the trajectory
+        timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
+            The timestep for each step in the trajectory
+        attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
+            Masking, used to mask the actions when performing autoregressive prediction
+"""
+
+
+@add_start_docstrings("The Decision Transformer Model", DECISION_TRANSFORMER_START_DOCSTRING)
+class DecisionTransformerModel(DecisionTransformerPreTrainedModel):
+    """
+
+    The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL
+    setting. Refer to the paper for more details: https://arxiv.org/abs/2106.01345
+
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.config = config
+        self.hidden_size = config.hidden_size
+        # note: the only difference between this GPT2Model and the default Huggingface version
+        # is that the positional embeddings are removed (since we'll add those ourselves)
+        self.encoder = DecisionTransformerGPT2Model(config)
+
+        self.embed_timestep = nn.Embedding(config.max_ep_len, config.hidden_size)
+        self.embed_return = torch.nn.Linear(1, config.hidden_size)
+        self.embed_state = torch.nn.Linear(config.state_dim, config.hidden_size)
+        self.embed_action = torch.nn.Linear(config.act_dim, config.hidden_size)
+
+        self.embed_ln = nn.LayerNorm(config.hidden_size)
+
+        # note: we don't predict states or returns for the paper
+        self.predict_state = torch.nn.Linear(config.hidden_size, config.state_dim)
+        self.predict_action = nn.Sequential(
+            *([nn.Linear(config.hidden_size, config.act_dim)] + ([nn.Tanh()] if config.action_tanh else []))
+        )
+        self.predict_return = torch.nn.Linear(config.hidden_size, 1)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(DECISION_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @replace_return_docstrings(output_type=DecisionTransformerOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        states: Optional[torch.FloatTensor] = None,
+        actions: Optional[torch.FloatTensor] = None,
+        rewards: Optional[torch.FloatTensor] = None,
+        returns_to_go: Optional[torch.FloatTensor] = None,
+        timesteps: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple[torch.FloatTensor], DecisionTransformerOutput]:
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import DecisionTransformerModel
+        >>> import torch
+
+        >>> model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium")
+        >>> # evaluation
+        >>> model = model.to(device)
+        >>> model.eval()
+
+        >>> env = gym.make("Hopper-v3")
+        >>> state_dim = env.observation_space.shape[0]
+        >>> act_dim = env.action_space.shape[0]
+
+        >>> state = env.reset()
+        >>> states = torch.from_numpy(state).reshape(1, 1, state_dim).to(device=device, dtype=torch.float32)
+        >>> actions = torch.zeros((1, 1, act_dim), device=device, dtype=torch.float32)
+        >>> rewards = torch.zeros(1, 1, device=device, dtype=torch.float32)
+        >>> target_return = torch.tensor(TARGET_RETURN, dtype=torch.float32).reshape(1, 1)
+        >>> timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
+        >>> attention_mask = torch.zeros(1, 1, device=device, dtype=torch.float32)
+
+        >>> # forward pass
+        >>> with torch.no_grad():
+        ...     state_preds, action_preds, return_preds = model(
+        ...         states=states,
+        ...         actions=actions,
+        ...         rewards=rewards,
+        ...         returns_to_go=target_return,
+        ...         timesteps=timesteps,
+        ...         attention_mask=attention_mask,
+        ...         return_dict=False,
+        ...     )
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        batch_size, seq_length = states.shape[0], states.shape[1]
+
+        if attention_mask is None:
+            # attention mask for GPT: 1 if can be attended to, 0 if not
+            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
+
+        # embed each modality with a different head
+        state_embeddings = self.embed_state(states)
+        action_embeddings = self.embed_action(actions)
+        returns_embeddings = self.embed_return(returns_to_go)
+        time_embeddings = self.embed_timestep(timesteps)
+
+        # time embeddings are treated similar to positional embeddings
+        state_embeddings = state_embeddings + time_embeddings
+        action_embeddings = action_embeddings + time_embeddings
+        returns_embeddings = returns_embeddings + time_embeddings
+
+        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
+        # which works nice in an autoregressive sense since states predict actions
+        stacked_inputs = (
+            torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1)
+            .permute(0, 2, 1, 3)
+            .reshape(batch_size, 3 * seq_length, self.hidden_size)
+        )
+        stacked_inputs = self.embed_ln(stacked_inputs)
+
+        # to make the attention mask fit the stacked inputs, have to stack it as well
+        stacked_attention_mask = (
+            torch.stack((attention_mask, attention_mask, attention_mask), dim=1)
+            .permute(0, 2, 1)
+            .reshape(batch_size, 3 * seq_length)
+        )
+        device = stacked_inputs.device
+        # we feed in the input embeddings (not word indices as in NLP) to the model
+        encoder_outputs = self.encoder(
+            inputs_embeds=stacked_inputs,
+            attention_mask=stacked_attention_mask,
+            position_ids=torch.zeros(stacked_attention_mask.shape, device=device, dtype=torch.long),
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        x = encoder_outputs[0]
+
+        # reshape x so that the second dimension corresponds to the original
+        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
+        x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
+
+        # get predictions
+        return_preds = self.predict_return(x[:, 2])  # predict next return given state and action
+        state_preds = self.predict_state(x[:, 2])  # predict next state given state and action
+        action_preds = self.predict_action(x[:, 1])  # predict next action given state
+        if not return_dict:
+            return (state_preds, action_preds, return_preds)
+
+        return DecisionTransformerOutput(
+            last_hidden_state=encoder_outputs.last_hidden_state,
+            state_preds=state_preds,
+            action_preds=action_preds,
+            return_preds=return_preds,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+__all__ = [
+    "DecisionTransformerGPT2Model",
+    "DecisionTransformerGPT2PreTrainedModel",
+    "DecisionTransformerModel",
+    "DecisionTransformerPreTrainedModel",
+]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71edfdf40cba6b49e3a7b564255f30f9613fc08c
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e477eb1d2cc2ff0e34d214ed99ad3f80afe8ab0a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .tokenization_gpt_sw3 import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..280f77ab5d6dfa4b749db348e25ed4acda973cd3
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bcd102a86b85d14b27f4eaa0d972d0a56f08da02
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
new file mode 100644
index 0000000000000000000000000000000000000000..7991988c74849ee8815ea315ebd3235f4c8d4012
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
@@ -0,0 +1,299 @@
+"""The tokenizer used by the GPT-SW3 models."""
+
+import os
+import re
+import unicodedata
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import sentencepiece as spm
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import is_torch_available, logging
+
+
+if is_torch_available():
+    import torch
+
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
+
+
+class GPTSw3Tokenizer(PreTrainedTokenizer):
+    """
+    Construct an GPTSw3 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
+
+    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+    this superclass for more information regarding those methods.
+
+    Example usage:
+    ```python
+    >>> from transformers import GPTSw3Tokenizer
+
+    >>> tokenizer = GPTSw3Tokenizer.from_pretrained("AI-Sweden-Models/gpt-sw3-126m")
+    >>> tokenizer("Svenska är kul!")["input_ids"]
+    [1814, 377, 3617, 63504]
+    ```
+
+    Args:
+        vocab_file (`str`):
+            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+            contains the vocabulary necessary to instantiate a tokenizer.
+        do_lower_case (`bool`, *optional*, defaults to `False`):
+            Whether or not to lowercase the input when tokenizing.
+        remove_space (`bool`, *optional*, defaults to `False`):
+            Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
+        keep_accents (`bool`, *optional*, defaults to `False`):
+            Whether or not to keep accents when tokenizing.
+        pad_token (`str`, *optional*):
+            The token used for padding, for example when batching sequences of different lengths. If not provided, will
+            default to '' or '' depending on model size.
+        unk_token (`str`, *optional*):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead. If not provided, will default to ''.
+        eos_token (`str`, *optional*):
+            The end of sequence token seen during pretraining. If not provided, will default to '<|endoftext|>'
+        bos_token (`str`, *optional*):
+            The beginning of sequence token that can be used for downstream task, was not seen during pretraining. If
+            not provided, will default to '' or '<|endoftext|>', depending on model size.
+        sp_model_kwargs (`dict`, *optional*):
+            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+            to set:
+
+            - `enable_sampling`: Enable subword regularization.
+            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+              - `nbest_size = {0,1}`: No sampling is performed.
+              - `nbest_size > 1`: samples from the nbest_size results.
+              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+                using forward-filtering-and-backward-sampling algorithm.
+
+            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+              BPE-dropout.
+
+    Attributes:
+        sp_model (`SentencePieceProcessor`):
+            The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
+        whitespaces (`set`):
+            The whitespaces that are replaced in the whitespace normalization in preprocessing.
+        non_printing_characters_re (`Pattern`):
+            The compiled regular expression to remove non-printing characters in preprocessing.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file,
+        do_lower_case=False,
+        remove_space=False,
+        keep_accents=False,
+        pad_token=None,
+        unk_token=None,
+        eos_token=None,
+        bos_token=None,
+        sp_model_kwargs: Optional[Dict[str, Any]] = None,
+        **kwargs,
+    ) -> None:
+        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+        name_or_path = kwargs.get("name_or_path")
+        if name_or_path is None:
+            logger.warning(
+                "name_or_path not provided, will work for all GPTSw3 models except gpt-sw3-7b,"
+                " you are testing the model, this can safely be ignored"
+            )
+            name_or_path = "None"
+
+        # Default definitions for our 2 tokenizer versions, with None-checks to enable proper testing
+        eos_token = "<|endoftext|>" if eos_token is None else eos_token
+        unk_token = "" if unk_token is None else unk_token
+        if "gpt-sw3-7b" in name_or_path:
+            pad_token = unk_token if pad_token is None else pad_token
+            bos_token = eos_token if bos_token is None else bos_token
+        else:
+            pad_token = "" if pad_token is None else pad_token
+            bos_token = "" if bos_token is None else bos_token
+
+        self.do_lower_case = do_lower_case
+        self.remove_space = remove_space
+        self.keep_accents = keep_accents
+        self.vocab_file = vocab_file
+
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.Load(vocab_file)
+
+        # Used for whitespace normalization in input texts
+        # fmt : off
+        self.whitespaces = {" ", " ", " ", " ", " ", " ", " ", " ", " ", " ", "", "„"}
+        # fmt : on
+
+        # Regular expression to remove non-printing characters (e.g. some unicode control chars) in preprocessing
+        self.non_printing_characters_re = re.compile(
+            f"[{''.join(map(chr, list(range(0, 9)) + list(range(11, 32)) + list(range(127, 160)) + [160, 173, 8203]))}]"
+        )
+
+        super().__init__(
+            do_lower_case=do_lower_case,
+            remove_space=remove_space,
+            keep_accents=keep_accents,
+            bos_token=bos_token,
+            eos_token=eos_token,
+            unk_token=unk_token,
+            pad_token=pad_token,
+            sp_model_kwargs=self.sp_model_kwargs,
+            **kwargs,
+        )
+
+    # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__getstate__
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["sp_model"] = None
+        return state
+
+    # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__setstate__
+    def __setstate__(self, d):
+        self.__dict__ = d
+
+        # for backward compatibility
+        if not hasattr(self, "sp_model_kwargs"):
+            self.sp_model_kwargs = {}
+
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.Load(self.vocab_file)
+
+    @property
+    # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.vocab_size
+    def vocab_size(self) -> int:
+        return len(self.sp_model)
+
+    def preprocess_text(self, text: str) -> str:
+        """
+        Returns the preprocessed text. This procedure is identical to what was used when training the tokenizer.
+        """
+
+        # Remove non-printing characters
+        text = self.non_printing_characters_re.sub("", text)
+
+        # Normalize whitespaces
+        text = "".join([char if char not in self.whitespaces else " " for char in text])
+
+        # NFC Unicode normalization
+        text = unicodedata.normalize("NFC", text)
+        return text
+
+    def _tokenize(self, text: str, **kwargs) -> List[str]:
+        text = self.preprocess_text(text)
+        return self.sp_model.encode(text, out_type=str)
+
+    def _convert_token_to_id(self, token: str) -> int:
+        """Converts a token (str) to an id (int) using the vocab."""
+        return self.sp_model.PieceToId(token)
+
+    def _convert_id_to_token(self, index: int) -> str:
+        """Converts an index (int) to a token (str) using the vocab."""
+        return self.sp_model.IdToPiece(index)
+
+    @staticmethod
+    def clean_up_tokenization(out_string: str) -> str:
+        """Returns the input string, this function is overridden to remove the default clean up."""
+        return out_string
+
+    def convert_tokens_to_string(self, tokens: List[str]) -> str:
+        """Converts a sequence of tokens (strings) to a single string. Special tokens remain intact."""
+        current_sub_tokens = []
+        out_string = ""
+        prev_is_special = False
+        for token in tokens:
+            # make sure that special tokens are not decoded using sentencepiece model
+            if token in self.all_special_tokens:
+                # TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document
+                if not prev_is_special:
+                    out_string += " "
+
+                out_string += self.sp_model.decode(current_sub_tokens) + token
+                prev_is_special = True
+                current_sub_tokens = []
+            else:
+                current_sub_tokens.append(token)
+                prev_is_special = False
+        out_string += self.sp_model.decode(current_sub_tokens)
+
+        return out_string
+
+    # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.get_vocab
+    def get_vocab(self) -> Dict[str, int]:
+        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+        vocab.update(self.added_tokens_encoder)
+        return vocab
+
+    # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.save_vocabulary
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+        elif not os.path.isfile(self.vocab_file):
+            with open(out_vocab_file, "wb") as fi:
+                content_spiece_model = self.sp_model.serialized_model_proto()
+                fi.write(content_spiece_model)
+
+        return (out_vocab_file,)
+
+    def encode_fast(
+        self, text: Union[str, List[str]], return_tensors: Union[str, bool] = False
+    ) -> Union[List[int], List[List[int]], "torch.Tensor"]:
+        """
+        Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced
+        functionality but is often much faster.
+
+        Does NOT handle special tokens correctly, these can manually be added as ids afterwards.
+
+        Does NOT support padding, these can manually be added as ids afterwards.
+
+        Use default HuggingFace tokenization methods for full functionality.
+
+        Args:
+            text (`str` or `List[str]`): One or several text(s) to convert to token ids.
+            return_tensors (`str` or `bool`): Returns PyTorch tensors if set to True or "pt"
+
+        Returns:
+            `List[int]`, `List[List[int]]`, or `torch.Tensor`: The encoded text(s) as token ids.
+        """
+
+        if isinstance(text, str):
+            text = self.preprocess_text(text)
+            token_ids = self.sp_model.encode(text)
+        else:
+            text = [self.preprocess_text(t) for t in text]
+            token_ids = self.sp_model.encode(text)
+
+        if return_tensors is True or return_tensors == "pt":
+            token_ids = torch.tensor(token_ids)
+
+        return token_ids
+
+    def decode_fast(self, token_ids: Union[int, List[int]]) -> str:
+        """
+        Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced
+        functionality but is often much faster.
+
+        Args:
+            token_ids (`int` or `List[int]`): Encoded token or text as token id(s).
+
+        Returns:
+            `str`: Decoded text
+        """
+
+        return self.sp_model.decode(token_ids)
+
+
+__all__ = ["GPTSw3Tokenizer"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..880274309cbab486a90df05548a0e4d3f2ea0925
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_musicgen import *
+    from .modeling_musicgen import *
+    from .processing_musicgen import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40499a73df1fde3b01bebb5aa499b7a3332d5ffa
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d384b0c59b210bd63df321cd2bcc15a86176eb25
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9b359b033b3ceb60d92cd733389a6a2c4cc5c7d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c38caf20dc4136e794e5762bf0b58847817dd19
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py
@@ -0,0 +1,247 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MusicGen model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto.configuration_auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class MusicgenDecoderConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a
+    MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a
+    configuration with the defaults will yield a similar configuration to that of the MusicGen
+    [facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 2048):
+            Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be
+            represented by the `inputs_ids` passed when calling [`MusicgenDecoder`].
+        hidden_size (`int`, *optional*, defaults to 1024):
+            Dimensionality of the layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 24):
+            Number of decoder layers.
+        num_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer block.
+        ffn_dim (`int`, *optional*, defaults to 4096):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
+        activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"silu"` and `"gelu_new"` are supported.
+        dropout (`float`, *optional*, defaults to 0.1):
+            The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        activation_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for activations inside the fully connected layer.
+        max_position_embeddings (`int`, *optional*, defaults to 2048):
+            The maximum sequence length that this model might ever be used with. Typically, set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        initializer_factor (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layerdrop (`float`, *optional*, defaults to 0.0):
+            The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+            for more details.
+        scale_embedding (`bool`, *optional*, defaults to `False`):
+            Scale embeddings by diving by sqrt(hidden_size).
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether the model should return the last key/values attentions (not used by all models)
+        num_codebooks (`int`, *optional*, defaults to 4):
+            The number of parallel codebooks forwarded to the model.
+        tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+            Whether input and output word embeddings should be tied.
+        audio_channels (`int`, *optional*, defaults to 1
+            Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
+            audio stream for the left/right output channels. Mono models generate a single audio stream output.
+    """
+
+    model_type = "musicgen_decoder"
+    base_config_key = "decoder_config"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=2048,
+        max_position_embeddings=2048,
+        num_hidden_layers=24,
+        ffn_dim=4096,
+        num_attention_heads=16,
+        layerdrop=0.0,
+        use_cache=True,
+        activation_function="gelu",
+        hidden_size=1024,
+        dropout=0.1,
+        attention_dropout=0.0,
+        activation_dropout=0.0,
+        initializer_factor=0.02,
+        scale_embedding=False,
+        num_codebooks=4,
+        audio_channels=1,
+        pad_token_id=2048,
+        bos_token_id=2048,
+        eos_token_id=None,
+        tie_word_embeddings=False,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.ffn_dim = ffn_dim
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.dropout = dropout
+        self.attention_dropout = attention_dropout
+        self.activation_dropout = activation_dropout
+        self.activation_function = activation_function
+        self.initializer_factor = initializer_factor
+        self.layerdrop = layerdrop
+        self.use_cache = use_cache
+        self.scale_embedding = scale_embedding  # scale factor will be sqrt(d_model) if True
+        self.num_codebooks = num_codebooks
+
+        if audio_channels not in [1, 2]:
+            raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
+        self.audio_channels = audio_channels
+
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
+
+
+class MusicgenConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a
+    MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder
+    configs.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        kwargs (*optional*):
+            Dictionary of keyword arguments. Notably:
+
+                - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+                  defines the text encoder config.
+                - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+                  defines the audio encoder config.
+                - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+                  the decoder config.
+
+    Example:
+
+    ```python
+    >>> from transformers import (
+    ...     MusicgenConfig,
+    ...     MusicgenDecoderConfig,
+    ...     T5Config,
+    ...     EncodecConfig,
+    ...     MusicgenForConditionalGeneration,
+    ... )
+
+    >>> # Initializing text encoder, audio encoder, and decoder model configurations
+    >>> text_encoder_config = T5Config()
+    >>> audio_encoder_config = EncodecConfig()
+    >>> decoder_config = MusicgenDecoderConfig()
+
+    >>> configuration = MusicgenConfig.from_sub_models_config(
+    ...     text_encoder_config, audio_encoder_config, decoder_config
+    ... )
+
+    >>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
+    >>> model = MusicgenForConditionalGeneration(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    >>> config_text_encoder = model.config.text_encoder
+    >>> config_audio_encoder = model.config.audio_encoder
+    >>> config_decoder = model.config.decoder
+
+    >>> # Saving the model, including its configuration
+    >>> model.save_pretrained("musicgen-model")
+
+    >>> # loading model and config from pretrained folder
+    >>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
+    >>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
+    ```"""
+
+    model_type = "musicgen"
+    sub_configs = {
+        "text_encoder": AutoConfig,
+        "audio_encoder": AutoConfig,
+        "decoder": MusicgenDecoderConfig,
+    }
+    is_composition = True
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
+            raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
+
+        text_encoder_config = kwargs.pop("text_encoder")
+        text_encoder_model_type = text_encoder_config.pop("model_type")
+
+        audio_encoder_config = kwargs.pop("audio_encoder")
+        audio_encoder_model_type = audio_encoder_config.pop("model_type")
+
+        decoder_config = kwargs.pop("decoder")
+
+        self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
+        self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
+        self.decoder = MusicgenDecoderConfig(**decoder_config)
+        self.is_encoder_decoder = True
+
+    @classmethod
+    def from_sub_models_config(
+        cls,
+        text_encoder_config: PretrainedConfig,
+        audio_encoder_config: PretrainedConfig,
+        decoder_config: MusicgenDecoderConfig,
+        **kwargs,
+    ):
+        r"""
+        Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder
+        configurations.
+
+        Returns:
+            [`MusicgenConfig`]: An instance of a configuration object
+        """
+
+        return cls(
+            text_encoder=text_encoder_config.to_dict(),
+            audio_encoder=audio_encoder_config.to_dict(),
+            decoder=decoder_config.to_dict(),
+            **kwargs,
+        )
+
+    @property
+    # This is a property because you might want to change the codec model on the fly
+    def sampling_rate(self):
+        return self.audio_encoder.sampling_rate
+
+
+__all__ = ["MusicgenConfig", "MusicgenDecoderConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..cab950995a97558af23c1a5cf856be4f568d3f1c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py
@@ -0,0 +1,2756 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Musicgen model."""
+
+import copy
+import inspect
+import math
+import random
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...generation import (
+    ClassifierFreeGuidanceLogitsProcessor,
+    GenerationConfig,
+    GenerationMixin,
+    GenerationMode,
+    LogitsProcessorList,
+    StoppingCriteriaList,
+)
+from ...modeling_attn_mask_utils import (
+    _prepare_4d_attention_mask,
+    _prepare_4d_attention_mask_for_sdpa,
+    _prepare_4d_causal_attention_mask,
+    _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPastAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    ModelOutput,
+    Seq2SeqLMOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    is_flash_attn_2_available,
+    is_flash_attn_greater_or_equal_2_10,
+    logging,
+    replace_return_docstrings,
+)
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_auto import AutoModel
+from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
+
+
+if is_flash_attn_2_available():
+    from ...modeling_flash_attention_utils import _flash_attention_forward
+
+if TYPE_CHECKING:
+    from ...generation.streamers import BaseStreamer
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "MusicgenConfig"
+_CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
+
+
+@dataclass
+class MusicgenUnconditionalInput(ModelOutput):
+    """
+    Args:
+        encoder_outputs  (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the text encoder model.
+        attention_mask (`torch.LongTensor`)  of shape `(batch_size, sequence_length)`, *optional*):
+            Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
+            1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
+        guidance_scale (`float`, *optional*):
+            Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted
+            from the prompts) and the unconditional logits (predicted without prompts).
+    """
+
+    encoder_outputs: Tuple[torch.FloatTensor] = None
+    attention_mask: torch.LongTensor = None
+    guidance_scale: float = None
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+    """
+    Shift input ids one token to the right.
+    """
+    # transpose to get (bsz, num_codebooks, seq_len)
+    input_ids = input_ids.transpose(1, 2)
+    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+    shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+    if decoder_start_token_id is None:
+        raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+    shifted_input_ids[..., 0] = decoder_start_token_id
+
+    if pad_token_id is None:
+        raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+    # replace possible -100 values in labels by `pad_token_id`
+    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+    return shifted_input_ids
+
+
+class MusicgenSinusoidalPositionalEmbedding(nn.Module):
+    """This module produces sinusoidal positional embeddings of any length."""
+
+    def __init__(self, num_positions: int, embedding_dim: int):
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.make_weights(num_positions, embedding_dim)
+
+    def make_weights(self, num_embeddings: int, embedding_dim: int):
+        emb_weights = self.get_embedding(num_embeddings, embedding_dim)
+        if hasattr(self, "weights"):
+            # in forward put the weights on the correct dtype and device of the param
+            emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
+
+        self.weights = nn.Parameter(emb_weights)
+        self.weights.requires_grad = False
+        self.weights.detach_()
+
+    @staticmethod
+    def get_embedding(num_embeddings: int, embedding_dim: int):
+        """
+        Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
+        description in Section 3.5 of "Attention Is All You Need".
+        """
+        half_dim = embedding_dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
+        emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
+        emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1)
+        if embedding_dim % 2 == 1:
+            # zero pad
+            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+        return emb.to(torch.get_default_dtype())
+
+    @torch.no_grad()
+    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+        bsz, codebooks, seq_len = input_ids.size()
+        # Create the position ids from the input token ids.
+        position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
+        # expand embeddings if needed
+        if seq_len > self.weights.size(0):
+            self.make_weights(seq_len + self.offset, self.embedding_dim)
+        return self.weights.index_select(0, position_ids.view(-1)).detach()
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Musicgen
+class MusicgenAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        is_decoder: bool = False,
+        bias: bool = True,
+        is_causal: bool = False,
+        config: Optional[MusicgenConfig] = None,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+                f" and `num_heads`: {num_heads})."
+            )
+        self.scaling = self.head_dim**-0.5
+        self.is_decoder = is_decoder
+        self.is_causal = is_causal
+
+        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+        else:
+            # self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.reshape(*proj_shape)
+        value_states = value_states.reshape(*proj_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if layer_head_mask is not None:
+            if layer_head_mask.size() != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.size()}"
+                )
+            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if output_attentions:
+            # this operation is a bit awkward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to be reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped, past_key_value
+
+
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen
+class MusicgenFlashAttention2(MusicgenAttention):
+    """
+    Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
+    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+    flash attention and deal with padding tokens in case the input contains any of them.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+    def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        # MusicgenFlashAttention2 attention does not support output_attentions
+        if output_attentions:
+            raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions")
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, q_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0].transpose(1, 2)
+            value_states = past_key_value[1].transpose(1, 2)
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+            value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+        else:
+            # self_attention
+            key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+
+        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+        # therefore the input hidden states gets silently casted in float32. Hence, we need
+        # cast them back in the correct dtype just to be sure everything works as expected.
+        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+        # in fp32. (LlamaRMSNorm handles it correctly)
+
+        input_dtype = query_states.dtype
+        if input_dtype == torch.float32:
+            if torch.is_autocast_enabled():
+                target_dtype = torch.get_autocast_gpu_dtype()
+            # Handle the case where the model is quantized
+            elif hasattr(self.config, "_pre_quantization_dtype"):
+                target_dtype = self.config._pre_quantization_dtype
+            else:
+                target_dtype = self.q_proj.weight.dtype
+
+            logger.warning_once(
+                f"The input hidden states seems to be silently casted in float32, this might be related to"
+                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+                f" {target_dtype}."
+            )
+
+            query_states = query_states.to(target_dtype)
+            key_states = key_states.to(target_dtype)
+            value_states = value_states.to(target_dtype)
+
+        attn_output = _flash_attention_forward(
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            q_len,
+            dropout=self.dropout if self.training else 0.0,
+            is_causal=self.is_causal,
+            use_top_left_mask=self._flash_attn_uses_top_left_mask,
+        )
+
+        attn_output = attn_output.reshape(bsz, q_len, -1)
+        attn_output = self.out_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+class MusicgenSdpaAttention(MusicgenAttention):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+        if output_attentions or layer_head_mask is not None:
+            # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+            logger.warning_once(
+                "MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
+                ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+            )
+            return super().forward(
+                hidden_states,
+                key_value_states=key_value_states,
+                past_key_value=past_key_value,
+                attention_mask=attention_mask,
+                layer_head_mask=layer_head_mask,
+                output_attentions=output_attentions,
+            )
+
+        if (
+            attention_mask is not None
+            and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any()
+        ):
+            logger.warning_once(
+                '`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+                "Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information."
+            )
+            return super().forward(
+                hidden_states,
+                key_value_states=key_value_states,
+                past_key_value=past_key_value,
+                attention_mask=attention_mask,
+                layer_head_mask=layer_head_mask,
+                output_attentions=output_attentions,
+            )
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states)
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+        else:
+            # self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        query_states = self._shape(query_states, tgt_len, bsz)
+
+        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+        # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
+        is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+
+        # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
+        # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
+        attn_output = torch.nn.functional.scaled_dot_product_attention(
+            query_states,
+            key_states,
+            value_states,
+            attn_mask=attention_mask,
+            dropout_p=self.dropout if self.training else 0.0,
+            is_causal=is_causal,
+        )
+
+        if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+
+MUSICGEN_ATTENTION_CLASSES = {
+    "eager": MusicgenAttention,
+    "sdpa": MusicgenSdpaAttention,
+    "flash_attention_2": MusicgenFlashAttention2,
+}
+
+
+class MusicgenDecoderLayer(nn.Module):
+    def __init__(self, config: MusicgenDecoderConfig):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+
+        self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.num_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            bias=False,
+            is_causal=True,
+            config=config,
+        )
+        self.dropout = config.dropout
+        self.activation_fn = ACT2FN[config.activation_function]
+        self.activation_dropout = config.activation_dropout
+
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
+            self.embed_dim,
+            config.num_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            bias=False,
+            config=config,
+        )
+        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
+        self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = True,
+    ) -> torch.Tensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            encoder_hidden_states (`torch.FloatTensor`):
+                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+                `(encoder_attention_heads,)`.
+            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+                size `(decoder_attention_heads,)`.
+            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        # Self Attention
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        # add present self-attn cache to positions 1,2 of present_key_value tuple
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            past_key_value=self_attn_past_key_value,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        # Cross-Attention Block
+        cross_attn_present_key_value = None
+        cross_attn_weights = None
+        if encoder_hidden_states is not None:
+            residual = hidden_states
+            hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+                hidden_states=hidden_states,
+                key_value_states=encoder_hidden_states,
+                attention_mask=encoder_attention_mask,
+                layer_head_mask=cross_attn_layer_head_mask,
+                past_key_value=cross_attn_past_key_value,
+                output_attentions=output_attentions,
+            )
+            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+            hidden_states = residual + hidden_states
+
+            # add cross-attn to positions 3,4 of present_key_value tuple
+            present_key_value = present_key_value + cross_attn_present_key_value
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights, cross_attn_weights)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+class MusicgenPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = MusicgenDecoderConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+
+    def _init_weights(self, module):
+        std = self.config.initializer_factor
+        if isinstance(module, (nn.Linear, nn.Conv1d)):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+MUSICGEN_START_DOCSTRING = r"""
+
+    The Musicgen model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
+    Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an
+    encoder decoder transformer trained on the task of conditional music generation
+
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`MusicgenConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MUSICGEN_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+            Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+            such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+            
+
+            The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+            target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+            you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+            frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+            target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+            `decoder_input_ids`.
+
+            
+
+        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+            1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+            than the model's internal embedding lookup matrix.
+        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+            input (see `past_key_values`). This is useful if you want more control over how to convert
+            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+            of `inputs_embeds`.
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+            Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+            such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+
+            
+
+            The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+            target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+            you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+            frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+            target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+            `input_ids`.
+
+            
+
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
+            the decoder.
+        encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+            Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+            selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+            cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+            than the model's internal embedding lookup matrix.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class MusicgenDecoder(MusicgenPreTrainedModel):
+    """
+    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenDecoderLayer`]
+    """
+
+    def __init__(self, config: MusicgenDecoderConfig):
+        super().__init__(config)
+        self.dropout = config.dropout
+        self.layerdrop = config.layerdrop
+        self.max_target_positions = config.max_position_embeddings
+        self.d_model = config.hidden_size
+        self.num_codebooks = config.num_codebooks
+        self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
+
+        embed_dim = config.vocab_size + 1
+        self.embed_tokens = nn.ModuleList(
+            [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
+        )
+
+        self.embed_positions = MusicgenSinusoidalPositionalEmbedding(
+            config.max_position_embeddings,
+            config.hidden_size,
+        )
+
+        self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.layer_norm = nn.LayerNorm(config.hidden_size)
+        self.attn_implementation = config._attn_implementation
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len)
+            input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+            bsz, num_codebooks, seq_len = input.shape
+            input_shape = (bsz, seq_len)
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            input = inputs_embeds[:, :, -1:]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if inputs_embeds is None:
+            inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
+
+        if self.attn_implementation == "flash_attention_2":
+            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+        elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
+            # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+            # the manual implementation that requires a 4D causal mask in all cases.
+            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+                attention_mask,
+                input_shape,
+                inputs_embeds,
+                past_key_values_length,
+            )
+        else:
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask, input_shape, inputs_embeds, past_key_values_length
+            )
+
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self.attn_implementation == "flash_attention_2":
+                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+            elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
+                # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+                # the manual implementation that requires a 4D causal mask in all cases.
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+                    encoder_attention_mask,
+                    inputs_embeds.dtype,
+                    tgt_len=input_shape[-1],
+                )
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.embed_positions(input, past_key_values_length)
+
+        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.size()[0]}."
+                    )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            dropout_probability = random.uniform(0, 1)
+            if self.training and (dropout_probability < self.layerdrop):
+                continue
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.forward,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+                    None,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+@add_start_docstrings(
+    "The bare Musicgen decoder model outputting raw hidden-states without any specific head on top.",
+    MUSICGEN_START_DOCSTRING,
+)
+class MusicgenModel(MusicgenPreTrainedModel):
+    def __init__(self, config: MusicgenDecoderConfig):
+        super().__init__(config)
+        self.decoder = MusicgenDecoder(config)
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.decoder.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.decoder.embed_tokens = value
+
+    def get_decoder(self):
+        return self.decoder
+
+    @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+        decoder_outputs = self.decoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_attention_mask=encoder_attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if not return_dict:
+            return decoder_outputs
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=decoder_outputs.last_hidden_state,
+            past_key_values=decoder_outputs.past_key_values,
+            hidden_states=decoder_outputs.hidden_states,
+            attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+        )
+
+
+@add_start_docstrings(
+    "The MusicGen decoder model with a language modelling head on top.",
+    MUSICGEN_START_DOCSTRING,
+)
+class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
+    def __init__(self, config: MusicgenDecoderConfig):
+        super().__init__(config)
+
+        self.model = MusicgenModel(config)
+
+        self.num_codebooks = config.num_codebooks
+        self.lm_heads = nn.ModuleList(
+            [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)]
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.decoder.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.decoder.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_heads
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_heads = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model.decoder = decoder
+
+    def get_decoder(self):
+        return self.model.decoder
+
+    @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        **kwargs,
+    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+        Returns:
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if (labels is not None) and (input_ids is None and inputs_embeds is None):
+            input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0]
+
+        lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)
+
+        loss = None
+        if labels is not None:
+            # since encoder hidden states have been concatenated to the decoder hidden states,
+            # we take the last timestamps corresponding to labels
+            logits = lm_logits[:, :, -labels.shape[1] :]
+
+            loss_fct = CrossEntropyLoss()
+            loss = torch.zeros([], device=self.device)
+
+            # per codebook cross-entropy
+            # -100 labels are ignored
+            labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
+
+            # per codebook cross-entropy
+            # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
+            for codebook in range(self.config.num_codebooks):
+                codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
+                codebook_labels = labels[..., codebook].contiguous().view(-1)
+                loss += loss_fct(codebook_logits, codebook_labels)
+
+            loss = loss / self.config.num_codebooks
+
+        # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
+        lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
+
+        if not return_dict:
+            output = (lm_logits,) + outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=lm_logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        attention_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        head_mask=None,
+        cross_attn_head_mask=None,
+        past_key_values=None,
+        use_cache=True,
+        delay_pattern_mask=None,
+        guidance_scale=None,
+        **kwargs,
+    ):
+        # Overwritten -- MusicGen has custom processing
+        if delay_pattern_mask is None:
+            input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+                input_ids,
+                pad_token_id=self.generation_config.pad_token_id,
+                max_length=self.generation_config.max_length,
+            )
+
+        # apply the delay pattern mask
+        input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
+
+        if guidance_scale is not None and guidance_scale > 1:
+            # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+            # before sampling)
+            input_ids = input_ids.repeat((2, 1))
+            if attention_mask is not None:
+                attention_mask = attention_mask.repeat((2, 1))
+
+        if past_key_values is not None:
+            input_ids = input_ids[:, -1:]
+
+        return {
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
+            "encoder_hidden_states": encoder_hidden_states,
+            "encoder_attention_mask": encoder_attention_mask,
+            "head_mask": head_mask,
+            "cross_attn_head_mask": cross_attn_head_mask,
+            "past_key_values": past_key_values,
+            "use_cache": use_cache,
+        }
+
+    def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None):
+        """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
+        one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
+        are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
+        seq_len)`:
+        - [P, -1, -1, -1, -1, P, P, P]
+        - [P, P, -1, -1, -1, -1, P, P]
+        - [P, P, P, -1, -1, -1, -1, P]
+        - [P, P, P, P, -1, -1, -1, -1]
+        where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
+        a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
+        mask is set to the value in the prompt:
+        - [P, a, b, -1, -1, P, P, P]
+        - [P, P, c, d, -1, -1, P, P]
+        - [P, P, P, e, f, -1, -1, P]
+        - [P, P, P, P, g, h, -1, -1]
+        where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
+        tokens in our prediction.
+        """
+        # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+        input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+        bsz, num_codebooks, seq_len = input_ids.shape
+
+        max_length = max_length if max_length is not None else self.generation_config.max_length
+        input_ids_shifted = (
+            torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
+        )
+
+        channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
+        # we only apply the mask if we have a large enough seq len - otherwise we return as is
+        if max_length < 2 * channel_codebooks - 1:
+            return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
+
+        # fill the shifted ids with the prompt entries, offset by the codebook idx
+        for codebook in range(channel_codebooks):
+            if self.config.audio_channels == 1:
+                # mono channel - loop over the codebooks one-by-one
+                input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
+            else:
+                # left/right channels are interleaved in the generated codebooks, so handle one then the other
+                input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
+                input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
+
+        # construct a pattern mask that indicates the positions of padding tokens for each codebook
+        # first fill the upper triangular part (the EOS padding)
+        delay_pattern = torch.triu(
+            torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
+        )
+        # then fill the lower triangular part (the BOS padding)
+        delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
+
+        if self.config.audio_channels == 2:
+            # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
+            delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
+
+        mask = ~delay_pattern.to(input_ids.device)
+        input_ids = mask * input_ids_shifted + ~mask * pad_token_id
+
+        # find the first position to start generating - this is the first place we have the -1 token
+        # and will always be in the first codebook (since it has no codebook offset)
+        first_codebook_ids = input_ids[:, 0, :]
+        start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
+        if len(start_ids) > 0:
+            first_start_id = min(start_ids)
+        else:
+            # we have no tokens that need to be filled - return entire matrix of input ids
+            first_start_id = seq_len
+
+        # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+        pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
+        input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
+        return input_ids, pattern_mask
+
+    @staticmethod
+    def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
+        """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
+        the mask is set to -1, and otherwise setting to the value detailed in the mask."""
+        seq_len = input_ids.shape[-1]
+        decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
+        input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
+        return input_ids
+
+    @torch.no_grad()
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        generation_config: Optional[GenerationConfig] = None,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        stopping_criteria: Optional[StoppingCriteriaList] = None,
+        synced_gpus: Optional[bool] = None,
+        streamer: Optional["BaseStreamer"] = None,
+        **kwargs,
+    ):
+        """
+
+        Generates sequences of token ids for models with a language modeling head.
+
+        
+
+        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+        model's default generation configuration. You can override any `generation_config` by passing the corresponding
+        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+        For an overview of generation strategies and code examples, check out the [following
+        guide](./generation_strategies).
+
+        
+
+        Parameters:
+            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+                should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+            generation_config (`~generation.GenerationConfig`, *optional*):
+                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+                passed to generate matching the attributes of `generation_config` will override them. If
+                `generation_config` is not provided, the default will be used, which had the following loading
+                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+                default values, whose documentation should be checked to parameterize generation.
+            logits_processor (`LogitsProcessorList`, *optional*):
+                Custom logits processors that complement the default logits processors built from arguments and
+                generation config. If a logit processor is passed that is already created with the arguments or a
+                generation config an error is thrown. This feature is intended for advanced users.
+            stopping_criteria (`StoppingCriteriaList`, *optional*):
+                Custom stopping criteria that complement the default stopping criteria built from arguments and a
+                generation config. If a stopping criteria is passed that is already created with the arguments or a
+                generation config an error is thrown. This feature is intended for advanced users.
+            synced_gpus (`bool`, *optional*, defaults to `False`):
+                Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+                `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+            streamer (`BaseStreamer`, *optional*):
+                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+            kwargs (`Dict[str, Any]`, *optional*):
+                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+        Return:
+            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+                [`~utils.ModelOutput`] types are:
+
+                    - [`~generation.GenerateDecoderOnlyOutput`],
+                    - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+                [`~utils.ModelOutput`] types are:
+
+                    - [`~generation.GenerateEncoderDecoderOutput`],
+                    - [`~generation.GenerateBeamEncoderDecoderOutput`]
+        """
+        # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+        if generation_config is None:
+            generation_config = self.generation_config
+
+        generation_config = copy.deepcopy(generation_config)
+        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
+        generation_config.validate()
+        self._validate_model_kwargs(model_kwargs.copy())
+
+        # 2. Set generation parameters if not already defined
+        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+        requires_attention_mask = "encoder_outputs" not in model_kwargs
+        kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+        # 3. Define model inputs`
+        input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
+            inputs, generation_config.bos_token_id, model_kwargs
+        )
+        batch_size = input_ids.shape[0] // self.num_codebooks
+        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
+
+        # 4. Define other model kwargs
+        model_kwargs["use_cache"] = generation_config.use_cache
+        model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+                input_ids, generation_config, model_kwargs
+            )
+
+        # 5. Prepare `max_length` depending on other stopping criteria.
+        input_ids_length = input_ids.shape[-1]
+        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+        has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+        generation_config = self._prepare_generated_length(
+            generation_config=generation_config,
+            has_default_max_length=has_default_max_length,
+            has_default_min_length=has_default_min_length,
+            model_input_name=model_input_name,
+            inputs_tensor=input_ids,
+            input_ids_length=input_ids_length,
+        )
+
+        # 6. Prepare `input_ids` which will be used for auto-regressive generation
+        # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+        input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+            input_ids,
+            pad_token_id=generation_config._decoder_start_token_tensor,
+            max_length=generation_config.max_length,
+        )
+
+        if streamer is not None:
+            streamer.put(input_ids.cpu())
+
+        # stash the delay mask so that we don't have to recompute it in each forward pass
+        model_kwargs["delay_pattern_mask"] = delay_pattern_mask
+
+        # 7. determine generation mode
+        generation_mode = generation_config.get_generation_mode()
+
+        # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
+        if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+            logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+            generation_config.guidance_scale = None
+
+        # 9. prepare distribution pre_processing samplers
+        logits_processor = self._get_logits_processor(
+            generation_config=generation_config,
+            input_ids_seq_length=input_ids_length,
+            encoder_input_ids=input_ids,
+            prefix_allowed_tokens_fn=None,
+            logits_processor=logits_processor,
+            device=input_ids.device,
+        )
+
+        # 10. prepare stopping criteria
+        stopping_criteria = self._get_stopping_criteria(
+            generation_config=generation_config, stopping_criteria=stopping_criteria
+        )
+
+        if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+            # expand input_ids with `num_return_sequences` additional sequences per batch
+            input_ids, model_kwargs = self._expand_inputs_for_generation(
+                input_ids=input_ids,
+                expand_size=generation_config.num_return_sequences,
+                **model_kwargs,
+            )
+
+            # 11. run sample
+            outputs = self._sample(
+                input_ids,
+                logits_processor=logits_processor,
+                stopping_criteria=stopping_criteria,
+                generation_config=generation_config,
+                synced_gpus=synced_gpus,
+                streamer=streamer,
+                **model_kwargs,
+            )
+
+        else:
+            raise ValueError(
+                "Got incompatible mode for generation, should be one of greedy or sampling. "
+                "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+            )
+
+        if generation_config.return_dict_in_generate:
+            output_ids = outputs.sequences
+        else:
+            output_ids = outputs
+
+        # apply the pattern mask to the final ids
+        output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
+
+        # revert the pattern delay mask by filtering the pad token id
+        output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
+            batch_size, self.num_codebooks, -1
+        )
+
+        if generation_config.return_dict_in_generate:
+            outputs.sequences = output_ids
+            return outputs
+        else:
+            return output_ids
+
+
+@add_start_docstrings(
+    "The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder, "
+    "for music generation tasks with one or both of text and audio prompts.",
+    MUSICGEN_START_DOCSTRING,
+)
+class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
+    config_class = MusicgenConfig
+    base_model_prefix = "encoder_decoder"
+    main_input_name = "input_ids"
+    supports_gradient_checkpointing = True
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+
+    def __init__(
+        self,
+        config: Optional[MusicgenConfig] = None,
+        text_encoder: Optional[PreTrainedModel] = None,
+        audio_encoder: Optional[PreTrainedModel] = None,
+        decoder: Optional[MusicgenForCausalLM] = None,
+    ):
+        if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
+            raise ValueError(
+                "Either a configuration has to be provided, or all three of text encoder, audio encoder and MusicGen decoder."
+            )
+        if config is None:
+            config = MusicgenConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
+        else:
+            if not isinstance(config, self.config_class):
+                raise ValueError(f"Config: {config} has to be of type {self.config_class}")
+
+        if config.decoder.cross_attention_hidden_size is not None:
+            if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
+                raise ValueError(
+                    "If `cross_attention_hidden_size` is specified in the MusicGen decoder's configuration, it has to be equal"
+                    f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+                    f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
+                    " `config.text_encoder.hidden_size`."
+                )
+
+        # initialize with config
+        super().__init__(config)
+
+        if text_encoder is None:
+            from ..auto.modeling_auto import AutoModelForTextEncoding
+
+            text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder)
+
+        if audio_encoder is None:
+            from ..auto.modeling_auto import AutoModel
+
+            audio_encoder = AutoModel.from_config(config.audio_encoder)
+
+        if decoder is None:
+            decoder = MusicgenForCausalLM._from_config(config.decoder)
+
+        self.text_encoder = text_encoder
+        self.audio_encoder = audio_encoder
+        self.decoder = decoder
+
+        if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict():
+            logger.warning(
+                f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:"
+                f" {self.config.text_encoder}"
+            )
+        if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict():
+            logger.warning(
+                f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:"
+                f" {self.config.audio_encoder}"
+            )
+        if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+            logger.warning(
+                f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+                f" {self.config.decoder}"
+            )
+
+        # make sure that the individual model's config refers to the shared config
+        # so that the updates to the config will be synced
+        self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation
+        self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation
+        self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
+        self.text_encoder.config = self.config.text_encoder
+        self.audio_encoder.config = self.config.audio_encoder
+        self.decoder.config = self.config.decoder
+
+        # text encoder outputs might need to be projected to different dimension for decoder
+        if (
+            self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
+            and self.decoder.config.cross_attention_hidden_size is None
+        ):
+            self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
+
+        if self.text_encoder.get_output_embeddings() is not None:
+            raise ValueError(
+                f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
+            )
+
+        decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
+        if "encoder_hidden_states" not in decoder_signature:
+            raise ValueError(
+                "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+                "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+            )
+
+        # tie text encoder, decoder weights if config set accordingly
+        self.tie_weights()
+
+    def tie_weights(self):
+        # tie text encoder & decoder if needed
+        if self.config.tie_encoder_decoder:
+            # tie text encoder and decoder base model
+            decoder_base_model_prefix = self.decoder.base_model_prefix
+            tied_weights = self._tie_encoder_decoder_weights(
+                self.text_encoder,
+                self.decoder._modules[decoder_base_model_prefix],
+                self.decoder.base_model_prefix,
+                "text_encoder",
+            )
+            # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+            # attributed not an instance member, therefore modifying it will modify the entire class
+            # Leading to issues on subsequent calls by different tests or subsequent calls.
+            self._dynamic_tied_weights_keys = tied_weights
+
+    def get_audio_encoder(self):
+        return self.audio_encoder
+
+    def get_text_encoder(self):
+        return self.text_encoder
+
+    def get_encoder(self):
+        # get the text encoder to compute the encoder hidden-states for generation
+        return self.get_text_encoder()
+
+    def get_decoder(self):
+        return self.decoder
+
+    def get_input_embeddings(self):
+        return self.text_encoder.get_input_embeddings()
+
+    def get_output_embeddings(self):
+        return self.decoder.get_output_embeddings()
+
+    def set_output_embeddings(self, new_embeddings):
+        return self.decoder.set_output_embeddings(new_embeddings)
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+        r"""
+        Example:
+
+        ```python
+        >>> from transformers import MusicgenForConditionalGeneration
+
+        >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+        ```"""
+
+        # At the moment fast initialization is not supported for composite models
+        if kwargs.get("_fast_init", False):
+            logger.warning(
+                "Fast initialization is currently not supported for MusicgenForConditionalGeneration. "
+                "Falling back to slow initialization..."
+            )
+        kwargs["_fast_init"] = False
+
+        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+    @classmethod
+    def from_sub_models_pretrained(
+        cls,
+        text_encoder_pretrained_model_name_or_path: str = None,
+        audio_encoder_pretrained_model_name_or_path: str = None,
+        decoder_pretrained_model_name_or_path: str = None,
+        *model_args,
+        **kwargs,
+    ) -> PreTrainedModel:
+        r"""
+        Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the
+        library from pretrained model checkpoints.
+
+
+        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+        the model, you need to first set it back in training mode with `model.train()`.
+
+        Params:
+            text_encoder_pretrained_model_name_or_path (`str`, *optional*):
+                Information necessary to initiate the text encoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+            audio_encoder_pretrained_model_name_or_path (`str`, *optional*):
+                Information necessary to initiate the audio encoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+            decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+                Information necessary to initiate the decoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+            model_args (remaining positional arguments, *optional*):
+                All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+                `output_attentions=True`).
+
+                - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration
+                  parameter.
+                - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration
+                  parameter.
+                - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+                - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+                Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+        Example:
+
+        ```python
+        >>> from transformers import MusicgenForConditionalGeneration
+
+        >>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder
+        >>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(
+        ...     text_encoder_pretrained_model_name_or_path="google-t5/t5-base",
+        ...     audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
+        ...     decoder_pretrained_model_name_or_path="facebook/musicgen-small",
+        ... )
+        >>> # saving model after fine-tuning
+        >>> model.save_pretrained("./musicgen-ft")
+        >>> # load fine-tuned model
+        >>> model = MusicgenForConditionalGeneration.from_pretrained("./musicgen-ft")
+        ```"""
+
+        kwargs_text_encoder = {
+            argument[len("text_encoder_") :]: value
+            for argument, value in kwargs.items()
+            if argument.startswith("text_encoder_")
+        }
+
+        kwargs_audio_encoder = {
+            argument[len("audio_encoder_") :]: value
+            for argument, value in kwargs.items()
+            if argument.startswith("audio_encoder_")
+        }
+
+        kwargs_decoder = {
+            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+        }
+
+        # remove text encoder, audio encoder and decoder kwargs from kwargs
+        for key in kwargs_text_encoder.keys():
+            del kwargs["text_encoder_" + key]
+        for key in kwargs_audio_encoder.keys():
+            del kwargs["audio_encoder_" + key]
+        for key in kwargs_decoder.keys():
+            del kwargs["decoder_" + key]
+
+        # Load and initialize the encoder and decoder
+        # The distinction between encoder and decoder at the model level is made
+        # by the value of the flag `is_decoder` that we need to set correctly.
+        text_encoder = kwargs_text_encoder.pop("model", None)
+        if text_encoder is None:
+            if text_encoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_text_encoder:
+                encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained(
+                    text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True
+                )
+
+                if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+                    logger.info(
+                        f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model "
+                        "from a decoder model. Cross-attention and casual mask are disabled."
+                    )
+                    encoder_config.is_decoder = False
+                    encoder_config.add_cross_attention = False
+
+                kwargs_text_encoder["config"] = encoder_config
+
+            text_encoder = AutoModel.from_pretrained(
+                text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
+            )
+
+        audio_encoder = kwargs_audio_encoder.pop("model", None)
+        if audio_encoder is None:
+            if audio_encoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_audio_encoder:
+                encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained(
+                    audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True
+                )
+
+                if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+                    logger.info(
+                        f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model "
+                        "from a decoder model. Cross-attention and casual mask are disabled."
+                    )
+                    encoder_config.is_decoder = False
+                    encoder_config.add_cross_attention = False
+
+                kwargs_audio_encoder["config"] = encoder_config
+
+            audio_encoder = AutoModel.from_pretrained(
+                audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder
+            )
+
+        decoder = kwargs_decoder.pop("model", None)
+        if decoder is None:
+            if decoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_decoder:
+                decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+                    decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+                )
+
+                if isinstance(decoder_config, MusicgenConfig):
+                    decoder_config = decoder_config.decoder
+
+                if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+                    logger.info(
+                        f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+                        f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+                        f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+                    )
+                    decoder_config.is_decoder = True
+                    decoder_config.add_cross_attention = True
+
+                kwargs_decoder["config"] = decoder_config
+
+            if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+                logger.warning(
+                    f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+                    f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+                    "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+                    "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a "
+                    "`decoder_config` to `.from_sub_models_pretrained(...)`"
+                )
+
+            decoder = MusicgenForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+        # instantiate config with corresponding kwargs
+        config = MusicgenConfig.from_sub_models_config(
+            text_encoder.config, audio_encoder.config, decoder.config, **kwargs
+        )
+        return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
+
+    @add_start_docstrings_to_model_forward(MUSICGEN_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.BoolTensor] = None,
+        input_values: Optional[torch.FloatTensor] = None,
+        padding_mask: Optional[torch.BoolTensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.BoolTensor] = None,
+        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+        past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        **kwargs,
+    ) -> Union[Tuple, Seq2SeqLMOutput]:
+        r"""
+        Returns:
+
+        Examples:
+        ```python
+        >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
+        >>> import torch
+
+        >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+        >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+        >>> inputs = processor(
+        ...     text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
+        ...     padding=True,
+        ...     return_tensors="pt",
+        ... )
+
+        >>> pad_token_id = model.generation_config.pad_token_id
+        >>> decoder_input_ids = (
+        ...     torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
+        ...     * pad_token_id
+        ... )
+
+        >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
+        >>> logits.shape  # (bsz * num_codebooks, tgt_len, vocab_size)
+        torch.Size([8, 1, 2048])
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        kwargs_text_encoder = {
+            argument[len("text_encoder_")]: value
+            for argument, value in kwargs.items()
+            if argument.startswith("text_encoder_")
+        }
+
+        kwargs_audio_encoder = {
+            argument[len("audio_encoder_")]: value
+            for argument, value in kwargs.items()
+            if argument.startswith("audio_encoder_")
+        }
+
+        kwargs_decoder = {
+            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+        }
+
+        if encoder_outputs is None:
+            encoder_outputs = self.text_encoder(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                inputs_embeds=inputs_embeds,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+                **kwargs_text_encoder,
+            )
+        elif isinstance(encoder_outputs, tuple):
+            encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+        encoder_hidden_states = encoder_outputs[0]
+
+        # optionally project encoder_hidden_states
+        if (
+            self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
+            and self.decoder.config.cross_attention_hidden_size is None
+        ):
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+        if attention_mask is not None:
+            encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
+
+        if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+            decoder_input_ids = shift_tokens_right(
+                labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
+            )
+
+        elif decoder_input_ids is None and decoder_inputs_embeds is None:
+            audio_encoder_outputs = self.audio_encoder(
+                input_values=input_values,
+                padding_mask=padding_mask,
+                **kwargs_audio_encoder,
+            )
+            audio_codes = audio_encoder_outputs.audio_codes
+            frames, bsz, codebooks, seq_len = audio_codes.shape
+            if frames != 1:
+                raise ValueError(
+                    f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
+                    "disabled by setting `chunk_length=None` in the audio encoder."
+                )
+
+            if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
+                # mono input through encodec that we convert to stereo
+                audio_codes = audio_codes.repeat_interleave(2, dim=2)
+
+            decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
+
+        # Decode
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=attention_mask,
+            inputs_embeds=decoder_inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            past_key_values=past_key_values,
+            return_dict=return_dict,
+            labels=labels,
+            **kwargs_decoder,
+        )
+
+        if not return_dict:
+            return decoder_outputs + encoder_outputs
+
+        return Seq2SeqLMOutput(
+            loss=decoder_outputs.loss,
+            logits=decoder_outputs.logits,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self,
+        decoder_input_ids,
+        past_key_values=None,
+        attention_mask=None,
+        head_mask=None,
+        decoder_attention_mask=None,
+        decoder_head_mask=None,
+        cross_attn_head_mask=None,
+        use_cache=None,
+        encoder_outputs=None,
+        decoder_delay_pattern_mask=None,
+        guidance_scale=None,
+        **kwargs,
+    ):
+        # Overwritten -- MusicGen has custom processing
+        if decoder_delay_pattern_mask is None:
+            decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+                decoder_input_ids,
+                self.generation_config.pad_token_id,
+                max_length=self.generation_config.max_length,
+            )
+
+        # apply the delay pattern mask
+        decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
+
+        if guidance_scale is not None and guidance_scale > 1:
+            # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+            # before sampling)
+            decoder_input_ids = decoder_input_ids.repeat((2, 1))
+            if decoder_attention_mask is not None:
+                decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
+
+        if past_key_values is not None:
+            past_length = past_key_values[0][0].shape[2]
+
+            # Some generation methods already pass only the last input ID
+            if decoder_input_ids.shape[1] > past_length:
+                remove_prefix_length = past_length
+            else:
+                # Default to old behavior: keep only final ID
+                remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+        return {
+            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
+            "encoder_outputs": encoder_outputs,
+            "past_key_values": past_key_values,
+            "decoder_input_ids": decoder_input_ids,
+            "attention_mask": attention_mask,
+            "decoder_attention_mask": decoder_attention_mask,
+            "head_mask": head_mask,
+            "decoder_head_mask": decoder_head_mask,
+            "cross_attn_head_mask": cross_attn_head_mask,
+            "use_cache": use_cache,
+        }
+
+    def _prepare_decoder_input_ids_for_generation(
+        self,
+        batch_size: int,
+        model_input_name: str,
+        model_kwargs: Dict[str, torch.Tensor],
+        decoder_start_token_id: int = None,
+        bos_token_id: int = None,
+        device: torch.device = None,
+    ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
+        """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+
+        # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
+        # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
+        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+            decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+        elif "input_ids" in model_kwargs and model_input_name != "input_ids":
+            decoder_input_ids = model_kwargs.pop("input_ids")
+        else:
+            decoder_input_ids = None
+
+        # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
+        decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
+        if device is None:
+            device = self.device
+        decoder_input_ids_start = (
+            torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device)
+            * decoder_start_token_id
+        )
+
+        # no user input -> use decoder_start_token_id as decoder_input_ids
+        if decoder_input_ids is None:
+            decoder_input_ids = decoder_input_ids_start
+
+        # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
+        # decoder_attention_mask if provided)
+        elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item():
+            decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
+            if "decoder_attention_mask" in model_kwargs:
+                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+                decoder_attention_mask = torch.cat(
+                    (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
+                    dim=-1,
+                )
+                model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+
+        return decoder_input_ids, model_kwargs
+
+    def _prepare_text_encoder_kwargs_for_generation(
+        self,
+        inputs_tensor: torch.Tensor,
+        model_kwargs,
+        model_input_name: Optional[str],
+        generation_config: GenerationConfig,
+    ) -> Dict[str, Any]:
+        # 1. get text encoder
+        encoder = self.get_text_encoder()
+        # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+        # as the inputs.
+        if hasattr(encoder, "_hf_hook"):
+            encoder._hf_hook.io_same_device = True
+
+        # 2. Prepare encoder args and encoder kwargs from model kwargs.
+        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+        encoder_kwargs = {
+            argument: value
+            for argument, value in model_kwargs.items()
+            if not any(argument.startswith(p) for p in irrelevant_prefix)
+        }
+        encoder_signature = set(inspect.signature(encoder.forward).parameters)
+        encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+        if not encoder_accepts_wildcard:
+            encoder_kwargs = {
+                argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+            }
+        encoder_kwargs["output_attentions"] = generation_config.output_attentions
+        encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+        guidance_scale = generation_config.guidance_scale
+
+        # 3. make sure that encoder returns `ModelOutput`
+        model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
+        encoder_kwargs["return_dict"] = True
+        encoder_kwargs[model_input_name] = inputs_tensor
+        last_hidden_state = encoder(**encoder_kwargs).last_hidden_state
+
+        # for classifier free guidance we need to add a 'null' input to our encoder hidden states
+        if guidance_scale is not None and guidance_scale > 1:
+            last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
+            if "attention_mask" in model_kwargs:
+                model_kwargs["attention_mask"] = torch.concatenate(
+                    [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
+                )
+
+        model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
+
+        return model_kwargs
+
+    def _prepare_audio_encoder_kwargs_for_generation(
+        self, input_values, model_kwargs, model_input_name: Optional[str] = None
+    ):
+        # 1. get audio encoder
+        encoder = self.get_audio_encoder()
+        # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+        # as the inputs.
+        if hasattr(encoder, "_hf_hook"):
+            encoder._hf_hook.io_same_device = True
+
+        # 2. Prepare encoder args and encoder kwargs from model kwargs.
+        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+        encoder_kwargs = {
+            argument: value
+            for argument, value in model_kwargs.items()
+            if not any(argument.startswith(p) for p in irrelevant_prefix)
+        }
+        encoder_signature = set(inspect.signature(encoder.forward).parameters)
+        encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+        if not encoder_accepts_wildcard:
+            encoder_kwargs = {
+                argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+            }
+
+        # 3. make sure that encoder returns `ModelOutput`
+        model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
+        encoder_kwargs["return_dict"] = True
+
+        if self.decoder.config.audio_channels == 1:
+            encoder_kwargs[model_input_name] = input_values
+            audio_encoder_outputs = encoder.encode(**encoder_kwargs)
+            audio_codes = audio_encoder_outputs.audio_codes
+            audio_scales = audio_encoder_outputs.audio_scales
+
+            frames, bsz, codebooks, seq_len = audio_codes.shape
+
+        else:
+            if input_values.shape[1] != 2:
+                raise ValueError(
+                    f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel."
+                )
+
+            encoder_kwargs[model_input_name] = input_values[:, :1, :]
+            audio_encoder_outputs_left = encoder.encode(**encoder_kwargs)
+            audio_codes_left = audio_encoder_outputs_left.audio_codes
+            audio_scales_left = audio_encoder_outputs_left.audio_scales
+
+            encoder_kwargs[model_input_name] = input_values[:, 1:, :]
+            audio_encoder_outputs_right = encoder.encode(**encoder_kwargs)
+            audio_codes_right = audio_encoder_outputs_right.audio_codes
+            audio_scales_right = audio_encoder_outputs_right.audio_scales
+
+            frames, bsz, codebooks, seq_len = audio_codes_left.shape
+            # copy alternating left/right channel codes into stereo codebook
+            audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len))
+
+            audio_codes[:, :, ::2, :] = audio_codes_left
+            audio_codes[:, :, 1::2, :] = audio_codes_right
+
+            if audio_scales_left != [None] or audio_scales_right != [None]:
+                audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1)
+            else:
+                audio_scales = [None] * bsz
+
+        if frames != 1:
+            raise ValueError(
+                f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
+                "disabled by setting `chunk_length=None` in the audio encoder."
+            )
+
+        decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
+
+        model_kwargs["decoder_input_ids"] = decoder_input_ids
+        model_kwargs["audio_scales"] = audio_scales
+        return model_kwargs
+
+    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+        return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
+
+    def resize_token_embeddings(self, *args, **kwargs):
+        raise NotImplementedError(
+            "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+            " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+            " model.decoder.resize_token_embeddings(...))"
+        )
+
+    def freeze_audio_encoder(self):
+        """
+        Freeze the audio encoder weights.
+        """
+        for param in self.audio_encoder.parameters():
+            param.requires_grad = False
+        self.audio_encoder._requires_grad = False
+
+    def freeze_text_encoder(self):
+        """
+        Freeze the text encoder weights.
+        """
+        for param in self.text_encoder.parameters():
+            param.requires_grad = False
+        self.text_encoder._requires_grad = False
+
+    def _maybe_initialize_input_ids_for_generation(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        bos_token_id: Optional[int] = None,
+        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+    ) -> torch.LongTensor:
+        """Initializes input ids for generation, if necessary."""
+        if inputs is not None:
+            return inputs
+
+        encoder_outputs = model_kwargs.get("encoder_outputs")
+        if encoder_outputs is not None:
+            # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
+            shape = encoder_outputs[0].size()[:-1]
+            return torch.ones(shape, dtype=torch.long, device=self.device) * -100
+
+        if bos_token_id is None:
+            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
+
+        # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
+        # soft-prompting or in multimodal implementations built on top of decoder-only language models.
+        batch_size = 1
+        for value in model_kwargs.values():
+            if isinstance(value, torch.Tensor):
+                batch_size = value.shape[0]
+                break
+        return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
+
+    def _get_decoder_start_token_id(
+        self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
+    ) -> int:
+        decoder_start_token_id = (
+            decoder_start_token_id
+            if decoder_start_token_id is not None
+            else self.generation_config.decoder_start_token_id
+        )
+        bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
+
+        if decoder_start_token_id is not None:
+            return decoder_start_token_id
+        elif bos_token_id is not None:
+            return bos_token_id
+        raise ValueError(
+            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
+        )
+
+    @torch.no_grad()
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        generation_config: Optional[GenerationConfig] = None,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        stopping_criteria: Optional[StoppingCriteriaList] = None,
+        synced_gpus: Optional[bool] = None,
+        streamer: Optional["BaseStreamer"] = None,
+        **kwargs,
+    ):
+        """
+
+        Generates sequences of token ids for models with a language modeling head.
+
+        
+
+        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+        model's default generation configuration. You can override any `generation_config` by passing the corresponding
+        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+        For an overview of generation strategies and code examples, check out the [following
+        guide](./generation_strategies).
+
+        
+
+        Parameters:
+            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+                should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+            generation_config (`~generation.GenerationConfig`, *optional*):
+                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+                passed to generate matching the attributes of `generation_config` will override them. If
+                `generation_config` is not provided, the default will be used, which had the following loading
+                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+                default values, whose documentation should be checked to parameterize generation.
+            logits_processor (`LogitsProcessorList`, *optional*):
+                Custom logits processors that complement the default logits processors built from arguments and
+                generation config. If a logit processor is passed that is already created with the arguments or a
+                generation config an error is thrown. This feature is intended for advanced users.
+            stopping_criteria (`StoppingCriteriaList`, *optional*):
+                Custom stopping criteria that complement the default stopping criteria built from arguments and a
+                generation config. If a stopping criteria is passed that is already created with the arguments or a
+                generation config an error is thrown. This feature is intended for advanced users.
+            synced_gpus (`bool`, *optional*, defaults to `False`):
+                Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+                `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+            streamer (`BaseStreamer`, *optional*):
+                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+            kwargs (`Dict[str, Any]`, *optional*):
+                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+        Return:
+            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+                [`~utils.ModelOutput`] types are:
+
+                    - [`~generation.GenerateDecoderOnlyOutput`],
+                    - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+                [`~utils.ModelOutput`] types are:
+
+                    - [`~generation.GenerateEncoderDecoderOutput`],
+                    - [`~generation.GenerateBeamEncoderDecoderOutput`]
+        """
+        # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+        if generation_config is None:
+            generation_config = self.generation_config
+
+        generation_config = copy.deepcopy(generation_config)
+        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
+        generation_config.validate()
+        self._validate_model_kwargs(model_kwargs.copy())
+
+        if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple:
+            # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
+            model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])
+
+        # 2. Set generation parameters if not already defined
+        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+        requires_attention_mask = "encoder_outputs" not in model_kwargs
+        kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+        # 3. Define model inputs
+        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+            inputs, generation_config.bos_token_id, model_kwargs
+        )
+        batch_size = inputs_tensor.shape[0]
+        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
+
+        # 4. Define other model kwargs
+        model_kwargs["use_cache"] = generation_config.use_cache
+        model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+                inputs_tensor, generation_config, model_kwargs
+            )
+
+        if "encoder_outputs" not in model_kwargs:
+            # encoder_outputs are created and added to `model_kwargs`
+            model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
+                inputs_tensor, model_kwargs, model_input_name, generation_config
+            )
+
+        if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
+            model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
+                model_kwargs["input_values"],
+                model_kwargs,
+            )
+
+        # 5. Prepare `input_ids` which will be used for auto-regressive generation
+        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+            batch_size=batch_size,
+            model_input_name=model_input_name,
+            model_kwargs=model_kwargs,
+            decoder_start_token_id=generation_config._decoder_start_token_tensor,
+            bos_token_id=generation_config._bos_token_tensor,
+            device=inputs_tensor.device,
+        )
+
+        # 6. Prepare `max_length` depending on other stopping criteria.
+        input_ids_length = input_ids.shape[-1]
+        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+        has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+        generation_config = self._prepare_generated_length(
+            generation_config=generation_config,
+            has_default_max_length=has_default_max_length,
+            has_default_min_length=has_default_min_length,
+            model_input_name=model_input_name,
+            inputs_tensor=inputs_tensor,
+            input_ids_length=input_ids_length,
+        )
+
+        # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+        input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+            input_ids,
+            pad_token_id=generation_config._decoder_start_token_tensor,
+            max_length=generation_config.max_length,
+        )
+        # stash the delay mask so that we don't have to recompute in each forward pass
+        model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
+
+        # input_ids are ready to be placed on the streamer (if used)
+        if streamer is not None:
+            streamer.put(input_ids.cpu())
+
+        # 7. determine generation mode
+        generation_mode = generation_config.get_generation_mode()
+
+        # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
+        if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+            logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+            generation_config.guidance_scale = None
+
+        # 9. prepare distribution pre_processing samplers
+        logits_processor = self._get_logits_processor(
+            generation_config=generation_config,
+            input_ids_seq_length=input_ids_length,
+            encoder_input_ids=inputs_tensor,
+            prefix_allowed_tokens_fn=None,
+            logits_processor=logits_processor,
+            device=input_ids.device,
+        )
+
+        # 10. prepare stopping criteria
+        stopping_criteria = self._get_stopping_criteria(
+            generation_config=generation_config, stopping_criteria=stopping_criteria
+        )
+
+        if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+            # expand input_ids with `num_return_sequences` additional sequences per batch
+            input_ids, model_kwargs = self._expand_inputs_for_generation(
+                input_ids=input_ids,
+                expand_size=generation_config.num_return_sequences,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+                **model_kwargs,
+            )
+
+            # 11. run sample
+            outputs = self._sample(
+                input_ids,
+                logits_processor=logits_processor,
+                stopping_criteria=stopping_criteria,
+                generation_config=generation_config,
+                synced_gpus=synced_gpus,
+                streamer=streamer,
+                **model_kwargs,
+            )
+
+        else:
+            raise ValueError(
+                "Got incompatible mode for generation, should be one of greedy or sampling. "
+                "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+            )
+
+        if generation_config.return_dict_in_generate:
+            output_ids = outputs.sequences
+        else:
+            output_ids = outputs
+
+        # apply the pattern mask to the final ids
+        output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
+
+        # revert the pattern delay mask by filtering the pad token id
+        output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
+            batch_size, self.decoder.num_codebooks, -1
+        )
+
+        # append the frame dimension back to the audio codes
+        output_ids = output_ids[None, ...]
+
+        audio_scales = model_kwargs.get("audio_scales")
+        if audio_scales is None:
+            audio_scales = [None] * batch_size
+
+        if self.decoder.config.audio_channels == 1:
+            output_values = self.audio_encoder.decode(
+                output_ids,
+                audio_scales=audio_scales,
+            ).audio_values
+        else:
+            codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
+            output_values_left = codec_outputs_left.audio_values
+
+            codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
+            output_values_right = codec_outputs_right.audio_values
+
+            output_values = torch.cat([output_values_left, output_values_right], dim=1)
+
+        if generation_config.return_dict_in_generate:
+            outputs.sequences = output_values
+            return outputs
+        else:
+            return output_values
+
+    def get_unconditional_inputs(self, num_samples=1):
+        """
+        Helper function to get null inputs for unconditional generation, enabling the model to be used without the
+        feature extractor or tokenizer.
+
+        Args:
+            num_samples (int, *optional*):
+                Number of audio samples to unconditionally generate.
+            max_new_tokens (int, *optional*):
+                Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
+                longer inference (since more audio tokens need to be generated per sample).
+
+        Example:
+        ```python
+        >>> from transformers import MusicgenForConditionalGeneration
+
+        >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+        >>> # get the unconditional (or 'null') inputs for the model
+        >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
+        >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
+        ```"""
+        last_hidden_state = torch.zeros(
+            (num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
+        )
+
+        attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
+
+        return MusicgenUnconditionalInput(
+            encoder_outputs=(last_hidden_state,),
+            attention_mask=attention_mask,
+            guidance_scale=1.0,
+        )
+
+
+__all__ = ["MusicgenForConditionalGeneration", "MusicgenForCausalLM", "MusicgenModel", "MusicgenPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..deebf9045b4ffb7f9e438a1892fe651c3e2f54d6
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py
@@ -0,0 +1,144 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Text/audio processor class for MusicGen
+"""
+
+from typing import List, Optional
+
+import numpy as np
+
+from ...processing_utils import ProcessorMixin
+from ...utils import to_numpy
+
+
+class MusicgenProcessor(ProcessorMixin):
+    r"""
+    Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
+    class.
+
+    [`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
+    [`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
+
+    Args:
+        feature_extractor (`EncodecFeatureExtractor`):
+            An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
+        tokenizer (`T5Tokenizer`):
+            An instance of [`T5Tokenizer`]. The tokenizer is a required input.
+    """
+
+    feature_extractor_class = "EncodecFeatureExtractor"
+    tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
+
+    def __init__(self, feature_extractor, tokenizer):
+        super().__init__(feature_extractor, tokenizer)
+        self.current_processor = self.feature_extractor
+        self._in_target_context_manager = False
+
+    def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
+        return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
+
+    def __call__(self, *args, **kwargs):
+        """
+        Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
+        argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
+        information.
+        """
+        # For backward compatibility
+        if self._in_target_context_manager:
+            return self.current_processor(*args, **kwargs)
+
+        audio = kwargs.pop("audio", None)
+        sampling_rate = kwargs.pop("sampling_rate", None)
+        text = kwargs.pop("text", None)
+        if len(args) > 0:
+            audio = args[0]
+            args = args[1:]
+
+        if audio is None and text is None:
+            raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+        if text is not None:
+            inputs = self.tokenizer(text, **kwargs)
+
+        if audio is not None:
+            audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
+
+        if audio is None:
+            return inputs
+
+        elif text is None:
+            return audio_inputs
+
+        else:
+            inputs["input_values"] = audio_inputs["input_values"]
+            if "padding_mask" in audio_inputs:
+                inputs["padding_mask"] = audio_inputs["padding_mask"]
+            return inputs
+
+    def batch_decode(self, *args, **kwargs):
+        """
+        This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
+        from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
+        [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
+        """
+        audio_values = kwargs.pop("audio", None)
+        padding_mask = kwargs.pop("padding_mask", None)
+
+        if len(args) > 0:
+            audio_values = args[0]
+            args = args[1:]
+
+        if audio_values is not None:
+            return self._decode_audio(audio_values, padding_mask=padding_mask)
+        else:
+            return self.tokenizer.batch_decode(*args, **kwargs)
+
+    def decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
+        docstring of this method for more information.
+        """
+        return self.tokenizer.decode(*args, **kwargs)
+
+    def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]:
+        """
+        This method strips any padding from the audio values to return a list of numpy audio arrays.
+        """
+        audio_values = to_numpy(audio_values)
+        bsz, channels, seq_len = audio_values.shape
+
+        if padding_mask is None:
+            return list(audio_values)
+
+        padding_mask = to_numpy(padding_mask)
+
+        # match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
+        # token (so that the generated audio values are **not** treated as padded tokens)
+        difference = seq_len - padding_mask.shape[-1]
+        padding_value = 1 - self.feature_extractor.padding_value
+        padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
+
+        audio_values = audio_values.tolist()
+        for i in range(bsz):
+            sliced_audio = np.asarray(audio_values[i])[
+                padding_mask[i][None, :] != self.feature_extractor.padding_value
+            ]
+            audio_values[i] = sliced_audio.reshape(channels, -1)
+
+        return audio_values
+
+
+__all__ = ["MusicgenProcessor"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0e18d77cbad11bf218b6cd304f703b4d7935eef
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_olmoe import *
+    from .modeling_olmoe import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..716df2b0da38cac8b121d452cceb4d471badb163
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aceb1ebbcc9839e5c491178ab2a1c93dfe91e604
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..209c76a1f741b292cbb4a60322e50d7bf549c239
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f24d5523bafd0398e6985959ffdd223261cc5ca
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py
@@ -0,0 +1,182 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""OLMoE model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class OlmoeConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924).
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 50304):
+            Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`OlmoeModel`]
+        hidden_size (`int`, *optional*, defaults to 2048):
+            Dimension of the hidden representations.
+        intermediate_size (`int`, *optional*, defaults to 2048):
+            Dimension of the MLP representations.
+        num_hidden_layers (`int`, *optional*, defaults to 16):
+            Number of hidden layers in the Transformer decoder.
+        num_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        num_key_value_heads (`int`, *optional*):
+            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+            by meanpooling all the original heads within that group. For more details checkout [this
+            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+            `num_attention_heads`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder.
+        max_position_embeddings (`int`, *optional*, defaults to 4096):
+            The maximum sequence length that this model might ever be used with.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the rms normalization layers.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        pad_token_id (`int`, *optional*, defaults to 1):
+            Padding token id.
+        bos_token_id (`int`, *optional*):
+            Beginning of stream token id.
+        eos_token_id (`int`, *optional*, defaults to 50279):
+            End of stream token id.
+        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether to tie weight embeddings
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        rope_scaling (`Dict`, *optional*):
+            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+            `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+            `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+            these scaling strategies behave:
+            https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+            experimental feature, subject to breaking API changes in future versions.
+        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+            Whether to use a bias in the query, key, value and output projection layers during self-attention.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        clip_qkv (`float`, *optional*):
+            If not `None`, elements of query, key and value attention states are clipped so that their
+            absolute value does not exceed this value.
+        num_experts_per_tok (`int`, *optional*, defaults to 8):
+            Number of selected experts.
+        num_experts (`int`, *optional*, defaults to 64):
+            Number of routed experts.
+        output_router_logits (`bool`, *optional*, defaults to `False`):
+            Whether or not the router logits should be returned by the model. Enabeling this will also
+            allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
+        router_aux_loss_coef (`float`, *optional*, defaults to 0.01):
+            The aux loss factor for the total loss.
+        norm_topk_prob (`bool`, *optional*, defaults to `False`):
+            Whether to normalize the topk probabilities.
+
+    ```python
+    >>> from transformers import OlmoeModel, OlmoeConfig
+
+    >>> # Initializing a OLMoE 7B A1B style configuration
+    >>> configuration = OlmoeConfig()
+
+    >>> # Initializing a model from the OLMoE 7B A1B style configuration
+    >>> model = OlmoeModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "olmoe"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=50304,
+        hidden_size=2048,
+        intermediate_size=2048,
+        num_hidden_layers=16,
+        num_attention_heads=16,
+        num_key_value_heads=None,
+        hidden_act="silu",
+        max_position_embeddings=4096,
+        initializer_range=0.02,
+        rms_norm_eps=1e-05,
+        use_cache=True,
+        pad_token_id=1,
+        bos_token_id=None,
+        eos_token_id=50279,
+        tie_word_embeddings=False,
+        rope_theta=10000.0,
+        rope_scaling=None,
+        attention_bias=False,
+        attention_dropout=0.0,
+        clip_qkv=None,
+        num_experts_per_tok=8,
+        num_experts=64,
+        output_router_logits=False,
+        router_aux_loss_coef=0.01,
+        norm_topk_prob=False,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+        self.attention_bias = attention_bias
+        self.attention_dropout = attention_dropout
+        self.clip_qkv = clip_qkv
+        self.num_experts_per_tok = num_experts_per_tok
+        self.num_experts = num_experts
+        self.output_router_logits = output_router_logits
+        self.router_aux_loss_coef = router_aux_loss_coef
+        self.norm_topk_prob = norm_topk_prob
+        # Validate the correctness of rotary position embeddings parameters
+        # BC: if there is a 'type' field, move it to 'rope_type'.
+        if self.rope_scaling is not None and "type" in self.rope_scaling:
+            self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+        rope_config_validation(self)
+
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
+
+
+__all__ = ["OlmoeConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c78138c1a035155564955fdf33819a19c6de64a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py
@@ -0,0 +1,1299 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OLMoE model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_outputs import (
+    MoeCausalLMOutputWithPast,
+    MoeModelOutputWithPast,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import ALL_LAYERNORM_LAYERS
+from ...utils import (
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    is_flash_attn_2_available,
+    is_flash_attn_greater_or_equal_2_10,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_olmoe import OlmoeConfig
+
+
+if is_flash_attn_2_available():
+    from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "OlmoeConfig"
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
+def load_balancing_loss_func(
+    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
+    num_experts: Optional[int] = None,
+    top_k=2,
+    attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+    r"""
+    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
+    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+    experts is too unbalanced.
+
+    Args:
+        gate_logits:
+            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+            shape [batch_size X sequence_length, num_experts].
+        num_experts:
+            Number of experts
+        top_k:
+            The number of experts to route per-token, can be also interpreted as the `top-k` routing
+            parameter.
+        attention_mask (`torch.Tensor`, *optional*):
+            The attention_mask used in forward function
+            shape [batch_size X sequence_length] if not None.
+
+    Returns:
+        The auxiliary loss.
+    """
+    if gate_logits is None or not isinstance(gate_logits, tuple):
+        return 0
+
+    if isinstance(gate_logits, tuple):
+        compute_device = gate_logits[0].device
+        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+    if attention_mask is None:
+        # Compute the percentage of tokens routed to each experts
+        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+        # Compute the average probability of routing to these experts
+        router_prob_per_expert = torch.mean(routing_weights, dim=0)
+    else:
+        batch_size, sequence_length = attention_mask.shape
+        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+        expert_attention_mask = (
+            attention_mask[None, :, :, None, None]
+            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+            .reshape(-1, top_k, num_experts)
+            .to(compute_device)
+        )
+
+        # Compute the percentage of tokens routed to each experts
+        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+            expert_attention_mask, dim=0
+        )
+
+        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+        router_per_expert_attention_mask = (
+            attention_mask[None, :, :, None]
+            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+            .reshape(-1, num_experts)
+            .to(compute_device)
+        )
+
+        # Compute the average probability of routing to these experts
+        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+            router_per_expert_attention_mask, dim=0
+        )
+
+    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+    return overall_loss * num_experts
+
+
+class OlmoeRMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-5):
+        """
+        OlmoeRMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+        return self.weight * hidden_states.to(input_dtype)
+
+    def extra_repr(self):
+        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmoe
+class OlmoeRotaryEmbedding(nn.Module):
+    def __init__(self, config: OlmoeConfig, device=None):
+        super().__init__()
+        # BC: "rope_type" was originally "type"
+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+        else:
+            self.rope_type = "default"
+        self.max_seq_len_cached = config.max_position_embeddings
+        self.original_max_seq_len = config.max_position_embeddings
+
+        self.config = config
+        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self.original_inv_freq = self.inv_freq
+
+    def _dynamic_frequency_update(self, position_ids, device):
+        """
+        dynamic RoPE layers should recompute `inv_freq` in the following situations:
+        1 - growing beyond the cached sequence length (allow scaling)
+        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+        """
+        seq_len = torch.max(position_ids) + 1
+        if seq_len > self.max_seq_len_cached:  # growth
+            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
+            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
+            self.max_seq_len_cached = seq_len
+
+        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
+            # This .to() is needed if the model has been moved to a device after being initialized (because
+            # the buffer is automatically moved, but not the original copy)
+            self.original_inv_freq = self.original_inv_freq.to(device)
+            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+            self.max_seq_len_cached = self.original_max_seq_len
+
+    @torch.no_grad()
+    def forward(self, x, position_ids):
+        if "dynamic" in self.rope_type:
+            self._dynamic_frequency_update(position_ids, device=x.device)
+
+        # Core RoPE block
+        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+        position_ids_expanded = position_ids[:, None, :].float()
+        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+        device_type = x.device.type
+        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+        with torch.autocast(device_type=device_type, enabled=False):
+            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+            emb = torch.cat((freqs, freqs), dim=-1)
+            cos = emb.cos()
+            sin = emb.sin()
+
+        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+        cos = cos * self.attention_scaling
+        sin = sin * self.attention_scaling
+
+        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`, *optional*):
+            Deprecated and unused.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+# Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmoe
+class OlmoeMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size
+        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, x):
+        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+        return down_proj
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class OlmoeAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        if layer_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+                "when creating this class."
+            )
+
+        self.attention_dropout = config.attention_dropout
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.is_causal = True
+
+        if (self.head_dim * self.num_heads) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads})."
+            )
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+        self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+        self.k_norm = OlmoeRMSNorm(
+            (self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_norm(self.q_proj(hidden_states))
+        key_states = self.k_norm(self.k_proj(hidden_states))
+        value_states = self.v_proj(hidden_states)
+
+        if self.config.clip_qkv is not None:
+            query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+            key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+            value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_value is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+            attn_weights = attn_weights + causal_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+class OlmoeFlashAttention2(OlmoeAttention):
+    """
+    OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays
+    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+    flash attention and deal with padding tokens in case the input contains any of them.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        output_attentions = False
+
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_norm(self.q_proj(hidden_states))
+        key_states = self.k_norm(self.k_proj(hidden_states))
+        value_states = self.v_proj(hidden_states)
+        if self.config.clip_qkv is not None:
+            query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+            key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+            value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+        # Flash attention requires the input to have the shape
+        # batch_size x seq_length x head_dim x hidden_dim
+        # therefore we just need to keep the original shape
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_value is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+        # to be able to avoid many of these transpose/reshape/view.
+        query_states = query_states.transpose(1, 2)
+        key_states = key_states.transpose(1, 2)
+        value_states = value_states.transpose(1, 2)
+
+        dropout_rate = self.attention_dropout if self.training else 0.0
+
+        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+        # therefore the input hidden states gets silently casted in float32. Hence, we need
+        # cast them back in the correct dtype just to be sure everything works as expected.
+        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+        # in fp32. (OlmoeRMSNorm handles it correctly)
+
+        input_dtype = query_states.dtype
+        if input_dtype == torch.float32:
+            if torch.is_autocast_enabled():
+                target_dtype = torch.get_autocast_gpu_dtype()
+            # Handle the case where the model is quantized
+            elif hasattr(self.config, "_pre_quantization_dtype"):
+                target_dtype = self.config._pre_quantization_dtype
+            else:
+                target_dtype = self.q_proj.weight.dtype
+
+            logger.warning_once(
+                f"The input hidden states seems to be silently casted in float32, this might be related to"
+                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+                f" {target_dtype}."
+            )
+
+            query_states = query_states.to(target_dtype)
+            key_states = key_states.to(target_dtype)
+            value_states = value_states.to(target_dtype)
+
+        attn_output = _flash_attention_forward(
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            q_len,
+            dropout=dropout_rate,
+            use_top_left_mask=self._flash_attn_uses_top_left_mask,
+            is_causal=self.is_causal,
+        )
+
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+class OlmoeSdpaAttention(OlmoeAttention):
+    """
+    OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+    `OlmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+    SDPA API.
+    """
+
+    # Adapted from OlmoeAttention.forward
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        if output_attentions:
+            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+            logger.warning_once(
+                "OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+            )
+            return super().forward(
+                hidden_states=hidden_states,
+                attention_mask=attention_mask,
+                position_ids=position_ids,
+                past_key_value=past_key_value,
+                output_attentions=output_attentions,
+                use_cache=use_cache,
+                cache_position=cache_position,
+                position_embeddings=position_embeddings,
+            )
+
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_norm(self.q_proj(hidden_states))
+        key_states = self.k_norm(self.k_proj(hidden_states))
+        value_states = self.v_proj(hidden_states)
+
+        if self.config.clip_qkv is not None:
+            query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+            key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+            value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_value is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        causal_mask = attention_mask
+        # if attention_mask is not None and cache_position is not None:
+        if attention_mask is not None:
+            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+        # Reference: https://github.com/pytorch/pytorch/issues/112577.
+        if query_states.device.type == "cuda" and causal_mask is not None:
+            query_states = query_states.contiguous()
+            key_states = key_states.contiguous()
+            value_states = value_states.contiguous()
+
+        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+        is_causal = True if causal_mask is None and q_len > 1 else False
+
+        attn_output = torch.nn.functional.scaled_dot_product_attention(
+            query_states,
+            key_states,
+            value_states,
+            attn_mask=causal_mask,
+            dropout_p=self.attention_dropout if self.training else 0.0,
+            is_causal=is_causal,
+        )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+        attn_output = self.o_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+
+OLMOE_ATTENTION_CLASSES = {
+    "eager": OlmoeAttention,
+    "flash_attention_2": OlmoeFlashAttention2,
+    "sdpa": OlmoeSdpaAttention,
+}
+
+
+class OlmoeSparseMoeBlock(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.num_experts = config.num_experts
+        self.top_k = config.num_experts_per_tok
+        self.norm_topk_prob = config.norm_topk_prob
+        self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
+        self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        batch_size, sequence_length, hidden_dim = hidden_states.shape
+        hidden_states = hidden_states.view(-1, hidden_dim)
+        # router_logits: (batch * sequence_length, n_experts)
+        router_logits = self.gate(hidden_states)
+
+        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+        if self.norm_topk_prob:
+            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+        # we cast back to the input dtype
+        routing_weights = routing_weights.to(hidden_states.dtype)
+
+        final_hidden_states = torch.zeros(
+            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+        )
+
+        # One hot encode the selected experts to create an expert mask
+        # this will be used to easily index which expert is going to be selected
+        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+        # Loop over all available experts in the model and perform the computation on each expert
+        for expert_idx in range(self.num_experts):
+            expert_layer = self.experts[expert_idx]
+            idx, top_x = torch.where(expert_mask[expert_idx])
+
+            # Index the correct hidden states and compute the expert hidden state for
+            # the current expert. We need to make sure to multiply the output hidden
+            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+            # However `index_add_` only support torch tensors for indexing so we'll use
+            # the `top_x` tensor here.
+            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+        return final_hidden_states, router_logits
+
+
+class OlmoeDecoderLayer(nn.Module):
+    def __init__(self, config: OlmoeConfig, layer_idx: int):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+
+        self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+        self.mlp = OlmoeSparseMoeBlock(config)
+        self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: Optional[bool] = False,
+        output_router_logits: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        **kwargs,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*):
+                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+                query_sequence_length, key_sequence_length)` if default attention is used.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_router_logits (`bool`, *optional*):
+                Whether or not to return the logits of all the routers. They are useful for computing the router loss,
+                and should not be returned during inference.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+                Indices depicting the position of the input sequence tokens in the sequence
+            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+                with `head_dim` being the embedding dimension of each attention head.
+            kwargs (`dict`, *optional*):
+                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+                into the model
+        """
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            position_embeddings=position_embeddings,
+            **kwargs,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states, router_logits = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        if output_router_logits:
+            outputs += (router_logits,)
+
+        return outputs
+
+
+OLMOE_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`OlmoeConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+    "The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
+    OLMOE_START_DOCSTRING,
+)
+# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmoe
+class OlmoePreTrainedModel(PreTrainedModel):
+    config_class = OlmoeConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["OlmoeDecoderLayer"]
+    _skip_keys_device_placement = ["past_key_values"]
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+    _supports_flex_attn = True
+    _supports_cache_class = True
+    _supports_quantized_cache = True
+    _supports_static_cache = True
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+OLMOE_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+            `past_key_values`).
+
+            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+            information on the default strategy.
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.n_positions - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+            Two formats are allowed:
+            - a [`~cache_utils.Cache`] instance;
+            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+            cache format.
+
+            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+            legacy cache format will be returned.
+
+            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+            of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        output_router_logits (`bool`, *optional*):
+            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+            should not be returned during inference.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+            the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+    "The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
+    OLMOE_START_DOCSTRING,
+)
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
+class OlmoeModel(OlmoePreTrainedModel):
+    """
+    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
+
+    Args:
+        config: OlmoeConfig
+    """
+
+    def __init__(self, config: OlmoeConfig):
+        super().__init__(config)
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+        self.layers = nn.ModuleList(
+            [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.rotary_emb = OlmoeRotaryEmbedding(config=config)
+        self.gradient_checkpointing = False
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
+    # Ignore copy
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_router_logits: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> Union[Tuple, MoeModelOutputWithPast]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_router_logits = (
+            output_router_logits if output_router_logits is not None else self.config.output_router_logits
+        )
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+        if self.gradient_checkpointing and self.training and use_cache:
+            logger.warning_once(
+                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+            )
+            use_cache = False
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        # kept for BC (non `Cache` `past_key_values` inputs)
+        return_legacy_cache = False
+        if use_cache and not isinstance(past_key_values, Cache):
+            return_legacy_cache = True
+            if past_key_values is None:
+                past_key_values = DynamicCache()
+            else:
+                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+                logger.warning_once(
+                    "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+                    "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+                    "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+                )
+
+        if cache_position is None:
+            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+            cache_position = torch.arange(
+                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+            )
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
+        causal_mask = self._update_causal_mask(
+            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+        )
+
+        # embed positions
+        hidden_states = inputs_embeds
+
+        # create position embeddings to be shared across the decoder layers
+        position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_router_logits = () if output_router_logits else None
+        next_decoder_cache = None
+
+        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    causal_mask,
+                    position_ids,
+                    past_key_values,
+                    output_attentions,
+                    output_router_logits,
+                    use_cache,
+                    cache_position,
+                    position_embeddings,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=causal_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_values,
+                    output_attentions=output_attentions,
+                    output_router_logits=output_router_logits,
+                    use_cache=use_cache,
+                    cache_position=cache_position,
+                    position_embeddings=position_embeddings,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+            if output_router_logits and layer_outputs[-1] is not None:
+                all_router_logits += (layer_outputs[-1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if return_legacy_cache:
+            next_cache = next_cache.to_legacy_cache()
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+        return MoeModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            router_logits=all_router_logits,
+        )
+
+    def _update_causal_mask(
+        self,
+        attention_mask: torch.Tensor,
+        input_tensor: torch.Tensor,
+        cache_position: torch.Tensor,
+        past_key_values: Cache,
+        output_attentions: bool,
+    ):
+        if self.config._attn_implementation == "flash_attention_2":
+            if attention_mask is not None and 0.0 in attention_mask:
+                return attention_mask
+            return None
+
+        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+        # to infer the attention mask.
+        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+        using_static_cache = isinstance(past_key_values, StaticCache)
+
+        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+            if AttentionMaskConverter._ignore_causal_mask_sdpa(
+                attention_mask,
+                inputs_embeds=input_tensor,
+                past_key_values_length=past_seen_tokens,
+                is_training=self.training,
+            ):
+                return None
+
+        dtype, device = input_tensor.dtype, input_tensor.device
+        sequence_length = input_tensor.shape[1]
+        if using_static_cache:
+            target_length = past_key_values.get_max_cache_shape()
+        else:
+            target_length = (
+                attention_mask.shape[-1]
+                if isinstance(attention_mask, torch.Tensor)
+                else past_seen_tokens + sequence_length + 1
+            )
+
+        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+            attention_mask,
+            sequence_length=sequence_length,
+            target_length=target_length,
+            dtype=dtype,
+            device=device,
+            cache_position=cache_position,
+            batch_size=input_tensor.shape[0],
+        )
+
+        if (
+            self.config._attn_implementation == "sdpa"
+            and attention_mask is not None
+            and attention_mask.device.type == "cuda"
+            and not output_attentions
+        ):
+            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+            # Details: https://github.com/pytorch/pytorch/issues/110213
+            min_dtype = torch.finfo(dtype).min
+            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+        return causal_mask
+
+    @staticmethod
+    def _prepare_4d_causal_attention_mask_with_cache_position(
+        attention_mask: torch.Tensor,
+        sequence_length: int,
+        target_length: int,
+        dtype: torch.dtype,
+        device: torch.device,
+        cache_position: torch.Tensor,
+        batch_size: int,
+        **kwargs,
+    ):
+        """
+        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+        Args:
+            attention_mask (`torch.Tensor`):
+                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+                `(batch_size, 1, query_length, key_value_length)`.
+            sequence_length (`int`):
+                The sequence length being processed.
+            target_length (`int`):
+                The target length: when generating with static cache, the mask should be as long as the static cache,
+                to account for the 0 padding, the part of the cache that is not filled yet.
+            dtype (`torch.dtype`):
+                The dtype to use for the 4D attention mask.
+            device (`torch.device`):
+                The device to plcae the 4D attention mask on.
+            cache_position (`torch.Tensor`):
+                Indices depicting the position of the input sequence tokens in the sequence.
+            batch_size (`torch.Tensor`):
+                Batch size.
+        """
+        if attention_mask is not None and attention_mask.dim() == 4:
+            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+            causal_mask = attention_mask
+        else:
+            min_dtype = torch.finfo(dtype).min
+            causal_mask = torch.full(
+                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+            )
+            if sequence_length != 1:
+                causal_mask = torch.triu(causal_mask, diagonal=1)
+            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+            if attention_mask is not None:
+                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
+                mask_length = attention_mask.shape[-1]
+                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+                padding_mask = padding_mask == 0
+                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+                    padding_mask, min_dtype
+                )
+
+        return causal_mask
+
+
+class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.model = OlmoeModel(config)
+        self.vocab_size = config.vocab_size
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.router_aux_loss_coef = config.router_aux_loss_coef
+        self.num_experts = config.num_experts
+        self.num_experts_per_tok = config.num_experts_per_tok
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_router_logits: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        num_logits_to_keep: int = 0,
+        **loss_kwargs,
+    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+        r"""
+        Args:
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+            num_logits_to_keep (`int`, *optional*):
+                Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
+                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, OlmoeForCausalLM
+
+        >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
+        >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
+
+        >>> prompt = "Hey, are you conscious? Can you talk to me?"
+        >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >>> # Generate
+        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
+        ```
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_router_logits = (
+            output_router_logits if output_router_logits is not None else self.config.output_router_logits
+        )
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            output_router_logits=output_router_logits,
+            return_dict=return_dict,
+            cache_position=cache_position,
+        )
+
+        hidden_states = outputs[0]
+        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
+
+        aux_loss = None
+        if output_router_logits:
+            aux_loss = load_balancing_loss_func(
+                outputs.router_logits if return_dict else outputs[-1],
+                self.num_experts,
+                self.num_experts_per_tok,
+                attention_mask,
+            )
+            if labels is not None:
+                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            if output_router_logits:
+                output = (aux_loss,) + output
+            return (loss,) + output if loss is not None else output
+
+        return MoeCausalLMOutputWithPast(
+            loss=loss,
+            aux_loss=aux_loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            router_logits=outputs.router_logits,
+        )
+
+
+__all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4903c400f9827239f7920b7e2c9bfddfda48eecd
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_pegasus import *
+    from .modeling_flax_pegasus import *
+    from .modeling_pegasus import *
+    from .modeling_tf_pegasus import *
+    from .tokenization_pegasus import *
+    from .tokenization_pegasus_fast import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4019eb86fed9c828268939e36697a5032eb76685
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a6157f2e8aac32727b728298493ea3cb048cd8f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_flax_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_flax_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9296509d0f21aadf0643827eecb7b19269aa2520
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_flax_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72040283b4b6eb8e9be61c58975e152922403cc5
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_tf_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_tf_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b803946ecb775848dd2fee780d24bc4ae8f7219a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_tf_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3e94382ee5736de452b7263c89904dbd0d8cd63
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus_fast.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus_fast.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d46b699c5a1e250779504c7bc1d0a9feb8e6ac5
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus_fast.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/configuration_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/configuration_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..23bff5d7719f45061328757c916dbfcfaaa0c365
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/configuration_pegasus.py
@@ -0,0 +1,164 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PEGASUS model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PegasusConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`PegasusModel`]. It is used to instantiate an
+    PEGASUS model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the PEGASUS
+    [google/pegasus-large](https://huggingface.co/google/pegasus-large) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 50265):
+            Vocabulary size of the PEGASUS model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`PegasusModel`] or [`TFPegasusModel`].
+        d_model (`int`, *optional*, defaults to 1024):
+            Dimensionality of the layers and the pooler layer.
+        encoder_layers (`int`, *optional*, defaults to 12):
+            Number of encoder layers.
+        decoder_layers (`int`, *optional*, defaults to 12):
+            Number of decoder layers.
+        encoder_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        decoder_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+        encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+        activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"silu"` and `"gelu_new"` are supported.
+        dropout (`float`, *optional*, defaults to 0.1):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        activation_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for activations inside the fully connected layer.
+        max_position_embeddings (`int`, *optional*, defaults to 1024):
+            The maximum sequence length that this model might ever be used with. Typically set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        init_std (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+            The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+            for more details.
+        decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+            The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+            for more details.
+        scale_embedding (`bool`, *optional*, defaults to `False`):
+            Scale embeddings by diving by sqrt(d_model).
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models)
+        forced_eos_token_id (`int`, *optional*, defaults to 1):
+            The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+            `eos_token_id`.
+
+    Example:
+
+    ```python
+    >>> from transformers import PegasusConfig, PegasusModel
+
+    >>> # Initializing a PEGASUS google/pegasus-large style configuration
+    >>> configuration = PegasusConfig()
+
+    >>> # Initializing a model (with random weights) from the google/pegasus-large style configuration
+    >>> model = PegasusModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "pegasus"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+    def __init__(
+        self,
+        vocab_size=50265,
+        max_position_embeddings=1024,
+        encoder_layers=12,
+        encoder_ffn_dim=4096,
+        encoder_attention_heads=16,
+        decoder_layers=12,
+        decoder_ffn_dim=4096,
+        decoder_attention_heads=16,
+        encoder_layerdrop=0.0,
+        decoder_layerdrop=0.0,
+        use_cache=True,
+        is_encoder_decoder=True,
+        activation_function="gelu",
+        d_model=1024,
+        dropout=0.1,
+        attention_dropout=0.0,
+        activation_dropout=0.0,
+        init_std=0.02,
+        decoder_start_token_id=0,
+        scale_embedding=False,
+        pad_token_id=0,
+        eos_token_id=1,
+        forced_eos_token_id=1,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.d_model = d_model
+        self.encoder_ffn_dim = encoder_ffn_dim
+        self.encoder_layers = encoder_layers
+        self.encoder_attention_heads = encoder_attention_heads
+        self.decoder_ffn_dim = decoder_ffn_dim
+        self.decoder_layers = decoder_layers
+        self.decoder_attention_heads = decoder_attention_heads
+        self.dropout = dropout
+        self.attention_dropout = attention_dropout
+        self.activation_dropout = activation_dropout
+        self.activation_function = activation_function
+        self.init_std = init_std
+        self.encoder_layerdrop = encoder_layerdrop
+        self.decoder_layerdrop = decoder_layerdrop
+        self.use_cache = use_cache
+        self.num_hidden_layers = encoder_layers
+        self.scale_embedding = scale_embedding  # scale factor will be sqrt(d_model) if True
+        super().__init__(
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            is_encoder_decoder=is_encoder_decoder,
+            decoder_start_token_id=decoder_start_token_id,
+            forced_eos_token_id=forced_eos_token_id,
+            **kwargs,
+        )
+
+    @property
+    def num_attention_heads(self) -> int:
+        return self.encoder_attention_heads
+
+    @property
+    def hidden_size(self) -> int:
+        return self.d_model
+
+
+__all__ = ["PegasusConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..269a3268fed12770177f133b14177f7ca6347e98
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py
@@ -0,0 +1,1532 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Flax PEGASUS model."""
+
+import math
+import random
+from functools import partial
+from typing import Callable, Optional, Tuple
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import (
+    FlaxBaseModelOutput,
+    FlaxBaseModelOutputWithPastAndCrossAttentions,
+    FlaxCausalLMOutputWithCrossAttentions,
+    FlaxSeq2SeqLMOutput,
+    FlaxSeq2SeqModelOutput,
+)
+from ...modeling_flax_utils import (
+    ACT2FN,
+    FlaxPreTrainedModel,
+    add_start_docstrings_to_model_forward,
+    append_call_sample_docstring,
+    append_replace_return_docstrings,
+    overwrite_call_docstring,
+)
+from ...utils import add_start_docstrings, logging, replace_return_docstrings
+from .configuration_pegasus import PegasusConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/pegasus-large"
+_CONFIG_FOR_DOC = "PegasusConfig"
+
+PEGASUS_START_DOCSTRING = r"""
+    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a Flax Linen
+    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+    Finally, this model supports inherent JAX features such as:
+
+    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+    Parameters:
+        config ([`PegasusConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+            `jax.numpy.bfloat16` (on TPUs).
+
+            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+            specified all the computation will be performed with the given `dtype`.
+
+            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+            parameters.**
+
+            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+            [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+PEGASUS_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+
+            If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
+            paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
+        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.max_position_embeddings - 1]`.
+        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+            range `[0, config.max_position_embeddings - 1]`.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+PEGASUS_ENCODE_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.max_position_embeddings - 1]`.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+PEGASUS_DECODE_INPUTS_DOCSTRING = r"""
+    Args:
+        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+        encoder_outputs (`tuple(tuple(jnp.ndarray)`):
+            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+        encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+
+            If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
+            paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
+        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+            range `[0, config.max_position_embeddings - 1]`.
+        past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+            Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+            auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+    """
+    Shift input ids one token to the right.
+    """
+    shifted_input_ids = jnp.zeros_like(input_ids)
+    shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
+    shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
+
+    shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
+    return shifted_input_ids
+
+
+# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions
+def create_sinusoidal_positions(n_pos, dim):
+    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
+    sentinel = dim // 2 + dim % 2
+    out = np.zeros_like(position_enc)
+    out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
+    out[:, sentinel:] = np.cos(position_enc[:, 1::2])
+
+    return jnp.array(out)
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus
+class FlaxPegasusAttention(nn.Module):
+    config: PegasusConfig
+    embed_dim: int
+    num_heads: int
+    dropout: float = 0.0
+    causal: bool = False
+    bias: bool = True
+    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
+
+    def setup(self) -> None:
+        self.head_dim = self.embed_dim // self.num_heads
+        if self.head_dim * self.num_heads != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+                f" and `num_heads`: {self.num_heads})."
+            )
+
+        dense = partial(
+            nn.Dense,
+            self.embed_dim,
+            use_bias=self.bias,
+            dtype=self.dtype,
+            kernel_init=jax.nn.initializers.normal(self.config.init_std),
+        )
+
+        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
+        self.out_proj = dense()
+
+        self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+        if self.causal:
+            self.causal_mask = make_causal_mask(
+                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+            )
+
+    def _split_heads(self, hidden_states):
+        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+    def _merge_heads(self, hidden_states):
+        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+    @nn.compact
+    def _concatenate_to_cache(self, key, value, query, attention_mask):
+        """
+        This function takes projected key, value states from a single input token and concatenates the states to cached
+        states from previous steps. This function is slighly adapted from the official Flax repository:
+        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+        """
+        # detect if we're initializing by absence of existing cache data.
+        is_initialized = self.has_variable("cache", "cached_key")
+        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+        if is_initialized:
+            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+            # update key, value caches with our new 1d spatial slices
+            cur_index = cache_index.value
+            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+            key = lax.dynamic_update_slice(cached_key.value, key, indices)
+            value = lax.dynamic_update_slice(cached_value.value, value, indices)
+            cached_key.value = key
+            cached_value.value = value
+            num_updated_cache_vectors = query.shape[1]
+            cache_index.value = cache_index.value + num_updated_cache_vectors
+            # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+            pad_mask = jnp.broadcast_to(
+                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+            )
+            attention_mask = combine_masks(pad_mask, attention_mask)
+        return key, value, attention_mask
+
+    def __call__(
+        self,
+        hidden_states: jnp.ndarray,
+        key_value_states: Optional[jnp.ndarray] = None,
+        attention_mask: Optional[jnp.ndarray] = None,
+        init_cache: bool = False,
+        deterministic: bool = True,
+    ) -> Tuple[jnp.ndarray]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+        batch_size = hidden_states.shape[0]
+
+        # get query proj
+        query_states = self.q_proj(hidden_states)
+        # get key, value proj
+        if is_cross_attention:
+            # cross_attentions
+            key_states = self.k_proj(key_value_states)
+            value_states = self.v_proj(key_value_states)
+        else:
+            # self_attention
+            key_states = self.k_proj(hidden_states)
+            value_states = self.v_proj(hidden_states)
+
+        query_states = self._split_heads(query_states)
+        key_states = self._split_heads(key_states)
+        value_states = self._split_heads(value_states)
+
+        # handle cache prepare causal attention mask
+        if self.causal:
+            query_length, key_length = query_states.shape[1], key_states.shape[1]
+            if self.has_variable("cache", "cached_key"):
+                mask_shift = self.variables["cache"]["cache_index"]
+                max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+                causal_mask = lax.dynamic_slice(
+                    self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+                )
+            else:
+                causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+            causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+        # combine masks if needed
+        if attention_mask is not None and self.causal:
+            attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+            attention_mask = combine_masks(attention_mask, causal_mask)
+        elif self.causal:
+            attention_mask = causal_mask
+        elif attention_mask is not None:
+            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+        # During fast autoregressive decoding, we feed one position at a time,
+        # and cache the keys and values step by step.
+        if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+            key_states, value_states, attention_mask = self._concatenate_to_cache(
+                key_states, value_states, query_states, attention_mask
+            )
+
+        # Convert the boolean attention mask to an attention bias.
+        if attention_mask is not None:
+            # attention mask in the form of attention bias
+            attention_bias = lax.select(
+                attention_mask > 0,
+                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+            )
+        else:
+            attention_bias = None
+
+        dropout_rng = None
+        if not deterministic and self.dropout > 0.0:
+            dropout_rng = self.make_rng("dropout")
+
+        attn_weights = dot_product_attention_weights(
+            query_states,
+            key_states,
+            bias=attention_bias,
+            dropout_rng=dropout_rng,
+            dropout_rate=self.dropout,
+            broadcast_dropout=True,
+            deterministic=deterministic,
+            dtype=self.dtype,
+            precision=None,
+        )
+
+        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+        attn_output = self._merge_heads(attn_output)
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights
+
+
+# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Pegasus
+class FlaxPegasusEncoderLayer(nn.Module):
+    config: PegasusConfig
+    dtype: jnp.dtype = jnp.float32
+
+    def setup(self) -> None:
+        self.embed_dim = self.config.d_model
+        self.self_attn = FlaxPegasusAttention(
+            config=self.config,
+            embed_dim=self.embed_dim,
+            num_heads=self.config.encoder_attention_heads,
+            dropout=self.config.attention_dropout,
+            dtype=self.dtype,
+        )
+        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+        self.activation_fn = ACT2FN[self.config.activation_function]
+        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
+        self.fc1 = nn.Dense(
+            self.config.encoder_ffn_dim,
+            dtype=self.dtype,
+            kernel_init=jax.nn.initializers.normal(self.config.init_std),
+        )
+        self.fc2 = nn.Dense(
+            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+        )
+        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+    def __call__(
+        self,
+        hidden_states: jnp.ndarray,
+        attention_mask: jnp.ndarray,
+        output_attentions: bool = True,
+        deterministic: bool = True,
+    ) -> Tuple[jnp.ndarray]:
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+        hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
+        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Pegasus
+class FlaxPegasusEncoderLayerCollection(nn.Module):
+    config: PegasusConfig
+    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
+
+    def setup(self):
+        self.layers = [
+            FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype)
+            for i in range(self.config.encoder_layers)
+        ]
+        self.layerdrop = self.config.encoder_layerdrop
+
+    def __call__(
+        self,
+        hidden_states,
+        attention_mask,
+        deterministic: bool = True,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+    ):
+        all_attentions = () if output_attentions else None
+        all_hidden_states = () if output_hidden_states else None
+
+        for encoder_layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            dropout_probability = random.uniform(0, 1)
+            if not deterministic and (dropout_probability < self.layerdrop):  # skip the layer
+                layer_outputs = (None, None)
+            else:
+                layer_outputs = encoder_layer(
+                    hidden_states,
+                    attention_mask,
+                    output_attentions,
+                    deterministic,
+                )
+            hidden_states = layer_outputs[0]
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        outputs = (hidden_states, all_hidden_states, all_attentions)
+
+        if not return_dict:
+            return tuple(v for v in outputs if v is not None)
+
+        return FlaxBaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+        )
+
+
+# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Pegasus
+class FlaxPegasusDecoderLayer(nn.Module):
+    config: PegasusConfig
+    dtype: jnp.dtype = jnp.float32
+
+    def setup(self) -> None:
+        self.embed_dim = self.config.d_model
+        self.self_attn = FlaxPegasusAttention(
+            config=self.config,
+            embed_dim=self.embed_dim,
+            num_heads=self.config.decoder_attention_heads,
+            dropout=self.config.attention_dropout,
+            causal=True,
+            dtype=self.dtype,
+        )
+        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+        self.activation_fn = ACT2FN[self.config.activation_function]
+        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
+
+        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+        self.encoder_attn = FlaxPegasusAttention(
+            config=self.config,
+            embed_dim=self.embed_dim,
+            num_heads=self.config.decoder_attention_heads,
+            dropout=self.config.attention_dropout,
+            dtype=self.dtype,
+        )
+        self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+        self.fc1 = nn.Dense(
+            self.config.decoder_ffn_dim,
+            dtype=self.dtype,
+            kernel_init=jax.nn.initializers.normal(self.config.init_std),
+        )
+        self.fc2 = nn.Dense(
+            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+        )
+        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+    def __call__(
+        self,
+        hidden_states: jnp.ndarray,
+        attention_mask: jnp.ndarray,
+        encoder_hidden_states: Optional[jnp.ndarray] = None,
+        encoder_attention_mask: Optional[jnp.ndarray] = None,
+        init_cache: bool = False,
+        output_attentions: bool = True,
+        deterministic: bool = True,
+    ) -> Tuple[jnp.ndarray]:
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights = self.self_attn(
+            hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
+        )
+        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+        hidden_states = residual + hidden_states
+
+        # Cross-Attention Block
+        cross_attn_weights = None
+        if encoder_hidden_states is not None:
+            residual = hidden_states
+
+            hidden_states = self.encoder_attn_layer_norm(hidden_states)
+            hidden_states, cross_attn_weights = self.encoder_attn(
+                hidden_states=hidden_states,
+                key_value_states=encoder_hidden_states,
+                attention_mask=encoder_attention_mask,
+            )
+            hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+            hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights, cross_attn_weights)
+
+        return outputs
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Pegasus
+class FlaxPegasusDecoderLayerCollection(nn.Module):
+    config: PegasusConfig
+    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
+
+    def setup(self):
+        self.layers = [
+            FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype)
+            for i in range(self.config.decoder_layers)
+        ]
+        self.layerdrop = self.config.decoder_layerdrop
+
+    def __call__(
+        self,
+        hidden_states,
+        attention_mask,
+        encoder_hidden_states: Optional[jnp.ndarray] = None,
+        encoder_attention_mask: Optional[jnp.ndarray] = None,
+        deterministic: bool = True,
+        init_cache: bool = False,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+    ):
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+        for decoder_layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+                # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            dropout_probability = random.uniform(0, 1)
+            if not deterministic and (dropout_probability < self.layerdrop):
+                layer_outputs = (None, None, None)
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    init_cache=init_cache,
+                    output_attentions=output_attentions,
+                    deterministic=deterministic,
+                )
+
+            hidden_states = layer_outputs[0]
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
+
+        if not return_dict:
+            return tuple(v for v in outputs if v is not None)
+
+        return FlaxBaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class FlaxPegasusEncoder(nn.Module):
+    config: PegasusConfig
+    embed_tokens: nn.Embed
+    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
+
+    def setup(self):
+        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+        embed_dim = self.config.d_model
+        self.padding_idx = self.config.pad_token_id
+        self.max_source_positions = self.config.max_position_embeddings
+        self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
+
+        self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
+        self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype)
+        self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+    def __call__(
+        self,
+        input_ids,
+        attention_mask,
+        position_ids,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+        deterministic: bool = True,
+    ):
+        input_shape = input_ids.shape
+        input_ids = input_ids.reshape(-1, input_shape[-1])
+
+        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        # embed positions
+        embed_pos = jnp.take(self.embed_positions, position_ids, axis=0)
+        # explicitly cast the positions here, since self.embed_positions are not registered as parameters
+        embed_pos = embed_pos.astype(inputs_embeds.dtype)
+
+        hidden_states = inputs_embeds + embed_pos
+        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+        outputs = self.layers(
+            hidden_states,
+            attention_mask,
+            deterministic=deterministic,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        last_hidden_state = outputs[0]
+        last_hidden_state = self.layer_norm(last_hidden_state)
+
+        # update the last element in `hidden_states` after applying `layernorm` above
+        hidden_states = None
+        if output_hidden_states:
+            hidden_states = outputs[1]
+            hidden_states = hidden_states[:-1] + (last_hidden_state,)
+
+        if not return_dict:
+            outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
+            return tuple(v for v in outputs if v is not None)
+
+        return FlaxBaseModelOutput(
+            last_hidden_state=last_hidden_state,
+            hidden_states=hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+class FlaxPegasusDecoder(nn.Module):
+    config: PegasusConfig
+    embed_tokens: nn.Embed
+    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
+
+    def setup(self):
+        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+        embed_dim = self.config.d_model
+        self.padding_idx = self.config.pad_token_id
+        self.max_target_positions = self.config.max_position_embeddings
+        self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
+
+        self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
+
+        self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype)
+        self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+    def __call__(
+        self,
+        input_ids,
+        attention_mask,
+        position_ids,
+        encoder_hidden_states: Optional[jnp.ndarray] = None,
+        encoder_attention_mask: Optional[jnp.ndarray] = None,
+        init_cache: bool = False,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+        deterministic: bool = True,
+    ):
+        input_shape = input_ids.shape
+        input_ids = input_ids.reshape(-1, input_shape[-1])
+
+        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        # embed positions
+        positions = jnp.take(self.embed_positions, position_ids, axis=0)
+        # explicitly cast the positions here, since self.embed_positions are not registered as parameters
+        positions = positions.astype(inputs_embeds.dtype)
+
+        hidden_states = inputs_embeds + positions
+        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+        outputs = self.layers(
+            hidden_states,
+            attention_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+            deterministic=deterministic,
+            init_cache=init_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        last_hidden_state = outputs[0]
+        last_hidden_state = self.layer_norm(last_hidden_state)
+
+        # update the last element in `hidden_states` after applying `layernorm` above
+        hidden_states = None
+        if output_hidden_states:
+            hidden_states = outputs[1]
+            hidden_states = hidden_states[:-1] + (last_hidden_state,)
+
+        if not return_dict:
+            outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
+            return tuple(v for v in outputs if v is not None)
+
+        return FlaxBaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=last_hidden_state,
+            hidden_states=hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Pegasus
+class FlaxPegasusModule(nn.Module):
+    config: PegasusConfig
+    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
+
+    def setup(self):
+        self.shared = nn.Embed(
+            self.config.vocab_size,
+            self.config.d_model,
+            embedding_init=jax.nn.initializers.normal(self.config.init_std),
+            dtype=self.dtype,
+        )
+
+        self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
+        self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
+
+    def _get_encoder_module(self):
+        return self.encoder
+
+    def _get_decoder_module(self):
+        return self.decoder
+
+    def __call__(
+        self,
+        input_ids,
+        attention_mask,
+        decoder_input_ids,
+        decoder_attention_mask,
+        position_ids,
+        decoder_position_ids,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+        deterministic: bool = True,
+    ):
+        encoder_outputs = self.encoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=deterministic,
+        )
+
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            position_ids=decoder_position_ids,
+            encoder_hidden_states=encoder_outputs[0],
+            encoder_attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=deterministic,
+        )
+
+        if not return_dict:
+            return decoder_outputs + encoder_outputs
+
+        return FlaxSeq2SeqModelOutput(
+            last_hidden_state=decoder_outputs.last_hidden_state,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+
+class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
+    config_class = PegasusConfig
+    base_model_prefix: str = "model"
+    module_class: nn.Module = None
+
+    def __init__(
+        self,
+        config: PegasusConfig,
+        input_shape: Tuple[int] = (1, 1),
+        seed: int = 0,
+        dtype: jnp.dtype = jnp.float32,
+        _do_init: bool = True,
+        **kwargs,
+    ):
+        module = self.module_class(config=config, dtype=dtype, **kwargs)
+        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+        # init input tensors
+        input_ids = jnp.zeros(input_shape, dtype="i4")
+        attention_mask = jnp.ones_like(input_ids)
+        decoder_input_ids = input_ids
+        decoder_attention_mask = jnp.ones_like(input_ids)
+
+        batch_size, sequence_length = input_ids.shape
+        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+        decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+        params_rng, dropout_rng = jax.random.split(rng)
+        rngs = {"params": params_rng, "dropout": dropout_rng}
+
+        random_params = self.module.init(
+            rngs,
+            input_ids,
+            attention_mask,
+            decoder_input_ids,
+            decoder_attention_mask,
+            position_ids,
+            decoder_position_ids,
+        )["params"]
+
+        if params is not None:
+            random_params = flatten_dict(unfreeze(random_params))
+            params = flatten_dict(unfreeze(params))
+            for missing_key in self._missing_keys:
+                params[missing_key] = random_params[missing_key]
+            self._missing_keys = set()
+            return freeze(unflatten_dict(params))
+        else:
+            return random_params
+
+    def init_cache(self, batch_size, max_length, encoder_outputs):
+        r"""
+        Args:
+            batch_size (`int`):
+                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+            max_length (`int`):
+                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+                cache.
+            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
+                `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
+                `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
+                is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
+                cross-attention of the decoder.
+        """
+        # init input variables to retrieve cache
+        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+        decoder_position_ids = jnp.broadcast_to(
+            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
+        )
+
+        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+            decoder_module = module._get_decoder_module()
+            return decoder_module(
+                decoder_input_ids,
+                decoder_attention_mask,
+                decoder_position_ids,
+                **kwargs,
+            )
+
+        init_variables = self.module.init(
+            jax.random.PRNGKey(0),
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            decoder_position_ids=decoder_position_ids,
+            encoder_hidden_states=encoder_outputs[0],
+            init_cache=True,
+            method=_decoder_forward,  # we only need to call the decoder to init the cache
+        )
+        return unfreeze(init_variables["cache"])
+
+    @add_start_docstrings(PEGASUS_ENCODE_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=PegasusConfig)
+    def encode(
+        self,
+        input_ids: jnp.ndarray,
+        attention_mask: Optional[jnp.ndarray] = None,
+        position_ids: Optional[jnp.ndarray] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        train: bool = False,
+        params: dict = None,
+        dropout_rng: PRNGKey = None,
+    ):
+        r"""
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+        >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+        >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+        >>> text = "My friends are cool but they eat too many carbs."
+        >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+        >>> encoder_outputs = model.encode(**inputs)
+        ```"""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        if attention_mask is None:
+            attention_mask = jnp.ones_like(input_ids)
+        if position_ids is None:
+            batch_size, sequence_length = input_ids.shape
+            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+        # Handle any PRNG if needed
+        rngs = {}
+        if dropout_rng is not None:
+            rngs["dropout"] = dropout_rng
+
+        def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
+            encode_module = module._get_encoder_module()
+            return encode_module(input_ids, attention_mask, position_ids, **kwargs)
+
+        return self.module.apply(
+            {"params": params or self.params},
+            input_ids=jnp.array(input_ids, dtype="i4"),
+            attention_mask=jnp.array(attention_mask, dtype="i4"),
+            position_ids=jnp.array(position_ids, dtype="i4"),
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=not train,
+            rngs=rngs,
+            method=_encoder_forward,
+        )
+
+    @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig)
+    def decode(
+        self,
+        decoder_input_ids,
+        encoder_outputs,
+        encoder_attention_mask: Optional[jnp.ndarray] = None,
+        decoder_attention_mask: Optional[jnp.ndarray] = None,
+        decoder_position_ids: Optional[jnp.ndarray] = None,
+        past_key_values: dict = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        train: bool = False,
+        params: dict = None,
+        dropout_rng: PRNGKey = None,
+    ):
+        r"""
+        Returns:
+
+        Example:
+
+        ```python
+        >>> import jax.numpy as jnp
+        >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+        >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+        >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+        >>> text = "My friends are cool but they eat too many carbs."
+        >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+        >>> encoder_outputs = model.encode(**inputs)
+
+        >>> decoder_start_token_id = model.config.decoder_start_token_id
+        >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+        >>> last_decoder_hidden_states = outputs.last_hidden_state
+        ```"""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        encoder_hidden_states = encoder_outputs[0]
+        if encoder_attention_mask is None:
+            batch_size, sequence_length = encoder_hidden_states.shape[:2]
+            encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+        batch_size, sequence_length = decoder_input_ids.shape
+        if decoder_attention_mask is None:
+            decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+        if decoder_position_ids is None:
+            if past_key_values is not None:
+                raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+            decoder_position_ids = jnp.broadcast_to(
+                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+            )
+
+        # Handle any PRNG if needed
+        rngs = {}
+        if dropout_rng is not None:
+            rngs["dropout"] = dropout_rng
+
+        inputs = {"params": params or self.params}
+
+        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+        # it can be changed by FlaxPegasusAttention module
+        if past_key_values:
+            inputs["cache"] = past_key_values
+            mutable = ["cache"]
+        else:
+            mutable = False
+
+        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+            decoder_module = module._get_decoder_module()
+            return decoder_module(
+                decoder_input_ids,
+                decoder_attention_mask,
+                decoder_position_ids,
+                **kwargs,
+            )
+
+        outputs = self.module.apply(
+            inputs,
+            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=not train,
+            rngs=rngs,
+            mutable=mutable,
+            method=_decoder_forward,
+        )
+
+        # add updated cache to model output
+        if past_key_values is not None and return_dict:
+            outputs, past = outputs
+            outputs["past_key_values"] = unfreeze(past["cache"])
+            return outputs
+        elif past_key_values is not None and not return_dict:
+            outputs, past = outputs
+            outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+        return outputs
+
+    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+    def __call__(
+        self,
+        input_ids: jnp.ndarray,
+        attention_mask: Optional[jnp.ndarray] = None,
+        decoder_input_ids: Optional[jnp.ndarray] = None,
+        decoder_attention_mask: Optional[jnp.ndarray] = None,
+        position_ids: Optional[jnp.ndarray] = None,
+        decoder_position_ids: Optional[jnp.ndarray] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        train: bool = False,
+        params: dict = None,
+        dropout_rng: PRNGKey = None,
+    ):
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        # prepare encoder inputs
+        if attention_mask is None:
+            attention_mask = jnp.ones_like(input_ids)
+        if position_ids is None:
+            batch_size, sequence_length = input_ids.shape
+            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+        # prepare decoder inputs
+        if decoder_input_ids is None:
+            decoder_input_ids = shift_tokens_right(
+                input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
+            )
+        if decoder_attention_mask is None:
+            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+        if decoder_position_ids is None:
+            batch_size, sequence_length = decoder_input_ids.shape
+            decoder_position_ids = jnp.broadcast_to(
+                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+            )
+
+        # Handle any PRNG if needed
+        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+        return self.module.apply(
+            {"params": params or self.params},
+            input_ids=jnp.array(input_ids, dtype="i4"),
+            attention_mask=jnp.array(attention_mask, dtype="i4"),
+            position_ids=jnp.array(position_ids, dtype="i4"),
+            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=not train,
+            rngs=rngs,
+        )
+
+
+@add_start_docstrings(
+    "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
+    PEGASUS_START_DOCSTRING,
+)
+class FlaxPegasusModel(FlaxPegasusPreTrainedModel):
+    config: PegasusConfig
+    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
+    module_class = FlaxPegasusModule
+
+
+append_call_sample_docstring(FlaxPegasusModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Pegasus
+class FlaxPegasusForConditionalGenerationModule(nn.Module):
+    config: PegasusConfig
+    dtype: jnp.dtype = jnp.float32
+    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
+
+    def setup(self):
+        self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype)
+        self.lm_head = nn.Dense(
+            self.model.shared.num_embeddings,
+            use_bias=False,
+            dtype=self.dtype,
+            kernel_init=jax.nn.initializers.normal(self.config.init_std),
+        )
+        self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
+
+    def _get_encoder_module(self):
+        return self.model.encoder
+
+    def _get_decoder_module(self):
+        return self.model.decoder
+
+    def __call__(
+        self,
+        input_ids,
+        attention_mask,
+        decoder_input_ids,
+        decoder_attention_mask,
+        position_ids,
+        decoder_position_ids,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+        deterministic: bool = True,
+    ):
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            position_ids=position_ids,
+            decoder_position_ids=decoder_position_ids,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=deterministic,
+        )
+
+        hidden_states = outputs[0]
+
+        if self.config.tie_word_embeddings:
+            shared_embedding = self.model.variables["params"]["shared"]["embedding"]
+            lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+        else:
+            lm_logits = self.lm_head(hidden_states)
+
+        lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
+
+        if not return_dict:
+            output = (lm_logits,) + outputs[1:]
+            return output
+
+        return FlaxSeq2SeqLMOutput(
+            logits=lm_logits,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+
+@add_start_docstrings(
+    "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING
+)
+class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
+    module_class = FlaxPegasusForConditionalGenerationModule
+    dtype: jnp.dtype = jnp.float32
+
+    @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig)
+    def decode(
+        self,
+        decoder_input_ids,
+        encoder_outputs,
+        encoder_attention_mask: Optional[jnp.ndarray] = None,
+        decoder_attention_mask: Optional[jnp.ndarray] = None,
+        decoder_position_ids: Optional[jnp.ndarray] = None,
+        past_key_values: dict = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        deterministic: bool = True,
+        params: dict = None,
+        dropout_rng: PRNGKey = None,
+    ):
+        r"""
+        Returns:
+
+        Example:
+
+        ```python
+        >>> import jax.numpy as jnp
+        >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+        >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+        >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+        >>> text = "My friends are cool but they eat too many carbs."
+        >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+        >>> encoder_outputs = model.encode(**inputs)
+
+        >>> decoder_start_token_id = model.config.decoder_start_token_id
+        >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+        >>> logits = outputs.logits
+        ```"""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        encoder_hidden_states = encoder_outputs[0]
+        if encoder_attention_mask is None:
+            batch_size, sequence_length = encoder_hidden_states.shape[:2]
+            encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+        batch_size, sequence_length = decoder_input_ids.shape
+        if decoder_attention_mask is None:
+            decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+        if decoder_position_ids is None:
+            if past_key_values is not None:
+                raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+            decoder_position_ids = jnp.broadcast_to(
+                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+            )
+
+        # Handle any PRNG if needed
+        rngs = {}
+        if dropout_rng is not None:
+            rngs["dropout"] = dropout_rng
+
+        inputs = {"params": params or self.params}
+
+        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+        # it can be changed by FlaxPegasusAttention module
+        if past_key_values:
+            inputs["cache"] = past_key_values
+            mutable = ["cache"]
+        else:
+            mutable = False
+
+        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+            decoder_module = module._get_decoder_module()
+            outputs = decoder_module(
+                decoder_input_ids,
+                decoder_attention_mask,
+                decoder_position_ids,
+                **kwargs,
+            )
+            hidden_states = outputs[0]
+
+            if self.config.tie_word_embeddings:
+                shared_embedding = module.model.variables["params"]["shared"]["embedding"]
+                lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+            else:
+                lm_logits = module.lm_head(hidden_states)
+
+            lm_logits += module.final_logits_bias.astype(self.dtype)
+            return lm_logits, outputs
+
+        outputs = self.module.apply(
+            inputs,
+            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=deterministic,
+            rngs=rngs,
+            mutable=mutable,
+            method=_decoder_forward,
+        )
+
+        if past_key_values is None:
+            lm_logits, decoder_outputs = outputs
+        else:
+            (lm_logits, decoder_outputs), past = outputs
+
+        if return_dict:
+            outputs = FlaxCausalLMOutputWithCrossAttentions(
+                logits=lm_logits,
+                hidden_states=decoder_outputs.hidden_states,
+                attentions=decoder_outputs.attentions,
+                cross_attentions=decoder_outputs.cross_attentions,
+            )
+        else:
+            outputs = (lm_logits,) + decoder_outputs[1:]
+
+        # add updated cache to model output
+        if past_key_values is not None and return_dict:
+            outputs["past_key_values"] = unfreeze(past["cache"])
+            return outputs
+        elif past_key_values is not None and not return_dict:
+            outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+        return outputs
+
+    def prepare_inputs_for_generation(
+        self,
+        decoder_input_ids,
+        max_length,
+        attention_mask: Optional[jax.Array] = None,
+        decoder_attention_mask: Optional[jax.Array] = None,
+        encoder_outputs=None,
+        **kwargs,
+    ):
+        # initializing the cache
+        batch_size, seq_length = decoder_input_ids.shape
+
+        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
+        # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+        # But since the decoder uses a causal mask, those positions are masked anyways.
+        # Thus we can create a single static attention_mask here, which is more efficient for compilation
+        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+        if decoder_attention_mask is not None:
+            position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
+            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
+        else:
+            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+        return {
+            "past_key_values": past_key_values,
+            "encoder_outputs": encoder_outputs,
+            "encoder_attention_mask": attention_mask,
+            "decoder_attention_mask": extended_attention_mask,
+            "decoder_position_ids": position_ids,
+        }
+
+    def update_inputs_for_generation(self, model_outputs, model_kwargs):
+        model_kwargs["past_key_values"] = model_outputs.past_key_values
+        model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
+        return model_kwargs
+
+
+FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING = """
+    Returns:
+
+    Summarization example:
+
+    ```pyton
+    >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+    >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
+    >>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large')
+
+    >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
+    >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
+
+    >>> # Generate Summary
+    >>> summary_ids = model.generate(inputs['input_ids']).sequences
+    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
+    ```
+
+    Mask filling example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+    >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+    >>> TXT = "My friends are  but they eat too many carbs."
+
+    >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+    >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"]
+    >>> logits = model(input_ids).logits
+
+    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+    >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
+    >>> values, predictions = jax.lax.top_k(probs)
+
+    >>> tokenizer.decode(predictions).split()
+    ```
+"""
+
+overwrite_call_docstring(
+    FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING
+)
+append_replace_return_docstrings(
+    FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
+)
+
+
+__all__ = ["FlaxPegasusForConditionalGeneration", "FlaxPegasusModel", "FlaxPegasusPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b1cd70404c7e63147a8686cc90205463dd856a3
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_pegasus.py
@@ -0,0 +1,1634 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PEGASUS model."""
+
+import copy
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from ...modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPastAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    Seq2SeqLMOutput,
+    Seq2SeqModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+    add_end_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_pegasus import PegasusConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/pegasus-large"
+_CONFIG_FOR_DOC = "PegasusConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+    """
+    Shift input ids one token to the right.
+    """
+    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+    shifted_input_ids[:, 0] = decoder_start_token_id
+
+    if pad_token_id is None:
+        raise ValueError("self.model.config.pad_token_id has to be defined.")
+    # replace possible -100 values in labels by `pad_token_id`
+    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+    return shifted_input_ids
+
+
+# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
+class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
+    """This module produces sinusoidal positional embeddings of any length."""
+
+    def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
+        super().__init__(num_positions, embedding_dim)
+        self.weight = self._init_weight(self.weight)
+
+    @staticmethod
+    def _init_weight(out: nn.Parameter) -> nn.Parameter:
+        """
+        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
+        the 2nd half of the vector. [dim // 2:]
+        """
+        n_pos, dim = out.shape
+        position_enc = np.array(
+            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
+        )
+        out.requires_grad = False  # set early to avoid an error in pytorch-1.8+
+        sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
+        out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
+        out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
+        out.detach_()
+        return out
+
+    @torch.no_grad()
+    def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
+        """`input_ids_shape` is expected to be [bsz x seqlen]."""
+        bsz, seq_len = input_ids_shape[:2]
+        positions = torch.arange(
+            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+        )
+        return super().forward(positions)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus
+class PegasusAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        is_decoder: bool = False,
+        bias: bool = True,
+        is_causal: bool = False,
+        config: Optional[PegasusConfig] = None,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+                f" and `num_heads`: {num_heads})."
+            )
+        self.scaling = self.head_dim**-0.5
+        self.is_decoder = is_decoder
+        self.is_causal = is_causal
+
+        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+        else:
+            # self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.reshape(*proj_shape)
+        value_states = value_states.reshape(*proj_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if layer_head_mask is not None:
+            if layer_head_mask.size() != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.size()}"
+                )
+            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if output_attentions:
+            # this operation is a bit awkward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to be reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped, past_key_value
+
+
+PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention}
+
+
+# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
+class PegasusEncoderLayer(nn.Module):
+    def __init__(self, config: PegasusConfig):
+        super().__init__()
+        self.embed_dim = config.d_model
+
+        self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.encoder_attention_heads,
+            dropout=config.attention_dropout,
+            config=config,
+        )
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.dropout = config.dropout
+        self.activation_fn = ACT2FN[config.activation_function]
+        self.activation_dropout = config.activation_dropout
+        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor,
+        layer_head_mask: torch.Tensor,
+        output_attentions: bool = False,
+    ) -> torch.Tensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+                `(encoder_attention_heads,)`.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+        hidden_states, attn_weights, _ = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        if hidden_states.dtype == torch.float16 and (
+            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+        ):
+            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+
+# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS
+class PegasusDecoderLayer(nn.Module):
+    def __init__(self, config: PegasusConfig):
+        super().__init__()
+        self.embed_dim = config.d_model
+
+        self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            is_causal=True,
+            config=config,
+        )
+        self.dropout = config.dropout
+        self.activation_fn = ACT2FN[config.activation_function]
+        self.activation_dropout = config.activation_dropout
+
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+            self.embed_dim,
+            config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            config=config,
+        )
+        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = True,
+    ) -> torch.Tensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            encoder_hidden_states (`torch.FloatTensor`):
+                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+                `(encoder_attention_heads,)`.
+            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+                size `(decoder_attention_heads,)`.
+            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        # Self Attention
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        # add present self-attn cache to positions 1,2 of present_key_value tuple
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            past_key_value=self_attn_past_key_value,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        # Cross-Attention Block
+        cross_attn_present_key_value = None
+        cross_attn_weights = None
+        if encoder_hidden_states is not None:
+            residual = hidden_states
+            hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+                hidden_states=hidden_states,
+                key_value_states=encoder_hidden_states,
+                attention_mask=encoder_attention_mask,
+                layer_head_mask=cross_attn_layer_head_mask,
+                past_key_value=cross_attn_past_key_value,
+                output_attentions=output_attentions,
+            )
+            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+            hidden_states = residual + hidden_states
+
+            # add cross-attn to positions 3,4 of present_key_value tuple
+            present_key_value = present_key_value + cross_attn_present_key_value
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights, cross_attn_weights)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+class PegasusPreTrainedModel(PreTrainedModel):
+    config_class = PegasusConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+
+    def _init_weights(self, module):
+        std = self.config.init_std
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, PegasusSinusoidalPositionalEmbedding):
+            pass
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+PEGASUS_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`PegasusConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+PEGASUS_GENERATION_EXAMPLE = r"""
+    Summarization example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, PegasusForConditionalGeneration
+
+    >>> model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
+    >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
+
+    >>> ARTICLE_TO_SUMMARIZE = (
+    ...     "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
+    ...     "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
+    ...     "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
+    ... )
+    >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt")
+
+    >>> # Generate Summary
+    >>> summary_ids = model.generate(inputs["input_ids"])
+    >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+    "California's largest electricity provider has turned off power to hundreds of thousands of customers."
+    ```
+"""
+
+PEGASUS_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+            Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+            1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+            than the model's internal embedding lookup matrix.
+        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+            input (see `past_key_values`). This is useful if you want more control over how to convert
+            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+            of `inputs_embeds`.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class PegasusEncoder(PegasusPreTrainedModel):
+    """
+    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+    [`PegasusEncoderLayer`].
+
+    Args:
+        config: PegasusConfig
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
+        super().__init__(config)
+
+        self.dropout = config.dropout
+        self.layerdrop = config.encoder_layerdrop
+
+        embed_dim = config.d_model
+        self.padding_idx = config.pad_token_id
+        self.max_source_positions = config.max_position_embeddings
+        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+        if embed_tokens is not None:
+            self.embed_tokens = embed_tokens
+        else:
+            self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
+
+        self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+            config.max_position_embeddings,
+            embed_dim,
+            self.padding_idx,
+        )
+        self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)])
+        self.layer_norm = nn.LayerNorm(config.d_model)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def resize_position_embeddings(self, new_num_position_embeddings: int):
+        """
+        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+        config.max_position_embeddings`.
+
+        Arguments:
+            new_num_position_embeddings (`int`):
+                The number of new position embeddings. If position embeddings are learned, increasing the size will add
+                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+                will remove vectors from the end.
+        """
+        logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
+        self.config.max_position_embeddings = new_num_position_embeddings
+
+        self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+            self.config.max_position_embeddings,
+            self.config.d_model,
+            self.padding_idx,
+        )
+        self.embed_positions.to(self.device)
+
+    def get_position_embeddings(self) -> nn.Embedding:
+        """
+        Returns the position embeddings matrix
+        """
+        return self.embed_positions
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        head_mask=None,
+        inputs_embeds=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        embed_pos = self.embed_positions(input_shape)
+
+        hidden_states = inputs_embeds + embed_pos
+
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        # expand attention_mask
+        if attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        # check if head_mask has a correct number of layers specified if desired
+        if head_mask is not None:
+            if head_mask.size()[0] != len(self.layers):
+                raise ValueError(
+                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+                    f" {head_mask.size()[0]}."
+                )
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            to_drop = False
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:  # skip the layer
+                    to_drop = True
+
+            if to_drop:
+                layer_outputs = (None, None)
+            else:
+                if self.gradient_checkpointing and self.training:
+                    layer_outputs = self._gradient_checkpointing_func(
+                        encoder_layer.__call__,
+                        hidden_states,
+                        attention_mask,
+                        (head_mask[idx] if head_mask is not None else None),
+                        output_attentions,
+                    )
+                else:
+                    layer_outputs = encoder_layer(
+                        hidden_states,
+                        attention_mask,
+                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                        output_attentions=output_attentions,
+                    )
+
+                hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+
+class PegasusDecoder(PegasusPreTrainedModel):
+    """
+    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`]
+
+    Args:
+        config: PegasusConfig
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
+        super().__init__(config)
+        self.dropout = config.dropout
+        self.layerdrop = config.decoder_layerdrop
+        self.padding_idx = config.pad_token_id
+        self.max_target_positions = config.max_position_embeddings
+        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+        if embed_tokens is not None:
+            self.embed_tokens = embed_tokens
+        else:
+            self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
+
+        self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+            config.max_position_embeddings,
+            config.d_model,
+            self.padding_idx,
+        )
+        self.layers = nn.ModuleList([PegasusDecoderLayer(config) for _ in range(config.decoder_layers)])
+        self.layer_norm = nn.LayerNorm(config.d_model)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    def resize_position_embeddings(self, new_num_position_embeddings: int):
+        """
+        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+        config.max_position_embeddings`.
+
+        Arguments:
+            new_num_position_embeddings (`int`):
+                The number of new position embeddings. If position embeddings are learned, increasing the size will add
+                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+                will remove vectors from the end.
+        """
+        logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
+        self.config.max_position_embeddings = new_num_position_embeddings
+
+        self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+            self.config.max_position_embeddings,
+            self.config.d_model,
+            self.padding_idx,
+        )
+        self.embed_positions.to(self.device)
+
+    def get_position_embeddings(self) -> nn.Embedding:
+        """
+        Returns the position embeddings matrix
+        """
+        return self.embed_positions
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        head_mask=None,
+        cross_attn_head_mask=None,
+        past_key_values=None,
+        inputs_embeds=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                of the decoder.
+            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+                selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing
+                cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        attention_mask = _prepare_4d_causal_attention_mask(
+            attention_mask, input_shape, inputs_embeds, past_key_values_length
+        )
+
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            encoder_attention_mask = _prepare_4d_attention_mask(
+                encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+            )
+
+        # embed positions
+        positions = self.embed_positions(input_shape, past_key_values_length)
+
+        hidden_states = inputs_embeds + positions
+
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {head_mask.size()[0]}."
+                    )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+                    None,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+@add_start_docstrings(
+    "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
+    PEGASUS_START_DOCSTRING,
+)
+class PegasusModel(PegasusPreTrainedModel):
+    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+    def __init__(self, config: PegasusConfig):
+        super().__init__(config)
+
+        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+        self.encoder = PegasusEncoder(config, self.shared)
+        self.decoder = PegasusDecoder(config, self.shared)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.shared
+
+    def set_input_embeddings(self, value):
+        self.shared = value
+        self.encoder.embed_tokens = self.shared
+        self.decoder.embed_tokens = self.shared
+
+    def get_encoder(self):
+        return self.encoder
+
+    def get_decoder(self):
+        return self.decoder
+
+    def resize_position_embeddings(self, new_num_position_embeddings: int):
+        """
+        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+        config.max_position_embeddings`.
+
+        Arguments:
+            new_num_position_embeddings (`int`):
+                The number of new position embeddings. If position embeddings are learned, increasing the size will add
+                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+                will remove vectors from the end.
+        """
+        self.config.max_position_embeddings = new_num_position_embeddings
+        self.encoder.resize_position_embeddings(new_num_position_embeddings)
+        self.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+    def get_position_embeddings(self) -> Tuple[nn.Embedding]:
+        """
+        Returns the position embeddings matrix
+        """
+        return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
+
+    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.Tensor] = None,
+        decoder_attention_mask: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        decoder_inputs_embeds: Optional[torch.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, Seq2SeqModelOutput]:
+        r"""
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, PegasusModel
+
+        >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+        >>> model = PegasusModel.from_pretrained("google/pegasus-large")
+
+        >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
+        >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt")
+        >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
+
+        >>> last_hidden_states = outputs.last_hidden_state
+        >>> list(last_hidden_states.shape)
+        [1, 4, 1024]
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if encoder_outputs is None:
+            encoder_outputs = self.encoder(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                head_mask=head_mask,
+                inputs_embeds=inputs_embeds,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+            encoder_outputs = BaseModelOutput(
+                last_hidden_state=encoder_outputs[0],
+                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+            )
+
+        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_outputs[0],
+            encoder_attention_mask=attention_mask,
+            head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if not return_dict:
+            return decoder_outputs + encoder_outputs
+
+        return Seq2SeqModelOutput(
+            last_hidden_state=decoder_outputs.last_hidden_state,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING
+)
+class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
+    base_model_prefix = "model"
+    _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+    def __init__(self, config: PegasusConfig):
+        super().__init__(config)
+        self.model = PegasusModel(config)
+        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_encoder(self):
+        return self.model.get_encoder()
+
+    def get_decoder(self):
+        return self.model.get_decoder()
+
+    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
+        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+        return new_embeddings
+
+    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+        old_num_tokens = self.final_logits_bias.shape[-1]
+        if new_num_tokens <= old_num_tokens:
+            new_bias = self.final_logits_bias[:, :new_num_tokens]
+        else:
+            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+        self.register_buffer("final_logits_bias", new_bias)
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def resize_position_embeddings(self, new_num_position_embeddings: int):
+        """
+        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+        config.max_position_embeddings`.
+
+        Arguments:
+            new_num_position_embeddings (`int`):
+                The number of new position embeddings. If position embeddings are learned, increasing the size will add
+                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+                will remove vectors from the end.
+        """
+        self.config.max_position_embeddings = new_num_position_embeddings
+        self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
+        self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+    def get_position_embeddings(self) -> Tuple[nn.Embedding]:
+        """
+        Returns the position embeddings matrix
+        """
+        return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
+
+    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+    @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.Tensor] = None,
+        decoder_attention_mask: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        decoder_inputs_embeds: Optional[torch.Tensor] = None,
+        labels: Optional[torch.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, Seq2SeqLMOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if labels is not None:
+            if use_cache:
+                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+            use_cache = False
+            if decoder_input_ids is None and decoder_inputs_embeds is None:
+                decoder_input_ids = shift_tokens_right(
+                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
+                )
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            encoder_outputs=encoder_outputs,
+            decoder_attention_mask=decoder_attention_mask,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
+
+        masked_lm_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            output = (lm_logits,) + outputs[1:]
+            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+        return Seq2SeqLMOutput(
+            loss=masked_lm_loss,
+            logits=lm_logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            # cached cross_attention states don't have to be reordered -> they are always the same
+            reordered_past += (
+                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+                + layer_past[2:],
+            )
+        return reordered_past
+
+
+# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
+class PegasusDecoderWrapper(PegasusPreTrainedModel):
+    """
+    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+    used in combination with the [`EncoderDecoderModel`] framework.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.decoder = PegasusDecoder(config)
+
+    def forward(self, *args, **kwargs):
+        return self.decoder(*args, **kwargs)
+
+
+class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config):
+        config = copy.deepcopy(config)
+        config.is_decoder = True
+        config.is_encoder_decoder = False
+        super().__init__(config)
+        self.model = PegasusDecoderWrapper(config)
+
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.decoder.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.decoder.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model.decoder = decoder
+
+    def get_decoder(self):
+        return self.model.decoder
+
+    def get_position_embeddings(self) -> nn.Embedding:
+        """
+        Returns the position embeddings matrix
+        """
+        return self.model.decoder.get_position_embeddings()
+
+    def resize_position_embeddings(self, new_num_position_embeddings: int):
+        """
+        Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+        config.max_position_embeddings`.
+
+        Arguments:
+            new_num_position_embeddings (`int`):
+                The number of new position embeddings. If position embeddings are learned, increasing the size will add
+                newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+                position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+                add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+                will remove vectors from the end.
+        """
+        self.config.max_position_embeddings = new_num_position_embeddings
+        self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+    # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                if the model is configured as a decoder.
+            encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+                in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+                tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, PegasusForCausalLM
+
+        >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+        >>> model = PegasusForCausalLM.from_pretrained("google/pegasus-large", add_cross_attention=False)
+        >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+        >>> outputs = model(**inputs)
+
+        >>> logits = outputs.logits
+        >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+        >>> list(logits.shape) == expected_shape
+        True
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model.decoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        logits = self.lm_head(outputs[0])
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            reordered_past += (
+                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+            )
+        return reordered_past
+
+
+__all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..a51835dcfa447f889913b3de7506d38af690b741
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py
@@ -0,0 +1,1574 @@
+# coding=utf-8
+# Copyright 2021, Google Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 Pegasus model."""
+
+from __future__ import annotations
+
+import random
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+    TFBaseModelOutput,
+    TFBaseModelOutputWithPastAndCrossAttentions,
+    TFSeq2SeqLMOutput,
+    TFSeq2SeqModelOutput,
+)
+
+# Public API
+from ...modeling_tf_utils import (
+    TFCausalLanguageModelingLoss,
+    TFModelInputType,
+    TFPreTrainedModel,
+    keras,
+    keras_serializable,
+    unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+    add_code_sample_docstrings,
+    add_end_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_pegasus import PegasusConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/pegasus-large"
+_CONFIG_FOR_DOC = "PegasusConfig"
+
+
+LARGE_NEGATIVE = -1e8
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
+def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
+    pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
+    decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
+    start_tokens = tf.fill(
+        (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
+    )
+    shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
+    # replace possible -100 values in labels by `pad_token_id`
+    shifted_input_ids = tf.where(
+        shifted_input_ids == -100,
+        tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
+        shifted_input_ids,
+    )
+
+    # "Verify that `labels` has only positive values and -100"
+    assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
+
+    # Make sure the assertion op is called by wrapping the result in an identity no-op
+    with tf.control_dependencies([assert_gte0]):
+        shifted_input_ids = tf.identity(shifted_input_ids)
+
+    return shifted_input_ids
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
+def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
+    """
+    Make causal mask used for bi-directional self-attention.
+    """
+    bsz = input_ids_shape[0]
+    tgt_len = input_ids_shape[1]
+    mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
+    mask_cond = tf.range(shape_list(mask)[-1])
+
+    mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
+
+    if past_key_values_length > 0:
+        mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
+
+    return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
+    """
+    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+    """
+    src_len = shape_list(mask)[1]
+    tgt_len = tgt_len if tgt_len is not None else src_len
+    one_cst = tf.constant(1.0)
+    mask = tf.cast(mask, dtype=one_cst.dtype)
+    expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
+
+    return (one_cst - expanded_mask) * LARGE_NEGATIVE
+
+
+# Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus
+class TFPegasusSinusoidalPositionalEmbedding(keras.layers.Layer):
+    """This module produces sinusoidal positional embeddings of any length."""
+
+    def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
+        super().__init__(**kwargs)
+
+        if embedding_dim % 2 != 0:
+            raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
+
+        self.embedding_dim = embedding_dim
+        self.num_positions = num_positions
+
+    def build(self, input_shape: tf.TensorShape):
+        """
+        Build shared token embedding layer Shared weights logic adapted from
+        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
+        """
+
+        weight = self._init_weight(self.num_positions, self.embedding_dim)
+
+        self.weight = self.add_weight(
+            name="embeddings",
+            shape=[self.num_positions, self.embedding_dim],
+        )
+        weight = tf.cast(weight, dtype=self.weight.dtype)
+
+        self.weight.assign(weight)
+
+        super().build(input_shape)
+
+    @staticmethod
+    def _init_weight(n_pos: int, dim: int):
+        """
+        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
+        the 2nd half of the vector. [dim // 2:]
+        """
+        position_enc = np.array(
+            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
+        )
+        table = np.zeros_like(position_enc)
+        # index 0 is all zero
+        table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
+        table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
+        # convert to tensor
+        table = tf.convert_to_tensor(table)
+        tf.stop_gradient(table)
+        return table
+
+    def call(
+        self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None
+    ):
+        """Input is expected to be of size [bsz x seqlen]."""
+        if position_ids is None:
+            seq_len = input_shape[1]
+            position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
+        return tf.gather(self.weight, position_ids)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus
+class TFPegasusAttention(keras.layers.Layer):
+    """Multi-headed attention from "Attention Is All You Need"""
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        is_decoder: bool = False,
+        bias: bool = True,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.embed_dim = embed_dim
+
+        self.num_heads = num_heads
+        self.dropout = keras.layers.Dropout(dropout)
+        self.head_dim = embed_dim // num_heads
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+                f" and `num_heads`: {num_heads})."
+            )
+        self.scaling = self.head_dim**-0.5
+        self.is_decoder = is_decoder
+
+        self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
+        self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
+        self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
+        self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
+
+    def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+        return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        key_value_states: tf.Tensor | None = None,
+        past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
+        attention_mask: tf.Tensor | None = None,
+        layer_head_mask: tf.Tensor | None = None,
+        training: Optional[bool] = False,
+    ) -> Tuple[tf.Tensor, tf.Tensor | None]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+        bsz, tgt_len, embed_dim = shape_list(hidden_states)
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        if is_cross_attention and past_key_value is not None:
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+            key_states = tf.concat([past_key_value[0], key_states], axis=2)
+            value_states = tf.concat([past_key_value[1], value_states], axis=2)
+        else:
+            # self_attention
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
+        key_states = tf.reshape(key_states, proj_shape)
+        value_states = tf.reshape(value_states, proj_shape)
+
+        src_len = shape_list(key_states)[1]
+        attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
+
+        tf.debugging.assert_equal(
+            shape_list(attn_weights),
+            [bsz * self.num_heads, tgt_len, src_len],
+            message=(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {shape_list(attn_weights)}"
+            ),
+        )
+
+        if attention_mask is not None:
+            tf.debugging.assert_equal(
+                shape_list(attention_mask),
+                [bsz, 1, tgt_len, src_len],
+                message=(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+                    f" {shape_list(attention_mask)}"
+                ),
+            )
+
+            attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
+            attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
+            attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+        attn_weights = stable_softmax(attn_weights, axis=-1)
+
+        if layer_head_mask is not None:
+            tf.debugging.assert_equal(
+                shape_list(layer_head_mask),
+                [self.num_heads],
+                message=(
+                    f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+                    f" {shape_list(layer_head_mask)}"
+                ),
+            )
+
+            attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
+                attn_weights, (bsz, self.num_heads, tgt_len, src_len)
+            )
+            attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+        attn_probs = self.dropout(attn_weights, training=training)
+        attn_output = tf.matmul(attn_probs, value_states)
+
+        tf.debugging.assert_equal(
+            shape_list(attn_output),
+            [bsz * self.num_heads, tgt_len, self.head_dim],
+            message=(
+                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {shape_list(attn_output)}"
+            ),
+        )
+
+        attn_output = tf.transpose(
+            tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
+        )
+        attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+        attn_output = self.out_proj(attn_output)
+        attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
+
+        return attn_output, attn_weights, past_key_value
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "k_proj", None) is not None:
+            with tf.name_scope(self.k_proj.name):
+                self.k_proj.build([None, None, self.embed_dim])
+        if getattr(self, "q_proj", None) is not None:
+            with tf.name_scope(self.q_proj.name):
+                self.q_proj.build([None, None, self.embed_dim])
+        if getattr(self, "v_proj", None) is not None:
+            with tf.name_scope(self.v_proj.name):
+                self.v_proj.build([None, None, self.embed_dim])
+        if getattr(self, "out_proj", None) is not None:
+            with tf.name_scope(self.out_proj.name):
+                self.out_proj.build([None, None, self.embed_dim])
+
+
+# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Pegasus
+class TFPegasusEncoderLayer(keras.layers.Layer):
+    def __init__(self, config: PegasusConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.embed_dim = config.d_model
+        self.self_attn = TFPegasusAttention(
+            self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
+        )
+        self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+        self.dropout = keras.layers.Dropout(config.dropout)
+        self.activation_fn = get_tf_activation(config.activation_function)
+        self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
+        self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
+        self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
+        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+        self.config = config
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        attention_mask: tf.Tensor,
+        layer_head_mask: tf.Tensor,
+        training: Optional[bool] = False,
+    ):
+        """
+        Args:
+            hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
+            attention_mask (`tf.Tensor`): attention mask of size
+                *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+            layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
+                *(encoder_attention_heads,)*
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+        hidden_states, self_attn_weights, _ = self.self_attn(
+            hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
+        )
+
+        tf.debugging.assert_equal(
+            shape_list(hidden_states),
+            shape_list(residual),
+            message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
+        )
+
+        hidden_states = self.dropout(hidden_states, training=training)
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = self.activation_dropout(hidden_states, training=training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = self.dropout(hidden_states, training=training)
+        hidden_states = residual + hidden_states
+
+        return hidden_states, self_attn_weights
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "self_attn", None) is not None:
+            with tf.name_scope(self.self_attn.name):
+                self.self_attn.build(None)
+        if getattr(self, "self_attn_layer_norm", None) is not None:
+            with tf.name_scope(self.self_attn_layer_norm.name):
+                self.self_attn_layer_norm.build([None, None, self.embed_dim])
+        if getattr(self, "fc1", None) is not None:
+            with tf.name_scope(self.fc1.name):
+                self.fc1.build([None, None, self.embed_dim])
+        if getattr(self, "fc2", None) is not None:
+            with tf.name_scope(self.fc2.name):
+                self.fc2.build([None, None, self.config.encoder_ffn_dim])
+        if getattr(self, "final_layer_norm", None) is not None:
+            with tf.name_scope(self.final_layer_norm.name):
+                self.final_layer_norm.build([None, None, self.embed_dim])
+
+
+# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Pegasus
+class TFPegasusDecoderLayer(keras.layers.Layer):
+    def __init__(self, config: PegasusConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.embed_dim = config.d_model
+        self.self_attn = TFPegasusAttention(
+            embed_dim=self.embed_dim,
+            num_heads=config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            name="self_attn",
+            is_decoder=True,
+        )
+        self.dropout = keras.layers.Dropout(config.dropout)
+        self.activation_fn = get_tf_activation(config.activation_function)
+        self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
+
+        self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+        self.encoder_attn = TFPegasusAttention(
+            self.embed_dim,
+            config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            name="encoder_attn",
+            is_decoder=True,
+        )
+        self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
+        self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
+        self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
+        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+        self.config = config
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        attention_mask: tf.Tensor | None = None,
+        encoder_hidden_states: tf.Tensor | None = None,
+        encoder_attention_mask: tf.Tensor | None = None,
+        layer_head_mask: tf.Tensor | None = None,
+        cross_attn_layer_head_mask: tf.Tensor | None = None,
+        past_key_value: Tuple[tf.Tensor] | None = None,
+        training: Optional[bool] = False,
+    ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
+        """
+        Args:
+            hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
+            attention_mask (`tf.Tensor`): attention mask of size
+                *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+            encoder_hidden_states (`tf.Tensor`):
+                cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
+            encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
+                *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+            layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
+                *(decoder_attention_heads,)*
+            cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
+                *(decoder_attention_heads,)*
+            past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        # Self Attention
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        # add present self-attn cache to positions 1,2 of present_key_value tuple
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            past_key_value=self_attn_past_key_value,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+        )
+        hidden_states = self.dropout(hidden_states, training=training)
+        hidden_states = residual + hidden_states
+
+        # Cross-Attention Block
+        cross_attn_present_key_value = None
+        cross_attn_weights = None
+        if encoder_hidden_states is not None:
+            residual = hidden_states
+            hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+                hidden_states=hidden_states,
+                key_value_states=encoder_hidden_states,
+                attention_mask=encoder_attention_mask,
+                layer_head_mask=cross_attn_layer_head_mask,
+                past_key_value=cross_attn_past_key_value,
+            )
+            hidden_states = self.dropout(hidden_states, training=training)
+            hidden_states = residual + hidden_states
+
+            # add cross-attn to positions 3,4 of present_key_value tuple
+            present_key_value = present_key_value + cross_attn_present_key_value
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = self.activation_dropout(hidden_states, training=training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = self.dropout(hidden_states, training=training)
+        hidden_states = residual + hidden_states
+
+        return (
+            hidden_states,
+            self_attn_weights,
+            cross_attn_weights,
+            present_key_value,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "self_attn", None) is not None:
+            with tf.name_scope(self.self_attn.name):
+                self.self_attn.build(None)
+        if getattr(self, "self_attn_layer_norm", None) is not None:
+            with tf.name_scope(self.self_attn_layer_norm.name):
+                self.self_attn_layer_norm.build([None, None, self.embed_dim])
+        if getattr(self, "encoder_attn", None) is not None:
+            with tf.name_scope(self.encoder_attn.name):
+                self.encoder_attn.build(None)
+        if getattr(self, "encoder_attn_layer_norm", None) is not None:
+            with tf.name_scope(self.encoder_attn_layer_norm.name):
+                self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
+        if getattr(self, "fc1", None) is not None:
+            with tf.name_scope(self.fc1.name):
+                self.fc1.build([None, None, self.embed_dim])
+        if getattr(self, "fc2", None) is not None:
+            with tf.name_scope(self.fc2.name):
+                self.fc2.build([None, None, self.config.decoder_ffn_dim])
+        if getattr(self, "final_layer_norm", None) is not None:
+            with tf.name_scope(self.final_layer_norm.name):
+                self.final_layer_norm.build([None, None, self.embed_dim])
+
+
+class TFPegasusPreTrainedModel(TFPreTrainedModel):
+    config_class = PegasusConfig
+    base_model_prefix = "model"
+
+
+PEGASUS_START_DOCSTRING = r"""
+    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+    behavior.
+
+    
+
+    TensorFlow models and layers in `transformers` accept two formats as input:
+
+    - having all inputs as keyword arguments (like PyTorch models), or
+    - having all inputs as a list, tuple or dict in the first positional argument.
+
+    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+    positional argument:
+
+    - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+    `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+    `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+    Note that when creating models and layers with
+    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+    about any of this, as you can just pass inputs like you would to any other Python function!
+
+    
+
+    Args:
+        config ([`PegasusConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+PEGASUS_GENERATION_EXAMPLE = r"""
+    Summarization example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, TFPegasusForConditionalGeneration
+
+    >>> model = TFPegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
+    >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
+
+    >>> ARTICLE_TO_SUMMARIZE = (
+    ...     "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
+    ...     "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
+    ...     "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
+    ... )
+    >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf")
+
+    >>> # Generate Summary
+    >>> summary_ids = model.generate(input_ids)
+    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
+    ```
+"""
+
+PEGASUS_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`tf.Tensor` of shape `({0})`):
+            Indices of input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+            Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+        decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
+        decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+            range `[0, config.max_position_embeddings - 1]`.
+        head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        encoder_outputs (`tf.FloatTensor`, *optional*):
+            hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+            of shape `(batch_size, sequence_length, hidden_size)` is a sequence of
+        past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
+            contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+            than the model's internal embedding lookup matrix.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`). Set to `False` during training, `True` during generation output_attentions (`bool`,
+            *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions`
+            under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the
+            value in the config will be used instead.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+            config will be used instead.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+            used instead.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+            eager mode, in graph mode the value will always be set to True.
+        training (`bool`, *optional*, defaults to `False`):
+            Whether or not to use the model in training mode (some modules like dropout modules have different
+            behaviors between training and evaluation).
+"""
+
+
+@keras_serializable
+class TFPegasusEncoder(keras.layers.Layer):
+    config_class = PegasusConfig
+    """
+    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+    [`TFPegasusEncoderLayer`].
+
+    Args:
+        config: PegasusConfig
+    """
+
+    def __init__(self, config: PegasusConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+        self.dropout = keras.layers.Dropout(config.dropout)
+        self.layerdrop = config.encoder_layerdrop
+        self.padding_idx = config.pad_token_id
+        self.max_source_positions = config.max_position_embeddings
+        self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
+
+        self.embed_tokens = embed_tokens
+        self.embed_positions = TFPegasusSinusoidalPositionalEmbedding(
+            config.max_position_embeddings,
+            config.d_model,
+            name="embed_positions",
+        )
+        self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
+        self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
+
+    def get_embed_tokens(self):
+        return self.embed_tokens
+
+    def set_embed_tokens(self, embed_tokens):
+        self.embed_tokens = embed_tokens
+
+    @unpack_inputs
+    def call(
+        self,
+        input_ids: tf.Tensor | None = None,
+        inputs_embeds: tf.Tensor | None = None,
+        attention_mask: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        training: Optional[bool] = False,
+    ):
+        """
+        Args:
+            input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
+                in the config will be used instead.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail. This argument can be used only in eager mode, in graph mode the value in the config
+                will be used instead.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used
+                in eager mode, in graph mode the value will always be set to True.
+            training (`bool`, *optional*, defaults to `False`):
+                Whether or not to use the model in training mode (some modules like dropout modules have different
+                behaviors between training and evaluation).
+        """
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = shape_list(input_ids)
+        elif inputs_embeds is not None:
+            input_shape = shape_list(inputs_embeds)[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        embed_pos = self.embed_positions(input_shape)
+        hidden_states = inputs_embeds + embed_pos
+        hidden_states = self.dropout(hidden_states, training=training)
+
+        # check attention mask and invert
+        if attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            attention_mask = _expand_mask(attention_mask)
+        else:
+            attention_mask = None
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        # check if head_mask has a correct number of layers specified if desired
+        if head_mask is not None:
+            tf.debugging.assert_equal(
+                shape_list(head_mask)[0],
+                len(self.layers),
+                message=(
+                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+                    f" {shape_list(head_mask)[0]}."
+                ),
+            )
+
+        # encoder layers
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            dropout_probability = random.uniform(0, 1)
+            if training and (dropout_probability < self.layerdrop):  # skip the layer
+                continue
+
+            hidden_states, attn = encoder_layer(
+                hidden_states,
+                attention_mask,
+                head_mask[idx] if head_mask is not None else None,
+            )
+
+            if output_attentions:
+                all_attentions += (attn,)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return TFBaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embed_positions", None) is not None:
+            with tf.name_scope(self.embed_positions.name):
+                self.embed_positions.build(None)
+        if getattr(self, "layer_norm", None) is not None:
+            with tf.name_scope(self.layer_norm.name):
+                self.layer_norm.build([None, None, self.config.d_model])
+        if getattr(self, "layers", None) is not None:
+            for layer in self.layers:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+
+
+@keras_serializable
+class TFPegasusDecoder(keras.layers.Layer):
+    config_class = PegasusConfig
+    """
+    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFPegasusDecoderLayer`]
+
+    Args:
+        config: PegasusConfig
+        embed_tokens: output embedding
+    """
+
+    def __init__(self, config: PegasusConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.embed_tokens = embed_tokens
+        self.layerdrop = config.decoder_layerdrop
+        self.embed_positions = TFPegasusSinusoidalPositionalEmbedding(
+            config.max_position_embeddings,
+            config.d_model,
+            name="embed_positions",
+        )
+        self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
+        self.layers = [TFPegasusDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
+        self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
+
+        self.dropout = keras.layers.Dropout(config.dropout)
+
+    def get_embed_tokens(self):
+        return self.embed_tokens
+
+    def set_embed_tokens(self, embed_tokens):
+        self.embed_tokens = embed_tokens
+
+    @unpack_inputs
+    def call(
+        self,
+        input_ids: tf.Tensor | None = None,
+        inputs_embeds: tf.Tensor | None = None,
+        attention_mask: tf.Tensor | None = None,
+        position_ids: tf.Tensor | None = None,
+        encoder_hidden_states: tf.Tensor | None = None,
+        encoder_attention_mask: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        cross_attn_head_mask: tf.Tensor | None = None,
+        past_key_values: Tuple[Tuple[tf.Tensor]] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        training: Optional[bool] = False,
+    ):
+        r"""
+        Args:
+            input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+                range `[0, config.max_position_embeddings - 1]`.
+            encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                of the decoder.
+            encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+                selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+                Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
+                decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+            inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
+                in the config will be used instead.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail. This argument can be used only in eager mode, in graph mode the value in the config
+                will be used instead.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used
+                in eager mode, in graph mode the value will always be set to True.
+            training (`bool`, *optional*, defaults to `False`):
+                Whether or not to use the model in training mode (some modules like dropout modules have different
+                behaviors between training and evaluation).
+        """
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = shape_list(input_ids)
+        elif inputs_embeds is not None:
+            input_shape = shape_list(inputs_embeds)[:-1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
+
+        # embed positions
+        if position_ids is None:
+            positions = self.embed_positions(input_shape, past_key_values_length)
+        else:
+            positions = self.embed_positions(input_shape, position_ids=position_ids)
+
+        if inputs_embeds is None:
+            check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        hidden_states = inputs_embeds
+
+        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+        if input_shape[-1] > 1:
+            combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
+        else:
+            combined_attention_mask = _expand_mask(
+                tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
+            )
+
+        if attention_mask is not None:
+            combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
+
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
+
+        hidden_states = self.dropout(hidden_states + positions, training=training)
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
+        present_key_values = () if use_cache else None
+
+        # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
+        for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
+            if attn_mask is not None:
+                tf.debugging.assert_equal(
+                    shape_list(attn_mask)[0],
+                    len(self.layers),
+                    message=(
+                        f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+                        f" {shape_list(attn_mask)[0]}."
+                    ),
+                )
+
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            dropout_probability = random.uniform(0, 1)
+
+            if training and (dropout_probability < self.layerdrop):
+                continue
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
+                hidden_states,
+                attention_mask=combined_attention_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                layer_head_mask=head_mask[idx] if head_mask is not None else None,
+                cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+                past_key_value=past_key_value,
+            )
+
+            if use_cache:
+                present_key_values += (present_key_value,)
+
+            if output_attentions:
+                all_self_attns += (layer_self_attn,)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attns += (layer_cross_attn,)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        if not return_dict:
+            return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
+        else:
+            return TFBaseModelOutputWithPastAndCrossAttentions(
+                last_hidden_state=hidden_states,
+                past_key_values=present_key_values,
+                hidden_states=all_hidden_states,
+                attentions=all_self_attns,
+                cross_attentions=all_cross_attns,
+            )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embed_positions", None) is not None:
+            with tf.name_scope(self.embed_positions.name):
+                self.embed_positions.build(None)
+        if getattr(self, "layer_norm", None) is not None:
+            with tf.name_scope(self.layer_norm.name):
+                self.layer_norm.build([None, None, self.config.d_model])
+        if getattr(self, "layers", None) is not None:
+            for layer in self.layers:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+
+
+@keras_serializable
+class TFPegasusMainLayer(keras.layers.Layer):
+    config_class = PegasusConfig
+
+    def __init__(self, config: PegasusConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.config = config
+        self.shared = keras.layers.Embedding(
+            input_dim=config.vocab_size,
+            output_dim=config.d_model,
+            embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
+            name="model.shared",
+        )
+        # Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
+        self.shared.load_weight_prefix = "model.shared"
+
+        self.encoder = TFPegasusEncoder(config, self.shared, name="encoder")
+        self.decoder = TFPegasusDecoder(config, self.shared, name="decoder")
+
+    def get_input_embeddings(self):
+        return self.shared
+
+    def set_input_embeddings(self, new_embeddings):
+        self.shared = new_embeddings
+        self.encoder.embed_tokens = self.shared
+        self.decoder.embed_tokens = self.shared
+
+    @unpack_inputs
+    def call(
+        self,
+        input_ids: tf.Tensor | None = None,
+        attention_mask: tf.Tensor | None = None,
+        decoder_input_ids: tf.Tensor | None = None,
+        decoder_attention_mask: tf.Tensor | None = None,
+        decoder_position_ids: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        decoder_head_mask: tf.Tensor | None = None,
+        cross_attn_head_mask: tf.Tensor | None = None,
+        encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
+        past_key_values: Tuple[Tuple[tf.Tensor]] = None,
+        inputs_embeds: tf.Tensor | None = None,
+        decoder_inputs_embeds: tf.Tensor | None = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        training: Optional[bool] = False,
+        **kwargs,
+    ):
+        if decoder_input_ids is None and decoder_inputs_embeds is None:
+            use_cache = False
+
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        if encoder_outputs is None:
+            encoder_outputs = self.encoder(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                head_mask=head_mask,
+                inputs_embeds=inputs_embeds,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+                training=training,
+            )
+        # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
+        elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
+            encoder_outputs = TFBaseModelOutput(
+                last_hidden_state=encoder_outputs[0],
+                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+            )
+        # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
+        elif not return_dict and not isinstance(encoder_outputs, tuple):
+            encoder_outputs = encoder_outputs.to_tuple()
+
+        decoder_outputs = self.decoder(
+            decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            position_ids=decoder_position_ids,
+            encoder_hidden_states=encoder_outputs[0],
+            encoder_attention_mask=attention_mask,
+            head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        if not return_dict:
+            return decoder_outputs + encoder_outputs
+
+        return TFSeq2SeqModelOutput(
+            last_hidden_state=decoder_outputs.last_hidden_state,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        # The shared/tied weights expect to be in the model base namespace
+        # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
+        # the current one.
+        with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
+            self.shared.build(None)
+        if getattr(self, "encoder", None) is not None:
+            with tf.name_scope(self.encoder.name):
+                self.encoder.build(None)
+        if getattr(self, "decoder", None) is not None:
+            with tf.name_scope(self.decoder.name):
+                self.decoder.build(None)
+
+
+@add_start_docstrings(
+    "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
+    PEGASUS_START_DOCSTRING,
+)
+class TFPegasusModel(TFPegasusPreTrainedModel):
+    def __init__(self, config: PegasusConfig, *inputs, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+
+        self.model = TFPegasusMainLayer(config, name="model")
+
+    def get_encoder(self):
+        return self.model.encoder
+
+    def get_decoder(self):
+        return self.model.decoder
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=TFSeq2SeqModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: np.ndarray | tf.Tensor | None = None,
+        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+        decoder_position_ids: np.ndarray | tf.Tensor | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        decoder_head_mask: np.ndarray | tf.Tensor | None = None,
+        cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
+        encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
+        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+        inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        training: bool = False,
+        **kwargs,
+    ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            decoder_position_ids=decoder_position_ids,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            encoder_outputs=encoder_outputs,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        return outputs
+
+    # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
+    def serving_output(self, output):
+        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
+        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
+        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
+        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
+        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
+
+        return TFSeq2SeqModelOutput(
+            last_hidden_state=output.last_hidden_state,
+            past_key_values=pkv,
+            decoder_hidden_states=dec_hs,
+            decoder_attentions=dec_attns,
+            cross_attentions=cross_attns,
+            encoder_last_hidden_state=output.encoder_last_hidden_state,
+            encoder_hidden_states=enc_hs,
+            encoder_attentions=enc_attns,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "model", None) is not None:
+            with tf.name_scope(self.model.name):
+                self.model.build(None)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
+class BiasLayer(keras.layers.Layer):
+    """
+    Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
+    so all weights have to be registered in a layer.
+    """
+
+    def __init__(self, shape, initializer, trainable, name, **kwargs):
+        super().__init__(name=name, **kwargs)
+        # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
+        # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
+        # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
+        self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
+
+    def call(self, x):
+        return x + self.bias
+
+
+@add_start_docstrings(
+    "The PEGASUS Model with a language modeling head. Can be used for summarization.",
+    PEGASUS_START_DOCSTRING,
+)
+class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss):
+    _keys_to_ignore_on_load_unexpected = [
+        r"model.encoder.embed_tokens.weight",
+        r"model.decoder.embed_tokens.weight",
+    ]
+
+    def __init__(self, config, *inputs, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+        self.model = TFPegasusMainLayer(config, name="model")
+        self.use_cache = config.use_cache
+        # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
+        self.bias_layer = BiasLayer(
+            name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
+        )
+
+    def get_decoder(self):
+        return self.model.decoder
+
+    def get_encoder(self):
+        return self.model.encoder
+
+    def get_output_embeddings(self):
+        return self.get_input_embeddings()
+
+    def set_output_embeddings(self, value):
+        self.set_input_embeddings(value)
+
+    def get_bias(self):
+        return {"final_logits_bias": self.bias_layer.bias}
+
+    def set_bias(self, value):
+        # Replaces the existing layers containing bias for correct (de)serialization.
+        vocab_size = value["final_logits_bias"].shape[-1]
+        self.bias_layer = BiasLayer(
+            name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
+        )
+        self.bias_layer.bias.assign(value["final_logits_bias"])
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+    @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: np.ndarray | tf.Tensor | None = None,
+        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+        decoder_position_ids: np.ndarray | tf.Tensor | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        decoder_head_mask: np.ndarray | tf.Tensor | None = None,
+        cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
+        encoder_outputs: Optional[TFBaseModelOutput] = None,
+        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+        inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: np.ndarray | tf.Tensor | None = None,
+        training: bool = False,
+    ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
+        """
+        labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        """
+
+        if labels is not None:
+            labels = tf.where(
+                labels == self.config.pad_token_id,
+                tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
+                labels,
+            )
+            use_cache = False
+            if decoder_input_ids is None and decoder_inputs_embeds is None:
+                decoder_input_ids = shift_tokens_right(
+                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
+                )
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            encoder_outputs=encoder_outputs,
+            decoder_attention_mask=decoder_attention_mask,
+            decoder_position_ids=decoder_position_ids,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+        lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
+        lm_logits = self.bias_layer(lm_logits)
+        masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
+
+        if not return_dict:
+            output = (lm_logits,) + outputs[1:]
+            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+        return TFSeq2SeqLMOutput(
+            loss=masked_lm_loss,
+            logits=lm_logits,
+            past_key_values=outputs.past_key_values,  # index 1 of d outputs
+            decoder_hidden_states=outputs.decoder_hidden_states,  # index 2 of d outputs
+            decoder_attentions=outputs.decoder_attentions,  # index 3 of d outputs
+            cross_attentions=outputs.cross_attentions,  # index 4 of d outputs
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,  # index 0 of encoder outputs
+            encoder_hidden_states=outputs.encoder_hidden_states,  # 1 of e out
+            encoder_attentions=outputs.encoder_attentions,  # 2 of e out
+        )
+
+    # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
+    def serving_output(self, output):
+        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
+        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
+        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
+        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
+        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
+
+        return TFSeq2SeqLMOutput(
+            logits=output.logits,
+            past_key_values=pkv,
+            decoder_hidden_states=dec_hs,
+            decoder_attentions=dec_attns,
+            cross_attentions=cross_attns,
+            encoder_last_hidden_state=output.encoder_last_hidden_state,
+            encoder_hidden_states=enc_hs,
+            encoder_attentions=enc_attns,
+        )
+
+    # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
+    def prepare_inputs_for_generation(
+        self,
+        decoder_input_ids,
+        past_key_values=None,
+        attention_mask=None,
+        decoder_attention_mask=None,
+        head_mask=None,
+        decoder_head_mask=None,
+        cross_attn_head_mask=None,
+        use_cache=None,
+        encoder_outputs=None,
+        **kwargs,
+    ):
+        # cut decoder_input_ids if past_key_values is used
+        if past_key_values is not None:
+            decoder_input_ids = decoder_input_ids[:, -1:]
+
+        if decoder_attention_mask is not None:  # xla
+            decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
+        elif past_key_values is not None:  # no xla + past_key_values
+            decoder_position_ids = past_key_values[0][0].shape[2]
+        else:  # no xla + no past_key_values
+            decoder_position_ids = tf.range(decoder_input_ids.shape[1])
+
+        return {
+            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
+            "encoder_outputs": encoder_outputs,
+            "past_key_values": past_key_values,
+            "decoder_input_ids": decoder_input_ids,
+            "attention_mask": attention_mask,
+            "decoder_attention_mask": decoder_attention_mask,
+            "decoder_position_ids": decoder_position_ids,
+            "head_mask": head_mask,
+            "decoder_head_mask": decoder_head_mask,
+            "cross_attn_head_mask": cross_attn_head_mask,
+            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
+        }
+
+    def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
+        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "model", None) is not None:
+            with tf.name_scope(self.model.name):
+                self.model.build(None)
+        if getattr(self, "bias_layer", None) is not None:
+            with tf.name_scope(self.bias_layer.name):
+                self.bias_layer.build(None)
+
+
+__all__ = ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bab5073322a291763fd4b75e538b2959cd45579
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus.py
@@ -0,0 +1,288 @@
+# coding=utf-8
+# Copyright 2020 Google and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO ArthurZ refactor this to only use the added_tokens_encoder
+class PegasusTokenizer(PreTrainedTokenizer):
+    r"""
+    Construct a PEGASUS tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
+
+    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+    this superclass for more information regarding those methods.
+
+    Args:
+        vocab_file (`str`):
+            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+            contains the vocabulary necessary to instantiate a tokenizer.
+        pad_token (`str`, *optional*, defaults to `""`):
+            The token used for padding, for example when batching sequences of different lengths.
+        eos_token (`str`, *optional*, defaults to `""`):
+            The end of sequence token.
+
+            
+
+            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+            The token used is the `sep_token`.
+
+            
+
+        unk_token (`str`, *optional*, defaults to `""`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        mask_token (`str`, *optional*, defaults to `""`):
+            The token used for masking single token values. This is the token used when training this model with masked
+            language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining.
+            It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive
+            Summarization](https://arxiv.org/pdf/1912.08777.pdf).
+        mask_token_sent (`str`, *optional*, defaults to `""`):
+            The token used for masking whole target sentences. This is the token used when training this model with gap
+            sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during
+            pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for
+            Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf).
+        additional_special_tokens (`List[str]`, *optional*):
+            Additional special tokens used by the tokenizer. If no additional_special_tokens are provided  and
+             are used as additional special tokens corresponding to the [original PEGASUS
+            tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66)
+            that uses the tokens 2 - 104 only for pretraining
+        sp_model_kwargs (`dict`, *optional*):
+            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+            to set:
+
+            - `enable_sampling`: Enable subword regularization.
+            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+              - `nbest_size = {0,1}`: No sampling is performed.
+              - `nbest_size > 1`: samples from the nbest_size results.
+              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+                using forward-filtering-and-backward-sampling algorithm.
+
+            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+              BPE-dropout.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file,
+        pad_token="",
+        eos_token="",
+        unk_token="",
+        mask_token="",
+        mask_token_sent="",
+        additional_special_tokens=None,
+        offset=103,  # entries 2 - 104 are only used for pretraining
+        sp_model_kwargs: Optional[Dict[str, Any]] = None,
+        **kwargs,
+    ) -> None:
+        self.offset = offset
+        if additional_special_tokens is not None:
+            if not isinstance(additional_special_tokens, list):
+                raise TypeError(
+                    f"additional_special_tokens should be of type {type(list)}, but is"
+                    f" {type(additional_special_tokens)}"
+                )
+            additional_special_tokens_extended = (
+                ([mask_token_sent] + additional_special_tokens)
+                if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
+                else additional_special_tokens
+            )
+            # fill additional tokens with ...,  in case not all additional tokens are already taken
+            additional_special_tokens_extended += [
+                f"" for i in range(len(additional_special_tokens_extended), self.offset - 1)
+            ]
+
+            if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
+                raise ValueError(
+                    "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+                    f" shifted list of  tokens. Found {additional_special_tokens_extended}."
+                )
+            additional_special_tokens = additional_special_tokens_extended
+        else:
+            additional_special_tokens_extended = []
+            additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
+            additional_special_tokens += [f"" for i in range(2, self.offset)]
+
+        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+        self.mask_token_sent = mask_token_sent
+        self.vocab_file = vocab_file
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.Load(vocab_file)
+
+        _added_tokens_decoder = {
+            0: AddedToken(str(pad_token), special=True),
+            1: AddedToken(str(eos_token), special=True),
+        }
+
+        if self.mask_token_sent is not None:
+            _added_tokens_decoder[2] = AddedToken(mask_token_sent, special=True)
+            _added_tokens_decoder[3] = AddedToken(str(mask_token), special=True)
+
+        for i in range(2, self.offset):
+            _added_tokens_decoder[len(_added_tokens_decoder)] = AddedToken(f"", special=True)
+
+        # Force update as we want to make sure vocab is enforced (same as fast)
+        self._added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
+        self._added_tokens_decoder.update(_added_tokens_decoder)
+
+        super().__init__(
+            eos_token=eos_token,
+            unk_token=unk_token,
+            mask_token=mask_token,
+            pad_token=pad_token,
+            mask_token_sent=mask_token_sent,
+            offset=offset,
+            additional_special_tokens=additional_special_tokens,
+            sp_model_kwargs=self.sp_model_kwargs,
+            **kwargs,
+        )
+
+    @property
+    def vocab_size(self) -> int:
+        return len(self.sp_model) + self.offset
+
+    def get_vocab(self) -> Dict[str, int]:
+        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+        vocab.update(self.added_tokens_encoder)
+        return vocab
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["sp_model"] = None
+        return state
+
+    def __setstate__(self, d):
+        self.__dict__ = d
+
+        # for backward compatibility
+        if not hasattr(self, "sp_model_kwargs"):
+            self.sp_model_kwargs = {}
+
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.Load(self.vocab_file)
+
+    def _tokenize(self, text: str) -> List[str]:
+        """Take as input a string and return a list of strings (tokens) for words/sub-words"""
+        return self.sp_model.encode(text, out_type=str)
+
+    def _convert_token_to_id(self, token: str) -> int:
+        """Converts a token (str) to an id using the vocab."""
+        sp_id = self.sp_model.piece_to_id(token)
+        return sp_id + self.offset
+
+    def _convert_id_to_token(self, index: int) -> str:
+        """Converts an index (integer) to a token (str) using the vocab."""
+        if index < self.offset:
+            return self.sp_model.IdToPiece(index)
+        token = self.sp_model.IdToPiece(index - self.offset)
+        return token
+
+    def convert_tokens_to_string(self, tokens):
+        """Converts a sequence of tokens (string) in a single string."""
+        current_sub_tokens = []
+        out_string = ""
+        for token in tokens:
+            # make sure that special tokens are not decoded using sentencepiece model
+            if token in self.all_special_tokens:
+                out_string += self.sp_model.decode(current_sub_tokens) + token
+                current_sub_tokens = []
+            else:
+                current_sub_tokens.append(token)
+        out_string += self.sp_model.decode(current_sub_tokens)
+        return out_string.strip()
+
+    def num_special_tokens_to_add(self, pair=False):
+        """Just EOS"""
+        return 1
+
+    def _special_token_mask(self, seq):
+        all_special_ids = set(self.all_special_ids)  # call it once instead of inside list comp
+        all_special_ids.remove(self.unk_token_id)  #  is only sometimes special
+
+        return [1 if x in all_special_ids else 0 for x in seq]
+
+    def get_special_tokens_mask(
+        self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
+    ) -> List[int]:
+        """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
+        if already_has_special_tokens:
+            return self._special_token_mask(token_ids_0)
+        elif token_ids_1 is None:
+            return self._special_token_mask(token_ids_0) + [1]
+        else:
+            return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
+
+    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
+        """
+        Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
+        and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:
+
+        - single sequence: `X `
+        - pair of sequences: `A B ` (not intended use)
+
+        BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+        separator.
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of IDs to which the special tokens will be added.
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+        """
+        if token_ids_1 is None:
+            return token_ids_0 + [self.eos_token_id]
+        # We don't expect to process pairs, but leave the pair logic for API consistency
+        return token_ids_0 + token_ids_1 + [self.eos_token_id]
+
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+        elif not os.path.isfile(self.vocab_file):
+            with open(out_vocab_file, "wb") as fi:
+                content_spiece_model = self.sp_model.serialized_model_proto()
+                fi.write(content_spiece_model)
+
+        return (out_vocab_file,)
+
+
+__all__ = ["PegasusTokenizer"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..af62976cb751450699c389fad874811517124e0e
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py
@@ -0,0 +1,219 @@
+# coding=utf-8
+# Copyright 2020 Google and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for model PEGASUS."""
+
+import os
+from shutil import copyfile
+from typing import List, Optional, Tuple
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+    from .tokenization_pegasus import PegasusTokenizer
+else:
+    PegasusTokenizer = None
+
+
+logger = logging.get_logger(__name__)
+
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
+
+
+class PegasusTokenizerFast(PreTrainedTokenizerFast):
+    r"""
+    Construct a "fast" PEGASUS tokenizer (backed by HuggingFace's *tokenizers* library). Based on
+    [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
+
+    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+    refer to this superclass for more information regarding those methods.
+
+    Args:
+        vocab_file (`str`):
+            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+            contains the vocabulary necessary to instantiate a tokenizer.
+        pad_token (`str`, *optional*, defaults to `""`):
+            The token used for padding, for example when batching sequences of different lengths.
+        eos_token (`str`, *optional*, defaults to `""`):
+            The end of sequence token.
+
+            
+
+            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+            The token used is the `sep_token`.
+
+            
+
+        unk_token (`str`, *optional*, defaults to `""`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        mask_token (`str`, *optional*, defaults to `""`):
+            The token used for masking single token values. This is the token used when training this model with masked
+            language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining.
+            It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive
+            Summarization](https://arxiv.org/pdf/1912.08777.pdf).
+        mask_token_sent (`str`, *optional*, defaults to `""`):
+            The token used for masking whole target sentences. This is the token used when training this model with gap
+            sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during
+            pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for
+            Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf).
+        additional_special_tokens (`List[str]`, *optional*):
+            Additional special tokens used by the tokenizer. If no additional_special_tokens are provided  and
+             are used as additional special tokens corresponding to the [original PEGASUS
+            tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66)
+            that uses the tokens 2 - 104 only for pretraining
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    slow_tokenizer_class = PegasusTokenizer
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file=None,
+        tokenizer_file=None,
+        pad_token="",
+        eos_token="",
+        unk_token="",
+        mask_token="",
+        mask_token_sent="",
+        additional_special_tokens=None,
+        offset=103,  # entries 2 - 104 are only used for pretraining
+        **kwargs,
+    ):
+        self.offset = offset
+
+        if additional_special_tokens is not None:
+            if not isinstance(additional_special_tokens, list):
+                raise TypeError(
+                    f"additional_special_tokens should be of type {type(list)}, but is"
+                    f" {type(additional_special_tokens)}"
+                )
+
+            additional_special_tokens_extended = (
+                ([mask_token_sent] + additional_special_tokens)
+                if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
+                else additional_special_tokens
+            )
+            # fill additional tokens with ...,  in case not all additional tokens are already taken
+            additional_special_tokens_extended += [
+                f"" for i in range(len(additional_special_tokens_extended), self.offset - 1)
+            ]
+
+            if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
+                raise ValueError(
+                    "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+                    f" shifted list of  tokens. Found {additional_special_tokens_extended}."
+                )
+            additional_special_tokens = additional_special_tokens_extended
+        else:
+            additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
+            additional_special_tokens += [f"" for i in range(2, self.offset)]
+
+        # pegasus was design to support changing the index of the first tokens. If one of the padding/eos/unk/mask token
+        # is different from default, we must rebuild the vocab
+        from_slow = kwargs.pop("from_slow", None)
+        from_slow = from_slow or str(pad_token) != "" or str(eos_token) != "" or str(unk_token) != ""
+
+        kwargs.pop("added_tokens_decoder", {})
+
+        super().__init__(
+            vocab_file,
+            tokenizer_file=tokenizer_file,
+            pad_token=pad_token,
+            eos_token=eos_token,
+            unk_token=unk_token,
+            mask_token=mask_token,
+            mask_token_sent=mask_token_sent,
+            offset=offset,
+            additional_special_tokens=additional_special_tokens,
+            from_slow=from_slow,
+            **kwargs,
+        )
+        self.vocab_file = vocab_file
+
+    @property
+    def can_save_slow_tokenizer(self) -> bool:
+        return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+    def _special_token_mask(self, seq):
+        all_special_ids = set(self.all_special_ids)  # call it once instead of inside list comp
+        all_special_ids.remove(self.unk_token_id)  #  is only sometimes special
+
+        if all_special_ids != set(range(len(self.additional_special_tokens) + 3)):
+            raise ValueError(
+                "There should be 3 special tokens: mask_token, pad_token, and eos_token +"
+                f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
+            )
+
+        return [1 if x in all_special_ids else 0 for x in seq]
+
+    def get_special_tokens_mask(
+        self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
+    ) -> List[int]:
+        """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
+        if already_has_special_tokens:
+            return self._special_token_mask(token_ids_0)
+        elif token_ids_1 is None:
+            return self._special_token_mask(token_ids_0) + [1]
+        else:
+            return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
+
+    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
+        """
+        Build model inputs from a sequence by adding eos to the end. no bos token is added to the front.
+
+        - single sequence: `X `
+        - pair of sequences: `A B ` (not intended use)
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of IDs to which the special tokens will be added
+            token_ids_1 (`List[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+        """
+        if token_ids_1 is None:
+            return token_ids_0 + [self.eos_token_id]
+        # We don't expect to process pairs, but leave the pair logic for API consistency
+        return token_ids_0 + token_ids_1 + [self.eos_token_id]
+
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+        if not self.can_save_slow_tokenizer:
+            raise ValueError(
+                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+                "tokenizer."
+            )
+
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+
+        return (out_vocab_file,)
+
+
+__all__ = ["PegasusTokenizerFast"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69ff8762af0e322d7038e00b892a1dbb0914e45
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__init__.py
@@ -0,0 +1,54 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {"configuration_vitpose_backbone": ["VitPoseBackboneConfig"]}
+
+
+try:
+    if not is_torch_available():
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    pass
+else:
+    _import_structure["modeling_vitpose_backbone"] = [
+        "VitPoseBackbonePreTrainedModel",
+        "VitPoseBackbone",
+    ]
+
+if TYPE_CHECKING:
+    from .configuration_vitpose_backbone import VitPoseBackboneConfig
+
+    try:
+        if not is_torch_available():
+            raise OptionalDependencyNotAvailable()
+    except OptionalDependencyNotAvailable:
+        pass
+    else:
+        from .modeling_vitpose_backbone import (
+            VitPoseBackbone,
+            VitPoseBackbonePreTrainedModel,
+        )
+
+else:
+    import sys
+
+    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c53e43a22bb2ba429c4fa1c45a01ee4d027449cb
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/configuration_vitpose_backbone.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/configuration_vitpose_backbone.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6febfd970839a75905e59e621c1c62ebcd9a9d80
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/configuration_vitpose_backbone.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/modeling_vitpose_backbone.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/modeling_vitpose_backbone.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58ad2c3133bf79c9d8c6eb308bb3338f9f893203
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/modeling_vitpose_backbone.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..2872d39a2a30df940307978b4c258a2c6cdfb811
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py
@@ -0,0 +1,136 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""VitPose backbone configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class VitPoseBackboneConfig(BackboneConfigMixin, PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`VitPoseBackbone`]. It is used to instantiate a
+    VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the VitPose
+    [usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        image_size (`int`, *optional*, defaults to `[256, 192]`):
+            The size (resolution) of each image.
+        patch_size (`List[int]`, *optional*, defaults to `[16, 16]`):
+            The size (resolution) of each patch.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        mlp_ratio (`int`, *optional*, defaults to 4):
+            The ratio of the hidden size in the feedforward network to the hidden size in the attention layers.
+        num_experts (`int`, *optional*, defaults to 1):
+            The number of experts in the MoE layer.
+        part_features (`int`, *optional*):
+            The number of part features to output. Only used in case `num_experts` is greater than 1.
+        hidden_act (`str`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"` and `"gelu_new"` are supported.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        qkv_bias (`bool`, *optional*, defaults to `True`):
+            Whether to add a bias to the queries, keys and values.
+        out_features (`List[str]`, *optional*):
+            If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+            (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+            corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+        out_indices (`List[int]`, *optional*):
+            If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+            many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+            If unset and `out_features` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+
+    Example:
+
+    ```python
+    >>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
+
+    >>> # Initializing a VitPose configuration
+    >>> configuration = VitPoseBackboneConfig()
+
+    >>> # Initializing a model (with random weights) from the configuration
+    >>> model = VitPoseBackbone(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "vitpose_backbone"
+
+    def __init__(
+        self,
+        image_size=[256, 192],
+        patch_size=[16, 16],
+        num_channels=3,
+        hidden_size=768,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        mlp_ratio=4,
+        num_experts=1,
+        part_features=256,
+        hidden_act="gelu",
+        hidden_dropout_prob=0.0,
+        attention_probs_dropout_prob=0.0,
+        initializer_range=0.02,
+        layer_norm_eps=1e-12,
+        qkv_bias=True,
+        out_features=None,
+        out_indices=None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.hidden_size = hidden_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.mlp_ratio = mlp_ratio
+        self.num_experts = num_experts
+        self.part_features = part_features
+        self.hidden_act = hidden_act
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.qkv_bias = qkv_bias
+        self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+        self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+            out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+        )
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f057988370f06ded97df5c4555276ca06ea7d8
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py
@@ -0,0 +1,542 @@
+# coding=utf-8
+# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch VitPose backbone model.
+
+This code is the same as the original Vision Transformer (ViT) with 2 modifications:
+- use of padding=2 in the patch embedding layer
+- addition of a mixture-of-experts MLP layer
+"""
+
+import collections.abc
+import math
+from typing import Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BackboneOutput, BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_vitpose_backbone import VitPoseBackboneConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "VitPoseBackboneConfig"
+
+
+class VitPoseBackbonePatchEmbeddings(nn.Module):
+    """Image to Patch Embedding."""
+
+    def __init__(self, config):
+        super().__init__()
+
+        image_size = config.image_size
+        patch_size = config.patch_size
+        num_channels = config.num_channels
+        embed_dim = config.hidden_size
+
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=2)
+
+    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+        height, width = pixel_values.shape[-2:]
+        if height != self.image_size[0] or width != self.image_size[1]:
+            raise ValueError(
+                f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+            )
+        embeddings = self.projection(pixel_values)
+
+        embeddings = embeddings.flatten(2).transpose(1, 2)
+        return embeddings
+
+
+class VitPoseBackboneEmbeddings(nn.Module):
+    """
+    Construct the position and patch embeddings.
+    """
+
+    def __init__(self, config: VitPoseBackboneConfig) -> None:
+        super().__init__()
+
+        self.patch_embeddings = VitPoseBackbonePatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+        embeddings = self.patch_embeddings(pixel_values)
+
+        # add positional encoding to each token
+        embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1]
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone
+class VitPoseBackboneSelfAttention(nn.Module):
+    def __init__(self, config: VitPoseBackboneConfig) -> None:
+        super().__init__()
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+                f"heads {config.num_attention_heads}."
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+        mixed_query_layer = self.query(hidden_states)
+
+        key_layer = self.transpose_for_scores(self.key(hidden_states))
+        value_layer = self.transpose_for_scores(self.value(hidden_states))
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VitPoseBackbone
+class VitPoseBackboneSelfOutput(nn.Module):
+    """
+    The residual connection is defined in VitPoseBackboneLayer instead of here (as is the case with other models), due to the
+    layernorm applied before each block.
+    """
+
+    def __init__(self, config: VitPoseBackboneConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VitPoseBackbone
+class VitPoseBackboneAttention(nn.Module):
+    def __init__(self, config: VitPoseBackboneConfig) -> None:
+        super().__init__()
+        self.attention = VitPoseBackboneSelfAttention(config)
+        self.output = VitPoseBackboneSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads: Set[int]) -> None:
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.attention.query = prune_linear_layer(self.attention.query, index)
+        self.attention.key = prune_linear_layer(self.attention.key, index)
+        self.attention.value = prune_linear_layer(self.attention.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+        self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+        attention_output = self.output(self_outputs[0], hidden_states)
+
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+class VitPoseBackboneMoeMLP(nn.Module):
+    def __init__(self, config: VitPoseBackboneConfig):
+        super().__init__()
+
+        in_features = out_features = config.hidden_size
+        hidden_features = int(config.hidden_size * config.mlp_ratio)
+
+        num_experts = config.num_experts
+        part_features = config.part_features
+
+        self.part_features = part_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = ACT2FN[config.hidden_act]
+        self.fc2 = nn.Linear(hidden_features, out_features - part_features)
+        self.drop = nn.Dropout(config.hidden_dropout_prob)
+
+        self.num_experts = num_experts
+        experts = [nn.Linear(hidden_features, part_features) for _ in range(num_experts)]
+        self.experts = nn.ModuleList(experts)
+
+    def forward(self, hidden_state: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+        expert_hidden_state = torch.zeros_like(hidden_state[:, :, -self.part_features :])
+
+        hidden_state = self.fc1(hidden_state)
+        hidden_state = self.act(hidden_state)
+        shared_hidden_state = self.fc2(hidden_state)
+        indices = indices.view(-1, 1, 1)
+
+        # to support ddp training
+        for i in range(self.num_experts):
+            selected_index = indices == i
+            current_hidden_state = self.experts[i](hidden_state) * selected_index
+            expert_hidden_state = expert_hidden_state + current_hidden_state
+
+        hidden_state = torch.cat([shared_hidden_state, expert_hidden_state], dim=-1)
+
+        return hidden_state
+
+
+class VitPoseBackboneMLP(nn.Module):
+    def __init__(self, config: VitPoseBackboneConfig) -> None:
+        super().__init__()
+        in_features = out_features = config.hidden_size
+        hidden_features = int(config.hidden_size * config.mlp_ratio)
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+        self.activation = ACT2FN[config.hidden_act]
+        self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+        hidden_state = self.fc1(hidden_state)
+        hidden_state = self.activation(hidden_state)
+        hidden_state = self.fc2(hidden_state)
+        return hidden_state
+
+
+class VitPoseBackboneLayer(nn.Module):
+    def __init__(self, config: VitPoseBackboneConfig) -> None:
+        super().__init__()
+        self.num_experts = config.num_experts
+        self.attention = VitPoseBackboneAttention(config)
+        self.mlp = VitPoseBackboneMLP(config) if self.num_experts == 1 else VitPoseBackboneMoeMLP(config)
+        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        dataset_index: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+        # Validate dataset_index when using multiple experts
+        if self.num_experts > 1 and dataset_index is None:
+            raise ValueError(
+                "dataset_index must be provided when using multiple experts "
+                f"(num_experts={self.num_experts}). Please provide dataset_index "
+                "to the forward pass."
+            )
+        self_attention_outputs = self.attention(
+            self.layernorm_before(hidden_states),  # in VitPoseBackbone, layernorm is applied before self-attention
+            head_mask,
+            output_attentions=output_attentions,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
+
+        # first residual connection
+        hidden_states = attention_output + hidden_states
+
+        layer_output = self.layernorm_after(hidden_states)
+        if self.num_experts == 1:
+            layer_output = self.mlp(layer_output)
+        else:
+            layer_output = self.mlp(layer_output, indices=dataset_index)
+
+        # second residual connection
+        layer_output = layer_output + hidden_states
+
+        outputs = (layer_output,) + outputs
+
+        return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VitPoseBackbone
+class VitPoseBackboneEncoder(nn.Module):
+    def __init__(self, config: VitPoseBackboneConfig) -> None:
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([VitPoseBackboneLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    # Ignore copy
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        dataset_index: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+    ) -> Union[tuple, BaseModelOutput]:
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    layer_module.__call__,
+                    hidden_states,
+                    dataset_index,
+                    layer_head_mask,
+                    output_attentions,
+                )
+            else:
+                layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions)
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+
+class VitPoseBackbonePreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = VitPoseBackboneConfig
+    base_model_prefix = "vit"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"]
+
+    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+            # `trunc_normal_cpu` not implemented in `half` issues
+            module.weight.data = nn.init.trunc_normal_(
+                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+            ).to(module.weight.dtype)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, VitPoseBackboneEmbeddings):
+            module.position_embeddings.data = nn.init.trunc_normal_(
+                module.position_embeddings.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.position_embeddings.dtype)
+
+
+VITPOSE_BACKBONE_START_DOCSTRING = r"""
+    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+    behavior.
+
+    Parameters:
+        config ([`VitPoseBackboneConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+VITPOSE_BACKBONE_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values.
+
+        dataset_index (`torch.Tensor` of shape `(batch_size,)`):
+            Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
+
+            This corresponds to the dataset index used during training, e.g. index 0 refers to COCO.
+
+        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The VitPose backbone useful for downstream tasks.",
+    VITPOSE_BACKBONE_START_DOCSTRING,
+)
+class VitPoseBackbone(VitPoseBackbonePreTrainedModel, BackboneMixin):
+    def __init__(self, config: VitPoseBackboneConfig):
+        super().__init__(config)
+        super()._init_backbone(config)
+
+        self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+        self.embeddings = VitPoseBackboneEmbeddings(config)
+        self.encoder = VitPoseBackboneEncoder(config)
+
+        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(VITPOSE_BACKBONE_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        dataset_index: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ):
+        """
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
+        >>> import torch
+
+        >>> config = VitPoseBackboneConfig(out_indices=[-1])
+        >>> model = VitPoseBackbone(config)
+
+        >>> pixel_values = torch.randn(1, 3, 256, 192)
+        >>> dataset_index = torch.tensor([1])
+        >>> outputs = model(pixel_values, dataset_index)
+        ```"""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        embedding_output = self.embeddings(pixel_values)
+
+        outputs = self.encoder(
+            embedding_output,
+            dataset_index=dataset_index,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=True,
+            return_dict=return_dict,
+        )
+        hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+        feature_maps = ()
+        for stage, hidden_state in zip(self.stage_names, hidden_states):
+            if stage in self.out_features:
+                hidden_state = self.layernorm(hidden_state)
+                feature_maps += (hidden_state,)
+
+        if not return_dict:
+            if output_hidden_states:
+                output = (feature_maps,) + outputs[1:]
+            else:
+                output = (feature_maps,) + outputs[2:]
+            return output
+
+        return BackboneOutput(
+            feature_maps=feature_maps,
+            hidden_states=outputs.hidden_states if output_hidden_states else None,
+            attentions=outputs.attentions,
+        )
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c5fea1b3147333041a432f8d50861104dc9140
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_x_clip import *
+    from .modeling_x_clip import *
+    from .processing_x_clip import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5310142ff0f8c60b8c80f5bf89b9a6f1bdf5eb2
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/configuration_x_clip.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/configuration_x_clip.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecfbc6a0190ff5234d8419d1c5d04dab9db9d96b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/configuration_x_clip.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/modeling_x_clip.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/modeling_x_clip.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdfba5f92b0cb642becf821eb7bea189d5adf78a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/modeling_x_clip.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/processing_x_clip.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/processing_x_clip.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d21dd48ab50a81740ebe59faa84164422d165ef9
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/processing_x_clip.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/configuration_x_clip.py b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/configuration_x_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..a500e5dccd90f2d1058b4a60b953ad3f26cc2d03
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/configuration_x_clip.py
@@ -0,0 +1,381 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""X-CLIP model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class XCLIPTextConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`XCLIPModel`]. It is used to instantiate an X-CLIP
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the X-CLIP
+    [microsoft/xclip-base-patch32](https://huggingface.co/microsoft/xclip-base-patch32) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 49408):
+            Vocabulary size of the X-CLIP text model. Defines the number of different tokens that can be represented by
+            the `inputs_ids` passed when calling [`XCLIPModel`].
+        hidden_size (`int`, *optional*, defaults to 512):
+            Dimensionality of the encoder layers and the pooler layer.
+        intermediate_size (`int`, *optional*, defaults to 2048):
+            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 8):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        max_position_embeddings (`int`, *optional*, defaults to 77):
+            The maximum sequence length that this model might ever be used with. Typically set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+            The epsilon used by the layer normalization layers.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        initializer_factor (`float`, *optional*, defaults to 1):
+            A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+            testing).
+
+    Example:
+
+    ```python
+    >>> from transformers import XCLIPTextModel, XCLIPTextConfig
+
+    >>> # Initializing a XCLIPTextModel with microsoft/xclip-base-patch32 style configuration
+    >>> configuration = XCLIPTextConfig()
+
+    >>> # Initializing a XCLIPTextConfig from the microsoft/xclip-base-patch32 style configuration
+    >>> model = XCLIPTextModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "xclip_text_model"
+    base_config_key = "text_config"
+
+    def __init__(
+        self,
+        vocab_size=49408,
+        hidden_size=512,
+        intermediate_size=2048,
+        num_hidden_layers=12,
+        num_attention_heads=8,
+        max_position_embeddings=77,
+        hidden_act="quick_gelu",
+        layer_norm_eps=1e-5,
+        attention_dropout=0.0,
+        initializer_range=0.02,
+        initializer_factor=1.0,
+        pad_token_id=1,
+        bos_token_id=0,
+        eos_token_id=2,
+        **kwargs,
+    ):
+        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+        self.vocab_size = vocab_size
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.max_position_embeddings = max_position_embeddings
+        self.layer_norm_eps = layer_norm_eps
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.initializer_factor = initializer_factor
+        self.attention_dropout = attention_dropout
+
+
+class XCLIPVisionConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`XCLIPModel`]. It is used to instantiate an X-CLIP
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the X-CLIP
+    [microsoft/xclip-base-patch32](https://huggingface.co/microsoft/xclip-base-patch32) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        intermediate_size (`int`, *optional*, defaults to 3072):
+            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        mit_hidden_size (`int`, *optional*, defaults to 512):
+            Dimensionality of the encoder layers of the Multiframe Integration Transformer (MIT).
+        mit_intermediate_size (`int`, *optional*, defaults to 2048):
+            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Multiframe Integration Transformer
+            (MIT).
+        mit_num_hidden_layers (`int`, *optional*, defaults to 1):
+            Number of hidden layers in the Multiframe Integration Transformer (MIT).
+        mit_num_attention_heads (`int`, *optional*, defaults to 8):
+            Number of attention heads for each attention layer in the Multiframe Integration Transformer (MIT).
+        image_size (`int`, *optional*, defaults to 224):
+            The size (resolution) of each image.
+        patch_size (`int`, *optional*, defaults to 32):
+            The size (resolution) of each patch.
+        hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"`, `"gelu_new"` and `"quick_gelu"` are supported.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+            The epsilon used by the layer normalization layers.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        initializer_factor (`float`, *optional*, defaults to 1):
+            A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+            testing).
+        drop_path_rate (`float`, *optional*, defaults to 0.0):
+            Stochastic depth rate.
+
+    Example:
+
+    ```python
+    >>> from transformers import XCLIPVisionModel, XCLIPVisionConfig
+
+    >>> # Initializing a XCLIPVisionModel with microsoft/xclip-base-patch32 style configuration
+    >>> configuration = XCLIPVisionConfig()
+
+    >>> # Initializing a XCLIPVisionModel model from the microsoft/xclip-base-patch32 style configuration
+    >>> model = XCLIPVisionModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "xclip_vision_model"
+    base_config_key = "vision_config"
+
+    def __init__(
+        self,
+        hidden_size=768,
+        intermediate_size=3072,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        mit_hidden_size=512,
+        mit_intermediate_size=2048,
+        mit_num_hidden_layers=1,
+        mit_num_attention_heads=8,
+        num_channels=3,
+        image_size=224,
+        patch_size=32,
+        num_frames=8,
+        hidden_act="quick_gelu",
+        layer_norm_eps=1e-5,
+        attention_dropout=0.0,
+        initializer_range=0.02,
+        initializer_factor=1.0,
+        drop_path_rate=0.0,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.mit_hidden_size = mit_hidden_size
+        self.mit_intermediate_size = mit_intermediate_size
+        self.mit_num_hidden_layers = mit_num_hidden_layers
+        self.mit_num_attention_heads = mit_num_attention_heads
+        self.num_channels = num_channels
+        self.patch_size = patch_size
+        self.num_frames = num_frames
+        self.image_size = image_size
+        self.initializer_range = initializer_range
+        self.initializer_factor = initializer_factor
+        self.attention_dropout = attention_dropout
+        self.layer_norm_eps = layer_norm_eps
+        self.hidden_act = hidden_act
+        self.drop_path_rate = drop_path_rate
+
+
+class XCLIPConfig(PretrainedConfig):
+    r"""
+    [`XCLIPConfig`] is the configuration class to store the configuration of a [`XCLIPModel`]. It is used to
+    instantiate X-CLIP model according to the specified arguments, defining the text model and vision model configs.
+    Instantiating a configuration with the defaults will yield a similar configuration to that of the X-CLIP
+    [microsoft/xclip-base-patch32](https://huggingface.co/microsoft/xclip-base-patch32) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        text_config (`dict`, *optional*):
+            Dictionary of configuration options used to initialize [`XCLIPTextConfig`].
+        vision_config (`dict`, *optional*):
+            Dictionary of configuration options used to initialize [`XCLIPVisionConfig`].
+        projection_dim (`int`, *optional*, defaults to 512):
+            Dimensionality of text and vision projection layers.
+        prompt_layers (`int`, *optional*, defaults to 2):
+            Number of layers in the video specific prompt generator.
+        prompt_alpha (`float`, *optional*, defaults to 0.1):
+            Alpha value to use in the video specific prompt generator.
+        prompt_hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+            The non-linear activation function (function or string) in the video specific prompt generator. If string,
+            `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+        prompt_num_attention_heads (`int`, *optional*, defaults to 8):
+            Number of attention heads in the cross-attention of the video specific prompt generator.
+        prompt_attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout probability for the attention layers in the video specific prompt generator.
+        prompt_projection_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout probability for the projection layers in the video specific prompt generator.
+        logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
+            The inital value of the *logit_scale* parameter. Default is used as per the original XCLIP implementation.
+        kwargs (*optional*):
+            Dictionary of keyword arguments.
+    """
+
+    model_type = "xclip"
+    sub_configs = {"text_config": XCLIPTextConfig, "vision_config": XCLIPVisionConfig}
+
+    def __init__(
+        self,
+        text_config=None,
+        vision_config=None,
+        projection_dim=512,
+        prompt_layers=2,
+        prompt_alpha=0.1,
+        prompt_hidden_act="quick_gelu",
+        prompt_num_attention_heads=8,
+        prompt_attention_dropout=0.0,
+        prompt_projection_dropout=0.0,
+        logit_scale_init_value=2.6592,
+        **kwargs,
+    ):
+        # If `_config_dict` exist, we use them for the backward compatibility.
+        # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
+        # of confusion!).
+        text_config_dict = kwargs.pop("text_config_dict", None)
+        vision_config_dict = kwargs.pop("vision_config_dict", None)
+
+        super().__init__(**kwargs)
+
+        # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
+        # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
+        # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
+        if text_config_dict is not None:
+            if text_config is None:
+                text_config = {}
+
+            # This is the complete result when using `text_config_dict`.
+            _text_config_dict = XCLIPTextConfig(**text_config_dict).to_dict()
+
+            # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
+            for key, value in _text_config_dict.items():
+                if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
+                    # If specified in `text_config_dict`
+                    if key in text_config_dict:
+                        message = (
+                            f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
+                            f'The value `text_config_dict["{key}"]` will be used instead.'
+                        )
+                    # If inferred from default argument values (just to be super careful)
+                    else:
+                        message = (
+                            f"`text_config_dict` is provided which will be used to initialize `XCLIPTextConfig`. The "
+                            f'value `text_config["{key}"]` will be overridden.'
+                        )
+                    logger.info(message)
+
+            # Update all values in `text_config` with the ones in `_text_config_dict`.
+            text_config.update(_text_config_dict)
+
+        if vision_config_dict is not None:
+            if vision_config is None:
+                vision_config = {}
+
+            # This is the complete result when using `vision_config_dict`.
+            _vision_config_dict = XCLIPVisionConfig(**vision_config_dict).to_dict()
+            # convert keys to string instead of integer
+            if "id2label" in _vision_config_dict:
+                _vision_config_dict["id2label"] = {
+                    str(key): value for key, value in _vision_config_dict["id2label"].items()
+                }
+
+            # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
+            for key, value in _vision_config_dict.items():
+                if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
+                    # If specified in `vision_config_dict`
+                    if key in vision_config_dict:
+                        message = (
+                            f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
+                            f'values. The value `vision_config_dict["{key}"]` will be used instead.'
+                        )
+                    # If inferred from default argument values (just to be super careful)
+                    else:
+                        message = (
+                            f"`vision_config_dict` is provided which will be used to initialize `XCLIPVisionConfig`. "
+                            f'The value `vision_config["{key}"]` will be overridden.'
+                        )
+                    logger.info(message)
+
+            # Update all values in `vision_config` with the ones in `_vision_config_dict`.
+            vision_config.update(_vision_config_dict)
+
+        if text_config is None:
+            text_config = {}
+            logger.info("`text_config` is `None`. Initializing the `XCLIPTextConfig` with default values.")
+
+        if vision_config is None:
+            vision_config = {}
+            logger.info("`vision_config` is `None`. initializing the `XCLIPVisionConfig` with default values.")
+
+        self.text_config = XCLIPTextConfig(**text_config)
+        self.vision_config = XCLIPVisionConfig(**vision_config)
+
+        self.projection_dim = projection_dim
+        self.prompt_layers = prompt_layers
+        self.prompt_alpha = prompt_alpha
+        self.prompt_hidden_act = prompt_hidden_act
+        self.prompt_num_attention_heads = prompt_num_attention_heads
+        self.prompt_attention_dropout = prompt_attention_dropout
+        self.prompt_projection_dropout = prompt_projection_dropout
+        self.logit_scale_init_value = logit_scale_init_value
+        self.initializer_factor = 1.0
+
+    @classmethod
+    def from_text_vision_configs(cls, text_config: XCLIPTextConfig, vision_config: XCLIPVisionConfig, **kwargs):
+        r"""
+        Instantiate a [`XCLIPConfig`] (or a derived class) from xclip text model configuration and xclip vision model
+        configuration.
+
+        Returns:
+            [`XCLIPConfig`]: An instance of a configuration object
+        """
+
+        return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
+
+
+__all__ = ["XCLIPConfig", "XCLIPTextConfig", "XCLIPVisionConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/modeling_x_clip.py b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/modeling_x_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce1ea02e7b9e085cb8b663a233d84f83d41a14b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/modeling_x_clip.py
@@ -0,0 +1,1680 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch X-CLIP model."""
+
+from copy import copy
+from dataclasses import dataclass
+from typing import Any, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+    ModelOutput,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+    torch_int,
+)
+from .configuration_x_clip import XCLIPConfig, XCLIPTextConfig, XCLIPVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "microsoft/xclip-base-patch32"
+
+
+# contrastive loss function, adapted from
+# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
+def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
+    return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
+
+
+# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->x_clip
+def x_clip_loss(similarity: torch.Tensor) -> torch.Tensor:
+    caption_loss = contrastive_loss(similarity)
+    image_loss = contrastive_loss(similarity.t())
+    return (caption_loss + image_loss) / 2.0
+
+
+@dataclass
+class XCLIPOutput(ModelOutput):
+    """
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+            Contrastive loss for video-text similarity.
+        logits_per_video (`torch.FloatTensor` of shape `(video_batch_size, text_batch_size)`):
+            The scaled dot product scores between `video_embeds` and `text_embeds`. This represents the video-text
+            similarity scores.
+        logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, video_batch_size)`):
+            The scaled dot product scores between `text_embeds` and `video_embeds`. This represents the text-video
+            similarity scores.
+        text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
+            The text embeddings obtained by applying the projection layer to the pooled output of [`XCLIPTextModel`].
+        video_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
+            The video embeddings obtained by applying the projection layer to the pooled output of
+            [`XCLIPVisionModel`].
+        text_model_output (`BaseModelOutputWithPooling`):
+            The output of the [`XCLIPTextModel`].
+        vision_model_output (`BaseModelOutputWithPooling`):
+            The output of the [`XCLIPVisionModel`].
+        mit_output (`BaseModelOutputWithPooling`):
+            The output of `XCLIPMultiframeIntegrationTransformer` (MIT for short).
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits_per_video: torch.FloatTensor = None
+    logits_per_text: torch.FloatTensor = None
+    text_embeds: torch.FloatTensor = None
+    video_embeds: torch.FloatTensor = None
+    text_model_output: BaseModelOutputWithPooling = None
+    vision_model_output: BaseModelOutputWithPooling = None
+    mit_output: BaseModelOutputWithPooling = None
+
+    def to_tuple(self) -> Tuple[Any]:
+        return tuple(
+            self[k]
+            if k not in ["text_model_output", "vision_model_output", "mit_output"]
+            else getattr(self, k).to_tuple()
+            for k in self.keys()
+        )
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->XCLIP
+class XCLIPVisionEmbeddings(nn.Module):
+    def __init__(self, config: XCLIPVisionConfig):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.image_size = config.image_size
+        self.patch_size = config.patch_size
+
+        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
+
+        self.patch_embedding = nn.Conv2d(
+            in_channels=config.num_channels,
+            out_channels=self.embed_dim,
+            kernel_size=self.patch_size,
+            stride=self.patch_size,
+            bias=False,
+        )
+
+        self.num_patches = (self.image_size // self.patch_size) ** 2
+        self.num_positions = self.num_patches + 1
+        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+        images. This method is also adapted to support torch.jit tracing.
+
+        Adapted from:
+        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+        """
+
+        num_patches = embeddings.shape[1] - 1
+        position_embedding = self.position_embedding.weight.unsqueeze(0)
+        num_positions = position_embedding.shape[1] - 1
+
+        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+            return self.position_embedding(self.position_ids)
+
+        class_pos_embed = position_embedding[:, :1]
+        patch_pos_embed = position_embedding[:, 1:]
+
+        dim = embeddings.shape[-1]
+
+        new_height = height // self.patch_size
+        new_width = width // self.patch_size
+
+        sqrt_num_positions = torch_int(num_positions**0.5)
+        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            size=(new_height, new_width),
+            mode="bicubic",
+            align_corners=False,
+        )
+
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+    def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
+        batch_size, _, height, width = pixel_values.shape
+        if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
+            raise ValueError(
+                f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
+            )
+        target_dtype = self.patch_embedding.weight.dtype
+        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
+        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+        if interpolate_pos_encoding:
+            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+        else:
+            embeddings = embeddings + self.position_embedding(self.position_ids)
+        return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->XCLIP
+class XCLIPTextEmbeddings(nn.Module):
+    def __init__(self, config: XCLIPTextConfig):
+        super().__init__()
+        embed_dim = config.hidden_size
+
+        self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+        self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
+
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.register_buffer(
+            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+        )
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+    ) -> torch.Tensor:
+        seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+
+        if position_ids is None:
+            position_ids = self.position_ids[:, :seq_length]
+
+        if inputs_embeds is None:
+            inputs_embeds = self.token_embedding(input_ids)
+
+        position_embeddings = self.position_embedding(position_ids)
+        embeddings = inputs_embeds + position_embeddings
+
+        return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->XCLIP
+class XCLIPAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.embed_dim // self.num_heads
+        if self.head_dim * self.num_heads != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+                f" {self.num_heads})."
+            )
+        self.scale = self.head_dim**-0.5
+        self.dropout = config.attention_dropout
+
+        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        causal_attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Input shape: Batch x Time x Channel"""
+
+        bsz, tgt_len, embed_dim = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scale
+        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.view(*proj_shape)
+        value_states = value_states.view(*proj_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        # apply the causal_attention_mask first
+        if causal_attention_mask is not None:
+            if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+                    f" {causal_attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if output_attentions:
+            # this operation is a bit akward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->XCLIP
+class XCLIPMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.activation_fn = ACT2FN[config.hidden_act]
+        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.fc2(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->XCLIP
+class XCLIPEncoderLayer(nn.Module):
+    def __init__(self, config: XCLIPConfig):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+        self.self_attn = XCLIPAttention(config)
+        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.mlp = XCLIPMLP(config)
+        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor,
+        causal_attention_mask: torch.Tensor,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+                `(config.encoder_attention_heads,)`.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+
+        hidden_states = self.layer_norm1(hidden_states)
+        hidden_states, attn_weights = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            causal_attention_mask=causal_attention_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.layer_norm2(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+    """
+    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+    argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return input
+    keep_prob = 1 - drop_prob
+    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+    random_tensor.floor_()  # binarize
+    output = input.div(keep_prob) * random_tensor
+    return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->XCLIP
+class XCLIPDropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob: Optional[float] = None) -> None:
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        return drop_path(hidden_states, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return "p={}".format(self.drop_prob)
+
+
+class XCLIPVisionEncoderLayer(nn.Module):
+    """
+    This corresponds to the `CrossFramelAttentionBlock` class in the original implementation.
+    """
+
+    def __init__(self, config: XCLIPConfig):
+        super().__init__()
+        self.num_frames = config.num_frames
+        self.embed_dim = config.hidden_size
+
+        self.message_fc = nn.Linear(self.embed_dim, self.embed_dim)
+        self.message_ln = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.message_attn = XCLIPAttention(config)
+
+        self.drop_path = XCLIPDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+
+        self.self_attn = XCLIPAttention(config)
+        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.mlp = XCLIPMLP(config)
+        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor,
+        causal_attention_mask: torch.Tensor,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+                `(config.encoder_attention_heads,)`.
+            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Causal mask for the text model. Mask values selected in `[0, 1]`:
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+                [What are attention masks?](../glossary#attention-mask)
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        batch_time, seq_length, hidden_size = hidden_states.size()
+        batch_size = batch_time // self.num_frames
+        msg_token = self.message_fc(hidden_states[:, 0, :])
+        msg_token = msg_token.view(batch_size, self.num_frames, hidden_size)
+
+        msg_token = msg_token + self.drop_path(self.message_attn(self.message_ln(msg_token))[0])
+        # add dummy sequence dimension
+        msg_token = msg_token.view(-1, 1, hidden_size)
+
+        hidden_states = torch.cat([hidden_states, msg_token], dim=1)
+
+        residual = hidden_states
+
+        hidden_states = self.layer_norm1(hidden_states)
+        hidden_states, attn_weights = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            causal_attention_mask=causal_attention_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = residual + hidden_states
+
+        hidden_states = hidden_states[:, :seq_length, :]
+
+        residual = hidden_states
+        hidden_states = self.layer_norm2(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+
+class XCLIPPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = XCLIPConfig
+    base_model_prefix = "x_clip"
+    supports_gradient_checkpointing = True
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        factor = self.config.initializer_factor
+        if isinstance(module, XCLIPTextEmbeddings):
+            module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+            module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+        elif isinstance(module, XCLIPVisionEmbeddings):
+            factor = self.config.initializer_factor
+            nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
+            nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
+            nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
+        elif isinstance(module, XCLIPAttention):
+            factor = self.config.initializer_factor
+            in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+            out_proj_std = (module.embed_dim**-0.5) * factor
+            nn.init.normal_(module.q_proj.weight, std=in_proj_std)
+            nn.init.normal_(module.k_proj.weight, std=in_proj_std)
+            nn.init.normal_(module.v_proj.weight, std=in_proj_std)
+            nn.init.normal_(module.out_proj.weight, std=out_proj_std)
+        elif isinstance(module, XCLIPMLP):
+            factor = self.config.initializer_factor
+            in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+            fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
+            nn.init.normal_(module.fc1.weight, std=fc_std)
+            nn.init.normal_(module.fc2.weight, std=in_proj_std)
+        elif isinstance(module, XCLIPModel):
+            factor = self.config.initializer_factor
+            nn.init.normal_(
+                module.text_projection.weight,
+                std=module.text_embed_dim**-0.5 * factor,
+            )
+            nn.init.normal_(
+                module.visual_projection.weight,
+                std=module.vision_embed_dim**-0.5 * factor,
+            )
+            nn.init.normal_(module.prompts_visual_projection, mean=0.0, std=module.vision_embed_dim**-0.5 * factor)
+        elif isinstance(module, XCLIPMultiframeIntegrationTransformer):
+            nn.init.normal_(module.position_embedding, std=self.config.initializer_factor)
+
+        if isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
+            if module.bias is not None:
+                module.bias.data.zero_()
+
+
+X_CLIP_START_DOCSTRING = r"""
+    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+    behavior.
+
+    Parameters:
+        config ([`XCLIPConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+X_CLIP_TEXT_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.max_position_embeddings - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+X_CLIP_VISION_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
+            Whether to interpolate the pre-trained position encodings.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+X_CLIP_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.max_position_embeddings - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+        return_loss (`bool`, *optional*):
+            Whether or not to return the contrastive loss.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
+            Whether to interpolate the pre-trained position encodings.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->XCLIP
+class XCLIPEncoder(nn.Module):
+    """
+    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+    [`XCLIPEncoderLayer`].
+
+    Args:
+        config: XCLIPConfig
+    """
+
+    def __init__(self, config: XCLIPConfig):
+        super().__init__()
+        self.config = config
+        self.layers = nn.ModuleList([XCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        inputs_embeds,
+        attention_mask: Optional[torch.Tensor] = None,
+        causal_attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutput]:
+        r"""
+        Args:
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        hidden_states = inputs_embeds
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    encoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    causal_attention_mask,
+                    output_attentions,
+                )
+            else:
+                layer_outputs = encoder_layer(
+                    hidden_states,
+                    attention_mask,
+                    causal_attention_mask,
+                    output_attentions=output_attentions,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+
+class XCLIPTextTransformer(nn.Module):
+    def __init__(self, config: XCLIPTextConfig):
+        super().__init__()
+        self.config = config
+        embed_dim = config.hidden_size
+        self.embeddings = XCLIPTextEmbeddings(config)
+        self.encoder = XCLIPEncoder(config)
+        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+    @add_start_docstrings_to_model_forward(X_CLIP_TEXT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=XCLIPTextConfig)
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPooling]:
+        r"""
+        Returns:
+
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if input_ids is None:
+            raise ValueError("You have to specify either input_ids")
+
+        input_shape = input_ids.size()
+        input_ids = input_ids.view(-1, input_shape[-1])
+
+        hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+        # X_CLIP's text model uses causal mask, prepare it here.
+        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+        causal_attention_mask = _create_4d_causal_attention_mask(
+            input_shape, hidden_states.dtype, device=hidden_states.device
+        )
+        # expand attention_mask
+        if attention_mask is not None:
+            # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
+            attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+        encoder_outputs = self.encoder(
+            inputs_embeds=hidden_states,
+            attention_mask=attention_mask,
+            causal_attention_mask=causal_attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        last_hidden_state = encoder_outputs[0]
+        last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+        # text_embeds.shape = [batch_size, sequence_length, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
+
+        if not return_dict:
+            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPooling(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+class XCLIPTextModel(XCLIPPreTrainedModel):
+    config_class = XCLIPTextConfig
+
+    def __init__(self, config: XCLIPTextConfig):
+        super().__init__(config)
+        self.text_model = XCLIPTextTransformer(config)
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> nn.Module:
+        return self.text_model.embeddings.token_embedding
+
+    def set_input_embeddings(self, value):
+        self.text_model.embeddings.token_embedding = value
+
+    @add_start_docstrings_to_model_forward(X_CLIP_TEXT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=XCLIPTextConfig)
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPooling]:
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoTokenizer, XCLIPTextModel
+
+        >>> model = XCLIPTextModel.from_pretrained("microsoft/xclip-base-patch32")
+        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/xclip-base-patch32")
+
+        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+        >>> outputs = model(**inputs)
+        >>> last_hidden_state = outputs.last_hidden_state
+        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
+        ```"""
+        return self.text_model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+
+class XCLIPVisionEncoder(nn.Module):
+    """
+    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+    [`XCLIPVisionEncoderLayer`].
+
+    Args:
+        config: XCLIPConfig
+    """
+
+    def __init__(self, config: XCLIPConfig):
+        super().__init__()
+        self.config = config
+        self.layers = nn.ModuleList([XCLIPVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        inputs_embeds,
+        attention_mask: Optional[torch.Tensor] = None,
+        causal_attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutput]:
+        r"""
+        Args:
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        hidden_states = inputs_embeds
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    encoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    causal_attention_mask,
+                    output_attentions,
+                )
+            else:
+                layer_outputs = encoder_layer(
+                    hidden_states,
+                    attention_mask,
+                    causal_attention_mask,
+                    output_attentions=output_attentions,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+
+class XCLIPVisionTransformer(nn.Module):
+    """
+    This corresponds to the `CrossFrameCommunicationTransformer` class in the original implementation.
+    """
+
+    def __init__(self, config: XCLIPVisionConfig):
+        super().__init__()
+        self.config = config
+        embed_dim = config.hidden_size
+
+        self.embeddings = XCLIPVisionEmbeddings(config)
+        self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+        self.encoder = XCLIPVisionEncoder(config)
+        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+    @add_start_docstrings_to_model_forward(X_CLIP_VISION_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=XCLIPVisionConfig)
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: bool = False,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPooling]:
+        r"""
+        Returns:
+
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+        hidden_states = self.pre_layernorm(hidden_states)
+
+        encoder_outputs = self.encoder(
+            inputs_embeds=hidden_states,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        last_hidden_state = encoder_outputs[0]
+        pooled_output = last_hidden_state[:, 0, :]
+        pooled_output = self.post_layernorm(pooled_output)
+
+        if not return_dict:
+            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPooling(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+class XCLIPVisionModel(XCLIPPreTrainedModel):
+    config_class = XCLIPVisionConfig
+    main_input_name = "pixel_values"
+
+    def __init__(self, config: XCLIPVisionConfig):
+        super().__init__(config)
+        self.vision_model = XCLIPVisionTransformer(config)
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> nn.Module:
+        return self.vision_model.embeddings.patch_embedding
+
+    @add_start_docstrings_to_model_forward(X_CLIP_VISION_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=XCLIPVisionConfig)
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPooling]:
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> import av
+        >>> import torch
+        >>> import numpy as np
+
+        >>> from transformers import AutoProcessor, XCLIPVisionModel
+        >>> from huggingface_hub import hf_hub_download
+
+        >>> np.random.seed(0)
+
+
+        >>> def read_video_pyav(container, indices):
+        ...     '''
+        ...     Decode the video with PyAV decoder.
+        ...     Args:
+        ...         container (`av.container.input.InputContainer`): PyAV container.
+        ...         indices (`List[int]`): List of frame indices to decode.
+        ...     Returns:
+        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
+        ...     '''
+        ...     frames = []
+        ...     container.seek(0)
+        ...     start_index = indices[0]
+        ...     end_index = indices[-1]
+        ...     for i, frame in enumerate(container.decode(video=0)):
+        ...         if i > end_index:
+        ...             break
+        ...         if i >= start_index and i in indices:
+        ...             frames.append(frame)
+        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+
+        >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
+        ...     '''
+        ...     Sample a given number of frame indices from the video.
+        ...     Args:
+        ...         clip_len (`int`): Total number of frames to sample.
+        ...         frame_sample_rate (`int`): Sample every n-th frame.
+        ...         seg_len (`int`): Maximum allowed index of sample's last frame.
+        ...     Returns:
+        ...         indices (`List[int]`): List of sampled frame indices
+        ...     '''
+        ...     converted_len = int(clip_len * frame_sample_rate)
+        ...     end_idx = np.random.randint(converted_len, seg_len)
+        ...     start_idx = end_idx - converted_len
+        ...     indices = np.linspace(start_idx, end_idx, num=clip_len)
+        ...     indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
+        ...     return indices
+
+
+        >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
+        >>> file_path = hf_hub_download(
+        ...     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+        ... )
+        >>> container = av.open(file_path)
+
+        >>> # sample 16 frames
+        >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
+        >>> video = read_video_pyav(container, indices)
+
+        >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
+        >>> model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32")
+
+        >>> pixel_values = processor(videos=list(video), return_tensors="pt").pixel_values
+
+        >>> batch_size, num_frames, num_channels, height, width = pixel_values.shape
+        >>> pixel_values = pixel_values.reshape(-1, num_channels, height, width)
+
+        >>> outputs = model(pixel_values)
+        >>> last_hidden_state = outputs.last_hidden_state
+        ```"""
+        return self.vision_model(
+            pixel_values=pixel_values,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+
+class XCLIPMultiframeIntegrationTransformer(nn.Module):
+    """
+    This corresponds to the `MultiframeIntegrationTransformer` class in the original implementation.
+    """
+
+    def __init__(self, config: XCLIPVisionConfig):
+        super().__init__()
+
+        self.position_embedding = nn.Parameter(torch.empty(1, config.num_frames, config.hidden_size))
+        self.encoder = XCLIPEncoder(config)
+
+    def forward(
+        self,
+        hidden_states,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutput]:
+        residual = hidden_states
+
+        # add position embeddings
+        hidden_states = hidden_states + self.position_embedding
+
+        encoder_outputs = self.encoder(
+            inputs_embeds=hidden_states,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        last_hidden_state = encoder_outputs[0]
+
+        last_hidden_state = last_hidden_state.type(hidden_states.dtype) + residual
+
+        pooled_output = last_hidden_state.mean(dim=1, keepdim=False)
+
+        if not return_dict:
+            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPooling(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+class XCLIPCrossAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config):
+        super().__init__()
+        self.num_heads = config.prompt_num_attention_heads
+
+        dim = config.projection_dim
+        head_dim = dim // self.num_heads
+        self.scale = head_dim**-0.5
+
+        self.q_proj = nn.Linear(dim, dim, bias=False)
+        self.k_proj = nn.Linear(dim, dim, bias=False)
+        self.v_proj = nn.Linear(dim, dim, bias=False)
+
+        self.attn_drop = nn.Dropout(config.prompt_attention_dropout)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(config.prompt_projection_dropout)
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
+        return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(self, queries, keys, values):
+        """Input shape: Batch x Time x Channel"""
+        batch_size, query_seq_len, hidden_size = queries.shape
+        batch_size, key_seq_len, hidden_size = keys.shape
+        queries = (
+            self.q_proj(queries)
+            .reshape(batch_size, query_seq_len, self.num_heads, hidden_size // self.num_heads)
+            .permute(0, 2, 1, 3)
+        )
+        keys = (
+            self.k_proj(keys)
+            .reshape(batch_size, key_seq_len, self.num_heads, hidden_size // self.num_heads)
+            .permute(0, 2, 1, 3)
+        )
+        values = (
+            self.v_proj(values)
+            .reshape(batch_size, key_seq_len, self.num_heads, hidden_size // self.num_heads)
+            .permute(0, 2, 1, 3)
+        )
+
+        attn = (queries @ keys.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ values).transpose(1, 2).reshape(batch_size, query_seq_len, hidden_size)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class PromptGeneratorLayer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        embed_dim = config.projection_dim
+        self.cross_attn = XCLIPCrossAttention(config)
+        self.norm1 = nn.LayerNorm(embed_dim, eps=config.text_config.layer_norm_eps)
+        self.norm3 = nn.LayerNorm(embed_dim, eps=config.text_config.layer_norm_eps)
+        self.mlp = nn.Sequential(
+            nn.Linear(embed_dim, embed_dim * 4),
+            ACT2FN[config.prompt_hidden_act],
+            nn.Dropout(config.prompt_attention_dropout),
+            nn.Linear(embed_dim * 4, embed_dim),
+        )
+
+    def forward(self, x, visual):
+        x = x + self.cross_attn(self.norm1(x), visual, visual)
+        x = x + self.mlp(self.norm3(x))
+        return x
+
+
+class XCLIPPromptGenerator(nn.Module):
+    """This corresponds to the `VideoSpecificPrompt` class in the original implementation."""
+
+    def __init__(self, config):
+        super().__init__()
+        embed_dim = config.projection_dim
+        self.layernorm = nn.LayerNorm(embed_dim, eps=config.vision_config.layer_norm_eps)
+        self.decoder = nn.ModuleList([PromptGeneratorLayer(config) for _ in range(config.prompt_layers)])
+        self.alpha = nn.Parameter(torch.ones(embed_dim) * config.prompt_alpha)
+
+    def forward(self, text, visual):
+        visual = self.layernorm(visual)
+        for layer in self.decoder:
+            text = layer(text, visual)
+
+        return self.alpha * text
+
+
+@add_start_docstrings(X_CLIP_START_DOCSTRING)
+class XCLIPModel(XCLIPPreTrainedModel):
+    config_class = XCLIPConfig
+
+    def __init__(self, config: XCLIPConfig):
+        super().__init__(config)
+
+        if not isinstance(config.text_config, XCLIPTextConfig):
+            raise TypeError(
+                "config.text_config is expected to be of type XCLIPTextConfig but is of type"
+                f" {type(config.text_config)}."
+            )
+
+        if not isinstance(config.vision_config, XCLIPVisionConfig):
+            raise TypeError(
+                "config.vision_config is expected to be of type XCLIPVisionConfig but is of type"
+                f" {type(config.vision_config)}."
+            )
+
+        text_config = config.text_config
+        vision_config = config.vision_config
+
+        self.projection_dim = config.projection_dim
+        self.text_embed_dim = text_config.hidden_size
+        self.vision_embed_dim = vision_config.hidden_size
+
+        self.text_model = XCLIPTextTransformer(text_config)
+        self.vision_model = XCLIPVisionTransformer(vision_config)
+
+        self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
+        self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
+        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
+
+        self.prompts_visual_layernorm = nn.LayerNorm(self.vision_embed_dim, eps=config.vision_config.layer_norm_eps)
+        self.prompts_visual_projection = nn.Parameter(torch.randn(self.vision_embed_dim, self.projection_dim))
+
+        mit_config = copy(vision_config)
+        mit_config.hidden_size = vision_config.mit_hidden_size
+        mit_config.intermediate_size = vision_config.mit_intermediate_size
+        mit_config.num_hidden_layers = vision_config.mit_num_hidden_layers
+        mit_config.num_attention_heads = vision_config.mit_num_attention_heads
+        self.mit = XCLIPMultiframeIntegrationTransformer(mit_config)
+
+        self.prompts_generator = XCLIPPromptGenerator(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(X_CLIP_TEXT_INPUTS_DOCSTRING)
+    def get_text_features(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> torch.FloatTensor:
+        r"""
+        Returns:
+            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+            applying the projection layer to the pooled output of [`XCLIPTextModel`].
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoTokenizer, AutoModel
+
+        >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/xclip-base-patch32")
+        >>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
+
+        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+        >>> text_features = model.get_text_features(**inputs)
+        ```"""
+        # Use X_CLIP model's config for some fields (if specified) instead of those of vision & text components.
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        text_outputs = self.text_model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        text_embeds = text_outputs[1]
+        text_embeds = self.text_projection(text_embeds)
+
+        return text_embeds
+
+    @add_start_docstrings_to_model_forward(X_CLIP_VISION_INPUTS_DOCSTRING)
+    def get_video_features(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> torch.FloatTensor:
+        r"""
+        Returns:
+            video_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The video embeddings obtained by
+            applying the projection layer to the pooled output of [`XCLIPVisionModel`] and
+            [`XCLIPMultiframeIntegrationTransformer`].
+
+        Examples:
+
+        ```python
+        >>> import av
+        >>> import torch
+        >>> import numpy as np
+
+        >>> from transformers import AutoProcessor, AutoModel
+        >>> from huggingface_hub import hf_hub_download
+
+        >>> np.random.seed(0)
+
+
+        >>> def read_video_pyav(container, indices):
+        ...     '''
+        ...     Decode the video with PyAV decoder.
+        ...     Args:
+        ...         container (`av.container.input.InputContainer`): PyAV container.
+        ...         indices (`List[int]`): List of frame indices to decode.
+        ...     Returns:
+        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
+        ...     '''
+        ...     frames = []
+        ...     container.seek(0)
+        ...     start_index = indices[0]
+        ...     end_index = indices[-1]
+        ...     for i, frame in enumerate(container.decode(video=0)):
+        ...         if i > end_index:
+        ...             break
+        ...         if i >= start_index and i in indices:
+        ...             frames.append(frame)
+        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+
+        >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
+        ...     '''
+        ...     Sample a given number of frame indices from the video.
+        ...     Args:
+        ...         clip_len (`int`): Total number of frames to sample.
+        ...         frame_sample_rate (`int`): Sample every n-th frame.
+        ...         seg_len (`int`): Maximum allowed index of sample's last frame.
+        ...     Returns:
+        ...         indices (`List[int]`): List of sampled frame indices
+        ...     '''
+        ...     converted_len = int(clip_len * frame_sample_rate)
+        ...     end_idx = np.random.randint(converted_len, seg_len)
+        ...     start_idx = end_idx - converted_len
+        ...     indices = np.linspace(start_idx, end_idx, num=clip_len)
+        ...     indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
+        ...     return indices
+
+
+        >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
+        >>> file_path = hf_hub_download(
+        ...     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+        ... )
+        >>> container = av.open(file_path)
+
+        >>> # sample 8 frames
+        >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
+        >>> video = read_video_pyav(container, indices)
+
+        >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
+        >>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
+
+        >>> inputs = processor(videos=list(video), return_tensors="pt")
+
+        >>> video_features = model.get_video_features(**inputs)
+        ```"""
+        # Use X_CLIP model's config for some fields (if specified) instead of those of vision & text components.
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        batch_size, num_frames, num_channels, height, width = pixel_values.shape
+        pixel_values = pixel_values.reshape(-1, num_channels, height, width)
+
+        vision_outputs = self.vision_model(
+            pixel_values=pixel_values,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        video_embeds = vision_outputs[1]
+        video_embeds = self.visual_projection(video_embeds)
+
+        cls_features = video_embeds.view(batch_size, num_frames, -1)
+
+        mit_outputs = self.mit(
+            cls_features,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        video_embeds = mit_outputs[1]
+
+        return video_embeds
+
+    @add_start_docstrings_to_model_forward(X_CLIP_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=XCLIPOutput, config_class=XCLIPConfig)
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        return_loss: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: bool = False,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, XCLIPOutput]:
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> import av
+        >>> import torch
+        >>> import numpy as np
+
+        >>> from transformers import AutoProcessor, AutoModel
+        >>> from huggingface_hub import hf_hub_download
+
+        >>> np.random.seed(0)
+
+
+        >>> def read_video_pyav(container, indices):
+        ...     '''
+        ...     Decode the video with PyAV decoder.
+        ...     Args:
+        ...         container (`av.container.input.InputContainer`): PyAV container.
+        ...         indices (`List[int]`): List of frame indices to decode.
+        ...     Returns:
+        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
+        ...     '''
+        ...     frames = []
+        ...     container.seek(0)
+        ...     start_index = indices[0]
+        ...     end_index = indices[-1]
+        ...     for i, frame in enumerate(container.decode(video=0)):
+        ...         if i > end_index:
+        ...             break
+        ...         if i >= start_index and i in indices:
+        ...             frames.append(frame)
+        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+
+        >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
+        ...     '''
+        ...     Sample a given number of frame indices from the video.
+        ...     Args:
+        ...         clip_len (`int`): Total number of frames to sample.
+        ...         frame_sample_rate (`int`): Sample every n-th frame.
+        ...         seg_len (`int`): Maximum allowed index of sample's last frame.
+        ...     Returns:
+        ...         indices (`List[int]`): List of sampled frame indices
+        ...     '''
+        ...     converted_len = int(clip_len * frame_sample_rate)
+        ...     end_idx = np.random.randint(converted_len, seg_len)
+        ...     start_idx = end_idx - converted_len
+        ...     indices = np.linspace(start_idx, end_idx, num=clip_len)
+        ...     indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
+        ...     return indices
+
+
+        >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
+        >>> file_path = hf_hub_download(
+        ...     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+        ... )
+        >>> container = av.open(file_path)
+
+        >>> # sample 8 frames
+        >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
+        >>> video = read_video_pyav(container, indices)
+
+        >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
+        >>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
+
+        >>> inputs = processor(
+        ...     text=["playing sports", "eating spaghetti", "go shopping"],
+        ...     videos=list(video),
+        ...     return_tensors="pt",
+        ...     padding=True,
+        ... )
+
+        >>> # forward pass
+        >>> with torch.no_grad():
+        ...     outputs = model(**inputs)
+
+        >>> logits_per_video = outputs.logits_per_video  # this is the video-text similarity score
+        >>> probs = logits_per_video.softmax(dim=1)  # we can take the softmax to get the label probabilities
+        >>> print(probs)
+        tensor([[1.9496e-04, 9.9960e-01, 2.0825e-04]])
+        ```"""
+        # Use X_CLIP model's config for some fields (if specified) instead of those of vision & text components.
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        batch_size, num_frames, num_channels, height, width = pixel_values.shape
+        pixel_values = pixel_values.reshape(-1, num_channels, height, width)
+
+        vision_outputs = self.vision_model(
+            pixel_values=pixel_values,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            return_dict=return_dict,
+        )
+
+        video_embeds = vision_outputs[1]
+        video_embeds = self.visual_projection(video_embeds)
+
+        cls_features = video_embeds.view(batch_size, num_frames, -1)
+
+        mit_outputs = self.mit(
+            cls_features,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        video_embeds = mit_outputs[1]
+
+        img_features = vision_outputs[0][:, 1:, :]
+        img_features = self.prompts_visual_layernorm(img_features)
+        img_features = img_features @ self.prompts_visual_projection
+        img_features = img_features.view(batch_size, num_frames, -1, video_embeds.shape[-1])
+        img_features = img_features.mean(dim=1, keepdim=False)
+
+        text_outputs = self.text_model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        text_embeds = text_outputs[1]
+        text_embeds = self.text_projection(text_embeds)
+
+        text_embeds = text_embeds.unsqueeze(0).expand(batch_size, -1, -1)
+        text_embeds = text_embeds + self.prompts_generator(text_embeds, img_features)
+
+        # normalized features
+        video_embeds = video_embeds / video_embeds.norm(p=2, dim=-1, keepdim=True)
+        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+        # cosine similarity as logits
+        logit_scale = self.logit_scale.exp()
+        logits_per_video = torch.einsum("bd,bkd->bk", video_embeds, logit_scale * text_embeds)
+        logits_per_text = logits_per_video.T
+
+        loss = None
+        if return_loss:
+            loss = x_clip_loss(logits_per_text)
+
+        if not return_dict:
+            output = (logits_per_video, logits_per_text, text_embeds, video_embeds, text_outputs, vision_outputs)
+            return ((loss,) + output) if loss is not None else output
+
+        return XCLIPOutput(
+            loss=loss,
+            logits_per_video=logits_per_video,
+            logits_per_text=logits_per_text,
+            text_embeds=text_embeds,
+            video_embeds=video_embeds,
+            text_model_output=text_outputs,
+            vision_model_output=vision_outputs,
+            mit_output=mit_outputs,
+        )
+
+
+__all__ = ["XCLIPModel", "XCLIPPreTrainedModel", "XCLIPTextModel", "XCLIPVisionModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/processing_x_clip.py b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/processing_x_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a17d3a15a208461e4b1f9f12ccfc084df73fa0d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/processing_x_clip.py
@@ -0,0 +1,151 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for XCLIP
+"""
+
+import warnings
+
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding
+
+
+class XCLIPProcessor(ProcessorMixin):
+    r"""
+    Constructs an X-CLIP processor which wraps a VideoMAE image processor and a CLIP tokenizer into a single processor.
+
+    [`XCLIPProcessor`] offers all the functionalities of [`VideoMAEImageProcessor`] and [`CLIPTokenizerFast`]. See the
+    [`~XCLIPProcessor.__call__`] and [`~XCLIPProcessor.decode`] for more information.
+
+    Args:
+        image_processor ([`VideoMAEImageProcessor`], *optional*):
+            The image processor is a required input.
+        tokenizer ([`CLIPTokenizerFast`], *optional*):
+            The tokenizer is a required input.
+    """
+
+    attributes = ["image_processor", "tokenizer"]
+    image_processor_class = "VideoMAEImageProcessor"
+    tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
+
+    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
+        feature_extractor = None
+        if "feature_extractor" in kwargs:
+            warnings.warn(
+                "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
+                " instead.",
+                FutureWarning,
+            )
+            feature_extractor = kwargs.pop("feature_extractor")
+
+        image_processor = image_processor if image_processor is not None else feature_extractor
+        if image_processor is None:
+            raise ValueError("You need to specify an `image_processor`.")
+        if tokenizer is None:
+            raise ValueError("You need to specify a `tokenizer`.")
+
+        super().__init__(image_processor, tokenizer)
+        self.current_processor = self.image_processor
+
+    def __call__(self, text=None, videos=None, return_tensors=None, **kwargs):
+        """
+        Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+        and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
+        the text. To prepare the image(s), this method forwards the `videos` and `kwargs` arguments to
+        VideoMAEImageProcessor's [`~VideoMAEImageProcessor.__call__`] if `videos` is not `None`. Please refer to the
+        doctsring of the above two methods for more information.
+
+        Args:
+            text (`str`, `List[str]`, `List[List[str]]`):
+                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+            videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`,:
+                `List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list
+                of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors,
+                each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of
+                channels.
+
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors of a particular framework. Acceptable values are:
+
+                - `'tf'`: Return TensorFlow `tf.constant` objects.
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                - `'np'`: Return NumPy `np.ndarray` objects.
+                - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+        Returns:
+            [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+              `None`).
+            - **pixel_values** -- Pixel values to be fed to a model. Returned when `videos` is not `None`.
+        """
+
+        if text is None and videos is None:
+            raise ValueError("You have to specify either text or videos. Both cannot be none.")
+
+        if text is not None:
+            encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
+
+        if videos is not None:
+            image_features = self.image_processor(videos, return_tensors=return_tensors, **kwargs)
+
+        if text is not None and videos is not None:
+            encoding["pixel_values"] = image_features.pixel_values
+            return encoding
+        elif text is not None:
+            return encoding
+        else:
+            return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
+
+    def batch_decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+        refer to the docstring of this method for more information.
+        """
+        return self.tokenizer.batch_decode(*args, **kwargs)
+
+    def decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+        the docstring of this method for more information.
+        """
+        return self.tokenizer.decode(*args, **kwargs)
+
+    @property
+    def model_input_names(self):
+        return ["input_ids", "attention_mask", "position_ids", "pixel_values"]
+
+    @property
+    def feature_extractor_class(self):
+        warnings.warn(
+            "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
+            FutureWarning,
+        )
+        return self.image_processor_class
+
+    @property
+    def feature_extractor(self):
+        warnings.warn(
+            "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
+            FutureWarning,
+        )
+        return self.image_processor
+
+
+__all__ = ["XCLIPProcessor"]