"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ prefix_token (`str`, *optional*, defaults to `"▁"`):
+ Prefix token used for infilling.
+ middle_token (`str`, *optional*, defaults to `"▁"`):
+ Middle token used for infilling.
+ suffix_token (`str`, *optional*, defaults to `"▁"`):
+ Suffix token used for infilling.
+ eot_token (`str`, *optional*, defaults to `"▁"`):
+ End of text token used for infilling.
+ fill_token (`str`, *optional*, defaults to `""`):
+ The token used to split the input between the prefix and suffix.
+ suffix_first (`bool`, *optional*, defaults to `False`):
+ Whether the input prompt and suffix should be formatted with the suffix first.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether to add a beginning of sequence token at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Whether to add an end of sequence token at the end of sequences.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to clean up the tokenization spaces.
+ additional_special_tokens (`List[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not the default system prompt for Llama should be used.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ prefix_token="▁",
+ middle_token="▁",
+ suffix_token="▁",
+ eot_token="▁",
+ fill_token="",
+ suffix_first=False,
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ clean_up_tokenization_spaces=False,
+ additional_special_tokens=None,
+ use_default_system_prompt=False,
+ **kwargs,
+ ):
+ requires_backends(self, "protobuf")
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+
+ self.use_default_system_prompt = use_default_system_prompt
+ # mark tokens special to skip them
+ additional_special_tokens = additional_special_tokens or []
+ for token in [prefix_token, middle_token, suffix_token, eot_token]:
+ additional_special_tokens += [token] if token is not None else []
+
+ self.vocab_file = vocab_file
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self._prefix_token = prefix_token
+ self._middle_token = middle_token
+ self._suffix_token = suffix_token
+ self._eot_token = eot_token
+ self.fill_token = fill_token
+ self.suffix_first = suffix_first
+ self.sp_model = self.get_spm_processor()
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ prefix_token=prefix_token,
+ middle_token=middle_token,
+ suffix_token=suffix_token,
+ eot_token=eot_token,
+ fill_token=fill_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ suffix_first=suffix_first,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ additional_special_tokens=additional_special_tokens,
+ use_default_system_prompt=use_default_system_prompt,
+ **kwargs,
+ )
+
+ @property
+ def unk_token_length(self):
+ return len(self.sp_model.encode(str(self.unk_token)))
+
+ def get_spm_processor(self):
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ with open(self.vocab_file, "rb") as f:
+ sp_model = f.read()
+ model_pb2 = import_protobuf()
+ model = model_pb2.ModelProto.FromString(sp_model)
+ normalizer_spec = model_pb2.NormalizerSpec()
+ normalizer_spec.add_dummy_prefix = False
+ model.normalizer_spec.MergeFrom(normalizer_spec)
+ sp_model = model.SerializeToString()
+ tokenizer.LoadFromSerializedProto(sp_model)
+ return tokenizer
+
+ @property
+ def prefix_token(self):
+ return self._prefix_token
+
+ @property
+ def prefix_id(self):
+ if self._prefix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.prefix_token)
+
+ @property
+ def middle_token(self):
+ return self._middle_token
+
+ @property
+ def middle_id(self):
+ if self._middle_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.middle_token)
+
+ @property
+ def suffix_token(self):
+ return self._suffix_token
+
+ @property
+ def suffix_id(self):
+ if self._suffix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.suffix_token)
+
+ @property
+ def eot_token(self):
+ return self._eot_token
+
+ @property
+ def eot_id(self):
+ if self._eot_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.eot_token)
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_model.get_piece_size()
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> List[int]:
+ # add a prefix space to `prefix`
+ if self.fill_token is not None and self.fill_token in prefix and suffix is None:
+ prefix, suffix = prefix.split(self.fill_token)
+
+ if len(prefix) > 0:
+ prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
+
+ if suffix is None or len(suffix) < 1:
+ tokens = super().tokenize(prefix, **kwargs)
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
+ tokens = tokens[1:]
+ return tokens
+
+ prefix_tokens = self._tokenize(prefix) # prefix has an extra `SPIECE_UNDERLINE`
+
+ if None in (self.prefix_id, self.middle_id, self.suffix_id):
+ raise ValueError(
+ "The input either includes a `prefix` and a `suffix` used for the infilling task,"
+ f" or can be split on the {self.fill_token} token, creating a suffix and prefix,"
+ " but the model does not support `infilling`."
+ )
+ suffix_tokens = self._tokenize(suffix) # make sure CodeLlama sp model does not mess up
+
+ suffix_first = suffix_first if suffix_first is not None else self.suffix_first
+ if suffix_first:
+ # format as " {suf} {pre}"
+ return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
+ else:
+ # format as " {pre} {suf} "
+ return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
+
+ def _tokenize(self, text, **kwargs):
+ """
+ Returns a tokenized string.
+
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
+ `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
+ `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
+ """
+ tokens = self.sp_model.encode(text, out_type=str)
+ if not text.startswith((SPIECE_UNDERLINE, " ")):
+ return tokens
+ # 1. Encode string + prefix ex: " Hey"
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.piece_to_id(token)
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = self.sp_model.IdToPiece(index)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ # since we manually add the prefix space, we have to remove it when decoding
+ if tokens[0].startswith(SPIECE_UNDERLINE):
+ tokens[0] = tokens[0][1:]
+
+ current_sub_tokens = []
+ out_string = ""
+ for _, token in enumerate(tokens):
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = bos_token_id + token_ids_0 + eos_token_id
+
+ if token_ids_1 is not None:
+ output = output + bos_token_id + token_ids_1 + eos_token_id
+
+ return output
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ bos_token_id = [1] if self.add_bos_token else []
+ eos_token_id = [1] if self.add_eos_token else []
+
+ if token_ids_1 is None:
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+ return (
+ bos_token_id
+ + ([0] * len(token_ids_0))
+ + eos_token_id
+ + bos_token_id
+ + ([0] * len(token_ids_1))
+ + eos_token_id
+ )
+
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+ sequence pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of ids.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+ if token_ids_1 is not None:
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+ return output
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+
+__all__ = ["CodeLlamaTokenizer"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bc831cdd6a15b0af85982173e903b7462e9a2a2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
@@ -0,0 +1,381 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import List, Optional, Tuple
+
+from tokenizers import normalizers, processors
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+from ...utils.versions import require_version
+
+
+require_version("tokenizers>=0.13.3")
+
+if is_sentencepiece_available():
+ from .tokenization_code_llama import CodeLlamaTokenizer
+else:
+ CodeLlamaTokenizer = None
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
+
+SPIECE_UNDERLINE = "▁"
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+# fmt: off
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information."""
+# fmt: on
+
+
+class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ This uses notably ByteFallback and no normalization.
+
+ ```python
+ >>> from transformers import CodeLlamaTokenizerFast
+
+ >>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
+ >>> tokenizer.encode("Hello this is a test")
+ [1, 15043, 445, 338, 263, 1243]
+ ```
+
+ If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
+ call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
+ values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
+ [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods. The default configuration match that of
+ [meta-llama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
+ which supports prompt infilling.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ tokenizer_file (`str`, *optional*):
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+ contains everything needed to load the tokenizer.
+ clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
+ Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
+ spaces.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+ prefix_token (`str`, *optional*, defaults to `"▁"`):
+ Prefix token used for infilling.
+ middle_token (`str`, *optional*, defaults to `"▁"`):
+ Middle token used for infilling.
+ suffix_token (`str`, *optional*, defaults to `"▁"`):
+ Suffix token used for infilling.
+ eot_token (`str`, *optional*, defaults to `"▁"`):
+ End of text token used for infilling.
+ fill_token (`str`, *optional*, defaults to `""`):
+ The token used to split the input between the prefix and suffix.
+ additional_special_tokens (`List[str]`, *optional*):
+ Additional special tokens used by the tokenizer.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether to add a beginning of sequence token at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Whether to add an end of sequence token at the end of sequences.
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not the default system prompt for Llama should be used.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = CodeLlamaTokenizer
+ padding_side = "left"
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ clean_up_tokenization_spaces=False,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ prefix_token="▁",
+ middle_token="▁",
+ suffix_token="▁",
+ eot_token="▁",
+ fill_token="",
+ additional_special_tokens=None,
+ add_bos_token=True,
+ add_eos_token=False,
+ use_default_system_prompt=False,
+ **kwargs,
+ ):
+ # mark tokens special to skip them
+ additional_special_tokens = additional_special_tokens or []
+ for token in [prefix_token, middle_token, suffix_token, eot_token]:
+ additional_special_tokens += [token] if token is not None else []
+ self.use_default_system_prompt = use_default_system_prompt
+
+ super().__init__(
+ vocab_file=vocab_file,
+ tokenizer_file=tokenizer_file,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ additional_special_tokens=additional_special_tokens,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ prefix_token=prefix_token,
+ middle_token=middle_token,
+ suffix_token=suffix_token,
+ eot_token=eot_token,
+ fill_token=fill_token,
+ use_default_system_prompt=use_default_system_prompt,
+ **kwargs,
+ )
+ self._add_bos_token = add_bos_token
+ self._add_eos_token = add_eos_token
+ self.update_post_processor()
+
+ self.vocab_file = vocab_file
+
+ self._prefix_token = prefix_token
+ self._middle_token = middle_token
+ self._suffix_token = suffix_token
+ self._eot_token = eot_token
+ self.fill_token = fill_token
+
+ @property
+ def can_save_slow_tokenizer(self) -> bool:
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+ # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
+ def update_post_processor(self):
+ """
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
+ """
+ bos = self.bos_token
+ bos_token_id = self.bos_token_id
+ if bos is None and self.add_bos_token:
+ raise ValueError("add_bos_token = True but bos_token = None")
+
+ eos = self.eos_token
+ eos_token_id = self.eos_token_id
+ if eos is None and self.add_eos_token:
+ raise ValueError("add_eos_token = True but eos_token = None")
+
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
+
+ special_tokens = []
+ if self.add_bos_token:
+ special_tokens.append((bos, bos_token_id))
+ if self.add_eos_token:
+ special_tokens.append((eos, eos_token_id))
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=single, pair=pair, special_tokens=special_tokens
+ )
+
+ @property
+ def prefix_token(self):
+ return self._prefix_token
+
+ @property
+ def prefix_id(self):
+ if self._prefix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.prefix_token)
+
+ @property
+ def middle_token(self):
+ return self._middle_token
+
+ @property
+ def middle_id(self):
+ if self._middle_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.middle_token)
+
+ @property
+ def suffix_token(self):
+ return self._suffix_token
+
+ @property
+ def suffix_id(self):
+ if self._suffix_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.suffix_token)
+
+ @property
+ def eot_id(self):
+ if self._eot_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.eot_token)
+
+ @property
+ def eot_token(self):
+ return self._eot_token
+
+ @property
+ def add_eos_token(self):
+ return self._add_eos_token
+
+ @property
+ def add_bos_token(self):
+ return self._add_bos_token
+
+ @add_eos_token.setter
+ def add_eos_token(self, value):
+ self._add_eos_token = value
+ self.update_post_processor()
+
+ @add_bos_token.setter
+ def add_bos_token(self, value):
+ self._add_bos_token = value
+ self.update_post_processor()
+
+ def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
+ """
+ Updates the normalizer to make sure the prompt format for `infilling` is respected. The infilling format is the
+ following: if suffix_first
+ " {suf} {pre}"
+ else:
+ " {pre} {suf} "
+
+ If `reset` is set to `True`, the `normalizer` and `post_processor` are reset to their "normal" behaviour, which
+ is to add a prefix space for the normalizer, and add a `bos_token` to the input text for the `post_processor`.
+ """
+ if reset:
+ self._tokenizer.normalizer = normalizers.Sequence(
+ [
+ normalizers.Prepend(prepend="▁"),
+ normalizers.Replace(pattern=" ", content="▁"),
+ ]
+ )
+ self.update_post_processor()
+ return
+
+ self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
+ pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
+ special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
+ if suffix_first:
+ # format as " {suf} {pre}"
+ pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
+ special_tokens += [
+ (self.prefix_token, self.prefix_id),
+ (self.suffix_token, self.suffix_id),
+ (self.middle_token, self.middle_id),
+ ]
+ else:
+ # format as " {pre} {suf} "
+ pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
+ special_tokens += [
+ (self.prefix_token, self.prefix_id),
+ (self.suffix_token, self.suffix_id),
+ (self.middle_token, self.middle_id),
+ ]
+
+ if self.add_eos_token and add_special_tokens:
+ pair += [self.eos_token]
+ special_tokens += [(self.eos_token, self.eos_token_id)]
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single="$A", pair=pair, special_tokens=special_tokens
+ )
+
+ def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
+ # hack to make sure the input is pre-process but outside rust
+ text_pair = kwargs.pop("suffix", text_pair)
+ if self.fill_token is not None and self.fill_token in text and text_pair is None:
+ text, text_pair = text.split(self.fill_token)
+
+ if text_pair is None or len(text_pair) < 1:
+ return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
+
+ if None in (self.prefix_id, self.middle_id, self.suffix_id):
+ raise ValueError(
+ "Then input includes a `prefix` and a `suffix` used for the infilling task,"
+ " the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
+ f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
+ )
+
+ self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
+ tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
+ self.set_infilling_processor(True)
+ return tokens
+
+ # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. The special tokens depend on calling set_lang.
+
+ An NLLB sequence has the following format, where `X` represents the sequence:
+
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return self.bos_token_id + token_ids_0 + self.eos_token_id
+ return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
+
+
+__all__ = ["CodeLlamaTokenizerFast"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..796b9a48926fe6239195588496012842cf355299
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_convnext import *
+ from .feature_extraction_convnext import *
+ from .image_processing_convnext import *
+ from .modeling_convnext import *
+ from .modeling_tf_convnext import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0377936552b0fb20458a7acadc29f4244df0504
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26c7a4a5a4bae2df53bafe2d51eb266d4788ab4f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f9ed3bfd469df4a8ea03966de43afb9fb67623c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py
@@ -0,0 +1,142 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ConvNeXT model configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
+ ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the ConvNeXT
+ [facebook/convnext-tiny-224](https://huggingface.co/facebook/convnext-tiny-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ patch_size (`int`, *optional*, defaults to 4):
+ Patch size to use in the patch embedding layer.
+ num_stages (`int`, *optional*, defaults to 4):
+ The number of stages in the model.
+ hidden_sizes (`List[int]`, *optional*, defaults to [96, 192, 384, 768]):
+ Dimensionality (hidden size) at each stage.
+ depths (`List[int]`, *optional*, defaults to [3, 3, 9, 3]):
+ Depth (number of blocks) for each stage.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ layer_scale_init_value (`float`, *optional*, defaults to 1e-6):
+ The initial value for the layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The drop rate for stochastic depth.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`List[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+ ```python
+ >>> from transformers import ConvNextConfig, ConvNextModel
+
+ >>> # Initializing a ConvNext convnext-tiny-224 style configuration
+ >>> configuration = ConvNextConfig()
+
+ >>> # Initializing a model (with random weights) from the convnext-tiny-224 style configuration
+ >>> model = ConvNextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "convnext"
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_size=4,
+ num_stages=4,
+ hidden_sizes=None,
+ depths=None,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ layer_scale_init_value=1e-6,
+ drop_path_rate=0.0,
+ image_size=224,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.num_stages = num_stages
+ self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
+ self.depths = [3, 3, 9, 3] if depths is None else depths
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.layer_scale_init_value = layer_scale_init_value
+ self.drop_path_rate = drop_path_rate
+ self.image_size = image_size
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+class ConvNextOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-5
+
+
+__all__ = ["ConvNextConfig", "ConvNextOnnxConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b2208e5b1134fd998e767a15e1639a097b5fa41
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py
@@ -0,0 +1,36 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for ConvNeXT."""
+
+import warnings
+
+from ...utils import logging
+from .image_processing_convnext import ConvNextImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextFeatureExtractor(ConvNextImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+ " Please use ConvNextImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["ConvNextFeatureExtractor"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..35cbc91797add39d04a436e67c6f8c1f8c4ad1ce
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py
@@ -0,0 +1,323 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ConvNeXT."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ center_crop,
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a ConvNeXT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
+ by `do_resize` in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
+ Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
+ resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
+ be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
+ be overriden by `size` in the `preprocess` method.
+ crop_pct (`float` *optional*, defaults to 224 / 256):
+ Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
+ overriden by `crop_pct` in the `preprocess` method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ crop_pct: float = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 384}
+ size = get_size_dict(size, default_to_square=False)
+
+ self.do_resize = do_resize
+ self.size = size
+ # Default value set here for backwards compatibility where the value in config is None
+ self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ crop_pct: float,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
+ `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
+ Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
+ after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
+ crop_pct (`float`):
+ Percentage of the image to crop. Only has an effect if size < 384.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
+ image.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" not in size:
+ raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
+ shortest_edge = size["shortest_edge"]
+
+ if shortest_edge < 384:
+ # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
+ resize_shortest_edge = int(shortest_edge / crop_pct)
+ resize_size = get_resize_output_image_size(
+ image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format
+ )
+ image = resize(
+ image=image,
+ size=resize_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ # then crop to (shortest_edge, shortest_edge)
+ return center_crop(
+ image=image,
+ size=(shortest_edge, shortest_edge),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ else:
+ # warping (no cropping) when evaluated at 384 or larger
+ return resize(
+ image,
+ size=(shortest_edge, shortest_edge),
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool = None,
+ size: Dict[str, int] = None,
+ crop_pct: float = None,
+ resample: PILImageResampling = None,
+ do_rescale: bool = None,
+ rescale_factor: float = None,
+ do_normalize: bool = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
+ is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
+ image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
+ crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
+ Percentage of the image to crop if size < 384.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ crop_pct = crop_pct if crop_pct is not None else self.crop_pct
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_resize:
+ images = [
+ self.resize(
+ image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
+ )
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["ConvNextImageProcessor"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..155f466ac4ae68034752d9cf11d9e0d5fdfdee69
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py
@@ -0,0 +1,551 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ConvNext model."""
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "ConvNextConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/convnext-tiny-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
+class ConvNextDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class ConvNextLayerNorm(nn.Module):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.data_format == "channels_last":
+ x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ elif self.data_format == "channels_first":
+ input_dtype = x.dtype
+ x = x.float()
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = x.to(dtype=input_dtype)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class ConvNextEmbeddings(nn.Module):
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
+ found in src/transformers/models/swin/modeling_swin.py.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.patch_embeddings = nn.Conv2d(
+ config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
+ )
+ self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
+ self.num_channels = config.num_channels
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.layernorm(embeddings)
+ return embeddings
+
+
+class ConvNextLayer(nn.Module):
+ """This corresponds to the `Block` class in the original implementation.
+
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+ The authors used (2) as they find it slightly faster in PyTorch.
+
+ Args:
+ config ([`ConvNextConfig`]): Model configuration class.
+ dim (`int`): Number of input channels.
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
+ """
+
+ def __init__(self, config, dim, drop_path=0):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = ACT2FN[config.hidden_act]
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.layer_scale_parameter = (
+ nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ if config.layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
+ input = hidden_states
+ x = self.dwconv(hidden_states)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.layernorm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.layer_scale_parameter is not None:
+ x = self.layer_scale_parameter * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class ConvNextStage(nn.Module):
+ """ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+ Args:
+ config ([`ConvNextConfig`]): Model configuration class.
+ in_channels (`int`): Number of input channels.
+ out_channels (`int`): Number of output channels.
+ depth (`int`): Number of residual blocks.
+ drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
+ """
+
+ def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
+ super().__init__()
+
+ if in_channels != out_channels or stride > 1:
+ self.downsampling_layer = nn.Sequential(
+ ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
+ )
+ else:
+ self.downsampling_layer = nn.Identity()
+ drop_path_rates = drop_path_rates or [0.0] * depth
+ self.layers = nn.Sequential(
+ *[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
+ hidden_states = self.downsampling_layer(hidden_states)
+ hidden_states = self.layers(hidden_states)
+ return hidden_states
+
+
+class ConvNextEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.stages = nn.ModuleList()
+ drop_path_rates = [
+ x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
+ ]
+ prev_chs = config.hidden_sizes[0]
+ for i in range(config.num_stages):
+ out_chs = config.hidden_sizes[i]
+ stage = ConvNextStage(
+ config,
+ in_channels=prev_chs,
+ out_channels=out_chs,
+ stride=2 if i > 0 else 1,
+ depth=config.depths[i],
+ drop_path_rates=drop_path_rates[i],
+ )
+ self.stages.append(stage)
+ prev_chs = out_chs
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.stages):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = layer_module(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ )
+
+
+class ConvNextPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ConvNextConfig
+ base_model_prefix = "convnext"
+ main_input_name = "pixel_values"
+ _no_split_modules = ["ConvNextLayer"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+CONVNEXT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CONVNEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`ConvNextImageProcessor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ConvNext model outputting raw features without any specific head on top.",
+ CONVNEXT_START_DOCSTRING,
+)
+class ConvNextModel(ConvNextPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = ConvNextEmbeddings(config)
+ self.encoder = ConvNextEncoder(config)
+
+ # final layernorm layer
+ self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+
+ # global average pooling, (N, C, H, W) -> (N, C)
+ pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ CONVNEXT_START_DOCSTRING,
+)
+class ConvNextForImageClassification(ConvNextPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.convnext = ConvNextModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.convnext(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
+ """,
+ CONVNEXT_START_DOCSTRING,
+)
+class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.embeddings = ConvNextEmbeddings(config)
+ self.encoder = ConvNextEncoder(config)
+ self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = {}
+ for stage, num_channels in zip(self._out_features, self.channels):
+ hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = processor(image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (hidden_states,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=hidden_states if output_hidden_states else None,
+ attentions=None,
+ )
+
+
+__all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py b/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..af1bae81db495f17d5fe1696efa020a8b4d28af2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py
@@ -0,0 +1,669 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 ConvNext model."""
+
+from __future__ import annotations
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
+from ...modeling_tf_utils import (
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "ConvNextConfig"
+_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
+
+
+class TFConvNextDropPath(keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ References:
+ (1) github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path: float, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path = drop_path
+
+ def call(self, x: tf.Tensor, training=None):
+ if training:
+ keep_prob = 1 - self.drop_path
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+class TFConvNextEmbeddings(keras.layers.Layer):
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
+ found in src/transformers/models/swin/modeling_swin.py.
+ """
+
+ def __init__(self, config: ConvNextConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.patch_embeddings = keras.layers.Conv2D(
+ filters=config.hidden_sizes[0],
+ kernel_size=config.patch_size,
+ strides=config.patch_size,
+ name="patch_embeddings",
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ )
+ self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
+ self.num_channels = config.num_channels
+ self.config = config
+
+ def call(self, pixel_values):
+ if isinstance(pixel_values, dict):
+ pixel_values = pixel_values["pixel_values"]
+
+ tf.debugging.assert_equal(
+ shape_list(pixel_values)[1],
+ self.num_channels,
+ message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
+ )
+
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.layernorm(embeddings)
+ return embeddings
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "patch_embeddings", None) is not None:
+ with tf.name_scope(self.patch_embeddings.name):
+ self.patch_embeddings.build([None, None, None, self.config.num_channels])
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
+
+
+class TFConvNextLayer(keras.layers.Layer):
+ """This corresponds to the `Block` class in the original implementation.
+
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+ The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
+ NHWC ordering, we can just apply the operations straight-away without the permutation.
+
+ Args:
+ config ([`ConvNextConfig`]): Model configuration class.
+ dim (`int`): Number of input channels.
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
+ """
+
+ def __init__(self, config, dim, drop_path=0.0, **kwargs):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.config = config
+ self.dwconv = keras.layers.Conv2D(
+ filters=dim,
+ kernel_size=7,
+ padding="same",
+ groups=dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="dwconv",
+ ) # depthwise conv
+ self.layernorm = keras.layers.LayerNormalization(
+ epsilon=1e-6,
+ name="layernorm",
+ )
+ self.pwconv1 = keras.layers.Dense(
+ units=4 * dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="pwconv1",
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = get_tf_activation(config.hidden_act)
+ self.pwconv2 = keras.layers.Dense(
+ units=dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="pwconv2",
+ )
+ # Using `layers.Activation` instead of `tf.identity` to better control `training`
+ # behaviour.
+ self.drop_path = (
+ TFConvNextDropPath(drop_path, name="drop_path")
+ if drop_path > 0.0
+ else keras.layers.Activation("linear", name="drop_path")
+ )
+
+ def build(self, input_shape: tf.TensorShape = None):
+ # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
+ self.layer_scale_parameter = (
+ self.add_weight(
+ shape=(self.dim,),
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
+ trainable=True,
+ name="layer_scale_parameter",
+ )
+ if self.config.layer_scale_init_value > 0
+ else None
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dwconv", None) is not None:
+ with tf.name_scope(self.dwconv.name):
+ self.dwconv.build([None, None, None, self.dim])
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, None, self.dim])
+ if getattr(self, "pwconv1", None) is not None:
+ with tf.name_scope(self.pwconv1.name):
+ self.pwconv1.build([None, None, self.dim])
+ if getattr(self, "pwconv2", None) is not None:
+ with tf.name_scope(self.pwconv2.name):
+ self.pwconv2.build([None, None, 4 * self.dim])
+ if getattr(self, "drop_path", None) is not None:
+ with tf.name_scope(self.drop_path.name):
+ self.drop_path.build(None)
+
+ def call(self, hidden_states, training=False):
+ input = hidden_states
+ x = self.dwconv(hidden_states)
+ x = self.layernorm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.layer_scale_parameter is not None:
+ x = self.layer_scale_parameter * x
+
+ x = input + self.drop_path(x, training=training)
+ return x
+
+
+class TFConvNextStage(keras.layers.Layer):
+ """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+ Args:
+ config (`ConvNextV2Config`):
+ Model configuration class.
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`):
+ Number of output channels.
+ depth (`int`):
+ Number of residual blocks.
+ drop_path_rates(`List[float]`):
+ Stochastic depth rates for each layer.
+ """
+
+ def __init__(
+ self,
+ config: ConvNextConfig,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 2,
+ stride: int = 2,
+ depth: int = 2,
+ drop_path_rates: Optional[List[float]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if in_channels != out_channels or stride > 1:
+ self.downsampling_layer = [
+ keras.layers.LayerNormalization(
+ epsilon=1e-6,
+ name="downsampling_layer.0",
+ ),
+ # Inputs to this layer will follow NHWC format since we
+ # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings`
+ # layer. All the outputs throughout the model will be in NHWC
+ # from this point on until the output where we again change to
+ # NCHW.
+ keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=kernel_size,
+ strides=stride,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer=keras.initializers.Zeros(),
+ name="downsampling_layer.1",
+ ),
+ ]
+ else:
+ self.downsampling_layer = [tf.identity]
+
+ drop_path_rates = drop_path_rates or [0.0] * depth
+ self.layers = [
+ TFConvNextLayer(
+ config,
+ dim=out_channels,
+ drop_path=drop_path_rates[j],
+ name=f"layers.{j}",
+ )
+ for j in range(depth)
+ ]
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.stride = stride
+
+ def call(self, hidden_states):
+ for layer in self.downsampling_layer:
+ hidden_states = layer(hidden_states)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+ if self.in_channels != self.out_channels or self.stride > 1:
+ with tf.name_scope(self.downsampling_layer[0].name):
+ self.downsampling_layer[0].build([None, None, None, self.in_channels])
+ with tf.name_scope(self.downsampling_layer[1].name):
+ self.downsampling_layer[1].build([None, None, None, self.in_channels])
+
+
+class TFConvNextEncoder(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.stages = []
+ drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
+ drop_path_rates = tf.split(drop_path_rates, config.depths)
+ drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
+ prev_chs = config.hidden_sizes[0]
+ for i in range(config.num_stages):
+ out_chs = config.hidden_sizes[i]
+ stage = TFConvNextStage(
+ config,
+ in_channels=prev_chs,
+ out_channels=out_chs,
+ stride=2 if i > 0 else 1,
+ depth=config.depths[i],
+ drop_path_rates=drop_path_rates[i],
+ name=f"stages.{i}",
+ )
+ self.stages.append(stage)
+ prev_chs = out_chs
+
+ def call(self, hidden_states, output_hidden_states=False, return_dict=True):
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.stages):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = layer_module(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+ def build(self, input_shape=None):
+ for stage in self.stages:
+ with tf.name_scope(stage.name):
+ stage.build(None)
+
+
+@keras_serializable
+class TFConvNextMainLayer(keras.layers.Layer):
+ config_class = ConvNextConfig
+
+ def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embeddings = TFConvNextEmbeddings(config, name="embeddings")
+ self.encoder = TFConvNextEncoder(config, name="encoder")
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ # We are setting the `data_format` like so because from here on we will revert to the
+ # NCHW output format
+ self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values, training=training)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ # Change to NCHW output format have uniformity in the modules
+ last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
+ pooled_output = self.layernorm(self.pooler(last_hidden_state))
+
+ # Change the other hidden state outputs to NCHW as well
+ if output_hidden_states:
+ hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
+
+ if not return_dict:
+ hidden_states = hidden_states if output_hidden_states else ()
+ return (last_hidden_state, pooled_output) + hidden_states
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, self.config.hidden_sizes[-1]])
+
+
+class TFConvNextPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ConvNextConfig
+ base_model_prefix = "convnext"
+ main_input_name = "pixel_values"
+
+
+CONVNEXT_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CONVNEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`ConvNextImageProcessor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+"""
+
+
+@add_start_docstrings(
+ "The bare ConvNext model outputting raw features without any specific head on top.",
+ CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextModel(TFConvNextPreTrainedModel):
+ def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, TFConvNextModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ outputs = self.convnext(
+ pixel_values=pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return (outputs[0],) + outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=outputs.last_hidden_state,
+ pooler_output=outputs.pooler_output,
+ hidden_states=outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convnext", None) is not None:
+ with tf.name_scope(self.convnext.name):
+ self.convnext.build(None)
+
+
+@add_start_docstrings(
+ """
+ ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: ConvNextConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.convnext = TFConvNextMainLayer(config, name="convnext")
+
+ # Classifier head
+ self.classifier = keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="classifier",
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: TFModelInputType | None = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, TFConvNextForImageClassification
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
+ >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ outputs = self.convnext(
+ pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convnext", None) is not None:
+ with tf.name_scope(self.convnext.name):
+ self.convnext.build(None)
+ if getattr(self, "classifier", None) is not None:
+ if hasattr(self.classifier, "name"):
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
+
+
+__all__ = ["TFConvNextForImageClassification", "TFConvNextModel", "TFConvNextPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..455f7ffec5dee0567039cfd844588cb991c15622
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_decision_transformer import *
+ from .modeling_decision_transformer import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b302319ce5786d77101087c8da66a7deb200be96
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..935ba6dbf3ad967f81adcef9a6c7fb68f56bbfaa
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b0b36beec6dacf375ccbe131395eb50d668bcac
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c0f918c43578dd43405e9d68eb807cab0d8144
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py
@@ -0,0 +1,963 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DecisionTransformer model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_decision_transformer import DecisionTransformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "edbeeching/decision-transformer-gym-hopper-medium"
+_CONFIG_FOR_DOC = "DecisionTransformerConfig"
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.load_tf_weights_in_gpt2
+def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
+ """Load tf checkpoints in a pytorch model"""
+ try:
+ import re
+
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array.squeeze())
+
+ for name, array in zip(names, arrays):
+ name = name[6:] # skip "model/"
+ name = name.split("/")
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
+ scope_names = re.split(r"(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "w" or scope_names[0] == "g":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "b":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
+ pointer = getattr(pointer, scope_names[0])
+ pointer = getattr(pointer, "weight")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except ValueError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward
+def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if module.scale_attn_weights:
+ attn_weights = attn_weights / torch.full(
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
+ )
+
+ # Layer-wise attention scaling
+ if module.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(module.layer_idx + 1)
+
+ if not module.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = module.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2)
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2
+class DecisionTransformerGPT2Attention(nn.Module):
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ super().__init__()
+ self.config = config
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
+ 1, 1, max_positions, max_positions
+ ),
+ persistent=False,
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.split_size = self.embed_dim
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.scale_attn_weights = config.scale_attn_weights
+ self.is_cross_attention = is_cross_attention
+
+ # Layer-wise attention scaling, reordering, and upcasting
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
+ self.layer_idx = layer_idx
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
+
+ if self.is_cross_attention:
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
+ else:
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
+
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+ self.is_causal = True
+
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
+
+ # Prune conv1d layers
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
+
+ # Update hyper params
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
+ self.num_heads = self.num_heads - len(heads)
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
+ bsz, num_heads, q_seq_len, dk = query.size()
+ _, _, k_seq_len, _ = key.size()
+
+ # Preallocate attn_weights for `baddbmm`
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
+
+ # Compute Scale Factor
+ scale_factor = 1.0
+ if self.scale_attn_weights:
+ scale_factor /= float(value.size(-1)) ** 0.5
+
+ if self.scale_attn_by_inverse_layer_idx:
+ scale_factor /= float(self.layer_idx + 1)
+
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
+ with torch.amp.autocast(query.device.type, enabled=False):
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
+ if attn_weights.dtype != torch.float32:
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2)
+
+ return attn_output, attn_weights
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ **kwargs,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
+ )
+
+ query_states = self.q_attn(hidden_states)
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else:
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
+
+ query_states = query_states.view(shape_q).transpose(1, 2)
+ key_states = key_states.view(shape_kv).transpose(1, 2)
+ value_states = value_states.view(shape_kv).transpose(1, 2)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key_states = torch.cat((past_key, key_states), dim=-2)
+ value_states = torch.cat((past_value, value_states), dim=-2)
+
+ if use_cache is True:
+ present = (key_states, value_states)
+ else:
+ present = None
+
+ is_cross_attention = encoder_hidden_states is not None
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
+
+ using_eager = self.config._attn_implementation == "eager"
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
+ using_eager = True
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
+ # not necessarily to eager (if mentionned options are provided).
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if using_eager and self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
+ query_states, key_states, value_states, attention_mask, head_mask
+ )
+ else:
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ head_mask=head_mask,
+ dropout=self.attn_dropout.p if self.training else 0.0,
+ is_causal=is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs # a, present, (attentions)
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
+class DecisionTransformerGPT2MLP(nn.Module):
+ def __init__(self, intermediate_size, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
+class DecisionTransformerGPT2Block(nn.Module):
+ # Ignore copy
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.attn = DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ if config.add_cross_attention:
+ self.crossattention = DecisionTransformerGPT2Attention(
+ config, is_cross_attention=True, layer_idx=layer_idx
+ )
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs = self.attn(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
+ outputs = attn_outputs[1:]
+ # residual connection
+ hidden_states = attn_output + residual
+
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_outputs = self.crossattention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attn_output = cross_attn_outputs[0]
+ # residual connection
+ hidden_states = residual + attn_output
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ return outputs # hidden_states, present, (attentions, cross_attentions)
+
+
+class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DecisionTransformerConfig
+ load_tf_weights = load_tf_weights_in_gpt2
+ base_model_prefix = "transformer"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear, Conv1D)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ for name, p in module.named_parameters():
+ if "c_proj" in name and "weight" in name:
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
+
+
+class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList(
+ [DecisionTransformerGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
+ )
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+ if position_ids is None:
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0)
+
+ # Attention mask.
+ if attention_mask is not None:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
+ if layer_past is not None:
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ outputs = self._gradient_checkpointing_func(
+ block.__call__,
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_cache,
+ output_attentions,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@dataclass
+class DecisionTransformerOutput(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
+ Environment state predictions
+ action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
+ Model action predictions
+ return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
+ Predicted returns for each state
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ state_preds: torch.FloatTensor = None
+ action_preds: torch.FloatTensor = None
+ return_preds: torch.FloatTensor = None
+ hidden_states: torch.FloatTensor = None
+ attentions: torch.FloatTensor = None
+ last_hidden_state: torch.FloatTensor = None
+
+
+class DecisionTransformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DecisionTransformerConfig
+ base_model_prefix = "decision_transformer"
+ main_input_name = "states"
+ supports_gradient_checkpointing = False
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+DECISION_TRANSFORMER_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`~DecisionTransformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DECISION_TRANSFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
+ The states for each step in the trajectory
+ actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
+ The actions taken by the "expert" policy for the current state, these are masked for auto regressive
+ prediction
+ rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
+ The rewards for each state, action
+ returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
+ The returns for each state in the trajectory
+ timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
+ The timestep for each step in the trajectory
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
+ Masking, used to mask the actions when performing autoregressive prediction
+"""
+
+
+@add_start_docstrings("The Decision Transformer Model", DECISION_TRANSFORMER_START_DOCSTRING)
+class DecisionTransformerModel(DecisionTransformerPreTrainedModel):
+ """
+
+ The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL
+ setting. Refer to the paper for more details: https://arxiv.org/abs/2106.01345
+
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.hidden_size = config.hidden_size
+ # note: the only difference between this GPT2Model and the default Huggingface version
+ # is that the positional embeddings are removed (since we'll add those ourselves)
+ self.encoder = DecisionTransformerGPT2Model(config)
+
+ self.embed_timestep = nn.Embedding(config.max_ep_len, config.hidden_size)
+ self.embed_return = torch.nn.Linear(1, config.hidden_size)
+ self.embed_state = torch.nn.Linear(config.state_dim, config.hidden_size)
+ self.embed_action = torch.nn.Linear(config.act_dim, config.hidden_size)
+
+ self.embed_ln = nn.LayerNorm(config.hidden_size)
+
+ # note: we don't predict states or returns for the paper
+ self.predict_state = torch.nn.Linear(config.hidden_size, config.state_dim)
+ self.predict_action = nn.Sequential(
+ *([nn.Linear(config.hidden_size, config.act_dim)] + ([nn.Tanh()] if config.action_tanh else []))
+ )
+ self.predict_return = torch.nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DECISION_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=DecisionTransformerOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ states: Optional[torch.FloatTensor] = None,
+ actions: Optional[torch.FloatTensor] = None,
+ rewards: Optional[torch.FloatTensor] = None,
+ returns_to_go: Optional[torch.FloatTensor] = None,
+ timesteps: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], DecisionTransformerOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import DecisionTransformerModel
+ >>> import torch
+
+ >>> model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium")
+ >>> # evaluation
+ >>> model = model.to(device)
+ >>> model.eval()
+
+ >>> env = gym.make("Hopper-v3")
+ >>> state_dim = env.observation_space.shape[0]
+ >>> act_dim = env.action_space.shape[0]
+
+ >>> state = env.reset()
+ >>> states = torch.from_numpy(state).reshape(1, 1, state_dim).to(device=device, dtype=torch.float32)
+ >>> actions = torch.zeros((1, 1, act_dim), device=device, dtype=torch.float32)
+ >>> rewards = torch.zeros(1, 1, device=device, dtype=torch.float32)
+ >>> target_return = torch.tensor(TARGET_RETURN, dtype=torch.float32).reshape(1, 1)
+ >>> timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
+ >>> attention_mask = torch.zeros(1, 1, device=device, dtype=torch.float32)
+
+ >>> # forward pass
+ >>> with torch.no_grad():
+ ... state_preds, action_preds, return_preds = model(
+ ... states=states,
+ ... actions=actions,
+ ... rewards=rewards,
+ ... returns_to_go=target_return,
+ ... timesteps=timesteps,
+ ... attention_mask=attention_mask,
+ ... return_dict=False,
+ ... )
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, seq_length = states.shape[0], states.shape[1]
+
+ if attention_mask is None:
+ # attention mask for GPT: 1 if can be attended to, 0 if not
+ attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
+
+ # embed each modality with a different head
+ state_embeddings = self.embed_state(states)
+ action_embeddings = self.embed_action(actions)
+ returns_embeddings = self.embed_return(returns_to_go)
+ time_embeddings = self.embed_timestep(timesteps)
+
+ # time embeddings are treated similar to positional embeddings
+ state_embeddings = state_embeddings + time_embeddings
+ action_embeddings = action_embeddings + time_embeddings
+ returns_embeddings = returns_embeddings + time_embeddings
+
+ # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
+ # which works nice in an autoregressive sense since states predict actions
+ stacked_inputs = (
+ torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1)
+ .permute(0, 2, 1, 3)
+ .reshape(batch_size, 3 * seq_length, self.hidden_size)
+ )
+ stacked_inputs = self.embed_ln(stacked_inputs)
+
+ # to make the attention mask fit the stacked inputs, have to stack it as well
+ stacked_attention_mask = (
+ torch.stack((attention_mask, attention_mask, attention_mask), dim=1)
+ .permute(0, 2, 1)
+ .reshape(batch_size, 3 * seq_length)
+ )
+ device = stacked_inputs.device
+ # we feed in the input embeddings (not word indices as in NLP) to the model
+ encoder_outputs = self.encoder(
+ inputs_embeds=stacked_inputs,
+ attention_mask=stacked_attention_mask,
+ position_ids=torch.zeros(stacked_attention_mask.shape, device=device, dtype=torch.long),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ x = encoder_outputs[0]
+
+ # reshape x so that the second dimension corresponds to the original
+ # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
+ x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
+
+ # get predictions
+ return_preds = self.predict_return(x[:, 2]) # predict next return given state and action
+ state_preds = self.predict_state(x[:, 2]) # predict next state given state and action
+ action_preds = self.predict_action(x[:, 1]) # predict next action given state
+ if not return_dict:
+ return (state_preds, action_preds, return_preds)
+
+ return DecisionTransformerOutput(
+ last_hidden_state=encoder_outputs.last_hidden_state,
+ state_preds=state_preds,
+ action_preds=action_preds,
+ return_preds=return_preds,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+__all__ = [
+ "DecisionTransformerGPT2Model",
+ "DecisionTransformerGPT2PreTrainedModel",
+ "DecisionTransformerModel",
+ "DecisionTransformerPreTrainedModel",
+]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71edfdf40cba6b49e3a7b564255f30f9613fc08c
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e477eb1d2cc2ff0e34d214ed99ad3f80afe8ab0a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_gpt_sw3 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..280f77ab5d6dfa4b749db348e25ed4acda973cd3
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bcd102a86b85d14b27f4eaa0d972d0a56f08da02
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
new file mode 100644
index 0000000000000000000000000000000000000000..7991988c74849ee8815ea315ebd3235f4c8d4012
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
@@ -0,0 +1,299 @@
+"""The tokenizer used by the GPT-SW3 models."""
+
+import os
+import re
+import unicodedata
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import sentencepiece as spm
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
+
+
+class GPTSw3Tokenizer(PreTrainedTokenizer):
+ """
+ Construct an GPTSw3 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Example usage:
+ ```python
+ >>> from transformers import GPTSw3Tokenizer
+
+ >>> tokenizer = GPTSw3Tokenizer.from_pretrained("AI-Sweden-Models/gpt-sw3-126m")
+ >>> tokenizer("Svenska är kul!")["input_ids"]
+ [1814, 377, 3617, 63504]
+ ```
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ do_lower_case (`bool`, *optional*, defaults to `False`):
+ Whether or not to lowercase the input when tokenizing.
+ remove_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
+ keep_accents (`bool`, *optional*, defaults to `False`):
+ Whether or not to keep accents when tokenizing.
+ pad_token (`str`, *optional*):
+ The token used for padding, for example when batching sequences of different lengths. If not provided, will
+ default to '' or '' depending on model size.
+ unk_token (`str`, *optional*):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead. If not provided, will default to ''.
+ eos_token (`str`, *optional*):
+ The end of sequence token seen during pretraining. If not provided, will default to '<|endoftext|>'
+ bos_token (`str`, *optional*):
+ The beginning of sequence token that can be used for downstream task, was not seen during pretraining. If
+ not provided, will default to '' or '<|endoftext|>', depending on model size.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+
+ Attributes:
+ sp_model (`SentencePieceProcessor`):
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
+ whitespaces (`set`):
+ The whitespaces that are replaced in the whitespace normalization in preprocessing.
+ non_printing_characters_re (`Pattern`):
+ The compiled regular expression to remove non-printing characters in preprocessing.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=False,
+ remove_space=False,
+ keep_accents=False,
+ pad_token=None,
+ unk_token=None,
+ eos_token=None,
+ bos_token=None,
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ name_or_path = kwargs.get("name_or_path")
+ if name_or_path is None:
+ logger.warning(
+ "name_or_path not provided, will work for all GPTSw3 models except gpt-sw3-7b,"
+ " you are testing the model, this can safely be ignored"
+ )
+ name_or_path = "None"
+
+ # Default definitions for our 2 tokenizer versions, with None-checks to enable proper testing
+ eos_token = "<|endoftext|>" if eos_token is None else eos_token
+ unk_token = "" if unk_token is None else unk_token
+ if "gpt-sw3-7b" in name_or_path:
+ pad_token = unk_token if pad_token is None else pad_token
+ bos_token = eos_token if bos_token is None else bos_token
+ else:
+ pad_token = "" if pad_token is None else pad_token
+ bos_token = "" if bos_token is None else bos_token
+
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.keep_accents = keep_accents
+ self.vocab_file = vocab_file
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(vocab_file)
+
+ # Used for whitespace normalization in input texts
+ # fmt : off
+ self.whitespaces = {" ", " ", " ", " ", " ", " ", " ", " ", " ", " ", "", ""}
+ # fmt : on
+
+ # Regular expression to remove non-printing characters (e.g. some unicode control chars) in preprocessing
+ self.non_printing_characters_re = re.compile(
+ f"[{''.join(map(chr, list(range(0, 9)) + list(range(11, 32)) + list(range(127, 160)) + [160, 173, 8203]))}]"
+ )
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ remove_space=remove_space,
+ keep_accents=keep_accents,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__getstate__
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__setstate__
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab_file)
+
+ @property
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.vocab_size
+ def vocab_size(self) -> int:
+ return len(self.sp_model)
+
+ def preprocess_text(self, text: str) -> str:
+ """
+ Returns the preprocessed text. This procedure is identical to what was used when training the tokenizer.
+ """
+
+ # Remove non-printing characters
+ text = self.non_printing_characters_re.sub("", text)
+
+ # Normalize whitespaces
+ text = "".join([char if char not in self.whitespaces else " " for char in text])
+
+ # NFC Unicode normalization
+ text = unicodedata.normalize("NFC", text)
+ return text
+
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
+ text = self.preprocess_text(text)
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token: str) -> int:
+ """Converts a token (str) to an id (int) using the vocab."""
+ return self.sp_model.PieceToId(token)
+
+ def _convert_id_to_token(self, index: int) -> str:
+ """Converts an index (int) to a token (str) using the vocab."""
+ return self.sp_model.IdToPiece(index)
+
+ @staticmethod
+ def clean_up_tokenization(out_string: str) -> str:
+ """Returns the input string, this function is overridden to remove the default clean up."""
+ return out_string
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ """Converts a sequence of tokens (strings) to a single string. Special tokens remain intact."""
+ current_sub_tokens = []
+ out_string = ""
+ prev_is_special = False
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ # TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document
+ if not prev_is_special:
+ out_string += " "
+
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+
+ return out_string
+
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.get_vocab
+ def get_vocab(self) -> Dict[str, int]:
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def encode_fast(
+ self, text: Union[str, List[str]], return_tensors: Union[str, bool] = False
+ ) -> Union[List[int], List[List[int]], "torch.Tensor"]:
+ """
+ Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced
+ functionality but is often much faster.
+
+ Does NOT handle special tokens correctly, these can manually be added as ids afterwards.
+
+ Does NOT support padding, these can manually be added as ids afterwards.
+
+ Use default HuggingFace tokenization methods for full functionality.
+
+ Args:
+ text (`str` or `List[str]`): One or several text(s) to convert to token ids.
+ return_tensors (`str` or `bool`): Returns PyTorch tensors if set to True or "pt"
+
+ Returns:
+ `List[int]`, `List[List[int]]`, or `torch.Tensor`: The encoded text(s) as token ids.
+ """
+
+ if isinstance(text, str):
+ text = self.preprocess_text(text)
+ token_ids = self.sp_model.encode(text)
+ else:
+ text = [self.preprocess_text(t) for t in text]
+ token_ids = self.sp_model.encode(text)
+
+ if return_tensors is True or return_tensors == "pt":
+ token_ids = torch.tensor(token_ids)
+
+ return token_ids
+
+ def decode_fast(self, token_ids: Union[int, List[int]]) -> str:
+ """
+ Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced
+ functionality but is often much faster.
+
+ Args:
+ token_ids (`int` or `List[int]`): Encoded token or text as token id(s).
+
+ Returns:
+ `str`: Decoded text
+ """
+
+ return self.sp_model.decode(token_ids)
+
+
+__all__ = ["GPTSw3Tokenizer"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..880274309cbab486a90df05548a0e4d3f2ea0925
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_musicgen import *
+ from .modeling_musicgen import *
+ from .processing_musicgen import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40499a73df1fde3b01bebb5aa499b7a3332d5ffa
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d384b0c59b210bd63df321cd2bcc15a86176eb25
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9b359b033b3ceb60d92cd733389a6a2c4cc5c7d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c38caf20dc4136e794e5762bf0b58847817dd19
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py
@@ -0,0 +1,247 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MusicGen model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto.configuration_auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class MusicgenDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a
+ MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the MusicGen
+ [facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 2048):
+ Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`MusicgenDecoder`].
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 24):
+ Number of decoder layers.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer block.
+ ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically, set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_factor (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(hidden_size).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values attentions (not used by all models)
+ num_codebooks (`int`, *optional*, defaults to 4):
+ The number of parallel codebooks forwarded to the model.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether input and output word embeddings should be tied.
+ audio_channels (`int`, *optional*, defaults to 1
+ Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
+ audio stream for the left/right output channels. Mono models generate a single audio stream output.
+ """
+
+ model_type = "musicgen_decoder"
+ base_config_key = "decoder_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=2048,
+ max_position_embeddings=2048,
+ num_hidden_layers=24,
+ ffn_dim=4096,
+ num_attention_heads=16,
+ layerdrop=0.0,
+ use_cache=True,
+ activation_function="gelu",
+ hidden_size=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ initializer_factor=0.02,
+ scale_embedding=False,
+ num_codebooks=4,
+ audio_channels=1,
+ pad_token_id=2048,
+ bos_token_id=2048,
+ eos_token_id=None,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.ffn_dim = ffn_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.initializer_factor = initializer_factor
+ self.layerdrop = layerdrop
+ self.use_cache = use_cache
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ self.num_codebooks = num_codebooks
+
+ if audio_channels not in [1, 2]:
+ raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
+ self.audio_channels = audio_channels
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class MusicgenConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a
+ MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder
+ configs.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ kwargs (*optional*):
+ Dictionary of keyword arguments. Notably:
+
+ - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the text encoder config.
+ - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
+ defines the audio encoder config.
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+ the decoder config.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... MusicgenConfig,
+ ... MusicgenDecoderConfig,
+ ... T5Config,
+ ... EncodecConfig,
+ ... MusicgenForConditionalGeneration,
+ ... )
+
+ >>> # Initializing text encoder, audio encoder, and decoder model configurations
+ >>> text_encoder_config = T5Config()
+ >>> audio_encoder_config = EncodecConfig()
+ >>> decoder_config = MusicgenDecoderConfig()
+
+ >>> configuration = MusicgenConfig.from_sub_models_config(
+ ... text_encoder_config, audio_encoder_config, decoder_config
+ ... )
+
+ >>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
+ >>> model = MusicgenForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ >>> config_text_encoder = model.config.text_encoder
+ >>> config_audio_encoder = model.config.audio_encoder
+ >>> config_decoder = model.config.decoder
+
+ >>> # Saving the model, including its configuration
+ >>> model.save_pretrained("musicgen-model")
+
+ >>> # loading model and config from pretrained folder
+ >>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
+ ```"""
+
+ model_type = "musicgen"
+ sub_configs = {
+ "text_encoder": AutoConfig,
+ "audio_encoder": AutoConfig,
+ "decoder": MusicgenDecoderConfig,
+ }
+ is_composition = True
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
+ raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
+
+ text_encoder_config = kwargs.pop("text_encoder")
+ text_encoder_model_type = text_encoder_config.pop("model_type")
+
+ audio_encoder_config = kwargs.pop("audio_encoder")
+ audio_encoder_model_type = audio_encoder_config.pop("model_type")
+
+ decoder_config = kwargs.pop("decoder")
+
+ self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
+ self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
+ self.decoder = MusicgenDecoderConfig(**decoder_config)
+ self.is_encoder_decoder = True
+
+ @classmethod
+ def from_sub_models_config(
+ cls,
+ text_encoder_config: PretrainedConfig,
+ audio_encoder_config: PretrainedConfig,
+ decoder_config: MusicgenDecoderConfig,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder
+ configurations.
+
+ Returns:
+ [`MusicgenConfig`]: An instance of a configuration object
+ """
+
+ return cls(
+ text_encoder=text_encoder_config.to_dict(),
+ audio_encoder=audio_encoder_config.to_dict(),
+ decoder=decoder_config.to_dict(),
+ **kwargs,
+ )
+
+ @property
+ # This is a property because you might want to change the codec model on the fly
+ def sampling_rate(self):
+ return self.audio_encoder.sampling_rate
+
+
+__all__ = ["MusicgenConfig", "MusicgenDecoderConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..cab950995a97558af23c1a5cf856be4f568d3f1c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py
@@ -0,0 +1,2756 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Musicgen model."""
+
+import copy
+import inspect
+import math
+import random
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...generation import (
+ ClassifierFreeGuidanceLogitsProcessor,
+ GenerationConfig,
+ GenerationMixin,
+ GenerationMode,
+ LogitsProcessorList,
+ StoppingCriteriaList,
+)
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ ModelOutput,
+ Seq2SeqLMOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_auto import AutoModel
+from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
+
+
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+if TYPE_CHECKING:
+ from ...generation.streamers import BaseStreamer
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "MusicgenConfig"
+_CHECKPOINT_FOR_DOC = "facebook/musicgen-small"
+
+
+@dataclass
+class MusicgenUnconditionalInput(ModelOutput):
+ """
+ Args:
+ encoder_outputs (`Tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the text encoder model.
+ attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
+ Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
+ 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
+ guidance_scale (`float`, *optional*):
+ Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted
+ from the prompts) and the unconditional logits (predicted without prompts).
+ """
+
+ encoder_outputs: Tuple[torch.FloatTensor] = None
+ attention_mask: torch.LongTensor = None
+ guidance_scale: float = None
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ # transpose to get (bsz, num_codebooks, seq_len)
+ input_ids = input_ids.transpose(1, 2)
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+ if decoder_start_token_id is None:
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+ shifted_input_ids[..., 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class MusicgenSinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.make_weights(num_positions, embedding_dim)
+
+ def make_weights(self, num_embeddings: int, embedding_dim: int):
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim)
+ if hasattr(self, "weights"):
+ # in forward put the weights on the correct dtype and device of the param
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
+
+ self.weights = nn.Parameter(emb_weights)
+ self.weights.requires_grad = False
+ self.weights.detach_()
+
+ @staticmethod
+ def get_embedding(num_embeddings: int, embedding_dim: int):
+ """
+ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
+ description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ return emb.to(torch.get_default_dtype())
+
+ @torch.no_grad()
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+ bsz, codebooks, seq_len = input_ids.size()
+ # Create the position ids from the input token ids.
+ position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
+ # expand embeddings if needed
+ if seq_len > self.weights.size(0):
+ self.make_weights(seq_len + self.offset, self.embedding_dim)
+ return self.weights.index_select(0, position_ids.view(-1)).detach()
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Musicgen
+class MusicgenAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[MusicgenConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen
+class MusicgenFlashAttention2(MusicgenAttention):
+ """
+ Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # MusicgenFlashAttention2 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=self.dropout if self.training else 0.0,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class MusicgenSdpaAttention(MusicgenAttention):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+ if output_attentions or layer_head_mask is not None:
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states,
+ key_value_states=key_value_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ if (
+ attention_mask is not None
+ and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any()
+ ):
+ logger.warning_once(
+ '`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ "Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information."
+ )
+ return super().forward(
+ hidden_states,
+ key_value_states=key_value_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ query_states = self._shape(query_states, tgt_len, bsz)
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
+ is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+MUSICGEN_ATTENTION_CLASSES = {
+ "eager": MusicgenAttention,
+ "sdpa": MusicgenSdpaAttention,
+ "flash_attention_2": MusicgenFlashAttention2,
+}
+
+
+class MusicgenDecoderLayer(nn.Module):
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+
+ self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ is_causal=True,
+ config=config,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
+ self.embed_dim,
+ config.num_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ bias=False,
+ config=config,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class MusicgenPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = MusicgenDecoderConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_factor
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+MUSICGEN_START_DOCSTRING = r"""
+
+ The Musicgen model was proposed in [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by
+ Jade Copet, Felix Kreuk, Itai Gat, Tal Remez, David Kant, Gabriel Synnaeve, Yossi Adi, Alexandre Défossez. It is an
+ encoder decoder transformer trained on the task of conditional music generation
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`MusicgenConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MUSICGEN_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+
+
+ The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `decoder_input_ids`.
+
+
+
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+ input (see `past_key_values`). This is useful if you want more control over how to convert
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+ of `inputs_embeds`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
+
+ Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
+ such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+
+
+ The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
+ target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
+ you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
+ frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
+ target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
+ `input_ids`.
+
+
+
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
+ the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class MusicgenDecoder(MusicgenPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenDecoderLayer`]
+ """
+
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.layerdrop
+ self.max_target_positions = config.max_position_embeddings
+ self.d_model = config.hidden_size
+ self.num_codebooks = config.num_codebooks
+ self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
+
+ embed_dim = config.vocab_size + 1
+ self.embed_tokens = nn.ModuleList(
+ [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
+ )
+
+ self.embed_positions = MusicgenSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ )
+
+ self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.attn_implementation = config._attn_implementation
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len)
+ input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+ bsz, num_codebooks, seq_len = input.shape
+ input_shape = (bsz, seq_len)
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1:]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
+
+ if self.attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ else:
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ # embed positions
+ positions = self.embed_positions(input, past_key_values_length)
+
+ hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {attn_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.forward,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare Musicgen decoder model outputting raw hidden-states without any specific head on top.",
+ MUSICGEN_START_DOCSTRING,
+)
+class MusicgenModel(MusicgenPreTrainedModel):
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+ self.decoder = MusicgenDecoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ def get_decoder(self):
+ return self.decoder
+
+ @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The MusicGen decoder model with a language modelling head on top.",
+ MUSICGEN_START_DOCSTRING,
+)
+class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
+ def __init__(self, config: MusicgenDecoderConfig):
+ super().__init__(config)
+
+ self.model = MusicgenModel(config)
+
+ self.num_codebooks = config.num_codebooks
+ self.lm_heads = nn.ModuleList(
+ [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)]
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_heads
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_heads = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ Returns:
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (labels is not None) and (input_ids is None and inputs_embeds is None):
+ input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)
+
+ loss = None
+ if labels is not None:
+ # since encoder hidden states have been concatenated to the decoder hidden states,
+ # we take the last timestamps corresponding to labels
+ logits = lm_logits[:, :, -labels.shape[1] :]
+
+ loss_fct = CrossEntropyLoss()
+ loss = torch.zeros([], device=self.device)
+
+ # per codebook cross-entropy
+ # -100 labels are ignored
+ labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
+
+ # per codebook cross-entropy
+ # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
+ for codebook in range(self.config.num_codebooks):
+ codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
+ codebook_labels = labels[..., codebook].contiguous().view(-1)
+ loss += loss_fct(codebook_logits, codebook_labels)
+
+ loss = loss / self.config.num_codebooks
+
+ # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
+ lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=True,
+ delay_pattern_mask=None,
+ guidance_scale=None,
+ **kwargs,
+ ):
+ # Overwritten -- MusicGen has custom processing
+ if delay_pattern_mask is None:
+ input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=self.generation_config.pad_token_id,
+ max_length=self.generation_config.max_length,
+ )
+
+ # apply the delay pattern mask
+ input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
+
+ if guidance_scale is not None and guidance_scale > 1:
+ # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+ # before sampling)
+ input_ids = input_ids.repeat((2, 1))
+ if attention_mask is not None:
+ attention_mask = attention_mask.repeat((2, 1))
+
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "head_mask": head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None):
+ """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
+ one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
+ are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
+ seq_len)`:
+ - [P, -1, -1, -1, -1, P, P, P]
+ - [P, P, -1, -1, -1, -1, P, P]
+ - [P, P, P, -1, -1, -1, -1, P]
+ - [P, P, P, P, -1, -1, -1, -1]
+ where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
+ a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
+ mask is set to the value in the prompt:
+ - [P, a, b, -1, -1, P, P, P]
+ - [P, P, c, d, -1, -1, P, P]
+ - [P, P, P, e, f, -1, -1, P]
+ - [P, P, P, P, g, h, -1, -1]
+ where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
+ tokens in our prediction.
+ """
+ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+ input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
+ bsz, num_codebooks, seq_len = input_ids.shape
+
+ max_length = max_length if max_length is not None else self.generation_config.max_length
+ input_ids_shifted = (
+ torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
+ )
+
+ channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
+ # we only apply the mask if we have a large enough seq len - otherwise we return as is
+ if max_length < 2 * channel_codebooks - 1:
+ return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
+
+ # fill the shifted ids with the prompt entries, offset by the codebook idx
+ for codebook in range(channel_codebooks):
+ if self.config.audio_channels == 1:
+ # mono channel - loop over the codebooks one-by-one
+ input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
+ else:
+ # left/right channels are interleaved in the generated codebooks, so handle one then the other
+ input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
+ input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
+
+ # construct a pattern mask that indicates the positions of padding tokens for each codebook
+ # first fill the upper triangular part (the EOS padding)
+ delay_pattern = torch.triu(
+ torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
+ )
+ # then fill the lower triangular part (the BOS padding)
+ delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
+
+ if self.config.audio_channels == 2:
+ # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
+ delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
+
+ mask = ~delay_pattern.to(input_ids.device)
+ input_ids = mask * input_ids_shifted + ~mask * pad_token_id
+
+ # find the first position to start generating - this is the first place we have the -1 token
+ # and will always be in the first codebook (since it has no codebook offset)
+ first_codebook_ids = input_ids[:, 0, :]
+ start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
+ if len(start_ids) > 0:
+ first_start_id = min(start_ids)
+ else:
+ # we have no tokens that need to be filled - return entire matrix of input ids
+ first_start_id = seq_len
+
+ # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
+ pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
+ input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
+ return input_ids, pattern_mask
+
+ @staticmethod
+ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
+ """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
+ the mask is set to -1, and otherwise setting to the value detailed in the mask."""
+ seq_len = input_ids.shape[-1]
+ decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
+ input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
+ return input_ids
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+ ):
+ """
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateDecoderOnlyOutput`],
+ - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateEncoderDecoderOutput`],
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+ if generation_config is None:
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+ # 3. Define model inputs`
+ input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = input_ids.shape[0] // self.num_codebooks
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
+
+ # 4. Define other model kwargs
+ model_kwargs["use_cache"] = generation_config.use_cache
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+ if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ input_ids, generation_config, model_kwargs
+ )
+
+ # 5. Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=input_ids,
+ input_ids_length=input_ids_length,
+ )
+
+ # 6. Prepare `input_ids` which will be used for auto-regressive generation
+ # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+ input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=generation_config._decoder_start_token_tensor,
+ max_length=generation_config.max_length,
+ )
+
+ if streamer is not None:
+ streamer.put(input_ids.cpu())
+
+ # stash the delay mask so that we don't have to recompute it in each forward pass
+ model_kwargs["delay_pattern_mask"] = delay_pattern_mask
+
+ # 7. determine generation mode
+ generation_mode = generation_config.get_generation_mode()
+
+ # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+ logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+ generation_config.guidance_scale = None
+
+ # 9. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=logits_processor,
+ device=input_ids.device,
+ )
+
+ # 10. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ **model_kwargs,
+ )
+
+ # 11. run sample
+ outputs = self._sample(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
+ "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+ )
+
+ if generation_config.return_dict_in_generate:
+ output_ids = outputs.sequences
+ else:
+ output_ids = outputs
+
+ # apply the pattern mask to the final ids
+ output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
+
+ # revert the pattern delay mask by filtering the pad token id
+ output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
+ batch_size, self.num_codebooks, -1
+ )
+
+ if generation_config.return_dict_in_generate:
+ outputs.sequences = output_ids
+ return outputs
+ else:
+ return output_ids
+
+
+@add_start_docstrings(
+ "The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder, "
+ "for music generation tasks with one or both of text and audio prompts.",
+ MUSICGEN_START_DOCSTRING,
+)
+class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
+ config_class = MusicgenConfig
+ base_model_prefix = "encoder_decoder"
+ main_input_name = "input_ids"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def __init__(
+ self,
+ config: Optional[MusicgenConfig] = None,
+ text_encoder: Optional[PreTrainedModel] = None,
+ audio_encoder: Optional[PreTrainedModel] = None,
+ decoder: Optional[MusicgenForCausalLM] = None,
+ ):
+ if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
+ raise ValueError(
+ "Either a configuration has to be provided, or all three of text encoder, audio encoder and MusicGen decoder."
+ )
+ if config is None:
+ config = MusicgenConfig.from_sub_models_config(text_encoder.config, audio_encoder.config, decoder.config)
+ else:
+ if not isinstance(config, self.config_class):
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the MusicGen decoder's configuration, it has to be equal"
+ f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
+ " `config.text_encoder.hidden_size`."
+ )
+
+ # initialize with config
+ super().__init__(config)
+
+ if text_encoder is None:
+ from ..auto.modeling_auto import AutoModelForTextEncoding
+
+ text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder)
+
+ if audio_encoder is None:
+ from ..auto.modeling_auto import AutoModel
+
+ audio_encoder = AutoModel.from_config(config.audio_encoder)
+
+ if decoder is None:
+ decoder = MusicgenForCausalLM._from_config(config.decoder)
+
+ self.text_encoder = text_encoder
+ self.audio_encoder = audio_encoder
+ self.decoder = decoder
+
+ if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict():
+ logger.warning(
+ f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:"
+ f" {self.config.text_encoder}"
+ )
+ if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict():
+ logger.warning(
+ f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:"
+ f" {self.config.audio_encoder}"
+ )
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+ logger.warning(
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
+ )
+
+ # make sure that the individual model's config refers to the shared config
+ # so that the updates to the config will be synced
+ self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation
+ self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation
+ self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
+ self.text_encoder.config = self.config.text_encoder
+ self.audio_encoder.config = self.config.audio_encoder
+ self.decoder.config = self.config.decoder
+
+ # text encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
+
+ if self.text_encoder.get_output_embeddings() is not None:
+ raise ValueError(
+ f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
+ )
+
+ decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
+ if "encoder_hidden_states" not in decoder_signature:
+ raise ValueError(
+ "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+ "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+ )
+
+ # tie text encoder, decoder weights if config set accordingly
+ self.tie_weights()
+
+ def tie_weights(self):
+ # tie text encoder & decoder if needed
+ if self.config.tie_encoder_decoder:
+ # tie text encoder and decoder base model
+ decoder_base_model_prefix = self.decoder.base_model_prefix
+ tied_weights = self._tie_encoder_decoder_weights(
+ self.text_encoder,
+ self.decoder._modules[decoder_base_model_prefix],
+ self.decoder.base_model_prefix,
+ "text_encoder",
+ )
+ # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+ # attributed not an instance member, therefore modifying it will modify the entire class
+ # Leading to issues on subsequent calls by different tests or subsequent calls.
+ self._dynamic_tied_weights_keys = tied_weights
+
+ def get_audio_encoder(self):
+ return self.audio_encoder
+
+ def get_text_encoder(self):
+ return self.text_encoder
+
+ def get_encoder(self):
+ # get the text encoder to compute the encoder hidden-states for generation
+ return self.get_text_encoder()
+
+ def get_decoder(self):
+ return self.decoder
+
+ def get_input_embeddings(self):
+ return self.text_encoder.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.decoder.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.decoder.set_output_embeddings(new_embeddings)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import MusicgenForConditionalGeneration
+
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+ ```"""
+
+ # At the moment fast initialization is not supported for composite models
+ if kwargs.get("_fast_init", False):
+ logger.warning(
+ "Fast initialization is currently not supported for MusicgenForConditionalGeneration. "
+ "Falling back to slow initialization..."
+ )
+ kwargs["_fast_init"] = False
+
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+ @classmethod
+ def from_sub_models_pretrained(
+ cls,
+ text_encoder_pretrained_model_name_or_path: str = None,
+ audio_encoder_pretrained_model_name_or_path: str = None,
+ decoder_pretrained_model_name_or_path: str = None,
+ *model_args,
+ **kwargs,
+ ) -> PreTrainedModel:
+ r"""
+ Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the
+ library from pretrained model checkpoints.
+
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you need to first set it back in training mode with `model.train()`.
+
+ Params:
+ text_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the text encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ audio_encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the audio encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration
+ parameter.
+ - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration
+ parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import MusicgenForConditionalGeneration
+
+ >>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder
+ >>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(
+ ... text_encoder_pretrained_model_name_or_path="google-t5/t5-base",
+ ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
+ ... decoder_pretrained_model_name_or_path="facebook/musicgen-small",
+ ... )
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./musicgen-ft")
+ >>> # load fine-tuned model
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("./musicgen-ft")
+ ```"""
+
+ kwargs_text_encoder = {
+ argument[len("text_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("text_encoder_")
+ }
+
+ kwargs_audio_encoder = {
+ argument[len("audio_encoder_") :]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("audio_encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove text encoder, audio encoder and decoder kwargs from kwargs
+ for key in kwargs_text_encoder.keys():
+ del kwargs["text_encoder_" + key]
+ for key in kwargs_audio_encoder.keys():
+ del kwargs["audio_encoder_" + key]
+ for key in kwargs_decoder.keys():
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ text_encoder = kwargs_text_encoder.pop("model", None)
+ if text_encoder is None:
+ if text_encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_text_encoder:
+ encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained(
+ text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model "
+ "from a decoder model. Cross-attention and casual mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_text_encoder["config"] = encoder_config
+
+ text_encoder = AutoModel.from_pretrained(
+ text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
+ )
+
+ audio_encoder = kwargs_audio_encoder.pop("model", None)
+ if audio_encoder is None:
+ if audio_encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_audio_encoder:
+ encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained(
+ audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model "
+ "from a decoder model. Cross-attention and casual mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_audio_encoder["config"] = encoder_config
+
+ audio_encoder = AutoModel.from_pretrained(
+ audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder
+ )
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+ )
+
+ if isinstance(decoder_config, MusicgenConfig):
+ decoder_config = decoder_config.decoder
+
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_sub_models_pretrained(...)`"
+ )
+
+ decoder = MusicgenForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # instantiate config with corresponding kwargs
+ config = MusicgenConfig.from_sub_models_config(
+ text_encoder.config, audio_encoder.config, decoder.config, **kwargs
+ )
+ return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
+
+ @add_start_docstrings_to_model_forward(MUSICGEN_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ input_values: Optional[torch.FloatTensor] = None,
+ padding_mask: Optional[torch.BoolTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
+ r"""
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
+ >>> import torch
+
+ >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+ >>> inputs = processor(
+ ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
+ ... padding=True,
+ ... return_tensors="pt",
+ ... )
+
+ >>> pad_token_id = model.generation_config.pad_token_id
+ >>> decoder_input_ids = (
+ ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
+ ... * pad_token_id
+ ... )
+
+ >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
+ >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)
+ torch.Size([8, 1, 2048])
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_text_encoder = {
+ argument[len("text_encoder_")]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("text_encoder_")
+ }
+
+ kwargs_audio_encoder = {
+ argument[len("audio_encoder_")]: value
+ for argument, value in kwargs.items()
+ if argument.startswith("audio_encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ if encoder_outputs is None:
+ encoder_outputs = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs_text_encoder,
+ )
+ elif isinstance(encoder_outputs, tuple):
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if (
+ self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ if attention_mask is not None:
+ encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
+ )
+
+ elif decoder_input_ids is None and decoder_inputs_embeds is None:
+ audio_encoder_outputs = self.audio_encoder(
+ input_values=input_values,
+ padding_mask=padding_mask,
+ **kwargs_audio_encoder,
+ )
+ audio_codes = audio_encoder_outputs.audio_codes
+ frames, bsz, codebooks, seq_len = audio_codes.shape
+ if frames != 1:
+ raise ValueError(
+ f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
+ "disabled by setting `chunk_length=None` in the audio encoder."
+ )
+
+ if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
+ # mono input through encodec that we convert to stereo
+ audio_codes = audio_codes.repeat_interleave(2, dim=2)
+
+ decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ return_dict=return_dict,
+ labels=labels,
+ **kwargs_decoder,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqLMOutput(
+ loss=decoder_outputs.loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_attention_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ decoder_delay_pattern_mask=None,
+ guidance_scale=None,
+ **kwargs,
+ ):
+ # Overwritten -- MusicGen has custom processing
+ if decoder_delay_pattern_mask is None:
+ decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+ decoder_input_ids,
+ self.generation_config.pad_token_id,
+ max_length=self.generation_config.max_length,
+ )
+
+ # apply the delay pattern mask
+ decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
+
+ if guidance_scale is not None and guidance_scale > 1:
+ # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
+ # before sampling)
+ decoder_input_ids = decoder_input_ids.repeat((2, 1))
+ if decoder_attention_mask is not None:
+ decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
+
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ }
+
+ def _prepare_decoder_input_ids_for_generation(
+ self,
+ batch_size: int,
+ model_input_name: str,
+ model_kwargs: Dict[str, torch.Tensor],
+ decoder_start_token_id: int = None,
+ bos_token_id: int = None,
+ device: torch.device = None,
+ ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+
+ # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
+ # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+ elif "input_ids" in model_kwargs and model_input_name != "input_ids":
+ decoder_input_ids = model_kwargs.pop("input_ids")
+ else:
+ decoder_input_ids = None
+
+ # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
+ decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
+ if device is None:
+ device = self.device
+ decoder_input_ids_start = (
+ torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device)
+ * decoder_start_token_id
+ )
+
+ # no user input -> use decoder_start_token_id as decoder_input_ids
+ if decoder_input_ids is None:
+ decoder_input_ids = decoder_input_ids_start
+
+ # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
+ # decoder_attention_mask if provided)
+ elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item():
+ decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ decoder_attention_mask = torch.cat(
+ (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
+ dim=-1,
+ )
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+
+ return decoder_input_ids, model_kwargs
+
+ def _prepare_text_encoder_kwargs_for_generation(
+ self,
+ inputs_tensor: torch.Tensor,
+ model_kwargs,
+ model_input_name: Optional[str],
+ generation_config: GenerationConfig,
+ ) -> Dict[str, Any]:
+ # 1. get text encoder
+ encoder = self.get_text_encoder()
+ # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+ # as the inputs.
+ if hasattr(encoder, "_hf_hook"):
+ encoder._hf_hook.io_same_device = True
+
+ # 2. Prepare encoder args and encoder kwargs from model kwargs.
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+ }
+ encoder_kwargs["output_attentions"] = generation_config.output_attentions
+ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ guidance_scale = generation_config.guidance_scale
+
+ # 3. make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
+ encoder_kwargs["return_dict"] = True
+ encoder_kwargs[model_input_name] = inputs_tensor
+ last_hidden_state = encoder(**encoder_kwargs).last_hidden_state
+
+ # for classifier free guidance we need to add a 'null' input to our encoder hidden states
+ if guidance_scale is not None and guidance_scale > 1:
+ last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
+ if "attention_mask" in model_kwargs:
+ model_kwargs["attention_mask"] = torch.concatenate(
+ [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
+ )
+
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
+
+ return model_kwargs
+
+ def _prepare_audio_encoder_kwargs_for_generation(
+ self, input_values, model_kwargs, model_input_name: Optional[str] = None
+ ):
+ # 1. get audio encoder
+ encoder = self.get_audio_encoder()
+ # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
+ # as the inputs.
+ if hasattr(encoder, "_hf_hook"):
+ encoder._hf_hook.io_same_device = True
+
+ # 2. Prepare encoder args and encoder kwargs from model kwargs.
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+ }
+
+ # 3. make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
+ encoder_kwargs["return_dict"] = True
+
+ if self.decoder.config.audio_channels == 1:
+ encoder_kwargs[model_input_name] = input_values
+ audio_encoder_outputs = encoder.encode(**encoder_kwargs)
+ audio_codes = audio_encoder_outputs.audio_codes
+ audio_scales = audio_encoder_outputs.audio_scales
+
+ frames, bsz, codebooks, seq_len = audio_codes.shape
+
+ else:
+ if input_values.shape[1] != 2:
+ raise ValueError(
+ f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel."
+ )
+
+ encoder_kwargs[model_input_name] = input_values[:, :1, :]
+ audio_encoder_outputs_left = encoder.encode(**encoder_kwargs)
+ audio_codes_left = audio_encoder_outputs_left.audio_codes
+ audio_scales_left = audio_encoder_outputs_left.audio_scales
+
+ encoder_kwargs[model_input_name] = input_values[:, 1:, :]
+ audio_encoder_outputs_right = encoder.encode(**encoder_kwargs)
+ audio_codes_right = audio_encoder_outputs_right.audio_codes
+ audio_scales_right = audio_encoder_outputs_right.audio_scales
+
+ frames, bsz, codebooks, seq_len = audio_codes_left.shape
+ # copy alternating left/right channel codes into stereo codebook
+ audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len))
+
+ audio_codes[:, :, ::2, :] = audio_codes_left
+ audio_codes[:, :, 1::2, :] = audio_codes_right
+
+ if audio_scales_left != [None] or audio_scales_right != [None]:
+ audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1)
+ else:
+ audio_scales = [None] * bsz
+
+ if frames != 1:
+ raise ValueError(
+ f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
+ "disabled by setting `chunk_length=None` in the audio encoder."
+ )
+
+ decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
+
+ model_kwargs["decoder_input_ids"] = decoder_input_ids
+ model_kwargs["audio_scales"] = audio_scales
+ return model_kwargs
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
+ )
+
+ def freeze_audio_encoder(self):
+ """
+ Freeze the audio encoder weights.
+ """
+ for param in self.audio_encoder.parameters():
+ param.requires_grad = False
+ self.audio_encoder._requires_grad = False
+
+ def freeze_text_encoder(self):
+ """
+ Freeze the text encoder weights.
+ """
+ for param in self.text_encoder.parameters():
+ param.requires_grad = False
+ self.text_encoder._requires_grad = False
+
+ def _maybe_initialize_input_ids_for_generation(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[int] = None,
+ model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> torch.LongTensor:
+ """Initializes input ids for generation, if necessary."""
+ if inputs is not None:
+ return inputs
+
+ encoder_outputs = model_kwargs.get("encoder_outputs")
+ if encoder_outputs is not None:
+ # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
+ shape = encoder_outputs[0].size()[:-1]
+ return torch.ones(shape, dtype=torch.long, device=self.device) * -100
+
+ if bos_token_id is None:
+ raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
+
+ # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
+ # soft-prompting or in multimodal implementations built on top of decoder-only language models.
+ batch_size = 1
+ for value in model_kwargs.values():
+ if isinstance(value, torch.Tensor):
+ batch_size = value.shape[0]
+ break
+ return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
+
+ def _get_decoder_start_token_id(
+ self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
+ ) -> int:
+ decoder_start_token_id = (
+ decoder_start_token_id
+ if decoder_start_token_id is not None
+ else self.generation_config.decoder_start_token_id
+ )
+ bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
+
+ if decoder_start_token_id is not None:
+ return decoder_start_token_id
+ elif bos_token_id is not None:
+ return bos_token_id
+ raise ValueError(
+ "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ synced_gpus: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+ ):
+ """
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateDecoderOnlyOutput`],
+ - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateEncoderDecoderOutput`],
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
+ if generation_config is None:
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple:
+ # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+ # 3. Define model inputs
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
+
+ # 4. Define other model kwargs
+ model_kwargs["use_cache"] = generation_config.use_cache
+ model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+ if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ inputs_tensor, generation_config, model_kwargs
+ )
+
+ if "encoder_outputs" not in model_kwargs:
+ # encoder_outputs are created and added to `model_kwargs`
+ model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name, generation_config
+ )
+
+ if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
+ model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
+ model_kwargs["input_values"],
+ model_kwargs,
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+ batch_size=batch_size,
+ model_input_name=model_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=generation_config._decoder_start_token_tensor,
+ bos_token_id=generation_config._bos_token_tensor,
+ device=inputs_tensor.device,
+ )
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=inputs_tensor,
+ input_ids_length=input_ids_length,
+ )
+
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
+ input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
+ input_ids,
+ pad_token_id=generation_config._decoder_start_token_tensor,
+ max_length=generation_config.max_length,
+ )
+ # stash the delay mask so that we don't have to recompute in each forward pass
+ model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
+
+ # input_ids are ready to be placed on the streamer (if used)
+ if streamer is not None:
+ streamer.put(input_ids.cpu())
+
+ # 7. determine generation mode
+ generation_mode = generation_config.get_generation_mode()
+
+ # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG)
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
+ logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+ generation_config.guidance_scale = None
+
+ # 9. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=None,
+ logits_processor=logits_processor,
+ device=input_ids.device,
+ )
+
+ # 10. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+
+ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 11. run sample
+ outputs = self._sample(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ else:
+ raise ValueError(
+ "Got incompatible mode for generation, should be one of greedy or sampling. "
+ "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+ )
+
+ if generation_config.return_dict_in_generate:
+ output_ids = outputs.sequences
+ else:
+ output_ids = outputs
+
+ # apply the pattern mask to the final ids
+ output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
+
+ # revert the pattern delay mask by filtering the pad token id
+ output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
+ batch_size, self.decoder.num_codebooks, -1
+ )
+
+ # append the frame dimension back to the audio codes
+ output_ids = output_ids[None, ...]
+
+ audio_scales = model_kwargs.get("audio_scales")
+ if audio_scales is None:
+ audio_scales = [None] * batch_size
+
+ if self.decoder.config.audio_channels == 1:
+ output_values = self.audio_encoder.decode(
+ output_ids,
+ audio_scales=audio_scales,
+ ).audio_values
+ else:
+ codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
+ output_values_left = codec_outputs_left.audio_values
+
+ codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
+ output_values_right = codec_outputs_right.audio_values
+
+ output_values = torch.cat([output_values_left, output_values_right], dim=1)
+
+ if generation_config.return_dict_in_generate:
+ outputs.sequences = output_values
+ return outputs
+ else:
+ return output_values
+
+ def get_unconditional_inputs(self, num_samples=1):
+ """
+ Helper function to get null inputs for unconditional generation, enabling the model to be used without the
+ feature extractor or tokenizer.
+
+ Args:
+ num_samples (int, *optional*):
+ Number of audio samples to unconditionally generate.
+ max_new_tokens (int, *optional*):
+ Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
+ longer inference (since more audio tokens need to be generated per sample).
+
+ Example:
+ ```python
+ >>> from transformers import MusicgenForConditionalGeneration
+
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
+
+ >>> # get the unconditional (or 'null') inputs for the model
+ >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
+ >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
+ ```"""
+ last_hidden_state = torch.zeros(
+ (num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
+ )
+
+ attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
+
+ return MusicgenUnconditionalInput(
+ encoder_outputs=(last_hidden_state,),
+ attention_mask=attention_mask,
+ guidance_scale=1.0,
+ )
+
+
+__all__ = ["MusicgenForConditionalGeneration", "MusicgenForCausalLM", "MusicgenModel", "MusicgenPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..deebf9045b4ffb7f9e438a1892fe651c3e2f54d6
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py
@@ -0,0 +1,144 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Text/audio processor class for MusicGen
+"""
+
+from typing import List, Optional
+
+import numpy as np
+
+from ...processing_utils import ProcessorMixin
+from ...utils import to_numpy
+
+
+class MusicgenProcessor(ProcessorMixin):
+ r"""
+ Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
+ class.
+
+ [`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
+ [`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor (`EncodecFeatureExtractor`):
+ An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`T5Tokenizer`):
+ An instance of [`T5Tokenizer`]. The tokenizer is a required input.
+ """
+
+ feature_extractor_class = "EncodecFeatureExtractor"
+ tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+ def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
+ return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
+
+ def __call__(self, *args, **kwargs):
+ """
+ Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
+ argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
+ information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ audio = kwargs.pop("audio", None)
+ sampling_rate = kwargs.pop("sampling_rate", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if text is not None:
+ inputs = self.tokenizer(text, **kwargs)
+
+ if audio is not None:
+ audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
+
+ if audio is None:
+ return inputs
+
+ elif text is None:
+ return audio_inputs
+
+ else:
+ inputs["input_values"] = audio_inputs["input_values"]
+ if "padding_mask" in audio_inputs:
+ inputs["padding_mask"] = audio_inputs["padding_mask"]
+ return inputs
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
+ from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
+ [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
+ """
+ audio_values = kwargs.pop("audio", None)
+ padding_mask = kwargs.pop("padding_mask", None)
+
+ if len(args) > 0:
+ audio_values = args[0]
+ args = args[1:]
+
+ if audio_values is not None:
+ return self._decode_audio(audio_values, padding_mask=padding_mask)
+ else:
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
+ docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]:
+ """
+ This method strips any padding from the audio values to return a list of numpy audio arrays.
+ """
+ audio_values = to_numpy(audio_values)
+ bsz, channels, seq_len = audio_values.shape
+
+ if padding_mask is None:
+ return list(audio_values)
+
+ padding_mask = to_numpy(padding_mask)
+
+ # match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
+ # token (so that the generated audio values are **not** treated as padded tokens)
+ difference = seq_len - padding_mask.shape[-1]
+ padding_value = 1 - self.feature_extractor.padding_value
+ padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
+
+ audio_values = audio_values.tolist()
+ for i in range(bsz):
+ sliced_audio = np.asarray(audio_values[i])[
+ padding_mask[i][None, :] != self.feature_extractor.padding_value
+ ]
+ audio_values[i] = sliced_audio.reshape(channels, -1)
+
+ return audio_values
+
+
+__all__ = ["MusicgenProcessor"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0e18d77cbad11bf218b6cd304f703b4d7935eef
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_olmoe import *
+ from .modeling_olmoe import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..716df2b0da38cac8b121d452cceb4d471badb163
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aceb1ebbcc9839e5c491178ab2a1c93dfe91e604
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..209c76a1f741b292cbb4a60322e50d7bf549c239
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f24d5523bafd0398e6985959ffdd223261cc5ca
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py
@@ -0,0 +1,182 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""OLMoE model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class OlmoeConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50304):
+ Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`OlmoeModel`]
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 16):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ Padding token id.
+ bos_token_id (`int`, *optional*):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 50279):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ clip_qkv (`float`, *optional*):
+ If not `None`, elements of query, key and value attention states are clipped so that their
+ absolute value does not exceed this value.
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
+ Number of selected experts.
+ num_experts (`int`, *optional*, defaults to 64):
+ Number of routed experts.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether or not the router logits should be returned by the model. Enabeling this will also
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.01):
+ The aux loss factor for the total loss.
+ norm_topk_prob (`bool`, *optional*, defaults to `False`):
+ Whether to normalize the topk probabilities.
+
+ ```python
+ >>> from transformers import OlmoeModel, OlmoeConfig
+
+ >>> # Initializing a OLMoE 7B A1B style configuration
+ >>> configuration = OlmoeConfig()
+
+ >>> # Initializing a model from the OLMoE 7B A1B style configuration
+ >>> model = OlmoeModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "olmoe"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=50304,
+ hidden_size=2048,
+ intermediate_size=2048,
+ num_hidden_layers=16,
+ num_attention_heads=16,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-05,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=None,
+ eos_token_id=50279,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ clip_qkv=None,
+ num_experts_per_tok=8,
+ num_experts=64,
+ output_router_logits=False,
+ router_aux_loss_coef=0.01,
+ norm_topk_prob=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.clip_qkv = clip_qkv
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_experts = num_experts
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.norm_topk_prob = norm_topk_prob
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["OlmoeConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c78138c1a035155564955fdf33819a19c6de64a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py
@@ -0,0 +1,1299 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OLMoE model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_outputs import (
+ MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import ALL_LAYERNORM_LAYERS
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_olmoe import OlmoeConfig
+
+
+if is_flash_attn_2_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "OlmoeConfig"
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
+def load_balancing_loss_func(
+ gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
+ num_experts: Optional[int] = None,
+ top_k=2,
+ attention_mask: Optional[torch.Tensor] = None,
+) -> Union[torch.Tensor, int]:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ gate_logits:
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+ shape [batch_size X sequence_length, num_experts].
+ num_experts:
+ Number of experts
+ top_k:
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
+ parameter.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention_mask used in forward function
+ shape [batch_size X sequence_length] if not None.
+
+ Returns:
+ The auxiliary loss.
+ """
+ if gate_logits is None or not isinstance(gate_logits, tuple):
+ return 0
+
+ if isinstance(gate_logits, tuple):
+ compute_device = gate_logits[0].device
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
+
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
+
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+ .reshape(-1, top_k, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the percentage of tokens routed to each experts
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+ expert_attention_mask, dim=0
+ )
+
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+ .reshape(-1, num_experts)
+ .to(compute_device)
+ )
+
+ # Compute the average probability of routing to these experts
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+ router_per_expert_attention_mask, dim=0
+ )
+
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+class OlmoeRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-5):
+ """
+ OlmoeRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmoe
+class OlmoeRotaryEmbedding(nn.Module):
+ def __init__(self, config: OlmoeConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ # This .to() is needed if the model has been moved to a device after being initialized (because
+ # the buffer is automatically moved, but not the original copy)
+ self.original_inv_freq = self.original_inv_freq.to(device)
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmoe
+class OlmoeMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class OlmoeAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+ self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.k_norm = OlmoeRMSNorm(
+ (self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+
+ if self.config.clip_qkv is not None:
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class OlmoeFlashAttention2(OlmoeAttention):
+ """
+ OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+ if self.config.clip_qkv is not None:
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (OlmoeRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class OlmoeSdpaAttention(OlmoeAttention):
+ """
+ OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `OlmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from OlmoeAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+
+ if self.config.clip_qkv is not None:
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ # if attention_mask is not None and cache_position is not None:
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+OLMOE_ATTENTION_CLASSES = {
+ "eager": OlmoeAttention,
+ "flash_attention_2": OlmoeFlashAttention2,
+ "sdpa": OlmoeSdpaAttention,
+}
+
+
+class OlmoeSparseMoeBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+ self.norm_topk_prob = config.norm_topk_prob
+ self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
+ self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.gate(hidden_states)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ if self.norm_topk_prob:
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # One hot encode the selected experts to create an expert mask
+ # this will be used to easily index which expert is going to be selected
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ # Loop over all available experts in the model and perform the computation on each expert
+ for expert_idx in range(self.num_experts):
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx])
+
+ # Index the correct hidden states and compute the expert hidden state for
+ # the current expert. We need to make sure to multiply the output hidden
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ # However `index_add_` only support torch tensors for indexing so we'll use
+ # the `top_x` tensor here.
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states, router_logits
+
+
+class OlmoeDecoderLayer(nn.Module):
+ def __init__(self, config: OlmoeConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = OlmoeSparseMoeBlock(config)
+ self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
+ and should not be returned during inference.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states, router_logits = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ if output_router_logits:
+ outputs += (router_logits,)
+
+ return outputs
+
+
+OLMOE_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`OlmoeConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
+ OLMOE_START_DOCSTRING,
+)
+# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmoe
+class OlmoePreTrainedModel(PreTrainedModel):
+ config_class = OlmoeConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["OlmoeDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+OLMOE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ output_router_logits (`bool`, *optional*):
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
+ should not be returned during inference.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
+ OLMOE_START_DOCSTRING,
+)
+# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
+class OlmoeModel(OlmoePreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
+
+ Args:
+ config: OlmoeConfig
+ """
+
+ def __init__(self, config: OlmoeConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = OlmoeRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
+ # Ignore copy
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ output_router_logits,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ output_router_logits=output_router_logits,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if output_router_logits and layer_outputs[-1] is not None:
+ all_router_logits += (layer_outputs[-1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ router_logits=all_router_logits,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to plcae the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = OlmoeModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.router_aux_loss_coef = config.router_aux_loss_coef
+ self.num_experts = config.num_experts
+ self.num_experts_per_tok = config.num_experts_per_tok
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ **loss_kwargs,
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ num_logits_to_keep (`int`, *optional*):
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, OlmoeForCausalLM
+
+ >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(
+ outputs.router_logits if return_dict else outputs[-1],
+ self.num_experts,
+ self.num_experts_per_tok,
+ attention_mask,
+ )
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ if output_router_logits:
+ output = (aux_loss,) + output
+ return (loss,) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+__all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4903c400f9827239f7920b7e2c9bfddfda48eecd
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_pegasus import *
+ from .modeling_flax_pegasus import *
+ from .modeling_pegasus import *
+ from .modeling_tf_pegasus import *
+ from .tokenization_pegasus import *
+ from .tokenization_pegasus_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4019eb86fed9c828268939e36697a5032eb76685
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a6157f2e8aac32727b728298493ea3cb048cd8f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_flax_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_flax_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9296509d0f21aadf0643827eecb7b19269aa2520
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_flax_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72040283b4b6eb8e9be61c58975e152922403cc5
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_tf_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_tf_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b803946ecb775848dd2fee780d24bc4ae8f7219a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/modeling_tf_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3e94382ee5736de452b7263c89904dbd0d8cd63
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus_fast.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus_fast.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d46b699c5a1e250779504c7bc1d0a9feb8e6ac5
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/tokenization_pegasus_fast.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/configuration_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/configuration_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..23bff5d7719f45061328757c916dbfcfaaa0c365
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/configuration_pegasus.py
@@ -0,0 +1,164 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PEGASUS model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PegasusConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`PegasusModel`]. It is used to instantiate an
+ PEGASUS model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the PEGASUS
+ [google/pegasus-large](https://huggingface.co/google/pegasus-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50265):
+ Vocabulary size of the PEGASUS model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`PegasusModel`] or [`TFPegasusModel`].
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ encoder_layers (`int`, *optional*, defaults to 12):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(d_model).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models)
+ forced_eos_token_id (`int`, *optional*, defaults to 1):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+
+ Example:
+
+ ```python
+ >>> from transformers import PegasusConfig, PegasusModel
+
+ >>> # Initializing a PEGASUS google/pegasus-large style configuration
+ >>> configuration = PegasusConfig()
+
+ >>> # Initializing a model (with random weights) from the google/pegasus-large style configuration
+ >>> model = PegasusModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "pegasus"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=50265,
+ max_position_embeddings=1024,
+ encoder_layers=12,
+ encoder_ffn_dim=4096,
+ encoder_attention_heads=16,
+ decoder_layers=12,
+ decoder_ffn_dim=4096,
+ decoder_attention_heads=16,
+ encoder_layerdrop=0.0,
+ decoder_layerdrop=0.0,
+ use_cache=True,
+ is_encoder_decoder=True,
+ activation_function="gelu",
+ d_model=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ decoder_start_token_id=0,
+ scale_embedding=False,
+ pad_token_id=0,
+ eos_token_id=1,
+ forced_eos_token_id=1,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.decoder_layerdrop = decoder_layerdrop
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ super().__init__(
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ decoder_start_token_id=decoder_start_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ **kwargs,
+ )
+
+ @property
+ def num_attention_heads(self) -> int:
+ return self.encoder_attention_heads
+
+ @property
+ def hidden_size(self) -> int:
+ return self.d_model
+
+
+__all__ = ["PegasusConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..269a3268fed12770177f133b14177f7ca6347e98
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_flax_pegasus.py
@@ -0,0 +1,1532 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Flax PEGASUS model."""
+
+import math
+import random
+from functools import partial
+from typing import Callable, Optional, Tuple
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
+ FlaxSeq2SeqLMOutput,
+ FlaxSeq2SeqModelOutput,
+)
+from ...modeling_flax_utils import (
+ ACT2FN,
+ FlaxPreTrainedModel,
+ add_start_docstrings_to_model_forward,
+ append_call_sample_docstring,
+ append_replace_return_docstrings,
+ overwrite_call_docstring,
+)
+from ...utils import add_start_docstrings, logging, replace_return_docstrings
+from .configuration_pegasus import PegasusConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/pegasus-large"
+_CONFIG_FOR_DOC = "PegasusConfig"
+
+PEGASUS_START_DOCSTRING = r"""
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`PegasusConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+PEGASUS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+PEGASUS_ENCODE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+PEGASUS_DECODE_INPUTS_DOCSTRING = r"""
+ Args:
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
+def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = jnp.zeros_like(input_ids)
+ shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
+ shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
+
+ shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
+ return shifted_input_ids
+
+
+# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions
+def create_sinusoidal_positions(n_pos, dim):
+ position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
+ sentinel = dim // 2 + dim % 2
+ out = np.zeros_like(position_enc)
+ out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
+ out[:, sentinel:] = np.cos(position_enc[:, 1::2])
+
+ return jnp.array(out)
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Pegasus
+class FlaxPegasusAttention(nn.Module):
+ config: PegasusConfig
+ embed_dim: int
+ num_heads: int
+ dropout: float = 0.0
+ causal: bool = False
+ bias: bool = True
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self) -> None:
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ dense = partial(
+ nn.Dense,
+ self.embed_dim,
+ use_bias=self.bias,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
+ self.out_proj = dense()
+
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ key_value_states: Optional[jnp.ndarray] = None,
+ attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states)
+ value_states = self.v_proj(key_value_states)
+ else:
+ # self_attention
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # Convert the boolean attention mask to an attention bias.
+ if attention_mask is not None:
+ # attention mask in the form of attention bias
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Pegasus
+class FlaxPegasusEncoderLayer(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self) -> None:
+ self.embed_dim = self.config.d_model
+ self.self_attn = FlaxPegasusAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.encoder_attention_heads,
+ dropout=self.config.attention_dropout,
+ dtype=self.dtype,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+ self.activation_fn = ACT2FN[self.config.activation_function]
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
+ self.fc1 = nn.Dense(
+ self.config.encoder_ffn_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.fc2 = nn.Dense(
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+ )
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ attention_mask: jnp.ndarray,
+ output_attentions: bool = True,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Pegasus
+class FlaxPegasusEncoderLayerCollection(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxPegasusEncoderLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.encoder_layers)
+ ]
+ self.layerdrop = self.config.encoder_layerdrop
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
+ layer_outputs = (None, None)
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ deterministic,
+ )
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ outputs = (hidden_states, all_hidden_states, all_attentions)
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Pegasus
+class FlaxPegasusDecoderLayer(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self) -> None:
+ self.embed_dim = self.config.d_model
+ self.self_attn = FlaxPegasusAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.decoder_attention_heads,
+ dropout=self.config.attention_dropout,
+ causal=True,
+ dtype=self.dtype,
+ )
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+ self.activation_fn = ACT2FN[self.config.activation_function]
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
+
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.encoder_attn = FlaxPegasusAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.decoder_attention_heads,
+ dropout=self.config.attention_dropout,
+ dtype=self.dtype,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.fc1 = nn.Dense(
+ self.config.decoder_ffn_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.fc2 = nn.Dense(
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+ )
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ attention_mask: jnp.ndarray,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ output_attentions: bool = True,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
+ )
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+ hidden_states, cross_attn_weights = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Pegasus
+class FlaxPegasusDecoderLayerCollection(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxPegasusDecoderLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.decoder_layers)
+ ]
+ self.layerdrop = self.config.decoder_layerdrop
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if not deterministic and (dropout_probability < self.layerdrop):
+ layer_outputs = (None, None, None)
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class FlaxPegasusEncoder(nn.Module):
+ config: PegasusConfig
+ embed_tokens: nn.Embed
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+ embed_dim = self.config.d_model
+ self.padding_idx = self.config.pad_token_id
+ self.max_source_positions = self.config.max_position_embeddings
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
+
+ self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
+ self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype)
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ input_shape = input_ids.shape
+ input_ids = input_ids.reshape(-1, input_shape[-1])
+
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ # embed positions
+ embed_pos = jnp.take(self.embed_positions, position_ids, axis=0)
+ # explicitly cast the positions here, since self.embed_positions are not registered as parameters
+ embed_pos = embed_pos.astype(inputs_embeds.dtype)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ outputs = self.layers(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+ last_hidden_state = self.layer_norm(last_hidden_state)
+
+ # update the last element in `hidden_states` after applying `layernorm` above
+ hidden_states = None
+ if output_hidden_states:
+ hidden_states = outputs[1]
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
+
+ if not return_dict:
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=last_hidden_state,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class FlaxPegasusDecoder(nn.Module):
+ config: PegasusConfig
+ embed_tokens: nn.Embed
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+ embed_dim = self.config.d_model
+ self.padding_idx = self.config.pad_token_id
+ self.max_target_positions = self.config.max_position_embeddings
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
+
+ self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
+
+ self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype)
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ input_shape = input_ids.shape
+ input_ids = input_ids.reshape(-1, input_shape[-1])
+
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ # embed positions
+ positions = jnp.take(self.embed_positions, position_ids, axis=0)
+ # explicitly cast the positions here, since self.embed_positions are not registered as parameters
+ positions = positions.astype(inputs_embeds.dtype)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ outputs = self.layers(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = outputs[0]
+ last_hidden_state = self.layer_norm(last_hidden_state)
+
+ # update the last element in `hidden_states` after applying `layernorm` above
+ hidden_states = None
+ if output_hidden_states:
+ hidden_states = outputs[1]
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
+
+ if not return_dict:
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=last_hidden_state,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Pegasus
+class FlaxPegasusModule(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.shared = nn.Embed(
+ self.config.vocab_size,
+ self.config.d_model,
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
+ dtype=self.dtype,
+ )
+
+ self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
+ self.decoder = FlaxPegasusDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
+
+ def _get_encoder_module(self):
+ return self.encoder
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return FlaxSeq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
+ config_class = PegasusConfig
+ base_model_prefix: str = "model"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: PegasusConfig,
+ input_shape: Tuple[int] = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ decoder_input_ids = input_ids
+ decoder_attention_mask = jnp.ones_like(input_ids)
+
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ )["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length, encoder_outputs):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
+ cross-attention of the decoder.
+ """
+ # init input variables to retrieve cache
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
+ )
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ decoder_position_ids,
+ **kwargs,
+ )
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0),
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_outputs[0],
+ init_cache=True,
+ method=_decoder_forward, # we only need to call the decoder to init the cache
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings(PEGASUS_ENCODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=PegasusConfig)
+ def encode(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if position_ids is None:
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
+ encode_module = module._get_encoder_module()
+ return encode_module(input_ids, attention_mask, position_ids, **kwargs)
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ method=_encoder_forward,
+ )
+
+ @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=PegasusConfig)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import jax.numpy as jnp
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> last_decoder_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ if decoder_position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxPegasusAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ decoder_position_ids,
+ **kwargs,
+ )
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past = outputs
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past = outputs
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+ @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ decoder_input_ids: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # prepare encoder inputs
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if position_ids is None:
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ # prepare decoder inputs
+ if decoder_input_ids is None:
+ decoder_input_ids = shift_tokens_right(
+ input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
+ )
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+ if decoder_position_ids is None:
+ batch_size, sequence_length = decoder_input_ids.shape
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ )
+
+
+@add_start_docstrings(
+ "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
+ PEGASUS_START_DOCSTRING,
+)
+class FlaxPegasusModel(FlaxPegasusPreTrainedModel):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ module_class = FlaxPegasusModule
+
+
+append_call_sample_docstring(FlaxPegasusModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Pegasus
+class FlaxPegasusForConditionalGenerationModule(nn.Module):
+ config: PegasusConfig
+ dtype: jnp.dtype = jnp.float32
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
+
+ def setup(self):
+ self.model = FlaxPegasusModule(config=self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.model.shared.num_embeddings,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
+
+ def _get_encoder_module(self):
+ return self.model.encoder
+
+ def _get_decoder_module(self):
+ return self.model.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ position_ids=position_ids,
+ decoder_position_ids=decoder_position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return output
+
+ return FlaxSeq2SeqLMOutput(
+ logits=lm_logits,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING
+)
+class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
+ module_class = FlaxPegasusForConditionalGenerationModule
+ dtype: jnp.dtype = jnp.float32
+
+ @add_start_docstrings(PEGASUS_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=PegasusConfig)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ deterministic: bool = True,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import jax.numpy as jnp
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer(text, max_length=1024, return_tensors="np")
+ >>> encoder_outputs = model.encode(**inputs)
+
+ >>> decoder_start_token_id = model.config.decoder_start_token_id
+ >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> logits = outputs.logits
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ if decoder_position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxPegasusAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+ decoder_module = module._get_decoder_module()
+ outputs = decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ decoder_position_ids,
+ **kwargs,
+ )
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = module.model.variables["params"]["shared"]["embedding"]
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+ else:
+ lm_logits = module.lm_head(hidden_states)
+
+ lm_logits += module.final_logits_bias.astype(self.dtype)
+ return lm_logits, outputs
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ if past_key_values is None:
+ lm_logits, decoder_outputs = outputs
+ else:
+ (lm_logits, decoder_outputs), past = outputs
+
+ if return_dict:
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
+ logits=lm_logits,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ )
+ else:
+ outputs = (lm_logits,) + decoder_outputs[1:]
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ max_length,
+ attention_mask: Optional[jax.Array] = None,
+ decoder_attention_mask: Optional[jax.Array] = None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # initializing the cache
+ batch_size, seq_length = decoder_input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if decoder_attention_mask is not None:
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "encoder_outputs": encoder_outputs,
+ "encoder_attention_mask": attention_mask,
+ "decoder_attention_mask": extended_attention_mask,
+ "decoder_position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING = """
+ Returns:
+
+ Summarization example:
+
+ ```pyton
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
+ >>> tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large')
+
+ >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
+
+ >>> # Generate Summary
+ >>> summary_ids = model.generate(inputs['input_ids']).sequences
+ >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
+ ```
+
+ Mask filling example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FlaxPegasusForConditionalGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+ >>> TXT = "My friends are but they eat too many carbs."
+
+ >>> model = FlaxPegasusForConditionalGeneration.from_pretrained("google/pegasus-large")
+ >>> input_ids = tokenizer([TXT], return_tensors="np")["input_ids"]
+ >>> logits = model(input_ids).logits
+
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
+ >>> values, predictions = jax.lax.top_k(probs)
+
+ >>> tokenizer.decode(predictions).split()
+ ```
+"""
+
+overwrite_call_docstring(
+ FlaxPegasusForConditionalGeneration, PEGASUS_INPUTS_DOCSTRING + FLAX_PEGASUS_CONDITIONAL_GENERATION_DOCSTRING
+)
+append_replace_return_docstrings(
+ FlaxPegasusForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
+)
+
+
+__all__ = ["FlaxPegasusForConditionalGeneration", "FlaxPegasusModel", "FlaxPegasusPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b1cd70404c7e63147a8686cc90205463dd856a3
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_pegasus.py
@@ -0,0 +1,1634 @@
+# coding=utf-8
+# Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PEGASUS model."""
+
+import copy
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_end_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_pegasus import PegasusConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/pegasus-large"
+_CONFIG_FOR_DOC = "PegasusConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
+class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
+ super().__init__(num_positions, embedding_dim)
+ self.weight = self._init_weight(self.weight)
+
+ @staticmethod
+ def _init_weight(out: nn.Parameter) -> nn.Parameter:
+ """
+ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
+ the 2nd half of the vector. [dim // 2:]
+ """
+ n_pos, dim = out.shape
+ position_enc = np.array(
+ [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
+ )
+ out.requires_grad = False # set early to avoid an error in pytorch-1.8+
+ sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
+ out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
+ out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
+ out.detach_()
+ return out
+
+ @torch.no_grad()
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ bsz, seq_len = input_ids_shape[:2]
+ positions = torch.arange(
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+ )
+ return super().forward(positions)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus
+class PegasusAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[PegasusConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention}
+
+
+# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
+class PegasusEncoderLayer(nn.Module):
+ def __init__(self, config: PegasusConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_head_mask: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS
+class PegasusDecoderLayer(nn.Module):
+ def __init__(self, config: PegasusConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ config=config,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class PegasusPreTrainedModel(PreTrainedModel):
+ config_class = PegasusConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, PegasusSinusoidalPositionalEmbedding):
+ pass
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+PEGASUS_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`PegasusConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+PEGASUS_GENERATION_EXAMPLE = r"""
+ Summarization example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PegasusForConditionalGeneration
+
+ >>> model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
+
+ >>> ARTICLE_TO_SUMMARIZE = (
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
+ ... )
+ >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt")
+
+ >>> # Generate Summary
+ >>> summary_ids = model.generate(inputs["input_ids"])
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "California's largest electricity provider has turned off power to hundreds of thousands of customers."
+ ```
+"""
+
+PEGASUS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+ 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+ input (see `past_key_values`). This is useful if you want more control over how to convert
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+ of `inputs_embeds`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class PegasusEncoder(PegasusPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`PegasusEncoderLayer`].
+
+ Args:
+ config: PegasusConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ if embed_tokens is not None:
+ self.embed_tokens = embed_tokens
+ else:
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
+
+ self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ embed_dim,
+ self.padding_idx,
+ )
+ self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
+ self.config.max_position_embeddings = new_num_position_embeddings
+
+ self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+ self.config.max_position_embeddings,
+ self.config.d_model,
+ self.padding_idx,
+ )
+ self.embed_positions.to(self.device)
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings matrix
+ """
+ return self.embed_positions
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ embed_pos = self.embed_positions(input_shape)
+
+ hidden_states = inputs_embeds + embed_pos
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class PegasusDecoder(PegasusPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`]
+
+ Args:
+ config: PegasusConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_position_embeddings
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ if embed_tokens is not None:
+ self.embed_tokens = embed_tokens
+ else:
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
+
+ self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ self.padding_idx,
+ )
+ self.layers = nn.ModuleList([PegasusDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self.layer_norm = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
+ self.config.max_position_embeddings = new_num_position_embeddings
+
+ self.embed_positions = PegasusSinusoidalPositionalEmbedding(
+ self.config.max_position_embeddings,
+ self.config.d_model,
+ self.padding_idx,
+ )
+ self.embed_positions.to(self.device)
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings matrix
+ """
+ return self.embed_positions
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ # embed positions
+ positions = self.embed_positions(input_shape, past_key_values_length)
+
+ hidden_states = inputs_embeds + positions
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
+ PEGASUS_START_DOCSTRING,
+)
+class PegasusModel(PegasusPreTrainedModel):
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+ def __init__(self, config: PegasusConfig):
+ super().__init__(config)
+
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+ self.encoder = PegasusEncoder(config, self.shared)
+ self.decoder = PegasusDecoder(config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, value):
+ self.shared = value
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.config.max_position_embeddings = new_num_position_embeddings
+ self.encoder.resize_position_embeddings(new_num_position_embeddings)
+ self.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+ def get_position_embeddings(self) -> Tuple[nn.Embedding]:
+ """
+ Returns the position embeddings matrix
+ """
+ return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
+
+ @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.Tensor] = None,
+ decoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PegasusModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+ >>> model = PegasusModel.from_pretrained("google/pegasus-large")
+
+ >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
+ >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt")
+ >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ [1, 4, 1024]
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING
+)
+class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
+ base_model_prefix = "model"
+ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+ def __init__(self, config: PegasusConfig):
+ super().__init__(config)
+ self.model = PegasusModel(config)
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.config.max_position_embeddings = new_num_position_embeddings
+ self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
+ self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+ def get_position_embeddings(self) -> Tuple[nn.Embedding]:
+ """
+ Returns the position embeddings matrix
+ """
+ return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
+
+ @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.Tensor] = None,
+ decoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if use_cache:
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ # cached cross_attention states don't have to be reordered -> they are always the same
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ + layer_past[2:],
+ )
+ return reordered_past
+
+
+# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
+class PegasusDecoderWrapper(PegasusPreTrainedModel):
+ """
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+ used in combination with the [`EncoderDecoderModel`] framework.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.decoder = PegasusDecoder(config)
+
+ def forward(self, *args, **kwargs):
+ return self.decoder(*args, **kwargs)
+
+
+class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ config = copy.deepcopy(config)
+ config.is_decoder = True
+ config.is_encoder_decoder = False
+ super().__init__(config)
+ self.model = PegasusDecoderWrapper(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings matrix
+ """
+ return self.model.decoder.get_position_embeddings()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
+ config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.config.max_position_embeddings = new_num_position_embeddings
+ self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
+
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ if the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PegasusForCausalLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
+ >>> model = PegasusForCausalLM.from_pretrained("google/pegasus-large", add_cross_attention=False)
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+ >>> list(logits.shape) == expected_shape
+ True
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = self.lm_head(outputs[0])
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+__all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..a51835dcfa447f889913b3de7506d38af690b741
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/modeling_tf_pegasus.py
@@ -0,0 +1,1574 @@
+# coding=utf-8
+# Copyright 2021, Google Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 Pegasus model."""
+
+from __future__ import annotations
+
+import random
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFSeq2SeqLMOutput,
+ TFSeq2SeqModelOutput,
+)
+
+# Public API
+from ...modeling_tf_utils import (
+ TFCausalLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ add_code_sample_docstrings,
+ add_end_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_pegasus import PegasusConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/pegasus-large"
+_CONFIG_FOR_DOC = "PegasusConfig"
+
+
+LARGE_NEGATIVE = -1e8
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
+def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
+ decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
+ start_tokens = tf.fill(
+ (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
+ )
+ shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids = tf.where(
+ shifted_input_ids == -100,
+ tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
+ shifted_input_ids,
+ )
+
+ # "Verify that `labels` has only positive values and -100"
+ assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
+
+ # Make sure the assertion op is called by wrapping the result in an identity no-op
+ with tf.control_dependencies([assert_gte0]):
+ shifted_input_ids = tf.identity(shifted_input_ids)
+
+ return shifted_input_ids
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
+def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz = input_ids_shape[0]
+ tgt_len = input_ids_shape[1]
+ mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
+ mask_cond = tf.range(shape_list(mask)[-1])
+
+ mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
+
+ if past_key_values_length > 0:
+ mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
+
+ return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
+
+
+# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ src_len = shape_list(mask)[1]
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ one_cst = tf.constant(1.0)
+ mask = tf.cast(mask, dtype=one_cst.dtype)
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
+
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
+
+
+# Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus
+class TFPegasusSinusoidalPositionalEmbedding(keras.layers.Layer):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
+ super().__init__(**kwargs)
+
+ if embedding_dim % 2 != 0:
+ raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
+
+ self.embedding_dim = embedding_dim
+ self.num_positions = num_positions
+
+ def build(self, input_shape: tf.TensorShape):
+ """
+ Build shared token embedding layer Shared weights logic adapted from
+ https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
+ """
+
+ weight = self._init_weight(self.num_positions, self.embedding_dim)
+
+ self.weight = self.add_weight(
+ name="embeddings",
+ shape=[self.num_positions, self.embedding_dim],
+ )
+ weight = tf.cast(weight, dtype=self.weight.dtype)
+
+ self.weight.assign(weight)
+
+ super().build(input_shape)
+
+ @staticmethod
+ def _init_weight(n_pos: int, dim: int):
+ """
+ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
+ the 2nd half of the vector. [dim // 2:]
+ """
+ position_enc = np.array(
+ [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
+ )
+ table = np.zeros_like(position_enc)
+ # index 0 is all zero
+ table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
+ table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
+ # convert to tensor
+ table = tf.convert_to_tensor(table)
+ tf.stop_gradient(table)
+ return table
+
+ def call(
+ self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None
+ ):
+ """Input is expected to be of size [bsz x seqlen]."""
+ if position_ids is None:
+ seq_len = input_shape[1]
+ position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
+ return tf.gather(self.weight, position_ids)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus
+class TFPegasusAttention(keras.layers.Layer):
+ """Multi-headed attention from "Attention Is All You Need"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = keras.layers.Dropout(dropout)
+ self.head_dim = embed_dim // num_heads
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
+ self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
+ self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
+ self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
+
+ def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+ return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ key_value_states: tf.Tensor | None = None,
+ past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
+ attention_mask: tf.Tensor | None = None,
+ layer_head_mask: tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Tuple[tf.Tensor, tf.Tensor | None]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, embed_dim = shape_list(hidden_states)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = tf.concat([past_key_value[0], key_states], axis=2)
+ value_states = tf.concat([past_key_value[1], value_states], axis=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
+ key_states = tf.reshape(key_states, proj_shape)
+ value_states = tf.reshape(value_states, proj_shape)
+
+ src_len = shape_list(key_states)[1]
+ attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
+
+ tf.debugging.assert_equal(
+ shape_list(attn_weights),
+ [bsz * self.num_heads, tgt_len, src_len],
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
+ )
+
+ if attention_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(attention_mask),
+ [bsz, 1, tgt_len, src_len],
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
+ )
+
+ attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
+ attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_weights = stable_softmax(attn_weights, axis=-1)
+
+ if layer_head_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(layer_head_mask),
+ [self.num_heads],
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
+ )
+
+ attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
+ attn_weights, (bsz, self.num_heads, tgt_len, src_len)
+ )
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_probs = self.dropout(attn_weights, training=training)
+ attn_output = tf.matmul(attn_probs, value_states)
+
+ tf.debugging.assert_equal(
+ shape_list(attn_output),
+ [bsz * self.num_heads, tgt_len, self.head_dim],
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
+ )
+
+ attn_output = tf.transpose(
+ tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
+ )
+ attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+ attn_output = self.out_proj(attn_output)
+ attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
+
+ return attn_output, attn_weights, past_key_value
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "k_proj", None) is not None:
+ with tf.name_scope(self.k_proj.name):
+ self.k_proj.build([None, None, self.embed_dim])
+ if getattr(self, "q_proj", None) is not None:
+ with tf.name_scope(self.q_proj.name):
+ self.q_proj.build([None, None, self.embed_dim])
+ if getattr(self, "v_proj", None) is not None:
+ with tf.name_scope(self.v_proj.name):
+ self.v_proj.build([None, None, self.embed_dim])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.embed_dim])
+
+
+# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Pegasus
+class TFPegasusEncoderLayer(keras.layers.Layer):
+ def __init__(self, config: PegasusConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = config.d_model
+ self.self_attn = TFPegasusAttention(
+ self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
+ )
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.activation_fn = get_tf_activation(config.activation_function)
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
+ self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
+ self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ layer_head_mask: tf.Tensor,
+ training: Optional[bool] = False,
+ ):
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
+ attention_mask (`tf.Tensor`): attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
+ *(encoder_attention_heads,)*
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, self_attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
+ )
+
+ tf.debugging.assert_equal(
+ shape_list(hidden_states),
+ shape_list(residual),
+ message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
+ )
+
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = self.activation_dropout(hidden_states, training=training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ return hidden_states, self_attn_weights
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attn", None) is not None:
+ with tf.name_scope(self.self_attn.name):
+ self.self_attn.build(None)
+ if getattr(self, "self_attn_layer_norm", None) is not None:
+ with tf.name_scope(self.self_attn_layer_norm.name):
+ self.self_attn_layer_norm.build([None, None, self.embed_dim])
+ if getattr(self, "fc1", None) is not None:
+ with tf.name_scope(self.fc1.name):
+ self.fc1.build([None, None, self.embed_dim])
+ if getattr(self, "fc2", None) is not None:
+ with tf.name_scope(self.fc2.name):
+ self.fc2.build([None, None, self.config.encoder_ffn_dim])
+ if getattr(self, "final_layer_norm", None) is not None:
+ with tf.name_scope(self.final_layer_norm.name):
+ self.final_layer_norm.build([None, None, self.embed_dim])
+
+
+# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Pegasus
+class TFPegasusDecoderLayer(keras.layers.Layer):
+ def __init__(self, config: PegasusConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = config.d_model
+ self.self_attn = TFPegasusAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ name="self_attn",
+ is_decoder=True,
+ )
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.activation_fn = get_tf_activation(config.activation_function)
+ self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
+
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+ self.encoder_attn = TFPegasusAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ name="encoder_attn",
+ is_decoder=True,
+ )
+ self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
+ self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
+ self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
+ self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ encoder_hidden_states: tf.Tensor | None = None,
+ encoder_attention_mask: tf.Tensor | None = None,
+ layer_head_mask: tf.Tensor | None = None,
+ cross_attn_layer_head_mask: tf.Tensor | None = None,
+ past_key_value: Tuple[tf.Tensor] | None = None,
+ training: Optional[bool] = False,
+ ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
+ attention_mask (`tf.Tensor`): attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`tf.Tensor`):
+ cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
+ encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
+ layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
+ *(decoder_attention_heads,)*
+ cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
+ *(decoder_attention_heads,)*
+ past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = self.activation_dropout(hidden_states, training=training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ return (
+ hidden_states,
+ self_attn_weights,
+ cross_attn_weights,
+ present_key_value,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attn", None) is not None:
+ with tf.name_scope(self.self_attn.name):
+ self.self_attn.build(None)
+ if getattr(self, "self_attn_layer_norm", None) is not None:
+ with tf.name_scope(self.self_attn_layer_norm.name):
+ self.self_attn_layer_norm.build([None, None, self.embed_dim])
+ if getattr(self, "encoder_attn", None) is not None:
+ with tf.name_scope(self.encoder_attn.name):
+ self.encoder_attn.build(None)
+ if getattr(self, "encoder_attn_layer_norm", None) is not None:
+ with tf.name_scope(self.encoder_attn_layer_norm.name):
+ self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
+ if getattr(self, "fc1", None) is not None:
+ with tf.name_scope(self.fc1.name):
+ self.fc1.build([None, None, self.embed_dim])
+ if getattr(self, "fc2", None) is not None:
+ with tf.name_scope(self.fc2.name):
+ self.fc2.build([None, None, self.config.decoder_ffn_dim])
+ if getattr(self, "final_layer_norm", None) is not None:
+ with tf.name_scope(self.final_layer_norm.name):
+ self.final_layer_norm.build([None, None, self.embed_dim])
+
+
+class TFPegasusPreTrainedModel(TFPreTrainedModel):
+ config_class = PegasusConfig
+ base_model_prefix = "model"
+
+
+PEGASUS_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Args:
+ config ([`PegasusConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+PEGASUS_GENERATION_EXAMPLE = r"""
+ Summarization example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, TFPegasusForConditionalGeneration
+
+ >>> model = TFPegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
+
+ >>> ARTICLE_TO_SUMMARIZE = (
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
+ ... )
+ >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="tf")
+
+ >>> # Generate Summary
+ >>> summary_ids = model.generate(input_ids)
+ >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
+ ```
+"""
+
+PEGASUS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+ decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
+ decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tf.FloatTensor`, *optional*):
+ hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ of shape `(batch_size, sequence_length, hidden_size)` is a sequence of
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation output_attentions (`bool`,
+ *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions`
+ under returned tensors for more detail. This argument can be used only in eager mode, in graph mode the
+ value in the config will be used instead.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@keras_serializable
+class TFPegasusEncoder(keras.layers.Layer):
+ config_class = PegasusConfig
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`TFPegasusEncoderLayer`].
+
+ Args:
+ config: PegasusConfig
+ """
+
+ def __init__(self, config: PegasusConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.layerdrop = config.encoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
+
+ self.embed_tokens = embed_tokens
+ self.embed_positions = TFPegasusSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ name="embed_positions",
+ )
+ self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
+
+ def get_embed_tokens(self):
+ return self.embed_tokens
+
+ def set_embed_tokens(self, embed_tokens):
+ self.embed_tokens = embed_tokens
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ):
+ """
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
+ in the config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail. This argument can be used only in eager mode, in graph mode the value in the config
+ will be used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used
+ in eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ embed_pos = self.embed_positions(input_shape)
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ # check attention mask and invert
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask)
+ else:
+ attention_mask = None
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(head_mask)[0],
+ len(self.layers),
+ message=(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(head_mask)[0]}."
+ ),
+ )
+
+ # encoder layers
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if training and (dropout_probability < self.layerdrop): # skip the layer
+ continue
+
+ hidden_states, attn = encoder_layer(
+ hidden_states,
+ attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ )
+
+ if output_attentions:
+ all_attentions += (attn,)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embed_positions", None) is not None:
+ with tf.name_scope(self.embed_positions.name):
+ self.embed_positions.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.d_model])
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFPegasusDecoder(keras.layers.Layer):
+ config_class = PegasusConfig
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFPegasusDecoderLayer`]
+
+ Args:
+ config: PegasusConfig
+ embed_tokens: output embedding
+ """
+
+ def __init__(self, config: PegasusConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.embed_tokens = embed_tokens
+ self.layerdrop = config.decoder_layerdrop
+ self.embed_positions = TFPegasusSinusoidalPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ name="embed_positions",
+ )
+ self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
+ self.layers = [TFPegasusDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
+
+ self.dropout = keras.layers.Dropout(config.dropout)
+
+ def get_embed_tokens(self):
+ return self.embed_tokens
+
+ def set_embed_tokens(self, embed_tokens):
+ self.embed_tokens = embed_tokens
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ position_ids: tf.Tensor | None = None,
+ encoder_hidden_states: tf.Tensor | None = None,
+ encoder_attention_mask: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ cross_attn_head_mask: tf.Tensor | None = None,
+ past_key_values: Tuple[Tuple[tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ):
+ r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.max_position_embeddings - 1]`.
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
+ decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value
+ in the config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail. This argument can be used only in eager mode, in graph mode the value in the config
+ will be used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used
+ in eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
+
+ # embed positions
+ if position_ids is None:
+ positions = self.embed_positions(input_shape, past_key_values_length)
+ else:
+ positions = self.embed_positions(input_shape, position_ids=position_ids)
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ hidden_states = inputs_embeds
+
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
+ else:
+ combined_attention_mask = _expand_mask(
+ tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
+ )
+
+ if attention_mask is not None:
+ combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
+
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
+
+ hidden_states = self.dropout(hidden_states + positions, training=training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None
+ present_key_values = () if use_cache else None
+
+ # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
+ for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
+ if attn_mask is not None:
+ tf.debugging.assert_equal(
+ shape_list(attn_mask)[0],
+ len(self.layers),
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+
+ if training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
+ hidden_states,
+ attention_mask=combined_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=head_mask[idx] if head_mask is not None else None,
+ cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ past_key_value=past_key_value,
+ )
+
+ if use_cache:
+ present_key_values += (present_key_value,)
+
+ if output_attentions:
+ all_self_attns += (layer_self_attn,)
+
+ if encoder_hidden_states is not None:
+ all_cross_attns += (layer_cross_attn,)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
+ else:
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attns,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embed_positions", None) is not None:
+ with tf.name_scope(self.embed_positions.name):
+ self.embed_positions.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.d_model])
+ if getattr(self, "layers", None) is not None:
+ for layer in self.layers:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFPegasusMainLayer(keras.layers.Layer):
+ config_class = PegasusConfig
+
+ def __init__(self, config: PegasusConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.shared = keras.layers.Embedding(
+ input_dim=config.vocab_size,
+ output_dim=config.d_model,
+ embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
+ name="model.shared",
+ )
+ # Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
+ self.shared.load_weight_prefix = "model.shared"
+
+ self.encoder = TFPegasusEncoder(config, self.shared, name="encoder")
+ self.decoder = TFPegasusDecoder(config, self.shared, name="decoder")
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared = new_embeddings
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: tf.Tensor | None = None,
+ attention_mask: tf.Tensor | None = None,
+ decoder_input_ids: tf.Tensor | None = None,
+ decoder_attention_mask: tf.Tensor | None = None,
+ decoder_position_ids: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ decoder_head_mask: tf.Tensor | None = None,
+ cross_attn_head_mask: tf.Tensor | None = None,
+ encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
+ past_key_values: Tuple[Tuple[tf.Tensor]] = None,
+ inputs_embeds: tf.Tensor | None = None,
+ decoder_inputs_embeds: tf.Tensor | None = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs,
+ ):
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ use_cache = False
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
+ encoder_outputs = TFBaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+ # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
+ elif not return_dict and not isinstance(encoder_outputs, tuple):
+ encoder_outputs = encoder_outputs.to_tuple()
+
+ decoder_outputs = self.decoder(
+ decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return TFSeq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ # The shared/tied weights expect to be in the model base namespace
+ # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
+ # the current one.
+ with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
+ self.shared.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "decoder", None) is not None:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build(None)
+
+
+@add_start_docstrings(
+ "The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
+ PEGASUS_START_DOCSTRING,
+)
+class TFPegasusModel(TFPegasusPreTrainedModel):
+ def __init__(self, config: PegasusConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.model = TFPegasusMainLayer(config, name="model")
+
+ def get_encoder(self):
+ return self.model.encoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFSeq2SeqModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_head_mask: np.ndarray | tf.Tensor | None = None,
+ cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ **kwargs,
+ ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ encoder_outputs=encoder_outputs,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
+ dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
+ cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
+ enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
+ enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
+
+ return TFSeq2SeqModelOutput(
+ last_hidden_state=output.last_hidden_state,
+ past_key_values=pkv,
+ decoder_hidden_states=dec_hs,
+ decoder_attentions=dec_attns,
+ cross_attentions=cross_attns,
+ encoder_last_hidden_state=output.encoder_last_hidden_state,
+ encoder_hidden_states=enc_hs,
+ encoder_attentions=enc_attns,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "model", None) is not None:
+ with tf.name_scope(self.model.name):
+ self.model.build(None)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
+class BiasLayer(keras.layers.Layer):
+ """
+ Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
+ so all weights have to be registered in a layer.
+ """
+
+ def __init__(self, shape, initializer, trainable, name, **kwargs):
+ super().__init__(name=name, **kwargs)
+ # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
+ # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
+ # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
+ self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
+
+ def call(self, x):
+ return x + self.bias
+
+
+@add_start_docstrings(
+ "The PEGASUS Model with a language modeling head. Can be used for summarization.",
+ PEGASUS_START_DOCSTRING,
+)
+class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLanguageModelingLoss):
+ _keys_to_ignore_on_load_unexpected = [
+ r"model.encoder.embed_tokens.weight",
+ r"model.decoder.embed_tokens.weight",
+ ]
+
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.model = TFPegasusMainLayer(config, name="model")
+ self.use_cache = config.use_cache
+ # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
+ self.bias_layer = BiasLayer(
+ name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
+ )
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def get_encoder(self):
+ return self.model.encoder
+
+ def get_output_embeddings(self):
+ return self.get_input_embeddings()
+
+ def set_output_embeddings(self, value):
+ self.set_input_embeddings(value)
+
+ def get_bias(self):
+ return {"final_logits_bias": self.bias_layer.bias}
+
+ def set_bias(self, value):
+ # Replaces the existing layers containing bias for correct (de)serialization.
+ vocab_size = value["final_logits_bias"].shape[-1]
+ self.bias_layer = BiasLayer(
+ name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
+ )
+ self.bias_layer.bias.assign(value["final_logits_bias"])
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_head_mask: np.ndarray | tf.Tensor | None = None,
+ cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: Optional[TFBaseModelOutput] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
+ """
+ labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ """
+
+ if labels is not None:
+ labels = tf.where(
+ labels == self.config.pad_token_id,
+ tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
+ labels,
+ )
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
+ lm_logits = self.bias_layer(lm_logits)
+ masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ return TFSeq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values, # index 1 of d outputs
+ decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
+ decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
+ cross_attentions=outputs.cross_attentions, # index 4 of d outputs
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
+ encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
+ encoder_attentions=outputs.encoder_attentions, # 2 of e out
+ )
+
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
+ dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
+ cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
+ enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
+ enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
+
+ return TFSeq2SeqLMOutput(
+ logits=output.logits,
+ past_key_values=pkv,
+ decoder_hidden_states=dec_hs,
+ decoder_attentions=dec_attns,
+ cross_attentions=cross_attns,
+ encoder_last_hidden_state=output.encoder_last_hidden_state,
+ encoder_hidden_states=enc_hs,
+ encoder_attentions=enc_attns,
+ )
+
+ # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ decoder_input_ids = decoder_input_ids[:, -1:]
+
+ if decoder_attention_mask is not None: # xla
+ decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
+ elif past_key_values is not None: # no xla + past_key_values
+ decoder_position_ids = past_key_values[0][0].shape[2]
+ else: # no xla + no past_key_values
+ decoder_position_ids = tf.range(decoder_input_ids.shape[1])
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "decoder_position_ids": decoder_position_ids,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "model", None) is not None:
+ with tf.name_scope(self.model.name):
+ self.model.build(None)
+ if getattr(self, "bias_layer", None) is not None:
+ with tf.name_scope(self.bias_layer.name):
+ self.bias_layer.build(None)
+
+
+__all__ = ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bab5073322a291763fd4b75e538b2959cd45579
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus.py
@@ -0,0 +1,288 @@
+# coding=utf-8
+# Copyright 2020 Google and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO ArthurZ refactor this to only use the added_tokens_encoder
+class PegasusTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a PEGASUS tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking single token values. This is the token used when training this model with masked
+ language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining.
+ It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive
+ Summarization](https://arxiv.org/pdf/1912.08777.pdf).
+ mask_token_sent (`str`, *optional*, defaults to `""`):
+ The token used for masking whole target sentences. This is the token used when training this model with gap
+ sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during
+ pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for
+ Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf).
+ additional_special_tokens (`List[str]`, *optional*):
+ Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and
+ are used as additional special tokens corresponding to the [original PEGASUS
+ tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66)
+ that uses the tokens 2 - 104 only for pretraining
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ pad_token="",
+ eos_token="",
+ unk_token="",
+ mask_token="",
+ mask_token_sent="",
+ additional_special_tokens=None,
+ offset=103, # entries 2 - 104 are only used for pretraining
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ self.offset = offset
+ if additional_special_tokens is not None:
+ if not isinstance(additional_special_tokens, list):
+ raise TypeError(
+ f"additional_special_tokens should be of type {type(list)}, but is"
+ f" {type(additional_special_tokens)}"
+ )
+ additional_special_tokens_extended = (
+ ([mask_token_sent] + additional_special_tokens)
+ if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
+ else additional_special_tokens
+ )
+ # fill additional tokens with ..., in case not all additional tokens are already taken
+ additional_special_tokens_extended += [
+ f"" for i in range(len(additional_special_tokens_extended), self.offset - 1)
+ ]
+
+ if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
+ raise ValueError(
+ "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+ f" shifted list of tokens. Found {additional_special_tokens_extended}."
+ )
+ additional_special_tokens = additional_special_tokens_extended
+ else:
+ additional_special_tokens_extended = []
+ additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
+ additional_special_tokens += [f"" for i in range(2, self.offset)]
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ self.mask_token_sent = mask_token_sent
+ self.vocab_file = vocab_file
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(vocab_file)
+
+ _added_tokens_decoder = {
+ 0: AddedToken(str(pad_token), special=True),
+ 1: AddedToken(str(eos_token), special=True),
+ }
+
+ if self.mask_token_sent is not None:
+ _added_tokens_decoder[2] = AddedToken(mask_token_sent, special=True)
+ _added_tokens_decoder[3] = AddedToken(str(mask_token), special=True)
+
+ for i in range(2, self.offset):
+ _added_tokens_decoder[len(_added_tokens_decoder)] = AddedToken(f"", special=True)
+
+ # Force update as we want to make sure vocab is enforced (same as fast)
+ self._added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
+ self._added_tokens_decoder.update(_added_tokens_decoder)
+
+ super().__init__(
+ eos_token=eos_token,
+ unk_token=unk_token,
+ mask_token=mask_token,
+ pad_token=pad_token,
+ mask_token_sent=mask_token_sent,
+ offset=offset,
+ additional_special_tokens=additional_special_tokens,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.sp_model) + self.offset
+
+ def get_vocab(self) -> Dict[str, int]:
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab_file)
+
+ def _tokenize(self, text: str) -> List[str]:
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token: str) -> int:
+ """Converts a token (str) to an id using the vocab."""
+ sp_id = self.sp_model.piece_to_id(token)
+ return sp_id + self.offset
+
+ def _convert_id_to_token(self, index: int) -> str:
+ """Converts an index (integer) to a token (str) using the vocab."""
+ if index < self.offset:
+ return self.sp_model.IdToPiece(index)
+ token = self.sp_model.IdToPiece(index - self.offset)
+ return token
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ""
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ out_string += self.sp_model.decode(current_sub_tokens)
+ return out_string.strip()
+
+ def num_special_tokens_to_add(self, pair=False):
+ """Just EOS"""
+ return 1
+
+ def _special_token_mask(self, seq):
+ all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
+ all_special_ids.remove(self.unk_token_id) # is only sometimes special
+
+ return [1 if x in all_special_ids else 0 for x in seq]
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
+ if already_has_special_tokens:
+ return self._special_token_mask(token_ids_0)
+ elif token_ids_1 is None:
+ return self._special_token_mask(token_ids_0) + [1]
+ else:
+ return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
+ and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:
+
+ - single sequence: `X `
+ - pair of sequences: `A B ` (not intended use)
+
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+ separator.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return token_ids_0 + [self.eos_token_id]
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["PegasusTokenizer"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..af62976cb751450699c389fad874811517124e0e
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/pegasus/tokenization_pegasus_fast.py
@@ -0,0 +1,219 @@
+# coding=utf-8
+# Copyright 2020 Google and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for model PEGASUS."""
+
+import os
+from shutil import copyfile
+from typing import List, Optional, Tuple
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+ from .tokenization_pegasus import PegasusTokenizer
+else:
+ PegasusTokenizer = None
+
+
+logger = logging.get_logger(__name__)
+
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
+
+
+class PegasusTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" PEGASUS tokenizer (backed by HuggingFace's *tokenizers* library). Based on
+ [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking single token values. This is the token used when training this model with masked
+ language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining.
+ It corresponds to *[MASK2]* in [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive
+ Summarization](https://arxiv.org/pdf/1912.08777.pdf).
+ mask_token_sent (`str`, *optional*, defaults to `""`):
+ The token used for masking whole target sentences. This is the token used when training this model with gap
+ sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during
+ pretraining. It corresponds to *[MASK1]* in [PEGASUS: Pre-training with Extracted Gap-sentences for
+ Abstractive Summarization](https://arxiv.org/pdf/1912.08777.pdf).
+ additional_special_tokens (`List[str]`, *optional*):
+ Additional special tokens used by the tokenizer. If no additional_special_tokens are provided and
+ are used as additional special tokens corresponding to the [original PEGASUS
+ tokenizer](https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66)
+ that uses the tokens 2 - 104 only for pretraining
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = PegasusTokenizer
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ pad_token="",
+ eos_token="",
+ unk_token="",
+ mask_token="",
+ mask_token_sent="",
+ additional_special_tokens=None,
+ offset=103, # entries 2 - 104 are only used for pretraining
+ **kwargs,
+ ):
+ self.offset = offset
+
+ if additional_special_tokens is not None:
+ if not isinstance(additional_special_tokens, list):
+ raise TypeError(
+ f"additional_special_tokens should be of type {type(list)}, but is"
+ f" {type(additional_special_tokens)}"
+ )
+
+ additional_special_tokens_extended = (
+ ([mask_token_sent] + additional_special_tokens)
+ if mask_token_sent not in additional_special_tokens and mask_token_sent is not None
+ else additional_special_tokens
+ )
+ # fill additional tokens with ..., in case not all additional tokens are already taken
+ additional_special_tokens_extended += [
+ f"" for i in range(len(additional_special_tokens_extended), self.offset - 1)
+ ]
+
+ if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
+ raise ValueError(
+ "Please make sure that the provided additional_special_tokens do not contain an incorrectly"
+ f" shifted list of tokens. Found {additional_special_tokens_extended}."
+ )
+ additional_special_tokens = additional_special_tokens_extended
+ else:
+ additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
+ additional_special_tokens += [f"" for i in range(2, self.offset)]
+
+ # pegasus was design to support changing the index of the first tokens. If one of the padding/eos/unk/mask token
+ # is different from default, we must rebuild the vocab
+ from_slow = kwargs.pop("from_slow", None)
+ from_slow = from_slow or str(pad_token) != "" or str(eos_token) != "" or str(unk_token) != ""
+
+ kwargs.pop("added_tokens_decoder", {})
+
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ pad_token=pad_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ mask_token=mask_token,
+ mask_token_sent=mask_token_sent,
+ offset=offset,
+ additional_special_tokens=additional_special_tokens,
+ from_slow=from_slow,
+ **kwargs,
+ )
+ self.vocab_file = vocab_file
+
+ @property
+ def can_save_slow_tokenizer(self) -> bool:
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+ def _special_token_mask(self, seq):
+ all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
+ all_special_ids.remove(self.unk_token_id) # is only sometimes special
+
+ if all_special_ids != set(range(len(self.additional_special_tokens) + 3)):
+ raise ValueError(
+ "There should be 3 special tokens: mask_token, pad_token, and eos_token +"
+ f" {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"
+ )
+
+ return [1 if x in all_special_ids else 0 for x in seq]
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
+ if already_has_special_tokens:
+ return self._special_token_mask(token_ids_0)
+ elif token_ids_1 is None:
+ return self._special_token_mask(token_ids_0) + [1]
+ else:
+ return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
+ """
+ Build model inputs from a sequence by adding eos to the end. no bos token is added to the front.
+
+ - single sequence: `X `
+ - pair of sequences: `A B ` (not intended use)
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return token_ids_0 + [self.eos_token_id]
+ # We don't expect to process pairs, but leave the pair logic for API consistency
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+
+__all__ = ["PegasusTokenizerFast"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69ff8762af0e322d7038e00b892a1dbb0914e45
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__init__.py
@@ -0,0 +1,54 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {"configuration_vitpose_backbone": ["VitPoseBackboneConfig"]}
+
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_vitpose_backbone"] = [
+ "VitPoseBackbonePreTrainedModel",
+ "VitPoseBackbone",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_vitpose_backbone import VitPoseBackboneConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_vitpose_backbone import (
+ VitPoseBackbone,
+ VitPoseBackbonePreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c53e43a22bb2ba429c4fa1c45a01ee4d027449cb
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/configuration_vitpose_backbone.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/configuration_vitpose_backbone.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6febfd970839a75905e59e621c1c62ebcd9a9d80
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/configuration_vitpose_backbone.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/modeling_vitpose_backbone.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/modeling_vitpose_backbone.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58ad2c3133bf79c9d8c6eb308bb3338f9f893203
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/__pycache__/modeling_vitpose_backbone.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..2872d39a2a30df940307978b4c258a2c6cdfb811
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py
@@ -0,0 +1,136 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""VitPose backbone configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class VitPoseBackboneConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`VitPoseBackbone`]. It is used to instantiate a
+ VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the VitPose
+ [usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to `[256, 192]`):
+ The size (resolution) of each image.
+ patch_size (`List[int]`, *optional*, defaults to `[16, 16]`):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ The ratio of the hidden size in the feedforward network to the hidden size in the attention layers.
+ num_experts (`int`, *optional*, defaults to 1):
+ The number of experts in the MoE layer.
+ part_features (`int`, *optional*):
+ The number of part features to output. Only used in case `num_experts` is greater than 1.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`List[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+
+ ```python
+ >>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
+
+ >>> # Initializing a VitPose configuration
+ >>> configuration = VitPoseBackboneConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = VitPoseBackbone(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "vitpose_backbone"
+
+ def __init__(
+ self,
+ image_size=[256, 192],
+ patch_size=[16, 16],
+ num_channels=3,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ num_experts=1,
+ part_features=256,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ qkv_bias=True,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.num_experts = num_experts
+ self.part_features = part_features
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f057988370f06ded97df5c4555276ca06ea7d8
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py
@@ -0,0 +1,542 @@
+# coding=utf-8
+# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch VitPose backbone model.
+
+This code is the same as the original Vision Transformer (ViT) with 2 modifications:
+- use of padding=2 in the patch embedding layer
+- addition of a mixture-of-experts MLP layer
+"""
+
+import collections.abc
+import math
+from typing import Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BackboneOutput, BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_vitpose_backbone import VitPoseBackboneConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "VitPoseBackboneConfig"
+
+
+class VitPoseBackbonePatchEmbeddings(nn.Module):
+ """Image to Patch Embedding."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ image_size = config.image_size
+ patch_size = config.patch_size
+ num_channels = config.num_channels
+ embed_dim = config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=2)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ height, width = pixel_values.shape[-2:]
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ embeddings = self.projection(pixel_values)
+
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class VitPoseBackboneEmbeddings(nn.Module):
+ """
+ Construct the position and patch embeddings.
+ """
+
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
+ super().__init__()
+
+ self.patch_embeddings = VitPoseBackbonePatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ embeddings = self.patch_embeddings(pixel_values)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1]
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone
+class VitPoseBackboneSelfAttention(nn.Module):
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VitPoseBackbone
+class VitPoseBackboneSelfOutput(nn.Module):
+ """
+ The residual connection is defined in VitPoseBackboneLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VitPoseBackbone
+class VitPoseBackboneAttention(nn.Module):
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
+ super().__init__()
+ self.attention = VitPoseBackboneSelfAttention(config)
+ self.output = VitPoseBackboneSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class VitPoseBackboneMoeMLP(nn.Module):
+ def __init__(self, config: VitPoseBackboneConfig):
+ super().__init__()
+
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+
+ num_experts = config.num_experts
+ part_features = config.part_features
+
+ self.part_features = part_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = ACT2FN[config.hidden_act]
+ self.fc2 = nn.Linear(hidden_features, out_features - part_features)
+ self.drop = nn.Dropout(config.hidden_dropout_prob)
+
+ self.num_experts = num_experts
+ experts = [nn.Linear(hidden_features, part_features) for _ in range(num_experts)]
+ self.experts = nn.ModuleList(experts)
+
+ def forward(self, hidden_state: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+ expert_hidden_state = torch.zeros_like(hidden_state[:, :, -self.part_features :])
+
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.act(hidden_state)
+ shared_hidden_state = self.fc2(hidden_state)
+ indices = indices.view(-1, 1, 1)
+
+ # to support ddp training
+ for i in range(self.num_experts):
+ selected_index = indices == i
+ current_hidden_state = self.experts[i](hidden_state) * selected_index
+ expert_hidden_state = expert_hidden_state + current_hidden_state
+
+ hidden_state = torch.cat([shared_hidden_state, expert_hidden_state], dim=-1)
+
+ return hidden_state
+
+
+class VitPoseBackboneMLP(nn.Module):
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ self.activation = ACT2FN[config.hidden_act]
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class VitPoseBackboneLayer(nn.Module):
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.attention = VitPoseBackboneAttention(config)
+ self.mlp = VitPoseBackboneMLP(config) if self.num_experts == 1 else VitPoseBackboneMoeMLP(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ dataset_index: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ # Validate dataset_index when using multiple experts
+ if self.num_experts > 1 and dataset_index is None:
+ raise ValueError(
+ "dataset_index must be provided when using multiple experts "
+ f"(num_experts={self.num_experts}). Please provide dataset_index "
+ "to the forward pass."
+ )
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in VitPoseBackbone, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ layer_output = self.layernorm_after(hidden_states)
+ if self.num_experts == 1:
+ layer_output = self.mlp(layer_output)
+ else:
+ layer_output = self.mlp(layer_output, indices=dataset_index)
+
+ # second residual connection
+ layer_output = layer_output + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VitPoseBackbone
+class VitPoseBackboneEncoder(nn.Module):
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([VitPoseBackboneLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ dataset_index: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ dataset_index,
+ layer_head_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class VitPoseBackbonePreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = VitPoseBackboneConfig
+ base_model_prefix = "vit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"]
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, VitPoseBackboneEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+
+VITPOSE_BACKBONE_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`VitPoseBackboneConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+VITPOSE_BACKBONE_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values.
+
+ dataset_index (`torch.Tensor` of shape `(batch_size,)`):
+ Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
+
+ This corresponds to the dataset index used during training, e.g. index 0 refers to COCO.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The VitPose backbone useful for downstream tasks.",
+ VITPOSE_BACKBONE_START_DOCSTRING,
+)
+class VitPoseBackbone(VitPoseBackbonePreTrainedModel, BackboneMixin):
+ def __init__(self, config: VitPoseBackboneConfig):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = VitPoseBackboneEmbeddings(config)
+ self.encoder = VitPoseBackboneEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VITPOSE_BACKBONE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ dataset_index: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
+ >>> import torch
+
+ >>> config = VitPoseBackboneConfig(out_indices=[-1])
+ >>> model = VitPoseBackbone(config)
+
+ >>> pixel_values = torch.randn(1, 3, 256, 192)
+ >>> dataset_index = torch.tensor([1])
+ >>> outputs = model(pixel_values, dataset_index)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ dataset_index=dataset_index,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ hidden_state = self.layernorm(hidden_state)
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (feature_maps,) + outputs[1:]
+ else:
+ output = (feature_maps,) + outputs[2:]
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__init__.py b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c5fea1b3147333041a432f8d50861104dc9140
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_x_clip import *
+ from .modeling_x_clip import *
+ from .processing_x_clip import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5310142ff0f8c60b8c80f5bf89b9a6f1bdf5eb2
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/configuration_x_clip.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/configuration_x_clip.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ecfbc6a0190ff5234d8419d1c5d04dab9db9d96b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/configuration_x_clip.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/modeling_x_clip.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/modeling_x_clip.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdfba5f92b0cb642becf821eb7bea189d5adf78a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/modeling_x_clip.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/processing_x_clip.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/processing_x_clip.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d21dd48ab50a81740ebe59faa84164422d165ef9
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/__pycache__/processing_x_clip.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/configuration_x_clip.py b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/configuration_x_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..a500e5dccd90f2d1058b4a60b953ad3f26cc2d03
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/configuration_x_clip.py
@@ -0,0 +1,381 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""X-CLIP model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class XCLIPTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`XCLIPModel`]. It is used to instantiate an X-CLIP
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the X-CLIP
+ [microsoft/xclip-base-patch32](https://huggingface.co/microsoft/xclip-base-patch32) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 49408):
+ Vocabulary size of the X-CLIP text model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`XCLIPModel`].
+ hidden_size (`int`, *optional*, defaults to 512):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 77):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+ Example:
+
+ ```python
+ >>> from transformers import XCLIPTextModel, XCLIPTextConfig
+
+ >>> # Initializing a XCLIPTextModel with microsoft/xclip-base-patch32 style configuration
+ >>> configuration = XCLIPTextConfig()
+
+ >>> # Initializing a XCLIPTextConfig from the microsoft/xclip-base-patch32 style configuration
+ >>> model = XCLIPTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "xclip_text_model"
+ base_config_key = "text_config"
+
+ def __init__(
+ self,
+ vocab_size=49408,
+ hidden_size=512,
+ intermediate_size=2048,
+ num_hidden_layers=12,
+ num_attention_heads=8,
+ max_position_embeddings=77,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+
+
+class XCLIPVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`XCLIPModel`]. It is used to instantiate an X-CLIP
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the X-CLIP
+ [microsoft/xclip-base-patch32](https://huggingface.co/microsoft/xclip-base-patch32) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mit_hidden_size (`int`, *optional*, defaults to 512):
+ Dimensionality of the encoder layers of the Multiframe Integration Transformer (MIT).
+ mit_intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Multiframe Integration Transformer
+ (MIT).
+ mit_num_hidden_layers (`int`, *optional*, defaults to 1):
+ Number of hidden layers in the Multiframe Integration Transformer (MIT).
+ mit_num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Multiframe Integration Transformer (MIT).
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 32):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"`, `"gelu_new"` and `"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate.
+
+ Example:
+
+ ```python
+ >>> from transformers import XCLIPVisionModel, XCLIPVisionConfig
+
+ >>> # Initializing a XCLIPVisionModel with microsoft/xclip-base-patch32 style configuration
+ >>> configuration = XCLIPVisionConfig()
+
+ >>> # Initializing a XCLIPVisionModel model from the microsoft/xclip-base-patch32 style configuration
+ >>> model = XCLIPVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "xclip_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ intermediate_size=3072,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mit_hidden_size=512,
+ mit_intermediate_size=2048,
+ mit_num_hidden_layers=1,
+ mit_num_attention_heads=8,
+ num_channels=3,
+ image_size=224,
+ patch_size=32,
+ num_frames=8,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-5,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ drop_path_rate=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mit_hidden_size = mit_hidden_size
+ self.mit_intermediate_size = mit_intermediate_size
+ self.mit_num_hidden_layers = mit_num_hidden_layers
+ self.mit_num_attention_heads = mit_num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.num_frames = num_frames
+ self.image_size = image_size
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.drop_path_rate = drop_path_rate
+
+
+class XCLIPConfig(PretrainedConfig):
+ r"""
+ [`XCLIPConfig`] is the configuration class to store the configuration of a [`XCLIPModel`]. It is used to
+ instantiate X-CLIP model according to the specified arguments, defining the text model and vision model configs.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the X-CLIP
+ [microsoft/xclip-base-patch32](https://huggingface.co/microsoft/xclip-base-patch32) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`XCLIPTextConfig`].
+ vision_config (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`XCLIPVisionConfig`].
+ projection_dim (`int`, *optional*, defaults to 512):
+ Dimensionality of text and vision projection layers.
+ prompt_layers (`int`, *optional*, defaults to 2):
+ Number of layers in the video specific prompt generator.
+ prompt_alpha (`float`, *optional*, defaults to 0.1):
+ Alpha value to use in the video specific prompt generator.
+ prompt_hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the video specific prompt generator. If string,
+ `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ prompt_num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads in the cross-attention of the video specific prompt generator.
+ prompt_attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for the attention layers in the video specific prompt generator.
+ prompt_projection_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for the projection layers in the video specific prompt generator.
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
+ The inital value of the *logit_scale* parameter. Default is used as per the original XCLIP implementation.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+ """
+
+ model_type = "xclip"
+ sub_configs = {"text_config": XCLIPTextConfig, "vision_config": XCLIPVisionConfig}
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ projection_dim=512,
+ prompt_layers=2,
+ prompt_alpha=0.1,
+ prompt_hidden_act="quick_gelu",
+ prompt_num_attention_heads=8,
+ prompt_attention_dropout=0.0,
+ prompt_projection_dropout=0.0,
+ logit_scale_init_value=2.6592,
+ **kwargs,
+ ):
+ # If `_config_dict` exist, we use them for the backward compatibility.
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
+ # of confusion!).
+ text_config_dict = kwargs.pop("text_config_dict", None)
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
+
+ super().__init__(**kwargs)
+
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
+ if text_config_dict is not None:
+ if text_config is None:
+ text_config = {}
+
+ # This is the complete result when using `text_config_dict`.
+ _text_config_dict = XCLIPTextConfig(**text_config_dict).to_dict()
+
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
+ for key, value in _text_config_dict.items():
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
+ # If specified in `text_config_dict`
+ if key in text_config_dict:
+ message = (
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
+ f'The value `text_config_dict["{key}"]` will be used instead.'
+ )
+ # If inferred from default argument values (just to be super careful)
+ else:
+ message = (
+ f"`text_config_dict` is provided which will be used to initialize `XCLIPTextConfig`. The "
+ f'value `text_config["{key}"]` will be overridden.'
+ )
+ logger.info(message)
+
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
+ text_config.update(_text_config_dict)
+
+ if vision_config_dict is not None:
+ if vision_config is None:
+ vision_config = {}
+
+ # This is the complete result when using `vision_config_dict`.
+ _vision_config_dict = XCLIPVisionConfig(**vision_config_dict).to_dict()
+ # convert keys to string instead of integer
+ if "id2label" in _vision_config_dict:
+ _vision_config_dict["id2label"] = {
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
+ }
+
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
+ for key, value in _vision_config_dict.items():
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
+ # If specified in `vision_config_dict`
+ if key in vision_config_dict:
+ message = (
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
+ )
+ # If inferred from default argument values (just to be super careful)
+ else:
+ message = (
+ f"`vision_config_dict` is provided which will be used to initialize `XCLIPVisionConfig`. "
+ f'The value `vision_config["{key}"]` will be overridden.'
+ )
+ logger.info(message)
+
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
+ vision_config.update(_vision_config_dict)
+
+ if text_config is None:
+ text_config = {}
+ logger.info("`text_config` is `None`. Initializing the `XCLIPTextConfig` with default values.")
+
+ if vision_config is None:
+ vision_config = {}
+ logger.info("`vision_config` is `None`. initializing the `XCLIPVisionConfig` with default values.")
+
+ self.text_config = XCLIPTextConfig(**text_config)
+ self.vision_config = XCLIPVisionConfig(**vision_config)
+
+ self.projection_dim = projection_dim
+ self.prompt_layers = prompt_layers
+ self.prompt_alpha = prompt_alpha
+ self.prompt_hidden_act = prompt_hidden_act
+ self.prompt_num_attention_heads = prompt_num_attention_heads
+ self.prompt_attention_dropout = prompt_attention_dropout
+ self.prompt_projection_dropout = prompt_projection_dropout
+ self.logit_scale_init_value = logit_scale_init_value
+ self.initializer_factor = 1.0
+
+ @classmethod
+ def from_text_vision_configs(cls, text_config: XCLIPTextConfig, vision_config: XCLIPVisionConfig, **kwargs):
+ r"""
+ Instantiate a [`XCLIPConfig`] (or a derived class) from xclip text model configuration and xclip vision model
+ configuration.
+
+ Returns:
+ [`XCLIPConfig`]: An instance of a configuration object
+ """
+
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
+
+
+__all__ = ["XCLIPConfig", "XCLIPTextConfig", "XCLIPVisionConfig"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/modeling_x_clip.py b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/modeling_x_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce1ea02e7b9e085cb8b663a233d84f83d41a14b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/modeling_x_clip.py
@@ -0,0 +1,1680 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch X-CLIP model."""
+
+from copy import copy
+from dataclasses import dataclass
+from typing import Any, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+ torch_int,
+)
+from .configuration_x_clip import XCLIPConfig, XCLIPTextConfig, XCLIPVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "microsoft/xclip-base-patch32"
+
+
+# contrastive loss function, adapted from
+# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
+def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
+
+
+# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->x_clip
+def x_clip_loss(similarity: torch.Tensor) -> torch.Tensor:
+ caption_loss = contrastive_loss(similarity)
+ image_loss = contrastive_loss(similarity.t())
+ return (caption_loss + image_loss) / 2.0
+
+
+@dataclass
+class XCLIPOutput(ModelOutput):
+ """
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for video-text similarity.
+ logits_per_video (`torch.FloatTensor` of shape `(video_batch_size, text_batch_size)`):
+ The scaled dot product scores between `video_embeds` and `text_embeds`. This represents the video-text
+ similarity scores.
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, video_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `video_embeds`. This represents the text-video
+ similarity scores.
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`XCLIPTextModel`].
+ video_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The video embeddings obtained by applying the projection layer to the pooled output of
+ [`XCLIPVisionModel`].
+ text_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`XCLIPTextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`XCLIPVisionModel`].
+ mit_output (`BaseModelOutputWithPooling`):
+ The output of `XCLIPMultiframeIntegrationTransformer` (MIT for short).
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_video: torch.FloatTensor = None
+ logits_per_text: torch.FloatTensor = None
+ text_embeds: torch.FloatTensor = None
+ video_embeds: torch.FloatTensor = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+ mit_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ return tuple(
+ self[k]
+ if k not in ["text_model_output", "vision_model_output", "mit_output"]
+ else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->XCLIP
+class XCLIPVisionEmbeddings(nn.Module):
+ def __init__(self, config: XCLIPVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=False,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ position_embedding = self.position_embedding.weight.unsqueeze(0)
+ num_positions = position_embedding.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embedding(self.position_ids)
+
+ class_pos_embed = position_embedding[:, :1]
+ patch_pos_embed = position_embedding[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
+ )
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->XCLIP
+class XCLIPTextEmbeddings(nn.Module):
+ def __init__(self, config: XCLIPTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->XCLIP
+class XCLIPAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->XCLIP
+class XCLIPMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->XCLIP
+class XCLIPEncoderLayer(nn.Module):
+ def __init__(self, config: XCLIPConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = XCLIPAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = XCLIPMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->XCLIP
+class XCLIPDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class XCLIPVisionEncoderLayer(nn.Module):
+ """
+ This corresponds to the `CrossFramelAttentionBlock` class in the original implementation.
+ """
+
+ def __init__(self, config: XCLIPConfig):
+ super().__init__()
+ self.num_frames = config.num_frames
+ self.embed_dim = config.hidden_size
+
+ self.message_fc = nn.Linear(self.embed_dim, self.embed_dim)
+ self.message_ln = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.message_attn = XCLIPAttention(config)
+
+ self.drop_path = XCLIPDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+
+ self.self_attn = XCLIPAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = XCLIPMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ causal_attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ batch_time, seq_length, hidden_size = hidden_states.size()
+ batch_size = batch_time // self.num_frames
+ msg_token = self.message_fc(hidden_states[:, 0, :])
+ msg_token = msg_token.view(batch_size, self.num_frames, hidden_size)
+
+ msg_token = msg_token + self.drop_path(self.message_attn(self.message_ln(msg_token))[0])
+ # add dummy sequence dimension
+ msg_token = msg_token.view(-1, 1, hidden_size)
+
+ hidden_states = torch.cat([hidden_states, msg_token], dim=1)
+
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ hidden_states = hidden_states[:, :seq_length, :]
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class XCLIPPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = XCLIPConfig
+ base_model_prefix = "x_clip"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ factor = self.config.initializer_factor
+ if isinstance(module, XCLIPTextEmbeddings):
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
+ elif isinstance(module, XCLIPVisionEmbeddings):
+ factor = self.config.initializer_factor
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
+ elif isinstance(module, XCLIPAttention):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ out_proj_std = (module.embed_dim**-0.5) * factor
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
+ elif isinstance(module, XCLIPMLP):
+ factor = self.config.initializer_factor
+ in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
+ nn.init.normal_(module.fc1.weight, std=fc_std)
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
+ elif isinstance(module, XCLIPModel):
+ factor = self.config.initializer_factor
+ nn.init.normal_(
+ module.text_projection.weight,
+ std=module.text_embed_dim**-0.5 * factor,
+ )
+ nn.init.normal_(
+ module.visual_projection.weight,
+ std=module.vision_embed_dim**-0.5 * factor,
+ )
+ nn.init.normal_(module.prompts_visual_projection, mean=0.0, std=module.vision_embed_dim**-0.5 * factor)
+ elif isinstance(module, XCLIPMultiframeIntegrationTransformer):
+ nn.init.normal_(module.position_embedding, std=self.config.initializer_factor)
+
+ if isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+X_CLIP_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`XCLIPConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+X_CLIP_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+X_CLIP_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+X_CLIP_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->XCLIP
+class XCLIPEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`XCLIPEncoderLayer`].
+
+ Args:
+ config: XCLIPConfig
+ """
+
+ def __init__(self, config: XCLIPConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([XCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class XCLIPTextTransformer(nn.Module):
+ def __init__(self, config: XCLIPTextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = XCLIPTextEmbeddings(config)
+ self.encoder = XCLIPEncoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ @add_start_docstrings_to_model_forward(X_CLIP_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=XCLIPTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None:
+ raise ValueError("You have to specify either input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # X_CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = _create_4d_causal_attention_mask(
+ input_shape, hidden_states.dtype, device=hidden_states.device
+ )
+ # expand attention_mask
+ if attention_mask is not None:
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class XCLIPTextModel(XCLIPPreTrainedModel):
+ config_class = XCLIPTextConfig
+
+ def __init__(self, config: XCLIPTextConfig):
+ super().__init__(config)
+ self.text_model = XCLIPTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @add_start_docstrings_to_model_forward(X_CLIP_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=XCLIPTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, XCLIPTextModel
+
+ >>> model = XCLIPTextModel.from_pretrained("microsoft/xclip-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/xclip-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class XCLIPVisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`XCLIPVisionEncoderLayer`].
+
+ Args:
+ config: XCLIPConfig
+ """
+
+ def __init__(self, config: XCLIPConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([XCLIPVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ causal_attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class XCLIPVisionTransformer(nn.Module):
+ """
+ This corresponds to the `CrossFrameCommunicationTransformer` class in the original implementation.
+ """
+
+ def __init__(self, config: XCLIPVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = XCLIPVisionEmbeddings(config)
+ self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.encoder = XCLIPVisionEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ @add_start_docstrings_to_model_forward(X_CLIP_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=XCLIPVisionConfig)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+ hidden_states = self.pre_layernorm(hidden_states)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class XCLIPVisionModel(XCLIPPreTrainedModel):
+ config_class = XCLIPVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: XCLIPVisionConfig):
+ super().__init__(config)
+ self.vision_model = XCLIPVisionTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @add_start_docstrings_to_model_forward(X_CLIP_VISION_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=XCLIPVisionConfig)
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import av
+ >>> import torch
+ >>> import numpy as np
+
+ >>> from transformers import AutoProcessor, XCLIPVisionModel
+ >>> from huggingface_hub import hf_hub_download
+
+ >>> np.random.seed(0)
+
+
+ >>> def read_video_pyav(container, indices):
+ ... '''
+ ... Decode the video with PyAV decoder.
+ ... Args:
+ ... container (`av.container.input.InputContainer`): PyAV container.
+ ... indices (`List[int]`): List of frame indices to decode.
+ ... Returns:
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
+ ... '''
+ ... frames = []
+ ... container.seek(0)
+ ... start_index = indices[0]
+ ... end_index = indices[-1]
+ ... for i, frame in enumerate(container.decode(video=0)):
+ ... if i > end_index:
+ ... break
+ ... if i >= start_index and i in indices:
+ ... frames.append(frame)
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
+ ... '''
+ ... Sample a given number of frame indices from the video.
+ ... Args:
+ ... clip_len (`int`): Total number of frames to sample.
+ ... frame_sample_rate (`int`): Sample every n-th frame.
+ ... seg_len (`int`): Maximum allowed index of sample's last frame.
+ ... Returns:
+ ... indices (`List[int]`): List of sampled frame indices
+ ... '''
+ ... converted_len = int(clip_len * frame_sample_rate)
+ ... end_idx = np.random.randint(converted_len, seg_len)
+ ... start_idx = end_idx - converted_len
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
+ ... return indices
+
+
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
+ >>> file_path = hf_hub_download(
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+ ... )
+ >>> container = av.open(file_path)
+
+ >>> # sample 16 frames
+ >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
+ >>> video = read_video_pyav(container, indices)
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
+ >>> model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32")
+
+ >>> pixel_values = processor(videos=list(video), return_tensors="pt").pixel_values
+
+ >>> batch_size, num_frames, num_channels, height, width = pixel_values.shape
+ >>> pixel_values = pixel_values.reshape(-1, num_channels, height, width)
+
+ >>> outputs = model(pixel_values)
+ >>> last_hidden_state = outputs.last_hidden_state
+ ```"""
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class XCLIPMultiframeIntegrationTransformer(nn.Module):
+ """
+ This corresponds to the `MultiframeIntegrationTransformer` class in the original implementation.
+ """
+
+ def __init__(self, config: XCLIPVisionConfig):
+ super().__init__()
+
+ self.position_embedding = nn.Parameter(torch.empty(1, config.num_frames, config.hidden_size))
+ self.encoder = XCLIPEncoder(config)
+
+ def forward(
+ self,
+ hidden_states,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ residual = hidden_states
+
+ # add position embeddings
+ hidden_states = hidden_states + self.position_embedding
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = encoder_outputs[0]
+
+ last_hidden_state = last_hidden_state.type(hidden_states.dtype) + residual
+
+ pooled_output = last_hidden_state.mean(dim=1, keepdim=False)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class XCLIPCrossAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_heads = config.prompt_num_attention_heads
+
+ dim = config.projection_dim
+ head_dim = dim // self.num_heads
+ self.scale = head_dim**-0.5
+
+ self.q_proj = nn.Linear(dim, dim, bias=False)
+ self.k_proj = nn.Linear(dim, dim, bias=False)
+ self.v_proj = nn.Linear(dim, dim, bias=False)
+
+ self.attn_drop = nn.Dropout(config.prompt_attention_dropout)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(config.prompt_projection_dropout)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(self, queries, keys, values):
+ """Input shape: Batch x Time x Channel"""
+ batch_size, query_seq_len, hidden_size = queries.shape
+ batch_size, key_seq_len, hidden_size = keys.shape
+ queries = (
+ self.q_proj(queries)
+ .reshape(batch_size, query_seq_len, self.num_heads, hidden_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ keys = (
+ self.k_proj(keys)
+ .reshape(batch_size, key_seq_len, self.num_heads, hidden_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+ values = (
+ self.v_proj(values)
+ .reshape(batch_size, key_seq_len, self.num_heads, hidden_size // self.num_heads)
+ .permute(0, 2, 1, 3)
+ )
+
+ attn = (queries @ keys.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ values).transpose(1, 2).reshape(batch_size, query_seq_len, hidden_size)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class PromptGeneratorLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ embed_dim = config.projection_dim
+ self.cross_attn = XCLIPCrossAttention(config)
+ self.norm1 = nn.LayerNorm(embed_dim, eps=config.text_config.layer_norm_eps)
+ self.norm3 = nn.LayerNorm(embed_dim, eps=config.text_config.layer_norm_eps)
+ self.mlp = nn.Sequential(
+ nn.Linear(embed_dim, embed_dim * 4),
+ ACT2FN[config.prompt_hidden_act],
+ nn.Dropout(config.prompt_attention_dropout),
+ nn.Linear(embed_dim * 4, embed_dim),
+ )
+
+ def forward(self, x, visual):
+ x = x + self.cross_attn(self.norm1(x), visual, visual)
+ x = x + self.mlp(self.norm3(x))
+ return x
+
+
+class XCLIPPromptGenerator(nn.Module):
+ """This corresponds to the `VideoSpecificPrompt` class in the original implementation."""
+
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.projection_dim
+ self.layernorm = nn.LayerNorm(embed_dim, eps=config.vision_config.layer_norm_eps)
+ self.decoder = nn.ModuleList([PromptGeneratorLayer(config) for _ in range(config.prompt_layers)])
+ self.alpha = nn.Parameter(torch.ones(embed_dim) * config.prompt_alpha)
+
+ def forward(self, text, visual):
+ visual = self.layernorm(visual)
+ for layer in self.decoder:
+ text = layer(text, visual)
+
+ return self.alpha * text
+
+
+@add_start_docstrings(X_CLIP_START_DOCSTRING)
+class XCLIPModel(XCLIPPreTrainedModel):
+ config_class = XCLIPConfig
+
+ def __init__(self, config: XCLIPConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, XCLIPTextConfig):
+ raise TypeError(
+ "config.text_config is expected to be of type XCLIPTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, XCLIPVisionConfig):
+ raise TypeError(
+ "config.vision_config is expected to be of type XCLIPVisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ self.projection_dim = config.projection_dim
+ self.text_embed_dim = text_config.hidden_size
+ self.vision_embed_dim = vision_config.hidden_size
+
+ self.text_model = XCLIPTextTransformer(text_config)
+ self.vision_model = XCLIPVisionTransformer(vision_config)
+
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
+
+ self.prompts_visual_layernorm = nn.LayerNorm(self.vision_embed_dim, eps=config.vision_config.layer_norm_eps)
+ self.prompts_visual_projection = nn.Parameter(torch.randn(self.vision_embed_dim, self.projection_dim))
+
+ mit_config = copy(vision_config)
+ mit_config.hidden_size = vision_config.mit_hidden_size
+ mit_config.intermediate_size = vision_config.mit_intermediate_size
+ mit_config.num_hidden_layers = vision_config.mit_num_hidden_layers
+ mit_config.num_attention_heads = vision_config.mit_num_attention_heads
+ self.mit = XCLIPMultiframeIntegrationTransformer(mit_config)
+
+ self.prompts_generator = XCLIPPromptGenerator(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(X_CLIP_TEXT_INPUTS_DOCSTRING)
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`XCLIPTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/xclip-base-patch32")
+ >>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+ >>> text_features = model.get_text_features(**inputs)
+ ```"""
+ # Use X_CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+
+ return text_embeds
+
+ @add_start_docstrings_to_model_forward(X_CLIP_VISION_INPUTS_DOCSTRING)
+ def get_video_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ video_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The video embeddings obtained by
+ applying the projection layer to the pooled output of [`XCLIPVisionModel`] and
+ [`XCLIPMultiframeIntegrationTransformer`].
+
+ Examples:
+
+ ```python
+ >>> import av
+ >>> import torch
+ >>> import numpy as np
+
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> from huggingface_hub import hf_hub_download
+
+ >>> np.random.seed(0)
+
+
+ >>> def read_video_pyav(container, indices):
+ ... '''
+ ... Decode the video with PyAV decoder.
+ ... Args:
+ ... container (`av.container.input.InputContainer`): PyAV container.
+ ... indices (`List[int]`): List of frame indices to decode.
+ ... Returns:
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
+ ... '''
+ ... frames = []
+ ... container.seek(0)
+ ... start_index = indices[0]
+ ... end_index = indices[-1]
+ ... for i, frame in enumerate(container.decode(video=0)):
+ ... if i > end_index:
+ ... break
+ ... if i >= start_index and i in indices:
+ ... frames.append(frame)
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
+ ... '''
+ ... Sample a given number of frame indices from the video.
+ ... Args:
+ ... clip_len (`int`): Total number of frames to sample.
+ ... frame_sample_rate (`int`): Sample every n-th frame.
+ ... seg_len (`int`): Maximum allowed index of sample's last frame.
+ ... Returns:
+ ... indices (`List[int]`): List of sampled frame indices
+ ... '''
+ ... converted_len = int(clip_len * frame_sample_rate)
+ ... end_idx = np.random.randint(converted_len, seg_len)
+ ... start_idx = end_idx - converted_len
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
+ ... return indices
+
+
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
+ >>> file_path = hf_hub_download(
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+ ... )
+ >>> container = av.open(file_path)
+
+ >>> # sample 8 frames
+ >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
+ >>> video = read_video_pyav(container, indices)
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
+ >>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
+
+ >>> inputs = processor(videos=list(video), return_tensors="pt")
+
+ >>> video_features = model.get_video_features(**inputs)
+ ```"""
+ # Use X_CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(-1, num_channels, height, width)
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ video_embeds = vision_outputs[1]
+ video_embeds = self.visual_projection(video_embeds)
+
+ cls_features = video_embeds.view(batch_size, num_frames, -1)
+
+ mit_outputs = self.mit(
+ cls_features,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ video_embeds = mit_outputs[1]
+
+ return video_embeds
+
+ @add_start_docstrings_to_model_forward(X_CLIP_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=XCLIPOutput, config_class=XCLIPConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, XCLIPOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import av
+ >>> import torch
+ >>> import numpy as np
+
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> from huggingface_hub import hf_hub_download
+
+ >>> np.random.seed(0)
+
+
+ >>> def read_video_pyav(container, indices):
+ ... '''
+ ... Decode the video with PyAV decoder.
+ ... Args:
+ ... container (`av.container.input.InputContainer`): PyAV container.
+ ... indices (`List[int]`): List of frame indices to decode.
+ ... Returns:
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
+ ... '''
+ ... frames = []
+ ... container.seek(0)
+ ... start_index = indices[0]
+ ... end_index = indices[-1]
+ ... for i, frame in enumerate(container.decode(video=0)):
+ ... if i > end_index:
+ ... break
+ ... if i >= start_index and i in indices:
+ ... frames.append(frame)
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
+
+
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
+ ... '''
+ ... Sample a given number of frame indices from the video.
+ ... Args:
+ ... clip_len (`int`): Total number of frames to sample.
+ ... frame_sample_rate (`int`): Sample every n-th frame.
+ ... seg_len (`int`): Maximum allowed index of sample's last frame.
+ ... Returns:
+ ... indices (`List[int]`): List of sampled frame indices
+ ... '''
+ ... converted_len = int(clip_len * frame_sample_rate)
+ ... end_idx = np.random.randint(converted_len, seg_len)
+ ... start_idx = end_idx - converted_len
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
+ ... return indices
+
+
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
+ >>> file_path = hf_hub_download(
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
+ ... )
+ >>> container = av.open(file_path)
+
+ >>> # sample 8 frames
+ >>> indices = sample_frame_indices(clip_len=8, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
+ >>> video = read_video_pyav(container, indices)
+
+ >>> processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
+ >>> model = AutoModel.from_pretrained("microsoft/xclip-base-patch32")
+
+ >>> inputs = processor(
+ ... text=["playing sports", "eating spaghetti", "go shopping"],
+ ... videos=list(video),
+ ... return_tensors="pt",
+ ... padding=True,
+ ... )
+
+ >>> # forward pass
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> logits_per_video = outputs.logits_per_video # this is the video-text similarity score
+ >>> probs = logits_per_video.softmax(dim=1) # we can take the softmax to get the label probabilities
+ >>> print(probs)
+ tensor([[1.9496e-04, 9.9960e-01, 2.0825e-04]])
+ ```"""
+ # Use X_CLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(-1, num_channels, height, width)
+
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ video_embeds = vision_outputs[1]
+ video_embeds = self.visual_projection(video_embeds)
+
+ cls_features = video_embeds.view(batch_size, num_frames, -1)
+
+ mit_outputs = self.mit(
+ cls_features,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ video_embeds = mit_outputs[1]
+
+ img_features = vision_outputs[0][:, 1:, :]
+ img_features = self.prompts_visual_layernorm(img_features)
+ img_features = img_features @ self.prompts_visual_projection
+ img_features = img_features.view(batch_size, num_frames, -1, video_embeds.shape[-1])
+ img_features = img_features.mean(dim=1, keepdim=False)
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_embeds = text_outputs[1]
+ text_embeds = self.text_projection(text_embeds)
+
+ text_embeds = text_embeds.unsqueeze(0).expand(batch_size, -1, -1)
+ text_embeds = text_embeds + self.prompts_generator(text_embeds, img_features)
+
+ # normalized features
+ video_embeds = video_embeds / video_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_video = torch.einsum("bd,bkd->bk", video_embeds, logit_scale * text_embeds)
+ logits_per_text = logits_per_video.T
+
+ loss = None
+ if return_loss:
+ loss = x_clip_loss(logits_per_text)
+
+ if not return_dict:
+ output = (logits_per_video, logits_per_text, text_embeds, video_embeds, text_outputs, vision_outputs)
+ return ((loss,) + output) if loss is not None else output
+
+ return XCLIPOutput(
+ loss=loss,
+ logits_per_video=logits_per_video,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ video_embeds=video_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ mit_output=mit_outputs,
+ )
+
+
+__all__ = ["XCLIPModel", "XCLIPPreTrainedModel", "XCLIPTextModel", "XCLIPVisionModel"]
diff --git a/.venv/lib/python3.11/site-packages/transformers/models/x_clip/processing_x_clip.py b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/processing_x_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a17d3a15a208461e4b1f9f12ccfc084df73fa0d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/models/x_clip/processing_x_clip.py
@@ -0,0 +1,151 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for XCLIP
+"""
+
+import warnings
+
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding
+
+
+class XCLIPProcessor(ProcessorMixin):
+ r"""
+ Constructs an X-CLIP processor which wraps a VideoMAE image processor and a CLIP tokenizer into a single processor.
+
+ [`XCLIPProcessor`] offers all the functionalities of [`VideoMAEImageProcessor`] and [`CLIPTokenizerFast`]. See the
+ [`~XCLIPProcessor.__call__`] and [`~XCLIPProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`VideoMAEImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`CLIPTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "VideoMAEImageProcessor"
+ tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
+
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
+ feature_extractor = None
+ if "feature_extractor" in kwargs:
+ warnings.warn(
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
+ " instead.",
+ FutureWarning,
+ )
+ feature_extractor = kwargs.pop("feature_extractor")
+
+ image_processor = image_processor if image_processor is not None else feature_extractor
+ if image_processor is None:
+ raise ValueError("You need to specify an `image_processor`.")
+ if tokenizer is None:
+ raise ValueError("You need to specify a `tokenizer`.")
+
+ super().__init__(image_processor, tokenizer)
+ self.current_processor = self.image_processor
+
+ def __call__(self, text=None, videos=None, return_tensors=None, **kwargs):
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `videos` and `kwargs` arguments to
+ VideoMAEImageProcessor's [`~VideoMAEImageProcessor.__call__`] if `videos` is not `None`. Please refer to the
+ doctsring of the above two methods for more information.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ videos (`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, `List[List[PIL.Image.Image]]`, `List[List[np.ndarrray]]`,:
+ `List[List[torch.Tensor]]`): The video or batch of videos to be prepared. Each video should be a list
+ of frames, which can be either PIL images or NumPy arrays. In case of NumPy arrays/PyTorch tensors,
+ each frame should be of shape (H, W, C), where H and W are frame height and width, and C is a number of
+ channels.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `videos` is not `None`.
+ """
+
+ if text is None and videos is None:
+ raise ValueError("You have to specify either text or videos. Both cannot be none.")
+
+ if text is not None:
+ encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
+
+ if videos is not None:
+ image_features = self.image_processor(videos, return_tensors=return_tensors, **kwargs)
+
+ if text is not None and videos is not None:
+ encoding["pixel_values"] = image_features.pixel_values
+ return encoding
+ elif text is not None:
+ return encoding
+ else:
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ return ["input_ids", "attention_mask", "position_ids", "pixel_values"]
+
+ @property
+ def feature_extractor_class(self):
+ warnings.warn(
+ "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
+ FutureWarning,
+ )
+ return self.image_processor_class
+
+ @property
+ def feature_extractor(self):
+ warnings.warn(
+ "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
+ FutureWarning,
+ )
+ return self.image_processor
+
+
+__all__ = ["XCLIPProcessor"]